def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm): global global_step print('\nEpoch: %d' % epoch) net.train() loss_meter = util.AverageMeter() with tqdm(total=len(trainloader.dataset)) as progress_bar: for x, _ in trainloader: x = x.to(device) optimizer.zero_grad() z, sldj = net(x, reverse=False) loss = loss_fn(z, sldj) loss_meter.update(loss.item(), x.size(0)) loss.backward() if max_grad_norm > 0: util.clip_grad_norm(optimizer, max_grad_norm) optimizer.step() progress_bar.set_postfix(nll=loss_meter.avg, bpd=util.bits_per_dim(x, loss_meter.avg), lr=optimizer.param_groups[0]['lr']) progress_bar.update(x.size(0)) scheduler.step(global_step) global_step += x.size(0)
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm): global global_step print('\nEpoch: %d' % epoch) net.train() loss_meter = util.AverageMeter() with tqdm(total=len(trainloader.dataset)) as progress_bar: for x, cond_x in trainloader: x , cond_x = x.to(device), cond_x.to(device) optimizer.zero_grad() z, sldj = net(x, cond_x, reverse=False) loss = loss_fn(z, sldj) loss_meter.update(loss.item(), x.size(0)) loss.backward() if max_grad_norm > 0: util.clip_grad_norm(optimizer, max_grad_norm) optimizer.step() scheduler.step(global_step) progress_bar.set_postfix(nll=loss_meter.avg, bpd=util.bits_per_dim(x, loss_meter.avg), lr=optimizer.param_groups[0]['lr']) progress_bar.update(x.size(0)) global_step += x.size(0) print('Saving...') state = { 'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, } torch.save(state, 'savemodel/cINN/checkpoint_' + str(epoch) + '.tar')
def plugin_estimator_training_loop(real_nvp_model, dataloader, learning_rate, optim, device, total_iters, checkpoint_intervals, batchsize, algorithm, weight_decay, max_grad_norm, save_dir, save_suffix): """ Function to train the RealNVP model using a plugin mean estimation algorithm for total_iters with learning_rate """ param_groups = util.get_param_groups(real_nvp_model, weight_decay, norm_suffix='weight_g') optimizer_cons = utils.get_optimizer_cons(optim, learning_rate) optimizer = optimizer_cons(param_groups) loss_fn = RealNVPLoss() flag = False iteration = 0 while not flag: for x, _ in dataloader: # Update iteration counter iteration += 1 x = x.to(device) z, sldj = real_nvp_model(x, reverse=False) unaggregated_loss = loss_fn(z, sldj, aggregate=False) if algorithm.__name__ == 'mean': agg_loss = unaggregated_loss.mean() agg_loss.backward() else: # First sample gradients sgradients = utils.gradient_sampler(unaggregated_loss, real_nvp_model) # Then get the estimate with the mean estimation algorithm stoc_grad = algorithm(sgradients) # Perform the update of .grad attributes with torch.no_grad(): utils.update_grad_attributes( real_nvp_model.parameters(), torch.as_tensor(stoc_grad, device=device)) # Clip gradient if required if max_grad_norm > 0: util.clip_grad_norm(optimizer, max_grad_norm) # Perform the update optimizer.step() if iteration in checkpoint_intervals: print(f"Completed {iteration}") torch.save( real_nvp_model.state_dict(), f"{save_dir}/real_nvp_{algorithm.__name__}_{iteration}_{save_suffix}.pt" ) if iteration == total_iters: flag = True break return real_nvp_model
def train_single_step(net, x, device, optimizer, loss_fn, max_grad_norm): net.train() x = x.to(device) optimizer.zero_grad() z, sldj = net(x, reverse=False) loss = loss_fn(z, sldj) loss.backward() if max_grad_norm > 0: util.clip_grad_norm(optimizer, max_grad_norm) optimizer.step()
def train(epoch, net, trainloader, device, optimizer, loss_fn, max_grad_norm): print('\nEpoch: %d' % epoch) net.train() loss_meter = util.AverageMeter() with tqdm(total=len(trainloader.dataset)) as progress_bar: for x, _ in trainloader: x = x.to(device) optimizer.zero_grad() z, sldj = net(x, reverse=False) loss = loss_fn(z, sldj) loss_meter.update(loss.item(), x.size(0)) loss.backward() util.clip_grad_norm(optimizer, max_grad_norm) optimizer.step() progress_bar.set_postfix(loss=loss_meter.avg, bpd=util.bits_per_dim(x, loss_meter.avg)) progress_bar.update(x.size(0))
def train_iter(self): """Run a training iteration (forward/backward) on a single batch. Important: Call `set_inputs` prior to each call to this function. """ # Forward self.forward() # Backprop the generators self.opt_g.zero_grad() self.backward_g() util.clip_grad_norm(self.opt_g, self.max_grad_norm) self.opt_g.step() # Backprop the discriminators self.opt_d.zero_grad() self.backward_d() util.clip_grad_norm(self.opt_d, self.max_grad_norm) self.opt_d.step()
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm): global global_step print('\nEpoch: %d' % epoch) net.train() loss_meter = util.AverageMeter() correct_class = 0 correct_domain = 0 with tqdm(total=len(trainloader.dataset)) as progress_bar: for x, y, d, yd in trainloader: x, y, d, yd = x.to(device), y.to(device), d.to(device), yd.to( device) optimizer.zero_grad() z1, z2 = net(x) loss2 = loss_fn(z2, d.argmax(dim=1)) loss1 = loss_fn(z1, y.argmax(dim=1)) loss_meter.update(loss2.item(), x.size(0)) loss_meter.update(loss1.item(), x.size(0)) loss = loss1 + loss2 loss.backward() if max_grad_norm > 0: util.clip_grad_norm(optimizer, max_grad_norm) optimizer.step() scheduler.step(global_step) progress_bar.set_postfix(nll=loss_meter.avg, bpd=util.bits_per_dim(x, loss_meter.avg), lr=optimizer.param_groups[0]['lr']) progress_bar.update(x.size(0)) global_step += x.size(0) values_class, pred_class = torch.max(z1, 1) values_domain, pred_domain = torch.max(z2, 1) correct_class += pred_class.eq(y.argmax(dim=1)).sum().item() correct_domain += pred_domain.eq(d.argmax(dim=1)).sum().item() accuracy_class = correct_class * 100. / len(trainloader.dataset) accuracy_domain = correct_domain * 100. / len(trainloader.dataset) print('train accuracy class', accuracy_class) print('train accuracy domain', accuracy_domain) return accuracy_class, accuracy_domain
def train(epoch, net, trainloader, device, optimizer, loss_fn, max_grad_norm, base_path, save=False): print('\nEpoch: %d' % epoch) net.train() loss_meters = [util.AverageMeter() for _ in range(3)] logvars = [] output_vars = [] with tqdm(total=len(trainloader.dataset)) as progress_bar: for x in trainloader: if len(x) == 2 and type(x) is list: x = x[0] x = x.to(device) optimizer.zero_grad() x_hat, mu, logvar, output_var = net(x) loss, reconstruction_loss, kl_loss = loss_fn( x, x_hat, mu, logvar, output_var) loss_meters[0].update(loss.item(), x.size(0)) loss_meters[1].update(reconstruction_loss.item(), x.size(0)) loss_meters[2].update(kl_loss.item(), x.size(0)) loss.backward() util.clip_grad_norm(optimizer, max_grad_norm) optimizer.step() progress_bar.set_postfix(loss=loss_meters[0].avg, rc_loss=loss_meters[1].avg, kl_loss=loss_meters[2].avg) progress_bar.update(x.size(0)) logvars.append(logvar.unsqueeze(0)) output_vars.append(output_var.unsqueeze(0)) if save: logvarfile = 'logvar' + str(epoch) + '.pt' output_varfile = 'output_var' + str(epoch) + '.pt' torch.save(logvars, base_path / logvarfile) torch.save(output_vars, base_path / output_varfile)
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm, mode): global global_step print('\nEpoch: %d' % epoch) net.train() loss_meter = util.AverageMeter() correct = 0 with tqdm(total=len(trainloader.dataset)) as progress_bar: for x, y, d, yd in trainloader: x, y, d, yd = x.to(device), y.to(device), d.to(device), yd.to( device) optimizer.zero_grad() z = net(x) if mode == 'domain': loss = loss_fn(z, d.argmax(dim=1)) else: loss = loss_fn(z, y.argmax(dim=1)) loss_meter.update(loss.item(), x.size(0)) loss.backward() if max_grad_norm > 0: util.clip_grad_norm(optimizer, max_grad_norm) optimizer.step() scheduler.step(global_step) progress_bar.set_postfix(nll=loss_meter.avg, bpd=util.bits_per_dim(x, loss_meter.avg), lr=optimizer.param_groups[0]['lr']) progress_bar.update(x.size(0)) global_step += x.size(0) values, pred = torch.max(z, 1) if mode == 'label': correct += pred.eq(y.argmax(dim=1)).sum().item() else: correct += pred.eq(d.argmax(dim=1)).sum().item() accuracy = correct * 100. / len(trainloader.dataset) print('train accuracy', accuracy) return accuracy
def streaming_approx_training_loop(real_nvp_model, dataloader, learning_rate, optim, device, total_iters, checkpoint_intervals, alpha, batchsize, n_discard, weight_decay, max_grad_norm, save_dir, save_suffix): """ Function to train the RealNVP model using the streaming rank-1 approximation with algorithm for total_iters with optimizer optim """ param_groups = util.get_param_groups(real_nvp_model, weight_decay, norm_suffix='weight_g') optimizer_cons = utils.get_optimizer_cons(optim, learning_rate) optimizer = optimizer_cons(param_groups) loss_fn = RealNVPLoss() flag = False iteration = 0 top_eigvec, top_eigval, running_mean = None, None, None real_nvp_model.train() while not flag: for x, _ in dataloader: # Update iteration counter iteration += 1 x = x.to(device) z, sldj = real_nvp_model(x, reverse=False) unaggregated_loss = loss_fn(z, sldj, aggregate=False) # First sample gradients sgradients = utils.gradient_sampler(unaggregated_loss, real_nvp_model) # Then get the estimate with the previously computed direction stoc_grad, top_eigvec, top_eigval, running_mean = streaming_update_algorithm( sgradients, n_discard=n_discard, top_v=top_eigvec, top_lambda=top_eigval, old_mean=running_mean, alpha=alpha) # Perform the update of .grad attributes with torch.no_grad(): utils.update_grad_attributes( real_nvp_model.parameters(), torch.as_tensor(stoc_grad, device=device)) # Clip gradient if required if max_grad_norm > 0: util.clip_grad_norm(optimizer, max_grad_norm) # Perform the update optimizer.step() if iteration in checkpoint_intervals: print(f"Completed {iteration}") torch.save( real_nvp_model.state_dict(), f"{save_dir}/real_nvp_streaming_approx_{iteration}_{save_suffix}.pt" ) if iteration == total_iters: flag = True break return real_nvp_model
def train(model, embedder, optimizer, scheduler, train_loader, val_loader, opt, writer, device=None): print("TRAINING STARTS") global global_step for epoch in range(opt.n_epochs): print("[Epoch %d/%d]" % (epoch + 1, opt.n_epochs)) model = model.train() loss_to_log = 0.0 loss_fn = util.NLLLoss().to(device) with tqdm(total=len(train_loader.dataset)) as progress_bar: for i, (imgs, labels, captions) in enumerate(train_loader): start_batch = time.time() imgs = imgs.to(device) labels = labels.to(device) with torch.no_grad(): if opt.conditioning == 'unconditional': condition_embd = None else: condition_embd = embedder(labels, captions) optimizer.zero_grad() # outputs = model.forward(imgs, condition_embd) # loss = outputs['loss'].mean() # loss.backward() # optimizer.step() z, sldj = model.forward(imgs, condition_embd, reverse=False) loss = loss_fn(z, sldj) / np.prod(imgs.size()[1:]) loss.backward() if opt.max_grad_norm > 0: util.clip_grad_norm(optimizer, opt.max_grad_norm) optimizer.step() scheduler.step(global_step) batches_done = epoch * len(train_loader) + i writer.add_scalar('train/bpd', loss / np.log(2), batches_done) loss_to_log += loss.item() # if (i + 1) % opt.print_every == 0: # loss_to_log = loss_to_log / (np.log(2) * opt.print_every) # print( # "[Epoch %d/%d] [Batch %d/%d] [bpd: %f] [Time/batch %.3f]" # % (epoch + 1, opt.n_epochs, i + 1, len(train_loader), loss_to_log, time.time() - start_batch) # ) progress_bar.set_postfix(bpd=(loss_to_log / np.log(2)), lr=optimizer.param_groups[0]['lr']) progress_bar.update(imgs.size(0)) global_step += imgs.size(0) loss_to_log = 0.0 if (batches_done + 1) % opt.sample_interval == 0: print("sampling_images") model = model.eval() sample_image(model, embedder, opt.output_dir, n_row=4, batches_done=batches_done, dataloader=val_loader, device=device) val_bpd = eval(model, embedder, val_loader, opt, writer, device=device) writer.add_scalar("val/bpd", val_bpd, (epoch + 1) * len(train_loader)) torch.save( model.state_dict(), os.path.join(opt.output_dir, 'models', 'epoch_{}.pt'.format(epoch)))
def train(epochs, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm): global global_step net.train() loss_meter = util.AverageMeter() evaluator = EvaluationModel() test_conditions = get_test_conditions(os.path.join('test.json')).to(device) new_test_conditions = get_test_conditions( os.path.join('new_test.json')).to(device) best_score = 0 new_best_score = 0 for epoch in range(1, epochs + 1): print('\nEpoch: ', epoch) with tqdm(total=len(trainloader.dataset)) as progress_bar: for x, cond_x in trainloader: x, cond_x = x.to(device, dtype=torch.float), cond_x.to( device, dtype=torch.float) optimizer.zero_grad() z, sldj = net(x, cond_x, reverse=False) loss = loss_fn(z, sldj) wandb.log({'loss': loss}) # print('loss: ',loss) loss_meter.update(loss.item(), x.size(0)) # wandb.log({'loss_meter',loss_meter}) loss.backward() if max_grad_norm > 0: util.clip_grad_norm(optimizer, max_grad_norm) optimizer.step() # scheduler.step(global_step) progress_bar.set_postfix(nll=loss_meter.avg, bpd=util.bits_per_dim( x, loss_meter.avg), lr=optimizer.param_groups[0]['lr']) progress_bar.update(x.size(0)) global_step += x.size(0) net.eval() with torch.no_grad(): gen_imgs = sample(net, test_conditions, device) score = evaluator.eval(gen_imgs, test_conditions) wandb.log({'score': score}) if score > best_score: best_score = score best_model_wts = copy.deepcopy(net.state_dict()) torch.save( best_model_wts, os.path.join('weightings/test', f'epoch{epoch}_score{score:.2f}.pt')) with torch.no_grad(): new_gen_imgs = sample(net, new_test_conditions, device) new_score = evaluator.eval(new_gen_imgs, new_test_conditions) wandb.log({'new_score': new_score}) if new_score > new_best_score: new_best_score = score new_best_model_wts = copy.deepcopy(net.state_dict()) torch.save( best_model_wts, os.path.join('weightings/new_test', f'epoch{epoch}_score{score:.2f}.pt')) save_image(gen_imgs, os.path.join('results/test', f'epoch{epoch}.png'), nrow=8, normalize=True) save_image(new_gen_imgs, os.path.join('results/new_test', f'epoch{epoch}.png'), nrow=8, normalize=True)