def execute_graph(model, conditional, data_loader, loss_fn, scheduler, optimizer, use_visdom, use_tb): # Training los loss_fn t_loss = train_validate(model, data_loader, loss_fn, optimizer, conditional, train=True) # Validation loss v_loss = train_validate(model, data_loader, loss_fn, optimizer, conditional, train=False) # Step the scheduler based on the validation loss scheduler.step(v_loss) print('====> Epoch: {} Average Train loss: {:.4f}'.format(epoch, t_loss)) print('====> Epoch: {} Average Validation loss: {:.4f}'.format( epoch, v_loss)) if use_tb: # Training and validation loss logger.add_scalar(log_dir + '/validation-loss', v_loss, epoch) logger.add_scalar(log_dir + '/training-loss', t_loss, epoch) # todo: log gradient values of the model # image generation examples sample = generation_example(model, latent_size, data_loader, conditional, args.cuda) sample = sample.detach() sample = tvu.make_grid(sample, normalize=False, scale_each=True) logger.add_image('generation example', sample, epoch) # image reconstruction examples comparison = reconstruction_example(model, data_loader, conditional, args.cuda) comparison = comparison.detach() comparison = tvu.make_grid(comparison, normalize=False, scale_each=True) logger.add_image('reconstruction example', comparison, epoch) if use_visdom: # Visdom: update training and validation loss plots vis.add_scalar(t_loss, epoch, 'Training loss', idtag='train') vis.add_scalar(v_loss, epoch, 'Validation loss', idtag='valid') # Visdom: Show generated images sample = generation_example(model, latent_size, data_loader, conditional, args.cuda) sample = sample.detach().numpy() vis.add_image(sample, 'Generated sample ' + str(epoch), 'generated') # Visdom: Show example reconstruction from the test set comparison = reconstruction_example(model, data_loader, conditional, args.cuda) comparison = comparison.detach().numpy() vis.add_image(comparison, 'Reconstruction sample ' + str(epoch), 'recon') return v_loss
{ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'val_loss': v_loss }, 'models/INFOVAE_{:04.4f}.pt'.format(v_loss)) if stop: print('Early stopping at epoch: {}'.format(epoch)) break # Write a final sample to disk sample = generation_example(model, latent_size, data_loader, conditional, args.cuda) save_image(sample, 'output/sample_' + str(num_epochs) + '.png') # Make a final reconstruction, and write to disk comparison = reconstruction_example(model, data_loader, conditional, args.cuda) save_image(comparison, 'output/comparison_' + str(num_epochs) + '.png') # latent space scatter example if args.latent_size == 2: centroids, labels = latentcluster2d_example(model, data_loader, args.conditional, args.cuda) cmap = ['b', 'g', 'r', 'c', 'y', 'm', 'k'] colors = [cmap[(int(i) % 7)] for i in labels] fig = plt.figure() plt.scatter(centroids[:, 0], centroids[:, 1], c=colors, cmap=plt.cm.Spectral) plt.savefig('output/InfoVAE_z_cluster.png') plt.close(fig)