def load_cortex(path, args): """Loads a cortex from path.""" bn = False if args.loss_type == 'wasserstein' else True inference = Inference(args.noise_dim, args.n_filters, 1 if args.dataset == 'mnist' else 3, image_size=args.image_size, bn=args.bn, hard_norm=args.divisive_normalization, spec_norm=args.spec_norm, derelu=False) generator = Generator(args.noise_dim, args.n_filters, 1 if args.dataset == 'mnist' else 3, image_size=args.image_size, hard_norm=args.divisive_normalization) if os.path.isfile(path): print("=> loading checkpoint '{}'".format(path)) # load onto the CPU checkpoint = torch.load(path, map_location=torch.device('cpu')) inference.load_state_dict(checkpoint['inference_state_dict']) generator.load_state_dict(checkpoint['generator_state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( path, checkpoint['epoch'])) else: raise IOError("=> no checkpoint found at '{}'".format(path)) return inference, generator
def load_checkpoint(args): """Loads a cortex from path.""" path = args.path + '/checkpoint.pth.tar' generator = Generator(args.noise_dim, args.n_filters, 1 if args.dataset == 'mnist' else 3, image_size=args.image_size, hard_norm=args.divisive_normalization) if os.path.isfile(path): print("=> loading checkpoint '{}'".format(path)) # load onto the CPU checkpoint = torch.load(path, map_location=torch.device('cpu')) generator.load_state_dict(checkpoint['generator_state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( path, checkpoint['epoch'])) else: raise IOError("=> no checkpoint found at '{}'".format(path)) return generator
def morph_project_only(im1: np.ndarray, im2: np.ndarray, Generator: Generator, Encoder: Encoder, pix2pix: networks.UnetGenerator, epsilon: float = 20.0, L: int = 9, dcgan_size: int = 64, pix2pix_size: int = 128, simulation_name: str = "image_interpolation", results_path: str = "results") -> None: """Generates 3 morphing processes given two images. The first is simple Wasserstein Barycenters, the second is our algorithm and the third is a simple GAN latent space linear interpolation Arguments: im1 {np.ndarray} -- source image im2 {np.ndarray} -- destination image Generator {Generator} -- DCGAN generator (latent space to pixel space) Encoder {Encoder} -- DCGAN encoder (pixel space to latent space) pix2pix {networks.UnetGenerator} -- pix2pix model trained to increase an image resolution Keyword Arguments: epsilon {float} -- entropic regularization parameter (default: {20.0}) L {int} -- number of images in the trasformation (default: {9}) dcgan_size {int} -- DCGAN image size (low resolution) (default: {64}) pix2pix_size {int} -- Pix2Pix image size (high resolution) (default: {128}) simulation_name {str} -- name of the simulation. Affects the saved file names (default: {"image_interpolation"}) results_path {str} -- the path to save the results in (default: {"results"}) """ img_size = im1.shape[:2] im1, im2 = (I.transpose(2, 0, 1).reshape(3, -1, 1) for I in (im1, im2)) print("Preparing transportation cost matrix...") C = generate_metric(img_size) Q = np.concatenate([im1, im2], axis=-1) Q, max_val, Q_counts = preprocess_Q(Q) out_ours = [] out_GAN = [] out_OT = [] print("Computing transportation plan...") for dim in range(3): print(f"Color space {dim+1}/3") out_OT.append([]) P = sinkhorn(Q[dim, :, 0], Q[dim, :, 1], C, img_size[0], img_size[1], epsilon) for t in tqdm(np.linspace(0, 1, L)): out_OT[-1].append( max_val - generate_interpolation(img_size[0], img_size[1], P, t) * ((1 - t) * Q_counts[dim, 0, 0] + t * Q_counts[dim, 0, 1])) out_OT = [np.stack(im_channels, axis=0) for im_channels in zip(*out_OT)] print("Computing GAN projections...") # Project OT results on GAN GAN_projections = [ project_on_generator(Generator, pix2pix, I, Encoder, dcgan_img_size=dcgan_size, pix2pix_img_size=pix2pix_size) for I in out_OT ] GAN_projections_images, GAN_projections_noises = zip(*GAN_projections) out_ours = GAN_projections_images # Linearly interpolate GAN's latent space noise1, noise2 = GAN_projections_noises[0].cuda( ), GAN_projections_noises[-1].cuda() for t in np.linspace(0, 1, L): t = float(t) # cast numpy object to primative type GAN_image = Generator((1 - t) * noise1 + t * noise2) GAN_image = F.interpolate(GAN_image, scale_factor=2, mode='bilinear') pix_outputs = pix2pix(GAN_image) GAN_image = utils.denorm(pix_outputs.detach()).cpu().numpy().reshape( 3, -1, 1) out_GAN.append(GAN_image.clip(0, 1)) # Save results: print("Saving results...") out_ours = torch.stack( [torch.Tensor(im).reshape(3, *img_size) for im in out_ours]) out_OT = torch.stack( [torch.Tensor(im).reshape(3, *img_size) for im in out_OT]) out_GAN = torch.stack( [torch.Tensor(im).reshape(3, *img_size) for im in out_GAN]) if not os.path.exists(results_path): os.mkdir(results_path) output_path = join(results_path, simulation_name + '.png') save_image(torch.cat([out_OT, out_ours, out_GAN], dim=0), output_path, nrow=L, normalize=False, scale_each=False, range=(0, 1)) print(f"Image saved in {output_path}")
def train_gan(latent_dim=100, num_filters=[1024, 512, 256, 128], batch_size=128, num_epochs=100, h5_file_path='shoes_images/shoes.hdf5', save_dir='networks/', train_log_dir='dcgan_log_dir', learning_rate=0.0002, betas=(0.5, 0.999)): # Models G = Generator(latent_dim, num_filters) D = Discriminator(num_filters[::-1]) G.cuda() D.cuda() # Loss function criterion = torch.nn.BCELoss() # Optimizers G_optimizer = optim.Adam(G.parameters(), lr=learning_rate, betas=betas, weight_decay=1e-5) D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas, weight_decay=1e-5) # Schedulers G_scheduler = optim.lr_scheduler.MultiStepLR(G_optimizer, milestones=[25, 50, 75]) D_scheduler = optim.lr_scheduler.MultiStepLR(D_optimizer, milestones=[25, 50, 75]) # loss arrays D_avg_losses = [] G_avg_losses = [] # Fixed noise for test num_test_samples = 6 * 6 fixed_noise = torch.randn(num_test_samples, latent_dim, 1, 1).cuda() # Dataloader data_loader = dataloader.get_h5_dataset(path=h5_file_path, batch_size=batch_size) for epoch in range(num_epochs): D_epoch_losses = [] G_epoch_losses = [] for i, images in enumerate(data_loader): mini_batch = images.size()[0] x = images.cuda() y_real = torch.ones(mini_batch).cuda() y_fake = torch.zeros(mini_batch).cuda() # Train discriminator D_real_decision = D(x).squeeze() D_real_loss = criterion(D_real_decision, y_real) z = torch.randn(mini_batch, latent_dim, 1, 1) z = z.cuda() generated_images = G(z) D_fake_decision = D(generated_images).squeeze() D_fake_loss = criterion(D_fake_decision, y_fake) # Backprop D_loss = D_real_loss + D_fake_loss D.zero_grad() if i % 2 == 0: # Update discriminator only once every 2 batches D_loss.backward() D_optimizer.step() # Train generator z = torch.randn(mini_batch, latent_dim, 1, 1) z = z.cuda() generated_images = G(z) D_fake_decision = D(generated_images).squeeze() G_loss = criterion(D_fake_decision, y_real) # Backprop Generator D.zero_grad() G.zero_grad() G_loss.backward() G_optimizer.step() # loss values D_epoch_losses.append(D_loss.data.item()) G_epoch_losses.append(G_loss.data.item()) print('Epoch [%d/%d], Step [%d/%d], D_loss: %.4f, G_loss: %.4f' % (epoch + 1, num_epochs, i + 1, len(data_loader), D_loss.data.item(), G_loss.data.item())) D_avg_loss = torch.mean(torch.FloatTensor(D_epoch_losses)).item() G_avg_loss = torch.mean(torch.FloatTensor(G_epoch_losses)).item() D_avg_losses.append(D_avg_loss) G_avg_losses.append(G_avg_loss) # Plots plot_loss(D_avg_losses, G_avg_losses, num_epochs, log_dir=train_log_dir) G.eval() generated_images = G(fixed_noise).detach() generated_images = denorm(generated_images) G.train() plot_result(generated_images, epoch, log_dir=train_log_dir) # Save models torch.save(G.state_dict(), join(save_dir, 'generator')) torch.save(D.state_dict(), join(save_dir, 'discriminator')) # Decrease learning-rate G_scheduler.step() D_scheduler.step()
def test_encoder(latent_dim=100, num_filters=[1024, 512, 256, 128], batch_size=128, num_epochs=100, h5_file_path='shoes_images/shoes.hdf5', save_dir='networks/', train_log_dir='dcgan_log_dir', alpha=0.002): # load alexnet: alexnet = models.alexnet(pretrained=True).cuda() alexnet.eval() for param in alexnet.parameters(): param.requires_grad = False G = Generator(latent_dim, num_filters).cuda() generator_path = join(save_dir, 'generator') G.load_state_dict(torch.load(generator_path)) G.eval() for param in G.parameters(): param.requires_grad = False E = Encoder(num_filters[::-1], latent_dim).cuda() encoder_path = join(save_dir, 'encoder') E.load_state_dict(torch.load(encoder_path)) E.eval() for param in E.parameters(): param.requires_grad = False # Dataloader data_loader = dataloader.get_h5_dataset(path=h5_file_path, batch_size=batch_size) interpolate = lambda x: F.interpolate(x, scale_factor=4, mode='bilinear') images = next(iter(data_loader)) mini_batch = images.size()[0] x = images.cuda() x_features = alexnet.features(alexnet_norm(interpolate(denorm(x)))) # Encode z = E(x) out_images = torch.stack((denorm(x), denorm(G(z))), dim=1) z.requires_grad_(True) criterion = torch.nn.MSELoss() optimizer = torch.optim.Adam([z], lr=1e-3) for num_epoch in range(100): outputs = G(z) # loss = criterion(outputs, x_) loss = criterion(x, outputs) + 0.002 * criterion( x_features, alexnet.features(alexnet_norm(interpolate(denorm(outputs))))) z.grad = None loss.backward() optimizer.step() out_images = torch.cat((out_images, denorm(G(z)).unsqueeze(1)), dim=1) nrow = out_images.shape[1] out_images = out_images.reshape(-1, *x.shape[1:]) save_image(out_images, join(train_log_dir, 'encoder_images.png'), nrow=nrow, normalize=False, scale_each=False, range=(0, 1))
def finetune_encoder_with_samples(latent_dim=100, num_filters=[1024, 512, 256, 128], batch_size=128, num_epochs=100, h5_file_path='shoes_images/shoes.hdf5', save_dir='networks/', train_log_dir='dcgan_log_dir', learning_rate=0.0002, betas=(0.5, 0.999), alpha=0.002): # load alexnet: alexnet = models.alexnet(pretrained=True).cuda() alexnet.eval() for param in alexnet.parameters(): param.requires_grad = False # Load generator and fix weights G = Generator(latent_dim, num_filters).cuda() generator_path = join(save_dir, 'generator') G.load_state_dict(torch.load(generator_path)) G.eval() for param in G.parameters(): param.requires_grad = False # Load encoder E = Encoder(num_filters[::-1], latent_dim).cuda() encoder_path = join(save_dir, 'encoder') E.load_state_dict(torch.load(encoder_path)) E.train() # Loss function criterion = torch.nn.MSELoss() # Optimizers E_optimizer = optim.Adam(E.parameters(), lr=learning_rate, betas=betas, weight_decay=1e-5) E_avg_losses = [] # Dataloader data_loader = dataloader.get_h5_dataset(path=h5_file_path, batch_size=batch_size) interpolate = lambda x: F.interpolate(x, scale_factor=4, mode='bilinear') get_features = lambda x: alexnet.features( alexnet_norm(interpolate(denorm(x)))) for epoch in range(num_epochs): E_losses = [] # minibatch training for i, images in enumerate(data_loader): # generate_noise mini_batch = images.size()[0] x = images.cuda() # Train Encoder out_images = G(E(x)) E_loss = criterion(x, out_images) + alpha * criterion( get_features(x), get_features(out_images)) # Backprop E.zero_grad() E_loss.backward() E_optimizer.step() # loss values E_losses.append(E_loss.data.item()) print('Epoch [%d/%d], Step [%d/%d], E_loss: %.4f' % (epoch + 1, num_epochs, i + 1, len(data_loader), E_loss.data.item())) E_avg_loss = torch.mean(torch.FloatTensor(E_losses)).item() # avg loss values for plot E_avg_losses.append(E_avg_loss) plot_loss(E_avg_losses, None, num_epochs, log_dir=train_log_dir, model1='Encoder', model2='') # Save models torch.save(E.state_dict(), join(save_dir, 'encoder'))
def train_encoder_with_noise(latent_dim=100, num_filters=[1024, 512, 256, 128], batch_size=128, num_epochs=100, h5_file_path='shoes_images/shoes.hdf5', save_dir='networks/', train_log_dir='dcgan_log_dir', learning_rate=0.0002, betas=(0.5, 0.999)): # Load generator and fix weights G = Generator(latent_dim, num_filters).cuda() generator_path = join(save_dir, 'generator') G.load_state_dict(torch.load(generator_path)) G.eval() for param in G.parameters(): param.requires_grad = False E = Encoder(num_filters[::-1], latent_dim) E.cuda() # Loss function criterion = torch.nn.MSELoss() # Optimizer E_optimizer = optim.Adam(E.parameters(), lr=learning_rate, betas=betas, weight_decay=1e-5) E_avg_losses = [] # Dataloader data_loader = dataloader.get_h5_dataset(path=h5_file_path, batch_size=batch_size) for epoch in range(num_epochs): E_losses = [] # minibatch training for i, images in enumerate(data_loader): # generate_noise z = torch.randn(images.shape[0], latent_dim, 1, 1).cuda() x = G(z) # Train Encoder out_latent = E(x) E_loss = criterion(z, out_latent) # Back propagation E.zero_grad() E_loss.backward() E_optimizer.step() # loss values E_losses.append(E_loss.data.item()) print('Epoch [%d/%d], Step [%d/%d], E_loss: %.4f' % (epoch + 1, num_epochs, i + 1, len(data_loader), E_loss.data.item())) E_avg_loss = torch.mean(torch.FloatTensor(E_losses)).item() # avg loss values for plot E_avg_losses.append(E_avg_loss) plot_loss(E_avg_losses, None, num_epochs, log_dir=train_log_dir, model1='Encoder', model2='') # Save models torch.save(E.state_dict(), join(save_dir, 'encoder'))
def main_worker(gpu, ngpus_per_node, args): args.scale_gen_surprisal_by_D = args.scale_gen_surprisal_by_D == "True" args.prioritized_replay = args.prioritized_replay == "True" args.divisive_normalization = args.divisive_normalization == "True" args.spectral_norm = args.spectral_norm == "True" args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: # For multiprocessing distributed training, rank needs to be the # global rank among all the processes args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) # ----- Get dataset ------ # image_size = args.image_size # Data loading code traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) if args.dataset in ['imagenet', 'folder', 'lfw']: # folder dataset train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), normalize, ])) nc = 3 elif args.dataset == 'cifar10': train_dataset = datasets.CIFAR10(root=args.data, download=True, transform=transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize( (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) nc = 3 elif args.dataset == 'mnist': train_dataset = datasets.MNIST(root=args.data, download=True, transform=transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )), ])) nc = 1 assert train_dataset if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) # ----- Create models ------ # inference = Inference(args.noise_dim, args.n_filters, nc, image_size=image_size, noise_before=False, hard_norm=args.divisive_normalization, spec_norm=args.spectral_norm) generator = Generator(args.noise_dim, args.n_filters, nc, image_size=image_size, hard_norm=args.divisive_normalization) discriminator = Discriminator(args.noise_dim, args.n_filters, nc, image_size=image_size, hard_norm=args.divisive_normalization, hidden_dim=128) readout_disc = ReadoutDiscriminator( args.n_filters, image_size, spec_norm = args.spectral_norm) if \ (args.gamma < 1) else None # get to proper GPU if args.distributed: # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. if args.gpu is not None: torch.cuda.set_device(args.gpu) inference.cuda(args.gpu) generator.cuda(args.gpu) if args.gamma < 1: readout_disc.cuda(args.gpu) discriminator.cuda(args.gpu) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int( (args.workers + ngpus_per_node - 1) / ngpus_per_node) inference = torch.nn.parallel.DistributedDataParallel( inference, device_ids=[args.gpu], broadcast_buffers=False) generator = torch.nn.parallel.DistributedDataParallel( generator, device_ids=[args.gpu], broadcast_buffers=False) if args.gamma < 1: readout_disc = torch.nn.parallel.DistributedDataParallel( readout_disc, device_ids=[args.gpu], broadcast_buffers=False) discriminator = torch.nn.parallel.DistributedDataParallel( discriminator, device_ids=[args.gpu], broadcast_buffers=False) else: inference.cuda() generator.cuda() if args.gamma < 1: readout_disc.cuda() discriminator.cuda() # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set generator = torch.nn.parallel.DistributedDataParallel(generator) inference = torch.nn.parallel.DistributedDataParallel(inference) readout_disc = torch.nn.parallel.DistributedDataParallel( readout_disc) if args.gamma < 1 else None discriminator = torch.nn.parallel.DistributedDataParallel( discriminator) # give intermediate state promote_attributes(inference) promote_attributes(generator) elif args.gpu is not None: torch.cuda.set_device(args.gpu) inference = inference.cuda(args.gpu) discriminator = discriminator.cuda(args.gpu) generator = generator.cuda(args.gpu) readout_disc = readout_disc.cuda(args.gpu) if args.gamma < 1 else None else: # DataParallel will divide and allocate batch_size to all available GPUs inference = torch.nn.DataParallel(inference).cuda() generator = torch.nn.DataParallel(generator).cuda() discriminator = torch.nn.DataParallel(discriminator).cuda() readout_disc = torch.nn.DataParallel( readout_disc).cuda() if args.gamma < 1 else None promote_attributes(inference) promote_attributes(generator) # ------ Build optimizer ------ # optimizerD = optim.Adam(discriminator.parameters(), lr=args.lr_d, betas=(args.beta1, args.beta2), weight_decay=args.wd) # we want the lr to be slower for upper layers as they get more gradient flow optimizerG = optim.Adam(generator.parameters(), lr=args.lr_g, betas=(args.beta1, args.beta2), weight_decay=args.wd) # similarly for the encoder lower layers should have have slower lrs optimizerF = optim.Adam(inference.parameters(), lr=args.lr_e, betas=(args.beta1, args.beta2), weight_decay=args.wd) optimizerRD = optim.Adam(readout_disc.parameters(), lr=args.lr_rd, betas=(args.beta1, args.beta2), weight_decay = args.wd) if \ args.gamma < 1 else None # ------ optionally resume from a checkpoint ------- # if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) if args.gpu is None: checkpoint = torch.load(args.resume) else: # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) args.start_epoch = checkpoint['epoch'] inference.load_state_dict(checkpoint['inference_state_dict']) generator.load_state_dict(checkpoint['generator_state_dict']) discriminator.load_state_dict( checkpoint['discriminator_state_dict']) optimizerD.load_state_dict(checkpoint['optimizerD']) optimizerG.load_state_dict(checkpoint['optimizerG']) optimizerF.load_state_dict(checkpoint['optimizerF']) train_history = checkpoint['train_history'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: train_history = { 'D_losses': [], 'GF_losses': [], 'ML_losses': [], 'reconstruction_error': [] } args.start_epoch = 0 decoding_error_history = [] reconstruction_history = [] decoding_error_std_history = [] reconstruction_std_history = [] if args.detailed_logging: # how well can we decode from layers? accuracies, reconstructions = decode_classes_from_layers( 0 if args.gpu is None else args.gpu, inference, generator, image_size, args.n_filters, args.noise_dim, args.data, args.dataset, nonlinear=False, lr=1, folds=4, epochs=20, hidden_size=1000, wd=1e-3, opt='sgd', lr_schedule=True, verbose=False, batch_size=args.batch_size, workers=args.workers) print("Epoch {}".format(-1)) for i in range(6): print("Layer{}: Accuracy {} +/- {}".format( i, accuracies.mean(dim=0)[i], accuracies.std(dim=0)[i])) decoding_error_history.append(accuracies.mean(dim=0).detach().cpu()) reconstruction_history.append( reconstructions.mean(dim=0).detach().cpu()) decoding_error_std_history.append(accuracies.std(dim=0).detach().cpu()) reconstruction_std_history.append( reconstructions.std(dim=0).detach().cpu()) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rates( [optimizerF, optimizerD, optimizerG, optimizerRD], epoch, args, inference, generator, discriminator) if args.distributed: train_sampler.set_epoch(epoch) train(args, inference, generator, train_loader, discriminator, optimizerD, optimizerG, optimizerF, epoch, readout_disc, optimizerRD) generator.eval() inference.eval() if args.save_imgs: try: os.mkdir("gen_images") except: pass noise = torch.empty(100, args.noise_dim, 1, 1).normal_().cuda() to_visualize = generator(noise).detach().cpu() grid = utils.make_grid(to_visualize, nrow=10, padding=5, normalize=True, range=None, scale_each=False, pad_value=0) sv_img(grid, "gen_images/imgs_epoch{}.png".format(epoch), epoch) if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): if args.detailed_logging or (epoch == args.epochs - 1): # how well can we decode from layers? accuracies, reconstructions = decode_classes_from_layers( 0 if args.gpu is None else args.gpu, inference, generator, image_size, args.n_filters, args.noise_dim, args.data, args.dataset, nonlinear=False, lr=1, folds=4, epochs=20, hidden_size=1000, wd=1e-3, opt='sgd', lr_schedule=True, verbose=False, batch_size=args.batch_size, workers=args.workers) print("Epoch {}".format(epoch)) for i in range(6): print("Layer{}: Accuracy {} +/- {}".format( i, accuracies.mean(dim=0)[i], accuracies.std(dim=0)[i])) decoding_error_history.append( accuracies.mean(dim=0).detach().cpu()) reconstruction_history.append( reconstructions.mean(dim=0).detach().cpu()) decoding_error_std_history.append( accuracies.std(dim=0).detach().cpu()) reconstruction_std_history.append( reconstructions.std(dim=0).detach().cpu()) torch.save( { 'epoch': epoch + 1, 'inference_state_dict': inference.state_dict(), 'generator_state_dict': generator.state_dict(), 'readout_dict_state_dict': readout_disc.state_dict() if args.gamma < 1 else None, 'discriminator_state_dict': discriminator.state_dict(), 'args': args, 'optimizerD': optimizerD.state_dict(), 'optimizerG': optimizerG.state_dict(), 'optimizerF': optimizerF.state_dict(), 'train_history': { "decoding_error_history": decoding_error_history, "reconstruction_history": reconstruction_history, "decoding_error_std_history": decoding_error_std_history, "reconstruction_std_history": reconstruction_std_history } }, 'checkpoint.pth.tar')