help="int_t ||df_i/dx_i||_F") parser.add_argument('--JoffdiagFrobint', type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F") parser.add_argument('--resume', type=str, default=None) parser.add_argument('--save', type=str, default='experiments/cnf') parser.add_argument('--evaluate', action='store_true') parser.add_argument('--val_freq', type=int, default=200) parser.add_argument('--log_freq', type=int, default=10) args = parser.parse_args() # logger utils.makedirs(args.save) logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) if args.layer_type == "blend": logger.info( "!! Setting time_length from None to 1.0 due to use of Blend layers.") args.time_length = 1.0 args.train_T = False logger.info(args) test_batch_size = args.test_batch_size if args.test_batch_size else args.batch_size def batch_iter(X, batch_size=args.batch_size, shuffle=False): """ X: feature tensor (shape: num_instances x num_features)
def run(args, kwargs): # ================================================================================================================== # SNAPSHOTS # ================================================================================================================== args.model_signature = str(datetime.datetime.now())[0:19].replace(' ', '_') args.model_signature = args.model_signature.replace(':', '_') snapshots_path = os.path.join(args.out_dir, 'vae_' + args.dataset + '_') snap_dir = snapshots_path + args.flow if args.flow != 'no_flow': snap_dir += '_' + 'num_flows_' + str(args.num_flows) if args.flow == 'orthogonal': snap_dir = snap_dir + '_num_vectors_' + str(args.num_ortho_vecs) elif args.flow == 'orthogonalH': snap_dir = snap_dir + '_num_householder_' + str(args.num_householder) elif args.flow == 'iaf': snap_dir = snap_dir + '_madehsize_' + str(args.made_h_size) elif args.flow == 'permutation': snap_dir = snap_dir + '_' + 'kernelsize_' + str(args.kernel_size) elif args.flow == 'mixed': snap_dir = snap_dir + '_' + 'num_householder_' + str( args.num_householder) elif args.flow == 'cnf_rank': snap_dir = snap_dir + '_rank_' + str( args.rank) + '_' + args.dims + '_num_blocks_' + str( args.num_blocks) elif 'cnf' in args.flow: snap_dir = snap_dir + '_' + args.dims + '_num_blocks_' + str( args.num_blocks) if args.retrain_encoder: snap_dir = snap_dir + '_retrain-encoder_' elif args.evaluate: snap_dir = snap_dir + '_evaluate_' snap_dir = snap_dir + '__' + args.model_signature + '/' args.snap_dir = snap_dir if not os.path.exists(snap_dir): os.makedirs(snap_dir) # logger utils.makedirs(args.snap_dir) logger = utils.get_logger(logpath=os.path.join(args.snap_dir, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) # SAVING torch.save(args, snap_dir + args.flow + '.config') # ================================================================================================================== # LOAD DATA # ================================================================================================================== train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs) if not args.evaluate: # ============================================================================================================== # SELECT MODEL # ============================================================================================================== # flow parameters and architecture choice are passed on to model through args if args.flow == 'no_flow': model = VAE.VAE(args) elif args.flow == 'planar': model = VAE.PlanarVAE(args) elif args.flow == 'iaf': model = VAE.IAFVAE(args) elif args.flow == 'orthogonal': model = VAE.OrthogonalSylvesterVAE(args) elif args.flow == 'householder': model = VAE.HouseholderSylvesterVAE(args) elif args.flow == 'triangular': model = VAE.TriangularSylvesterVAE(args) elif args.flow == 'cnf': model = CNFVAE.CNFVAE(args) elif args.flow == 'cnf_bias': model = CNFVAE.AmortizedBiasCNFVAE(args) elif args.flow == 'cnf_hyper': model = CNFVAE.HypernetCNFVAE(args) elif args.flow == 'cnf_lyper': model = CNFVAE.LypernetCNFVAE(args) elif args.flow == 'cnf_rank': model = CNFVAE.AmortizedLowRankCNFVAE(args) else: raise ValueError('Invalid flow choice') if args.retrain_encoder: logger.info(f"Initializing decoder from {args.model_path}") dec_model = torch.load(args.model_path) dec_sd = {} for k, v in dec_model.state_dict().items(): if 'p_x' in k: dec_sd[k] = v model.load_state_dict(dec_sd, strict=False) if args.cuda: logger.info("Model on GPU") model.cuda() logger.info(model) if args.retrain_encoder: parameters = [] logger.info('Optimizing over:') for name, param in model.named_parameters(): if 'p_x' not in name: logger.info(name) parameters.append(param) else: parameters = model.parameters() optimizer = optim.Adamax(parameters, lr=args.learning_rate, eps=1.e-7) # ================================================================================================================== # TRAINING # ================================================================================================================== train_loss = [] val_loss = [] # for early stopping best_loss = np.inf best_bpd = np.inf e = 0 epoch = 0 train_times = [] for epoch in range(1, args.epochs + 1): t_start = time.time() tr_loss = train(epoch, train_loader, model, optimizer, args, logger) train_loss.append(tr_loss) train_times.append(time.time() - t_start) logger.info('One training epoch took %.2f seconds' % (time.time() - t_start)) v_loss, v_bpd = evaluate(val_loader, model, args, logger, epoch=epoch) val_loss.append(v_loss) # early-stopping if v_loss < best_loss: e = 0 best_loss = v_loss if args.input_type != 'binary': best_bpd = v_bpd logger.info('->model saved<-') torch.save(model, snap_dir + args.flow + '.model') # torch.save(model, snap_dir + args.flow + '_' + args.architecture + '.model') elif (args.early_stopping_epochs > 0) and (epoch >= args.warmup): e += 1 if e > args.early_stopping_epochs: break if args.input_type == 'binary': logger.info( '--> Early stopping: {}/{} (BEST: loss {:.4f})\n'.format( e, args.early_stopping_epochs, best_loss)) else: logger.info( '--> Early stopping: {}/{} (BEST: loss {:.4f}, bpd {:.4f})\n' .format(e, args.early_stopping_epochs, best_loss, best_bpd)) if math.isnan(v_loss): raise ValueError('NaN encountered!') train_loss = np.hstack(train_loss) val_loss = np.array(val_loss) plot_training_curve(train_loss, val_loss, fname=snap_dir + '/training_curve_%s.pdf' % args.flow) # training time per epoch train_times = np.array(train_times) mean_train_time = np.mean(train_times) std_train_time = np.std(train_times, ddof=1) logger.info('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time)) # ================================================================================================================== # EVALUATION # ================================================================================================================== logger.info(args) logger.info('Stopped after %d epochs' % epoch) logger.info('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time)) final_model = torch.load(snap_dir + args.flow + '.model') validation_loss, validation_bpd = evaluate(val_loader, final_model, args, logger) else: validation_loss = "N/A" validation_bpd = "N/A" logger.info(f"Loading model from {args.model_path}") final_model = torch.load(args.model_path) test_loss, test_bpd = evaluate(test_loader, final_model, args, logger, testing=True) logger.info( 'FINAL EVALUATION ON VALIDATION SET. ELBO (VAL): {:.4f}'.format( validation_loss)) logger.info( 'FINAL EVALUATION ON TEST SET. NLL (TEST): {:.4f}'.format(test_loss)) if args.input_type != 'binary': logger.info( 'FINAL EVALUATION ON VALIDATION SET. ELBO (VAL) BPD : {:.4f}'. format(validation_bpd)) logger.info( 'FINAL EVALUATION ON TEST SET. NLL (TEST) BPD: {:.4f}'.format( test_bpd))