Example #1
0
def MNIST_GAN(epoch, epoch_start, G, G_optim, D, D_optim, dataloader, train_fn,
              history, best_stat, times, config, z_fixed):
    '''
        MNIST dataset training loop for GAN model. Used to train GAN as a
        proof-of-concept, i.e. that the linear GAN model is able to reproduce
        MNIST data, and is therefore (possibly) suited to reproducing the
        LArCV1 dataset. Hopefully this will extend to other datasets...
        - Args: G (Torch model): Generator model
                G_optim (function): G optimizer (either adam or sgd)
                D (Torch model): Discriminator model
                D_optim (function): D optimizer (either adam or sgd)
                Dataloader (iterable): Torch dataloader object wrapped as
                                       tqdm progress bar for terminal output
                train_fn (function): GAN training function selected in train.py
                history, best_stat, times, config (dicts): dictionaries
                epoch, epoch_start (ints)
                z_fixed (Torch tensor): Fixed vector for sampling G at the
                                        end of a training epoch
    '''
    for itr, (x, _) in enumerate(dataloader):
        tr_loop_start = time.time()

        metrics = train_fn(x)
        history, best_stat, best = utils.train_logger(history, best_stat,
                                                      metrics)

        # Save checkpoint periodically
        if (itr % 2000 == 0):
            # G Checkpoint
            chkpt_G = utils.get_checkpoint(itr, epoch, G, G_optim)
            utils.save_checkpoint(chkpt_G, best, 'G', config['weights_save'])

            # D Checkpoint
            chkpt_D = utils.get_checkpoint(itr, epoch, D, D_optim)
            utils.save_checkpoint(chkpt_D, best, 'D', config['weights_save'])

        # Save Generator output periodically
        if (itr % 1000 == 0):
            z_rand = torch.randn(config['sample_size'],
                                 config['z_dim']).to(config['gpu'])
            sample = G(z_rand).view(-1, 1, config['dataset'],
                                    config['dataset'])
            utils.save_sample(sample, epoch, itr, config['random_samples'])

        # Log the time at the end of training loop
        times['tr_loop_times'].append(time.time() - tr_loop_start)

    # Log the time at the end of the training epoch
    times['epoch_times'].append(time.time() - epoch_start)

    # Save Generator output using fixed vector at end of epoch
    sample = G(z_fixed).view(-1, 1, config['dataset'], config['dataset'])
    utils.save_sample(sample, epoch, itr, config['fixed_samples'])

    return history, best_stat, times
