def test(model, device, epoch, ema, data_loader, tag, root_process): # convert model to evaluation mode (no Dropout etc.) model.eval() # setup the reconstruction dataset recon_dataset = None nbatches = data_loader.batch_sampler.sampler.num_samples // data_loader.batch_size recon_batch_idx = int(torch.Tensor(1).random_(0, nbatches - 1)) # setup testing metrics if root_process: logrecons = torch.zeros((nbatches), device=device) logdecs = torch.zeros((nbatches, model.nz), device=device) logencs = torch.zeros((nbatches, model.nz), device=device) elbos = [] # switch to EMA parameters for evaluation for name, param in model.named_parameters(): if param.requires_grad: param.data = ema.get_ema(name) # allocate memory for the input data data = torch.zeros((data_loader.batch_size,) + model.xs, device=device) # enumerate over the batches for batch_idx, (batch, _) in enumerate(data_loader): # save batch for reconstruction if batch_idx == recon_batch_idx: recon_dataset = data # copy the mini-batch in the pre-allocated data-variable data.copy_(batch) with torch.no_grad(): # evaluate the data under the model and calculate ELBO components logrecon, logdec, logenc, _ = model.loss(data) # construct the ELBO elbo = -logrecon + torch.sum(-logdec + logenc) # compute the inference- and generative-model loss logdec = torch.sum(logdec, dim=1) logenc = torch.sum(logenc, dim=1) if root_process: # scale by image dimensions to get "bits/dim" elbo *= model.perdimsscale logrecon *= model.perdimsscale logdec *= model.perdimsscale logenc *= model.perdimsscale elbos.append(elbo.item()) # log logrecons[batch_idx] += logrecon logdecs[batch_idx] += logdec logencs[batch_idx] += logenc if root_process: elbo = np.mean(elbos) entrecon = -torch.mean(logrecons).detach().cpu().numpy() entdec = -torch.mean(logdecs, dim=0).detach().cpu().numpy() entenc = -torch.mean(logencs, dim=0).detach().cpu().numpy() kl = entdec - entenc # print metrics to console and Tensorboard print(f'\nEpoch: {epoch}\tTest loss: {elbo:.6f}') model.logger.add_scalar('elbo/test', elbo, epoch) # log to Tensorboard model.logger.add_scalar('x/reconstruction/test', entrecon, epoch) for i in range(1, logdec.shape[0] + 1): model.logger.add_scalar(f'z{i}/encoder/test', entenc[i - 1], epoch) model.logger.add_scalar(f'z{i}/decoder/test', entdec[i - 1], epoch) model.logger.add_scalar(f'z{i}/KL/test', kl[i - 1], epoch) # if the current ELBO is better than the ELBO's before, save parameters if elbo < model.best_elbo and not np.isnan(elbo): model.logger.add_scalar('elbo/besttest', elbo, epoch) if not os.path.exists(f'params/mnist/'): os.makedirs(f'params/mnist/') torch.save(model.state_dict(), f'params/mnist/{tag}') if epoch % 25 == 0: torch.save(model.state_dict(), f'params/mnist/epoch{epoch}_{tag}') print("saved params\n") model.best_elbo = elbo model.sample(device, epoch) model.reconstruct(recon_dataset, device, epoch) else: print("loss did not improve\n")
def train(model, device, epoch, data_loader, optimizer, ema, log_interval, root_process, schedule=True, decay=0.99995): # convert model to train mode (activate Dropout etc.) model.train() # get number of batches nbatches = data_loader.batch_sampler.sampler.num_samples // data_loader.batch_size # switch to parameters not affected by exponential moving average decay for name, param in model.named_parameters(): if param.requires_grad: param.data = ema.get_default(name) # setup training metrics if root_process: elbos = torch.zeros((nbatches), device=device) logrecons = torch.zeros((nbatches), device=device) logdecs = torch.zeros((nbatches, model.nz), device=device) logencs = torch.zeros((nbatches, model.nz), device=device) if root_process: start_time = time.time() # allocate memory for data data = torch.zeros((data_loader.batch_size,) + model.xs, device=device) # enumerate over the batches for batch_idx, (batch, _) in enumerate(data_loader): # keep track of the global step global_step = (epoch - 1) * len(data_loader) + (batch_idx + 1) # update the learning rate according to schedule if schedule: for param_group in optimizer.param_groups: lr = param_group['lr'] lr = lr_step(global_step, lr, decay=decay) param_group['lr'] = lr # empty all the gradients stored optimizer.zero_grad() # copy the mini-batch in the pre-allocated data-variable data.copy_(batch) # evaluate the data under the model and calculate ELBO components logrecon, logdec, logenc, zsamples = model.loss(data) # free bits technique, in order to prevent posterior collapse bits_pc = 1. kl = torch.sum(torch.max(-logdec + logenc, bits_pc * torch.ones((model.nz, model.zdim[0]), device=device))) # compute the inference- and generative-model loss logdec = torch.sum(logdec, dim=1) logenc = torch.sum(logenc, dim=1) # construct ELBO elbo = -logrecon + kl # scale by image dimensions to get "bits/dim" elbo *= model.perdimsscale logrecon *= model.perdimsscale logdec *= model.perdimsscale logenc *= model.perdimsscale # calculate gradients elbo.backward() # take gradient step total_norm = nn.utils.clip_grad_norm_(model.parameters(), 1., norm_type=2) optimizer.step() # log gradient norm if root_process: model.logger.add_scalar('gnorm', total_norm, global_step) # do ema update on parameters used for evaluation if root_process: with torch.no_grad(): for name, param in model.named_parameters(): if param.requires_grad: ema(name, param.data) # log if root_process: elbos[batch_idx] += elbo logrecons[batch_idx] += logrecon logdecs[batch_idx] += logdec logencs[batch_idx] += logenc # log and save parameters if root_process and batch_idx % log_interval == 0 and log_interval < nbatches: # print metrics to console print(f'Train Epoch: {epoch} [{batch_idx}/{nbatches} ({100. * batch_idx / len(data_loader):.0f}%)]\tLoss: {elbo.item():.6f}\tGnorm: {total_norm:.2f}\tSteps/sec: {(time.time() - start_time) / (batch_idx + 1):.3f}') model.logger.add_scalar('step-sec', (time.time() - start_time) / (batch_idx + 1), global_step) entrecon = -logrecon entdec = -logdec entenc = -logenc kl = entdec - entenc # log model.logger.add_scalar('elbo/train', elbo, global_step) for param_group in optimizer.param_groups: lr = param_group['lr'] model.logger.add_scalar('lr', lr, global_step) model.logger.add_scalar('x/reconstruction/train', entrecon, global_step) for i in range(1, logdec.shape[0] + 1): model.logger.add_scalar(f'z{i}/encoder/train', entenc[i - 1], global_step) model.logger.add_scalar(f'z{i}/decoder/train', entdec[i - 1], global_step) model.logger.add_scalar(f'z{i}/KL/train', kl[i - 1], global_step) # save training params, to be able to return to these values after evaluation with torch.no_grad(): for name, param in model.named_parameters(): if param.requires_grad: ema.register_default(name, param.data) # print the average loss of the epoch to the console if root_process: elbo = torch.mean(elbos).detach().cpu().numpy() print(f'====> Epoch: {epoch} Average loss: {elbo:.4f}')