Exemple #1
0
def conditional_mle_training(model,
                             data,
                             valid_data=None,
                             optimizer=torch.optim.Adam,
                             optimizer_kwargs=None,
                             eval_no_grad=True,
                             **kwargs):
    opt = filter_kwargs(kwargs, ctx=TrainingContext)
    ctx = TrainingContext(**opt.ctx)
    ctx.optimizer = optimizer

    # networks to device
    ctx.register(data=to_device(data, ctx.device),
                 model=to_device(model, ctx.device))

    ctx.add(train_step=UpdateStep(
        partial(maximum_likelihood_step, ctx.model, ctx.data),
        Update(
            [ctx.model], optimizer=ctx.optimizer, **(optimizer_kwargs or {})),
        ctx=ctx))

    if valid_data is not None:
        ctx.register(valid_data=to_device(valid_data, ctx.device))
        ctx.add(valid_step=EvalStep(partial(maximum_likelihood_step, ctx.model,
                                            ctx.valid_data),
                                    modules=[ctx.model],
                                    no_grad=eval_no_grad,
                                    ctx=ctx),
                every=ctx.report_interval)
    return ctx
def supervised_training(net,
                        data,
                        valid_data=None,
                        losses=None,
                        optimizer=torch.optim.Adam,
                        optimizer_kwargs=None,
                        eval_no_grad=True,
                        **kwargs):
    opt = filter_kwargs(kwargs, ctx=TrainingContext)
    ctx = TrainingContext(**opt.ctx)
    ctx.optimizer = optimizer
    ctx.losses = losses

    # networks to device
    ctx.register(data=to_device(data, ctx.device),
                 net=to_device(net, ctx.device))

    ctx.add(train_step=UpdateStep(
        partial(supervised_step, ctx.net, ctx.data, losses=ctx.losses),
        Update([ctx.net], optimizer=ctx.optimizer, **(optimizer_kwargs or {})),
        ctx=ctx))

    if valid_data is not None:
        ctx.register(valid_data=to_device(valid_data, ctx.device))
        ctx.add(valid_step=EvalStep(partial(supervised_step,
                                            ctx.net,
                                            ctx.valid_data,
                                            losses=ctx.losses),
                                    modules=[ctx.net],
                                    no_grad=eval_no_grad,
                                    ctx=ctx),
                every=ctx.report_interval)
    return ctx
Exemple #3
0
 def decompose_batch(self, data, *args):
     if isinstance(data, (list, tuple)):
         return (to_device(([item[idx]
                             for item in data], *[arg[idx]
                                                  for arg in args]), "cpu")
                 for idx in range(len(data[0])))
     return (to_device((data[idx], *[arg[idx] for arg in args]), "cpu")
             for idx in range(len(data)))
Exemple #4
0
 def sample(self):
     scr = self.score
     self.score.eval()
     integrator = self.integrator
     prep = to_device(self.prepare_sample(), self.device)
     data, *args = self.data_key(prep)
     result = integrator.integrate(scr, data, *args).detach()
     self.score.train()
     return to_device((result, data, *args), self.device)
Exemple #5
0
 def validate(self, data):
     with torch.no_grad():
         self.net.eval()
         outputs = self.run_networks(data)
         self.valid_loss(outputs)
         self.each_validate()
         self.valid_callback(self, to_device(data, "cpu"),
                             to_device(outputs, "cpu"))
         self.net.train()
Exemple #6
0
    def out_of_sample(self):
        result = []
        for idx in range(self.batch_size):
            sample = self.bad_prepare()
            result.append(sample)
        data, *args = to_device(default_collate(result), self.device)
        #data = self.integrator.integrate(self.score, data, *args).detach()
        detached = to_device(data.detach(), "cpu")

        return to_device((data, *args), self.device)
Exemple #7
0
 def sample(self):
     self.score.eval()
     with torch.no_grad():
         integrator = AnnealedLangevin(
             [self.sigma * self.factor**idx for idx in range(self.n_sigma)])
         prep = to_device(self.prepare_sample(), self.device)
         data, *args = self.data_key(prep)
         result = integrator.integrate(self.score, data, *args).detach()
     self.score.train()
     return to_device((result, data, *args), self.device)