Example #2
0
    def train(x, itr, epoch):
        '''
            AE training function
            Does: Trains the AE model
            Args: x (Torch tensor): Real data input image
            Returns: list of training metrics
        '''
        # Make sure model is in training mode
        AE.train()

        # Move input to gpu -- no need to flatten
        x = x.to(config['gpu'])

        # Forward pass
        output = AE(x)

        # Compare output to real data
        loss = loss_fn(output, x)

        # Backprop and update weights
        AE_optim.zero_grad()
        loss.backward()
        AE_optim.step()

        # TODO: check that the samples are saving correctly
        # Save output periodically - concatenate the model outputs with the
        # images it was supposed to reconstruct in order to visualize the
        # model evolution during training.
        if itr % 20 == 0:
            # Arrange training data and model outputs on
            # alternating rows for easy visual comparison.
            row1 = x[0:config['sample_size'] // 2, :]
            row2 = output[0:config['sample_size'] // 2, :]
            row3 = x[config['sample_size'] // 2:config['sample_size'], :]
            row4 = output[config['sample_size'] // 2:config['sample_size'], :]
            sample = torch.cat([row1, row2, row3, row4])
            sample = sample.view(sample.size(0), 1, config['dataset'],
                                 config['dataset'])
            utils.save_sample(sample, epoch, itr, config['random_samples'])

        # Return training metrics
        metrics = {'ae_loss': float(loss.item())}

        return metrics
Example #3
0
def train_mlp(G_mlp, opt_G, psi, opt_psi, dataloader, device, loader_kwargs,
              config, times, losses, hist_dict, checkpt):

    if config['mnist']:
        import poc
        poc.train_MNIST(G_mlp, opt_G, psi, opt_psi, dataloader, device,
                        loader_kwargs, config, times, losses, hist_dict,
                        checkpt)
        return
    # Set fixed noise vector for testing
    z_fixed = utils.get_z(config, device, sample=False)
    z_fixed.resize_((config['batch_size'], config['z_dim']))

    # Create python version of cpp operation
    if config['dist'] == 'W1':
        from torch.utils.cpp_extension import load
        my_ops = load(name="my_ops",
                      sources=[
                          "W1_extension/my_ops.cpp",
                          "W1_extension/my_ops_kernel.cu"
                      ],
                      verbose=False)
        import my_ops
    n_dim = len(dataloader)
    # Training loop
    for epoch in range(config['num_epochs']):

        epoch_start_time = time.time()

        # Save list of losses for end-training determination
        loss_memory = []

        # Set up memory tensors: simple feed-forward distribution, transfer plan
        mu = torch.zeros(config['mem_size'], config['batch_size'],
                         config['z_dim'])
        transfer = torch.zeros(config['mem_size'],
                               config['batch_size'],
                               dtype=torch.long)
        mem_idx = 0

        # Compute Optimal Transport Solver (OTS) over every training example
        ot_start = time.time()

        # for ots_iter in range(0, config['dset_size']):
        for iter in range(0, loader_kwargs['batch_size']):

            opt_psi.zero_grad()

            # Generate samples from feed-forward distribution
            z_batch = utils.get_z(config, device, sample=False)
            z_batch.resize_((config['batch_size'], config['z_dim']))
            y_fake = G_mlp(z_batch)  # [B, n_dim]

            # Compute cost between sample batch and target distribution
            if (config['dist'] == 'W1'):
                score = -my_ops.l1_t(y_fake, dataloader) - psi
            else:
                score = torch.matmul(
                    y_fake, dataloader.t()) - psi  # score: [B, N], psi: [N]

            # phi, hit = torch.max(score, 1)
            phi, hit = torch.min(score, 1)  # [B], [B]

            # Wasserstein distance computation: d(x,y)^p
            if (config['dist'] == 'W1'):
                loss_primal = torch.mean(
                    torch.abs(y_fake - dataloader[hit])) * config['out_dim']
            else:
                loss_primal = torch.mean(
                    (y_fake - dataloader[hit])**2) * config['out_dim']

            # Loss computation
            # loss = -(torch.mean(phi) + torch.mean(psi)) # Testing this
            loss = -torch.mean(psi[hit])  # equiv. to loss?

            # Backprop
            loss.backward()  # Gradient ascent
            opt_psi.step()

            # Append losses to dict
            losses['ot_loss'].append(loss.item())
            losses['w2_estim'].append(loss_primal.item())

            # Update memory tensors
            mu[mem_idx] = z_batch
            transfer[mem_idx] = hit
            mem_idx = (mem_idx + 1) % config['mem_size']

            if (iter % 500 == 0):
                print('OTS Iteration {} | Epoch {}'.format(iter, epoch))
            if (iter % (config['num_epochs'] * 10) == 0):
                # Display histogram stats
                hist_dict, stop = utils.update_histogram(
                    transfer, n_dim, epoch, iter, config, losses, hist_dict)
                # Emperical stopping criterion
                if stop:
                    break

        # Compute OTS time and append
        ot_end = time.time()
        times['ot_time'].append(ot_end - ot_start)

        # Compute Fitting Optimal Transport Plan (FIT)
        fit_start = time.time()
        for fit_iter in range(config['mem_size']):

            opt_G.zero_grad()

            # Get stored batch of generated samples
            z_batch = mu[fit_iter].to(device)
            y_fake = G_mlp(z_batch)  # G'(z)

            # Get Transfer plan from OTS: T(G_{t-1}(z))
            y0_hit = dataloader[transfer[fit_iter].to(device)]

            # Compute Wasserstein distance between G and T
            if (config['dist'] == 'W1'):
                loss_g = torch.mean(
                    torch.abs(y0_hit - y_fake)) * config['out_dim']
            else:
                loss_g = torch.mean((y0_hit - y_fake)**2) * config['out_dim']

            # Backprop
            loss_g.backward()  # Gradient descent
            opt_G.step()

            # Append losses to dict
            losses['g_loss'].append(loss_g.item())
            loss_memory.append(loss_g.item())

            if (fit_iter % 500 == 0):
                print(
                    'Fit_iter: {} | Epoch: {} | Loss: {:.2f} | Best Loss: {:.2f}'
                    .format(fit_iter, epoch, loss_g, checkpt['best']))

            # Check if best loss value and save checkpoint
            threshold = (checkpt['best'] - round(checkpt['best'] * 0.5))
            best = (loss_g.item() < threshold)
            if best:
                checkpt['best'] = loss_g.item()
                chkpt_dict = utils.checkpoint_dict(fit_iter, epoch, G_mlp,
                                                   opt_G)
                utils.save_checkpoint(chkpt_dict, best, epoch, -1,
                                      config['weights_root'])

            # Save periodic checkpoint
            if (fit_iter % (config['num_epochs'] * 5) == 0):
                chkpt_dict = utils.checkpoint_dict(fit_iter, epoch, G_mlp,
                                                   opt_G)
                utils.save_checkpoint(chkpt_dict, False, epoch, fit_iter,
                                      config['weights_root'])

            # Get random sample from G
            if (fit_iter % (config['num_epochs']) == 0):
                z_rand = utils.get_z(config, device, sample=True)
                z_rand.resize_((config['batch_size'], config['z_dim']))
                sample = G_mlp(z_rand).view(-1, 1, config['imsize'],
                                            config['imsize'])
                utils.save_sample(sample, epoch, fit_iter, config['random'])

            # Check if loss is changing - stop training if no change
            if (len(loss_memory) > (config['mem_size'] // 2)):
                if ((loss_g <= (mean(loss_memory)*.999)) and \
                    (loss_g >= (mean(loss_memory)*.995))):
                    break

        # Compute FIT time
        fit_end = time.time()
        times['fit_time'].append(fit_end - fit_start)

        # Compute epoch time
        times['epoch_times'].append(time.time() - epoch_start_time)

        # Output to terminal
        print('Best loss:  {}'.format(checkpt['best']))
        print('Epoch_time: {}'.format(time.time() - epoch_start_time))
        print('Num epochs: {}'.format(epoch))
        print("FIT loss: {:.2f}".format(np.mean(losses['g_loss'])))

        # Save fixed sample at end of training epoch
        sample = G_mlp(z_fixed).view(-1, 1, config['imsize'], config['imsize'])
        utils.save_sample(sample, epoch, 0, config['fixed'])

    return [times, losses, hist_dict]
Example #4
0
def train(config):
    '''
        Training function for EWM generator model.
    '''
    # Create python version of cpp operation
    # (Credit: Chen, arXiv:1906.03471, GitHub: https://github.com/chen0706/EWM)
    from torch.utils.cpp_extension import load
    my_ops = load(name = "my_ops",
                  sources = ["W1_extension/my_ops.cpp",
                             "W1_extension/my_ops_kernel.cu"],
                  verbose = False)
    import my_ops

    # Set up GPU device ordinal
    device = torch.device(config['gpu'])

    # Get model kwargs for convolutional generator
    config['model'] == 'ewm_conv'
    emw_kwargs = setup_model.ewm_kwargs(config)

    # Setup convolutional generator model on GPU
    G = ewm.ewm_convG(**emw_kwargs).to(device)
#     ewm.weights_init(G)
    G.train()
    
    # Setup model optimizer
    model_params = {'g_params': G.parameters()}
    G_optim = utils.get_optim(config, model_params)
    
    # Testing: MSE loss for conv_generator reconstruction
    loss_fn = nn.MSELoss().to(device)
    
    # Print G -- make sure it's right before continuing
    print(G)
    input('Press any key to continue')
    
    # Setup source of structured noise on GPU (Trained EWM_MLP_Generator Model)
    # Add these configs to the experiment source file
    if not config['ewm_root']:
        raise Exception('Path to trained EWM_MLP Model must be specified')

    # Print list of evaluated EWM model
    EWM_paths = []; EWM_root = config['ewm_root']
    for path in os.listdir(EWM_root):
        EWM_paths.append(os.path.join(EWM_root, path))
    print("-"*60)
    for i in range(len(EWM_paths)):
        EWM_name = EWM_paths[i].split('/')[-1]
        print("\n Exp_{}:".format(str(i)), EWM_name, '\n')
        print("-"*60)
    
    # Select the trained model
    model_num = input('Select EWM_MLP model (enter integer): ')
    EWM_dir = EWM_paths[int(model_num)]
    
    # Create the full path to the EWM model
    EWM_path = os.path.join(EWM_root, EWM_dir) + '/'
    print("Path to EWM Generator Model set as: \n{}".format(EWM_path))
    
    config_csv = EWM_path + "config.csv"
    config_df = pd.read_csv(config_csv, delimiter = ",")

    # Get the model architecture from config df
    n_layers = int(config_df[config_df['Unnamed: 0'].str.contains("n_layers")==True]['0'].values.item())
    n_hidden = int(config_df[config_df['Unnamed: 0'].str.contains("n_hidden")==True]['0'].values.item())
    l_dim    = int(config_df[config_df['Unnamed: 0'].str.contains("l_dim")==True]['0'].values.item())
    im_size  = int(config_df[config_df['Unnamed: 0'].str.contains("dataset")==True]['0'].values.item())
    z_dim    = int(config_df[config_df['Unnamed: 0'].str.contains("z_dim")==True]['0'].values.item())

    # Model kwargs
    ewm_kwargs = {'z_dim': z_dim, 'fc_sizes': [n_hidden]*n_layers, 'n_out': l_dim}

    # Send model to GPU
    Gz = ewm.ewm_G(**ewm_kwargs).to(device)
 
    # Load the model checkpoint
    # Get checkpoint name(s)
    EWM_checkpoint_path  = EWM_path + 'weights/'
    EWM_checkpoint_names = []
    for file in os.listdir(EWM_checkpoint_path):
        EWM_checkpoint_names.append(os.path.join(EWM_checkpoint_path, file))
    print("-"*60)
    for i in range(len(EWM_checkpoint_names)):
        name = EWM_checkpoint_names[i].split('/')[-1]
        print("\n {} :".format(str(i)), name, '\n')
        print("-"*60)
    file_num = input("Select a checkpoint file for EWM_MLP (enter integer): ")
    EWM_checkpoint = EWM_checkpoint_names[int(file_num)]

    # Load the model checkpoint
    # Keys: ['state_dict', 'epoch', 'optimizer']
    checkpoint = torch.load(EWM_checkpoint)

    # Load the model's state dictionary
    Gz.load_state_dict(checkpoint['state_dict'])

    # Use the code_vector model in evaluation mode -- no need for gradients here
    Gz.eval()
    print(Gz)
    input('Press any key to continue')

    # Set up full_dataloader (single batch)
    dataloader = utils.get_dataloader(config).to(device) # Full Dataloader
    dset_size  = len(dataloader)

    # Flatten the dataloader into a Tensor of shape [dset_size, l_dim]
#     dataloader = dataloader.view(dset_size, -1).to(device)

    # Set up psi optimizer
    psi = torch.zeros(dset_size, requires_grad=True).to(device).detach().requires_grad_(True)
    psi_optim = torch.optim.Adam([psi], lr=config['psi_lr'])

    # Set up directories for saving training stats and outputs
    config = utils.directories(config)

    # Set up dict for saving checkpoints
    checkpoint_kwargs = {'G':G, 'G_optim':G_optim}

    # Set up stats logging
    hist_dict = {'hist_min':[], 'hist_max':[], 'ot_loss':[]}
    losses    = {'ot_loss': [], 'fit_loss': []}
    history   = {'dset_size': dset_size, 'epoch': 0, 'iter': 0,
                 'losses'   : losses, 'hist_dict': hist_dict}
    config['early_end'] = (200, 320) # Empirical stopping criterion from EWM author

    stop_counter = 0

    # Compute how the input vectors need to be reshaped, based on conv_G input layer
    in_f  = G.main[0][2].in_channels; out_f = G.main[0][2].out_channels
    
    # Set the height and width of the feature maps. Note: Manually setting this to 8 is
    # hackish, but computing the actualy value requires knowing the number of layers in
    # the AutoEncoder whose code layer was used to train the generator model being loaded
    # here. I'm avoiding loading multiple paths and dataframes by simply setting it to 8, 
    # but maybe you can do better than I did...
    H = W = 8
    print('Vectors will be reshaped as: [{}] --> [{},{},{}]'.format(l_dim, in_f, H, W))
    
    # Set a fixed feature tensor for testing
    noise  = torch.randn(config['sample_size'], config['z_dim']).to(device)
    z_fixed = Gz(noise).view(-1, in_f, H, W).to(device)

    # Training Loop
    input('\nPress any key to launch -- good luck out there\n')

    # Set up progress bar for terminal output and enumeration
    epoch_bar  = tqdm([i for i in range(config['num_epochs'])])
    for epoch, _ in enumerate(epoch_bar):

        history['epoch'] = epoch

        # Set up memory lists: 
        #     - mu: simple feed-forward distribution 
        #     - transfer: transfer plan given by lists of indices
        # Rule-of-thumb: do not save the tensors themselves: instead, save the 
        #                data as a list and covert it to a tensor as needed.
        mu = [0] * config['mem_size']
        transfer = [0] * config['mem_size']
        mem_idx = 0

        # Compute the Optimal Transport Solver
        print("\nOptimal Transport Solver")
        ots_bar  = tqdm([i for i in range(dset_size//10)])
        for ots_iter, _ in enumerate(ots_bar):
            history['iter'] = ots_iter

            psi_optim.zero_grad()

            # Generate samples from cove_vector distribution
            z_batch = torch.randn(config['batch_size'], config['z_dim']).to(device)
            z_batch = Gz(z_batch).view(-1, in_f, H, W)

            # Push structured noise vector through convolutional generator
#             y_fake = G(z_batch).view(config['batch_size'], -1) # Flatten the output to match dataloader
            y_fake = G(z_batch)
            
            # Compute the W1 distance between the model output and the target distribution
            score = my_ops.l1_t(y_fake, dataloader) - psi

            phi, hit = torch.max(score, 1)

            # Standard loss computation
            # This loss defines the sample mean of the marginal distribution
            # of the dataset. This is the only computation that generalizes.
            loss = -torch.mean(psi[hit])

            # Backprop
            loss.backward()
            psi_optim.step()

            # Update memory tensors (lists)
            mu[mem_idx] = z_batch.data.cpu().numpy().tolist()
            transfer[mem_idx] = hit.data.cpu().numpy().tolist()
            mem_idx = (mem_idx + 1) % config['mem_size']

            # Update losses
            history['losses']['ot_loss'].append(loss.item())

            if (ots_iter % 50 == 0):
                avg_loss = np.mean(history['losses']['ot_loss'])
#                 print('OTS Iteration {} | Epoch {} | Avg Loss Value: {}'.format(ots_iter, epoch, round(avg_loss, 3)))
            if (ots_iter % 2000 == 0):
                # Occasionally save a random sample from the generator during OTS
                sample = y_fake[0:config['sample_size']].view(-1, 1, config['dataset'], config['dataset'])
                utils.save_sample(sample, epoch, ots_iter, config['random_samples'])
                
#                 # Display histogram stats
#                 hist_dict, stop = utils.update_histogram(transfer, history, config)
#                 # Emperical stopping criterion
#                 if stop:
#                     break

        # Compute the Optimal Fitting Transport Plan
        print("\nFitting Optimal Transport Plan")
        fit_bar  = tqdm([i for i in range(config['mem_size'])])
        for fit_iter, _ in enumerate(fit_bar):
            G_optim.zero_grad()
            
            # Retrieve stored batch of generated samples
#             z_batch = torch.tensor(mu[fit_iter]).view(-1, in_f, H, W).to(device)
            z_batch = torch.tensor(mu[fit_iter]).to(device)
            
            # Flatten the model output to match dataloader
#             y_fake = G(z_batch).view(config['batch_size'], -1)
            y_fake = G(z_batch)
            
            # Get Transfer plan from OTS: T(G_{t-1}(Gz))
            t_plan = torch.tensor(transfer[fit_iter])
            y0_hit = dataloader[t_plan].to(device)

            # Compute Wasserstein distance between G and T
#             G_loss = torch.mean(torch.abs(y0_hit - y_fake)) * l_dim
            G_loss = loss_fn(y_fake, y0_hit)
            
            # Backprop
            G_loss.backward() # Gradient descent
            G_optim.step()

            # Update losses
            history['losses']['fit_loss'].append(G_loss.item())

            # Check if best loss value and save checkpoint
            if 'best_loss' not in history:
                history.update({ 'best_loss' : G_loss.item() })

            best = G_loss.item() < (history['best_loss'] * 0.5)
            if best:
                history['best_loss'] = G_loss.item()
                checkpoint = utils.get_checkpoint(history['epoch'], checkpoint_kwargs, config)
                utils.save_checkpoint(checkpoint, config)

            if (fit_iter % 50 == 0):
                avg_loss = np.mean(history['losses']['fit_loss'])
#                 print('FIT Iteration {} | Epoch {} | Avg Loss Value: {}'.format(fit_iter, epoch, round(avg_loss,3)))
         
        # Save a fixed sample of the generator's output at the end of FIT
        sample = G(z_fixed).view(-1, 1, config['dataset'], config['dataset'])
        utils.save_sample(sample, epoch, fit_iter, config['fixed_samples'])

    # Save a checkpoint at end of training
    checkpoint = utils.get_checkpoint(history['epoch'], checkpoint_kwargs, config)
    utils.save_checkpoint(checkpoint, config)

    # Save training data to csv's after training end
    utils.save_train_hist(history, config, times=None, histogram=history['hist_dict'])
    print("Stop Counter Triggered {} Times".format(stop_counter))

    # For Spike
    print("See you, Space Cowboy")
Example #5
0
def train_MNIST(config):
    # Create python version of cpp operation
    if config['dist'] == 'W1':
        print("Building C++ extension for W1 (requires PyTorch >= 1.0.0)...")
        from torch.utils.cpp_extension import load
        my_ops = load(name="my_ops",
                      sources=[
                          "W1_extension/my_ops.cpp",
                          "W1_extension/my_ops_kernel.cu"
                      ],
                      verbose=False)
        import my_ops
        print("Building complete")

    # Centralize stats logging
    times, losses, hist_dict, checkpt = utils.centralized_logs()

    # Select device
    device = torch.device(config['gpu'])

    # Update config dict with MNIST params and get MNIST dataset as one batch
    # config, mnist_data = utils.MNIST(config)
    '''
        Returns MNIST training data as a single batch of data
    '''
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, ), (0.5, ))])
    dataset = dset.MNIST(root=config['data_root'],
                         train=True,
                         download=True,
                         transform=transform)

    def get_data(dataset):
        full_dataloader = torch.utils.data.DataLoader(dataset,
                                                      batch_size=len(dataset))
        for y_batch, l_batch in full_dataloader:
            return y_batch

    y_t = get_data(dataset)
    n_dim = len(y_t)
    mnist_data = y_t.view(n_dim, -1).to(device)

    config.update({
        'dset_size': n_dim,
        'imsize': 28,
        'out_dim': 784,
        'batch_size': 64,
        'sample_size': 16,
        'early_end': (200, 320)
    })

    # Set MLP architecture
    G_arch = utils.get_mlp_arch(config)

    # Create G_model
    G_mlp = utils.get_model(G_arch, device)

    # Get optimizer
    opt_G = utils.get_optim(G_arch, G_mlp.parameters(), MNIST=True)
    print(G_mlp)

    # Initialize G_model weights and get the number of layers
    # G_mlp = utils.weights_init(G_mlp, MNIST=True)
    def initialize_weights(net):
        for m in net.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.02)
            if hasattr(m, "bias") and m.bias is not None:
                m.bias.data.zero_()

    initialize_weights(G_mlp)

    # Create labels for experiment
    labels = utils.create_labels(config)

    # Update config with labels and save locations
    config = utils.update_with_labels(config, labels)

    # Setup psi optimizer
    psi = torch.zeros(n_dim, requires_grad=True, device=device)
    opt_psi = torch.optim.Adam([psi], lr=1e-1)

    # Set fixed noise vector for testing
    z_fixed = utils.get_z(config, device, sample=False)
    z_fixed.resize_((config['batch_size'], config['z_dim'])).to(device)

    # Training loop
    for epoch in range(config['num_epochs']):

        epoch_start_time = time.time()

        # Save list of losses for end-training determination
        loss_memory = []

        # Set up memory tensors: simple feed-forward distribution, transfer plan
        mu = torch.zeros(config['mem_size'], config['batch_size'],
                         config['z_dim'])
        transfer = torch.zeros(config['mem_size'],
                               config['batch_size'],
                               dtype=torch.long)
        mem_idx = 0

        # Compute Optimal Transport Solver (OTS) over every training example
        ot_start = time.time()

        # for ots_iter in range(0, config['dset_size']):
        for iter in range(1, 20001):

            opt_psi.zero_grad()

            # Generate samples from feed-forward distribution
            z_batch = utils.get_z(config, device, sample=False)
            z_batch.resize_((config['batch_size'], config['z_dim'])).to(device)
            y_fake = G_mlp(z_batch)  # [B, n_dim]

            # Compute cost between sample batch and target distribution
            if (config['dist'] == 'W1'):
                score = -my_ops.l1_t(y_fake, mnist_data) - psi
            else:
                score = torch.matmul(
                    y_fake, mnist_data.t()) - psi  # score: [B, N], psi: [N]

            phi, hit = torch.max(score, 1)
            # phi, hit = torch.min(score, 1) # [B], [B]

            # Wasserstein distance computation: d(x,y)^p
            if (config['dist'] == 'W1'):
                loss_primal = torch.mean(
                    torch.abs(y_fake - mnist_data[hit])) * config['out_dim']
            else:
                loss_primal = torch.mean(
                    (y_fake - mnist_data[hit])**2) * config['out_dim']

            # Loss computation
            # loss = (torch.mean(phi) + torch.mean(psi)) # Testing this
            loss = -torch.mean(psi[hit])  # equiv. to loss?

            # Backprop
            loss.backward()  # Gradient ascent
            opt_psi.step()

            # Append losses to dict
            losses['ot_loss'].append(loss.item())
            losses['w2_estim'].append(loss_primal.item())

            # Update memory tensors
            mu[mem_idx] = z_batch
            transfer[mem_idx] = hit
            mem_idx = (mem_idx + 1) % config['mem_size']

            if (iter % 500 == 0):
                print('OTS Iteration {} | Epoch {}'.format(iter, epoch))
            if (iter % 2000 == 0):
                # Display histogram stats
                hist_dict, stop = utils.update_histogram(
                    transfer, n_dim, epoch, iter, config, losses, hist_dict)
                # Emperical stopping criterion
                if stop:
                    break

        # Compute OTS time and append
        ot_end = time.time()
        times['ot_time'].append(ot_end - ot_start)

        # Compute Fitting Optimal Transport Plan (FIT)
        fit_start = time.time()
        for fit_iter in range(config['mem_size']):

            opt_G.zero_grad()

            # Get stored batch of generated samples
            z_batch = mu[fit_iter].to(device)
            y_fake = G_mlp(z_batch)  # G'(z)

            # Get Transfer plan from OTS: T(G_{t-1}(z))
            y0_hit = mnist_data[transfer[fit_iter].to(device)]

            # Compute Wasserstein distance between G and T
            if (config['dist'] == 'W1'):
                loss_g = torch.mean(
                    torch.abs(y0_hit - y_fake)) * config['out_dim']
            else:
                loss_g = torch.mean((y0_hit - y_fake)**2) * config['out_dim']

            # Backprop
            loss_g.backward()  # Gradient descent
            opt_G.step()

            # Append losses to dict
            losses['g_loss'].append(loss_g.item())
            loss_memory.append(loss_g.item())

            if (fit_iter % 500 == 0):
                print(
                    'Fit_iter: {} | Epoch: {} | Loss: {:.2f} | Best Loss: {:.2f}'
                    .format(fit_iter, epoch, loss_g, checkpt['best']))

            # Check if best loss value and save checkpoint
            # threshold = (checkpt['best'] - round(checkpt['best']*0.5))
            # best = ( loss_g.item() < threshold )
            # if best:
            #     checkpt['best'] = loss_g.item()
            #     chkpt_dict = utils.checkpoint_dict(fit_iter, epoch, G_mlp, opt_G)
            #     utils.save_checkpoint(chkpt_dict, best, epoch, -1, config['weights_root'])

            # Save periodic checkpoint
            # if (fit_iter % 2000 == 0):
            #     chkpt_dict = utils.checkpoint_dict(fit_iter, epoch, G_mlp, opt_G)
            #     utils.save_checkpoint(chkpt_dict, False, epoch, fit_iter, config['weights_root'])

            # Get random sample from G
            if (fit_iter % 1000 == 0):
                z_rand = utils.get_z(config, device, sample=True)
                z_rand.resize_(
                    (config['sample_size'], config['z_dim'])).to(device)
                sample = G_mlp(z_rand).view(-1, 1, config['imsize'],
                                            config['imsize'])
                utils.save_sample(sample, epoch, fit_iter, config['random'])

            # Check if loss is changing - stop training if no change
            if (len(loss_memory) > (config['mem_size'] // 2)):
                if ((loss_g <= (mean(loss_memory)*.999)) and \
                    (loss_g >= (mean(loss_memory)*.995))):
                    break

        # Compute FIT time
        fit_end = time.time()
        times['fit_time'].append(fit_end - fit_start)

        # Compute epoch time
        times['epoch_times'].append(time.time() - epoch_start_time)

        # Output to terminal
        print('Best loss:  {}'.format(checkpt['best']))
        print('Epoch_time: {}'.format(time.time() - epoch_start_time))
        print('Num epochs: {}'.format(epoch))
        print("FIT loss: {:.2f}".format(np.mean(losses['g_loss'])))

        # Save fixed sample at end of training epoch
        sample = G_mlp(z_fixed).view(-1, 1, config['imsize'], config['imsize'])
        utils.save_sample(sample, epoch, 0, config['fixed'])

    # Save training data to csv after training completion
    utils.save_stats(times, losses, hist_dict, G_arch, config['save_root'])