class FlowExperiment(BaseExperiment): log_base = os.path.join(get_survae_path(), 'experiments/manifold/log') no_log_keys = [ 'project', 'name', 'log_tb', 'log_wandb', 'check_every', 'eval_every', 'device', 'parallel', 'pin_memory', 'num_workers' ] def __init__(self, args, data_id, model_id, optim_id, train_loader, eval_loader, model, optimizer, scheduler_iter, scheduler_epoch): # Edit args if args.eval_every is None or args.eval_every == 0: args.eval_every = args.epochs if args.check_every is None or args.check_every == 0: args.check_every = args.epochs if args.name is None: args.name = time.strftime("%Y-%m-%d_%H-%M-%S") if args.project is None: args.project = '_'.join([data_id, model_id]) if args.name == "debug": log_path = os.path.join(self.log_base, "debug", data_id, model_id, optim_id, f"seed{args.seed}", time.strftime("%Y-%m-%d_%H-%M-%S")) else: log_path = os.path.join(self.log_base, data_id, model_id, optim_id, f"seed{args.seed}", args.name) # Move model model = model.to(args.device) if args.parallel == 'dp': model = DataParallelDistribution(model) # Init parent super(FlowExperiment, self).__init__(model=model, optimizer=optimizer, scheduler_iter=scheduler_iter, scheduler_epoch=scheduler_epoch, log_path=log_path, eval_every=args.eval_every, check_every=args.check_every, save_samples=args.save_samples) # Store args self.create_folders() self.save_args(args) self.args = args # Store IDs self.data_id = data_id self.model_id = model_id self.optim_id = optim_id # Store data loaders self.train_loader = train_loader self.eval_loader = eval_loader # Init logging args_dict = clean_dict(vars(args), keys=self.no_log_keys) if args.log_tb: self.writer = SummaryWriter(os.path.join(self.log_path, 'tb')) self.writer.add_text("args", get_args_table(args_dict).get_html_string(), global_step=0) if args.log_wandb: wandb.init(config=args_dict, project=args.project, id=args.name, dir=self.log_path) # training params self.max_grad_norm = args.max_grad_norm # automatic mixed precision # bigger changes need to make this work with dataparallel though (@autocast() decoration on each forward) pytorch_170 = int(str(torch.__version__)[2]) >= 7 self.amp = args.amp and args.parallel != 'dp' and pytorch_170 if self.amp: # only available in pytorch 1.7.0+ self.scaler = torch.cuda.amp.GradScaler() else: self.scaler = None # save model architecture for reference self.save_architecture() def log_fn(self, epoch, train_dict, eval_dict): # Tensorboard if self.args.log_tb: for metric_name, metric_value in train_dict.items(): self.writer.add_scalar('base/{}'.format(metric_name), metric_value, global_step=epoch + 1) if eval_dict: for metric_name, metric_value in eval_dict.items(): self.writer.add_scalar('eval/{}'.format(metric_name), metric_value, global_step=epoch + 1) # Weights & Biases if self.args.log_wandb: for metric_name, metric_value in train_dict.items(): wandb.log({'base/{}'.format(metric_name): metric_value}, step=epoch + 1) if eval_dict: for metric_name, metric_value in eval_dict.items(): wandb.log({'eval/{}'.format(metric_name): metric_value}, step=epoch + 1) def resume(self): resume_path = os.path.join(self.log_base, self.data_id, self.model_id, self.optim_id, self.args.resume, 'check') self.checkpoint_load(resume_path) for epoch in range(self.current_epoch): train_dict = {} for metric_name, metric_values in self.train_metrics.items(): train_dict[metric_name] = metric_values[epoch] if epoch in self.eval_epochs: eval_dict = {} for metric_name, metric_values in self.eval_metrics.items(): eval_dict[metric_name] = metric_values[ self.eval_epochs.index(epoch)] else: eval_dict = None self.log_fn(epoch, train_dict=train_dict, eval_dict=eval_dict) def beta_annealing(self, epoch, min_beta=0.0, max_beta=1.0): """ Beta annealing term to control the weight of the NDP portion of the flow. Often it can help to let the bijective portion of the flow begin to map to a the first base distribution before expecting the NDP portion of the flow find a low dimensional representation. Only applicable when using 2 (or more) base distributions """ if self.args.annealing_schedule > 0 and self.args.compression == "vae" and len( self.args.base_distributions) > 1: return max( min([(epoch * 1.0) / max([self.args.annealing_schedule, 1.0]), max_beta]), min_beta) else: return 1.0 def run(self): if self.args.resume: self.resume() super(FlowExperiment, self).run(epochs=self.args.epochs) def train_fn(self, epoch): if self.amp: # use automatic mixed precision return self._train_amp(epoch) else: return self._train(epoch) def _train_amp(self, epoch): """ Same training procedure, but uses half precision to speed up training on GPUs NOTE: Not currently implemented, this only runs on the latest pytorch versions so I'm leaving this out for now. """ self.model.train() beta = self.beta_annealing(epoch) loss_sum = 0.0 loss_count = 0 for x in self.train_loader: # Cast operations to mixed precision with torch.cuda.amp.autocast(): loss = -self.model.log_prob(x.to( self.args.device), beta=beta).sum() / (math.log(2) * x.shape.numel()) #loss = elbo_bpd(self.model, x.to(self.args.device), beta=beta) # Scale loss and call backward() to create scaled gradients self.scaler.scale(loss).backward() if self.max_grad_norm > 0: # Unscales the gradients of optimizer's assigned params in-place self.scaler.unscale_(self.optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.max_grad_norm) # Unscale gradients and call (or skip) optimizer.step() self.scaler.step(self.optimizer) self.scaler.update() if self.scheduler_iter: self.scheduler_iter.step() self.optimizer.zero_grad(set_to_none=True) # accumulate loss and report loss_sum += loss.detach().cpu().item() * len(x) loss_count += len(x) print('Training. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'. format(epoch + 1, self.args.epochs, loss_count, len(self.train_loader.dataset), loss_sum / loss_count), end='\r') print('') if self.scheduler_epoch: self.scheduler_epoch.step() return {'bpd': loss_sum / loss_count} def _train(self, epoch): self.model.train() beta = self.beta_annealing(epoch) loss_sum = 0.0 loss_count = 0 for x in self.train_loader: self.optimizer.zero_grad() loss = -self.model.log_prob(x.to(self.args.device), beta=beta).sum() / (math.log(2) * x.shape.numel()) #loss = elbo_bpd(self.model, x.to(self.args.device), beta=beta) loss.backward() if self.max_grad_norm > 0: grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.max_grad_norm) self.optimizer.step() if self.scheduler_iter: self.scheduler_iter.step() loss_sum += loss.detach().cpu().item() * len(x) loss_count += len(x) print('Training. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'. format(epoch + 1, self.args.epochs, loss_count, len(self.train_loader.dataset), loss_sum / loss_count), end='\r') print('') if self.scheduler_epoch: self.scheduler_epoch.step() return {'bpd': loss_sum / loss_count} def eval_fn(self, epoch): self.model.eval() with torch.no_grad(): loss_sum = 0.0 loss_count = 0 for x in self.eval_loader: loss = elbo_bpd(self.model, x.to(self.args.device)) loss_sum += loss.detach().cpu().item() * len(x) loss_count += len(x) print( 'Evaluating. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}' .format(epoch + 1, self.args.epochs, loss_count, len(self.eval_loader.dataset), loss_sum / loss_count), end='\r') print('') return {'bpd': loss_sum / loss_count} def sample_fn(self): self.model.eval() path_check = '{}/check/checkpoint.pt'.format(self.log_path) checkpoint = torch.load(path_check) path_samples = '{}/samples/sample_ep{}_s{}.png'.format( self.log_path, checkpoint['current_epoch'], self.args.seed) if not os.path.exists(os.path.dirname(path_samples)): os.mkdir(os.path.dirname(path_samples)) # save model samples samples = self.model.sample( self.args.samples).cpu().float() / (2**self.args.num_bits - 1) vutils.save_image(samples, path_samples, nrow=self.args.nrow) # save real samples too path_true_samples = '{}/samples/true_ep{}_s{}.png'.format( self.log_path, checkpoint['current_epoch'], self.args.seed) imgs = next(iter(self.eval_loader))[:self.args.samples].cpu().float() if imgs.max().item() > 2: imgs /= (2**self.args.num_bits - 1) vutils.save_image(imgs, path_true_samples, nrow=self.args.nrow) def stop_early(self, loss_dict, epoch): if self.args.early_stop == 0 or epoch < self.args.annealing_schedule: return False, True # else check if we've passed the early stopping threshold current_loss = loss_dict['bpd'] model_improved = current_loss < self.best_loss if model_improved: early_stop_flag = False self.best_loss = current_loss self.best_loss_epoch = epoch else: # model didn't improve, do we consider it converged yet? early_stop_count = (epoch - self.best_loss_epoch) early_stop_flag = early_stop_count >= self.args.early_stop if early_stop_flag: print( f'Stopping training early: no improvement after {self.args.early_stop} epochs (last improvement at epoch {self.best_loss_epoch})' ) return early_stop_flag, model_improved
class GaussianProcessExperiment(BaseExperiment): log_base = os.path.join(get_survae_path(), 'experiments/student_teacher/log') no_log_keys = ['project', 'name', 'log_tb', 'log_wandb', 'device', 'parallel', 'eval_every', 'num_flows', 'actnorm', 'scale_fn', 'hidden_units', 'range_flow', 'base_dist', 'affine', 'augment_size', 'pin_memory', 'num_workers'] def __init__(self, args, data_id, model_id, model, teacher): # Edit args args.epoch = 1 if args.name is None: args.name = time.strftime("%Y-%m-%d_%H-%M-%S") if args.project is None: args.project = '_'.join([data_id, model_id]) # Move models self.teacher = teacher.to(args.device) self.teacher.eval() cond_id = args.cond_trans.lower() self.cond_size = 1 if cond_id.startswith('split') or cond_id.startswith('multiply') else 2 seed_id = f"seed{args.seed}" arch_id = f"Gaussian_Process_{args.kernel}_kernel_a{int(args.gp_alpha)}_s{int(100*args.gp_length_scale)}" if args.name == "debug": log_path = os.path.join( self.log_base, "debug", model_id, data_id, cond_id, arch_id, seed_id, time.strftime("%Y-%m-%d_%H-%M-%S")) else: log_path = os.path.join( self.log_base, model_id, data_id, cond_id, arch_id, seed_id, args.name) # Init parent super(GaussianProcessExperiment, self).__init__(model=model, optimizer=None, scheduler_iter=None, scheduler_epoch=None, log_path=log_path, eval_every=args.eval_every) # Store args self.create_folders() self.save_args(args) self.args = args # Store IDs self.model_id = model_id self.data_id = data_id self.cond_id = cond_id self.arch_id = arch_id self.seed_id = seed_id # Init logging args_dict = clean_dict(vars(args), keys=self.no_log_keys) if args.log_tb: self.writer = SummaryWriter(os.path.join(self.log_path, 'tb')) self.writer.add_text("args", get_args_table( args_dict).get_html_string(), global_step=0) def run(self): # Train train_dict = self.train_fn() self.log_train_metrics(train_dict) # Eval eval_dict = self.eval_fn() self.log_eval_metrics(eval_dict) # Log self.save_metrics() # Plotting self.plot_fn() def train_fn(self): self.teacher.eval() with torch.no_grad(): y = self.teacher.sample(num_samples=self.args.train_samples) x = self.cond_fn(y) x = x.cpu().numpy() y = y.cpu().numpy() self.model.fit(x, y) nll = self.model.log_marginal_likelihood() r2 = self.model.score(x, y) print(f"Baseline: Log-marginal Likelihood={nll}, R-squared={r2}") return {'nll': nll, 'rsquared': r2} def eval_fn(self): K_test = 3 # number of MC samples self.teacher.eval() with torch.no_grad(): y = self.teacher.sample(num_samples=self.args.test_samples) x = self.cond_fn(y) x = x.cpu().numpy() y = y.cpu().numpy() MC_samples = [self.model.predict(x, return_std=True) for _ in range(K_test)] means = np.stack([tup[0] for tup in MC_samples]) logvar = np.stack([np.tile(tup[1], (2,1)).T for tup in MC_samples]) test_ll = -0.5 * np.exp(-logvar) * (means - y.squeeze())**2.0 - 0.5 * logvar - 0.5 * np.log(2 * np.pi) test_ll = np.sum(np.sum(test_ll, -1), -1) test_ll = logsumexp(test_ll) - np.log(K_test) #pppp = test_ll / self.args.test_samples # per point predictive probability rmse = np.mean((np.mean(means, 0) - y.squeeze())**2.0)**0.5 r2 = self.model.score(x, y) print(f"Baseline: R-squared={r2}, rmse={rmse}, likelihood={test_ll}") return {'rsquared': r2, 'rmse': rmse, 'lhood': test_ll} def plot_fn(self): plot_path = os.path.join(self.log_path, "samples/") if not os.path.exists(plot_path): os.mkdir(plot_path) if self.args.dataset == 'face_einstein': bounds = [[0, 1], [0, 1]] else: bounds = [[-4, 4], [-4, 4]] # plot true data test_data = self.teacher.sample(num_samples=self.args.test_samples).data.numpy() plt.figure(figsize=(self.args.pixels/self.args.dpi, self.args.pixels/self.args.dpi), dpi=self.args.dpi) plt.hist2d(test_data[...,0], test_data[...,1], bins=256, range=bounds) plt.xlim(bounds[0]) plt.ylim(bounds[1]) plt.axis('off') plt.savefig(os.path.join(plot_path, f'teacher.png'), bbox_inches='tight', pad_inches=0) # Plot samples while varying the context with torch.no_grad(): y = self.teacher.sample(num_samples=self.args.num_samples) x = self.cond_fn(y).cpu().numpy() y_mean, y_std = self.model.predict(x, return_std=True) temperature = 0.4 samples = [np.random.normal(y_mean[:, i], y_std * temperature, self.args.num_samples).T[:, np.newaxis] \ for i in range(y_mean.shape[1])] samples = np.hstack(samples) plt.figure(figsize=(self.args.pixels/self.args.dpi, self.args.pixels/self.args.dpi), dpi=self.args.dpi) plt.hist2d(samples[...,0], samples[...,1], bins=256, range=bounds) plt.xlim(bounds[0]) plt.ylim(bounds[1]) plt.axis('off') plt.savefig(os.path.join(plot_path, f'varying_context_flow_samples.png'), bbox_inches='tight', pad_inches=0) # Plot density xv, yv = torch.meshgrid([torch.linspace(bounds[0][0], bounds[0][1], self.args.grid_size), torch.linspace(bounds[1][0], bounds[1][1], self.args.grid_size)]) y = torch.cat([xv.reshape(-1,1), yv.reshape(-1,1)], dim=-1) x = self.cond_fn(y).numpy() means, logvar = self.model.predict(x, return_std=True) logvar = np.tile(logvar, (2,1)).T logprobs = -0.5 * np.exp(-logvar) * (means - y.numpy())**2.0 - 0.5 * logvar - 0.5 * np.log(2 * np.pi) logprobs = np.sum(logprobs, -1) logprobs = logprobs - logprobs.max() probs = np.exp(logprobs) plt.figure(figsize=(self.args.pixels/self.args.dpi, self.args.pixels/self.args.dpi), dpi=self.args.dpi) plt.pcolormesh(xv, yv, probs.reshape(xv.shape), shading='auto') plt.xlim(bounds[0]) plt.ylim(bounds[1]) plt.axis('off') plt.clim(0,self.args.clim) plt.savefig(os.path.join(plot_path, f'varying_context_flow_density.png'), bbox_inches='tight', pad_inches=0) # plot samples with fixed context for i, x in enumerate(self.context_permutations()): x = x.view((1, self.cond_size)).cpu().numpy() y_mean, y_std = self.model.predict(x, return_std=True) temperature = 1.5 samples = np.hstack([np.random.normal(y_mean[:, i], y_std * temperature, self.args.num_samples).T[:, np.newaxis] \ for i in range(y_mean.shape[1])]) #samples = self.model.sample_y(x, n_samples=self.args.num_samples).T[..., 0] plt.figure(figsize=(self.args.pixels/self.args.dpi, self.args.pixels/self.args.dpi), dpi=self.args.dpi) plt.hist2d(samples[...,0], samples[...,1], bins=256, range=bounds) plt.xlim(bounds[0]) plt.ylim(bounds[1]) plt.axis('off') plt.savefig(os.path.join(plot_path, f'fixed_context_flow_samples_{i}.png'), bbox_inches='tight', pad_inches=0) # Plot density xv, yv = torch.meshgrid([torch.linspace(bounds[0][0], bounds[0][1], self.args.grid_size), torch.linspace(bounds[1][0], bounds[1][1], self.args.grid_size)]) y = torch.cat([xv.reshape(-1,1), yv.reshape(-1,1)], dim=-1).cpu().numpy() means, logvar = self.model.predict(x, return_std=True) logvar = np.tile(logvar, (2,1)).T logprobs = -0.5 * np.exp(-logvar) * (means - y)**2.0 - 0.5 * logvar - 0.5 * np.log(2 * np.pi) logprobs = np.sum(logprobs, -1) logprobs = logprobs - logprobs.max() probs = np.exp(logprobs) plt.figure(figsize=(self.args.pixels/self.args.dpi, self.args.pixels/self.args.dpi), dpi=self.args.dpi) plt.pcolormesh(xv, yv, probs.reshape(xv.shape), shading='auto') plt.xlim(bounds[0]) plt.ylim(bounds[1]) plt.axis('off') plt.clim(0,self.args.clim) plt.savefig(os.path.join(plot_path, f'fixed_context_flow_density_{i}.png'), bbox_inches='tight', pad_inches=0)
class FlowExperiment(BaseExperiment): log_base = os.path.join(get_survae_path(), 'experiments/image/log') no_log_keys = [ 'project', 'name', 'log_tb', 'log_wandb', 'check_every', 'eval_every', 'device', 'parallel', 'pin_memory', 'num_workers' ] def __init__(self, args, data_id, model_id, optim_id, train_loader, eval_loader, model, optimizer, scheduler_iter, scheduler_epoch): # Edit args if args.eval_every is None: args.eval_every = args.epochs if args.check_every is None: args.check_every = args.epochs if args.name is None: args.name = time.strftime("%Y-%m-%d_%H-%M-%S") if args.project is None: args.project = '_'.join([data_id, model_id]) # Move model model = model.to(args.device) if args.parallel == 'dp': model = DataParallelDistribution(model) # Init parent log_path = os.path.join(self.log_base, data_id, model_id, optim_id, args.name) super(FlowExperiment, self).__init__(model=model, optimizer=optimizer, scheduler_iter=scheduler_iter, scheduler_epoch=scheduler_epoch, log_path=log_path, eval_every=args.eval_every, check_every=args.check_every) # Store args self.create_folders() self.save_args(args) self.args = args # Store IDs self.data_id = data_id self.model_id = model_id self.optim_id = optim_id # Store data loaders self.train_loader = train_loader self.eval_loader = eval_loader # Init logging args_dict = clean_dict(vars(args), keys=self.no_log_keys) if args.log_tb: self.writer = SummaryWriter(os.path.join(self.log_path, 'tb')) self.writer.add_text("args", get_args_table(args_dict).get_html_string(), global_step=0) if args.log_wandb: wandb.init(config=args_dict, project=args.project, id=args.name, dir=self.log_path) def log_fn(self, epoch, train_dict, eval_dict): # Tensorboard if self.args.log_tb: for metric_name, metric_value in train_dict.items(): self.writer.add_scalar('base/{}'.format(metric_name), metric_value, global_step=epoch + 1) if eval_dict: for metric_name, metric_value in eval_dict.items(): self.writer.add_scalar('eval/{}'.format(metric_name), metric_value, global_step=epoch + 1) # Weights & Biases if self.args.log_wandb: for metric_name, metric_value in train_dict.items(): wandb.log({'base/{}'.format(metric_name): metric_value}, step=epoch + 1) if eval_dict: for metric_name, metric_value in eval_dict.items(): wandb.log({'eval/{}'.format(metric_name): metric_value}, step=epoch + 1) def resume(self): resume_path = os.path.join(self.log_base, self.data_id, self.model_id, self.optim_id, self.args.resume, 'check') self.checkpoint_load(resume_path) for epoch in range(self.current_epoch): train_dict = {} for metric_name, metric_values in self.train_metrics.items(): train_dict[metric_name] = metric_values[epoch] if epoch in self.eval_epochs: eval_dict = {} for metric_name, metric_values in self.eval_metrics.items(): eval_dict[metric_name] = metric_values[ self.eval_epochs.index(epoch)] else: eval_dict = None self.log_fn(epoch, train_dict=train_dict, eval_dict=eval_dict) def run(self): if self.args.resume: self.resume() super(FlowExperiment, self).run(epochs=self.args.epochs) def train_fn(self, epoch): self.model.train() loss_sum = 0.0 loss_count = 0 for x in self.train_loader: self.optimizer.zero_grad() loss = elbo_bpd(self.model, x.to(self.args.device)) loss.backward() self.optimizer.step() if self.scheduler_iter: self.scheduler_iter.step() loss_sum += loss.detach().cpu().item() * len(x) loss_count += len(x) print('Training. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'. format(epoch + 1, self.args.epochs, loss_count, len(self.train_loader.dataset), loss_sum / loss_count), end='\r') print('') if self.scheduler_epoch: self.scheduler_epoch.step() return {'bpd': loss_sum / loss_count} def eval_fn(self, epoch): self.model.eval() with torch.no_grad(): loss_sum = 0.0 loss_count = 0 for x in self.eval_loader: loss = elbo_bpd(self.model, x.to(self.args.device)) loss_sum += loss.detach().cpu().item() * len(x) loss_count += len(x) print( 'Evaluating. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}' .format(epoch + 1, self.args.epochs, loss_count, len(self.eval_loader.dataset), loss_sum / loss_count), end='\r') print('') return {'bpd': loss_sum / loss_count}
class DropoutExperiment(BaseExperiment): log_base = os.path.join(get_survae_path(), 'experiments/student_teacher/log') no_log_keys = [ 'project', 'name', 'log_tb', 'log_wandb', 'eval_every', 'device', 'parallel', 'pin_memory', 'num_workers' ] def __init__(self, args, data_id, model_id, optim_id, model, teacher, optimizer, scheduler_iter, scheduler_epoch): # Edit args if args.eval_every is None or args.eval_every == 0: args.eval_every = args.epochs if args.name is None: args.name = time.strftime("%Y-%m-%d_%H-%M-%S") if args.project is None: args.project = '_'.join([data_id, model_id]) cond_id = args.cond_trans.lower() arch_id = f"Dropout_h{args.hidden_units}" seed_id = f"seed{args.seed}" if args.name == "debug": log_path = os.path.join(self.log_base, "debug", model_id, data_id, cond_id, arch_id, seed_id, time.strftime("%Y-%m-%d_%H-%M-%S")) else: log_path = os.path.join(self.log_base, model_id, data_id, cond_id, arch_id, seed_id, args.name) # Move models model = model.to(args.device) if args.parallel == 'dp': model = DataParallelDistribution(model) # Init parent super(DropoutExperiment, self).__init__(model=model, optimizer=optimizer, scheduler_iter=scheduler_iter, scheduler_epoch=scheduler_epoch, log_path=log_path, eval_every=args.eval_every) # student teacher args teacher = teacher.to(args.device) self.teacher = teacher self.teacher.eval() self.cond_size = 1 if cond_id.startswith( 'split') or cond_id.startswith('multiply') else 2 # Store args self.create_folders() self.save_args(args) self.args = args # Store IDs self.model_id = model_id self.data_id = data_id self.cond_id = cond_id self.optim_id = optim_id self.arch_id = arch_id self.seed_id = seed_id # Init logging args_dict = clean_dict(vars(args), keys=self.no_log_keys) if args.log_tb: self.writer = SummaryWriter(os.path.join(self.log_path, 'tb')) self.writer.add_text("args", get_args_table(args_dict).get_html_string(), global_step=0) if args.log_wandb: wandb.init(config=args_dict, project=args.project, id=args.name, dir=self.log_path) # training params self.max_grad_norm = args.max_grad_norm # automatic mixed precision # bigger changes need to make this work with dataparallel though (@autocast() decoration on each forward) pytorch_170 = int(str(torch.__version__)[2]) >= 7 self.amp = args.amp and args.parallel != 'dp' and pytorch_170 if self.amp: # only available in pytorch 1.7.0+ self.scaler = torch.cuda.amp.GradScaler() else: self.scaler = None # save model architecture for reference self.save_architecture() def log_fn(self, epoch, train_dict, eval_dict): # Tensorboard if self.args.log_tb: for metric_name, metric_value in train_dict.items(): self.writer.add_scalar('base/{}'.format(metric_name), metric_value, global_step=epoch + 1) if eval_dict: for metric_name, metric_value in eval_dict.items(): self.writer.add_scalar('eval/{}'.format(metric_name), metric_value, global_step=epoch + 1) # Weights & Biases if self.args.log_wandb: for metric_name, metric_value in train_dict.items(): wandb.log({'base/{}'.format(metric_name): metric_value}, step=epoch + 1) if eval_dict: for metric_name, metric_value in eval_dict.items(): wandb.log({'eval/{}'.format(metric_name): metric_value}, step=epoch + 1) def resume(self): resume_path = os.path.join(self.log_base, self.model_id, self.data_id, self.cond_id, self.arch_id, self.seed_id, self.args.resume, 'check') self.checkpoint_load(resume_path) for epoch in range(self.current_epoch): train_dict = {} for metric_name, metric_values in self.train_metrics.items(): train_dict[metric_name] = metric_values[epoch] if epoch in self.eval_epochs: eval_dict = {} for metric_name, metric_values in self.eval_metrics.items(): eval_dict[metric_name] = metric_values[ self.eval_epochs.index(epoch)] else: eval_dict = None self.log_fn(epoch, train_dict=train_dict, eval_dict=eval_dict) def run(self): if self.args.resume: self.resume() super(DropoutExperiment, self).run(epochs=self.args.epochs) def train_fn(self, epoch): if self.amp: # use automatic mixed precision return self._train_amp(epoch) else: return self._train(epoch) def _train_amp(self, epoch): """ Same training procedure, but uses half precision to speed up training on GPUs. Only works on SOME GPUs and the latest version of Pytorch. """ self.model.train() self.teacher.eval() loss_sum = 0.0 loss_count = 0 iters_per_epoch = self.args.train_samples // self.args.batch_size for i in range(iters_per_epoch): batch_size = self.args.batch_size with torch.no_grad(): y = self.teacher.sample(num_samples=self.args.batch_size) x = self.cond_fn(y) with torch.cuda.amp.autocast(): loss = self.model.loss(x, y) # Scale loss and call backward() to create scaled gradients self.scaler.scale(loss).backward() if self.max_grad_norm > 0: # Unscales the gradients of optimizer's assigned params in-place self.scaler.unscale_(self.optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.max_grad_norm) # Unscale gradients and call (or skip) optimizer.step() self.scaler.step(self.optimizer) self.scaler.update() if self.scheduler_iter: self.scheduler_iter.step() self.optimizer.zero_grad(set_to_none=True) # accumulate loss and report loss_sum += loss.detach().cpu().item() * batch_size loss_count += batch_size self.log_epoch("Training", loss_count, self.args.train_samples, loss_sum) print('') if self.scheduler_epoch: self.scheduler_epoch.step() return {'nll': loss_sum / loss_count} def _train(self, epoch): self.model.train() self.teacher.eval() loss_sum = 0.0 loss_count = 0 iters_per_epoch = max(1, self.args.train_samples // self.args.batch_size) for i in range(iters_per_epoch): with torch.no_grad(): y = self.teacher.sample(num_samples=self.args.batch_size) x = self.cond_fn(y) self.optimizer.zero_grad() loss = self.model.loss(x, y) loss.backward() if self.max_grad_norm > 0: grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.max_grad_norm) self.optimizer.step() if self.scheduler_iter: self.scheduler_iter.step() loss_sum += loss.detach().cpu().item() * self.args.batch_size loss_count += self.args.batch_size self.log_epoch("Training", loss_count, self.args.train_samples, loss_sum) print('') if self.scheduler_epoch: self.scheduler_epoch.step() return {'nll': loss_sum / loss_count} def eval_fn(self, epoch): K_test = 2 # Number of MC samples self.model.eval() self.teacher.eval() with torch.no_grad(): y = self.teacher.sample(num_samples=self.args.test_samples) x = self.cond_fn(y) MC_samples = [self.model(x) for _ in range(K_test)] x = x.numpy() y = y.numpy() means = torch.stack([tup[0] for tup in MC_samples ]).view(K_test, x.shape[0], 2).cpu().data.numpy() logvar = torch.stack([tup[1] for tup in MC_samples ]).view(K_test, x.shape[0], 2).cpu().data.numpy() test_ll = -0.5 * np.exp(-logvar) * ( means - y.squeeze())**2.0 - 0.5 * logvar - 0.5 * np.log(2 * np.pi) test_ll = np.sum(np.sum(test_ll, -1), -1) test_ll = logsumexp(test_ll) - np.log(K_test) #pppp = test_ll / self.args.test_samples # per point predictive probability rmse = np.mean((np.mean(means, 0) - y)**2.0)**0.5 print("Eval:", rmse, "test LL:", test_ll) return {'rmse': rmse} def plot_fn(self): plot_path = os.path.join(self.log_path, "samples/") if not os.path.exists(plot_path): os.mkdir(plot_path) if self.args.dataset == 'face_einstein': bounds = [[0, 1], [0, 1]] else: bounds = [[-4, 4], [-4, 4]] # plot true data test_data = self.teacher.sample( num_samples=self.args.test_samples).data.numpy() plt.figure(figsize=(self.args.pixels / self.args.dpi, self.args.pixels / self.args.dpi), dpi=self.args.dpi) plt.hist2d(test_data[..., 0], test_data[..., 1], bins=256, range=bounds) plt.xlim(bounds[0]) plt.ylim(bounds[1]) plt.axis('off') plt.savefig(os.path.join(plot_path, f'teacher.png'), bbox_inches='tight', pad_inches=0) # Plot samples while varying the context with torch.no_grad(): y = self.teacher.sample(num_samples=self.args.num_samples) x = self.cond_fn(y) y_mean, y_logvar, _ = self.model(x) y_mean = y_mean.cpu().numpy() y_var = np.exp(y_logvar.cpu().numpy()) temperature = 0.4 samples = [np.random.normal(y_mean[:, i], y_var[:, i] * temperature, self.args.num_samples).T[:, np.newaxis] \ for i in range(y_mean.shape[1])] samples = np.hstack(samples) #samples = y_mean.cpu().numpy() plt.figure(figsize=(self.args.pixels / self.args.dpi, self.args.pixels / self.args.dpi), dpi=self.args.dpi) plt.hist2d(samples[..., 0], samples[..., 1], bins=256, range=bounds) plt.xlim(bounds[0]) plt.ylim(bounds[1]) plt.axis('off') plt.savefig(os.path.join(plot_path, f'varying_context_flow_samples.png'), bbox_inches='tight', pad_inches=0) # Plot density xv, yv = torch.meshgrid([ torch.linspace(bounds[0][0], bounds[0][1], self.args.grid_size), torch.linspace(bounds[1][0], bounds[1][1], self.args.grid_size) ]) y = torch.cat([xv.reshape(-1, 1), yv.reshape(-1, 1)], dim=-1).to(self.args.device) with torch.no_grad(): x = self.cond_fn(y) means, logvar, _ = self.model(x) logprobs = -0.5 * torch.exp(-logvar) * ( means - y)**2.0 - 0.5 * logvar - 0.5 * np.log(2 * np.pi) logprobs = torch.sum(logprobs, dim=1) logprobs = logprobs - logprobs.max() probs = torch.exp(logprobs) plt.figure(figsize=(self.args.pixels / self.args.dpi, self.args.pixels / self.args.dpi), dpi=self.args.dpi) plt.pcolormesh(xv, yv, probs.reshape(xv.shape), shading='auto') plt.xlim(bounds[0]) plt.ylim(bounds[1]) plt.axis('off') plt.clim(0, self.args.clim) plt.savefig(os.path.join(plot_path, f'varying_context_flow_density.png'), bbox_inches='tight', pad_inches=0) # plot samples with fixed context for i, x in enumerate(self.context_permutations()): with torch.no_grad(): y_mean, y_logvar, _ = self.model( x.expand((self.args.num_samples, self.cond_size))) y_mean = y_mean.cpu().numpy() y_var = np.exp(y_logvar.cpu().numpy()) temperature = 1.5 samples = [np.random.normal(y_mean[:, i], y_var[:, i] * temperature, self.args.num_samples).T[:, np.newaxis] \ for i in range(y_mean.shape[1])] samples = np.hstack(samples) #samples = y_mean.cpu().numpy() plt.figure(figsize=(self.args.pixels / self.args.dpi, self.args.pixels / self.args.dpi), dpi=self.args.dpi) plt.hist2d(samples[..., 0], samples[..., 1], bins=256, range=bounds) plt.xlim(bounds[0]) plt.ylim(bounds[1]) plt.axis('off') plt.savefig(os.path.join(plot_path, f'fixed_context_flow_samples_{i}.png'), bbox_inches='tight', pad_inches=0) # Plot density xv, yv = torch.meshgrid([ torch.linspace(bounds[0][0], bounds[0][1], self.args.grid_size), torch.linspace(bounds[1][0], bounds[1][1], self.args.grid_size) ]) y = torch.cat( [xv.reshape(-1, 1), yv.reshape(-1, 1)], dim=-1).to(self.args.device) with torch.no_grad(): means, logvar, _ = self.model(x.view(1, self.cond_size)) means = means.expand((y.size(0), 2)) logvar = logvar.expand((y.size(0), 2)) logprobs = -0.5 * torch.exp(-logvar) * ( means - y)**2.0 - 0.5 * logvar - 0.5 * np.log(2 * np.pi) logprobs = torch.sum(logprobs, dim=1) logprobs = logprobs - logprobs.max() probs = torch.exp(logprobs) plt.figure(figsize=(self.args.pixels / self.args.dpi, self.args.pixels / self.args.dpi), dpi=self.args.dpi) plt.pcolormesh(xv, yv, probs.reshape(xv.shape), shading='auto') plt.xlim(bounds[0]) plt.ylim(bounds[1]) plt.axis('off') plt.clim(0, self.args.clim) plt.savefig(os.path.join(plot_path, f'fixed_context_flow_density_{i}.png'), bbox_inches='tight', pad_inches=0)
class StudentExperiment(BaseExperiment): log_base = os.path.join(get_survae_path(), 'experiments/student_teacher/log') no_log_keys = [ 'project', 'name', 'log_tb', 'log_wandb', 'eval_every', 'device', 'parallel', 'pin_memory', 'num_workers' ] def __init__(self, args, data_id, model_id, optim_id, model, teacher, optimizer, scheduler_iter, scheduler_epoch): # Edit args if args.eval_every is None or args.eval_every == 0: args.eval_every = args.epochs if args.name is None: args.name = time.strftime("%Y-%m-%d_%H-%M-%S") if args.project is None: args.project = '_'.join([data_id, model_id]) aug_or_abs = 'abs' if args.augment_size == 0 else f"aug{args.augment_size}" hidden = '_'.join([str(u) for u in args.hidden_units]) cond_id = args.cond_trans.lower() arch_id = f"flow_{aug_or_abs}_k{args.num_flows}_h{hidden}_{'affine' if args.affine else 'additive'}{'_actnorm' if args.actnorm else ''}" seed_id = f"seed{args.seed}" if args.name == "debug": log_path = os.path.join(self.log_base, "debug", model_id, data_id, cond_id, arch_id, seed_id, time.strftime("%Y-%m-%d_%H-%M-%S")) else: log_path = os.path.join(self.log_base, model_id, data_id, cond_id, arch_id, seed_id, args.name) # Move models model = model.to(args.device) if args.parallel == 'dp': model = DataParallelDistribution(model) # Init parent super(StudentExperiment, self).__init__(model=model, optimizer=optimizer, scheduler_iter=scheduler_iter, scheduler_epoch=scheduler_epoch, log_path=log_path, eval_every=args.eval_every) # student teacher args teacher = teacher.to(args.device) self.teacher = teacher self.teacher.eval() self.cond_size = 1 if cond_id.startswith( 'split') or cond_id.startswith('multiply') else 2 # Store args self.create_folders() self.save_args(args) self.args = args # Store IDs self.model_id = model_id self.data_id = data_id self.cond_id = cond_id self.optim_id = optim_id self.arch_id = arch_id self.seed_id = seed_id # Init logging args_dict = clean_dict(vars(args), keys=self.no_log_keys) if args.log_tb: self.writer = SummaryWriter(os.path.join(self.log_path, 'tb')) self.writer.add_text("args", get_args_table(args_dict).get_html_string(), global_step=0) if args.log_wandb: wandb.init(config=args_dict, project=args.project, id=args.name, dir=self.log_path) # training params self.max_grad_norm = args.max_grad_norm # automatic mixed precision # bigger changes need to make this work with dataparallel though (@autocast() decoration on each forward) pytorch_170 = int(str(torch.__version__)[2]) >= 7 self.amp = args.amp and args.parallel != 'dp' and pytorch_170 if self.amp: # only available in pytorch 1.7.0+ self.scaler = torch.cuda.amp.GradScaler() else: self.scaler = None # save model architecture for reference self.save_architecture() def log_fn(self, epoch, train_dict, eval_dict): # Tensorboard if self.args.log_tb: for metric_name, metric_value in train_dict.items(): self.writer.add_scalar('base/{}'.format(metric_name), metric_value, global_step=epoch + 1) if eval_dict: for metric_name, metric_value in eval_dict.items(): self.writer.add_scalar('eval/{}'.format(metric_name), metric_value, global_step=epoch + 1) # Weights & Biases if self.args.log_wandb: for metric_name, metric_value in train_dict.items(): wandb.log({'base/{}'.format(metric_name): metric_value}, step=epoch + 1) if eval_dict: for metric_name, metric_value in eval_dict.items(): wandb.log({'eval/{}'.format(metric_name): metric_value}, step=epoch + 1) def resume(self): resume_path = os.path.join(self.log_base, self.model_id, self.data_id, self.cond_id, self.arch_id, self.seed_id, self.args.resume, 'check') self.checkpoint_load(resume_path) for epoch in range(self.current_epoch): train_dict = {} for metric_name, metric_values in self.train_metrics.items(): train_dict[metric_name] = metric_values[epoch] if epoch in self.eval_epochs: eval_dict = {} for metric_name, metric_values in self.eval_metrics.items(): eval_dict[metric_name] = metric_values[ self.eval_epochs.index(epoch)] else: eval_dict = None self.log_fn(epoch, train_dict=train_dict, eval_dict=eval_dict) def run(self): if self.args.resume: self.resume() super(StudentExperiment, self).run(epochs=self.args.epochs) def train_fn(self, epoch): if self.amp: # use automatic mixed precision return self._train_amp(epoch) else: return self._train(epoch) def _train_amp(self, epoch): """ Same training procedure, but uses half precision to speed up training on GPUs. Only works on SOME GPUs and the latest version of Pytorch. """ self.model.train() self.teacher.eval() loss_sum = 0.0 loss_count = 0 iters_per_epoch = self.args.train_samples // self.args.batch_size for i in range(iters_per_epoch): batch_size = self.args.batch_size with torch.no_grad(): y = self.teacher.sample(num_samples=self.args.batch_size) x = self.cond_fn(y) with torch.cuda.amp.autocast(): loss = -self.model.log_prob(y, context=x).mean() # Scale loss and call backward() to create scaled gradients self.scaler.scale(loss).backward() if self.max_grad_norm > 0: # Unscales the gradients of optimizer's assigned params in-place self.scaler.unscale_(self.optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.max_grad_norm) # Unscale gradients and call (or skip) optimizer.step() self.scaler.step(self.optimizer) self.scaler.update() if self.scheduler_iter: self.scheduler_iter.step() self.optimizer.zero_grad(set_to_none=True) # accumulate loss and report loss_sum += loss.detach().cpu().item() * batch_size loss_count += batch_size self.log_epoch("Training", loss_count, self.args.train_samples, loss_sum) print('') if self.scheduler_epoch: self.scheduler_epoch.step() return {'nll': loss_sum / loss_count} def _train(self, epoch): self.model.train() self.teacher.eval() loss_sum = 0.0 loss_count = 0 iters_per_epoch = self.args.train_samples // self.args.batch_size for i in range(iters_per_epoch): with torch.no_grad(): y = self.teacher.sample(num_samples=self.args.batch_size) x = self.cond_fn(y) self.optimizer.zero_grad() loss = -self.model.log_prob(y, context=x).mean() loss.backward() if self.max_grad_norm > 0: grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.max_grad_norm) self.optimizer.step() if self.scheduler_iter: self.scheduler_iter.step() loss_sum += loss.detach().cpu().item() * self.args.batch_size loss_count += self.args.batch_size self.log_epoch("Training", loss_count, self.args.train_samples, loss_sum) print('') if self.scheduler_epoch: self.scheduler_epoch.step() return {'nll': loss_sum / loss_count} def eval_fn(self): self.model.eval() self.teacher.eval() with torch.no_grad(): loss_sum = 0.0 loss_count = 0 iters_per_epoch = self.args.test_samples // self.args.batch_size for i in range(iters_per_epoch): y = self.teacher.sample(num_samples=self.args.batch_size) x = self.cond_fn(y) loss = -self.model.log_prob(y, context=x).mean() loss_sum += loss.detach().cpu().item() * self.args.batch_size loss_count += self.args.batch_size print( 'Evaluating. Epoch: {}/{}, Datapoint: {}/{}, NLL: {:.3f}'. format(self.current_epoch + 1, self.args.epochs, loss_count, self.args.test_samples, loss_sum / loss_count), end='\r') print('') return {'nll': loss_sum / loss_count} def plot_fn(self): plot_path = os.path.join(self.log_path, "samples/") if not os.path.exists(plot_path): os.mkdir(plot_path) if self.args.dataset == 'face_einstein': bounds = [[0, 1], [0, 1]] else: bounds = [[-4, 4], [-4, 4]] # plot true data test_data = self.teacher.sample( num_samples=self.args.test_samples).data.numpy() plt.figure(figsize=(self.args.pixels / self.args.dpi, self.args.pixels / self.args.dpi), dpi=self.args.dpi) plt.hist2d(test_data[..., 0], test_data[..., 1], bins=256, range=bounds) plt.xlim(bounds[0]) plt.ylim(bounds[1]) plt.axis('off') plt.savefig(os.path.join(plot_path, f'teacher.png'), bbox_inches='tight', pad_inches=0) # Plot samples while varying the context with torch.no_grad(): y = self.teacher.sample(num_samples=self.args.num_samples) x = self.cond_fn(y) samples = self.model.sample(context=x, temperature=0.4) samples = samples.cpu().numpy() plt.figure(figsize=(self.args.pixels / self.args.dpi, self.args.pixels / self.args.dpi), dpi=self.args.dpi) plt.hist2d(samples[..., 0], samples[..., 1], bins=256, range=bounds) plt.xlim(bounds[0]) plt.ylim(bounds[1]) plt.axis('off') plt.savefig(os.path.join(plot_path, f'varying_context_flow_samples.png'), bbox_inches='tight', pad_inches=0) # Plot density xv, yv = torch.meshgrid([ torch.linspace(bounds[0][0], bounds[0][1], self.args.grid_size), torch.linspace(bounds[1][0], bounds[1][1], self.args.grid_size) ]) y = torch.cat([xv.reshape(-1, 1), yv.reshape(-1, 1)], dim=-1).to(self.args.device) x = self.cond_fn(y) with torch.no_grad(): logprobs = self.model.log_prob(y, context=x) logprobs = logprobs - logprobs.max().item() plt.figure(figsize=(self.args.pixels / self.args.dpi, self.args.pixels / self.args.dpi), dpi=self.args.dpi) plt.pcolormesh(xv, yv, logprobs.exp().reshape(xv.shape)) plt.xlim(bounds[0]) plt.ylim(bounds[1]) plt.axis('off') print('Range:', logprobs.exp().min().item(), logprobs.exp().max().item()) print('Limits:', 0.0, self.args.clim) plt.clim(0, self.args.clim) plt.savefig(os.path.join(plot_path, f'varying_context_flow_density.png'), bbox_inches='tight', pad_inches=0) # plot samples with fixed context for i, x in enumerate(self.context_permutations()): with torch.no_grad(): #y = self.teacher.sample(num_samples=1).expand((self.args.num_samples, 2)) samples = self.model.sample(context=x.expand( (self.args.num_samples, self.cond_size)), temperature=1.5).cpu().numpy() plt.figure(figsize=(self.args.pixels / self.args.dpi, self.args.pixels / self.args.dpi), dpi=self.args.dpi) plt.hist2d(samples[..., 0], samples[..., 1], bins=256, range=bounds) plt.xlim(bounds[0]) plt.ylim(bounds[1]) plt.axis('off') plt.savefig(os.path.join(plot_path, f'fixed_context_flow_samples_{i}.png'), bbox_inches='tight', pad_inches=0) # Plot density xv, yv = torch.meshgrid([ torch.linspace(bounds[0][0], bounds[0][1], self.args.grid_size), torch.linspace(bounds[1][0], bounds[1][1], self.args.grid_size) ]) y = torch.cat( [xv.reshape(-1, 1), yv.reshape(-1, 1)], dim=-1).to(self.args.device) with torch.no_grad(): logprobs = self.model.log_prob(y, context=x.expand( (y.size(0), self.cond_size))) logprobs = logprobs - logprobs.max().item() plt.figure(figsize=(self.args.pixels / self.args.dpi, self.args.pixels / self.args.dpi), dpi=self.args.dpi) plt.pcolormesh(xv, yv, logprobs.exp().reshape(xv.shape)) plt.xlim(bounds[0]) plt.ylim(bounds[1]) plt.axis('off') print('Range:', logprobs.exp().min().item(), logprobs.exp().max().item()) print('Limits:', 0.0, self.args.clim) plt.clim(0, self.args.clim) plt.savefig(os.path.join(plot_path, f'fixed_context_flow_density_{i}.png'), bbox_inches='tight', pad_inches=0)
class FlowExperiment(BaseExperiment): log_base = os.path.join(get_survae_path(), 'experiments/gbnf/log') no_log_keys = [ 'project', 'name', 'log_tb', 'log_wandb', 'check_every', 'eval_every', 'device', 'parallel', 'pin_memory', 'num_workers' ] def __init__(self, args, data_id, model_id, optim_id, train_loader, eval_loader, model, optimizer, scheduler_iter, scheduler_epoch): # Edit args if args.eval_every is None or args.eval_every == 0: args.eval_every = args.epochs if args.check_every is None or args.check_every == 0: args.check_every = args.epochs if args.name is None: args.name = time.strftime("%Y-%m-%d_%H-%M-%S") if args.project is None: args.project = '_'.join([data_id, model_id]) arch_id = f"{args.coupling_network}_coupling_scales{args.num_scales}_steps{args.num_steps}" seed_id = f"seed{args.seed}" if args.name == "debug": log_path = os.path.join(self.log_base, "debug", data_id, model_id, arch_id, optim_id, seed_id, time.strftime("%Y-%m-%d_%H-%M-%S")) else: log_path = os.path.join(self.log_base, data_id, model_id, arch_id, optim_id, seed_id, args.name) # Move model model = model.to(args.device) if args.parallel == 'dp': model = DataParallelDistribution(model) # Init parent super(FlowExperiment, self).__init__(model=model, optimizer=optimizer, scheduler_iter=scheduler_iter, scheduler_epoch=scheduler_epoch, log_path=log_path, eval_every=args.eval_every, check_every=args.check_every) # Store args self.create_folders() self.save_args(args) self.args = args # Store IDs self.data_id = data_id self.model_id = model_id self.optim_id = optim_id self.arch_id = arch_id self.seed_id = seed_id # Store data loaders self.train_loader = train_loader self.eval_loader = eval_loader # Init logging args_dict = clean_dict(vars(args), keys=self.no_log_keys) if args.log_tb: self.writer = SummaryWriter(os.path.join(self.log_path, 'tb')) self.writer.add_text("args", get_args_table(args_dict).get_html_string(), global_step=0) if args.log_wandb: wandb.init(config=args_dict, project=args.project, id=args.name, dir=self.log_path) # training params self.max_grad_norm = args.max_grad_norm # automatic mixed precision # bigger changes need to make this work with dataparallel though (@autocast() decoration on each forward) pytorch_170 = int(str(torch.__version__)[2]) >= 7 self.amp = args.amp and args.parallel != 'dp' and pytorch_170 if self.amp: # only available in pytorch 1.7.0+ self.scaler = torch.cuda.amp.GradScaler() else: self.scaler = None # save model architecture for reference self.save_architecture() def log_fn(self, epoch, train_dict, eval_dict): # Tensorboard if self.args.log_tb: for metric_name, metric_value in train_dict.items(): self.writer.add_scalar('base/{}'.format(metric_name), metric_value, global_step=epoch + 1) if eval_dict: for metric_name, metric_value in eval_dict.items(): self.writer.add_scalar('eval/{}'.format(metric_name), metric_value, global_step=epoch + 1) # Weights & Biases if self.args.log_wandb: for metric_name, metric_value in train_dict.items(): wandb.log({'base/{}'.format(metric_name): metric_value}, step=epoch + 1) if eval_dict: for metric_name, metric_value in eval_dict.items(): wandb.log({'eval/{}'.format(metric_name): metric_value}, step=epoch + 1) def resume(self): resume_path = os.path.join(self.log_base, self.data_id, self.model_id, self.arch_id, self.optim_id, self.seed_id, self.args.resume, 'check') #self.checkpoint_load(resume_path, device=self.args.device) self.checkpoint_load(resume_path) for epoch in range(self.current_epoch): train_dict = {} for metric_name, metric_values in self.train_metrics.items(): train_dict[metric_name] = metric_values[epoch] if epoch in self.eval_epochs: eval_dict = {} for metric_name, metric_values in self.eval_metrics.items(): eval_dict[metric_name] = metric_values[ self.eval_epochs.index(epoch)] else: eval_dict = None self.log_fn(epoch, train_dict=train_dict, eval_dict=eval_dict) def run(self): if self.args.resume: self.resume() super(FlowExperiment, self).run(epochs=self.args.epochs) def train_fn(self, epoch): if self.amp: # use automatic mixed precision return self._train_amp(epoch) else: return self._train(epoch) def _train_amp(self, epoch): """ Same training procedure, but uses half precision to speed up training on GPUs. Only works on SOME GPUs and the latest version of Pytorch. """ self.model.train() loss_sum = 0.0 loss_count = 0 print_every = max( 1, (len(self.train_loader.dataset) // self.args.batch_size) // 20) for i, x in enumerate(self.train_loader): # Cast operations to mixed precision if self.args.super_resolution or self.args.conditional: batch_size = len(x[0]) with torch.cuda.amp.autocast(): loss = cond_elbo_bpd(self.model, x[0].to(self.args.device), context=x[1].to(self.args.device)) else: batch_size = len(x) with torch.cuda.amp.autocast(): loss = elbo_bpd(self.model, x.to(self.args.device)) # Scale loss and call backward() to create scaled gradients self.scaler.scale(loss).backward() if self.max_grad_norm > 0: # Unscales the gradients of optimizer's assigned params in-place self.scaler.unscale_(self.optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.max_grad_norm) # Unscale gradients and call (or skip) optimizer.step() self.scaler.step(self.optimizer) self.scaler.update() if self.scheduler_iter: self.scheduler_iter.step() self.optimizer.zero_grad(set_to_none=True) # accumulate loss and report loss_sum += loss.detach().cpu().item() * batch_size loss_count += batch_size if i % print_every == 0 or i == (len(self.train_loader) - 1): self.log_epoch("Training", loss_count, len(self.train_loader.dataset), loss_sum) print('') if self.scheduler_epoch: self.scheduler_epoch.step() return {'bpd': loss_sum / loss_count} def _train(self, epoch): self.model.train() loss_sum = 0.0 loss_count = 0 print_every = max( 1, (len(self.train_loader.dataset) // self.args.batch_size) // 20) for i, x in enumerate(self.train_loader): self.optimizer.zero_grad() if self.args.super_resolution or self.args.conditional: batch_size = len(x[0]) loss = cond_elbo_bpd(self.model, x[0].to(self.args.device), context=x[1].to(self.args.device)) else: batch_size = len(x) loss = elbo_bpd(self.model, x.to(self.args.device)) loss.backward() if self.max_grad_norm > 0: grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.max_grad_norm) self.optimizer.step() if self.scheduler_iter: self.scheduler_iter.step() loss_sum += loss.detach().cpu().item() * batch_size loss_count += batch_size if i % print_every == 0 or i == (len(self.train_loader) - 1): self.log_epoch("Training", loss_count, len(self.train_loader.dataset), loss_sum) print('') if self.scheduler_epoch: self.scheduler_epoch.step() return {'bpd': loss_sum / loss_count} def eval_fn(self, epoch): self.model.eval() with torch.no_grad(): loss_sum = 0.0 loss_count = 0 for x in self.eval_loader: if self.args.super_resolution or self.args.conditional: batch_size = len(x[0]) loss = cond_elbo_bpd(self.model, x[0].to(self.args.device), context=x[1].to(self.args.device)) else: batch_size = len(x) loss = elbo_bpd(self.model, x.to(self.args.device)) loss_sum += loss.detach().cpu().item() * batch_size loss_count += batch_size self.log_epoch("Evaluating", loss_count, len(self.eval_loader.dataset), loss_sum) print('') return {'bpd': loss_sum / loss_count} def sample_fn(self, temperature=None, sample_new_batch=False): if self.args.samples < 1: return self.model.eval() get_new_batch = self.sample_batch is None or sample_new_batch if get_new_batch: self.sample_batch = next(iter(self.eval_loader)) if self.args.super_resolution or self.args.conditional: imgs = self.sample_batch[0][:self.args.samples] context = self.sample_batch[1][:self.args.samples] self._cond_sample_fn(context, temperature=temperature, save_context=get_new_batch) else: imgs = self.sample_batch[:self.args.samples] self._sample_fn(temperature=temperature) if get_new_batch: # save real samples path_true_samples = '{}/samples/true_e{}_s{}.png'.format( self.log_path, self.current_epoch + 1, self.args.seed) self.save_images(imgs, path_true_samples) def _sample_fn(self, temperature=None): path_samples = '{}/samples/sample_e{}_s{}.png'.format( self.log_path, self.current_epoch + 1, self.args.seed) samples = self.model.sample(self.args.samples, temperature=temperature) self.save_images(samples, path_samples) def _cond_sample_fn(self, context, temperature=None, save_context=True): if save_context: # save low-resolution samples path_context = '{}/samples/context_e{}_s{}.png'.format( self.log_path, self.current_epoch + 1, self.args.seed) self.save_images(context, path_context) # save samples from model conditioned on context path_samples = '{}/samples/sample_e{}_s{}.png'.format( self.log_path, self.current_epoch + 1, self.args.seed) samples = self.model.sample(context.to(self.args.device), temperature=temperature) self.save_images(samples, path_samples) def save_images(self, imgs, file_path): if not os.path.exists(os.path.dirname(file_path)): os.mkdir(os.path.dirname(file_path)) out = imgs.cpu().float() if out.max().item() > 2: out /= (2**self.args.num_bits - 1) vutils.save_image(out, file_path, nrow=self.args.nrow) def stop_early(self, loss_dict, epoch): if self.args.early_stop == 0: return False, True # else check if we've passed the early stopping threshold current_loss = loss_dict['bpd'] model_improved = current_loss < self.best_loss if model_improved: early_stop_flag = False self.best_loss = current_loss self.best_loss_epoch = epoch else: # model didn't improve, do we consider it converged yet? early_stop_count = (epoch - self.best_loss_epoch) early_stop_flag = early_stop_count >= self.args.early_stop if early_stop_flag: print( f'Stopping training early: no improvement after {self.args.early_stop} epochs (last improvement at epoch {self.best_loss_epoch})' ) return early_stop_flag, model_improved