def instentiate_expr_datamodule() -> WeightedExprDataset: """ Create a WeightedExprDataset """ ignore_percentile = 65 dataset_path = os.path.join(ROOT_PROJECT, 'weighted_retraining/data/expr') data_seed = 0 good_percentile = 5 weight_type = 'rank' rank_weight_k = 1 weight_quantile = None val_frac = .1 property_key = 'scores' second_key = 'expr' batch_size = 128 predict_target = False metric_loss = None hparams = Namespace() hparams.ignore_percentile = ignore_percentile hparams.data_seed = data_seed hparams.good_percentile = good_percentile hparams.weight_type = weight_type hparams.dataset_path = dataset_path hparams.rank_weight_k = rank_weight_k hparams.weight_quantile = weight_quantile hparams.val_frac = val_frac hparams.property_key = property_key hparams.second_key = second_key hparams.batch_size = batch_size hparams.predict_target = predict_target hparams.metric_loss = metric_loss hparams.dataset_path = get_filepath(hparams.ignore_percentile, hparams.dataset_path, hparams.data_seed, good_percentile=hparams.good_percentile) datamodule = WeightedExprDataset(hparams, utils.DataWeighter(hparams), add_channel=False) datamodule.setup() return datamodule
def main_aux(args, result_dir: str): """ main """ # Seeding pl.seed_everything(args.seed) # Make results directory data_dir = os.path.join(result_dir, "data") os.makedirs(data_dir, exist_ok=True) setup_logger(os.path.join(result_dir, "log.txt")) result_filepath = os.path.join(result_dir, 'results.pkl') if not args.overwrite and os.path.exists(result_filepath): print(f"Already exists: {result_dir}") return # Load data datamodule = WeightedJTNNDataset(args, utils.DataWeighter(args)) datamodule.setup("fit", n_init_points=args.n_init_bo_points) # print python command run cmd = ' '.join(sys.argv[1:]) print_flush(f"{cmd}\n") # Load model assert args.use_pretrained if args.predict_target: if 'pred_y' in args.pretrained_model_file: # fully supervised setup from a model trained with target prediction ckpt = torch.load(args.pretrained_model_file) ckpt['hyper_parameters']['hparams'].beta_target_pred_loss = args.beta_target_pred_loss ckpt['hyper_parameters']['hparams'].predict_target = True ckpt['hyper_parameters']['hparams'].target_predictor_hdims = args.target_predictor_hdims torch.save(ckpt, args.pretrained_model_file) vae: JTVAE = JTVAE.load_from_checkpoint(args.pretrained_model_file, vocab=datamodule.vocab) vae.beta = vae.hparams.beta_final # Override any beta annealing vae.metric_loss = args.metric_loss vae.hparams.metric_loss = args.metric_loss vae.beta_metric_loss = args.beta_metric_loss vae.hparams.beta_metric_loss = args.beta_metric_loss vae.metric_loss_kw = args.metric_loss_kw vae.hparams.metric_loss_kw = args.metric_loss_kw vae.predict_target = args.predict_target vae.hparams.predict_target = args.predict_target vae.beta_target_pred_loss = args.beta_target_pred_loss vae.hparams.beta_target_pred_loss = args.beta_target_pred_loss vae.target_predictor_hdims = args.target_predictor_hdims vae.hparams.target_predictor_hdims = args.target_predictor_hdims if vae.predict_target and vae.target_predictor is None: vae.hparams.target_predictor_hdims = args.target_predictor_hdims vae.hparams.predict_target = args.predict_target vae.build_target_predictor() vae.eval() # Set up some stuff for the progress bar postfix = dict( n_train=len(datamodule.train_dataset.data), save_path=result_dir ) # Set up results tracking start_time = time.time() train_chosen_indices = _choose_best_rand_points(n_rand_points=args.n_rand_points, n_best_points=args.n_best_points, dataset=datamodule.train_dataset) train_mol_trees = [datamodule.train_dataset.data[i] for i in train_chosen_indices] train_targets = datamodule.train_dataset.data_properties[train_chosen_indices] train_chosen_smiles = [datamodule.train_dataset.canonic_smiles[i] for i in train_chosen_indices] test_chosen_indices = _choose_best_rand_points(n_rand_points=args.n_test_points, n_best_points=0, dataset=datamodule.val_dataset) test_mol_trees = [datamodule.val_dataset.data[i] for i in test_chosen_indices] # Main loop with tqdm( total=1, dynamic_ncols=True, smoothing=0.0, file=sys.stdout ) as pbar: if vae.predict_target and vae.metric_loss is not None: vae.training_m = datamodule.training_m vae.training_M = datamodule.training_M vae.validation_m = datamodule.validation_m vae.validation_M = datamodule.validation_M torch.cuda.empty_cache() # Free the memory up for tensorflow pbar.set_postfix(postfix) pbar.set_description("retraining") print(result_dir) # Optionally do retraining num_epochs = args.n_init_retrain_epochs if num_epochs > 0: retrain_dir = os.path.join(result_dir, "retraining") version = f"retrain_0" retrain_model( model=vae, datamodule=datamodule, save_dir=retrain_dir, version_str=version, num_epochs=num_epochs, gpu=args.gpu ) vae.eval() del num_epochs model = vae # Update progress bar pbar.set_postfix(postfix) # Do querying! gp_dir = os.path.join(result_dir, "gp") os.makedirs(gp_dir, exist_ok=True) gp_data_file = os.path.join(gp_dir, "data.npz") # Next, encode these mol trees if args.gpu: model = model.cuda() train_latent_points = _encode_mol_trees(model, train_mol_trees) test_latent_points = _encode_mol_trees(model, test_mol_trees) if args.use_decoded: print("Use targets from decoded latent test points") _, test_targets = _batch_decode_z_and_props( model, torch.as_tensor(test_latent_points, device=model.device), datamodule, invalid_score=args.invalid_score, pbar=pbar, ) test_targets = np.array(test_targets) else: test_targets = datamodule.val_dataset.data_properties[test_chosen_indices] model = model.cpu() # Make sure to free up GPU memory torch.cuda.empty_cache() # Free the memory up for tensorflow # Save points to file def _save_gp_data(x, y, test_x, y_test, s, file, flip_sign=True): # Prevent overfitting to bad points y = np.maximum(y, args.invalid_score) if flip_sign: y = -y.reshape(-1, 1) # Since it is a maximization problem y_test = -y_test.reshape(-1, 1) else: y = y.reshape(-1, 1) y_test = y_test.reshape(-1, 1) # Save the file np.savez_compressed( file, X_train=x.astype(np.float32), X_test=test_x.astype(np.float32), y_train=y.astype(np.float32), y_test=y_test.astype(np.float32), smiles=s, ) _save_gp_data(train_latent_points, train_targets, test_latent_points, test_targets, train_chosen_smiles, gp_data_file) current_n_inducing_points = min(train_latent_points.shape[0], args.n_inducing_points) new_gp_file = os.path.join(gp_dir, f"new.npz") log_path = os.path.join(gp_dir, f"gp_fit.log") iter_seed = int(np.random.randint(10000)) gp_train_command = [ "python", GP_TRAIN_FILE, f"--nZ={current_n_inducing_points}", f"--seed={iter_seed}", f"--data_file={str(gp_data_file)}", f"--save_file={str(new_gp_file)}", f"--logfile={str(log_path)}", "--use_test_set" ] gp_fit_desc = "GP initial fit" gp_train_command += [ "--init", "--kmeans_init", f"--save_metrics_file={str(result_filepath)}" ] # Set pbar status for user if pbar is not None: old_desc = pbar.desc pbar.set_description(gp_fit_desc) _run_command(gp_train_command, f"GP train {0}") curr_gp_file = new_gp_file print_flush("=== DONE ({:.3f}s) ===".format(time.time() - start_time))
def main(): # Create arg parser parser = argparse.ArgumentParser() parser = TopologyVAE.add_model_specific_args(parser) parser = WeightedNumpyDataset.add_model_specific_args(parser) parser = utils.DataWeighter.add_weight_args(parser) utils.add_default_trainer_args(parser, default_root="") parser.add_argument( "--augment_dataset", action='store_true', help="Use data augmentation or not", ) parser.add_argument( "--use_binary_data", action='store_true', help="Binarize images in the dataset", ) # Parse arguments hparams = parser.parse_args() hparams.root_dir = topology_get_path( k=hparams.rank_weight_k, n_max_epochs=hparams.max_epochs, predict_target=hparams.predict_target, hdims=hparams.target_predictor_hdims, metric_loss=hparams.metric_loss, metric_loss_kw=hparams.metric_loss_kw, beta_target_pred_loss=hparams.beta_target_pred_loss, beta_metric_loss=hparams.beta_metric_loss, latent_dim=hparams.latent_dim, beta_final=hparams.beta_final, use_binary_data=hparams.use_binary_data) print_flush(' '.join(sys.argv[1:])) print_flush(hparams.root_dir) pl.seed_everything(hparams.seed) # Create data if hparams.use_binary_data: if not os.path.exists( os.path.join(get_data_root(), 'topology_data/target_bin.npy')): gen_binary_dataset_from_all_files(get_data_root()) hparams.dataset_path = os.path.join(ROOT_PROJECT, get_topology_binary_dataset_path()) else: if not os.path.exists( os.path.join(get_data_root(), 'topology_data/target.npy')): gen_dataset_from_all_files(get_data_root()) hparams.dataset_path = os.path.join(ROOT_PROJECT, get_topology_dataset_path()) if hparams.augment_dataset: aug = transforms.Compose([ # transforms.Normalize(mean=, std=), # transforms.RandomCrop(30, padding=10), transforms.RandomRotation(45), transforms.RandomRotation(90), transforms.RandomRotation(180), transforms.RandomVerticalFlip(0.5) ]) else: aug = None datamodule = WeightedNumpyDataset(hparams, utils.DataWeighter(hparams), transform=aug) # Load model model = TopologyVAE(hparams) checkpoint_callback = pl.callbacks.ModelCheckpoint(period=max( 1, hparams.max_epochs // 10), monitor="loss/val", save_top_k=-1, save_last=True, mode='min') if hparams.load_from_checkpoint is not None: model = TopologyVAE.load_from_checkpoint(hparams.load_from_checkpoint) utils.update_hparams(hparams, model) trainer = pl.Trainer( gpus=[hparams.cuda] if hparams.cuda else 0, default_root_dir=hparams.root_dir, max_epochs=hparams.max_epochs, callbacks=[ checkpoint_callback, LearningRateMonitor(logging_interval='step') ], resume_from_checkpoint=hparams.load_from_checkpoint) print(f'Load from checkpoint') else: # Main trainer trainer = pl.Trainer( gpus=[hparams.cuda] if hparams.cuda is not None else 0, default_root_dir=hparams.root_dir, max_epochs=hparams.max_epochs, checkpoint_callback=True, callbacks=[ checkpoint_callback, LearningRateMonitor(logging_interval='step') ], terminate_on_nan=True, progress_bar_refresh_rate=5, # gradient_clip_val=20.0, ) # Fit trainer.fit(model, datamodule=datamodule) print( f"Training finished; end of script: rename {checkpoint_callback.best_model_path}" ) shutil.copyfile( checkpoint_callback.best_model_path, os.path.join(os.path.dirname(checkpoint_callback.best_model_path), 'best.ckpt'))
def main_aux(args, result_dir: str): """ main """ # Seeding pl.seed_everything(args.seed) if args.train_only and os.path.exists( args.save_model_path) and not args.overwrite: print_flush(f'--- JTVAE already trained in {args.save_model_path} ---') return # Make results directory data_dir = os.path.join(result_dir, "data") os.makedirs(data_dir, exist_ok=True) setup_logger(os.path.join(result_dir, "log.txt")) # Load data datamodule = WeightedJTNNDataset(args, utils.DataWeighter(args)) datamodule.setup("fit", n_init_points=args.n_init_bo_points) # print python command run cmd = ' '.join(sys.argv[1:]) print_flush(f"{cmd}\n") # Load model if args.use_pretrained: if args.predict_target: if 'pred_y' in args.pretrained_model_file: # fully supervised training from a model already trained with target prediction ckpt = torch.load(args.pretrained_model_file) ckpt['hyper_parameters'][ 'hparams'].beta_target_pred_loss = args.beta_target_pred_loss ckpt['hyper_parameters']['hparams'].predict_target = True ckpt['hyper_parameters'][ 'hparams'].target_predictor_hdims = args.target_predictor_hdims torch.save(ckpt, args.pretrained_model_file) print(os.path.abspath(args.pretrained_model_file)) vae: JTVAE = JTVAE.load_from_checkpoint(args.pretrained_model_file, vocab=datamodule.vocab) vae.beta = vae.hparams.beta_final # Override any beta annealing vae.metric_loss = args.metric_loss vae.hparams.metric_loss = args.metric_loss vae.beta_metric_loss = args.beta_metric_loss vae.hparams.beta_metric_loss = args.beta_metric_loss vae.metric_loss_kw = args.metric_loss_kw vae.hparams.metric_loss_kw = args.metric_loss_kw vae.predict_target = args.predict_target vae.hparams.predict_target = args.predict_target vae.beta_target_pred_loss = args.beta_target_pred_loss vae.hparams.beta_target_pred_loss = args.beta_target_pred_loss vae.target_predictor_hdims = args.target_predictor_hdims vae.hparams.target_predictor_hdims = args.target_predictor_hdims if vae.predict_target and vae.target_predictor is None: vae.hparams.target_predictor_hdims = args.target_predictor_hdims vae.hparams.predict_target = args.predict_target vae.build_target_predictor() else: print("initialising VAE from scratch !") vae: JTVAE = JTVAE(hparams=args, vocab=datamodule.vocab) vae.eval() # Set up some stuff for the progress bar num_retrain = int(np.ceil(args.query_budget / args.retraining_frequency)) postfix = dict(retrain_left=num_retrain, best=float(datamodule.train_dataset.data_properties.max()), n_train=len(datamodule.train_dataset.data), save_path=result_dir) start_num_retrain = 0 # Set up results tracking results = dict( opt_points=[], opt_point_properties=[], opt_model_version=[], params=str(sys.argv), sample_points=[], sample_versions=[], sample_properties=[], ) result_filepath = os.path.join(result_dir, 'results.npz') if not args.overwrite and os.path.exists(result_filepath): with np.load(result_filepath, allow_pickle=True) as npz: results = {} for k in list(npz.keys()): results[k] = npz[k] if k != 'params': results[k] = list(results[k]) else: results[k] = npz[k].item() start_num_retrain = results['opt_model_version'][-1] + 1 prev_retrain_model = args.retraining_frequency * (start_num_retrain - 1) num_sampled_points = len(results['opt_points']) if args.n_init_retrain_epochs == 0 and prev_retrain_model == 0: pretrained_model_path = args.pretrained_model_file else: pretrained_model_path = os.path.join( result_dir, 'retraining', f'retrain_{prev_retrain_model}', 'checkpoints', 'last.ckpt') print(f"Found checkpoint at {pretrained_model_path}") ckpt = torch.load(pretrained_model_path) ckpt['hyper_parameters']['hparams'].metric_loss = args.metric_loss ckpt['hyper_parameters'][ 'hparams'].metric_loss_kw = args.metric_loss_kw ckpt['hyper_parameters'][ 'hparams'].beta_metric_loss = args.beta_metric_loss ckpt['hyper_parameters'][ 'hparams'].beta_target_pred_loss = args.beta_target_pred_loss if args.predict_target: ckpt['hyper_parameters']['hparams'].predict_target = True ckpt['hyper_parameters'][ 'hparams'].target_predictor_hdims = args.target_predictor_hdims torch.save(ckpt, pretrained_model_path) print(f"Loading model from {pretrained_model_path}") vae.load_from_checkpoint(pretrained_model_path, vocab=datamodule.vocab) if args.predict_target and not hasattr(vae.hparams, 'predict_target'): vae.hparams.target_predictor_hdims = args.target_predictor_hdims vae.hparams.predict_target = args.predict_target # vae.hparams.cuda = args.cuda vae.beta = vae.hparams.beta_final # Override any beta annealing vae.eval() # Set up some stuff for the progress bar num_retrain = int( np.ceil(args.query_budget / args.retraining_frequency)) - start_num_retrain print(f"Append existing points and properties to datamodule...") datamodule.append_train_data(np.array(results['opt_points']), np.array(results['opt_point_properties'])) postfix = dict(retrain_left=num_retrain, best=float( datamodule.train_dataset.data_properties.max()), n_train=len(datamodule.train_dataset.data), initial=num_sampled_points, save_path=result_dir) print( f"Retrain from {result_dir} | Best: {max(results['opt_point_properties'])}" ) start_time = time.time() # Main loop with tqdm(total=args.query_budget, dynamic_ncols=True, smoothing=0.0, file=sys.stdout) as pbar: for ret_idx in range(start_num_retrain, start_num_retrain + num_retrain): if vae.predict_target and vae.metric_loss is not None: vae.training_m = datamodule.training_m vae.training_M = datamodule.training_M vae.validation_m = datamodule.validation_m vae.validation_M = datamodule.validation_M torch.cuda.empty_cache() # Free the memory up for tensorflow pbar.set_postfix(postfix) pbar.set_description("retraining") print(result_dir) # Decide whether to retrain samples_so_far = args.retraining_frequency * ret_idx # Optionally do retraining num_epochs = args.n_retrain_epochs if ret_idx == 0 and args.n_init_retrain_epochs is not None: num_epochs = args.n_init_retrain_epochs if num_epochs > 0: retrain_dir = os.path.join(result_dir, "retraining") version = f"retrain_{samples_so_far}" retrain_model(model=vae, datamodule=datamodule, save_dir=retrain_dir, version_str=version, num_epochs=num_epochs, gpu=args.gpu, store_best=args.train_only, best_ckpt_path=args.save_model_path) vae.eval() if args.train_only: return del num_epochs model = vae # Update progress bar postfix["retrain_left"] -= 1 pbar.set_postfix(postfix) # Draw samples for logs! if args.samples_per_model > 0: pbar.set_description("sampling") with trange(args.samples_per_model, desc="sampling", leave=False) as sample_pbar: sample_x, sample_y = latent_sampling( args, model, datamodule, args.samples_per_model, pbar=sample_pbar) # Append to results dict results["sample_points"].append(sample_x) results["sample_properties"].append(sample_y) results["sample_versions"].append(ret_idx) # Do querying! pbar.set_description("querying") num_queries_to_do = min(args.retraining_frequency, args.query_budget - samples_so_far) if args.lso_strategy == "opt": gp_dir = os.path.join(result_dir, "gp", f"iter{samples_so_far}") os.makedirs(gp_dir, exist_ok=True) gp_data_file = os.path.join(gp_dir, "data.npz") gp_err_data_file = os.path.join(gp_dir, "data_err.npz") x_new, y_new = latent_optimization( model=model, datamodule=datamodule, n_inducing_points=args.n_inducing_points, n_best_points=args.n_best_points, n_rand_points=args.n_rand_points, num_queries_to_do=num_queries_to_do, gp_data_file=gp_data_file, gp_err_data_file=gp_err_data_file, gp_run_folder=gp_dir, gpu=args.gpu, invalid_score=args.invalid_score, pbar=pbar, postfix=postfix, error_aware_acquisition=args.error_aware_acquisition, ) elif args.lso_strategy == "sample": x_new, y_new = latent_sampling( args, model, datamodule, num_queries_to_do, pbar=pbar, ) else: raise NotImplementedError(args.lso_strategy) # Update dataset datamodule.append_train_data(x_new, y_new) # Add new results results["opt_points"] += list(x_new) results["opt_point_properties"] += list(y_new) results["opt_model_version"] += [ret_idx] * len(x_new) postfix["best"] = max(postfix["best"], float(max(y_new))) postfix["n_train"] = len(datamodule.train_dataset.data) pbar.set_postfix(postfix) # Save results np.savez_compressed(os.path.join(result_dir, "results.npz"), **results) # Keep a record of the dataset here new_data_file = os.path.join( data_dir, f"train_data_iter{samples_so_far + num_queries_to_do}.txt") with open(new_data_file, "w") as f: f.write("\n".join(datamodule.train_dataset.canonic_smiles)) print_flush("=== DONE ({:.3f}s) ===".format(time.time() - start_time))
def main(): # Create arg parser parser = argparse.ArgumentParser() parser = EquationVaeTorch.add_model_specific_args(parser) parser = WeightedExprDataset.add_model_specific_args(parser) parser = utils.DataWeighter.add_weight_args(parser) utils.add_default_trainer_args(parser, default_root='') parser.add_argument("--ignore_percentile", type=int, default=50, help="percentile of scores to ignore") parser.add_argument("--good_percentile", type=int, default=0, help="percentile of good scores selected") parser.add_argument("--data_seed", type=int, required=True, help="Seed that has been used to generate the dataset") # Parse arguments hparams = parser.parse_args() hparams.dataset_path = get_filepath( hparams.ignore_percentile, hparams.dataset_path, hparams.data_seed, good_percentile=hparams.good_percentile) hparams.root_dir = get_path( k=hparams.rank_weight_k, ignore_percentile=hparams.ignore_percentile, good_percentile=hparams.good_percentile, n_max_epochs=hparams.max_epochs, predict_target=hparams.predict_target, beta_final=hparams.beta_final, beta_target_pred_loss=hparams.beta_target_pred_loss, beta_metric_loss=hparams.beta_metric_loss, latent_dim=hparams.latent_dim, hdims=hparams.target_predictor_hdims, metric_loss=hparams.metric_loss, metric_loss_kw=hparams.metric_loss_kw) print_flush(' '.join(sys.argv[1:])) print_flush(hparams.root_dir) pl.seed_everything(hparams.seed) # Create data datamodule = WeightedExprDataset(hparams, utils.DataWeighter(hparams), add_channel=False) device = hparams.cuda if device is not None: torch.cuda.set_device(device) data_info = G.gram.split('\n') # Load model model = EquationVaeTorch(hparams, len(data_info), MAX_LEN) # model.decoder.apply(torch_weight_init) checkpoint_callback = pl.callbacks.ModelCheckpoint(period=max( 1, hparams.max_epochs // 20), monitor="loss/val", save_top_k=-1, save_last=True, mode='min') if hparams.load_from_checkpoint is not None: # .load_from_checkpoint(hparams.load_from_checkpoint) model = EquationVaeTorch.load_from_checkpoint( hparams.load_from_checkpoint, len(data_info), MAX_LEN) utils.update_hparams(hparams, model) trainer = pl.Trainer( gpus=[hparams.cuda] if hparams.cuda else 0, default_root_dir=hparams.root_dir, max_epochs=hparams.max_epochs, callbacks=[ checkpoint_callback, LearningRateMonitor(logging_interval='step') ], resume_from_checkpoint=hparams.load_from_checkpoint) print(f'Load from checkpoint') else: # Main trainer trainer = pl.Trainer( gpus=[hparams.cuda] if hparams.cuda is not None else 0, default_root_dir=hparams.root_dir, max_epochs=hparams.max_epochs, checkpoint_callback=True, callbacks=[ checkpoint_callback, LearningRateMonitor(logging_interval='step') ], terminate_on_nan=True, progress_bar_refresh_rate=100) # Fit trainer.fit(model, datamodule=datamodule) print( f"Training finished; end of script: rename {checkpoint_callback.best_model_path}" ) shutil.copyfile( checkpoint_callback.best_model_path, os.path.join(os.path.dirname(checkpoint_callback.best_model_path), 'best.ckpt'))
parser = argparse.ArgumentParser() parser = JTVAE.add_model_specific_args(parser) parser = WeightedJTNNDataset.add_model_specific_args(parser) parser = utils.DataWeighter.add_weight_args(parser) utils.add_default_trainer_args(parser, default_root=None) # Parse arguments hparams = parser.parse_args() hparams.root_dir = os.path.join(get_storage_root(), hparams.root_dir) pl.seed_everything(hparams.seed) print_flush(' '.join(sys.argv[1:])) # Create data datamodule = WeightedJTNNDataset(hparams, utils.DataWeighter(hparams)) datamodule.setup("fit") # Load model model = JTVAE(hparams, datamodule.vocab) checkpoint_callback = pl.callbacks.ModelCheckpoint( period=1, monitor="loss/val", save_top_k=1, save_last=True, mode='min' ) if hparams.load_from_checkpoint is not None: # .load_from_checkpoint(hparams.load_from_checkpoint) # utils.update_hparams(hparams, model) trainer = pl.Trainer(gpus=[hparams.cuda] if hparams.cuda else 0, default_root_dir=hparams.root_dir,
def main(): # Create arg parser parser = argparse.ArgumentParser() parser = ShapesVAE.add_model_specific_args(parser) parser = WeightedNumpyDataset.add_model_specific_args(parser) parser = utils.DataWeighter.add_weight_args(parser) utils.add_default_trainer_args(parser, default_root="") # Parse arguments hparams = parser.parse_args() hparams.root_dir = shape_get_path(k=hparams.rank_weight_k, predict_target=hparams.predict_target, hdims=hparams.target_predictor_hdims, metric_loss=hparams.metric_loss, metric_loss_kw=hparams.metric_loss_kw, latent_dim=hparams.latent_dim) print_flush(' '.join(sys.argv[1:])) print_flush(hparams.root_dir) pl.seed_everything(hparams.seed) # Create data datamodule = WeightedNumpyDataset(hparams, utils.DataWeighter(hparams)) # Load model model = ShapesVAE(hparams) checkpoint_callback = pl.callbacks.ModelCheckpoint( period=max(1, hparams.max_epochs // 20), monitor="loss/val", save_top_k=1, save_last=True, mode='min' ) if hparams.load_from_checkpoint is not None: model = ShapesVAE.load_from_checkpoint(hparams.load_from_checkpoint) utils.update_hparams(hparams, model) trainer = pl.Trainer(gpus=[hparams.cuda] if hparams.cuda is not None else 0, default_root_dir=hparams.root_dir, max_epochs=hparams.max_epochs, callbacks=[checkpoint_callback, LearningRateMonitor(logging_interval='step')], resume_from_checkpoint=hparams.load_from_checkpoint) print(f'Load from checkpoint') else: # Main trainer trainer = pl.Trainer( gpus=[hparams.cuda] if hparams.cuda is not None else 0, default_root_dir=hparams.root_dir, max_epochs=hparams.max_epochs, checkpoint_callback=True, callbacks=[checkpoint_callback, LearningRateMonitor(logging_interval='step')], terminate_on_nan=True, progress_bar_refresh_rate=100 ) # Fit trainer.fit(model, datamodule=datamodule) print(f"Training finished; end of script: rename {checkpoint_callback.best_model_path}") shutil.copyfile(checkpoint_callback.best_model_path, os.path.join( os.path.dirname(checkpoint_callback.best_model_path), 'best.ckpt' ))