Exemple #8
0
    def sample(self):
        buffer_iter = iter(self.buffer_loader(self.buffer))
        data, *args = to_device(self.data_key(next(buffer_iter)), self.device)
        data = self.integrator.integrate(self.score, data, *args).detach()
        detached = data.detach().cpu()
        update = (to_device((detached[idx], *[arg[idx]
                                              for arg in args]), "cpu")
                  for idx in range(data.size(0)))
        make_differentiable(update, toggle=False)
        self.buffer.update(update)

        return to_device((detached, *args), self.device)
Exemple #9
0
def base_sm_training(energy, data, optimizer=torch.optim.Adam, **kwargs):
    opt = filter_kwargs(kwargs, ctx=TrainingContext)
    ctx = TrainingContext(**opt.ctx)
    ctx.optimizer = optimizer

    # networks to device
    energy_target = deepcopy(energy)
    ctx.register(data=to_device(data, ctx.device),
                 energy=to_device(energy, ctx.device),
                 energy_target=to_device(energy_target, ctx.device))

    return ctx
Exemple #10
0
    def sample(self):
        buffer_iter = iter(self.buffer_loader(self.buffer))
        data, *args = to_device(self.data_key(next(buffer_iter)), self.device)
        self.score.eval()
        data = detach(self.integrator.integrate(self.score, data, *args))
        self.score.train()
        detached = to_device(data, "cpu")
        args = detach(args)
        update = self.decompose_batch(detached, *args)
        self.buffer.update(update)

        return to_device((detached, *args), self.device)
Exemple #11
0
    def sample(self):
        buffer_iter = iter(self.buffer_loader(self.buffer))
        data, *args = to_device(self.data_key(next(buffer_iter)), self.device)
        self.score.eval()
        data = self.integrator.integrate(self.create_score(), data,
                                         *args).detach()
        self.score.train()
        detached = to_device(data.detach(), "cpu")
        update = self.decompose_batch(detached, *args)
        make_differentiable(update, toggle=False)
        self.buffer.update(update)

        return to_device((detached, *args), self.device)
Exemple #12
0
 def validate(self, data):
     with torch.no_grad():
         self.net.eval()
         if self.accumulate is not None:
             point = self.chunk(data, self.accumulate)[0]
             outputs = self.run_networks(point)
         else:
             outputs = self.run_networks(data)
         self.valid_loss(outputs)
         self.each_validate()
         self.valid_callback(self, to_device(data, "cpu"),
                             to_device(outputs, "cpu"))
         self.net.train()
Exemple #13
0
 def validate(self, data):
   with torch.no_grad():
     self.net.eval()
     inputs, *label = data
     inputs, label = inputs.to(self.device), list(map(lambda x: x.to(self.device), label))
     support, support_label = next(self.valid_support_loader)
     support = support[0].to(self.device)
     support_label = support_label[0].to(self.device)
     outputs = self.run_networks(inputs, support, support_label)
     self.valid_loss(outputs)
     self.each_validate()
     self.valid_callback(self, to_device(inputs, "cpu"), to_device(outputs, "cpu"))
     self.net.train()
Exemple #14
0
    def run_energy(self, data):
        data, labels = data

        make_differentiable(data)
        input_data, *data_args = self.data_key(data)
        real_logits = self.score(input_data, *data_args)

        # sample after first pass over real data, to catch
        # possible batch-norm shenanigans without blowing up.
        fake = self.sample()

        if self.step_id % self.report_interval == 0:
            detached, *args = self.data_key(to_device(fake, "cpu"))
            self.each_generate(detached.detach(), *args)

        make_differentiable(fake)
        input_fake, *fake_args = self.data_key(fake)
        fake_logits = self.score(input_fake, *fake_args)

        real_result = self.logit_energy(real_logits, *data_args)
        fake_result = self.logit_energy(fake_logits, *fake_args)

        # set integrator target, if appropriate.
        if self.integrator.target is None:
            self.integrator.target = real_result.detach().mean(dim=0)
        self.integrator.target = 0.6 * self.integrator.target + 0.4 * real_result.detach(
        ).mean(dim=0)
        return real_result, fake_result, real_logits, labels
Exemple #15
0
 def step(self, data):
     data = to_device(data, self.device)
     real, *args = self.data_key(data)
     latent, fake = self.posterior_step(real, args)
     self.energy_step(real, fake, args)
     self.generator_step(latent, fake, args)
     self.each_step()
