def conditional_mle_training(model, data, valid_data=None, optimizer=torch.optim.Adam, optimizer_kwargs=None, eval_no_grad=True, **kwargs): opt = filter_kwargs(kwargs, ctx=TrainingContext) ctx = TrainingContext(**opt.ctx) ctx.optimizer = optimizer # networks to device ctx.register(data=to_device(data, ctx.device), model=to_device(model, ctx.device)) ctx.add(train_step=UpdateStep( partial(maximum_likelihood_step, ctx.model, ctx.data), Update( [ctx.model], optimizer=ctx.optimizer, **(optimizer_kwargs or {})), ctx=ctx)) if valid_data is not None: ctx.register(valid_data=to_device(valid_data, ctx.device)) ctx.add(valid_step=EvalStep(partial(maximum_likelihood_step, ctx.model, ctx.valid_data), modules=[ctx.model], no_grad=eval_no_grad, ctx=ctx), every=ctx.report_interval) return ctx
def supervised_training(net, data, valid_data=None, losses=None, optimizer=torch.optim.Adam, optimizer_kwargs=None, eval_no_grad=True, **kwargs): opt = filter_kwargs(kwargs, ctx=TrainingContext) ctx = TrainingContext(**opt.ctx) ctx.optimizer = optimizer ctx.losses = losses # networks to device ctx.register(data=to_device(data, ctx.device), net=to_device(net, ctx.device)) ctx.add(train_step=UpdateStep( partial(supervised_step, ctx.net, ctx.data, losses=ctx.losses), Update([ctx.net], optimizer=ctx.optimizer, **(optimizer_kwargs or {})), ctx=ctx)) if valid_data is not None: ctx.register(valid_data=to_device(valid_data, ctx.device)) ctx.add(valid_step=EvalStep(partial(supervised_step, ctx.net, ctx.valid_data, losses=ctx.losses), modules=[ctx.net], no_grad=eval_no_grad, ctx=ctx), every=ctx.report_interval) return ctx
def decompose_batch(self, data, *args): if isinstance(data, (list, tuple)): return (to_device(([item[idx] for item in data], *[arg[idx] for arg in args]), "cpu") for idx in range(len(data[0]))) return (to_device((data[idx], *[arg[idx] for arg in args]), "cpu") for idx in range(len(data)))
def sample(self): scr = self.score self.score.eval() integrator = self.integrator prep = to_device(self.prepare_sample(), self.device) data, *args = self.data_key(prep) result = integrator.integrate(scr, data, *args).detach() self.score.train() return to_device((result, data, *args), self.device)
def validate(self, data): with torch.no_grad(): self.net.eval() outputs = self.run_networks(data) self.valid_loss(outputs) self.each_validate() self.valid_callback(self, to_device(data, "cpu"), to_device(outputs, "cpu")) self.net.train()
def out_of_sample(self): result = [] for idx in range(self.batch_size): sample = self.bad_prepare() result.append(sample) data, *args = to_device(default_collate(result), self.device) #data = self.integrator.integrate(self.score, data, *args).detach() detached = to_device(data.detach(), "cpu") return to_device((data, *args), self.device)
def sample(self): self.score.eval() with torch.no_grad(): integrator = AnnealedLangevin( [self.sigma * self.factor**idx for idx in range(self.n_sigma)]) prep = to_device(self.prepare_sample(), self.device) data, *args = self.data_key(prep) result = integrator.integrate(self.score, data, *args).detach() self.score.train() return to_device((result, data, *args), self.device)
def sample(self): buffer_iter = iter(self.buffer_loader(self.buffer)) data, *args = to_device(self.data_key(next(buffer_iter)), self.device) data = self.integrator.integrate(self.score, data, *args).detach() detached = data.detach().cpu() update = (to_device((detached[idx], *[arg[idx] for arg in args]), "cpu") for idx in range(data.size(0))) make_differentiable(update, toggle=False) self.buffer.update(update) return to_device((detached, *args), self.device)
def base_sm_training(energy, data, optimizer=torch.optim.Adam, **kwargs): opt = filter_kwargs(kwargs, ctx=TrainingContext) ctx = TrainingContext(**opt.ctx) ctx.optimizer = optimizer # networks to device energy_target = deepcopy(energy) ctx.register(data=to_device(data, ctx.device), energy=to_device(energy, ctx.device), energy_target=to_device(energy_target, ctx.device)) return ctx
def sample(self): buffer_iter = iter(self.buffer_loader(self.buffer)) data, *args = to_device(self.data_key(next(buffer_iter)), self.device) self.score.eval() data = detach(self.integrator.integrate(self.score, data, *args)) self.score.train() detached = to_device(data, "cpu") args = detach(args) update = self.decompose_batch(detached, *args) self.buffer.update(update) return to_device((detached, *args), self.device)
def sample(self): buffer_iter = iter(self.buffer_loader(self.buffer)) data, *args = to_device(self.data_key(next(buffer_iter)), self.device) self.score.eval() data = self.integrator.integrate(self.create_score(), data, *args).detach() self.score.train() detached = to_device(data.detach(), "cpu") update = self.decompose_batch(detached, *args) make_differentiable(update, toggle=False) self.buffer.update(update) return to_device((detached, *args), self.device)
def validate(self, data): with torch.no_grad(): self.net.eval() if self.accumulate is not None: point = self.chunk(data, self.accumulate)[0] outputs = self.run_networks(point) else: outputs = self.run_networks(data) self.valid_loss(outputs) self.each_validate() self.valid_callback(self, to_device(data, "cpu"), to_device(outputs, "cpu")) self.net.train()
def validate(self, data): with torch.no_grad(): self.net.eval() inputs, *label = data inputs, label = inputs.to(self.device), list(map(lambda x: x.to(self.device), label)) support, support_label = next(self.valid_support_loader) support = support[0].to(self.device) support_label = support_label[0].to(self.device) outputs = self.run_networks(inputs, support, support_label) self.valid_loss(outputs) self.each_validate() self.valid_callback(self, to_device(inputs, "cpu"), to_device(outputs, "cpu")) self.net.train()
def run_energy(self, data): data, labels = data make_differentiable(data) input_data, *data_args = self.data_key(data) real_logits = self.score(input_data, *data_args) # sample after first pass over real data, to catch # possible batch-norm shenanigans without blowing up. fake = self.sample() if self.step_id % self.report_interval == 0: detached, *args = self.data_key(to_device(fake, "cpu")) self.each_generate(detached.detach(), *args) make_differentiable(fake) input_fake, *fake_args = self.data_key(fake) fake_logits = self.score(input_fake, *fake_args) real_result = self.logit_energy(real_logits, *data_args) fake_result = self.logit_energy(fake_logits, *fake_args) # set integrator target, if appropriate. if self.integrator.target is None: self.integrator.target = real_result.detach().mean(dim=0) self.integrator.target = 0.6 * self.integrator.target + 0.4 * real_result.detach( ).mean(dim=0) return real_result, fake_result, real_logits, labels
def step(self, data): data = to_device(data, self.device) real, *args = self.data_key(data) latent, fake = self.posterior_step(real, args) self.energy_step(real, fake, args) self.generator_step(latent, fake, args) self.each_step()
def step(self, data): data = to_device(data, self.device) data, *netargs = self.preprocess(data) self.critic_optimizer.zero_grad() pass_through, *critic_args = self.run_critic(data, *netargs) loss_val = self.critic_loss(*critic_args) loss_val.backward(retain_graph=True) self.writer.add_scalar("critic loss", float(loss_val), self.step_id) self.critic_optimizer.step() self.generator_optimizer.zero_grad() generator_args = self.run_generator(data, *pass_through, *netargs) loss_val = self.generator_loss(*generator_args) loss_val.backward() self.writer.add_scalar("generator loss", float(loss_val), self.step_id) self.generator_optimizer.step() if self.verbose: for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.each_step() return float(loss_val)
def step(self, data): """Performs a single step of VAE training. Args: data: data points used for training.""" self.optimizer.zero_grad() data = to_device(data, self.device) data, *netargs = self.preprocess(data) args = self.run_networks(data, *netargs) loss_val = self.loss(*args) if self.verbose: if self.step_id % self.report_interval == 0: self.each_generate(*args) for loss_name in self.current_losses: loss_float = self.current_losses[loss_name] self.writer.add_scalar(f"{loss_name} loss", loss_float, self.step_id) self.writer.add_scalar("total loss", float(loss_val), self.step_id) loss_val.backward() parameters = [ param for key, val in self.get_netlist(self.network_names).items() for param in val.parameters() ] gn = nn.utils.clip_grad_norm_(parameters, self.gradient_clip) if (not torch.isnan(gn).any()) and (gn < self.gradient_skip).all(): self.optimizer.step() self.each_step() return float(loss_val)
def train(self): """Trains a VAE until the maximum number of epochs is reached.""" for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id self.train_data = None self.train_data = DataLoader(self.data, batch_size=self.batch_size, num_workers=8, shuffle=True) if self.valid is not None: self.valid_data = DataLoader(self.valid, batch_size=self.batch_size, num_workers=8, shuffle=True) for data in self.train_data: self.step(data) if self.step_id % self.checkpoint_interval == 0: self.checkpoint() if self.valid is not None and self.step_id % self.report_interval == 0: vdata = None try: vdata = next(valid_iter) except StopIteration: valid_iter = iter(self.valid_data) vdata = next(valid_iter) vdata = to_device(vdata, self.device) self.validate(vdata) self.step_id += 1 netlist = [getattr(self, name) for name in self.network_names] return netlist
def sample(self): batch = next(iter(self.train_data)) _, *args = self.data_key(batch) args = to_device(args, self.device) _, fake = self.generator.sample(*args, sample_shape=self.batch_size) improved = self.integrator.integrate(self.score, fake, *args).detach() return (fake, improved, *args)
def run_report(self): vdata = None try: vdata = next(self.valid_iter) except StopIteration: self.valid_iter = iter(self.validate_data) vdata = next(self.valid_iter) vdata = to_device(vdata, self.device) self.validate(vdata)
def move_to(self, device): return Experience(initial_state=to_device(self.initial_state, device), final_state=to_device(self.final_state, device), action=to_device(self.action, device), reward=to_device(self.reward, device), terminal=to_device(self.terminal, device), logits=to_device(self.logits, device), outputs=to_device(self.outputs, device))
def sample(self): self.score_clone.eval() batch = next(iter(self.train_data)) data, *args = self.data_key(to_device(batch, self.device)) data = torch.randn_like(data) improved = self.integrator.integrate(self.score_clone, data, *args).detach() self.score_clone.train() return (improved, *args)
def init_buffer(self): diter = iter(self.train_data) for _ in range(self.buffer_size // self.batch_size): data = to_device(next(diter), self.device) with torch.no_grad(): key = self.target(data[0]).cpu() if self.buffer is None: self.buffer = torch.zeros(self.buffer_size, key.size(1)) self.index = 0 self.update_buffer(key)
def decompose_batch(self, data, *args): count = len(data) targets = [self.device] * count gt, protein = args protein = protein.chunk(targets) result = [ to_device((data[idx].detach(), gt[idx].detach(), protein[idx]), "cpu") for idx in range(count) ] return result
def train(self): for epoch_id in range(self.max_epochs): for data in self.train_data: data = to_device(data, self.device) self.step(data) self.log() self.step_id += 1 self.schedule_step() self.each_epoch() self.epoch_id += 1 return self.net
def sampler_step(self, data, idx): self.sampler_optimizer.zero_grad() data = to_device(data, self.device) args = self.run_sampler(data) loss_val = self.sampler_loss(*args) self.writer.add_scalar("sampler total loss", float(loss_val), self.step_id * self.n_sampler + idx) loss_val.backward() self.sampler_optimizer.step()
def run_energy(self, data): make_differentiable(data) input_data, *data_args = self.data_key(data) real_result = self.score(input_data, *data_args) # sample after first pass over real data, to catch # possible batch-norm shenanigans without blowing up. fake = self.sample() if self.step_id % self.report_interval == 0: detached, *args = self.data_key(to_device(fake, "cpu")) self.each_generate(detach(detached), *args) make_differentiable(fake) input_fake, *fake_args = self.data_key(fake) self.score.eval() fake_update_result = None if self.sampler_likelihood: # Sampler log likelihood: fake_result = self.score(input_fake, *fake_args) fake_update = self.integrator.step(self.score, input_fake, *fake_args) self.update_target() fake_update_result = self.target_score(fake_update, *fake_args) comparison = None if self.maximum_entropy: # Sampler entropy: compare_update = fake_update compare_target = to_device( self.data_key(self.buffer.sample(self.batch_size))[0], compare_update.device) if hasattr(self.score, "embed"): compare_update = self.target_score.embedding(compare_update) compare_target = self.target_score.embedding(compare_target) comparison = self.compare(compare_update, compare_target.detach()) self.score.train() return real_result, fake_result, fake_update_result, comparison
def sample(self, *args, sample_shape=None): sample_shape = sample_shape or [] if not isinstance(sample_shape, (list, tuple, torch.Size)): sample_shape = [sample_shape] sample_shape = torch.Size(sample_shape) total = sample_shape.numel() latents = to_device(self.prior.sample((total, )), self.log_conditional_variance.device) result = self(latents, *args) result = result.view(*sample_shape, *result.shape[1:]) latents = latents.view(*sample_shape, *latents.shape[1:]) return latents, result
def train(self): for epoch_id in range(self.max_epochs): self.epoch_id = epoch_id valid_iter = iter(self.validate_data) for data in self.train_data: data = to_device(data, self.device) self.step(data) if self.step_id % self.report_interval == 0: vdata = None try: vdata = next(valid_iter) except StopIteration: valid_iter = iter(self.validate_data) vdata = next(valid_iter) vdata = to_device(vdata, self.device) self.validate(vdata) if self.step_id % self.checkpoint_interval == 0: self.checkpoint() self.step_id += 1 self.schedule_step() self.each_epoch() return self.net
def auxiliary_step(self): self.auxiliary_optimizer.zero_grad() data = self.buffer.sample(self.batch_size) data = to_device(data, self.device) args = self.run_auxiliary(data) loss = self.auxiliary_loss(*args) loss.backward() self.current_losses["auxiliary"] = float(loss) self.auxiliary_optimizer.step()