def train(self, conditional=True): if conditional: print('USING CONDITIONAL DSM') if self.config.data.random_flip is False: tran_transform = test_transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) else: tran_transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor() ]) test_transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) if self.config.data.dataset == 'CIFAR10': dataset = CIFAR10(os.path.join(self.args.run, 'datasets'), train=True, download=True, transform=tran_transform) elif self.config.data.dataset == 'MNIST': print('RUNNING REDUCED MNIST') dataset = MNIST(os.path.join(self.args.run, 'datasets'), train=True, download=True, transform=tran_transform) elif self.config.data.dataset == 'FashionMNIST': dataset = FashionMNIST(os.path.join(self.args.run, 'datasets'), train=True, download=True, transform=tran_transform) elif self.config.data.dataset == 'MNIST_transferBaseline': # use same dataset as transfer_nets.py # we can also use the train dataset since the digits are unseen anyway dataset = MNIST(os.path.join(self.args.run, 'datasets'), train=False, download=True, transform=test_transform) print('TRANSFER BASELINES !! Subset size: ' + str(self.subsetSize)) elif self.config.data.dataset == 'CIFAR10_transferBaseline': # use same dataset as transfer_nets.py # we can also use the train dataset since the digits are unseen anyway dataset = CIFAR10(os.path.join(self.args.run, 'datasets'), train=False, download=True, transform=test_transform) print('TRANSFER BASELINES !! Subset size: ' + str(self.subsetSize)) elif self.config.data.dataset == 'FashionMNIST_transferBaseline': # use same dataset as transfer_nets.py # we can also use the train dataset since the digits are unseen anyway dataset = FashionMNIST(os.path.join(self.args.run, 'datasets'), train=False, download=True, transform=test_transform) print('TRANSFER BASELINES !! Subset size: ' + str(self.subsetSize)) else: raise ValueError('Unknown config dataset {}'.format(self.config.data.dataset)) # apply collation if self.config.data.dataset in ['MNIST', 'CIFAR10', 'FashionMNIST']: collate_helper = lambda batch: my_collate(batch, nSeg=self.nSeg) print('Subset size: ' + str(self.subsetSize)) id_range = list(range(self.subsetSize)) dataset = torch.utils.data.Subset(dataset, id_range) dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=0, collate_fn=collate_helper) elif self.config.data.dataset in ['MNIST_transferBaseline', 'CIFAR10_transferBaseline', 'FashionMNIST_transferBaseline']: # trains a model on only digits 8,9 from scratch print('Subset size: ' + str(self.subsetSize)) id_range = list(range(self.subsetSize)) dataset = torch.utils.data.Subset(dataset, id_range) dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=0, drop_last=True, collate_fn=my_collate_rev) print('loaded reduced subset') else: raise ValueError('Unknown config dataset {}'.format(self.config.data.dataset)) self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels # define the g network energy_net_finalLayer = torch.ones((self.config.data.image_size * self.config.data.image_size, self.nSeg)).to( self.config.device) energy_net_finalLayer.requires_grad_() # define the f network enet = RefineNetDilated(self.config).to(self.config.device) enet = torch.nn.DataParallel(enet) # training optimizer = self.get_optimizer(list(enet.parameters()) + [energy_net_finalLayer]) step = 0 loss_track_epochs = [] for epoch in range(self.config.training.n_epochs): loss_vals = [] for i, (X, y) in enumerate(dataloader): step += 1 enet.train() X = X.to(self.config.device) X = X / 256. * 255. + torch.rand_like(X) / 256. if self.config.data.logit_transform: X = self.logit_transform(X) y -= y.min() # need to ensure its zero centered ! if conditional: loss = conditional_dsm(enet, X, y, energy_net_finalLayer, sigma=0.01) else: loss = dsm(enet, X, sigma=0.01) optimizer.zero_grad() loss.backward() optimizer.step() logging.info("step: {}, loss: {}, maxLabel: {}".format(step, loss.item(), y.max())) loss_vals.append(loss.item()) loss_track_epochs.append(loss.item()) if step >= self.config.training.n_iters: # save final checkpoints for distrubution! states = [ enet.state_dict(), optimizer.state_dict(), ] torch.save(states, os.path.join(self.args.checkpoints, 'checkpoint_{}.pth'.format(step))) torch.save(states, os.path.join(self.args.checkpoints, 'checkpoint.pth')) torch.save([energy_net_finalLayer], os.path.join(self.args.checkpoints, 'finalLayerweights_.pth')) pickle.dump(energy_net_finalLayer, open(os.path.join(self.args.checkpoints, 'finalLayerweights.p'), 'wb')) return 0 if step % self.config.training.snapshot_freq == 0: print('checkpoint at step: {}'.format(step)) # save checkpoint for transfer learning! ! torch.save([energy_net_finalLayer], os.path.join(self.args.log, 'finalLayerweights_.pth')) pickle.dump(energy_net_finalLayer, open(os.path.join(self.args.log, 'finalLayerweights.p'), 'wb')) states = [ enet.state_dict(), optimizer.state_dict(), ] torch.save(states, os.path.join(self.args.log, 'checkpoint_{}.pth'.format(step))) torch.save(states, os.path.join(self.args.log, 'checkpoint.pth')) if self.config.data.dataset in ['MNIST_transferBaseline', 'CIFAR10_transferBaseline']: # save loss track during epoch for transfer baseline pickle.dump(loss_vals, open(os.path.join(self.args.run, self.args.dataset + '_Baseline_Size' + str( self.subsetSize) + "_Seed" + str(self.seed) + '.p'), 'wb')) if self.config.data.dataset in ['MNIST_transferBaseline', 'CIFAR10_transferBaseline']: # save loss track during epoch for transfer baseline pickle.dump(loss_track_epochs, open(os.path.join(self.args.run, self.args.dataset + '_Baseline_epochs_Size' + str( self.subsetSize) + "_Seed" + str(self.seed) + '.p'), 'wb')) # save final checkpoints for distrubution! states = [ enet.state_dict(), optimizer.state_dict(), ] torch.save(states, os.path.join(self.args.checkpoints, 'checkpoint_{}.pth'.format(step))) torch.save(states, os.path.join(self.args.checkpoints, 'checkpoint.pth')) torch.save([energy_net_finalLayer], os.path.join(self.args.checkpoints, 'finalLayerweights_.pth')) pickle.dump(energy_net_finalLayer, open(os.path.join(self.args.checkpoints, 'finalLayerweights.p'), 'wb'))
def finalize( self, dkef, tb_logger, train_data, val_data, test_data, collate_fn, train_mode ): lambda_params = [ param for (name, param) in dkef.named_parameters() if "lambd" in name ] optimizer = optim.Adam(lambda_params, lr=0.001) batch_size = self.config.training.fval_batch_size val_loader = DataLoader( val_data, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn, ) test_loader = DataLoader( test_data, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn, ) dkef.save_alpha_matrices( train_data, collate_fn, self.config.device, override=True ) def energy_net(inputs): return -dkef(None, inputs, stage="finalize") step = 0 while step < 1000: for val_batch in val_loader: if step >= 1000: break val_batch = val_batch.to(self.config.device) if train_mode == "exact": val_loss = exact_score_matching(energy_net, val_batch, train=True) elif train_mode == "sliced": val_loss, _, _ = single_sliced_score_matching(energy_net, val_batch) elif train_mode == "sliced_fd": val_loss = efficient_score_matching_conjugate(energy_net, val_batch) elif train_mode == "sliced_VR": val_loss, _, _ = sliced_VR_score_matching(energy_net, val_batch) elif train_mode == "dsm": val_loss = dsm(energy_net, val_batch, sigma=self.dsm_sigma) elif train_mode == "dsm_fd": val_loss = dsm_fd(energy_net, val_batch, sigma=self.dsm_sigma) elif train_mode == "kingma": logp, grad1, grad2 = dkef.approx_bp_forward( None, val_batch, stage="finalize", mode=train_mode ) val_loss = (0.5 * grad1 ** 2).sum(1) + grad2.sum(1) elif train_mode == "CP": logp, grad1, S_r, S_i = dkef.approx_bp_forward( None, val_batch, stage="finalize", mode=train_mode ) grad2 = S_r ** 2 - S_i ** 2 val_loss = (0.5 * grad1 ** 2).sum(1) + grad2.sum(1) val_loss = val_loss.mean() optimizer.zero_grad() val_loss.backward() tn = nn.utils.clip_grad_norm_(dkef.parameters(), 1.0) optimizer.step() logging.info("Val loss: {:.3f}".format(val_loss)) tb_logger.add_scalar("finalize/loss", val_loss, global_step=step) step += 1 val_losses = [] for data_v in val_loader: data_v = data_v.to(self.config.device) batch_val_loss = exact_score_matching(energy_net, data_v, train=False) val_losses.append(batch_val_loss.mean()) val_loss = sum(val_losses) / len(val_losses) logging.info("Overall val exact score matching: {:.3f}".format(val_loss)) tb_logger.add_scalar("finalize/final_valid_score", val_loss, global_step=0) self.results["final_valid_score"] = np.asscalar(val_loss.cpu().numpy()) test_losses = [] for data_t in test_loader: data_t = data_t.to(self.config.device) batch_test_loss = exact_score_matching(energy_net, data_t, train=False) test_losses.append(batch_test_loss.mean()) test_loss = sum(test_losses) / len(test_losses) logging.info("Overall test exact score matching: {:.3f}".format(test_loss)) tb_logger.add_scalar("finalize/final_test_score", test_loss, global_step=0) self.results["final_test_score"] = np.asscalar(test_loss.cpu().numpy())
def train_stage1( self, dkef, tb_logger, train_data, val_data, collate_fn, train_mode ): optimizer = self.get_optimizer(dkef.parameters()) step = 0 num_mb = len(train_data) // self.config.training.batch_size split_size = self.config.training.batch_size // 2 best_val_step = 0 best_val_loss = 1e5 best_model = None train_losses = np.zeros(30) val_loss_window = np.zeros(15) torch.cuda.synchronize() prev_time = time.time() val_batch_size = len(val_data) num_val_iters = 1 val_loader = DataLoader( val_data, batch_size=val_batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn, ) train_loader = DataLoader( train_data, batch_size=split_size, shuffle=True, num_workers=2, collate_fn=collate_fn, ) train_iter = iter(train_loader) val_iter = iter(val_loader) total_time = 0.0 time_dur = 0.0 secs_per_it = [] for _ in range(self.config.training.n_epochs): for _ in range(num_mb): train_iter, X_t = self.sample(train_iter, train_loader) train_iter, X_v = self.sample(train_iter, train_loader) start_point = time.time() def energy_net(inputs): return -dkef(X_t, inputs) if train_mode == "exact": train_loss = exact_score_matching(energy_net, X_v, train=True) elif train_mode == "sliced": train_loss, _, _ = single_sliced_score_matching(energy_net, X_v) elif train_mode == "sliced_fd": train_loss = efficient_score_matching_conjugate(energy_net, X_v) elif train_mode == "sliced_VR": train_loss, _, _ = sliced_VR_score_matching(energy_net, X_v) elif train_mode == "dsm": train_loss = dsm(energy_net, X_v, sigma=self.dsm_sigma) elif train_mode == "dsm_fd": train_loss = dsm_fd(energy_net, X_v, sigma=self.dsm_sigma) elif train_mode == "kingma": logp, grad1, grad2 = dkef.approx_bp_forward( X_t, X_v, stage="train", mode=train_mode ) train_loss = (0.5 * grad1 ** 2).sum(1) + grad2.sum(1) elif train_mode == "CP": logp, grad1, S_r, S_i = dkef.approx_bp_forward( X_t, X_v, stage="train", mode=train_mode ) grad2 = S_r ** 2 - S_i ** 2 train_loss = (0.5 * grad1 ** 2).sum(1) + grad2.sum(1) train_loss = train_loss.mean() optimizer.zero_grad() train_loss.backward() train_losses[step % 30] = train_loss.detach() # Their code clips by overall gradient norm at 100. tn = nn.utils.clip_grad_norm_(dkef.parameters(), 1.0) optimizer.step() time_dur += time.time() - start_point idx = np.random.choice(len(train_data), 1000, replace=False) train_data_for_val = torch.utils.data.Subset(train_data, idx) dkef.save_alpha_matrices( train_data_for_val, collate_fn, self.config.device ) # Compute validation loss def energy_net_val(inputs): return -dkef(None, inputs, stage="eval") val_losses = [] for val_step in range(num_val_iters): val_iter, data_v = self.sample(val_iter, val_loader) if train_mode == "exact": batch_val_loss = exact_score_matching( energy_net_val, data_v, train=False ) elif train_mode == "sliced": batch_val_loss, _, _ = single_sliced_score_matching( energy_net_val, data_v, detach=True ) elif train_mode == "sliced_fd": batch_val_loss = efficient_score_matching_conjugate( energy_net_val, data_v, detach=True ) elif train_mode == "sliced_VR": batch_val_loss, _, _ = sliced_VR_score_matching( energy_net_val, data_v, detach=True ) elif train_mode == "dsm": batch_val_loss = dsm( energy_net_val, data_v, sigma=self.dsm_sigma ) elif train_mode == "dsm_fd": batch_val_loss = dsm_fd( energy_net_val, data_v, sigma=self.dsm_sigma ) elif train_mode == "kingma": logp, grad1, grad2 = dkef.approx_bp_forward( None, X_v, stage="eval", mode=train_mode ) batch_val_loss = (0.5 * grad1 ** 2).sum(1) + grad2.sum(1) elif train_mode == "CP": logp, grad1, S_r, S_i = dkef.approx_bp_forward( None, X_v, stage="eval", mode=train_mode ) grad2 = S_r ** 2 - S_i ** 2 batch_val_loss = (0.5 * grad1 ** 2).sum(1) + grad2.sum(1) val_losses.append(batch_val_loss.mean()) val_loss = sum(val_losses) / len(val_losses) val_loss_window[step % 15] = val_loss.detach() smoothed_val_loss = ( val_loss_window[: step + 1].mean() if step < 15 else val_loss_window.mean() ) if val_loss < best_val_loss: best_val_loss = val_loss best_val_step = step best_model = copy.deepcopy(dkef.state_dict()) elif step - best_val_step > self.config.training.patience: self.results["secs_per_it"] = sum(secs_per_it) / len(secs_per_it) self.results["its_per_sec"] = 1.0 / self.results["secs_per_it"] logging.info( "Validation loss has not improved in {} steps. Finalizing model!".format( self.config.training.patience ) ) return best_model mean_train_loss = ( train_losses[: step + 1].mean() if step < 30 else train_losses.mean() ) logging.info( "Step {}, Training loss: {:.2f}, validation loss: {:.2f}".format( step, mean_train_loss, best_val_loss ) ) tb_logger.add_scalar( "train/train_loss_smoothed", mean_train_loss, global_step=step ) tb_logger.add_scalar( "train/best_val_loss", best_val_loss, global_step=step ) tb_logger.add_scalar("train/train_loss", train_loss, global_step=step) tb_logger.add_scalar("train/val_loss", val_loss, global_step=step) if step % 20 == 0: torch.cuda.synchronize() new_time = time.time() logging.info("#" * 80) if step > 0: secs_per_it.append((new_time - prev_time) / 20.0) logging.info( "Iterations per second: {:.3f}".format( 20.0 / (new_time - prev_time) ) ) logging.info("Only Training Time: {:.3f}".format(time_dur)) time_dur = 0.0 tb_logger.add_scalar( "train/its_per_sec", 20.0 / (new_time - prev_time), global_step=step, ) if step > 0: total_time += new_time - prev_time val_losses_exact = [] for val_step in range(num_val_iters): val_iter, data_v = self.sample(val_iter, val_loader) vle = exact_score_matching(energy_net_val, data_v, train=False) val_losses_exact.append(vle.mean()) val_loss_exact = sum(val_losses_exact) / len(val_losses_exact) logging.info( "Exact score matching loss on val: {:.2f}".format( val_loss_exact.mean() ) ) tb_logger.add_scalar( "eval/exact_score_matching", val_loss_exact.mean(), global_step=step, ) logging.info("#" * 80) torch.cuda.synchronize() prev_time = time.time() step += 1 logging.info("Completed training") self.results["secs_per_it"] = sum(secs_per_it) / len(secs_per_it) self.results["its_per_sec"] = 1.0 / self.results["secs_per_it"] return best_model
def train(self): transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) if self.config.data.dataset == 'CIFAR10': dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=True, download=True, transform=transform) test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True, transform=transform) elif self.config.data.dataset == 'MNIST': dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=True, download=True, transform=transform) num_items = len(dataset) indices = list(range(num_items)) random_state = np.random.get_state() np.random.seed(2019) np.random.shuffle(indices) np.random.set_state(random_state) train_indices, val_indices = indices[:int(num_items * 0.9)], indices[int(num_items * 0.9):] val_dataset = Subset(dataset, val_indices) dataset = Subset(dataset, train_indices) test_dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=False, download=True, transform=transform) dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=2) val_loader = DataLoader(val_dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=2) test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=2) val_iter = iter(val_loader) self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc) if os.path.exists(tb_path): shutil.rmtree(tb_path) model_path = os.path.join(self.args.run, 'results', self.args.doc) if os.path.exists(model_path): shutil.rmtree(model_path) os.makedirs(model_path) tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path) flow = NICE(self.config.input_dim, self.config.model.hidden_size, self.config.model.num_layers).to( self.config.device) optimizer = self.get_optimizer(flow.parameters()) # Set up test data noise_sigma = self.config.data.noise_sigma step = 0 def energy_net(inputs): energy, _ = flow(inputs, inv=False) return -energy def grad_net_kingma(inputs): energy, _ = flow(inputs, inv=False) grad1, grad2 = flow.grads_backward(inv=False) return -grad1, -grad2 def grad_net_UT(inputs): energy, _ = flow(inputs, inv=False) grad1, T, U = flow.grads_backward_TU(inv=False) grad2 = T * U / 2. return -grad1, -grad2 def grad_net_S(inputs): energy, _ = flow(inputs, inv=False) grad1, S_r, S_i = flow.grads_backward_S(inv=False) grad2 = (S_r ** 2 - S_i ** 2) return -grad1, -grad2 def sample_net(z): samples, _ = flow(z, inv=True) samples, _ = Logit()(samples, mode='inverse') return samples # Use this to select the sigma for DSM losses if self.config.training.algo == 'dsm': sigma = self.args.dsm_sigma # if noise_sigma is None: # sigma = select_sigma(iter(dataloader), iter(val_loader)) # else: # sigma = select_sigma(iter(dataloader), iter(val_loader), noise_sigma=noise_sigma) if self.args.load_path != "": flow.load_state_dict(torch.load(self.args.load_path)) best_model = {"val": None, "ll": None, "esm": None} best_val_loss = {"val": 1e+10, "ll": -1e+10, "esm": 1e+10} best_val_iter = {"val": 0, "ll": 0, "esm": 0} for _ in range(self.config.training.n_epochs): for _, (X, y) in enumerate(dataloader): X = X + (torch.rand_like(X) - 0.5) / 256. flattened_X = X.type(torch.float32).to(self.config.device).view(X.shape[0], -1) flattened_X.clamp_(1e-3, 1-1e-3) flattened_X, _ = Logit()(flattened_X, mode='direct') if noise_sigma is not None: flattened_X += torch.randn_like(flattened_X) * noise_sigma flattened_X.requires_grad_(True) logp = -energy_net(flattened_X) logp = logp.mean() if self.config.training.algo == 'kingma': loss = approx_backprop_score_matching(grad_net_kingma, flattened_X) if self.config.training.algo == 'UT': loss = approx_backprop_score_matching(grad_net_UT, flattened_X) if self.config.training.algo == 'S': loss = approx_backprop_score_matching(grad_net_S, flattened_X) elif self.config.training.algo == 'mle': loss = -logp elif self.config.training.algo == 'ssm': loss, *_ = single_sliced_score_matching(energy_net, flattened_X, noise_type=self.config.training.noise_type) elif self.config.training.algo == 'ssm_vr': loss, *_ = sliced_VR_score_matching(energy_net, flattened_X, noise_type=self.config.training.noise_type) elif self.config.training.algo == 'dsm': loss = dsm(energy_net, flattened_X, sigma=sigma) elif self.config.training.algo == "exact": loss = exact_score_matching(energy_net, flattened_X, train=True).mean() optimizer.zero_grad() loss.backward() optimizer.step() if step % 10 == 0: try: val_X, _ = next(val_iter) except: val_iter = iter(val_loader) val_X, _ = next(val_iter) val_X = val_X + (torch.rand_like(val_X) - 0.5) / 256. val_X = val_X.type(torch.float32).to(self.config.device) val_X.clamp_(1e-3, 1-1e-3) val_X, _ = Logit()(val_X, mode='direct') val_X = val_X.view(val_X.shape[0], -1) if noise_sigma is not None: val_X += torch.randn_like(val_X) * noise_sigma val_logp = -energy_net(val_X).mean() if self.config.training.algo == 'kingma': val_loss = approx_backprop_score_matching(grad_net_kingma, val_X) if self.config.training.algo == 'UT': val_loss = approx_backprop_score_matching(grad_net_UT, val_X) if self.config.training.algo == 'S': val_loss = approx_backprop_score_matching(grad_net_S, val_X) elif self.config.training.algo == 'ssm': val_loss, *_ = single_sliced_score_matching(energy_net, val_X, noise_type=self.config.training.noise_type) elif self.config.training.algo == 'ssm_vr': val_loss, *_ = sliced_VR_score_matching(energy_net, val_X, noise_type=self.config.training.noise_type) elif self.config.training.algo == 'dsm': val_loss = dsm(energy_net, val_X, sigma=sigma) elif self.config.training.algo == 'mle': val_loss = -val_logp elif self.config.training.algo == "exact": val_loss = exact_score_matching(energy_net, val_X, train=False).mean() logging.info("logp: {:.3f}, val_logp: {:.3f}, loss: {:.3f}, val_loss: {:.3f}".format(logp.item(), val_logp.item(), loss.item(), val_loss.item())) tb_logger.add_scalar('logp', logp, global_step=step) tb_logger.add_scalar('loss', loss, global_step=step) tb_logger.add_scalar('val_logp', val_logp, global_step=step) tb_logger.add_scalar('val_loss', val_loss, global_step=step) if val_loss < best_val_loss['val']: best_val_loss['val'] = val_loss best_val_iter['val'] = step best_model['val'] = copy.deepcopy(flow.state_dict()) if val_logp > best_val_loss['ll']: best_val_loss['ll'] = val_logp best_val_iter['ll'] = step best_model['ll'] = copy.deepcopy(flow.state_dict()) if step % 100 == 0: with torch.no_grad(): z = torch.normal(torch.zeros(100, flattened_X.shape[1], device=self.config.device)) samples = sample_net(z) samples = samples.view(100, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) samples = torch.clamp(samples, 0.0, 1.0) image_grid = make_grid(samples, 10) tb_logger.add_image('samples', image_grid, global_step=step) data = X data_grid = make_grid(data[:100], 10) tb_logger.add_image('data', data_grid, global_step=step) logging.info("Computing exact score matching....") try: val_X, _ = next(val_iter) except: val_iter = iter(val_loader) val_X, _ = next(val_iter) val_X = val_X + (torch.rand_like(val_X) - 0.5) / 256. val_X = val_X.type(torch.float32).to(self.config.device) val_X.clamp_(1e-3, 1-1e-3) val_X, _ = Logit()(val_X, mode='direct') val_X = val_X.view(val_X.shape[0], -1) if noise_sigma is not None: val_X += torch.randn_like(val_X) * noise_sigma sm_loss = exact_score_matching(energy_net, val_X, train=False).mean() if sm_loss < best_val_loss['esm']: best_val_loss['esm'] = sm_loss best_val_iter['esm'] = step best_model['esm'] = copy.deepcopy(flow.state_dict()) logging.info('step: {}, exact score matching loss: {}'.format(step, sm_loss.item())) tb_logger.add_scalar('exact_score_matching_loss', sm_loss, global_step=step) if step % 500 == 0: torch.save(flow.state_dict(), os.path.join(model_path, 'nice.pth')) step += 1 self.results = {} self.evaluate_model(flow.state_dict(), "final", val_loader, test_loader, model_path) self.evaluate_model(best_model['val'], "best_on_val", val_loader, test_loader, model_path) self.evaluate_model(best_model['ll'], "best_on_ll", val_loader, test_loader, model_path) self.evaluate_model(best_model['esm'], "best_on_esm", val_loader, test_loader, model_path) self.results['final']['num_iters'] = step self.results['best_on_val']['num_iters'] = best_val_iter['val'] self.results['best_on_ll']['num_iters'] = best_val_iter['ll'] self.results['best_on_esm']['num_iters'] = best_val_iter['esm'] pickle_out = open(model_path + "/results.pkl", "wb") pickle.dump(self.results, pickle_out) pickle_out.close()
def transfer(args, config): """ once an icebeem is pretrained on some labels (0-7), we train only secondary network (g in our manuscript) on unseen labels 8-9 (these are new datasets) """ conditional = args.subset_size != 0 # load data dataloader, dataset, cond_size = get_dataset(args, config, test=False, rev=True, one_hot=True, subset=True) # load the feature network f ckpt_path = os.path.join(args.checkpoints, 'checkpoint.pth') print('loading weights from: {}'.format(ckpt_path)) states = torch.load(ckpt_path, map_location=config.device) f = feature_net(config).to(config.device) f.load_state_dict(states[0]) if conditional: # define the feature network g g = SimpleLinear(cond_size, f.output_size, bias=False).to(config.device) energy_net = ModularUnnormalizedConditionalEBM( f, g, augment=config.model.augment, positive=config.model.positive) # define the optimizer parameters = energy_net.g.parameters() optimizer = get_optimizer(config, parameters) else: # no learning is involved: just evaluate f on the new labels, with g = 1 energy_net = ModularUnnormalizedEBM(f) optimizer = None # start optimizing! eCount = 10 loss_track_epochs = [] for epoch in range(eCount): print('epoch: ' + str(epoch)) loss_track = [] for i, (X, y) in enumerate(dataloader): X = X.to(config.device) X = X / 256. * 255. + torch.rand_like(X) / 256. if conditional: loss = cdsm(energy_net, X, y, sigma=0.01) optimizer.zero_grad() loss.backward() optimizer.step() else: # just evaluate the DSM loss using the pretarined f --- no learning loss = dsm(energy_net, X, sigma=0.01) loss.backward( ) # strangely, without this line, the script requires twice as much GPU memory loss_track.append(loss.item()) loss_track_epochs.append(loss.item()) pickle.dump( loss_track, open( os.path.join( args.output, 'size{}_seed{}.p'.format(args.subset_size, args.seed)), 'wb')) print('saving loss track under: {}'.format(args.output)) pickle.dump( loss_track_epochs, open( os.path.join( args.output, 'all_epochs_SIZE{}_SEED{}.p'.format(args.subset_size, args.seed)), 'wb'))
def train(args, config, conditional=True): save_weights = 'baseline' not in config.data.dataset.lower( ) # we don't need the if args.subset_size == 0: conditional = False # load dataset dataloader, dataset, cond_size = get_dataset(args, config, one_hot=True) # define the energy model if conditional: f = feature_net(config).to(config.device) g = SimpleLinear(cond_size, f.output_size, bias=False).to(config.device) energy_net = ModularUnnormalizedConditionalEBM( f, g, augment=config.model.augment, positive=config.model.positive) else: f = feature_net(config).to(config.device) energy_net = ModularUnnormalizedEBM(f) # get optimizer optimizer = get_optimizer(config, energy_net.parameters()) # train step = 0 loss_track_epochs = [] for epoch in range(config.training.n_epochs): loss_track = [] for i, (X, y) in enumerate(dataloader): step += 1 energy_net.train() X = X.to(config.device) X = X / 256. * 255. + torch.rand_like(X) / 256. if config.data.logit_transform: X = logit_transform(X) # compute loss if conditional: loss = cdsm(energy_net, X, y, sigma=0.01) else: loss = dsm(energy_net, X, sigma=0.01) # optimize optimizer.zero_grad() loss.backward() optimizer.step() loss_track.append(loss.item()) loss_track_epochs.append(loss.item()) if step >= config.training.n_iters and save_weights: enet, energy_net_finalLayer = energy_net.f, energy_net.g # save final checkpoints for distribution! states = [ enet.state_dict(), optimizer.state_dict(), ] print('saving weights under: {}'.format(args.checkpoints)) # torch.save(states, os.path.join(args.checkpoints, 'checkpoint_{}.pth'.format(step))) torch.save(states, os.path.join(args.checkpoints, 'checkpoint.pth')) torch.save([energy_net_finalLayer], os.path.join(args.checkpoints, 'finalLayerweights_.pth')) pickle.dump( energy_net_finalLayer, open(os.path.join(args.checkpoints, 'finalLayerweights.p'), 'wb')) return 0 if step % config.training.snapshot_freq == 0: enet, energy_net_finalLayer = energy_net.f, energy_net.g print('checkpoint at step: {}'.format(step)) # save checkpoint for transfer learning! ! # torch.save([energy_net_finalLayer], os.path.join(args.log, 'finalLayerweights_.pth')) # pickle.dump(energy_net_finalLayer, # open(os.path.join(args.log, 'finalLayerweights.p'), 'wb')) # states = [ # enet.state_dict(), # optimizer.state_dict(), # ] # torch.save(states, os.path.join(args.log, 'checkpoint_{}.pth'.format(step))) # torch.save(states, os.path.join(args.log, 'checkpoint.pth')) if config.data.dataset.lower() in [ 'mnist_transferbaseline', 'cifar10_transferbaseline', 'fashionmnist_transferbaseline', 'cifar100_transferbaseline' ]: # save loss track during epoch for transfer baseline pickle.dump( loss_track, open( os.path.join( args.output, 'size{}_seed{}.p'.format(args.subset_size, args.seed)), 'wb')) if config.data.dataset.lower() in [ 'mnist_transferbaseline', 'cifar10_transferbaseline', 'fashionmnist_transferbaseline', 'cifar100_transferbaseline' ]: # save loss track during epoch for transfer baseline print('saving loss track under: {}'.format(args.output)) pickle.dump( loss_track_epochs, open( os.path.join( args.output, 'all_epochs_SIZE{}_SEED{}.p'.format( args.subset_size, args.seed)), 'wb')) # save final checkpoints for distrubution! if save_weights: enet, energy_net_finalLayer = energy_net.f, energy_net.g states = [ enet.state_dict(), optimizer.state_dict(), ] print('saving weights under: {}'.format(args.checkpoints)) # torch.save(states, os.path.join(args.checkpoints, 'checkpoint_{}.pth'.format(step))) torch.save(states, os.path.join(args.checkpoints, 'checkpoint.pth')) torch.save([energy_net_finalLayer], os.path.join(args.checkpoints, 'finalLayerweights_.pth')) pickle.dump( energy_net_finalLayer, open(os.path.join(args.checkpoints, 'finalLayerweights.p'), 'wb'))