def load_cortex(path, args):
    """Loads a cortex from path."""
    bn = False if args.loss_type == 'wasserstein' else True

    inference = Inference(args.noise_dim,
                          args.n_filters,
                          1 if args.dataset == 'mnist' else 3,
                          image_size=args.image_size,
                          bn=args.bn,
                          hard_norm=args.divisive_normalization,
                          spec_norm=args.spec_norm,
                          derelu=False)
    generator = Generator(args.noise_dim,
                          args.n_filters,
                          1 if args.dataset == 'mnist' else 3,
                          image_size=args.image_size,
                          hard_norm=args.divisive_normalization)

    if os.path.isfile(path):
        print("=> loading checkpoint '{}'".format(path))
        # load onto the CPU
        checkpoint = torch.load(path, map_location=torch.device('cpu'))
        inference.load_state_dict(checkpoint['inference_state_dict'])
        generator.load_state_dict(checkpoint['generator_state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            path, checkpoint['epoch']))
    else:
        raise IOError("=> no checkpoint found at '{}'".format(path))

    return inference, generator
def load_checkpoint(args):
    """Loads a cortex from path."""
    path = args.path + '/checkpoint.pth.tar'

    generator = Generator(args.noise_dim,
                          args.n_filters,
                          1 if args.dataset == 'mnist' else 3,
                          image_size=args.image_size,
                          hard_norm=args.divisive_normalization)

    if os.path.isfile(path):
        print("=> loading checkpoint '{}'".format(path))
        # load onto the CPU
        checkpoint = torch.load(path, map_location=torch.device('cpu'))
        generator.load_state_dict(checkpoint['generator_state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            path, checkpoint['epoch']))
    else:
        raise IOError("=> no checkpoint found at '{}'".format(path))

    return generator
Esempio n. 3
0
def morph_project_only(im1: np.ndarray,
                       im2: np.ndarray,
                       Generator: Generator,
                       Encoder: Encoder,
                       pix2pix: networks.UnetGenerator,
                       epsilon: float = 20.0,
                       L: int = 9,
                       dcgan_size: int = 64,
                       pix2pix_size: int = 128,
                       simulation_name: str = "image_interpolation",
                       results_path: str = "results") -> None:
    """Generates 3 morphing processes given two images. 
    The first is simple Wasserstein Barycenters, the second is our algorithm and 
    the third is a simple GAN latent space linear interpolation
    
    Arguments:
        im1 {np.ndarray} -- source image
        im2 {np.ndarray} -- destination image
        Generator {Generator} -- DCGAN generator (latent space to pixel space)
        Encoder {Encoder} -- DCGAN encoder (pixel space to latent space)
        pix2pix {networks.UnetGenerator} -- pix2pix model trained to increase an image resolution
    
    Keyword Arguments:
        epsilon {float} -- entropic regularization parameter (default: {20.0})
        L {int} -- number of images in the trasformation (default: {9})
        dcgan_size {int} -- DCGAN image size (low resolution) (default: {64})
        pix2pix_size {int} -- Pix2Pix image size (high resolution) (default: {128})
        simulation_name {str} -- name of the simulation. Affects the saved file names (default: {"image_interpolation"})
        results_path {str} -- the path to save the results in (default: {"results"})
    """
    img_size = im1.shape[:2]
    im1, im2 = (I.transpose(2, 0, 1).reshape(3, -1, 1) for I in (im1, im2))

    print("Preparing transportation cost matrix...")
    C = generate_metric(img_size)
    Q = np.concatenate([im1, im2], axis=-1)
    Q, max_val, Q_counts = preprocess_Q(Q)
    out_ours = []
    out_GAN = []
    out_OT = []

    print("Computing transportation plan...")
    for dim in range(3):
        print(f"Color space {dim+1}/3")
        out_OT.append([])
        P = sinkhorn(Q[dim, :, 0], Q[dim, :, 1], C, img_size[0], img_size[1],
                     epsilon)
        for t in tqdm(np.linspace(0, 1, L)):
            out_OT[-1].append(
                max_val -
                generate_interpolation(img_size[0], img_size[1], P, t) *
                ((1 - t) * Q_counts[dim, 0, 0] + t * Q_counts[dim, 0, 1]))
    out_OT = [np.stack(im_channels, axis=0) for im_channels in zip(*out_OT)]

    print("Computing GAN projections...")
    # Project OT results on GAN
    GAN_projections = [
        project_on_generator(Generator,
                             pix2pix,
                             I,
                             Encoder,
                             dcgan_img_size=dcgan_size,
                             pix2pix_img_size=pix2pix_size) for I in out_OT
    ]
    GAN_projections_images, GAN_projections_noises = zip(*GAN_projections)
    out_ours = GAN_projections_images

    # Linearly interpolate GAN's latent space
    noise1, noise2 = GAN_projections_noises[0].cuda(
    ), GAN_projections_noises[-1].cuda()
    for t in np.linspace(0, 1, L):
        t = float(t)  # cast numpy object to primative type
        GAN_image = Generator((1 - t) * noise1 + t * noise2)
        GAN_image = F.interpolate(GAN_image, scale_factor=2, mode='bilinear')
        pix_outputs = pix2pix(GAN_image)
        GAN_image = utils.denorm(pix_outputs.detach()).cpu().numpy().reshape(
            3, -1, 1)
        out_GAN.append(GAN_image.clip(0, 1))

    # Save results:
    print("Saving results...")
    out_ours = torch.stack(
        [torch.Tensor(im).reshape(3, *img_size) for im in out_ours])
    out_OT = torch.stack(
        [torch.Tensor(im).reshape(3, *img_size) for im in out_OT])
    out_GAN = torch.stack(
        [torch.Tensor(im).reshape(3, *img_size) for im in out_GAN])
    if not os.path.exists(results_path):
        os.mkdir(results_path)
    output_path = join(results_path, simulation_name + '.png')
    save_image(torch.cat([out_OT, out_ours, out_GAN], dim=0),
               output_path,
               nrow=L,
               normalize=False,
               scale_each=False,
               range=(0, 1))
    print(f"Image saved in {output_path}")
Esempio n. 4
0
def train_gan(latent_dim=100,
              num_filters=[1024, 512, 256, 128],
              batch_size=128,
              num_epochs=100,
              h5_file_path='shoes_images/shoes.hdf5',
              save_dir='networks/',
              train_log_dir='dcgan_log_dir',
              learning_rate=0.0002,
              betas=(0.5, 0.999)):
    # Models
    G = Generator(latent_dim, num_filters)
    D = Discriminator(num_filters[::-1])
    G.cuda()
    D.cuda()

    # Loss function
    criterion = torch.nn.BCELoss()

    # Optimizers
    G_optimizer = optim.Adam(G.parameters(),
                             lr=learning_rate,
                             betas=betas,
                             weight_decay=1e-5)
    D_optimizer = optim.Adam(D.parameters(),
                             lr=learning_rate,
                             betas=betas,
                             weight_decay=1e-5)

    # Schedulers
    G_scheduler = optim.lr_scheduler.MultiStepLR(G_optimizer,
                                                 milestones=[25, 50, 75])
    D_scheduler = optim.lr_scheduler.MultiStepLR(D_optimizer,
                                                 milestones=[25, 50, 75])

    # loss arrays
    D_avg_losses = []
    G_avg_losses = []

    # Fixed noise for test
    num_test_samples = 6 * 6
    fixed_noise = torch.randn(num_test_samples, latent_dim, 1, 1).cuda()

    # Dataloader
    data_loader = dataloader.get_h5_dataset(path=h5_file_path,
                                            batch_size=batch_size)

    for epoch in range(num_epochs):
        D_epoch_losses = []
        G_epoch_losses = []

        for i, images in enumerate(data_loader):
            mini_batch = images.size()[0]
            x = images.cuda()

            y_real = torch.ones(mini_batch).cuda()
            y_fake = torch.zeros(mini_batch).cuda()

            # Train discriminator
            D_real_decision = D(x).squeeze()
            D_real_loss = criterion(D_real_decision, y_real)

            z = torch.randn(mini_batch, latent_dim, 1, 1)
            z = z.cuda()
            generated_images = G(z)

            D_fake_decision = D(generated_images).squeeze()
            D_fake_loss = criterion(D_fake_decision, y_fake)

            # Backprop
            D_loss = D_real_loss + D_fake_loss
            D.zero_grad()
            if i % 2 == 0:  # Update discriminator only once every 2 batches
                D_loss.backward()
                D_optimizer.step()

            # Train generator
            z = torch.randn(mini_batch, latent_dim, 1, 1)
            z = z.cuda()
            generated_images = G(z)

            D_fake_decision = D(generated_images).squeeze()
            G_loss = criterion(D_fake_decision, y_real)

            # Backprop Generator
            D.zero_grad()
            G.zero_grad()
            G_loss.backward()
            G_optimizer.step()

            # loss values
            D_epoch_losses.append(D_loss.data.item())
            G_epoch_losses.append(G_loss.data.item())

            print('Epoch [%d/%d], Step [%d/%d], D_loss: %.4f, G_loss: %.4f' %
                  (epoch + 1, num_epochs, i + 1, len(data_loader),
                   D_loss.data.item(), G_loss.data.item()))

        D_avg_loss = torch.mean(torch.FloatTensor(D_epoch_losses)).item()
        G_avg_loss = torch.mean(torch.FloatTensor(G_epoch_losses)).item()
        D_avg_losses.append(D_avg_loss)
        G_avg_losses.append(G_avg_loss)

        # Plots
        plot_loss(D_avg_losses,
                  G_avg_losses,
                  num_epochs,
                  log_dir=train_log_dir)

        G.eval()
        generated_images = G(fixed_noise).detach()
        generated_images = denorm(generated_images)
        G.train()
        plot_result(generated_images, epoch, log_dir=train_log_dir)

        # Save models
        torch.save(G.state_dict(), join(save_dir, 'generator'))
        torch.save(D.state_dict(), join(save_dir, 'discriminator'))

        # Decrease learning-rate
        G_scheduler.step()
        D_scheduler.step()
Esempio n. 5
0
def test_encoder(latent_dim=100,
                 num_filters=[1024, 512, 256, 128],
                 batch_size=128,
                 num_epochs=100,
                 h5_file_path='shoes_images/shoes.hdf5',
                 save_dir='networks/',
                 train_log_dir='dcgan_log_dir',
                 alpha=0.002):
    # load alexnet:
    alexnet = models.alexnet(pretrained=True).cuda()
    alexnet.eval()
    for param in alexnet.parameters():
        param.requires_grad = False

    G = Generator(latent_dim, num_filters).cuda()
    generator_path = join(save_dir, 'generator')
    G.load_state_dict(torch.load(generator_path))
    G.eval()
    for param in G.parameters():
        param.requires_grad = False

    E = Encoder(num_filters[::-1], latent_dim).cuda()
    encoder_path = join(save_dir, 'encoder')
    E.load_state_dict(torch.load(encoder_path))
    E.eval()
    for param in E.parameters():
        param.requires_grad = False

    # Dataloader
    data_loader = dataloader.get_h5_dataset(path=h5_file_path,
                                            batch_size=batch_size)

    interpolate = lambda x: F.interpolate(x, scale_factor=4, mode='bilinear')

    images = next(iter(data_loader))
    mini_batch = images.size()[0]
    x = images.cuda()
    x_features = alexnet.features(alexnet_norm(interpolate(denorm(x))))

    # Encode
    z = E(x)
    out_images = torch.stack((denorm(x), denorm(G(z))), dim=1)

    z.requires_grad_(True)
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam([z], lr=1e-3)

    for num_epoch in range(100):
        outputs = G(z)
        # loss = criterion(outputs, x_)
        loss = criterion(x, outputs) + 0.002 * criterion(
            x_features,
            alexnet.features(alexnet_norm(interpolate(denorm(outputs)))))
        z.grad = None
        loss.backward()
        optimizer.step()
    out_images = torch.cat((out_images, denorm(G(z)).unsqueeze(1)), dim=1)

    nrow = out_images.shape[1]
    out_images = out_images.reshape(-1, *x.shape[1:])
    save_image(out_images,
               join(train_log_dir, 'encoder_images.png'),
               nrow=nrow,
               normalize=False,
               scale_each=False,
               range=(0, 1))
Esempio n. 6
0
def finetune_encoder_with_samples(latent_dim=100,
                                  num_filters=[1024, 512, 256, 128],
                                  batch_size=128,
                                  num_epochs=100,
                                  h5_file_path='shoes_images/shoes.hdf5',
                                  save_dir='networks/',
                                  train_log_dir='dcgan_log_dir',
                                  learning_rate=0.0002,
                                  betas=(0.5, 0.999),
                                  alpha=0.002):
    # load alexnet:
    alexnet = models.alexnet(pretrained=True).cuda()
    alexnet.eval()
    for param in alexnet.parameters():
        param.requires_grad = False

    # Load generator and fix weights
    G = Generator(latent_dim, num_filters).cuda()
    generator_path = join(save_dir, 'generator')
    G.load_state_dict(torch.load(generator_path))
    G.eval()
    for param in G.parameters():
        param.requires_grad = False

    # Load encoder
    E = Encoder(num_filters[::-1], latent_dim).cuda()
    encoder_path = join(save_dir, 'encoder')
    E.load_state_dict(torch.load(encoder_path))
    E.train()

    # Loss function
    criterion = torch.nn.MSELoss()

    # Optimizers
    E_optimizer = optim.Adam(E.parameters(),
                             lr=learning_rate,
                             betas=betas,
                             weight_decay=1e-5)

    E_avg_losses = []

    # Dataloader
    data_loader = dataloader.get_h5_dataset(path=h5_file_path,
                                            batch_size=batch_size)

    interpolate = lambda x: F.interpolate(x, scale_factor=4, mode='bilinear')
    get_features = lambda x: alexnet.features(
        alexnet_norm(interpolate(denorm(x))))
    for epoch in range(num_epochs):
        E_losses = []

        # minibatch training
        for i, images in enumerate(data_loader):

            # generate_noise
            mini_batch = images.size()[0]
            x = images.cuda()

            # Train Encoder
            out_images = G(E(x))
            E_loss = criterion(x, out_images) + alpha * criterion(
                get_features(x), get_features(out_images))

            # Backprop
            E.zero_grad()
            E_loss.backward()
            E_optimizer.step()

            # loss values
            E_losses.append(E_loss.data.item())

            print('Epoch [%d/%d], Step [%d/%d], E_loss: %.4f' %
                  (epoch + 1, num_epochs, i + 1, len(data_loader),
                   E_loss.data.item()))

        E_avg_loss = torch.mean(torch.FloatTensor(E_losses)).item()

        # avg loss values for plot
        E_avg_losses.append(E_avg_loss)

        plot_loss(E_avg_losses,
                  None,
                  num_epochs,
                  log_dir=train_log_dir,
                  model1='Encoder',
                  model2='')

        # Save models
        torch.save(E.state_dict(), join(save_dir, 'encoder'))
Esempio n. 7
0
def train_encoder_with_noise(latent_dim=100,
                             num_filters=[1024, 512, 256, 128],
                             batch_size=128,
                             num_epochs=100,
                             h5_file_path='shoes_images/shoes.hdf5',
                             save_dir='networks/',
                             train_log_dir='dcgan_log_dir',
                             learning_rate=0.0002,
                             betas=(0.5, 0.999)):
    # Load generator and fix weights
    G = Generator(latent_dim, num_filters).cuda()
    generator_path = join(save_dir, 'generator')
    G.load_state_dict(torch.load(generator_path))
    G.eval()
    for param in G.parameters():
        param.requires_grad = False

    E = Encoder(num_filters[::-1], latent_dim)
    E.cuda()

    # Loss function
    criterion = torch.nn.MSELoss()

    # Optimizer
    E_optimizer = optim.Adam(E.parameters(),
                             lr=learning_rate,
                             betas=betas,
                             weight_decay=1e-5)

    E_avg_losses = []

    # Dataloader
    data_loader = dataloader.get_h5_dataset(path=h5_file_path,
                                            batch_size=batch_size)

    for epoch in range(num_epochs):
        E_losses = []

        # minibatch training
        for i, images in enumerate(data_loader):

            # generate_noise
            z = torch.randn(images.shape[0], latent_dim, 1, 1).cuda()
            x = G(z)

            # Train Encoder
            out_latent = E(x)
            E_loss = criterion(z, out_latent)

            # Back propagation
            E.zero_grad()
            E_loss.backward()
            E_optimizer.step()

            # loss values
            E_losses.append(E_loss.data.item())

            print('Epoch [%d/%d], Step [%d/%d], E_loss: %.4f' %
                  (epoch + 1, num_epochs, i + 1, len(data_loader),
                   E_loss.data.item()))

        E_avg_loss = torch.mean(torch.FloatTensor(E_losses)).item()

        # avg loss values for plot
        E_avg_losses.append(E_avg_loss)

        plot_loss(E_avg_losses,
                  None,
                  num_epochs,
                  log_dir=train_log_dir,
                  model1='Encoder',
                  model2='')

        # Save models
        torch.save(E.state_dict(), join(save_dir, 'encoder'))
Esempio n. 8
0
def main_worker(gpu, ngpus_per_node, args):
    args.scale_gen_surprisal_by_D = args.scale_gen_surprisal_by_D == "True"
    args.prioritized_replay = args.prioritized_replay == "True"
    args.divisive_normalization = args.divisive_normalization == "True"
    args.spectral_norm = args.spectral_norm == "True"

    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # ----- Get dataset ------ #

    image_size = args.image_size

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

    if args.dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.Resize(image_size),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                normalize,
            ]))
        nc = 3
    elif args.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10(root=args.data,
                                         download=True,
                                         transform=transforms.Compose([
                                             transforms.Resize(image_size),
                                             transforms.ToTensor(),
                                             transforms.Normalize(
                                                 (0.5, 0.5, 0.5),
                                                 (0.5, 0.5, 0.5)),
                                         ]))
        nc = 3

    elif args.dataset == 'mnist':
        train_dataset = datasets.MNIST(root=args.data,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.Resize(image_size),
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.5, ),
                                                                (0.5, )),
                                       ]))
        nc = 1

    assert train_dataset

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    # ----- Create models ------ #

    inference = Inference(args.noise_dim,
                          args.n_filters,
                          nc,
                          image_size=image_size,
                          noise_before=False,
                          hard_norm=args.divisive_normalization,
                          spec_norm=args.spectral_norm)
    generator = Generator(args.noise_dim,
                          args.n_filters,
                          nc,
                          image_size=image_size,
                          hard_norm=args.divisive_normalization)

    discriminator = Discriminator(args.noise_dim,
                                  args.n_filters,
                                  nc,
                                  image_size=image_size,
                                  hard_norm=args.divisive_normalization,
                                  hidden_dim=128)

    readout_disc = ReadoutDiscriminator( args.n_filters, image_size, spec_norm = args.spectral_norm) if \
                (args.gamma < 1) else None

    # get to proper GPU
    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            inference.cuda(args.gpu)
            generator.cuda(args.gpu)
            if args.gamma < 1:
                readout_disc.cuda(args.gpu)
            discriminator.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            inference = torch.nn.parallel.DistributedDataParallel(
                inference, device_ids=[args.gpu], broadcast_buffers=False)
            generator = torch.nn.parallel.DistributedDataParallel(
                generator, device_ids=[args.gpu], broadcast_buffers=False)
            if args.gamma < 1:
                readout_disc = torch.nn.parallel.DistributedDataParallel(
                    readout_disc,
                    device_ids=[args.gpu],
                    broadcast_buffers=False)
            discriminator = torch.nn.parallel.DistributedDataParallel(
                discriminator, device_ids=[args.gpu], broadcast_buffers=False)
        else:
            inference.cuda()
            generator.cuda()
            if args.gamma < 1:
                readout_disc.cuda()
            discriminator.cuda()

            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            generator = torch.nn.parallel.DistributedDataParallel(generator)
            inference = torch.nn.parallel.DistributedDataParallel(inference)
            readout_disc = torch.nn.parallel.DistributedDataParallel(
                readout_disc) if args.gamma < 1 else None
            discriminator = torch.nn.parallel.DistributedDataParallel(
                discriminator)

        # give intermediate state
        promote_attributes(inference)
        promote_attributes(generator)

    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        inference = inference.cuda(args.gpu)
        discriminator = discriminator.cuda(args.gpu)
        generator = generator.cuda(args.gpu)
        readout_disc = readout_disc.cuda(args.gpu) if args.gamma < 1 else None
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        inference = torch.nn.DataParallel(inference).cuda()
        generator = torch.nn.DataParallel(generator).cuda()
        discriminator = torch.nn.DataParallel(discriminator).cuda()
        readout_disc = torch.nn.DataParallel(
            readout_disc).cuda() if args.gamma < 1 else None

        promote_attributes(inference)
        promote_attributes(generator)

    # ------ Build optimizer ------ #
    optimizerD = optim.Adam(discriminator.parameters(),
                            lr=args.lr_d,
                            betas=(args.beta1, args.beta2),
                            weight_decay=args.wd)
    # we want the lr to be slower for upper layers as they get more gradient flow
    optimizerG = optim.Adam(generator.parameters(),
                            lr=args.lr_g,
                            betas=(args.beta1, args.beta2),
                            weight_decay=args.wd)

    # similarly for the encoder lower layers should have have slower lrs
    optimizerF = optim.Adam(inference.parameters(),
                            lr=args.lr_e,
                            betas=(args.beta1, args.beta2),
                            weight_decay=args.wd)

    optimizerRD = optim.Adam(readout_disc.parameters(),
                     lr=args.lr_rd, betas=(args.beta1, args.beta2), weight_decay = args.wd) if \
                        args.gamma < 1 else None

    # ------ optionally resume from a checkpoint ------- #
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']

            inference.load_state_dict(checkpoint['inference_state_dict'])
            generator.load_state_dict(checkpoint['generator_state_dict'])

            discriminator.load_state_dict(
                checkpoint['discriminator_state_dict'])
            optimizerD.load_state_dict(checkpoint['optimizerD'])
            optimizerG.load_state_dict(checkpoint['optimizerG'])
            optimizerF.load_state_dict(checkpoint['optimizerF'])
            train_history = checkpoint['train_history']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        train_history = {
            'D_losses': [],
            'GF_losses': [],
            'ML_losses': [],
            'reconstruction_error': []
        }
        args.start_epoch = 0

    decoding_error_history = []
    reconstruction_history = []
    decoding_error_std_history = []
    reconstruction_std_history = []

    if args.detailed_logging:
        # how well can we decode from layers?
        accuracies, reconstructions = decode_classes_from_layers(
            0 if args.gpu is None else args.gpu,
            inference,
            generator,
            image_size,
            args.n_filters,
            args.noise_dim,
            args.data,
            args.dataset,
            nonlinear=False,
            lr=1,
            folds=4,
            epochs=20,
            hidden_size=1000,
            wd=1e-3,
            opt='sgd',
            lr_schedule=True,
            verbose=False,
            batch_size=args.batch_size,
            workers=args.workers)
        print("Epoch {}".format(-1))
        for i in range(6):
            print("Layer{}: Accuracy {} +/- {}".format(
                i,
                accuracies.mean(dim=0)[i],
                accuracies.std(dim=0)[i]))
        decoding_error_history.append(accuracies.mean(dim=0).detach().cpu())
        reconstruction_history.append(
            reconstructions.mean(dim=0).detach().cpu())
        decoding_error_std_history.append(accuracies.std(dim=0).detach().cpu())
        reconstruction_std_history.append(
            reconstructions.std(dim=0).detach().cpu())

    for epoch in range(args.start_epoch, args.epochs):

        adjust_learning_rates(
            [optimizerF, optimizerD, optimizerG, optimizerRD], epoch, args,
            inference, generator, discriminator)

        if args.distributed:
            train_sampler.set_epoch(epoch)

        train(args, inference, generator, train_loader, discriminator,
              optimizerD, optimizerG, optimizerF, epoch, readout_disc,
              optimizerRD)
        generator.eval()
        inference.eval()

        if args.save_imgs:
            try:
                os.mkdir("gen_images")
            except:
                pass
            noise = torch.empty(100, args.noise_dim, 1, 1).normal_().cuda()
            to_visualize = generator(noise).detach().cpu()
            grid = utils.make_grid(to_visualize,
                                   nrow=10,
                                   padding=5,
                                   normalize=True,
                                   range=None,
                                   scale_each=False,
                                   pad_value=0)
            sv_img(grid, "gen_images/imgs_epoch{}.png".format(epoch), epoch)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            if args.detailed_logging or (epoch == args.epochs - 1):
                # how well can we decode from layers?
                accuracies, reconstructions = decode_classes_from_layers(
                    0 if args.gpu is None else args.gpu,
                    inference,
                    generator,
                    image_size,
                    args.n_filters,
                    args.noise_dim,
                    args.data,
                    args.dataset,
                    nonlinear=False,
                    lr=1,
                    folds=4,
                    epochs=20,
                    hidden_size=1000,
                    wd=1e-3,
                    opt='sgd',
                    lr_schedule=True,
                    verbose=False,
                    batch_size=args.batch_size,
                    workers=args.workers)
                print("Epoch {}".format(epoch))
                for i in range(6):
                    print("Layer{}: Accuracy {} +/- {}".format(
                        i,
                        accuracies.mean(dim=0)[i],
                        accuracies.std(dim=0)[i]))
                decoding_error_history.append(
                    accuracies.mean(dim=0).detach().cpu())
                reconstruction_history.append(
                    reconstructions.mean(dim=0).detach().cpu())
                decoding_error_std_history.append(
                    accuracies.std(dim=0).detach().cpu())
                reconstruction_std_history.append(
                    reconstructions.std(dim=0).detach().cpu())

            torch.save(
                {
                    'epoch':
                    epoch + 1,
                    'inference_state_dict':
                    inference.state_dict(),
                    'generator_state_dict':
                    generator.state_dict(),
                    'readout_dict_state_dict':
                    readout_disc.state_dict() if args.gamma < 1 else None,
                    'discriminator_state_dict':
                    discriminator.state_dict(),
                    'args':
                    args,
                    'optimizerD':
                    optimizerD.state_dict(),
                    'optimizerG':
                    optimizerG.state_dict(),
                    'optimizerF':
                    optimizerF.state_dict(),
                    'train_history': {
                        "decoding_error_history": decoding_error_history,
                        "reconstruction_history": reconstruction_history,
                        "decoding_error_std_history":
                        decoding_error_std_history,
                        "reconstruction_std_history":
                        reconstruction_std_history
                    }
                }, 'checkpoint.pth.tar')