Exemple #16
0
    def step(self, data):
        data = to_device(data, self.device)
        data, *netargs = self.preprocess(data)

        self.critic_optimizer.zero_grad()
        pass_through, *critic_args = self.run_critic(data, *netargs)
        loss_val = self.critic_loss(*critic_args)
        loss_val.backward(retain_graph=True)
        self.writer.add_scalar("critic loss", float(loss_val), self.step_id)
        self.critic_optimizer.step()

        self.generator_optimizer.zero_grad()
        generator_args = self.run_generator(data, *pass_through, *netargs)
        loss_val = self.generator_loss(*generator_args)
        loss_val.backward()
        self.writer.add_scalar("generator loss", float(loss_val), self.step_id)
        self.generator_optimizer.step()

        if self.verbose:
            for loss_name in self.current_losses:
                loss_float = self.current_losses[loss_name]
                self.writer.add_scalar(f"{loss_name} loss", loss_float,
                                       self.step_id)

        self.each_step()

        return float(loss_val)
Exemple #17
0
  def step(self, data):
    """Performs a single step of VAE training.

    Args:
      data: data points used for training."""
    self.optimizer.zero_grad()
    data = to_device(data, self.device)
    data, *netargs = self.preprocess(data)
    args = self.run_networks(data, *netargs)

    loss_val = self.loss(*args)

    if self.verbose:
      if self.step_id % self.report_interval == 0:
        self.each_generate(*args)
      for loss_name in self.current_losses:
        loss_float = self.current_losses[loss_name]
        self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id)
    self.writer.add_scalar("total loss", float(loss_val), self.step_id)

    loss_val.backward()
    parameters = [
      param
      for key, val in self.get_netlist(self.network_names).items()
      for param in val.parameters()
    ]
    gn = nn.utils.clip_grad_norm_(parameters, self.gradient_clip)
    if (not torch.isnan(gn).any()) and (gn < self.gradient_skip).all():
      self.optimizer.step()
    self.each_step()

    return float(loss_val)
Exemple #18
0
    def train(self):
        """Trains a VAE until the maximum number of epochs is reached."""
        for epoch_id in range(self.max_epochs):
            self.epoch_id = epoch_id
            self.train_data = None
            self.train_data = DataLoader(self.data,
                                         batch_size=self.batch_size,
                                         num_workers=8,
                                         shuffle=True)
            if self.valid is not None:
                self.valid_data = DataLoader(self.valid,
                                             batch_size=self.batch_size,
                                             num_workers=8,
                                             shuffle=True)
            for data in self.train_data:
                self.step(data)
                if self.step_id % self.checkpoint_interval == 0:
                    self.checkpoint()
                if self.valid is not None and self.step_id % self.report_interval == 0:
                    vdata = None
                    try:
                        vdata = next(valid_iter)
                    except StopIteration:
                        valid_iter = iter(self.valid_data)
                        vdata = next(valid_iter)
                    vdata = to_device(vdata, self.device)
                    self.validate(vdata)
                self.step_id += 1

        netlist = [getattr(self, name) for name in self.network_names]

        return netlist
Exemple #19
0
 def sample(self):
     batch = next(iter(self.train_data))
     _, *args = self.data_key(batch)
     args = to_device(args, self.device)
     _, fake = self.generator.sample(*args, sample_shape=self.batch_size)
     improved = self.integrator.integrate(self.score, fake, *args).detach()
     return (fake, improved, *args)
