def loss_on_batch(self, batch): noiseness = compute_param_by_scheme(self.noiseness_scheme, self.num_iters_done) dropword_p = compute_param_by_scheme(self.decoder_dropword_scheme, self.num_iters_done) batch.text = cudable(batch.text) (means, log_stds), predictions = self.vae(batch.text, noiseness, dropword_p) rec_loss = self.rec_criterion(predictions.view(-1, len(self.vocab)), batch.text[:, 1:].contiguous().view(-1)) kl_loss = self.kl_criterion(means, log_stds.exp()) return rec_loss, kl_loss, (means, log_stds)
def train_on_batch(self, batch): self.train_mode() # Computing losses desired_kl = compute_param_by_scheme(self.desired_kl_val, self.num_iters_done) force_kl = compute_param_by_scheme(self.force_kl, self.num_iters_done) rec_loss, kl_loss, (means, log_stds) = self.loss_on_batch(batch) loss = rec_loss + force_kl * abs(kl_loss - desired_kl) # Optimization step self.optimizer.zero_grad() loss.backward() grad_norm = math.sqrt( (sum([w.grad.norm()**2 for w in self.vae.parameters()]))) weights_norm = math.sqrt( (sum([w.norm()**2 for w in self.vae.parameters()]))) weights_l_inf_norm = max( [w.abs().max() for w in self.vae.parameters()]) if 'grad_clip' in self.config['hp']: clip_grad_norm_(self.vae.parameters(), self.config['hp']['grad_clip']) self.optimizer.step() # Logging stuff self.writer.add_scalar('Total loss', loss, self.num_iters_done) self.writer.add_scalar('CE loss', rec_loss, self.num_iters_done) self.writer.add_scalar('KL loss', kl_loss, self.num_iters_done) self.writer.add_scalar('Desired KL', desired_kl, self.num_iters_done) self.writer.add_scalar('Force KL', force_kl, self.num_iters_done) self.writer.add_scalar('Means norm', means.norm(dim=1).mean(), self.num_iters_done) self.writer.add_scalar('Stds norm', log_stds.exp().norm(dim=1).mean(), self.num_iters_done) self.writer.add_scalar('Grad norm', grad_norm, self.num_iters_done) self.writer.add_scalar('Weights norm', weights_norm, self.num_iters_done) self.writer.add_scalar('Weights l_inf norm', weights_l_inf_norm, self.num_iters_done)
def train_on_batch(self, batch): loss = self.loss_on_batch(batch) noiseness = compute_param_by_scheme(self.noiseness_scheme, self.num_iters_done) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.writer.add_scalar('CE loss', loss, self.num_iters_done) self.writer.add_scalar('Noiseness', noiseness, self.num_iters_done)
def inference(self, dataloader): """ Produces predictions for a given dataloader """ seqs = [] originals = [] noiseness = compute_param_by_scheme(self.noiseness_scheme, self.num_iters_done) for batch in dataloader: inputs = cudable(batch.text) seqs.extend(self.vae.inference(inputs, self.vocab, noiseness)) originals.extend(inputs.detach().cpu().numpy().tolist()) return itos_many(seqs, self.vocab), itos_many(originals, self.vocab)
def compute_noise(self, size): noiseness = compute_param_by_scheme(self.noiseness_scheme, self.num_iters_done) stds = cudable(torch.from_numpy(np.random.normal(size=size)).float()) return stds * noiseness