def train_mlp(G_mlp, opt_G, psi, opt_psi, dataloader, device, loader_kwargs, config, times, losses, hist_dict, checkpt): if config['mnist']: import poc poc.train_MNIST(G_mlp, opt_G, psi, opt_psi, dataloader, device, loader_kwargs, config, times, losses, hist_dict, checkpt) return # Set fixed noise vector for testing z_fixed = utils.get_z(config, device, sample=False) z_fixed.resize_((config['batch_size'], config['z_dim'])) # Create python version of cpp operation if config['dist'] == 'W1': from torch.utils.cpp_extension import load my_ops = load(name="my_ops", sources=[ "W1_extension/my_ops.cpp", "W1_extension/my_ops_kernel.cu" ], verbose=False) import my_ops n_dim = len(dataloader) # Training loop for epoch in range(config['num_epochs']): epoch_start_time = time.time() # Save list of losses for end-training determination loss_memory = [] # Set up memory tensors: simple feed-forward distribution, transfer plan mu = torch.zeros(config['mem_size'], config['batch_size'], config['z_dim']) transfer = torch.zeros(config['mem_size'], config['batch_size'], dtype=torch.long) mem_idx = 0 # Compute Optimal Transport Solver (OTS) over every training example ot_start = time.time() # for ots_iter in range(0, config['dset_size']): for iter in range(0, loader_kwargs['batch_size']): opt_psi.zero_grad() # Generate samples from feed-forward distribution z_batch = utils.get_z(config, device, sample=False) z_batch.resize_((config['batch_size'], config['z_dim'])) y_fake = G_mlp(z_batch) # [B, n_dim] # Compute cost between sample batch and target distribution if (config['dist'] == 'W1'): score = -my_ops.l1_t(y_fake, dataloader) - psi else: score = torch.matmul( y_fake, dataloader.t()) - psi # score: [B, N], psi: [N] # phi, hit = torch.max(score, 1) phi, hit = torch.min(score, 1) # [B], [B] # Wasserstein distance computation: d(x,y)^p if (config['dist'] == 'W1'): loss_primal = torch.mean( torch.abs(y_fake - dataloader[hit])) * config['out_dim'] else: loss_primal = torch.mean( (y_fake - dataloader[hit])**2) * config['out_dim'] # Loss computation # loss = -(torch.mean(phi) + torch.mean(psi)) # Testing this loss = -torch.mean(psi[hit]) # equiv. to loss? # Backprop loss.backward() # Gradient ascent opt_psi.step() # Append losses to dict losses['ot_loss'].append(loss.item()) losses['w2_estim'].append(loss_primal.item()) # Update memory tensors mu[mem_idx] = z_batch transfer[mem_idx] = hit mem_idx = (mem_idx + 1) % config['mem_size'] if (iter % 500 == 0): print('OTS Iteration {} | Epoch {}'.format(iter, epoch)) if (iter % (config['num_epochs'] * 10) == 0): # Display histogram stats hist_dict, stop = utils.update_histogram( transfer, n_dim, epoch, iter, config, losses, hist_dict) # Emperical stopping criterion if stop: break # Compute OTS time and append ot_end = time.time() times['ot_time'].append(ot_end - ot_start) # Compute Fitting Optimal Transport Plan (FIT) fit_start = time.time() for fit_iter in range(config['mem_size']): opt_G.zero_grad() # Get stored batch of generated samples z_batch = mu[fit_iter].to(device) y_fake = G_mlp(z_batch) # G'(z) # Get Transfer plan from OTS: T(G_{t-1}(z)) y0_hit = dataloader[transfer[fit_iter].to(device)] # Compute Wasserstein distance between G and T if (config['dist'] == 'W1'): loss_g = torch.mean( torch.abs(y0_hit - y_fake)) * config['out_dim'] else: loss_g = torch.mean((y0_hit - y_fake)**2) * config['out_dim'] # Backprop loss_g.backward() # Gradient descent opt_G.step() # Append losses to dict losses['g_loss'].append(loss_g.item()) loss_memory.append(loss_g.item()) if (fit_iter % 500 == 0): print( 'Fit_iter: {} | Epoch: {} | Loss: {:.2f} | Best Loss: {:.2f}' .format(fit_iter, epoch, loss_g, checkpt['best'])) # Check if best loss value and save checkpoint threshold = (checkpt['best'] - round(checkpt['best'] * 0.5)) best = (loss_g.item() < threshold) if best: checkpt['best'] = loss_g.item() chkpt_dict = utils.checkpoint_dict(fit_iter, epoch, G_mlp, opt_G) utils.save_checkpoint(chkpt_dict, best, epoch, -1, config['weights_root']) # Save periodic checkpoint if (fit_iter % (config['num_epochs'] * 5) == 0): chkpt_dict = utils.checkpoint_dict(fit_iter, epoch, G_mlp, opt_G) utils.save_checkpoint(chkpt_dict, False, epoch, fit_iter, config['weights_root']) # Get random sample from G if (fit_iter % (config['num_epochs']) == 0): z_rand = utils.get_z(config, device, sample=True) z_rand.resize_((config['batch_size'], config['z_dim'])) sample = G_mlp(z_rand).view(-1, 1, config['imsize'], config['imsize']) utils.save_sample(sample, epoch, fit_iter, config['random']) # Check if loss is changing - stop training if no change if (len(loss_memory) > (config['mem_size'] // 2)): if ((loss_g <= (mean(loss_memory)*.999)) and \ (loss_g >= (mean(loss_memory)*.995))): break # Compute FIT time fit_end = time.time() times['fit_time'].append(fit_end - fit_start) # Compute epoch time times['epoch_times'].append(time.time() - epoch_start_time) # Output to terminal print('Best loss: {}'.format(checkpt['best'])) print('Epoch_time: {}'.format(time.time() - epoch_start_time)) print('Num epochs: {}'.format(epoch)) print("FIT loss: {:.2f}".format(np.mean(losses['g_loss']))) # Save fixed sample at end of training epoch sample = G_mlp(z_fixed).view(-1, 1, config['imsize'], config['imsize']) utils.save_sample(sample, epoch, 0, config['fixed']) return [times, losses, hist_dict]
def train(config): ''' Training function for EWM generator model. ''' # Create python version of cpp operation # (Credit: Chen, arXiv:1906.03471, GitHub: https://github.com/chen0706/EWM) from torch.utils.cpp_extension import load my_ops = load(name = "my_ops", sources = ["W1_extension/my_ops.cpp", "W1_extension/my_ops_kernel.cu"], verbose = False) import my_ops # Set up GPU device ordinal device = torch.device(config['gpu']) # Get model kwargs for convolutional generator config['model'] == 'ewm_conv' emw_kwargs = setup_model.ewm_kwargs(config) # Setup convolutional generator model on GPU G = ewm.ewm_convG(**emw_kwargs).to(device) # ewm.weights_init(G) G.train() # Setup model optimizer model_params = {'g_params': G.parameters()} G_optim = utils.get_optim(config, model_params) # Testing: MSE loss for conv_generator reconstruction loss_fn = nn.MSELoss().to(device) # Print G -- make sure it's right before continuing print(G) input('Press any key to continue') # Setup source of structured noise on GPU (Trained EWM_MLP_Generator Model) # Add these configs to the experiment source file if not config['ewm_root']: raise Exception('Path to trained EWM_MLP Model must be specified') # Print list of evaluated EWM model EWM_paths = []; EWM_root = config['ewm_root'] for path in os.listdir(EWM_root): EWM_paths.append(os.path.join(EWM_root, path)) print("-"*60) for i in range(len(EWM_paths)): EWM_name = EWM_paths[i].split('/')[-1] print("\n Exp_{}:".format(str(i)), EWM_name, '\n') print("-"*60) # Select the trained model model_num = input('Select EWM_MLP model (enter integer): ') EWM_dir = EWM_paths[int(model_num)] # Create the full path to the EWM model EWM_path = os.path.join(EWM_root, EWM_dir) + '/' print("Path to EWM Generator Model set as: \n{}".format(EWM_path)) config_csv = EWM_path + "config.csv" config_df = pd.read_csv(config_csv, delimiter = ",") # Get the model architecture from config df n_layers = int(config_df[config_df['Unnamed: 0'].str.contains("n_layers")==True]['0'].values.item()) n_hidden = int(config_df[config_df['Unnamed: 0'].str.contains("n_hidden")==True]['0'].values.item()) l_dim = int(config_df[config_df['Unnamed: 0'].str.contains("l_dim")==True]['0'].values.item()) im_size = int(config_df[config_df['Unnamed: 0'].str.contains("dataset")==True]['0'].values.item()) z_dim = int(config_df[config_df['Unnamed: 0'].str.contains("z_dim")==True]['0'].values.item()) # Model kwargs ewm_kwargs = {'z_dim': z_dim, 'fc_sizes': [n_hidden]*n_layers, 'n_out': l_dim} # Send model to GPU Gz = ewm.ewm_G(**ewm_kwargs).to(device) # Load the model checkpoint # Get checkpoint name(s) EWM_checkpoint_path = EWM_path + 'weights/' EWM_checkpoint_names = [] for file in os.listdir(EWM_checkpoint_path): EWM_checkpoint_names.append(os.path.join(EWM_checkpoint_path, file)) print("-"*60) for i in range(len(EWM_checkpoint_names)): name = EWM_checkpoint_names[i].split('/')[-1] print("\n {} :".format(str(i)), name, '\n') print("-"*60) file_num = input("Select a checkpoint file for EWM_MLP (enter integer): ") EWM_checkpoint = EWM_checkpoint_names[int(file_num)] # Load the model checkpoint # Keys: ['state_dict', 'epoch', 'optimizer'] checkpoint = torch.load(EWM_checkpoint) # Load the model's state dictionary Gz.load_state_dict(checkpoint['state_dict']) # Use the code_vector model in evaluation mode -- no need for gradients here Gz.eval() print(Gz) input('Press any key to continue') # Set up full_dataloader (single batch) dataloader = utils.get_dataloader(config).to(device) # Full Dataloader dset_size = len(dataloader) # Flatten the dataloader into a Tensor of shape [dset_size, l_dim] # dataloader = dataloader.view(dset_size, -1).to(device) # Set up psi optimizer psi = torch.zeros(dset_size, requires_grad=True).to(device).detach().requires_grad_(True) psi_optim = torch.optim.Adam([psi], lr=config['psi_lr']) # Set up directories for saving training stats and outputs config = utils.directories(config) # Set up dict for saving checkpoints checkpoint_kwargs = {'G':G, 'G_optim':G_optim} # Set up stats logging hist_dict = {'hist_min':[], 'hist_max':[], 'ot_loss':[]} losses = {'ot_loss': [], 'fit_loss': []} history = {'dset_size': dset_size, 'epoch': 0, 'iter': 0, 'losses' : losses, 'hist_dict': hist_dict} config['early_end'] = (200, 320) # Empirical stopping criterion from EWM author stop_counter = 0 # Compute how the input vectors need to be reshaped, based on conv_G input layer in_f = G.main[0][2].in_channels; out_f = G.main[0][2].out_channels # Set the height and width of the feature maps. Note: Manually setting this to 8 is # hackish, but computing the actualy value requires knowing the number of layers in # the AutoEncoder whose code layer was used to train the generator model being loaded # here. I'm avoiding loading multiple paths and dataframes by simply setting it to 8, # but maybe you can do better than I did... H = W = 8 print('Vectors will be reshaped as: [{}] --> [{},{},{}]'.format(l_dim, in_f, H, W)) # Set a fixed feature tensor for testing noise = torch.randn(config['sample_size'], config['z_dim']).to(device) z_fixed = Gz(noise).view(-1, in_f, H, W).to(device) # Training Loop input('\nPress any key to launch -- good luck out there\n') # Set up progress bar for terminal output and enumeration epoch_bar = tqdm([i for i in range(config['num_epochs'])]) for epoch, _ in enumerate(epoch_bar): history['epoch'] = epoch # Set up memory lists: # - mu: simple feed-forward distribution # - transfer: transfer plan given by lists of indices # Rule-of-thumb: do not save the tensors themselves: instead, save the # data as a list and covert it to a tensor as needed. mu = [0] * config['mem_size'] transfer = [0] * config['mem_size'] mem_idx = 0 # Compute the Optimal Transport Solver print("\nOptimal Transport Solver") ots_bar = tqdm([i for i in range(dset_size//10)]) for ots_iter, _ in enumerate(ots_bar): history['iter'] = ots_iter psi_optim.zero_grad() # Generate samples from cove_vector distribution z_batch = torch.randn(config['batch_size'], config['z_dim']).to(device) z_batch = Gz(z_batch).view(-1, in_f, H, W) # Push structured noise vector through convolutional generator # y_fake = G(z_batch).view(config['batch_size'], -1) # Flatten the output to match dataloader y_fake = G(z_batch) # Compute the W1 distance between the model output and the target distribution score = my_ops.l1_t(y_fake, dataloader) - psi phi, hit = torch.max(score, 1) # Standard loss computation # This loss defines the sample mean of the marginal distribution # of the dataset. This is the only computation that generalizes. loss = -torch.mean(psi[hit]) # Backprop loss.backward() psi_optim.step() # Update memory tensors (lists) mu[mem_idx] = z_batch.data.cpu().numpy().tolist() transfer[mem_idx] = hit.data.cpu().numpy().tolist() mem_idx = (mem_idx + 1) % config['mem_size'] # Update losses history['losses']['ot_loss'].append(loss.item()) if (ots_iter % 50 == 0): avg_loss = np.mean(history['losses']['ot_loss']) # print('OTS Iteration {} | Epoch {} | Avg Loss Value: {}'.format(ots_iter, epoch, round(avg_loss, 3))) if (ots_iter % 2000 == 0): # Occasionally save a random sample from the generator during OTS sample = y_fake[0:config['sample_size']].view(-1, 1, config['dataset'], config['dataset']) utils.save_sample(sample, epoch, ots_iter, config['random_samples']) # # Display histogram stats # hist_dict, stop = utils.update_histogram(transfer, history, config) # # Emperical stopping criterion # if stop: # break # Compute the Optimal Fitting Transport Plan print("\nFitting Optimal Transport Plan") fit_bar = tqdm([i for i in range(config['mem_size'])]) for fit_iter, _ in enumerate(fit_bar): G_optim.zero_grad() # Retrieve stored batch of generated samples # z_batch = torch.tensor(mu[fit_iter]).view(-1, in_f, H, W).to(device) z_batch = torch.tensor(mu[fit_iter]).to(device) # Flatten the model output to match dataloader # y_fake = G(z_batch).view(config['batch_size'], -1) y_fake = G(z_batch) # Get Transfer plan from OTS: T(G_{t-1}(Gz)) t_plan = torch.tensor(transfer[fit_iter]) y0_hit = dataloader[t_plan].to(device) # Compute Wasserstein distance between G and T # G_loss = torch.mean(torch.abs(y0_hit - y_fake)) * l_dim G_loss = loss_fn(y_fake, y0_hit) # Backprop G_loss.backward() # Gradient descent G_optim.step() # Update losses history['losses']['fit_loss'].append(G_loss.item()) # Check if best loss value and save checkpoint if 'best_loss' not in history: history.update({ 'best_loss' : G_loss.item() }) best = G_loss.item() < (history['best_loss'] * 0.5) if best: history['best_loss'] = G_loss.item() checkpoint = utils.get_checkpoint(history['epoch'], checkpoint_kwargs, config) utils.save_checkpoint(checkpoint, config) if (fit_iter % 50 == 0): avg_loss = np.mean(history['losses']['fit_loss']) # print('FIT Iteration {} | Epoch {} | Avg Loss Value: {}'.format(fit_iter, epoch, round(avg_loss,3))) # Save a fixed sample of the generator's output at the end of FIT sample = G(z_fixed).view(-1, 1, config['dataset'], config['dataset']) utils.save_sample(sample, epoch, fit_iter, config['fixed_samples']) # Save a checkpoint at end of training checkpoint = utils.get_checkpoint(history['epoch'], checkpoint_kwargs, config) utils.save_checkpoint(checkpoint, config) # Save training data to csv's after training end utils.save_train_hist(history, config, times=None, histogram=history['hist_dict']) print("Stop Counter Triggered {} Times".format(stop_counter)) # For Spike print("See you, Space Cowboy")
def train_MNIST(config): # Create python version of cpp operation if config['dist'] == 'W1': print("Building C++ extension for W1 (requires PyTorch >= 1.0.0)...") from torch.utils.cpp_extension import load my_ops = load(name="my_ops", sources=[ "W1_extension/my_ops.cpp", "W1_extension/my_ops_kernel.cu" ], verbose=False) import my_ops print("Building complete") # Centralize stats logging times, losses, hist_dict, checkpt = utils.centralized_logs() # Select device device = torch.device(config['gpu']) # Update config dict with MNIST params and get MNIST dataset as one batch # config, mnist_data = utils.MNIST(config) ''' Returns MNIST training data as a single batch of data ''' transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ))]) dataset = dset.MNIST(root=config['data_root'], train=True, download=True, transform=transform) def get_data(dataset): full_dataloader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset)) for y_batch, l_batch in full_dataloader: return y_batch y_t = get_data(dataset) n_dim = len(y_t) mnist_data = y_t.view(n_dim, -1).to(device) config.update({ 'dset_size': n_dim, 'imsize': 28, 'out_dim': 784, 'batch_size': 64, 'sample_size': 16, 'early_end': (200, 320) }) # Set MLP architecture G_arch = utils.get_mlp_arch(config) # Create G_model G_mlp = utils.get_model(G_arch, device) # Get optimizer opt_G = utils.get_optim(G_arch, G_mlp.parameters(), MNIST=True) print(G_mlp) # Initialize G_model weights and get the number of layers # G_mlp = utils.weights_init(G_mlp, MNIST=True) def initialize_weights(net): for m in net.modules(): if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.02) if hasattr(m, "bias") and m.bias is not None: m.bias.data.zero_() initialize_weights(G_mlp) # Create labels for experiment labels = utils.create_labels(config) # Update config with labels and save locations config = utils.update_with_labels(config, labels) # Setup psi optimizer psi = torch.zeros(n_dim, requires_grad=True, device=device) opt_psi = torch.optim.Adam([psi], lr=1e-1) # Set fixed noise vector for testing z_fixed = utils.get_z(config, device, sample=False) z_fixed.resize_((config['batch_size'], config['z_dim'])).to(device) # Training loop for epoch in range(config['num_epochs']): epoch_start_time = time.time() # Save list of losses for end-training determination loss_memory = [] # Set up memory tensors: simple feed-forward distribution, transfer plan mu = torch.zeros(config['mem_size'], config['batch_size'], config['z_dim']) transfer = torch.zeros(config['mem_size'], config['batch_size'], dtype=torch.long) mem_idx = 0 # Compute Optimal Transport Solver (OTS) over every training example ot_start = time.time() # for ots_iter in range(0, config['dset_size']): for iter in range(1, 20001): opt_psi.zero_grad() # Generate samples from feed-forward distribution z_batch = utils.get_z(config, device, sample=False) z_batch.resize_((config['batch_size'], config['z_dim'])).to(device) y_fake = G_mlp(z_batch) # [B, n_dim] # Compute cost between sample batch and target distribution if (config['dist'] == 'W1'): score = -my_ops.l1_t(y_fake, mnist_data) - psi else: score = torch.matmul( y_fake, mnist_data.t()) - psi # score: [B, N], psi: [N] phi, hit = torch.max(score, 1) # phi, hit = torch.min(score, 1) # [B], [B] # Wasserstein distance computation: d(x,y)^p if (config['dist'] == 'W1'): loss_primal = torch.mean( torch.abs(y_fake - mnist_data[hit])) * config['out_dim'] else: loss_primal = torch.mean( (y_fake - mnist_data[hit])**2) * config['out_dim'] # Loss computation # loss = (torch.mean(phi) + torch.mean(psi)) # Testing this loss = -torch.mean(psi[hit]) # equiv. to loss? # Backprop loss.backward() # Gradient ascent opt_psi.step() # Append losses to dict losses['ot_loss'].append(loss.item()) losses['w2_estim'].append(loss_primal.item()) # Update memory tensors mu[mem_idx] = z_batch transfer[mem_idx] = hit mem_idx = (mem_idx + 1) % config['mem_size'] if (iter % 500 == 0): print('OTS Iteration {} | Epoch {}'.format(iter, epoch)) if (iter % 2000 == 0): # Display histogram stats hist_dict, stop = utils.update_histogram( transfer, n_dim, epoch, iter, config, losses, hist_dict) # Emperical stopping criterion if stop: break # Compute OTS time and append ot_end = time.time() times['ot_time'].append(ot_end - ot_start) # Compute Fitting Optimal Transport Plan (FIT) fit_start = time.time() for fit_iter in range(config['mem_size']): opt_G.zero_grad() # Get stored batch of generated samples z_batch = mu[fit_iter].to(device) y_fake = G_mlp(z_batch) # G'(z) # Get Transfer plan from OTS: T(G_{t-1}(z)) y0_hit = mnist_data[transfer[fit_iter].to(device)] # Compute Wasserstein distance between G and T if (config['dist'] == 'W1'): loss_g = torch.mean( torch.abs(y0_hit - y_fake)) * config['out_dim'] else: loss_g = torch.mean((y0_hit - y_fake)**2) * config['out_dim'] # Backprop loss_g.backward() # Gradient descent opt_G.step() # Append losses to dict losses['g_loss'].append(loss_g.item()) loss_memory.append(loss_g.item()) if (fit_iter % 500 == 0): print( 'Fit_iter: {} | Epoch: {} | Loss: {:.2f} | Best Loss: {:.2f}' .format(fit_iter, epoch, loss_g, checkpt['best'])) # Check if best loss value and save checkpoint # threshold = (checkpt['best'] - round(checkpt['best']*0.5)) # best = ( loss_g.item() < threshold ) # if best: # checkpt['best'] = loss_g.item() # chkpt_dict = utils.checkpoint_dict(fit_iter, epoch, G_mlp, opt_G) # utils.save_checkpoint(chkpt_dict, best, epoch, -1, config['weights_root']) # Save periodic checkpoint # if (fit_iter % 2000 == 0): # chkpt_dict = utils.checkpoint_dict(fit_iter, epoch, G_mlp, opt_G) # utils.save_checkpoint(chkpt_dict, False, epoch, fit_iter, config['weights_root']) # Get random sample from G if (fit_iter % 1000 == 0): z_rand = utils.get_z(config, device, sample=True) z_rand.resize_( (config['sample_size'], config['z_dim'])).to(device) sample = G_mlp(z_rand).view(-1, 1, config['imsize'], config['imsize']) utils.save_sample(sample, epoch, fit_iter, config['random']) # Check if loss is changing - stop training if no change if (len(loss_memory) > (config['mem_size'] // 2)): if ((loss_g <= (mean(loss_memory)*.999)) and \ (loss_g >= (mean(loss_memory)*.995))): break # Compute FIT time fit_end = time.time() times['fit_time'].append(fit_end - fit_start) # Compute epoch time times['epoch_times'].append(time.time() - epoch_start_time) # Output to terminal print('Best loss: {}'.format(checkpt['best'])) print('Epoch_time: {}'.format(time.time() - epoch_start_time)) print('Num epochs: {}'.format(epoch)) print("FIT loss: {:.2f}".format(np.mean(losses['g_loss']))) # Save fixed sample at end of training epoch sample = G_mlp(z_fixed).view(-1, 1, config['imsize'], config['imsize']) utils.save_sample(sample, epoch, 0, config['fixed']) # Save training data to csv after training completion utils.save_stats(times, losses, hist_dict, G_arch, config['save_root'])
def train(config): ''' Training function for EWM generator model. ''' # Create python version of cpp operation # (Credit: Chen, arXiv:1906.03471, GitHub: https://github.com/chen0706/EWM) from torch.utils.cpp_extension import load my_ops = load(name = "my_ops", sources = ["W1_extension/my_ops.cpp", "W1_extension/my_ops_kernel.cu"], verbose = False) import my_ops # Set up GPU device ordinal - if this fails, use CUDA_LAUNCH_BLOCKING environment param... device = torch.device(config['gpu']) # Get model kwargs emw_kwargs = setup_model.ewm_kwargs(config) # Setup model on GPU G = ewm_G(**emw_kwargs).to(device) G.weights_init() print(G) input('Press any key to launch') # Setup model optimizer model_params = {'g_params': G.parameters()} G_optim = utils.get_optim(config, model_params) # Set up full_dataloader (single batch) dataloader = utils.get_dataloader(config) # Full Dataloader dset_size = len(dataloader) # Flatten the dataloader into a Tensor of shape [dset_size, l_dim] dataloader = dataloader.view(dset_size, -1).to(device) # Set up psi optimizer psi = torch.zeros(dset_size, requires_grad=True).to(device).detach().requires_grad_(True).to(device) psi_optim = torch.optim.Adam([psi], lr=config['psi_lr']) # Set up directories for saving training stats and outputs config = utils.directories(config) # Set up dict for saving checkpoints checkpoint_kwargs = {'G':G, 'G_optim':G_optim} # Variance argument for the tessellation vectors tess_var = config['tess_var']**0.5 # Compute the stopping criterion using set of test vectors # and computing the 'ideal' loss between the test/target. print(line(60)) print("Computing stopping criterion") print(line(60)) stop_criterion = [] test_loader = utils.get_test_loader(config) for _, test_vecs in enumerate(test_loader): # Add Gaussian noise to test_vectors test_vecs = test_vecs.view(config['batch_size'], -1).to(device) # 'Perfect' generator model t1 = tess_var*torch.randn(test_vecs.shape[0], test_vecs.shape[1]).to(device) test_vecs += t1 # Add Gaussian noise to target data t2 = tess_var*torch.randn(dataloader.shape[0], dataloader.shape[1]).to(device) test_target = dataloader + t2 # Compute the stop score stop_score = my_ops.l1_t(test_vecs, test_target) stop_loss = -torch.mean(stop_score) stop_criterion.append(stop_loss.cpu().detach().numpy()) del test_loader # Set stopping criterion variables stop_min, stop_mean, stop_max = np.min(stop_criterion), np.mean(stop_criterion), np.max(stop_criterion) print(line(60)) print('Stop Criterion: min: {}, mean: {}, max: {}'.format(round(stop_min, 3), round(stop_mean, 3), round(stop_max, 3))) print(line(60)) # Set up stats logging hist_dict = {'hist_min':[], 'hist_max':[], 'ot_loss':[]} losses = {'ot_loss': [], 'fit_loss': []} history = {'dset_size': dset_size, 'epoch': 0, 'iter': 0, 'losses' : losses, 'hist_dict': hist_dict} config['early_end'] = (200, 320) # Empirical stopping criterion from EWM author stop_counter = 0 # Set up progress bar for terminal output and enumeration epoch_bar = tqdm([i for i in range(config['num_epochs'])]) # Training Loop for epoch, _ in enumerate(epoch_bar): history['epoch'] = epoch # Set up memory lists: # - mu: simple feed-forward distribution # - transfer: transfer plan given by lists of indices # Rule-of-thumb: do not save the tensors themselves: instead, save the # data as a list and covert it to a tensor as needed. mu = [0] * config['mem_size'] transfer = [0] * config['mem_size'] mem_idx = 0 # Compute the Optimal Transport Solver for ots_iter in range(1, dset_size//2): history['iter'] = ots_iter psi_optim.zero_grad() # Generate samples from feed-forward distribution z_batch = torch.randn(config['batch_size'], config['z_dim']).to(device) y_fake = G(z_batch) # [B, dset_size] # # Add Gaussian noise to the output of the generator function and to the data with tessellation vectors t1 = tess_var*torch.randn(y_fake.shape[0], y_fake.shape[1]).to(device) t2 = tess_var*torch.randn(dataloader.shape[0], dataloader.shape[1]).to(device) y_fake += t1 dataloader += t2 # Compute the W1 distance between the model output and the target distribution score = my_ops.l1_t(y_fake, dataloader) - psi phi, hit = torch.max(score, 1) # Remove the tesselation from the dataloader dataloader -= t2 # Standard loss computation # This loss defines the sample mean of the marginal distribution # of the dataset. This is the only computation that generalizes. loss = -torch.mean(psi[hit]) # Backprop loss.backward() psi_optim.step() # Update memory tensors mu[mem_idx] = z_batch.data.cpu().numpy().tolist() transfer[mem_idx] = hit.data.cpu().numpy().tolist() mem_idx = (mem_idx + 1) % config['mem_size'] # Update losses history['losses']['ot_loss'].append(loss.item()) if (ots_iter % 500 == 0): avg_loss = np.mean(history['losses']['ot_loss']) print('OTS Iteration {} | Epoch {} | Avg Loss Value: {}'.format(ots_iter, epoch, round(avg_loss, 3))) # if (iter % 2000 == 0): # # Display histogram stats # hist_dict, stop = utils.update_histogram(transfer, history, config) # # Emperical stopping criterion # if stop: # break if ots_iter > (dset_size//3): if stop_min <= np.mean(history['losses']['ot_loss']) <= stop_max: stop_counter += 1 break # Compute the Optimal Fitting Transport Plan for fit_iter in range(config['mem_size']): G_optim.zero_grad() # Retrieve stored batch of generated samples z_batch = torch.tensor(mu[fit_iter]).to(device) y_fake = G(z_batch) # G'(z) # Get Transfer plan from OTS: T(G_{t-1}(z)) t_plan = torch.tensor(transfer[fit_iter]).to(device) y0_hit = dataloader[t_plan].to(device) # Tesselate the output of the generator function and the data # t1 = tess_var*torch.randn(y_fake.shape[0], y_fake.shape[1]).to(device) # t2 = tess_var*torch.randn(y0_hit.shape[0], y0_hit.shape[1]).to(device) # y_fake *= t1 # y0_hit *= t1 # Compute Wasserstein distance between G and T G_loss = torch.mean(torch.abs(y0_hit - y_fake)) * config['l_dim'] # Backprop G_loss.backward() # Gradient descent G_optim.step() # Update losses history['losses']['fit_loss'].append(G_loss.item()) # Check if best loss value and save checkpoint if 'best_loss' not in history: history.update({ 'best_loss' : G_loss.item() }) best = G_loss.item() < (history['best_loss'] * 0.70) if best: history['best_loss'] = G_loss.item() checkpoint = utils.get_checkpoint(history['epoch'], checkpoint_kwargs, config) utils.save_checkpoint(checkpoint, config) if (fit_iter % 500 == 0): avg_loss = np.mean(history['losses']['fit_loss']) print('FIT Iteration {} | Epoch {} | Avg Loss Value: {}'.format(fit_iter, epoch, round(avg_loss,3))) # Save a checkpoint at end of training checkpoint = utils.get_checkpoint(history['epoch'], checkpoint_kwargs, config) utils.save_checkpoint(checkpoint, config) # Save training data to csv's after training end utils.save_train_hist(history, config, times=None, histogram=history['hist_dict']) print("Stop Counter Triggered {} Times".format(stop_counter)) # For Aiur print("I see you have an appetite for destruction.") print("And you have learned to use your illusion.") print("But I find your lack of control disturbing.")
z_output = torch.randn(batch_size, nz, device=device) for it in range(1, 101): # OTS ot_loss = [] w1_estimate = [] memory_p = 0 memory_z = torch.zeros(memory_size, batch_size, nz) memory_y = torch.zeros(memory_size, batch_size, dtype=torch.long) for ots_iter in range(1, 20001): opt_psi.zero_grad() z_batch = torch.randn(batch_size, nz, device=device) y_fake = G(z_batch) if distance == "W1": score = -my_ops.l1_t(y_fake, y_t) - psi elif distance == "W2" or distance == "Hybrid": score = torch.matmul(y_fake, y_t.t()) - psi phi, hit = torch.max(score, 1) loss = torch.mean(phi) + torch.mean(psi) if distance == "W1" or distance == "Hybrid": loss_primal = torch.mean(torch.abs(y_fake - y_t[hit])) * d elif distance == "W2": loss_primal = torch.mean((y_fake - y_t[hit])**2) * d loss_back = -torch.mean(psi[hit]) # equivalent to loss loss_back.backward() opt_psi.step() ot_loss.append(loss.item())