def train_epoch(self, corpus, train_data, criterion, optimizer, epoch, args): self.train() dataset_size = len(train_data.data()) # this will be approximate total_loss = 0 total_tokens = 0 batch_idx = 0 for batch in train_data: if epoch > args.kl_anneal_delay and args.anneal < 1.: args.anneal += args.kl_anneal_rate optimizer.zero_grad() data, targets = batch.text, batch.target logits, mean, logvar = self.forward(data, targets, args) elbo, NLL, KL, tokens = self.elbo(logits, data, criterion, mean, logvar, args, args.anneal) loss = elbo/tokens loss.backward() torch.nn.utils.clip_grad_norm(self.parameters(), args.clip) optimizer.step() total_loss += elbo.detach() total_tokens += tokens # print if necessary if batch_idx % args.log_interval == 0 and batch_idx > 0: print_in_epoch_summary(epoch, batch_idx, args.batch_size, dataset_size, loss.data[0], NLL.data[0] / tokens, {'normal': KL.data[0] / tokens}, tokens, "anneal={:.2f}".format(args.anneal)) batch_idx += 1 # because no cheap generator smh return total_loss[0] / total_tokens
def train_loop(profile=False): total_loss = 0 total_tokens = 0 total_resamples = 0 batch_idx = 0 # for pretty printing the loss in each chunk last_chunk_loss = 0 last_chunk_tokens = 0 last_chunk_resamples = 0 for batch in tqdm(train_data): if profile and batch_idx > 10: print("breaking because profiling finished;") break if epoch > args.kl_anneal_delay: args.anneal = min(args.anneal + args.kl_anneal_rate, 1.) optimizer.zero_grad() data, targets = batch.text, batch.target elbo, NLL, tokens, resamples = self.forward( data, targets, args, num_importance_samples, criterion) loss = elbo / tokens loss.backward() torch.nn.utils.clip_grad_norm(self.parameters(), args.clip) optimizer.step() total_loss += elbo.detach() total_tokens += tokens total_resamples += resamples # print if necessary if batch_idx % args.log_interval == 0 and batch_idx > 0: chunk_loss = total_loss.data[0] - last_chunk_loss chunk_tokens = total_tokens - last_chunk_tokens chunk_resamples = (total_resamples - last_chunk_resamples ) / args.log_interval print(total_resamples) print_in_epoch_summary( epoch, batch_idx, args.batch_size, dataset_size, loss.data[0], NLL / tokens, { 'Chunk Loss': chunk_loss / chunk_tokens, 'resamples': chunk_resamples }, tokens, "anneal={:.2f}".format(args.anneal)) last_chunk_loss = total_loss.data[0] last_chunk_tokens = total_tokens last_chunk_resamples = total_resamples batch_idx += 1 # because no cheap generator smh return total_loss, total_tokens