Exemple #20
0
 def run_report(self):
     vdata = None
     try:
         vdata = next(self.valid_iter)
     except StopIteration:
         self.valid_iter = iter(self.validate_data)
         vdata = next(self.valid_iter)
     vdata = to_device(vdata, self.device)
     self.validate(vdata)
 def move_to(self, device):
     return Experience(initial_state=to_device(self.initial_state, device),
                       final_state=to_device(self.final_state, device),
                       action=to_device(self.action, device),
                       reward=to_device(self.reward, device),
                       terminal=to_device(self.terminal, device),
                       logits=to_device(self.logits, device),
                       outputs=to_device(self.outputs, device))
    def sample(self):
        self.score_clone.eval()
        batch = next(iter(self.train_data))
        data, *args = self.data_key(to_device(batch, self.device))
        data = torch.randn_like(data)

        improved = self.integrator.integrate(self.score_clone, data,
                                             *args).detach()
        self.score_clone.train()
        return (improved, *args)
 def init_buffer(self):
     diter = iter(self.train_data)
     for _ in range(self.buffer_size // self.batch_size):
         data = to_device(next(diter), self.device)
         with torch.no_grad():
             key = self.target(data[0]).cpu()
         if self.buffer is None:
             self.buffer = torch.zeros(self.buffer_size, key.size(1))
             self.index = 0
         self.update_buffer(key)
Exemple #24
0
    def decompose_batch(self, data, *args):
        count = len(data)
        targets = [self.device] * count
        gt, protein = args
        protein = protein.chunk(targets)
        result = [
            to_device((data[idx].detach(), gt[idx].detach(), protein[idx]),
                      "cpu") for idx in range(count)
        ]

        return result
Exemple #25
0
 def train(self):
   for epoch_id in range(self.max_epochs):
     for data in self.train_data:
       data = to_device(data, self.device)
       self.step(data)
       self.log()
       self.step_id += 1
     self.schedule_step()
     self.each_epoch()
     self.epoch_id += 1
   return self.net
    def sampler_step(self, data, idx):
        self.sampler_optimizer.zero_grad()
        data = to_device(data, self.device)
        args = self.run_sampler(data)
        loss_val = self.sampler_loss(*args)

        self.writer.add_scalar("sampler total loss", float(loss_val),
                               self.step_id * self.n_sampler + idx)

        loss_val.backward()
        self.sampler_optimizer.step()
Exemple #27
0
    def run_energy(self, data):
        make_differentiable(data)
        input_data, *data_args = self.data_key(data)
        real_result = self.score(input_data, *data_args)

        # sample after first pass over real data, to catch
        # possible batch-norm shenanigans without blowing up.
        fake = self.sample()

        if self.step_id % self.report_interval == 0:
            detached, *args = self.data_key(to_device(fake, "cpu"))
            self.each_generate(detach(detached), *args)

        make_differentiable(fake)
        input_fake, *fake_args = self.data_key(fake)

        self.score.eval()
        fake_update_result = None
        if self.sampler_likelihood:
            # Sampler log likelihood:
            fake_result = self.score(input_fake, *fake_args)
            fake_update = self.integrator.step(self.score, input_fake,
                                               *fake_args)
            self.update_target()
            fake_update_result = self.target_score(fake_update, *fake_args)

        comparison = None
        if self.maximum_entropy:
            # Sampler entropy:
            compare_update = fake_update
            compare_target = to_device(
                self.data_key(self.buffer.sample(self.batch_size))[0],
                compare_update.device)
            if hasattr(self.score, "embed"):
                compare_update = self.target_score.embedding(compare_update)
                compare_target = self.target_score.embedding(compare_target)
            comparison = self.compare(compare_update, compare_target.detach())
        self.score.train()
        return real_result, fake_result, fake_update_result, comparison
Exemple #28
0
    def sample(self, *args, sample_shape=None):
        sample_shape = sample_shape or []
        if not isinstance(sample_shape, (list, tuple, torch.Size)):
            sample_shape = [sample_shape]
        sample_shape = torch.Size(sample_shape)
        total = sample_shape.numel()

        latents = to_device(self.prior.sample((total, )),
                            self.log_conditional_variance.device)
        result = self(latents, *args)
        result = result.view(*sample_shape, *result.shape[1:])
        latents = latents.view(*sample_shape, *latents.shape[1:])
        return latents, result
Exemple #29
0
 def train(self):
     for epoch_id in range(self.max_epochs):
         self.epoch_id = epoch_id
         valid_iter = iter(self.validate_data)
         for data in self.train_data:
             data = to_device(data, self.device)
             self.step(data)
             if self.step_id % self.report_interval == 0:
                 vdata = None
                 try:
                     vdata = next(valid_iter)
                 except StopIteration:
                     valid_iter = iter(self.validate_data)
                     vdata = next(valid_iter)
                 vdata = to_device(vdata, self.device)
                 self.validate(vdata)
             if self.step_id % self.checkpoint_interval == 0:
                 self.checkpoint()
             self.step_id += 1
         self.schedule_step()
         self.each_epoch()
     return self.net
    def auxiliary_step(self):
        self.auxiliary_optimizer.zero_grad()

        data = self.buffer.sample(self.batch_size)
        data = to_device(data, self.device)

        args = self.run_auxiliary(data)
        loss = self.auxiliary_loss(*args)
        loss.backward()

        self.current_losses["auxiliary"] = float(loss)

        self.auxiliary_optimizer.step()