# Logging config flags.DEFINE_integer('report_loss_every', 100, 'Number of iterations between reporting minibatch loss.') flags.DEFINE_integer('train_epochs', 20, 'Maximum number of training epochs.') # Experiment config flags.DEFINE_integer('batch_size', 32, 'Mini-batch size.') flags.DEFINE_float('learning_rate', 1e-5, 'SGD learning rate.') # Parse flags config = forge.config() # Prepare enviornment logdir = osp.join(config.results_dir, config.run_name) logdir, resume_checkpoint = fet.init_checkpoint(logdir, config.data_config, config.model_config, config.resume) checkpoint_name = osp.join(logdir, 'model.ckpt') # Load data train_loader = fet.load(config.data_config, config) # Load model model = fet.load(config.model_config, config) # Print flags fet.print_flags() # Print model info print(model) # Setup optimizer optimizer = optim.RMSprop(model.parameters(),
def main(): # ------------------------ # SETUP # ------------------------ # Parse flags config = forge.config() if config.debug: config.num_workers = 0 config.batch_size = 2 # Fix seeds. Always first thing to be done after parsing the config! torch.manual_seed(config.seed) np.random.seed(config.seed) # Make CUDA operations deterministic torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Setup checkpoint or resume logdir = osp.join(config.results_dir, config.run_name) logdir, resume_checkpoint = fet.init_checkpoint(logdir, config.data_config, config.model_config, config.resume) checkpoint_name = osp.join(logdir, 'model.ckpt') # Using GPU(S)? if torch.cuda.is_available() and config.gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') else: config.gpu = False torch.set_default_tensor_type('torch.FloatTensor') fprint(f"Use GPU: {config.gpu}") if config.gpu and config.multi_gpu and torch.cuda.device_count() > 1: fprint(f"Using {torch.cuda.device_count()} GPUs!") config.num_workers = torch.cuda.device_count() * config.num_workers else: config.multi_gpu = False # Print flags # fet.print_flags() # TODO(martin) make this cleaner fprint(json.dumps(fet._flags.FLAGS.__flags, indent=4, sort_keys=True)) # Setup TensorboardX SummaryWriter writer = SummaryWriter(logdir) # Load data train_loader, val_loader, test_loader = fet.load(config.data_config, config) num_elements = 3 * config.img_size**2 # Assume three input channels # Load model model = fet.load(config.model_config, config) fprint(model) if config.geco: # Goal is specified per pixel & channel so it doesn't need to # be changed for different resolutions etc. geco_goal = config.g_goal * num_elements # Scale step size to get similar update at different resolutions geco_lr = config.g_lr * (64**2 / config.img_size**2) geco = GECO(geco_goal, geco_lr, config.g_alpha, config.g_init, config.g_min, config.g_speedup) beta = geco.beta else: beta = torch.tensor(config.beta) # Setup optimiser if config.optimiser == 'rmsprop': optimiser = optim.RMSprop(model.parameters(), config.learning_rate) elif config.optimiser == 'adam': optimiser = optim.Adam(model.parameters(), config.learning_rate) elif config.optimiser == 'sgd': optimiser = optim.SGD(model.parameters(), config.learning_rate, 0.9) # Try to restore model and optimiser from checkpoint iter_idx = 0 if resume_checkpoint is not None: fprint(f"Restoring checkpoint from {resume_checkpoint}") checkpoint = torch.load(resume_checkpoint, map_location='cpu') # Restore model & optimiser model_state_dict = checkpoint['model_state_dict'] model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1', None) model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2', None) model.load_state_dict(model_state_dict) optimiser.load_state_dict(checkpoint['optimiser_state_dict']) # Restore GECO if config.geco and 'beta' in checkpoint: geco.beta = checkpoint['beta'] if config.geco and 'err_ema' in checkpoint: geco.err_ema = checkpoint['err_ema'] # Update starting iter iter_idx = checkpoint['iter_idx'] + 1 fprint(f"Starting training at iter = {iter_idx}") # Push model to GPU(s)? if config.multi_gpu: fprint("Wrapping model in DataParallel.") model = nn.DataParallel(model) if config.gpu: fprint("Pushing model to GPU.") model = model.cuda() if config.geco: geco.to_cuda() # ------------------------ # TRAINING # ------------------------ model.train() timer = time.time() while iter_idx <= config.train_iter: for train_batch in train_loader: # Parse data train_input = train_batch['input'] if config.gpu: train_input = train_input.cuda() # Forward propagation optimiser.zero_grad() output, losses, stats, att_stats, comp_stats = model(train_input) # Reconstruction error err = losses.err.mean(0) # KL divergences kl_m, kl_l = torch.tensor(0), torch.tensor(0) # -- KL stage 1 if 'kl_m' in losses: kl_m = losses.kl_m.mean(0) elif 'kl_m_k' in losses: kl_m = torch.stack(losses.kl_m_k, dim=1).mean(dim=0).sum() # -- KL stage 2 if 'kl_l' in losses: kl_l = losses.kl_l.mean(0) elif 'kl_l_k' in losses: kl_l = torch.stack(losses.kl_l_k, dim=1).mean(dim=0).sum() # Compute ELBO elbo = (err + kl_l + kl_m).detach() err_new = err.detach() kl_new = (kl_m + kl_l).detach() # Compute MSE / RMSE mse_batched = ((train_input - output)**2).mean((1, 2, 3)).detach() rmse_batched = mse_batched.sqrt() mse, rmse = mse_batched.mean(0), rmse_batched.mean(0) # Main objective if config.geco: loss = geco.loss(err, kl_l + kl_m) beta = geco.beta else: if config.beta_warmup: # Increase beta linearly over 20% of training beta = config.beta * iter_idx / (0.2 * config.train_iter) beta = torch.tensor(beta).clamp(0, config.beta) else: beta = config.beta loss = err + beta * (kl_l + kl_m) # Backprop and optimise loss.backward() optimiser.step() # Heartbeat log if (iter_idx % config.report_loss_every == 0 or float(elbo) > ELBO_DIV or config.debug): # Print output and write to file ps = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") ps += f' {config.run_name} | ' ps += f'[{iter_idx}/{config.train_iter:.0e}]' ps += f' elb: {float(elbo):.0f} err: {float(err):.0f} ' if 'kl_m' in losses or 'kl_m_k' in losses: ps += f' klm: {float(kl_m):.1f}' if 'kl_l' in losses or 'kl_l_k' in losses: ps += f' kll: {float(kl_l):.1f}' ps += f' bet: {float(beta):.1e}' s_per_b = (time.time() - timer) if not config.debug: s_per_b /= config.report_loss_every timer = time.time() # Reset timer ps += f' - {s_per_b:.2f} s/b' fprint(ps) # TensorBoard logging # -- Optimisation stats writer.add_scalar('optim/beta', beta, iter_idx) writer.add_scalar('optim/s_per_batch', s_per_b, iter_idx) if config.geco: writer.add_scalar('optim/geco_err_ema', geco.err_ema, iter_idx) writer.add_scalar('optim/geco_err_ema_element', geco.err_ema / num_elements, iter_idx) # -- Main loss terms writer.add_scalar('train/err', err, iter_idx) writer.add_scalar('train/err_element', err / num_elements, iter_idx) writer.add_scalar('train/kl_m', kl_m, iter_idx) writer.add_scalar('train/kl_l', kl_l, iter_idx) writer.add_scalar('train/elbo', elbo, iter_idx) writer.add_scalar('train/loss', loss, iter_idx) writer.add_scalar('train/mse', mse, iter_idx) writer.add_scalar('train/rmse', rmse, iter_idx) # -- Per step loss terms for key in ['kl_l_k', 'kl_m_k']: if key not in losses: continue for step, val in enumerate(losses[key]): writer.add_scalar(f'train_steps/{key}{step}', val.mean(0), iter_idx) # -- Attention stats if config.log_distributions and att_stats is not None: for key in ['mu_k', 'sigma_k', 'pmu_k', 'psigma_k']: if key not in att_stats: continue for step, val in enumerate(att_stats[key]): writer.add_histogram(f'att_{key}_{step}', val, iter_idx) # -- Component stats if config.log_distributions and comp_stats is not None: for key in ['mu_k', 'sigma_k', 'pmu_k', 'psigma_k']: if key not in comp_stats: continue for step, val in enumerate(comp_stats[key]): writer.add_histogram(f'comp_{key}_{step}', val, iter_idx) # Save checkpoints ckpt_freq = config.train_iter / config.num_checkpoints if iter_idx % ckpt_freq == 0: ckpt_file = '{}-{}'.format(checkpoint_name, iter_idx) fprint(f"Saving model training checkpoint to: {ckpt_file}") if config.multi_gpu: model_state_dict = model.module.state_dict() else: model_state_dict = model.state_dict() ckpt_dict = { 'iter_idx': iter_idx, 'model_state_dict': model_state_dict, 'optimiser_state_dict': optimiser.state_dict(), 'elbo': elbo } if config.geco: ckpt_dict['beta'] = geco.beta ckpt_dict['err_ema'] = geco.err_ema torch.save(ckpt_dict, ckpt_file) # Run validation and log images if (iter_idx % config.run_validation_every == 0 or float(elbo) > ELBO_DIV): # Weight and gradient histograms if config.log_grads_and_weights: for name, param in model.named_parameters(): writer.add_histogram(f'weights/{name}', param.data, iter_idx) writer.add_histogram(f'grads/{name}', param.grad, iter_idx) # TensorboardX logging - images visualise_inference(model, train_batch, writer, 'train', iter_idx) # Validation fprint("Running validation...") eval_model = model.module if config.multi_gpu else model evaluation(eval_model, val_loader, writer, config, iter_idx, N_eval=config.N_eval) # Increment counter iter_idx += 1 if iter_idx > config.train_iter: break # Exit if training has diverged if elbo.item() > ELBO_DIV: fprint(f"ELBO: {elbo.item()}") fprint( f"ELBO has exceeded {ELBO_DIV} - training has diverged.") sys.exit() # ------------------------ # TESTING # ------------------------ # Save final checkpoint ckpt_file = '{}-{}'.format(checkpoint_name, 'FINAL') fprint(f"Saving model training checkpoint to: {ckpt_file}") if config.multi_gpu: model_state_dict = model.module.state_dict() else: model_state_dict = model.state_dict() ckpt_dict = { 'iter_idx': iter_idx, 'model_state_dict': model_state_dict, 'optimiser_state_dict': optimiser.state_dict() } if config.geco: ckpt_dict['beta'] = geco.beta ckpt_dict['err_ema'] = geco.err_ema torch.save(ckpt_dict, ckpt_file) # Test evaluation fprint("STARTING TESTING...") eval_model = model.module if config.gpu and config.multi_gpu else model final_elbo = evaluation(eval_model, test_loader, None, config, iter_idx, N_eval=config.N_eval) fprint(f"TEST ELBO = {float(final_elbo)}") # FID computation try: fid_from_model(model, test_loader, img_dir=osp.join('/tmp', logdir)) except NotImplementedError: fprint("Sampling not implemented for this model.") # Close writer writer.close()
def main(): # Parse flags config = forge.config() # Set device if torch.cuda.is_available(): device = f"cuda:{config.device}" torch.cuda.set_device(device) else: device = "cpu" # Load data dataloaders, data_name = fet.load(config.data_config, config=config) train_loader = dataloaders["train"] test_loader = dataloaders["test"] val_loader = dataloaders["val"] # Load model model, model_name = fet.load(config.model_config, config) model = model.to(device) print(model) # Prepare environment params_in_run_name = [ ("batch_size", "bs"), ("learning_rate", "lr"), ("num_heads", "nheads"), ("num_layers", "nlayers"), ("dim_hidden", "hdim"), ("kernel_dim", "kdim"), ("location_attention", "locatt"), ("model_seed", "mseed"), ("lr_schedule", "lrsched"), ("layer_norm", "ln"), ("batch_norm", "bn"), ("channel_width", "width"), ("attention_fn", "attfn"), ("output_mlp_scale", "mlpscale"), ("train_epochs", "epochs"), ("block_norm", "block"), ("kernel_type", "ktype"), ("architecture", "arch"), ("activation_function", "act"), ("space_dim", "spacedim"), ("num_particles", "prtcls"), ("n_train", "ntrain"), ("group", "group"), ("lift_samples", "ls"), ] run_name = "" for config_param in params_in_run_name: attr = config_param[0] abbrev = config_param[1] if hasattr(config, attr): run_name += abbrev run_name += str(getattr(config, attr)) run_name += "_" if config.clip_grad_norm: run_name += "clipnorm" + str(config.max_grad_norm) + "_" results_folder_name = osp.join( data_name, model_name, config.run_name, run_name, ) logdir = osp.join(config.results_dir, results_folder_name.replace(".", "_")) logdir, resume_checkpoint = fet.init_checkpoint(logdir, config.data_config, config.model_config, config.resume) checkpoint_name = osp.join(logdir, "model.ckpt") # Print flags fet.print_flags() # Setup optimizer model_params = model.predictor.parameters() opt_learning_rate = config.learning_rate model_opt = torch.optim.Adam(model_params, lr=opt_learning_rate, betas=(config.beta1, config.beta2)) if config.lr_schedule == "cosine_annealing": scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( model_opt, config.train_epochs) elif config.lr_schedule == "cosine_annealing_warmup": num_warmup_epochs = int(0.05 * config.train_epochs) num_warmup_steps = len(train_loader) * num_warmup_epochs num_training_steps = len(train_loader) * config.train_epochs scheduler = transformers.get_cosine_schedule_with_warmup( model_opt, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, ) # Try to restore model and optimizer from checkpoint if resume_checkpoint is not None: start_epoch = load_checkpoint(resume_checkpoint, model, model_opt) else: start_epoch = 1 train_iter = (start_epoch - 1) * (len(train_loader.dataset) // config.batch_size) + 1 print("Starting training at epoch = {}, iter = {}".format( start_epoch, train_iter)) # Setup tensorboard writing summary_writer = SummaryWriter(logdir) train_reports = [] report_all = {} report_all_val = {} # Saving model at epoch 0 before training print("saving model at epoch 0 before training ... ") save_checkpoint(checkpoint_name, 0, model, model_opt, loss=0.0) print("finished saving model at epoch 0 before training") # if ( # config.debug # and config.model_config == "configs/dynamics/eqv_transformer_model.py" # ): # model_components = ( # [(0, [], "embedding_layer")] # + list( # chain.from_iterable( # ( # (k, [], f"ema_{k}"), # ( # k, # ["ema", "kernel", "location_kernel"], # f"ema_{k}_location_kernel", # ), # ( # k, # ["ema", "kernel", "feature_kernel"], # f"ema_{k}_feature_kernel", # ), # ) # for k in range(1, config.num_layers + 1) # ) # ) # + [(config.num_layers + 2, [], "output_mlp")] # ) # components to track for debugging num_params = param_count(model) print(f"Number of model parameters: {num_params}") # Training start_t = time.time() total_train_iters = len(train_loader) * config.train_epochs iters_per_eval = int(total_train_iters / 100) # evaluate 100 times over the course of training assert ( config.n_train % min(config.batch_size, config.n_train) == 0 ), "Batch size doesn't divide dataset size. Can be problematic for loss computation (see below)." training_failed = False best_val_loss_so_far = 1e+7 for epoch in tqdm(range(start_epoch, config.train_epochs + 1)): model.train() for batch_idx, data in enumerate(train_loader): data = nested_to( data, device, torch.float32 ) # the format is ((z0, sys_params, ts), true_zs) for data outputs = model(data) if torch.isnan(outputs.loss): if not training_failed: epoch_of_nan = epoch if (epoch > epoch_of_nan + 1) and training_failed: raise ValueError("Loss Nan-ed.") training_failed = True model_opt.zero_grad() outputs.loss.backward(retain_graph=False) if config.clip_grad_norm: torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm, norm_type=1) model_opt.step() train_reports.append(parse_reports_cpu(outputs.reports)) if config.log_train_values: reports = parse_reports(outputs.reports) if batch_idx % config.report_loss_every == 0: log_tensorboard(summary_writer, train_iter, reports, "train/") print_reports( reports, start_t, epoch, batch_idx, len(train_loader.dataset) // config.batch_size, prefix="train", ) log_tensorboard( summary_writer, train_iter, {"lr": model_opt.param_groups[0]["lr"]}, "hyperparams/", ) # Do learning rate schedule steps per STEP for cosine_annealing_warmup if config.lr_schedule == "cosine_annealing_warmup": scheduler.step() # Track stuff for debugging if config.debug: model.eval() with torch.no_grad(): curr_params = { k: v.detach().clone() for k, v in model.state_dict().items() } # updates norm if train_iter == 1: update_norm = 0 else: update_norms = [] for (k_prev, v_prev), (k_curr, v_curr) in zip(prev_params.items(), curr_params.items()): assert k_prev == k_curr if ("tracked" not in k_prev): update_norms.append( (v_curr - v_prev).norm(1).item()) update_norm = sum(update_norms) # gradient norm grad_norm = 0 for p in model.parameters(): try: grad_norm += p.grad.norm(1) except AttributeError: pass # Average V size: (z0, sys_params, ts), true_zs = data z = z0 m = sys_params[ ..., 0] # assume the first component encodes masses D = z.shape[-1] # of ODE dims, 2*num_particles*space_dim q = z[:, :D // 2].reshape(*m.shape, -1) k = sys_params[..., 1] V = model.predictor.compute_V((q, sys_params)) V_true = SpringV(q, k) log_tensorboard( summary_writer, train_iter, { "avg_update_norm1": update_norm / num_params, "avg_grad_norm1": grad_norm / num_params, "avg_predicted_potential_norm": V.norm(1) / V.numel(), "avg_true_potential_norm": V_true.norm(1) / V_true.numel(), }, "debug/", ) prev_params = curr_params model.train() # Logging if train_iter % iters_per_eval == 0 or ( train_iter == total_train_iters ): # batch_idx % config.evaluate_every == 0: model.eval() with torch.no_grad(): reports = None for data in test_loader: data = nested_to(data, device, torch.float32) outputs = model(data) if reports is None: reports = { k: v.detach().clone().cpu() for k, v in outputs.reports.items() } else: for k, v in outputs.reports.items(): reports[k] += v.detach().clone().cpu() for k, v in reports.items(): reports[k] = v / len(test_loader) reports = parse_reports(reports) reports["time"] = time.time() - start_t if report_all == {}: report_all = deepcopy(reports) for d in reports.keys(): report_all[d] = [report_all[d]] else: for d in reports.keys(): report_all[d].append(reports[d]) log_tensorboard(summary_writer, train_iter, reports, "test/") print_reports( reports, start_t, epoch, batch_idx, len(train_loader.dataset) // config.batch_size, prefix="test", ) if config.kill_if_poor: if epoch > config.train_epochs * 0.2: if reports["mse"] > 0.01: raise RuntimeError( f"Killed run due to poor performance.") # repeat for validation data reports = None for data in val_loader: data = nested_to(data, device, torch.float32) outputs = model(data) if reports is None: reports = { k: v.detach().clone().cpu() for k, v in outputs.reports.items() } else: for k, v in outputs.reports.items(): reports[k] += v.detach().clone().cpu() for k, v in reports.items(): reports[k] = v / len(val_loader) reports = parse_reports(reports) reports["time"] = time.time() - start_t if report_all_val == {}: report_all_val = deepcopy(reports) for d in reports.keys(): report_all_val[d] = [report_all_val[d]] else: for d in reports.keys(): report_all_val[d].append(reports[d]) log_tensorboard(summary_writer, train_iter, reports, "val/") print_reports( reports, start_t, epoch, batch_idx, len(train_loader.dataset) // config.batch_size, prefix="val", ) if report_all_val['mse'][-1] < best_val_loss_so_far: save_checkpoint(checkpoint_name, f"early_stop", model, model_opt, loss=outputs.loss) best_val_loss_so_far = report_all_val['mse'][-1] model.train() train_iter += 1 if config.lr_schedule == "cosine_annealing": scheduler.step() if epoch % config.save_check_points == 0: save_checkpoint(checkpoint_name, train_iter, model, model_opt, loss=outputs.loss) dd.io.save(logdir + "/results_dict_train.h5", train_reports) dd.io.save(logdir + "/results_dict.h5", report_all) dd.io.save(logdir + "/results_dict_val.h5", report_all_val) # always save final model save_checkpoint(checkpoint_name, train_iter, model, model_opt, loss=outputs.loss) if config.save_test_predictions: print( "Starting to make model predictions on test sets for *final model*." ) for chunk_len in [5, 100]: start_t_preds = time.time() data_config = SimpleNamespace( **{ **config.__dict__["__flags"], **{ "chunk_len": chunk_len, "batch_size": 500 } }) dataloaders, data_name = fet.load(config.data_config, config=data_config) test_loader_preds = dataloaders["test"] torch.cuda.empty_cache() with torch.no_grad(): preds = [] true = [] num_datapoints = 0 for idx, d in enumerate(test_loader_preds): true.append(d[-1]) d = nested_to(d, device, torch.float32) outputs = model(d) pred_zs = outputs.prediction preds.append(pred_zs) num_datapoints += len(pred_zs) if num_datapoints >= 2000: break preds = torch.cat(preds, dim=0).cpu() true = torch.cat(true, dim=0).cpu() save_dir = osp.join( logdir, f"traj_preds_{chunk_len}_steps_2k_test.pt") torch.save(preds, save_dir) save_dir = osp.join(logdir, f"traj_true_{chunk_len}_steps_2k_test.pt") torch.save(true, save_dir) print( f"Completed making test predictions for chunk_len = {chunk_len} in {time.time() - start_t_preds:.2f} seconds." )
def main(): # Parse flags config = forge.config() # Load data dataloaders, num_species, charge_scale, ds_stats, data_name = fet.load( config.data_config, config=config) config.num_species = num_species config.charge_scale = charge_scale config.ds_stats = ds_stats # Load model model, model_name = fet.load(config.model_config, config) model.to(device) config.charge_scale = float(config.charge_scale.numpy()) config.ds_stats = [float(stat.numpy()) for stat in config.ds_stats] # Prepare environment run_name = (config.run_name + "_bs" + str(config.batch_size) + "_lr" + str(config.learning_rate)) if config.batch_fit != 0: run_name += "_bf" + str(config.batch_fit) if config.lr_schedule != "none": run_name += "_" + config.lr_schedule # Print flags fet.print_flags() # Setup optimizer model_params = model.predictor.parameters() opt_learning_rate = config.learning_rate model_opt = torch.optim.Adam( model_params, lr=opt_learning_rate, betas=(config.beta1, config.beta2), eps=1e-8, ) # model_opt = torch.optim.SGD(model_params, lr=opt_learning_rate) # Cosine annealing learning rate if config.lr_schedule == "cosine": cos = cosLr(config.train_epochs) lr_sched = lambda e: max(cos(e), config.lr_floor * config.learning_rate ) lr_schedule = optim.lr_scheduler.LambdaLR(model_opt, lr_sched) elif config.lr_schedule == "cosine_warmup": cos = cosLr(config.train_epochs) lr_sched = lambda e: max( min(e / (config.warmup_length * config.train_epochs), 1) * cos(e), config.lr_floor * config.learning_rate, ) lr_schedule = optim.lr_scheduler.LambdaLR(model_opt, lr_sched) elif config.lr_schedule == "quadratic_warmup": lr_sched = lambda e: min(e / (0.01 * config.train_epochs), 1) * ( 1.0 / sqrt(1.0 + 10000.0 * (e / config.train_epochs) ) # finish at 1/100 of initial lr ) lr_schedule = optim.lr_scheduler.LambdaLR(model_opt, lr_sched) elif config.lr_schedule == "none": lr_sched = lambda e: 1.0 lr_schedule = optim.lr_scheduler.LambdaLR(model_opt, lr_sched) else: raise ValueError( f"{config.lr_schedule} is not a recognised learning rate schedule") num_params = param_count(model) if config.parameter_count: for (name, parameter) in model.predictor.named_parameters(): print(name, parameter.dtype) print(model) print("============================================================") print(f"{model_name} parameters: {num_params:.5e}") print("============================================================") # from torchsummary import summary # data = next(iter(dataloaders["train"])) # data = {k: v.to(device) for k, v in data.items()} # print( # summary( # model.predictor, # data, # batch_size=config.batch_size, # ) # ) parameters = sum(parameter.numel() for parameter in model.predictor.parameters()) parameters_grad = sum( parameter.numel() if parameter.requires_grad else 0 for parameter in model.predictor.parameters()) print(f"Parameters: {parameters:,}") print(f"Parameters grad: {parameters_grad:,}") memory_allocations = [] for batch_idx, data in enumerate(dataloaders["train"]): print(batch_idx) data = {k: v.to(device) for k, v in data.items()} model_opt.zero_grad() outputs = model(data, compute_loss=True) # torch.cuda.empty_cache() # memory_allocations.append(torch.cuda.memory_reserved() / 1024 / 1024 / 1024) # outputs.loss.backward() print( f"max memory reserved in one pass: {max(memory_allocations):0.4}GB" ) sys.exit(0) else: print(f"{model_name} parameters: {num_params:.5e}") # set up results folders results_folder_name = osp.join( data_name, model_name, run_name, ) logdir = osp.join(config.results_dir, results_folder_name.replace(".", "_")) logdir, resume_checkpoint = fet.init_checkpoint(logdir, config.data_config, config.model_config, config.resume) checkpoint_name = osp.join(logdir, "model.ckpt") # Try to restore model and optimizer from checkpoint if resume_checkpoint is not None: start_epoch, best_valid_mae = load_checkpoint(resume_checkpoint, model, model_opt, lr_schedule) else: start_epoch = 1 best_valid_mae = 1e12 train_iter = (start_epoch - 1) * (len(dataloaders["train"].dataset) // config.batch_size) + 1 print("Starting training at epoch = {}, iter = {}".format( start_epoch, train_iter)) # Setup tensorboard writing summary_writer = SummaryWriter(logdir) report_all = defaultdict(list) # Saving model at epoch 0 before training print("saving model at epoch 0 before training ... ") save_checkpoint(checkpoint_name, 0, model, model_opt, lr_schedule, 0.0) print("finished saving model at epoch 0 before training") if (config.debug and config.model_config == "configs/dynamics/eqv_transformer_model.py"): model_components = ([(0, [], "embedding_layer")] + list( chain.from_iterable(( (k, [], f"ema_{k}"), ( k, ["ema", "kernel", "location_kernel"], f"ema_{k}_location_kernel", ), ( k, ["ema", "kernel", "feature_kernel"], f"ema_{k}_feature_kernel", ), ) for k in range(1, config.num_layers + 1))) + [(config.num_layers + 2, [], "output_mlp")] ) # components to track for debugging grad_flows = [] if config.init_activations: activation_tracked = [(name, module) for name, module in model.named_modules() if isinstance(module, Expression) | isinstance(module, nn.Linear) | isinstance(module, MultiheadLinear) | isinstance(module, MaskBatchNormNd)] activations = {} def save_activation(name, mod, inpt, otpt): if isinstance(inpt, tuple): if isinstance(inpt[0], list) | isinstance(inpt[0], tuple): activations[name + "_inpt"] = inpt[0][1].detach().cpu() else: if len(inpt) == 1: activations[name + "_inpt"] = inpt[0].detach().cpu() else: activations[name + "_inpt"] = inpt[1].detach().cpu() else: activations[name + "_inpt"] = inpt.detach().cpu() if isinstance(otpt, tuple): if isinstance(otpt[0], list): activations[name + "_otpt"] = otpt[0][1].detach().cpu() else: if len(otpt) == 1: activations[name + "_otpt"] = otpt[0].detach().cpu() else: activations[name + "_otpt"] = otpt[1].detach().cpu() else: activations[name + "_otpt"] = otpt.detach().cpu() for name, tracked_module in activation_tracked: tracked_module.register_forward_hook(partial( save_activation, name)) # Training start_t = time.perf_counter() iters_per_epoch = len(dataloaders["train"]) last_valid_loss = 1000.0 for epoch in tqdm(range(start_epoch, config.train_epochs + 1)): model.train() for batch_idx, data in enumerate(dataloaders["train"]): data = {k: v.to(device) for k, v in data.items()} model_opt.zero_grad() outputs = model(data, compute_loss=True) outputs.loss.backward() if config.clip_grad: # Clip gradient L2-norm at 1 torch.nn.utils.clip_grad.clip_grad_norm_( model.predictor.parameters(), 1.0) model_opt.step() if config.init_activations: model_opt.zero_grad() outputs = model(data, compute_loss=True) outputs.loss.backward() for name, activation in activations.items(): print(name) summary_writer.add_histogram(f"activations/{name}", activation.numpy(), 0) sys.exit(0) if config.log_train_values: reports = parse_reports(outputs.reports) if batch_idx % config.report_loss_every == 0: log_tensorboard(summary_writer, train_iter, reports, "train/") report_all = log_reports(report_all, train_iter, reports, "train") print_reports( reports, start_t, epoch, batch_idx, len(dataloaders["train"].dataset) // config.batch_size, prefix="train", ) # Logging if batch_idx % config.evaluate_every == 0: model.eval() with torch.no_grad(): valid_mae = 0.0 for data in dataloaders["valid"]: data = {k: v.to(device) for k, v in data.items()} outputs = model(data, compute_loss=True) valid_mae = valid_mae + outputs.mae model.train() outputs["reports"].valid_mae = valid_mae / len( dataloaders["valid"]) reports = parse_reports(outputs.reports) log_tensorboard(summary_writer, train_iter, reports, "valid") report_all = log_reports(report_all, train_iter, reports, "valid") print_reports( reports, start_t, epoch, batch_idx, len(dataloaders["train"].dataset) // config.batch_size, prefix="valid", ) loss_diff = (last_valid_loss - (valid_mae / len(dataloaders["valid"])).item()) if loss_diff and config.find_spikes < -0.1: save_checkpoint( checkpoint_name + "_spike", epoch, model, model_opt, lr_schedule, outputs.loss, ) last_valid_loss = (valid_mae / len(dataloaders["valid"])).item() if outputs["reports"].valid_mae < best_valid_mae: save_checkpoint( checkpoint_name, "best_valid_mae", model, model_opt, lr_schedule, best_valid_mae, ) best_valid_mae = outputs["reports"].valid_mae train_iter += 1 # Step the LR schedule lr_schedule.step(train_iter / iters_per_epoch) # Track stuff for debugging if config.debug: model.eval() with torch.no_grad(): curr_params = { k: v.detach().clone() for k, v in model.state_dict().items() } # updates norm if train_iter == 1: update_norm = 0 else: update_norms = [] for (k_prev, v_prev), (k_curr, v_curr) in zip(prev_params.items(), curr_params.items()): assert k_prev == k_curr if ( "tracked" not in k_prev ): # ignore batch norm tracking. TODO: should this be ignored? if not, fix! update_norms.append( (v_curr - v_prev).norm(1).item()) update_norm = sum(update_norms) # gradient norm grad_norm = 0 for p in model.parameters(): try: grad_norm += p.grad.norm(1) except AttributeError: pass # weights norm if (config.model_config == "configs/dynamics/eqv_transformer_model.py"): model_norms = {} for comp_name in model_components: comp = get_component(model.predictor.net, comp_name) norm = get_average_norm(comp) model_norms[comp_name[2]] = norm log_tensorboard( summary_writer, train_iter, model_norms, "debug/avg_model_norms/", ) log_tensorboard( summary_writer, train_iter, { "avg_update_norm1": update_norm / num_params, "avg_grad_norm1": grad_norm / num_params, }, "debug/", ) prev_params = curr_params # gradient flow ave_grads = [] max_grads = [] layers = [] for n, p in model.named_parameters(): if (p.requires_grad) and ("bias" not in n): layers.append(n) ave_grads.append(p.grad.abs().mean().item()) max_grads.append(p.grad.abs().max().item()) grad_flow = { "layers": layers, "ave_grads": ave_grads, "max_grads": max_grads, } grad_flows.append(grad_flow) model.train() # Test model at end of batch with torch.no_grad(): model.eval() test_mae = 0.0 for data in dataloaders["test"]: data = {k: v.to(device) for k, v in data.items()} outputs = model(data, compute_loss=True) test_mae = test_mae + outputs.mae outputs["reports"].test_mae = test_mae / len(dataloaders["test"]) reports = parse_reports(outputs.reports) log_tensorboard(summary_writer, train_iter, reports, "test") report_all = log_reports(report_all, train_iter, reports, "test") print_reports( reports, start_t, epoch, batch_idx, len(dataloaders["train"].dataset) // config.batch_size, prefix="test", ) reports = { "lr": lr_schedule.get_lr()[0], "time": time.perf_counter() - start_t, "epoch": epoch, } log_tensorboard(summary_writer, train_iter, reports, "stats") report_all = log_reports(report_all, train_iter, reports, "stats") # Save the reports dd.io.save(logdir + "/results_dict.h5", report_all) # Save a checkpoint if epoch % config.save_check_points == 0: save_checkpoint( checkpoint_name, epoch, model, model_opt, lr_schedule, best_valid_mae, ) if config.only_store_last_checkpoint: delete_checkpoint(checkpoint_name, epoch - config.save_check_points) save_checkpoint( checkpoint_name, "final", model, model_opt, lr_schedule, outputs.loss, )
def main(): # Parse flags config = forge.config() # Set device if torch.cuda.is_available(): device = f"cuda:{config.device}" torch.cuda.set_device(device) else: device = "cpu" # Load data train_loader, test_loader, data_name = fet.load(config.data_config, config=config) # Load model model, model_name = fet.load(config.model_config, config) model = model.to(device) print(model) torch.manual_seed(0) # Prepare environment if "set_transformer" in config.model_config: params_in_run_name = [ ("batch_size", "bs"), ("learning_rate", "lr"), ("num_heads", "nheads"), ("patterns_reps", "reps"), ("model_seed", "mseed"), ("train_epochs", "epochs"), ("naug", "naug"), ] else: params_in_run_name = [ ("batch_size", "bs"), ("learning_rate", "lr"), ("num_heads", "nheads"), ("num_layers", "nlayers"), ("dim_hidden", "hdim"), ("kernel_dim", "kdim"), ("location_attention", "locatt"), ("model_seed", "mseed"), ("lr_schedule", "lrsched"), ("layer_norm", "ln"), ("batch_norm_att", "bnatt"), ("batch_norm", "bn"), ("batch_norm_final_mlp", "bnfinalmlp"), ("k", "k"), ("attention_fn", "attfn"), ("output_mlp_scale", "mlpscale"), ("train_epochs", "epochs"), ("block_norm", "block"), ("kernel_type", "ktype"), ("architecture", "arch"), ("kernel_act", "actv"), ("patterns_reps", "reps"), ("lift_samples", "nsamples"), ("content_type", "content"), ("naug", "naug"), ] run_name = "" # config.run_name for config_param in params_in_run_name: attr = config_param[0] abbrev = config_param[1] if hasattr(config, attr): run_name += abbrev run_name += str(getattr(config, attr)) run_name += "_" results_folder_name = osp.join( data_name, model_name, config.run_name, run_name, ) # results_folder_name = osp.join(data_name, model_name, run_name,) logdir = osp.join(config.results_dir, results_folder_name.replace(".", "_")) logdir, resume_checkpoint = fet.init_checkpoint(logdir, config.data_config, config.model_config, config.resume) checkpoint_name = osp.join(logdir, "model.ckpt") # Print flags fet.print_flags() # Setup optimizer model_params = model.encoder.parameters() opt_learning_rate = config.learning_rate model_opt = torch.optim.Adam(model_params, lr=opt_learning_rate, betas=(config.beta1, config.beta2)) # Try to restore model and optimizer from checkpoint if resume_checkpoint is not None: start_epoch = load_checkpoint(resume_checkpoint, model, model_opt) else: start_epoch = 1 train_iter = (start_epoch - 1) * (len(train_loader.dataset) // config.batch_size) + 1 print("Starting training at epoch = {}, iter = {}".format( start_epoch, train_iter)) # Setup tensorboard writing summary_writer = SummaryWriter(logdir) train_reports = [] report_all = {} # Saving model at epoch 0 before training print("saving model at epoch 0 before training ... ") save_checkpoint(checkpoint_name, 0, model, model_opt, loss=0.0) print("finished saving model at epoch 0 before training") if (config.debug and config.model_config == "configs/constellation/eqv_transformer_model.py"): model_components = ([(0, [], "embedding_layer")] + list( chain.from_iterable(( (k, [], f"ema_{k}"), ( k, ["ema", "kernel", "location_kernel"], f"ema_{k}_location_kernel", ), ( k, ["ema", "kernel", "feature_kernel"], f"ema_{k}_feature_kernel", ), ) for k in range(1, config.num_layers + 1))) + [(config.num_layers + 2, [], "output_mlp")] ) # components to track for debugging num_params = param_count(model) print(f"Number of model parameters: {num_params}") # Training start_t = time.time() grad_flows = [] training_failed = False for epoch in tqdm(range(start_epoch, config.train_epochs + 1)): model.train() for batch_idx, data in enumerate(train_loader): data, presence, target = [d.to(device).float() for d in data] if config.train_aug_t2: unifs = np.random.normal( 1 ) # torch.rand(1) * 2 * math.pi # torch.randn(1).to(device) * 10 data += unifs if config.train_aug_se2: # angle = torch.rand(1) * 2 * math.pi angle = np.random.random(size=1) * 2 * math.pi unifs = np.random.normal(1) # torch.randn(1).to(device) * 10 data = rotate(data, angle.item()) data += unifs outputs = model([data, presence], target) if torch.isnan(outputs.loss): if not training_failed: epoch_of_nan = epoch if (epoch > epoch_of_nan + 1) and training_failed: raise ValueError("Loss Nan-ed.") training_failed = True model_opt.zero_grad() outputs.loss.backward(retain_graph=False) model_opt.step() train_reports.append(parse_reports_cpu(outputs.reports)) if config.log_train_values: reports = parse_reports(outputs.reports) if batch_idx % config.report_loss_every == 0: log_tensorboard(summary_writer, train_iter, reports, "train/") print_reports( reports, start_t, epoch, batch_idx, len(train_loader.dataset) // config.batch_size, prefix="train", ) # Track stuff for debugging if config.debug: model.eval() with torch.no_grad(): curr_params = { k: v.detach().clone() for k, v in model.state_dict().items() } # updates norm if train_iter == 1: update_norm = 0 else: update_norms = [] for (k_prev, v_prev), (k_curr, v_curr) in zip(prev_params.items(), curr_params.items()): assert k_prev == k_curr if ( "tracked" not in k_prev ): # ignore batch norm tracking. TODO: should this be ignored? if not, fix! update_norms.append( (v_curr - v_prev).norm(1).item()) update_norm = sum(update_norms) # gradient norm grad_norm = 0 for p in model.parameters(): try: grad_norm += p.grad.norm(1) except AttributeError: pass # # weights norm # if config.model_config == 'configs/constellation/eqv_transformer_model.py': # model_norms = {} # for comp_name in model_components: # comp = get_component(model.encoder.net, comp_name) # norm = get_average_norm(comp) # model_norms[comp_name[2]] = norm # log_tensorboard( # summary_writer, # train_iter, # model_norms, # "debug/avg_model_norms/", # ) log_tensorboard( summary_writer, train_iter, { "avg_update_norm1": update_norm / num_params, "avg_grad_norm1": grad_norm / num_params, }, "debug/", ) prev_params = curr_params # # gradient flow # ave_grads = [] # max_grads= [] # layers = [] # for n, p in model.named_parameters(): # if (p.requires_grad) and (p.grad is not None): # and ("bias" not in n): # layers.append(n) # ave_grads.append(p.grad.abs().mean().item()) # max_grads.append(p.grad.abs().max().item()) # grad_flow = {"layers": layers, "ave_grads": ave_grads, "max_grads": max_grads} # grad_flows.append(grad_flow) model.train() # Logging if batch_idx % config.evaluate_every == 0: model.eval() with torch.no_grad(): reports = None for data in test_loader: data, presence, target = [ d.to(device).float() for d in data ] # if config.data_config == "configs/constellation/constellation.py": # if config.global_rotation_angle != 0.0: # data = rotate(data, config.global_rotation_angle) outputs = model([data, presence], target) if reports is None: reports = { k: v.detach().clone().cpu() for k, v in outputs.reports.items() } else: for k, v in outputs.reports.items(): reports[k] += v.detach().clone().cpu() for k, v in reports.items(): reports[k] = v / len( test_loader ) # XXX: note this is slightly incorrect since mini-batch sizes can vary (if batch_size doesn't divide train_size), but approximately correct. reports = parse_reports(reports) reports["time"] = time.time() - start_t if report_all == {}: report_all = deepcopy(reports) for d in reports.keys(): report_all[d] = [report_all[d]] else: for d in reports.keys(): report_all[d].append(reports[d]) log_tensorboard(summary_writer, train_iter, reports, "test/") print_reports( reports, start_t, epoch, batch_idx, len(train_loader.dataset) // config.batch_size, prefix="test", ) model.train() train_iter += 1 if epoch % config.save_check_points == 0: save_checkpoint(checkpoint_name, train_iter, model, model_opt, loss=outputs.loss) dd.io.save(logdir + "/results_dict_train.h5", train_reports) dd.io.save(logdir + "/results_dict.h5", report_all) # if config.debug: # # dd.io.save(logdir + "/grad_flows.h5", grad_flows) save_checkpoint(checkpoint_name, train_iter, model, model_opt, loss=outputs.loss)