Ejemplo n.º 1
0
def train_mlp(G_mlp, opt_G, psi, opt_psi, dataloader, device, loader_kwargs,
              config, times, losses, hist_dict, checkpt):

    if config['mnist']:
        import poc
        poc.train_MNIST(G_mlp, opt_G, psi, opt_psi, dataloader, device,
                        loader_kwargs, config, times, losses, hist_dict,
                        checkpt)
        return
    # Set fixed noise vector for testing
    z_fixed = utils.get_z(config, device, sample=False)
    z_fixed.resize_((config['batch_size'], config['z_dim']))

    # Create python version of cpp operation
    if config['dist'] == 'W1':
        from torch.utils.cpp_extension import load
        my_ops = load(name="my_ops",
                      sources=[
                          "W1_extension/my_ops.cpp",
                          "W1_extension/my_ops_kernel.cu"
                      ],
                      verbose=False)
        import my_ops
    n_dim = len(dataloader)
    # Training loop
    for epoch in range(config['num_epochs']):

        epoch_start_time = time.time()

        # Save list of losses for end-training determination
        loss_memory = []

        # Set up memory tensors: simple feed-forward distribution, transfer plan
        mu = torch.zeros(config['mem_size'], config['batch_size'],
                         config['z_dim'])
        transfer = torch.zeros(config['mem_size'],
                               config['batch_size'],
                               dtype=torch.long)
        mem_idx = 0

        # Compute Optimal Transport Solver (OTS) over every training example
        ot_start = time.time()

        # for ots_iter in range(0, config['dset_size']):
        for iter in range(0, loader_kwargs['batch_size']):

            opt_psi.zero_grad()

            # Generate samples from feed-forward distribution
            z_batch = utils.get_z(config, device, sample=False)
            z_batch.resize_((config['batch_size'], config['z_dim']))
            y_fake = G_mlp(z_batch)  # [B, n_dim]

            # Compute cost between sample batch and target distribution
            if (config['dist'] == 'W1'):
                score = -my_ops.l1_t(y_fake, dataloader) - psi
            else:
                score = torch.matmul(
                    y_fake, dataloader.t()) - psi  # score: [B, N], psi: [N]

            # phi, hit = torch.max(score, 1)
            phi, hit = torch.min(score, 1)  # [B], [B]

            # Wasserstein distance computation: d(x,y)^p
            if (config['dist'] == 'W1'):
                loss_primal = torch.mean(
                    torch.abs(y_fake - dataloader[hit])) * config['out_dim']
            else:
                loss_primal = torch.mean(
                    (y_fake - dataloader[hit])**2) * config['out_dim']

            # Loss computation
            # loss = -(torch.mean(phi) + torch.mean(psi)) # Testing this
            loss = -torch.mean(psi[hit])  # equiv. to loss?

            # Backprop
            loss.backward()  # Gradient ascent
            opt_psi.step()

            # Append losses to dict
            losses['ot_loss'].append(loss.item())
            losses['w2_estim'].append(loss_primal.item())

            # Update memory tensors
            mu[mem_idx] = z_batch
            transfer[mem_idx] = hit
            mem_idx = (mem_idx + 1) % config['mem_size']

            if (iter % 500 == 0):
                print('OTS Iteration {} | Epoch {}'.format(iter, epoch))
            if (iter % (config['num_epochs'] * 10) == 0):
                # Display histogram stats
                hist_dict, stop = utils.update_histogram(
                    transfer, n_dim, epoch, iter, config, losses, hist_dict)
                # Emperical stopping criterion
                if stop:
                    break

        # Compute OTS time and append
        ot_end = time.time()
        times['ot_time'].append(ot_end - ot_start)

        # Compute Fitting Optimal Transport Plan (FIT)
        fit_start = time.time()
        for fit_iter in range(config['mem_size']):

            opt_G.zero_grad()

            # Get stored batch of generated samples
            z_batch = mu[fit_iter].to(device)
            y_fake = G_mlp(z_batch)  # G'(z)

            # Get Transfer plan from OTS: T(G_{t-1}(z))
            y0_hit = dataloader[transfer[fit_iter].to(device)]

            # Compute Wasserstein distance between G and T
            if (config['dist'] == 'W1'):
                loss_g = torch.mean(
                    torch.abs(y0_hit - y_fake)) * config['out_dim']
            else:
                loss_g = torch.mean((y0_hit - y_fake)**2) * config['out_dim']

            # Backprop
            loss_g.backward()  # Gradient descent
            opt_G.step()

            # Append losses to dict
            losses['g_loss'].append(loss_g.item())
            loss_memory.append(loss_g.item())

            if (fit_iter % 500 == 0):
                print(
                    'Fit_iter: {} | Epoch: {} | Loss: {:.2f} | Best Loss: {:.2f}'
                    .format(fit_iter, epoch, loss_g, checkpt['best']))

            # Check if best loss value and save checkpoint
            threshold = (checkpt['best'] - round(checkpt['best'] * 0.5))
            best = (loss_g.item() < threshold)
            if best:
                checkpt['best'] = loss_g.item()
                chkpt_dict = utils.checkpoint_dict(fit_iter, epoch, G_mlp,
                                                   opt_G)
                utils.save_checkpoint(chkpt_dict, best, epoch, -1,
                                      config['weights_root'])

            # Save periodic checkpoint
            if (fit_iter % (config['num_epochs'] * 5) == 0):
                chkpt_dict = utils.checkpoint_dict(fit_iter, epoch, G_mlp,
                                                   opt_G)
                utils.save_checkpoint(chkpt_dict, False, epoch, fit_iter,
                                      config['weights_root'])

            # Get random sample from G
            if (fit_iter % (config['num_epochs']) == 0):
                z_rand = utils.get_z(config, device, sample=True)
                z_rand.resize_((config['batch_size'], config['z_dim']))
                sample = G_mlp(z_rand).view(-1, 1, config['imsize'],
                                            config['imsize'])
                utils.save_sample(sample, epoch, fit_iter, config['random'])

            # Check if loss is changing - stop training if no change
            if (len(loss_memory) > (config['mem_size'] // 2)):
                if ((loss_g <= (mean(loss_memory)*.999)) and \
                    (loss_g >= (mean(loss_memory)*.995))):
                    break

        # Compute FIT time
        fit_end = time.time()
        times['fit_time'].append(fit_end - fit_start)

        # Compute epoch time
        times['epoch_times'].append(time.time() - epoch_start_time)

        # Output to terminal
        print('Best loss:  {}'.format(checkpt['best']))
        print('Epoch_time: {}'.format(time.time() - epoch_start_time))
        print('Num epochs: {}'.format(epoch))
        print("FIT loss: {:.2f}".format(np.mean(losses['g_loss'])))

        # Save fixed sample at end of training epoch
        sample = G_mlp(z_fixed).view(-1, 1, config['imsize'], config['imsize'])
        utils.save_sample(sample, epoch, 0, config['fixed'])

    return [times, losses, hist_dict]
Ejemplo n.º 2
0
def train(config):
    '''
        Training function for EWM generator model.
    '''
    # Create python version of cpp operation
    # (Credit: Chen, arXiv:1906.03471, GitHub: https://github.com/chen0706/EWM)
    from torch.utils.cpp_extension import load
    my_ops = load(name = "my_ops",
                  sources = ["W1_extension/my_ops.cpp",
                             "W1_extension/my_ops_kernel.cu"],
                  verbose = False)
    import my_ops

    # Set up GPU device ordinal
    device = torch.device(config['gpu'])

    # Get model kwargs for convolutional generator
    config['model'] == 'ewm_conv'
    emw_kwargs = setup_model.ewm_kwargs(config)

    # Setup convolutional generator model on GPU
    G = ewm.ewm_convG(**emw_kwargs).to(device)
#     ewm.weights_init(G)
    G.train()
    
    # Setup model optimizer
    model_params = {'g_params': G.parameters()}
    G_optim = utils.get_optim(config, model_params)
    
    # Testing: MSE loss for conv_generator reconstruction
    loss_fn = nn.MSELoss().to(device)
    
    # Print G -- make sure it's right before continuing
    print(G)
    input('Press any key to continue')
    
    # Setup source of structured noise on GPU (Trained EWM_MLP_Generator Model)
    # Add these configs to the experiment source file
    if not config['ewm_root']:
        raise Exception('Path to trained EWM_MLP Model must be specified')

    # Print list of evaluated EWM model
    EWM_paths = []; EWM_root = config['ewm_root']
    for path in os.listdir(EWM_root):
        EWM_paths.append(os.path.join(EWM_root, path))
    print("-"*60)
    for i in range(len(EWM_paths)):
        EWM_name = EWM_paths[i].split('/')[-1]
        print("\n Exp_{}:".format(str(i)), EWM_name, '\n')
        print("-"*60)
    
    # Select the trained model
    model_num = input('Select EWM_MLP model (enter integer): ')
    EWM_dir = EWM_paths[int(model_num)]
    
    # Create the full path to the EWM model
    EWM_path = os.path.join(EWM_root, EWM_dir) + '/'
    print("Path to EWM Generator Model set as: \n{}".format(EWM_path))
    
    config_csv = EWM_path + "config.csv"
    config_df = pd.read_csv(config_csv, delimiter = ",")

    # Get the model architecture from config df
    n_layers = int(config_df[config_df['Unnamed: 0'].str.contains("n_layers")==True]['0'].values.item())
    n_hidden = int(config_df[config_df['Unnamed: 0'].str.contains("n_hidden")==True]['0'].values.item())
    l_dim    = int(config_df[config_df['Unnamed: 0'].str.contains("l_dim")==True]['0'].values.item())
    im_size  = int(config_df[config_df['Unnamed: 0'].str.contains("dataset")==True]['0'].values.item())
    z_dim    = int(config_df[config_df['Unnamed: 0'].str.contains("z_dim")==True]['0'].values.item())

    # Model kwargs
    ewm_kwargs = {'z_dim': z_dim, 'fc_sizes': [n_hidden]*n_layers, 'n_out': l_dim}

    # Send model to GPU
    Gz = ewm.ewm_G(**ewm_kwargs).to(device)
 
    # Load the model checkpoint
    # Get checkpoint name(s)
    EWM_checkpoint_path  = EWM_path + 'weights/'
    EWM_checkpoint_names = []
    for file in os.listdir(EWM_checkpoint_path):
        EWM_checkpoint_names.append(os.path.join(EWM_checkpoint_path, file))
    print("-"*60)
    for i in range(len(EWM_checkpoint_names)):
        name = EWM_checkpoint_names[i].split('/')[-1]
        print("\n {} :".format(str(i)), name, '\n')
        print("-"*60)
    file_num = input("Select a checkpoint file for EWM_MLP (enter integer): ")
    EWM_checkpoint = EWM_checkpoint_names[int(file_num)]

    # Load the model checkpoint
    # Keys: ['state_dict', 'epoch', 'optimizer']
    checkpoint = torch.load(EWM_checkpoint)

    # Load the model's state dictionary
    Gz.load_state_dict(checkpoint['state_dict'])

    # Use the code_vector model in evaluation mode -- no need for gradients here
    Gz.eval()
    print(Gz)
    input('Press any key to continue')

    # Set up full_dataloader (single batch)
    dataloader = utils.get_dataloader(config).to(device) # Full Dataloader
    dset_size  = len(dataloader)

    # Flatten the dataloader into a Tensor of shape [dset_size, l_dim]
#     dataloader = dataloader.view(dset_size, -1).to(device)

    # Set up psi optimizer
    psi = torch.zeros(dset_size, requires_grad=True).to(device).detach().requires_grad_(True)
    psi_optim = torch.optim.Adam([psi], lr=config['psi_lr'])

    # Set up directories for saving training stats and outputs
    config = utils.directories(config)

    # Set up dict for saving checkpoints
    checkpoint_kwargs = {'G':G, 'G_optim':G_optim}

    # Set up stats logging
    hist_dict = {'hist_min':[], 'hist_max':[], 'ot_loss':[]}
    losses    = {'ot_loss': [], 'fit_loss': []}
    history   = {'dset_size': dset_size, 'epoch': 0, 'iter': 0,
                 'losses'   : losses, 'hist_dict': hist_dict}
    config['early_end'] = (200, 320) # Empirical stopping criterion from EWM author

    stop_counter = 0

    # Compute how the input vectors need to be reshaped, based on conv_G input layer
    in_f  = G.main[0][2].in_channels; out_f = G.main[0][2].out_channels
    
    # Set the height and width of the feature maps. Note: Manually setting this to 8 is
    # hackish, but computing the actualy value requires knowing the number of layers in
    # the AutoEncoder whose code layer was used to train the generator model being loaded
    # here. I'm avoiding loading multiple paths and dataframes by simply setting it to 8, 
    # but maybe you can do better than I did...
    H = W = 8
    print('Vectors will be reshaped as: [{}] --> [{},{},{}]'.format(l_dim, in_f, H, W))
    
    # Set a fixed feature tensor for testing
    noise  = torch.randn(config['sample_size'], config['z_dim']).to(device)
    z_fixed = Gz(noise).view(-1, in_f, H, W).to(device)

    # Training Loop
    input('\nPress any key to launch -- good luck out there\n')

    # Set up progress bar for terminal output and enumeration
    epoch_bar  = tqdm([i for i in range(config['num_epochs'])])
    for epoch, _ in enumerate(epoch_bar):

        history['epoch'] = epoch

        # Set up memory lists: 
        #     - mu: simple feed-forward distribution 
        #     - transfer: transfer plan given by lists of indices
        # Rule-of-thumb: do not save the tensors themselves: instead, save the 
        #                data as a list and covert it to a tensor as needed.
        mu = [0] * config['mem_size']
        transfer = [0] * config['mem_size']
        mem_idx = 0

        # Compute the Optimal Transport Solver
        print("\nOptimal Transport Solver")
        ots_bar  = tqdm([i for i in range(dset_size//10)])
        for ots_iter, _ in enumerate(ots_bar):
            history['iter'] = ots_iter

            psi_optim.zero_grad()

            # Generate samples from cove_vector distribution
            z_batch = torch.randn(config['batch_size'], config['z_dim']).to(device)
            z_batch = Gz(z_batch).view(-1, in_f, H, W)

            # Push structured noise vector through convolutional generator
#             y_fake = G(z_batch).view(config['batch_size'], -1) # Flatten the output to match dataloader
            y_fake = G(z_batch)
            
            # Compute the W1 distance between the model output and the target distribution
            score = my_ops.l1_t(y_fake, dataloader) - psi

            phi, hit = torch.max(score, 1)

            # Standard loss computation
            # This loss defines the sample mean of the marginal distribution
            # of the dataset. This is the only computation that generalizes.
            loss = -torch.mean(psi[hit])

            # Backprop
            loss.backward()
            psi_optim.step()

            # Update memory tensors (lists)
            mu[mem_idx] = z_batch.data.cpu().numpy().tolist()
            transfer[mem_idx] = hit.data.cpu().numpy().tolist()
            mem_idx = (mem_idx + 1) % config['mem_size']

            # Update losses
            history['losses']['ot_loss'].append(loss.item())

            if (ots_iter % 50 == 0):
                avg_loss = np.mean(history['losses']['ot_loss'])
#                 print('OTS Iteration {} | Epoch {} | Avg Loss Value: {}'.format(ots_iter, epoch, round(avg_loss, 3)))
            if (ots_iter % 2000 == 0):
                # Occasionally save a random sample from the generator during OTS
                sample = y_fake[0:config['sample_size']].view(-1, 1, config['dataset'], config['dataset'])
                utils.save_sample(sample, epoch, ots_iter, config['random_samples'])
                
#                 # Display histogram stats
#                 hist_dict, stop = utils.update_histogram(transfer, history, config)
#                 # Emperical stopping criterion
#                 if stop:
#                     break

        # Compute the Optimal Fitting Transport Plan
        print("\nFitting Optimal Transport Plan")
        fit_bar  = tqdm([i for i in range(config['mem_size'])])
        for fit_iter, _ in enumerate(fit_bar):
            G_optim.zero_grad()
            
            # Retrieve stored batch of generated samples
#             z_batch = torch.tensor(mu[fit_iter]).view(-1, in_f, H, W).to(device)
            z_batch = torch.tensor(mu[fit_iter]).to(device)
            
            # Flatten the model output to match dataloader
#             y_fake = G(z_batch).view(config['batch_size'], -1)
            y_fake = G(z_batch)
            
            # Get Transfer plan from OTS: T(G_{t-1}(Gz))
            t_plan = torch.tensor(transfer[fit_iter])
            y0_hit = dataloader[t_plan].to(device)

            # Compute Wasserstein distance between G and T
#             G_loss = torch.mean(torch.abs(y0_hit - y_fake)) * l_dim
            G_loss = loss_fn(y_fake, y0_hit)
            
            # Backprop
            G_loss.backward() # Gradient descent
            G_optim.step()

            # Update losses
            history['losses']['fit_loss'].append(G_loss.item())

            # Check if best loss value and save checkpoint
            if 'best_loss' not in history:
                history.update({ 'best_loss' : G_loss.item() })

            best = G_loss.item() < (history['best_loss'] * 0.5)
            if best:
                history['best_loss'] = G_loss.item()
                checkpoint = utils.get_checkpoint(history['epoch'], checkpoint_kwargs, config)
                utils.save_checkpoint(checkpoint, config)

            if (fit_iter % 50 == 0):
                avg_loss = np.mean(history['losses']['fit_loss'])
#                 print('FIT Iteration {} | Epoch {} | Avg Loss Value: {}'.format(fit_iter, epoch, round(avg_loss,3)))
         
        # Save a fixed sample of the generator's output at the end of FIT
        sample = G(z_fixed).view(-1, 1, config['dataset'], config['dataset'])
        utils.save_sample(sample, epoch, fit_iter, config['fixed_samples'])

    # Save a checkpoint at end of training
    checkpoint = utils.get_checkpoint(history['epoch'], checkpoint_kwargs, config)
    utils.save_checkpoint(checkpoint, config)

    # Save training data to csv's after training end
    utils.save_train_hist(history, config, times=None, histogram=history['hist_dict'])
    print("Stop Counter Triggered {} Times".format(stop_counter))

    # For Spike
    print("See you, Space Cowboy")
Ejemplo n.º 3
0
def train_MNIST(config):
    # Create python version of cpp operation
    if config['dist'] == 'W1':
        print("Building C++ extension for W1 (requires PyTorch >= 1.0.0)...")
        from torch.utils.cpp_extension import load
        my_ops = load(name="my_ops",
                      sources=[
                          "W1_extension/my_ops.cpp",
                          "W1_extension/my_ops_kernel.cu"
                      ],
                      verbose=False)
        import my_ops
        print("Building complete")

    # Centralize stats logging
    times, losses, hist_dict, checkpt = utils.centralized_logs()

    # Select device
    device = torch.device(config['gpu'])

    # Update config dict with MNIST params and get MNIST dataset as one batch
    # config, mnist_data = utils.MNIST(config)
    '''
        Returns MNIST training data as a single batch of data
    '''
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, ), (0.5, ))])
    dataset = dset.MNIST(root=config['data_root'],
                         train=True,
                         download=True,
                         transform=transform)

    def get_data(dataset):
        full_dataloader = torch.utils.data.DataLoader(dataset,
                                                      batch_size=len(dataset))
        for y_batch, l_batch in full_dataloader:
            return y_batch

    y_t = get_data(dataset)
    n_dim = len(y_t)
    mnist_data = y_t.view(n_dim, -1).to(device)

    config.update({
        'dset_size': n_dim,
        'imsize': 28,
        'out_dim': 784,
        'batch_size': 64,
        'sample_size': 16,
        'early_end': (200, 320)
    })

    # Set MLP architecture
    G_arch = utils.get_mlp_arch(config)

    # Create G_model
    G_mlp = utils.get_model(G_arch, device)

    # Get optimizer
    opt_G = utils.get_optim(G_arch, G_mlp.parameters(), MNIST=True)
    print(G_mlp)

    # Initialize G_model weights and get the number of layers
    # G_mlp = utils.weights_init(G_mlp, MNIST=True)
    def initialize_weights(net):
        for m in net.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.02)
            if hasattr(m, "bias") and m.bias is not None:
                m.bias.data.zero_()

    initialize_weights(G_mlp)

    # Create labels for experiment
    labels = utils.create_labels(config)

    # Update config with labels and save locations
    config = utils.update_with_labels(config, labels)

    # Setup psi optimizer
    psi = torch.zeros(n_dim, requires_grad=True, device=device)
    opt_psi = torch.optim.Adam([psi], lr=1e-1)

    # Set fixed noise vector for testing
    z_fixed = utils.get_z(config, device, sample=False)
    z_fixed.resize_((config['batch_size'], config['z_dim'])).to(device)

    # Training loop
    for epoch in range(config['num_epochs']):

        epoch_start_time = time.time()

        # Save list of losses for end-training determination
        loss_memory = []

        # Set up memory tensors: simple feed-forward distribution, transfer plan
        mu = torch.zeros(config['mem_size'], config['batch_size'],
                         config['z_dim'])
        transfer = torch.zeros(config['mem_size'],
                               config['batch_size'],
                               dtype=torch.long)
        mem_idx = 0

        # Compute Optimal Transport Solver (OTS) over every training example
        ot_start = time.time()

        # for ots_iter in range(0, config['dset_size']):
        for iter in range(1, 20001):

            opt_psi.zero_grad()

            # Generate samples from feed-forward distribution
            z_batch = utils.get_z(config, device, sample=False)
            z_batch.resize_((config['batch_size'], config['z_dim'])).to(device)
            y_fake = G_mlp(z_batch)  # [B, n_dim]

            # Compute cost between sample batch and target distribution
            if (config['dist'] == 'W1'):
                score = -my_ops.l1_t(y_fake, mnist_data) - psi
            else:
                score = torch.matmul(
                    y_fake, mnist_data.t()) - psi  # score: [B, N], psi: [N]

            phi, hit = torch.max(score, 1)
            # phi, hit = torch.min(score, 1) # [B], [B]

            # Wasserstein distance computation: d(x,y)^p
            if (config['dist'] == 'W1'):
                loss_primal = torch.mean(
                    torch.abs(y_fake - mnist_data[hit])) * config['out_dim']
            else:
                loss_primal = torch.mean(
                    (y_fake - mnist_data[hit])**2) * config['out_dim']

            # Loss computation
            # loss = (torch.mean(phi) + torch.mean(psi)) # Testing this
            loss = -torch.mean(psi[hit])  # equiv. to loss?

            # Backprop
            loss.backward()  # Gradient ascent
            opt_psi.step()

            # Append losses to dict
            losses['ot_loss'].append(loss.item())
            losses['w2_estim'].append(loss_primal.item())

            # Update memory tensors
            mu[mem_idx] = z_batch
            transfer[mem_idx] = hit
            mem_idx = (mem_idx + 1) % config['mem_size']

            if (iter % 500 == 0):
                print('OTS Iteration {} | Epoch {}'.format(iter, epoch))
            if (iter % 2000 == 0):
                # Display histogram stats
                hist_dict, stop = utils.update_histogram(
                    transfer, n_dim, epoch, iter, config, losses, hist_dict)
                # Emperical stopping criterion
                if stop:
                    break

        # Compute OTS time and append
        ot_end = time.time()
        times['ot_time'].append(ot_end - ot_start)

        # Compute Fitting Optimal Transport Plan (FIT)
        fit_start = time.time()
        for fit_iter in range(config['mem_size']):

            opt_G.zero_grad()

            # Get stored batch of generated samples
            z_batch = mu[fit_iter].to(device)
            y_fake = G_mlp(z_batch)  # G'(z)

            # Get Transfer plan from OTS: T(G_{t-1}(z))
            y0_hit = mnist_data[transfer[fit_iter].to(device)]

            # Compute Wasserstein distance between G and T
            if (config['dist'] == 'W1'):
                loss_g = torch.mean(
                    torch.abs(y0_hit - y_fake)) * config['out_dim']
            else:
                loss_g = torch.mean((y0_hit - y_fake)**2) * config['out_dim']

            # Backprop
            loss_g.backward()  # Gradient descent
            opt_G.step()

            # Append losses to dict
            losses['g_loss'].append(loss_g.item())
            loss_memory.append(loss_g.item())

            if (fit_iter % 500 == 0):
                print(
                    'Fit_iter: {} | Epoch: {} | Loss: {:.2f} | Best Loss: {:.2f}'
                    .format(fit_iter, epoch, loss_g, checkpt['best']))

            # Check if best loss value and save checkpoint
            # threshold = (checkpt['best'] - round(checkpt['best']*0.5))
            # best = ( loss_g.item() < threshold )
            # if best:
            #     checkpt['best'] = loss_g.item()
            #     chkpt_dict = utils.checkpoint_dict(fit_iter, epoch, G_mlp, opt_G)
            #     utils.save_checkpoint(chkpt_dict, best, epoch, -1, config['weights_root'])

            # Save periodic checkpoint
            # if (fit_iter % 2000 == 0):
            #     chkpt_dict = utils.checkpoint_dict(fit_iter, epoch, G_mlp, opt_G)
            #     utils.save_checkpoint(chkpt_dict, False, epoch, fit_iter, config['weights_root'])

            # Get random sample from G
            if (fit_iter % 1000 == 0):
                z_rand = utils.get_z(config, device, sample=True)
                z_rand.resize_(
                    (config['sample_size'], config['z_dim'])).to(device)
                sample = G_mlp(z_rand).view(-1, 1, config['imsize'],
                                            config['imsize'])
                utils.save_sample(sample, epoch, fit_iter, config['random'])

            # Check if loss is changing - stop training if no change
            if (len(loss_memory) > (config['mem_size'] // 2)):
                if ((loss_g <= (mean(loss_memory)*.999)) and \
                    (loss_g >= (mean(loss_memory)*.995))):
                    break

        # Compute FIT time
        fit_end = time.time()
        times['fit_time'].append(fit_end - fit_start)

        # Compute epoch time
        times['epoch_times'].append(time.time() - epoch_start_time)

        # Output to terminal
        print('Best loss:  {}'.format(checkpt['best']))
        print('Epoch_time: {}'.format(time.time() - epoch_start_time))
        print('Num epochs: {}'.format(epoch))
        print("FIT loss: {:.2f}".format(np.mean(losses['g_loss'])))

        # Save fixed sample at end of training epoch
        sample = G_mlp(z_fixed).view(-1, 1, config['imsize'], config['imsize'])
        utils.save_sample(sample, epoch, 0, config['fixed'])

    # Save training data to csv after training completion
    utils.save_stats(times, losses, hist_dict, G_arch, config['save_root'])
Ejemplo n.º 4
0
def train(config):
    '''
        Training function for EWM generator model.
    '''
    # Create python version of cpp operation
    # (Credit: Chen, arXiv:1906.03471, GitHub: https://github.com/chen0706/EWM)
    from torch.utils.cpp_extension import load
    my_ops = load(name = "my_ops",
                  sources = ["W1_extension/my_ops.cpp",
                             "W1_extension/my_ops_kernel.cu"],
                  verbose = False)
    import my_ops

    # Set up GPU device ordinal - if this fails, use CUDA_LAUNCH_BLOCKING environment param...
    device = torch.device(config['gpu'])

    # Get model kwargs
    emw_kwargs = setup_model.ewm_kwargs(config)

    # Setup model on GPU
    G = ewm_G(**emw_kwargs).to(device)
    G.weights_init()

    print(G)
    input('Press any key to launch')

    # Setup model optimizer
    model_params = {'g_params': G.parameters()}
    G_optim = utils.get_optim(config, model_params)

    # Set up full_dataloader (single batch)
    dataloader = utils.get_dataloader(config) # Full Dataloader
    dset_size  = len(dataloader)

    # Flatten the dataloader into a Tensor of shape [dset_size, l_dim]
    dataloader = dataloader.view(dset_size, -1).to(device)

    # Set up psi optimizer
    psi = torch.zeros(dset_size, requires_grad=True).to(device).detach().requires_grad_(True).to(device)
    psi_optim = torch.optim.Adam([psi], lr=config['psi_lr'])

    # Set up directories for saving training stats and outputs
    config = utils.directories(config)

    # Set up dict for saving checkpoints
    checkpoint_kwargs = {'G':G, 'G_optim':G_optim}

    # Variance argument for the tessellation vectors
    tess_var = config['tess_var']**0.5
    
    # Compute the stopping criterion using set of test vectors
    # and computing the 'ideal' loss between the test/target.
    print(line(60))
    print("Computing stopping criterion")
    print(line(60))
    stop_criterion = []
    test_loader = utils.get_test_loader(config)
    for _, test_vecs in enumerate(test_loader):
        # Add Gaussian noise to test_vectors
        test_vecs = test_vecs.view(config['batch_size'], -1).to(device) # 'Perfect' generator model
        t1 = tess_var*torch.randn(test_vecs.shape[0], test_vecs.shape[1]).to(device)
        test_vecs += t1
        # Add Gaussian noise to target data
        t2 = tess_var*torch.randn(dataloader.shape[0], dataloader.shape[1]).to(device)
        test_target  = dataloader + t2
        # Compute the stop score
        stop_score = my_ops.l1_t(test_vecs, test_target)
        stop_loss = -torch.mean(stop_score)
        stop_criterion.append(stop_loss.cpu().detach().numpy())
    del test_loader
    # Set stopping criterion variables
    stop_min, stop_mean, stop_max = np.min(stop_criterion), np.mean(stop_criterion), np.max(stop_criterion)
    print(line(60))
    print('Stop Criterion: min: {}, mean: {}, max: {}'.format(round(stop_min, 3), round(stop_mean, 3), round(stop_max, 3)))
    print(line(60))

    # Set up stats logging
    hist_dict = {'hist_min':[], 'hist_max':[], 'ot_loss':[]}
    losses    = {'ot_loss': [], 'fit_loss': []}
    history   = {'dset_size': dset_size, 'epoch': 0, 'iter': 0, 'losses'   : losses, 'hist_dict': hist_dict}
    config['early_end'] = (200, 320) # Empirical stopping criterion from EWM author
    stop_counter = 0
    
    # Set up progress bar for terminal output and enumeration
    epoch_bar  = tqdm([i for i in range(config['num_epochs'])])
    
    # Training Loop
    for epoch, _ in enumerate(epoch_bar):

        history['epoch'] = epoch

        # Set up memory lists: 
        #     - mu: simple feed-forward distribution 
        #     - transfer: transfer plan given by lists of indices
        # Rule-of-thumb: do not save the tensors themselves: instead, save the 
        #                data as a list and covert it to a tensor as needed.
        mu = [0] * config['mem_size']
        transfer = [0] * config['mem_size']
        mem_idx = 0

        # Compute the Optimal Transport Solver
        for ots_iter in range(1, dset_size//2):

            history['iter'] = ots_iter

            psi_optim.zero_grad()

            # Generate samples from feed-forward distribution
            z_batch = torch.randn(config['batch_size'], config['z_dim']).to(device)
            y_fake  = G(z_batch) # [B, dset_size]
            
#             # Add Gaussian noise to the output of the generator function and to the data with tessellation vectors
            t1 = tess_var*torch.randn(y_fake.shape[0], y_fake.shape[1]).to(device)
            t2 = tess_var*torch.randn(dataloader.shape[0], dataloader.shape[1]).to(device)
            
            y_fake  += t1
            dataloader += t2
            
            # Compute the W1 distance between the model output and the target distribution
            score = my_ops.l1_t(y_fake, dataloader) - psi
            phi, hit = torch.max(score, 1)

            # Remove the tesselation from the dataloader
            dataloader -= t2
            
            # Standard loss computation
            # This loss defines the sample mean of the marginal distribution
            # of the dataset. This is the only computation that generalizes.
            loss = -torch.mean(psi[hit])

            # Backprop
            loss.backward()
            psi_optim.step()

            # Update memory tensors
            mu[mem_idx] = z_batch.data.cpu().numpy().tolist()
            transfer[mem_idx] = hit.data.cpu().numpy().tolist()
            mem_idx = (mem_idx + 1) % config['mem_size']

            # Update losses
            history['losses']['ot_loss'].append(loss.item())

            if (ots_iter % 500 == 0):
                avg_loss = np.mean(history['losses']['ot_loss'])
                print('OTS Iteration {} | Epoch {} | Avg Loss Value: {}'.format(ots_iter, epoch, round(avg_loss, 3)))
#             if (iter % 2000 == 0):
#                 # Display histogram stats
#                 hist_dict, stop = utils.update_histogram(transfer, history, config)
#                 # Emperical stopping criterion
#                 if stop:
#                     break

            if ots_iter > (dset_size//3):
                if  stop_min <= np.mean(history['losses']['ot_loss']) <= stop_max:
                    stop_counter += 1
                    break

        # Compute the Optimal Fitting Transport Plan
        for fit_iter in range(config['mem_size']):
            G_optim.zero_grad()

            # Retrieve stored batch of generated samples
            z_batch = torch.tensor(mu[fit_iter]).to(device)
            y_fake  = G(z_batch) # G'(z)
            
            # Get Transfer plan from OTS: T(G_{t-1}(z))
            t_plan = torch.tensor(transfer[fit_iter]).to(device)
            y0_hit = dataloader[t_plan].to(device)
            
#            Tesselate the output of the generator function and the data
#             t1 = tess_var*torch.randn(y_fake.shape[0], y_fake.shape[1]).to(device)
#             t2 = tess_var*torch.randn(y0_hit.shape[0], y0_hit.shape[1]).to(device)
            
#             y_fake *= t1
#             y0_hit *= t1
            
            # Compute Wasserstein distance between G and T
            G_loss = torch.mean(torch.abs(y0_hit - y_fake)) * config['l_dim']

            # Backprop
            G_loss.backward() # Gradient descent
            G_optim.step()

            # Update losses
            history['losses']['fit_loss'].append(G_loss.item())

            # Check if best loss value and save checkpoint
            if 'best_loss' not in history:
                history.update({ 'best_loss' : G_loss.item() })

            best = G_loss.item() < (history['best_loss'] * 0.70)
            if best:
                history['best_loss'] = G_loss.item()
                checkpoint = utils.get_checkpoint(history['epoch'], checkpoint_kwargs, config)
                utils.save_checkpoint(checkpoint, config)

            if (fit_iter % 500 == 0):
                avg_loss = np.mean(history['losses']['fit_loss'])
                print('FIT Iteration {} | Epoch {} | Avg Loss Value: {}'.format(fit_iter, epoch, round(avg_loss,3)))

    # Save a checkpoint at end of training
    checkpoint = utils.get_checkpoint(history['epoch'], checkpoint_kwargs, config)
    utils.save_checkpoint(checkpoint, config)

    # Save training data to csv's after training end
    utils.save_train_hist(history, config, times=None, histogram=history['hist_dict'])
    print("Stop Counter Triggered {} Times".format(stop_counter))

    # For Aiur
    print("I see you have an appetite for destruction.")
    print("And you have learned to use your illusion.")
    print("But I find your lack of control disturbing.")
Ejemplo n.º 5
0
z_output = torch.randn(batch_size, nz, device=device)
for it in range(1, 101):
    # OTS
    ot_loss = []
    w1_estimate = []
    memory_p = 0
    memory_z = torch.zeros(memory_size, batch_size, nz)
    memory_y = torch.zeros(memory_size, batch_size, dtype=torch.long)
    for ots_iter in range(1, 20001):
        opt_psi.zero_grad()

        z_batch = torch.randn(batch_size, nz, device=device)
        y_fake = G(z_batch)

        if distance == "W1":
            score = -my_ops.l1_t(y_fake, y_t) - psi
        elif distance == "W2" or distance == "Hybrid":
            score = torch.matmul(y_fake, y_t.t()) - psi
        phi, hit = torch.max(score, 1)

        loss = torch.mean(phi) + torch.mean(psi)
        if distance == "W1" or distance == "Hybrid":
            loss_primal = torch.mean(torch.abs(y_fake - y_t[hit])) * d
        elif distance == "W2":
            loss_primal = torch.mean((y_fake - y_t[hit])**2) * d
        loss_back = -torch.mean(psi[hit])  # equivalent to loss
        loss_back.backward()

        opt_psi.step()

        ot_loss.append(loss.item())