def main(): # ========================================================================= # PARSE EXPERIMENT SETTINGS, SETUP SNAPSHOTS DIRECTORY, LOGGING # ========================================================================= args, kwargs = parse_args() # ========================================================================= # LOAD DATA # ========================================================================= logger.info('LOADING DATA:') train_loader, val_loader, test_loader, args = load_image_dataset( args, **kwargs) args.z_size = args.input_size # ========================================================================= # SAVE EXPERIMENT SETTINGS # ========================================================================= logger.info(f'EXPERIMENT SETTINGS:\n{args}\n') torch.save(args, os.path.join(args.snap_dir, 'config.pt')) # ========================================================================= # INITIALIZE MODEL AND OPTIMIZATION # ========================================================================= model = init_model(args) optimizer, scheduler = init_optimizer(model, args) num_params = sum([param.nelement() for param in model.parameters()]) logger.info(f"MODEL:\nNumber of model parameters={num_params}\n{model}\n") # ========================================================================= # TRAINING # ========================================================================= logger.info('TRAINING:') train(model, train_loader, val_loader, optimizer, scheduler, args) # ========================================================================= # VALIDATION # ========================================================================= logger.info('VALIDATION:') val_loss = evaluate(model, val_loader, args) # ========================================================================= # TESTING # ========================================================================= if args.testing: logger.info("TESTING:") val_loss = evaluate(model, test_loader, args)
def train(model, data_loaders, optimizer, scheduler, args): writer = SummaryWriter(args.snap_dir) if args.tensorboard else None header_msg = f'| Epoch | {"TRAIN": <14}{"Loss": >4} | {"VALIDATION": <14}{"Loss": >4} | {"TIMING":<8}{"(sec)":>4} | {"Improved": >8} |' header_msg += f' {"Component": >9} | {"All Trained": >11} | {"Rho": >{min(8, args.num_components) * 6}} |' if args.boosted else '' logger.info('|' + "-" * (len(header_msg) - 2) + '|') logger.info(header_msg) logger.info('|' + "-" * (len(header_msg) - 2) + '|') best_loss = np.array([np.inf] * args.num_components) early_stop_count = 0 converged_epoch = 0 # for boosting, helps keep track how long the current component has been training if args.boosted: #model.component = 0 prev_lr = init_boosted_lr(model, optimizer, args) else: prev_lr = [] grad_norm = None epoch_times = [] epoch_train = [] epoch_valid = [] pval_loss = 0.0 val_losses = {'g_nll': 9999999.9} step = 0 for epoch in range(args.init_epoch, args.epochs + 1): model.train() train_loss = [] t_start = time.time() for batch_id, (x, _) in enumerate(data_loaders['train']): # initialize data and optimizer x = x.to(args.device) optimizer.zero_grad() # initialize ActNorm on first steps if (args.flow == 'glow' or args.component_type == 'glow') and step < args.num_init_batches: with torch.no_grad(): if args.boosted: for i in range(args.num_components): model(x=x, components=i) else: model(x=x) step += 1 continue # compute loss and gradients losses = compute_kl_pq_loss(model, x, args) train_loss.append(losses['nll']) losses['nll'].backward() if args.max_grad_norm > 0: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.max_grad_norm) # Adjust learning rates for boosted model, keep fixed components frozen if args.boosted: update_learning_rates(prev_lr, model, optimizer, step, args) # batch level reporting batch_reporting(writer, optimizer, losses, grad_norm, step, args) # Perform gradient update, modify learning rate according to learning rate schedule optimizer.step() if not args.no_lr_schedule: prev_lr = update_scheduler(prev_lr, model, optimizer, scheduler, val_losses['g_nll'], step, args) if args.lr_schedule == "test": if step % 50 == 0: pval_loss = evaluate(model, data_loaders['val'], args)['nll'] writer.add_scalar('step/val_nll', pval_loss, step) step += 1 # Validation, collect results val_losses = evaluate(model, data_loaders['val'], args) train_loss = torch.stack(train_loss).mean().item() epoch_times.append(time.time() - t_start) epoch_train.append(train_loss) epoch_valid.append(val_losses['nll']) # Assess convergence component = (model.component, model.all_trained) if args.boosted else 0 converged, model_improved, early_stop_count, best_loss = check_convergence( early_stop_count, val_losses, best_loss, epoch - converged_epoch, component, args) if model_improved: fname = f'model_c{model.component}.pt' if args.boosted and args.save_intermediate_checkpoints else 'model.pt' save(model, optimizer, args.snap_dir + fname, scheduler) # epoch level reporting epoch_msg = epoch_reporting(writer, model, train_loss, val_losses, epoch_times, model_improved, epoch, args) if converged: logger.info(epoch_msg + ' |') logger.info("-" * (len(header_msg))) if args.boosted: converged_epoch = epoch # revert back to the last best version of the model and update rho fname = f'model_c{model.component}.pt' if args.save_intermediate_checkpoints else 'model.pt' load(model=model, optimizer=optimizer, path=args.snap_dir + fname, args=args, scheduler=scheduler, verbose=False) model.update_rho(data_loaders['train']) last_component = model.component == (args.num_components - 1) no_fine_tuning = args.epochs <= args.epochs_per_component * args.num_components fine_tuning_done = model.all_trained and last_component # no early stopping if burnin employed if (fine_tuning_done or no_fine_tuning) and last_component: # stop the full model after all components have been trained logger.info( f"Model converged, training complete, saving: {args.snap_dir + 'model.pt'}" ) model.all_trained = True save(model, optimizer, args.snap_dir + f'model.pt', scheduler) break # else if not done training: save model with updated rho save(model, optimizer, args.snap_dir + fname, scheduler) # tempory: look at results after each component test_loss = evaluate(model, data_loaders['test'], args) logger.info( f"Loss after training {model.component + 1} components: {test_loss['nll']:8.3f}" ) logger.info("-" * (len(header_msg))) # reset optimizer, scheduler, and early_stop_count and train the next component model.increment_component() early_stop_count = 0 val_losses = {'g_nll': 9999999.9} optimizer, scheduler = init_optimizer(model, args, verbose=False) prev_lr = init_boosted_lr(model, optimizer, args) else: # if a standard model converges once, break logger.info(f"Model converged, stopping training.") break else: logger.info(epoch_msg + ' |') if epoch == args.epochs: if args.boosted and args.save_intermediate_checkpoints: # Save the best version of the model trained up to the current component with filename model.pt # This is to protect against times when the model is trained/re-trained but doesn't run long enough # for all components to converge / train completely copyfile(args.snap_dir + f'model_c{model.component}.pt', args.snap_dir + 'model.pt') logger.info( f"Resaving last improved version of {f'model_c{model.component}.pt'} as 'model.pt' for future testing" ) else: logger.info( f"Stopping training after {epoch} epochs of training.") logger.info('|' + "-" * (len(header_msg) - 2) + '|\n') if args.tensorboard: writer.close() epoch_times, epoch_train, epoch_valid = np.array(epoch_times), np.array( epoch_train), np.array(epoch_valid) timing_msg = f"Stopped after {epoch_times.shape[0]} epochs. " timing_msg += f"Average train time per epoch: {np.mean(epoch_times):.2f} +/- {np.std(epoch_times, ddof=1):.2f}" logger.info(timing_msg + '\n') if args.save_results: np.savetxt(args.snap_dir + '/train_loss.csv', epoch_train, fmt='%f', delimiter=',') np.savetxt(args.snap_dir + '/valid_loss.csv', epoch_valid, fmt='%f', delimiter=',') np.savetxt(args.snap_dir + '/epoch_times.csv', epoch_times, fmt='%f', delimiter=',') with open(args.exp_log, 'a') as ff: timestamp = str(datetime.datetime.now())[0:19].replace(' ', '_') setup_msg = '\n'.join([timestamp, args.snap_dir ]) + '\n' + repr(args) print('\n' + setup_msg + '\n' + timing_msg, file=ff)
def train_boosted(train_loader, val_loader, model, optimizer, scheduler, args): train_times = [] train_loss = [] train_rec = [] train_G = [] train_p = [] train_entropy = [] val_loss = [] val_rec = [] val_kl = [] # for early stopping best_loss = np.array([np.inf] * args.num_components) best_tr_ratio = np.array([-np.inf] * args.num_components) early_stop_count = 0 converged_epoch = 0 # corrects the annealing schedule when a component converges early v_loss = 9999999.9 # initialize learning rates for boosted components prev_lr = init_boosted_lr(model, optimizer, args) args.step = 0 for epoch in range(args.init_epoch, args.epochs + 1): # compute annealing rate for KL loss term beta = kl_annealing_rate(epoch - converged_epoch, model.component, model.all_trained, args) # occasionally sample from all components to keep decoder from focusing solely on new component prob_all = sample_from_all_prob(epoch - converged_epoch, model.component, model.all_trained, args) # Train model t_start = time.time() tr_loss, tr_rec, tr_G, tr_p, tr_entropy, tr_ratio, prev_lr = train_epoch_boosted( epoch, train_loader, model, optimizer, scheduler, beta, prob_all, prev_lr, v_loss, args) train_times.append(time.time() - t_start) train_loss.append(tr_loss) train_rec.append(tr_rec) train_G.append(tr_G) train_p.append(tr_p) train_entropy.append(tr_entropy) # Evaluate model v_loss, v_rec, v_kl = evaluate(val_loader, model, args, epoch=epoch) val_loss.append(v_loss) val_rec.append(v_rec) val_kl.append(v_kl) # Assess convergence component_converged, model_improved, early_stop_count, best_loss, best_tr_ratio = check_convergence( early_stop_count, v_loss, best_loss, tr_ratio, best_tr_ratio, epoch - converged_epoch, model, args) # epoch level reporting epoch_msg = epoch_reporting(model, tr_loss, tr_rec, tr_G, tr_p, tr_entropy, tr_ratio, v_loss, v_rec, v_kl, beta, prob_all, train_times, epoch, model_improved, args) if model_improved: fname = f'model_c{model.component}.pt' if args.boosted and args.save_intermediate_checkpoints else 'model.pt' save(model, optimizer, args.snap_dir + fname, scheduler) if component_converged: logger.info(epoch_msg + f'{"| ": >4}') logger.info("-" * 206) converged_epoch = epoch # revert back to the last best version of the model and update rho fname = f'model_c{model.component}.pt' if args.save_intermediate_checkpoints else 'model.pt' load(model=model, optimizer=optimizer, path=args.snap_dir + fname, args=args, scheduler=scheduler, verbose=False) model.update_rho(train_loader) last_component = model.component == (args.num_components - 1) no_fine_tuning = args.epochs <= args.epochs_per_component * args.num_components fine_tuning_done = model.all_trained and last_component if (fine_tuning_done or no_fine_tuning) and last_component: # stop the full model after all components have been trained logger.info( f"Model converged, training complete, saving: {args.snap_dir + 'model.pt'}" ) model.all_trained = True save(model, optimizer, args.snap_dir + f'model.pt', scheduler) break save(model, optimizer, args.snap_dir + f'model_c{model.component}.pt', scheduler) # reset early_stop_count and train the next component model.increment_component() early_stop_count = 0 v_loss = 9999999.9 optimizer, scheduler = init_optimizer(model, args, verbose=False) prev_lr = init_boosted_lr(model, optimizer, args) else: logger.info(epoch_msg + f'{"| ": >4}') if epoch == args.epochs: if args.boosted and args.save_intermediate_checkpoints: # Save the best version of the model trained up to the current component with filename model.pt # This is to protect against times when the model is trained/re-trained but doesn't run long enough # for all components to converge / train completely copyfile(args.snap_dir + f'model_c{model.component}.pt', args.snap_dir + 'model.pt') logger.info( f"Resaving last improved version of {f'model_c{model.component}.pt'} as 'model.pt' for future testing" ) else: logger.info( f"Stopping training after {epoch} epochs of training.") train_loss = np.hstack(train_loss) train_rec = np.hstack(train_rec) train_G = np.hstack(train_G) train_p = np.hstack(train_p) train_entropy = np.hstack(train_entropy) val_loss = np.array(val_loss) val_rec = np.array(val_rec) val_kl = np.array(val_kl) train_times = np.array(train_times) return train_loss, train_rec, train_G, train_p, train_entropy, val_loss, val_rec, val_kl, train_times
def main(main_args=None): """ use main_args to run this script as function in another script """ # ========================================================================= # PARSE EXPERIMENT SETTINGS, SETUP SNAPSHOTS DIRECTORY, LOGGING # ========================================================================= args, kwargs = parse_args(main_args) # ========================================================================= # LOAD DATA # ========================================================================= logger.info('LOADING DATA:') data_loaders, args = load_density_dataset(args) # ========================================================================= # SAVE EXPERIMENT SETTINGS # ========================================================================= logger.info(f'EXPERIMENT SETTINGS:\n{args}\n') torch.save(args, os.path.join(args.snap_dir, 'config.pt')) # ========================================================================= # INITIALIZE MODEL AND OPTIMIZATION # ========================================================================= model = init_model(args) optimizer, scheduler = init_optimizer(model, args) num_params = sum([param.nelement() for param in model.parameters()]) logger.info(f"MODEL:\nNumber of model parameters={num_params}\n{model}\n") if args.load: logger.info(f'LOADING CHECKPOINT FROM PRE-TRAINED MODEL: {args.load}') init_with_args = args.flow == "boosted" and args.loaded_init_component is not None and args.loaded_all_trained is not None load(model=model, optimizer=optimizer, path=args.load, args=args, init_with_args=init_with_args, scheduler=scheduler) logger.info( f'Warning: boosted models may only be loaded to train a new component (until pytorch bug is fixed), optimizer and scheduler will be reset. Non-boosted models may not be loaded at all (will fail).' ) optimizer, scheduler = init_optimizer(model, args, verbose=False) # ========================================================================= # TRAINING # ========================================================================= if args.epochs > 0: logger.info('TRAINING:') if args.tensorboard: logger.info(f'Follow progress on tensorboard: tb {args.snap_dir}') train(model, data_loaders, optimizer, scheduler, args) # ========================================================================= # VALIDATION # ========================================================================= logger.info('VALIDATION:') load(model=model, optimizer=optimizer, path=args.snap_dir + 'model.pt', args=args) val_loss = evaluate(model, data_loaders['val'], args, results_type='Validation') # ========================================================================= # TESTING # ========================================================================= if args.testing: logger.info("TESTING:") test_loss = evaluate(model, data_loaders['test'], args, results_type='Test')
def main(main_args=None): """ use main_args to run this script as function in another script """ # ========================================================================= # PARSE EXPERIMENT SETTINGS, SETUP SNAPSHOTS DIRECTORY, LOGGING # ========================================================================= args, kwargs = parse_args(main_args) # ========================================================================= # LOAD DATA # ========================================================================= logger.info('LOADING DATA:') train_loader, val_loader, test_loader, args = load_image_dataset(args, **kwargs) # ========================================================================= # SAVE EXPERIMENT SETTINGS # ========================================================================= logger.info(f'EXPERIMENT SETTINGS:\n{args}\n') torch.save(args, os.path.join(args.snap_dir, 'config.pt')) # ========================================================================= # INITIALIZE MODEL AND OPTIMIZATION # ========================================================================= model = init_model(args) optimizer, scheduler = init_optimizer(model, args) num_params = sum([param.nelement() for param in model.parameters()]) logger.info(f"MODEL:\nNumber of model parameters={num_params}\n{model}\n") if args.load: logger.info(f'LOADING CHECKPOINT FROM PRE-TRAINED MODEL: {args.load}') init_with_args = args.flow == "boosted" and args.loaded_init_component is not None and args.loaded_all_trained is not None load(model, optimizer, args.load, args, init_with_args) # ========================================================================= # TRAINING # ========================================================================= training_required = args.epochs > 0 or args.load is None if training_required: logger.info('TRAINING:') if args.tensorboard: logger.info(f'Follow progress on tensorboard: tb {args.snap_dir}') train_loss, val_loss = train(train_loader, val_loader, model, optimizer, scheduler, args) # ========================================================================= # VALIDATION # ========================================================================= logger.info('VALIDATION:') if training_required: load(model, optimizer, args.snap_dir + 'model.pt', args) val_loss, val_rec, val_kl = evaluate(val_loader, model, args, results_type='Validation') # ========================================================================= # TESTING # ========================================================================= if args.testing: logger.info("TESTING:") test_loss, test_rec, test_kl = evaluate(test_loader, model, args, results_type='Test') test_nll = evaluate_likelihood(test_loader, model, args, S=args.nll_samples, MB=args.nll_mb, results_type='Test')