def train(data, model, opt, report_freq): """Perform a training epoch.""" ravg = RunningAverage() model.train() for step, task in enumerate(data): y_mean, y_std = model(task['x_context'], task['y_context'], task['x_target']) obj = -gaussian_logpdf(task['y_target'], y_mean, y_std, 'batched_mean') obj.backward() opt.step() opt.zero_grad() ravg.update(obj.item() / data.batch_size, data.batch_size) report_loss('Training', ravg.avg, step, report_freq) return ravg.avg
def validate(data, model, report_freq=None): """Compute the validation loss.""" ravg = RunningAverage() model.eval() with torch.no_grad(): for step, task in enumerate(data): y_mean, y_std = \ model(task['x_context'], task['y_context'], task['x_target']) obj = \ -gaussian_logpdf(task['y_target'], y_mean, y_std, 'batched_mean') ravg.update(obj.item() / data.batch_size, data.batch_size) if report_freq: report_loss('Validation', ravg.avg, step, report_freq) return ravg.avg
def validate(data, model, report_freq=None): """Compute the validation loss.""" model.eval() likelihoods = [] with torch.no_grad(): for step, task in enumerate(data): num_target = task['y_target'].shape[1] y_mean, y_std = \ model(task['x_context'], task['y_context'], task['x_target']) obj = \ gaussian_logpdf(task['y_target'], y_mean, y_std, 'batched_mean') likelihoods.append(obj.item() / num_target) if report_freq: avg_ll = np.array(likelihoods).mean() report_loss('Validation', avg_ll, step, report_freq) avg_ll = np.array(likelihoods).mean() return avg_ll
def train(data, model, opt, report_freq): """Perform a training epoch.""" model.train() losses = [] for step, task in enumerate(data): y_mean, y_std = model(task['x_context'], task['y_context'], task['x_target']) obj = -gaussian_logpdf(task['y_target'], y_mean, y_std, 'batched_mean') # Optimization obj.backward() opt.step() opt.zero_grad() # Track training progress losses.append(obj.item()) avg_loss = np.array(losses).mean() report_loss('Training', avg_loss, step, report_freq) return avg_loss
model.to(device) # Perform training. opt = torch.optim.Adam(model.parameters(), args.learning_rate, weight_decay=args.weight_decay) if args.train: # Run the training loop, maintaining the best objective value. best_obj = np.inf for epoch in range(args.epochs): print('\nEpoch: {}/{}'.format(epoch + 1, args.epochs)) # Compute training objective. train_obj = train(gen, model, opt, report_freq=50) report_loss('Training', train_obj, 'epoch') # Compute validation objective. val_obj = validate(gen_val, model, report_freq=20) report_loss('Validation', val_obj, 'epoch') # Update the best objective value and checkpoint the model. is_best = False if val_obj < best_obj: best_obj = val_obj is_best = True save_checkpoint(wd, { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc_top1': best_obj, 'optimizer': opt.state_dict()