def get_latent_encodings(use_test_set, use_full_data_for_gp, model, data_file, data_set: WeightedMolTreeFolder, n_best, n_rand, true_vals: bool, tkwargs: Dict[str, Any], bs=64, return_inds: bool = False): """ get latent encodings and split data into train and test data """ print_flush( "\tComputing latent training data encodings and corresponding scores..." ) if len(data_set) < n_best + n_rand: n_best, n_rand = int(n_best / (n_best + n_rand) * len(data_set)), int( n_rand / (n_best + n_rand) * len(data_set)) n_rand += 1 if n_best + n_rand < len(data_set) else 0 if use_full_data_for_gp: chosen_indices = np.arange(len(data_set)) else: chosen_indices = _choose_best_rand_points(n_best, n_rand, data_set) mol_trees = [data_set.data[i] for i in chosen_indices] targets = data_set.data_properties[chosen_indices] # Next, encode these mol trees latent_points = _encode_mol_trees(model, mol_trees, batch_size=bs) targets = targets.reshape((-1, 1)) # problem with train_inds returned by ubsample_dataset is they are train indices within passed points and not # indices of the original dataset if not use_full_data_for_gp: assert not use_test_set X_mean, X_std = latent_points.mean(), latent_points.std() y_mean, y_std = targets.mean(), targets.std() save_data(latent_points, targets, None, None, X_mean, X_std, y_mean, y_std, data_file) if return_inds: return latent_points, targets, None, None, X_mean, y_mean, X_std, y_std, chosen_indices, None else: return latent_points, targets, None, None, X_mean, y_mean, X_std, y_std return subsample_dataset(latent_points, targets, data_file, use_test_set, use_full_data_for_gp, n_best, n_rand, return_inds=return_inds)
def get_latent_encodings(use_test_set, use_full_data_for_gp, model, data_file, data_scores, data_imgs, n_best, n_rand, tkwargs: Dict[str, Any], bs=5000, bs_true_eval: int = 256, repeat: int = 10, return_inds: bool = False): """ get latent encodings and split data into train and test data """ print_flush("\tComputing latent training data encodings and corresponding scores...") n_batches = int(np.ceil(len(data_imgs) / bs)) if n_best > 0 and n_rand > 0 and (n_best + n_rand) < len(data_scores): # do not encode all data, it's too long, only encode the number of points needed (w.r.t. n_best+n_rand) sorted_idx = np.argsort(-data_scores) best_idx = sorted_idx[:n_best] rand_idx = sorted_idx[np.random.choice(list(range(n_best + 1, len(data_scores))), n_rand, replace=False)] n_best_scores = data_scores[best_idx] n_best_data = data_imgs[best_idx] n_rand_scores = data_scores[rand_idx] n_rand_data = data_imgs[rand_idx] # concatenate and then shuffle scores_best_cat_rand = np.concatenate([n_best_scores, n_rand_scores]) data_best_cat_rand = np.concatenate([n_best_data, n_rand_data]) cat_idx = np.arange(len(scores_best_cat_rand)) cat_shuffled_idx = np.random.choice(cat_idx, len(cat_idx)) scores_best_cat_rand = scores_best_cat_rand[cat_shuffled_idx] data_best_cat_rand = data_best_cat_rand[cat_shuffled_idx] n_batches = int(np.ceil(len(data_best_cat_rand) / bs)) Xs = [model.encode_to_params( torch.from_numpy(data_best_cat_rand[i * bs:(i + 1) * bs]).to(**tkwargs).unsqueeze(1) )[0].detach().cpu().numpy() for i in tqdm(range(n_batches))] else: Xs = [model.encode_to_params( torch.from_numpy(data_imgs[i * bs:(i + 1) * bs]).to(**tkwargs).unsqueeze(1) )[0].detach().cpu().numpy() for i in tqdm(range(n_batches))] X = np.concatenate(Xs, axis=0) y = scores_best_cat_rand if n_best > 0 and n_rand > 0 and (n_best + n_rand) < len(data_scores) else data_scores y = y.reshape((-1, 1)) if n_best > 0 and n_rand > 0 and (n_best + n_rand) < len(data_scores): assert not use_test_set assert not use_full_data_for_gp X_mean, X_std = X.mean(), X.std() y_mean, y_std = y.mean(), y.std() save_data(X, y, None, None, X_mean, X_std, y_mean, y_std, data_file) if return_inds: train_inds = np.concatenate([best_idx, rand_idx])[cat_shuffled_idx] return X, y, None, None, X_mean, y_mean, X_std, y_std, train_inds, None else: return X, y, None, None, X_mean, y_mean, X_std, y_std return subsample_dataset(X, y, data_file, use_test_set, use_full_data_for_gp, n_best, n_rand, return_inds=return_inds)
def get_latent_encodings(use_test_set: bool, use_full_data_for_gp: bool, model: ShapesVAE, data_file: str, data_imgs: np.ndarray, data_scores: np.ndarray, n_best: int, n_rand: int, tkwargs: Dict[str, Any], bs=1000): """ get latent encodings and split data into train and test data """ print_flush( "\tComputing latent training data encodings and corresponding scores..." ) X = get_latent_encodings_aux(model=model, data_imgs=data_imgs, bs=bs, tkwargs=tkwargs) y = data_scores.reshape((-1, 1)) return _subsample_dataset(X, y, data_file, use_test_set, use_full_data_for_gp, n_best, n_rand)
def main_aux(args, result_dir: str): """ main """ # Seeding pl.seed_everything(args.seed) # Make results directory data_dir = os.path.join(result_dir, "data") os.makedirs(data_dir, exist_ok=True) setup_logger(os.path.join(result_dir, "log.txt")) result_filepath = os.path.join(result_dir, 'results.pkl') if not args.overwrite and os.path.exists(result_filepath): print(f"Already exists: {result_dir}") return # Load data datamodule = WeightedJTNNDataset(args, utils.DataWeighter(args)) datamodule.setup("fit", n_init_points=args.n_init_bo_points) # print python command run cmd = ' '.join(sys.argv[1:]) print_flush(f"{cmd}\n") # Load model assert args.use_pretrained if args.predict_target: if 'pred_y' in args.pretrained_model_file: # fully supervised setup from a model trained with target prediction ckpt = torch.load(args.pretrained_model_file) ckpt['hyper_parameters']['hparams'].beta_target_pred_loss = args.beta_target_pred_loss ckpt['hyper_parameters']['hparams'].predict_target = True ckpt['hyper_parameters']['hparams'].target_predictor_hdims = args.target_predictor_hdims torch.save(ckpt, args.pretrained_model_file) vae: JTVAE = JTVAE.load_from_checkpoint(args.pretrained_model_file, vocab=datamodule.vocab) vae.beta = vae.hparams.beta_final # Override any beta annealing vae.metric_loss = args.metric_loss vae.hparams.metric_loss = args.metric_loss vae.beta_metric_loss = args.beta_metric_loss vae.hparams.beta_metric_loss = args.beta_metric_loss vae.metric_loss_kw = args.metric_loss_kw vae.hparams.metric_loss_kw = args.metric_loss_kw vae.predict_target = args.predict_target vae.hparams.predict_target = args.predict_target vae.beta_target_pred_loss = args.beta_target_pred_loss vae.hparams.beta_target_pred_loss = args.beta_target_pred_loss vae.target_predictor_hdims = args.target_predictor_hdims vae.hparams.target_predictor_hdims = args.target_predictor_hdims if vae.predict_target and vae.target_predictor is None: vae.hparams.target_predictor_hdims = args.target_predictor_hdims vae.hparams.predict_target = args.predict_target vae.build_target_predictor() vae.eval() # Set up some stuff for the progress bar postfix = dict( n_train=len(datamodule.train_dataset.data), save_path=result_dir ) # Set up results tracking start_time = time.time() train_chosen_indices = _choose_best_rand_points(n_rand_points=args.n_rand_points, n_best_points=args.n_best_points, dataset=datamodule.train_dataset) train_mol_trees = [datamodule.train_dataset.data[i] for i in train_chosen_indices] train_targets = datamodule.train_dataset.data_properties[train_chosen_indices] train_chosen_smiles = [datamodule.train_dataset.canonic_smiles[i] for i in train_chosen_indices] test_chosen_indices = _choose_best_rand_points(n_rand_points=args.n_test_points, n_best_points=0, dataset=datamodule.val_dataset) test_mol_trees = [datamodule.val_dataset.data[i] for i in test_chosen_indices] # Main loop with tqdm( total=1, dynamic_ncols=True, smoothing=0.0, file=sys.stdout ) as pbar: if vae.predict_target and vae.metric_loss is not None: vae.training_m = datamodule.training_m vae.training_M = datamodule.training_M vae.validation_m = datamodule.validation_m vae.validation_M = datamodule.validation_M torch.cuda.empty_cache() # Free the memory up for tensorflow pbar.set_postfix(postfix) pbar.set_description("retraining") print(result_dir) # Optionally do retraining num_epochs = args.n_init_retrain_epochs if num_epochs > 0: retrain_dir = os.path.join(result_dir, "retraining") version = f"retrain_0" retrain_model( model=vae, datamodule=datamodule, save_dir=retrain_dir, version_str=version, num_epochs=num_epochs, gpu=args.gpu ) vae.eval() del num_epochs model = vae # Update progress bar pbar.set_postfix(postfix) # Do querying! gp_dir = os.path.join(result_dir, "gp") os.makedirs(gp_dir, exist_ok=True) gp_data_file = os.path.join(gp_dir, "data.npz") # Next, encode these mol trees if args.gpu: model = model.cuda() train_latent_points = _encode_mol_trees(model, train_mol_trees) test_latent_points = _encode_mol_trees(model, test_mol_trees) if args.use_decoded: print("Use targets from decoded latent test points") _, test_targets = _batch_decode_z_and_props( model, torch.as_tensor(test_latent_points, device=model.device), datamodule, invalid_score=args.invalid_score, pbar=pbar, ) test_targets = np.array(test_targets) else: test_targets = datamodule.val_dataset.data_properties[test_chosen_indices] model = model.cpu() # Make sure to free up GPU memory torch.cuda.empty_cache() # Free the memory up for tensorflow # Save points to file def _save_gp_data(x, y, test_x, y_test, s, file, flip_sign=True): # Prevent overfitting to bad points y = np.maximum(y, args.invalid_score) if flip_sign: y = -y.reshape(-1, 1) # Since it is a maximization problem y_test = -y_test.reshape(-1, 1) else: y = y.reshape(-1, 1) y_test = y_test.reshape(-1, 1) # Save the file np.savez_compressed( file, X_train=x.astype(np.float32), X_test=test_x.astype(np.float32), y_train=y.astype(np.float32), y_test=y_test.astype(np.float32), smiles=s, ) _save_gp_data(train_latent_points, train_targets, test_latent_points, test_targets, train_chosen_smiles, gp_data_file) current_n_inducing_points = min(train_latent_points.shape[0], args.n_inducing_points) new_gp_file = os.path.join(gp_dir, f"new.npz") log_path = os.path.join(gp_dir, f"gp_fit.log") iter_seed = int(np.random.randint(10000)) gp_train_command = [ "python", GP_TRAIN_FILE, f"--nZ={current_n_inducing_points}", f"--seed={iter_seed}", f"--data_file={str(gp_data_file)}", f"--save_file={str(new_gp_file)}", f"--logfile={str(log_path)}", "--use_test_set" ] gp_fit_desc = "GP initial fit" gp_train_command += [ "--init", "--kmeans_init", f"--save_metrics_file={str(result_filepath)}" ] # Set pbar status for user if pbar is not None: old_desc = pbar.desc pbar.set_description(gp_fit_desc) _run_command(gp_train_command, f"GP train {0}") curr_gp_file = new_gp_file print_flush("=== DONE ({:.3f}s) ===".format(time.time() - start_time))
def main(): # Create arg parser parser = argparse.ArgumentParser() parser = TopologyVAE.add_model_specific_args(parser) parser = WeightedNumpyDataset.add_model_specific_args(parser) parser = utils.DataWeighter.add_weight_args(parser) utils.add_default_trainer_args(parser, default_root="") parser.add_argument( "--augment_dataset", action='store_true', help="Use data augmentation or not", ) parser.add_argument( "--use_binary_data", action='store_true', help="Binarize images in the dataset", ) # Parse arguments hparams = parser.parse_args() hparams.root_dir = topology_get_path( k=hparams.rank_weight_k, n_max_epochs=hparams.max_epochs, predict_target=hparams.predict_target, hdims=hparams.target_predictor_hdims, metric_loss=hparams.metric_loss, metric_loss_kw=hparams.metric_loss_kw, beta_target_pred_loss=hparams.beta_target_pred_loss, beta_metric_loss=hparams.beta_metric_loss, latent_dim=hparams.latent_dim, beta_final=hparams.beta_final, use_binary_data=hparams.use_binary_data) print_flush(' '.join(sys.argv[1:])) print_flush(hparams.root_dir) pl.seed_everything(hparams.seed) # Create data if hparams.use_binary_data: if not os.path.exists( os.path.join(get_data_root(), 'topology_data/target_bin.npy')): gen_binary_dataset_from_all_files(get_data_root()) hparams.dataset_path = os.path.join(ROOT_PROJECT, get_topology_binary_dataset_path()) else: if not os.path.exists( os.path.join(get_data_root(), 'topology_data/target.npy')): gen_dataset_from_all_files(get_data_root()) hparams.dataset_path = os.path.join(ROOT_PROJECT, get_topology_dataset_path()) if hparams.augment_dataset: aug = transforms.Compose([ # transforms.Normalize(mean=, std=), # transforms.RandomCrop(30, padding=10), transforms.RandomRotation(45), transforms.RandomRotation(90), transforms.RandomRotation(180), transforms.RandomVerticalFlip(0.5) ]) else: aug = None datamodule = WeightedNumpyDataset(hparams, utils.DataWeighter(hparams), transform=aug) # Load model model = TopologyVAE(hparams) checkpoint_callback = pl.callbacks.ModelCheckpoint(period=max( 1, hparams.max_epochs // 10), monitor="loss/val", save_top_k=-1, save_last=True, mode='min') if hparams.load_from_checkpoint is not None: model = TopologyVAE.load_from_checkpoint(hparams.load_from_checkpoint) utils.update_hparams(hparams, model) trainer = pl.Trainer( gpus=[hparams.cuda] if hparams.cuda else 0, default_root_dir=hparams.root_dir, max_epochs=hparams.max_epochs, callbacks=[ checkpoint_callback, LearningRateMonitor(logging_interval='step') ], resume_from_checkpoint=hparams.load_from_checkpoint) print(f'Load from checkpoint') else: # Main trainer trainer = pl.Trainer( gpus=[hparams.cuda] if hparams.cuda is not None else 0, default_root_dir=hparams.root_dir, max_epochs=hparams.max_epochs, checkpoint_callback=True, callbacks=[ checkpoint_callback, LearningRateMonitor(logging_interval='step') ], terminate_on_nan=True, progress_bar_refresh_rate=5, # gradient_clip_val=20.0, ) # Fit trainer.fit(model, datamodule=datamodule) print( f"Training finished; end of script: rename {checkpoint_callback.best_model_path}" ) shutil.copyfile( checkpoint_callback.best_model_path, os.path.join(os.path.dirname(checkpoint_callback.best_model_path), 'best.ckpt'))
def latent_optimization( model: JTVAE, datamodule: WeightedJTNNDataset, n_inducing_points: int, n_best_points: int, n_rand_points: int, num_queries_to_do: int, invalid_score: float, gp_data_file: str, gp_run_folder: str, gpu: bool, error_aware_acquisition: bool, gp_err_data_file: Optional[str], pbar=None, postfix=None, ): ################################################## # Prepare GP ################################################## # First, choose GP points to train! dset = datamodule.train_dataset chosen_indices = _choose_best_rand_points(n_rand_points=n_rand_points, n_best_points=n_best_points, dataset=dset) mol_trees = [dset.data[i] for i in chosen_indices] targets = dset.data_properties[chosen_indices] chosen_smiles = [dset.canonic_smiles[i] for i in chosen_indices] # Next, encode these mol trees if gpu: model = model.cuda() latent_points = _encode_mol_trees(model, mol_trees) model = model.cpu() # Make sure to free up GPU memory torch.cuda.empty_cache() # Free the memory up for tensorflow # Save points to file def _save_gp_data(x, y, s, file, flip_sign=True): # Prevent overfitting to bad points y = np.maximum(y, invalid_score) if flip_sign: y = -y.reshape(-1, 1) # Since it is a maximization problem else: y = y.reshape(-1, 1) # Save the file np.savez_compressed( file, X_train=x.astype(np.float32), X_test=[], y_train=y.astype(np.float32), y_test=[], smiles=s, ) # If using error-aware acquisition, compute reconstruction error of selected points if error_aware_acquisition: assert gp_err_data_file is not None, "Please provide a data file for the error GP" if gpu: model = model.cuda() error_train, safe_idx = get_rec_x_error( model, tkwargs={'dtype': torch.float}, data=[datamodule.train_dataset.data[i] for i in chosen_indices], ) # exclude points for which we could not compute the reconstruction error from the objective GP dataset if len(safe_idx) < latent_points.shape[0]: failed = [ i for i in range(latent_points.shape[0]) if i not in safe_idx ] print_flush( f"Could not compute the recon. err. of {len(failed)} points -> excluding them." ) latent_points_err = latent_points[safe_idx] chosen_smiles_err = [chosen_smiles[i] for i in safe_idx] else: latent_points_err = latent_points chosen_smiles_err = chosen_smiles model = model.cpu() # Make sure to free up GPU memory torch.cuda.empty_cache() # Free the memory up for tensorflow _save_gp_data(latent_points, error_train.cpu().numpy(), chosen_smiles, gp_err_data_file) _save_gp_data(latent_points, targets, chosen_smiles, gp_data_file, flip_sign=False) ################################################## # Run iterative GP fitting/optimization ################################################## curr_gp_file = None curr_gp_err_file = None all_new_smiles = [] all_new_props = [] all_new_err = [] for gp_iter in range(num_queries_to_do): gp_initial_train = gp_iter == 0 current_n_inducing_points = min(latent_points.shape[0], n_inducing_points) if latent_points.shape[0] == n_inducing_points: gp_initial_train = True # Part 1: fit GP # =============================== new_gp_file = os.path.join(gp_run_folder, f"gp_train_res{gp_iter:04d}.npz") new_gp_err_file = os.path.join( gp_run_folder, f"gp_err_train_res0000.npz") # no incremental fit of error-GP log_path = os.path.join(gp_run_folder, f"gp_train{gp_iter:04d}.log") err_log_path = os.path.join(gp_run_folder, f"gp_err_train0000.log") try: iter_seed = int(np.random.randint(10000)) gp_train_command = [ "python", GP_TRAIN_FILE, f"--nZ={current_n_inducing_points}", f"--seed={iter_seed}", f"--data_file={str(gp_data_file)}", f"--save_file={str(new_gp_file)}", f"--logfile={str(log_path)}", f"--normal_inputs", f"--standard_targets" ] gp_err_train_command = [ "python", GP_TRAIN_FILE, f"--nZ={n_inducing_points}", f"--seed={iter_seed}", f"--data_file={str(gp_err_data_file)}", f"--save_file={str(new_gp_err_file)}", f"--logfile={str(err_log_path)}", ] if gp_initial_train: # Add commands for initial fitting gp_fit_desc = "GP initial fit" gp_train_command += [ "--init", "--kmeans_init", ] gp_err_train_command += [ "--init", "--kmeans_init", ] else: gp_fit_desc = "GP incremental fit" gp_train_command += [ f"--gp_file={str(curr_gp_file)}", f"--n_perf_measure=1", # specifically see how well it fits the last point! ] gp_err_train_command += [ f"--gp_file={str(curr_gp_err_file)}", f"--n_perf_measure=1", # specifically see how well it fits the last point! ] # Set pbar status for user if pbar is not None: old_desc = pbar.desc pbar.set_description(gp_fit_desc) # Run command print_flush("Training objective GP...") _run_command(gp_train_command, f"GP train {gp_iter}") curr_gp_file = new_gp_file if error_aware_acquisition: if gp_initial_train: # currently we do not incrementally refit this GP as we do not estimate rec. err. _run_command(gp_err_train_command, f"GP err train {gp_iter}") curr_gp_err_file = new_gp_err_file except AssertionError as e: logs = traceback.format_exc() print(logs) print_flush( f'Got an error in GP training. Retrying with different seed or crash...' ) iter_seed = int(np.random.randint(10000)) gp_train_command = [ "python", GP_TRAIN_FILE, f"--nZ={current_n_inducing_points}", f"--seed={iter_seed}", f"--data_file={str(gp_data_file)}", f"--save_file={str(new_gp_file)}", f"--logfile={str(log_path)}", ] gp_err_train_command = [ "python", GP_TRAIN_FILE, f"--nZ={n_inducing_points}", f"--seed={iter_seed}", f"--data_file={str(gp_err_data_file)}", f"--save_file={str(new_gp_err_file)}", f"--logfile={str(err_log_path)}", f"--normal_inputs", f"--standard_targets" ] if gp_initial_train: # Add commands for initial fitting gp_fit_desc = "GP initial fit" gp_train_command += [ "--init", "--kmeans_init", ] gp_err_train_command += [ "--init", "--kmeans_init", ] else: gp_fit_desc = "GP incremental fit" gp_train_command += [ f"--gp_file={str(curr_gp_file)}", f"--n_perf_measure=1", # specifically see how well it fits the last point! ] gp_err_train_command += [ f"--gp_file={str(curr_gp_err_file)}", f"--n_perf_measure=1", # specifically see how well it fits the last point! ] # Set pbar status for user if pbar is not None: old_desc = pbar.desc pbar.set_description(gp_fit_desc) # Run command _run_command(gp_train_command, f"GP train {gp_iter}") curr_gp_file = new_gp_file if error_aware_acquisition: if gp_initial_train: # currently we do not incrementally refit this GP as we do not estimate rec. err. _run_command(gp_err_train_command, f"GP err train {gp_iter}") curr_gp_err_file = new_gp_err_file # Part 2: optimize GP acquisition func to query point # =============================== max_retry = 3 n_retry = 0 good = False while not good: try: # Run GP opt script opt_path = os.path.join(gp_run_folder, f"gp_opt_res{gp_iter:04d}.npy") log_path = os.path.join(gp_run_folder, f"gp_opt_{gp_iter:04d}.log") gp_opt_command = [ "python", GP_OPT_FILE, f"--seed={iter_seed}", f"--gp_file={str(curr_gp_file)}", f"--data_file={str(gp_data_file)}", f"--save_file={str(opt_path)}", f"--n_out={1}", # hard coded f"--logfile={str(log_path)}", ] if error_aware_acquisition: gp_opt_command += [ f"--gp_err_file={str(curr_gp_err_file)}", f"--data_err_file={str(gp_err_data_file)}", ] if pbar is not None: pbar.set_description("optimizing acq func") print_flush("Start running gp_opt_command") _run_command(gp_opt_command, f"GP opt {gp_iter}") # Load point z_opt = np.load(opt_path) # Decode point smiles_opt, prop_opt = _batch_decode_z_and_props( model, torch.as_tensor(z_opt, device=model.device), datamodule, invalid_score=invalid_score, pbar=pbar, ) good = True except AssertionError: iter_seed = int(np.random.randint(10000)) n_retry += 1 print_flush( f'Got an error in optimization......trial {n_retry} / {max_retry}' ) if n_retry >= max_retry: raise # Reset pbar description if pbar is not None: pbar.set_description(old_desc) # Update best point in progress bar if postfix is not None: postfix["best"] = max(postfix["best"], float(max(prop_opt))) pbar.set_postfix(postfix) # Append to new GP data latent_points = np.concatenate([latent_points, z_opt], axis=0) targets = np.concatenate([targets, prop_opt], axis=0) chosen_smiles.append(smiles_opt) _save_gp_data(latent_points, targets, chosen_smiles, gp_data_file) # Append to overall list all_new_smiles += smiles_opt all_new_props += prop_opt if error_aware_acquisition: pass # Update datamodule with ALL data points return all_new_smiles, all_new_props
def main_aux(args, result_dir: str): """ main """ # Seeding pl.seed_everything(args.seed) if args.train_only and os.path.exists( args.save_model_path) and not args.overwrite: print_flush(f'--- JTVAE already trained in {args.save_model_path} ---') return # Make results directory data_dir = os.path.join(result_dir, "data") os.makedirs(data_dir, exist_ok=True) setup_logger(os.path.join(result_dir, "log.txt")) # Load data datamodule = WeightedJTNNDataset(args, utils.DataWeighter(args)) datamodule.setup("fit", n_init_points=args.n_init_bo_points) # print python command run cmd = ' '.join(sys.argv[1:]) print_flush(f"{cmd}\n") # Load model if args.use_pretrained: if args.predict_target: if 'pred_y' in args.pretrained_model_file: # fully supervised training from a model already trained with target prediction ckpt = torch.load(args.pretrained_model_file) ckpt['hyper_parameters'][ 'hparams'].beta_target_pred_loss = args.beta_target_pred_loss ckpt['hyper_parameters']['hparams'].predict_target = True ckpt['hyper_parameters'][ 'hparams'].target_predictor_hdims = args.target_predictor_hdims torch.save(ckpt, args.pretrained_model_file) print(os.path.abspath(args.pretrained_model_file)) vae: JTVAE = JTVAE.load_from_checkpoint(args.pretrained_model_file, vocab=datamodule.vocab) vae.beta = vae.hparams.beta_final # Override any beta annealing vae.metric_loss = args.metric_loss vae.hparams.metric_loss = args.metric_loss vae.beta_metric_loss = args.beta_metric_loss vae.hparams.beta_metric_loss = args.beta_metric_loss vae.metric_loss_kw = args.metric_loss_kw vae.hparams.metric_loss_kw = args.metric_loss_kw vae.predict_target = args.predict_target vae.hparams.predict_target = args.predict_target vae.beta_target_pred_loss = args.beta_target_pred_loss vae.hparams.beta_target_pred_loss = args.beta_target_pred_loss vae.target_predictor_hdims = args.target_predictor_hdims vae.hparams.target_predictor_hdims = args.target_predictor_hdims if vae.predict_target and vae.target_predictor is None: vae.hparams.target_predictor_hdims = args.target_predictor_hdims vae.hparams.predict_target = args.predict_target vae.build_target_predictor() else: print("initialising VAE from scratch !") vae: JTVAE = JTVAE(hparams=args, vocab=datamodule.vocab) vae.eval() # Set up some stuff for the progress bar num_retrain = int(np.ceil(args.query_budget / args.retraining_frequency)) postfix = dict(retrain_left=num_retrain, best=float(datamodule.train_dataset.data_properties.max()), n_train=len(datamodule.train_dataset.data), save_path=result_dir) start_num_retrain = 0 # Set up results tracking results = dict( opt_points=[], opt_point_properties=[], opt_model_version=[], params=str(sys.argv), sample_points=[], sample_versions=[], sample_properties=[], ) result_filepath = os.path.join(result_dir, 'results.npz') if not args.overwrite and os.path.exists(result_filepath): with np.load(result_filepath, allow_pickle=True) as npz: results = {} for k in list(npz.keys()): results[k] = npz[k] if k != 'params': results[k] = list(results[k]) else: results[k] = npz[k].item() start_num_retrain = results['opt_model_version'][-1] + 1 prev_retrain_model = args.retraining_frequency * (start_num_retrain - 1) num_sampled_points = len(results['opt_points']) if args.n_init_retrain_epochs == 0 and prev_retrain_model == 0: pretrained_model_path = args.pretrained_model_file else: pretrained_model_path = os.path.join( result_dir, 'retraining', f'retrain_{prev_retrain_model}', 'checkpoints', 'last.ckpt') print(f"Found checkpoint at {pretrained_model_path}") ckpt = torch.load(pretrained_model_path) ckpt['hyper_parameters']['hparams'].metric_loss = args.metric_loss ckpt['hyper_parameters'][ 'hparams'].metric_loss_kw = args.metric_loss_kw ckpt['hyper_parameters'][ 'hparams'].beta_metric_loss = args.beta_metric_loss ckpt['hyper_parameters'][ 'hparams'].beta_target_pred_loss = args.beta_target_pred_loss if args.predict_target: ckpt['hyper_parameters']['hparams'].predict_target = True ckpt['hyper_parameters'][ 'hparams'].target_predictor_hdims = args.target_predictor_hdims torch.save(ckpt, pretrained_model_path) print(f"Loading model from {pretrained_model_path}") vae.load_from_checkpoint(pretrained_model_path, vocab=datamodule.vocab) if args.predict_target and not hasattr(vae.hparams, 'predict_target'): vae.hparams.target_predictor_hdims = args.target_predictor_hdims vae.hparams.predict_target = args.predict_target # vae.hparams.cuda = args.cuda vae.beta = vae.hparams.beta_final # Override any beta annealing vae.eval() # Set up some stuff for the progress bar num_retrain = int( np.ceil(args.query_budget / args.retraining_frequency)) - start_num_retrain print(f"Append existing points and properties to datamodule...") datamodule.append_train_data(np.array(results['opt_points']), np.array(results['opt_point_properties'])) postfix = dict(retrain_left=num_retrain, best=float( datamodule.train_dataset.data_properties.max()), n_train=len(datamodule.train_dataset.data), initial=num_sampled_points, save_path=result_dir) print( f"Retrain from {result_dir} | Best: {max(results['opt_point_properties'])}" ) start_time = time.time() # Main loop with tqdm(total=args.query_budget, dynamic_ncols=True, smoothing=0.0, file=sys.stdout) as pbar: for ret_idx in range(start_num_retrain, start_num_retrain + num_retrain): if vae.predict_target and vae.metric_loss is not None: vae.training_m = datamodule.training_m vae.training_M = datamodule.training_M vae.validation_m = datamodule.validation_m vae.validation_M = datamodule.validation_M torch.cuda.empty_cache() # Free the memory up for tensorflow pbar.set_postfix(postfix) pbar.set_description("retraining") print(result_dir) # Decide whether to retrain samples_so_far = args.retraining_frequency * ret_idx # Optionally do retraining num_epochs = args.n_retrain_epochs if ret_idx == 0 and args.n_init_retrain_epochs is not None: num_epochs = args.n_init_retrain_epochs if num_epochs > 0: retrain_dir = os.path.join(result_dir, "retraining") version = f"retrain_{samples_so_far}" retrain_model(model=vae, datamodule=datamodule, save_dir=retrain_dir, version_str=version, num_epochs=num_epochs, gpu=args.gpu, store_best=args.train_only, best_ckpt_path=args.save_model_path) vae.eval() if args.train_only: return del num_epochs model = vae # Update progress bar postfix["retrain_left"] -= 1 pbar.set_postfix(postfix) # Draw samples for logs! if args.samples_per_model > 0: pbar.set_description("sampling") with trange(args.samples_per_model, desc="sampling", leave=False) as sample_pbar: sample_x, sample_y = latent_sampling( args, model, datamodule, args.samples_per_model, pbar=sample_pbar) # Append to results dict results["sample_points"].append(sample_x) results["sample_properties"].append(sample_y) results["sample_versions"].append(ret_idx) # Do querying! pbar.set_description("querying") num_queries_to_do = min(args.retraining_frequency, args.query_budget - samples_so_far) if args.lso_strategy == "opt": gp_dir = os.path.join(result_dir, "gp", f"iter{samples_so_far}") os.makedirs(gp_dir, exist_ok=True) gp_data_file = os.path.join(gp_dir, "data.npz") gp_err_data_file = os.path.join(gp_dir, "data_err.npz") x_new, y_new = latent_optimization( model=model, datamodule=datamodule, n_inducing_points=args.n_inducing_points, n_best_points=args.n_best_points, n_rand_points=args.n_rand_points, num_queries_to_do=num_queries_to_do, gp_data_file=gp_data_file, gp_err_data_file=gp_err_data_file, gp_run_folder=gp_dir, gpu=args.gpu, invalid_score=args.invalid_score, pbar=pbar, postfix=postfix, error_aware_acquisition=args.error_aware_acquisition, ) elif args.lso_strategy == "sample": x_new, y_new = latent_sampling( args, model, datamodule, num_queries_to_do, pbar=pbar, ) else: raise NotImplementedError(args.lso_strategy) # Update dataset datamodule.append_train_data(x_new, y_new) # Add new results results["opt_points"] += list(x_new) results["opt_point_properties"] += list(y_new) results["opt_model_version"] += [ret_idx] * len(x_new) postfix["best"] = max(postfix["best"], float(max(y_new))) postfix["n_train"] = len(datamodule.train_dataset.data) pbar.set_postfix(postfix) # Save results np.savez_compressed(os.path.join(result_dir, "results.npz"), **results) # Keep a record of the dataset here new_data_file = os.path.join( data_dir, f"train_data_iter{samples_so_far + num_queries_to_do}.txt") with open(new_data_file, "w") as f: f.write("\n".join(datamodule.train_dataset.canonic_smiles)) print_flush("=== DONE ({:.3f}s) ===".format(time.time() - start_time))
def main(): # Create arg parser parser = argparse.ArgumentParser() parser = EquationVaeTorch.add_model_specific_args(parser) parser = WeightedExprDataset.add_model_specific_args(parser) parser = utils.DataWeighter.add_weight_args(parser) utils.add_default_trainer_args(parser, default_root='') parser.add_argument("--ignore_percentile", type=int, default=50, help="percentile of scores to ignore") parser.add_argument("--good_percentile", type=int, default=0, help="percentile of good scores selected") parser.add_argument("--data_seed", type=int, required=True, help="Seed that has been used to generate the dataset") # Parse arguments hparams = parser.parse_args() hparams.dataset_path = get_filepath( hparams.ignore_percentile, hparams.dataset_path, hparams.data_seed, good_percentile=hparams.good_percentile) hparams.root_dir = get_path( k=hparams.rank_weight_k, ignore_percentile=hparams.ignore_percentile, good_percentile=hparams.good_percentile, n_max_epochs=hparams.max_epochs, predict_target=hparams.predict_target, beta_final=hparams.beta_final, beta_target_pred_loss=hparams.beta_target_pred_loss, beta_metric_loss=hparams.beta_metric_loss, latent_dim=hparams.latent_dim, hdims=hparams.target_predictor_hdims, metric_loss=hparams.metric_loss, metric_loss_kw=hparams.metric_loss_kw) print_flush(' '.join(sys.argv[1:])) print_flush(hparams.root_dir) pl.seed_everything(hparams.seed) # Create data datamodule = WeightedExprDataset(hparams, utils.DataWeighter(hparams), add_channel=False) device = hparams.cuda if device is not None: torch.cuda.set_device(device) data_info = G.gram.split('\n') # Load model model = EquationVaeTorch(hparams, len(data_info), MAX_LEN) # model.decoder.apply(torch_weight_init) checkpoint_callback = pl.callbacks.ModelCheckpoint(period=max( 1, hparams.max_epochs // 20), monitor="loss/val", save_top_k=-1, save_last=True, mode='min') if hparams.load_from_checkpoint is not None: # .load_from_checkpoint(hparams.load_from_checkpoint) model = EquationVaeTorch.load_from_checkpoint( hparams.load_from_checkpoint, len(data_info), MAX_LEN) utils.update_hparams(hparams, model) trainer = pl.Trainer( gpus=[hparams.cuda] if hparams.cuda else 0, default_root_dir=hparams.root_dir, max_epochs=hparams.max_epochs, callbacks=[ checkpoint_callback, LearningRateMonitor(logging_interval='step') ], resume_from_checkpoint=hparams.load_from_checkpoint) print(f'Load from checkpoint') else: # Main trainer trainer = pl.Trainer( gpus=[hparams.cuda] if hparams.cuda is not None else 0, default_root_dir=hparams.root_dir, max_epochs=hparams.max_epochs, checkpoint_callback=True, callbacks=[ checkpoint_callback, LearningRateMonitor(logging_interval='step') ], terminate_on_nan=True, progress_bar_refresh_rate=100) # Fit trainer.fit(model, datamodule=datamodule) print( f"Training finished; end of script: rename {checkpoint_callback.best_model_path}" ) shutil.copyfile( checkpoint_callback.best_model_path, os.path.join(os.path.dirname(checkpoint_callback.best_model_path), 'best.ckpt'))
if __name__ == "__main__": # Create arg parser parser = argparse.ArgumentParser() parser = JTVAE.add_model_specific_args(parser) parser = WeightedJTNNDataset.add_model_specific_args(parser) parser = utils.DataWeighter.add_weight_args(parser) utils.add_default_trainer_args(parser, default_root=None) # Parse arguments hparams = parser.parse_args() hparams.root_dir = os.path.join(get_storage_root(), hparams.root_dir) pl.seed_everything(hparams.seed) print_flush(' '.join(sys.argv[1:])) # Create data datamodule = WeightedJTNNDataset(hparams, utils.DataWeighter(hparams)) datamodule.setup("fit") # Load model model = JTVAE(hparams, datamodule.vocab) checkpoint_callback = pl.callbacks.ModelCheckpoint( period=1, monitor="loss/val", save_top_k=1, save_last=True, mode='min' ) if hparams.load_from_checkpoint is not None: # .load_from_checkpoint(hparams.load_from_checkpoint)
def main(): # Create arg parser parser = argparse.ArgumentParser() parser = ShapesVAE.add_model_specific_args(parser) parser = WeightedNumpyDataset.add_model_specific_args(parser) parser = utils.DataWeighter.add_weight_args(parser) utils.add_default_trainer_args(parser, default_root="") # Parse arguments hparams = parser.parse_args() hparams.root_dir = shape_get_path(k=hparams.rank_weight_k, predict_target=hparams.predict_target, hdims=hparams.target_predictor_hdims, metric_loss=hparams.metric_loss, metric_loss_kw=hparams.metric_loss_kw, latent_dim=hparams.latent_dim) print_flush(' '.join(sys.argv[1:])) print_flush(hparams.root_dir) pl.seed_everything(hparams.seed) # Create data datamodule = WeightedNumpyDataset(hparams, utils.DataWeighter(hparams)) # Load model model = ShapesVAE(hparams) checkpoint_callback = pl.callbacks.ModelCheckpoint( period=max(1, hparams.max_epochs // 20), monitor="loss/val", save_top_k=1, save_last=True, mode='min' ) if hparams.load_from_checkpoint is not None: model = ShapesVAE.load_from_checkpoint(hparams.load_from_checkpoint) utils.update_hparams(hparams, model) trainer = pl.Trainer(gpus=[hparams.cuda] if hparams.cuda is not None else 0, default_root_dir=hparams.root_dir, max_epochs=hparams.max_epochs, callbacks=[checkpoint_callback, LearningRateMonitor(logging_interval='step')], resume_from_checkpoint=hparams.load_from_checkpoint) print(f'Load from checkpoint') else: # Main trainer trainer = pl.Trainer( gpus=[hparams.cuda] if hparams.cuda is not None else 0, default_root_dir=hparams.root_dir, max_epochs=hparams.max_epochs, checkpoint_callback=True, callbacks=[checkpoint_callback, LearningRateMonitor(logging_interval='step')], terminate_on_nan=True, progress_bar_refresh_rate=100 ) # Fit trainer.fit(model, datamodule=datamodule) print(f"Training finished; end of script: rename {checkpoint_callback.best_model_path}") shutil.copyfile(checkpoint_callback.best_model_path, os.path.join( os.path.dirname(checkpoint_callback.best_model_path), 'best.ckpt' ))