Example #1
0
def get_latent_encodings(use_test_set,
                         use_full_data_for_gp,
                         model,
                         data_file,
                         data_set: WeightedMolTreeFolder,
                         n_best,
                         n_rand,
                         true_vals: bool,
                         tkwargs: Dict[str, Any],
                         bs=64,
                         return_inds: bool = False):
    """ get latent encodings and split data into train and test data """

    print_flush(
        "\tComputing latent training data encodings and corresponding scores..."
    )

    if len(data_set) < n_best + n_rand:
        n_best, n_rand = int(n_best / (n_best + n_rand) * len(data_set)), int(
            n_rand / (n_best + n_rand) * len(data_set))
        n_rand += 1 if n_best + n_rand < len(data_set) else 0

    if use_full_data_for_gp:
        chosen_indices = np.arange(len(data_set))
    else:
        chosen_indices = _choose_best_rand_points(n_best, n_rand, data_set)
    mol_trees = [data_set.data[i] for i in chosen_indices]
    targets = data_set.data_properties[chosen_indices]

    # Next, encode these mol trees
    latent_points = _encode_mol_trees(model, mol_trees, batch_size=bs)

    targets = targets.reshape((-1, 1))

    # problem with train_inds returned by ubsample_dataset is they are train indices within passed points and not
    # indices of the original dataset
    if not use_full_data_for_gp:
        assert not use_test_set
        X_mean, X_std = latent_points.mean(), latent_points.std()
        y_mean, y_std = targets.mean(), targets.std()
        save_data(latent_points, targets, None, None, X_mean, X_std, y_mean,
                  y_std, data_file)
        if return_inds:
            return latent_points, targets, None, None, X_mean, y_mean, X_std, y_std, chosen_indices, None
        else:
            return latent_points, targets, None, None, X_mean, y_mean, X_std, y_std

    return subsample_dataset(latent_points,
                             targets,
                             data_file,
                             use_test_set,
                             use_full_data_for_gp,
                             n_best,
                             n_rand,
                             return_inds=return_inds)
def get_latent_encodings(use_test_set, use_full_data_for_gp, model, data_file, data_scores, data_imgs,
                         n_best, n_rand, tkwargs: Dict[str, Any],
                         bs=5000, bs_true_eval: int = 256, repeat: int = 10, return_inds: bool = False):
    """ get latent encodings and split data into train and test data """

    print_flush("\tComputing latent training data encodings and corresponding scores...")
    n_batches = int(np.ceil(len(data_imgs) / bs))

    if n_best > 0 and n_rand > 0 and (n_best + n_rand) < len(data_scores):
        # do not encode all data, it's too long, only encode the number of points needed (w.r.t. n_best+n_rand)
        sorted_idx = np.argsort(-data_scores)
        best_idx = sorted_idx[:n_best]
        rand_idx = sorted_idx[np.random.choice(list(range(n_best + 1, len(data_scores))), n_rand, replace=False)]
        n_best_scores = data_scores[best_idx]
        n_best_data = data_imgs[best_idx]
        n_rand_scores = data_scores[rand_idx]
        n_rand_data = data_imgs[rand_idx]
        # concatenate and then shuffle
        scores_best_cat_rand = np.concatenate([n_best_scores, n_rand_scores])
        data_best_cat_rand = np.concatenate([n_best_data, n_rand_data])
        cat_idx = np.arange(len(scores_best_cat_rand))
        cat_shuffled_idx = np.random.choice(cat_idx, len(cat_idx))
        scores_best_cat_rand = scores_best_cat_rand[cat_shuffled_idx]
        data_best_cat_rand = data_best_cat_rand[cat_shuffled_idx]
        n_batches = int(np.ceil(len(data_best_cat_rand) / bs))
        Xs = [model.encode_to_params(
            torch.from_numpy(data_best_cat_rand[i * bs:(i + 1) * bs]).to(**tkwargs).unsqueeze(1)
        )[0].detach().cpu().numpy() for i in tqdm(range(n_batches))]
    else:
        Xs = [model.encode_to_params(
            torch.from_numpy(data_imgs[i * bs:(i + 1) * bs]).to(**tkwargs).unsqueeze(1)
        )[0].detach().cpu().numpy() for i in tqdm(range(n_batches))]
    X = np.concatenate(Xs, axis=0)

    y = scores_best_cat_rand if n_best > 0 and n_rand > 0 and (n_best + n_rand) < len(data_scores) else data_scores
    y = y.reshape((-1, 1))

    if n_best > 0 and n_rand > 0 and (n_best + n_rand) < len(data_scores):
        assert not use_test_set
        assert not use_full_data_for_gp
        X_mean, X_std = X.mean(), X.std()
        y_mean, y_std = y.mean(), y.std()
        save_data(X, y, None, None, X_mean, X_std, y_mean, y_std, data_file)
        if return_inds:
            train_inds = np.concatenate([best_idx, rand_idx])[cat_shuffled_idx]
            return X, y, None, None, X_mean, y_mean, X_std, y_std, train_inds, None
        else:
            return X, y, None, None, X_mean, y_mean, X_std, y_std
    return subsample_dataset(X, y, data_file, use_test_set, use_full_data_for_gp, n_best, n_rand,
                             return_inds=return_inds)
Example #3
0
def get_latent_encodings(use_test_set: bool,
                         use_full_data_for_gp: bool,
                         model: ShapesVAE,
                         data_file: str,
                         data_imgs: np.ndarray,
                         data_scores: np.ndarray,
                         n_best: int,
                         n_rand: int,
                         tkwargs: Dict[str, Any],
                         bs=1000):
    """ get latent encodings and split data into train and test data """

    print_flush(
        "\tComputing latent training data encodings and corresponding scores..."
    )
    X = get_latent_encodings_aux(model=model,
                                 data_imgs=data_imgs,
                                 bs=bs,
                                 tkwargs=tkwargs)
    y = data_scores.reshape((-1, 1))

    return _subsample_dataset(X, y, data_file, use_test_set,
                              use_full_data_for_gp, n_best, n_rand)
Example #4
0
def main_aux(args, result_dir: str):
    """ main """

    # Seeding
    pl.seed_everything(args.seed)

    # Make results directory
    data_dir = os.path.join(result_dir, "data")
    os.makedirs(data_dir, exist_ok=True)
    setup_logger(os.path.join(result_dir, "log.txt"))

    result_filepath = os.path.join(result_dir, 'results.pkl')
    if not args.overwrite and os.path.exists(result_filepath):
        print(f"Already exists: {result_dir}")
        return

    # Load data
    datamodule = WeightedJTNNDataset(args, utils.DataWeighter(args))
    datamodule.setup("fit", n_init_points=args.n_init_bo_points)

    # print python command run
    cmd = ' '.join(sys.argv[1:])
    print_flush(f"{cmd}\n")

    # Load model
    assert args.use_pretrained

    if args.predict_target:
        if 'pred_y' in args.pretrained_model_file:
            # fully supervised setup from a model trained with target prediction
            ckpt = torch.load(args.pretrained_model_file)
            ckpt['hyper_parameters']['hparams'].beta_target_pred_loss = args.beta_target_pred_loss
            ckpt['hyper_parameters']['hparams'].predict_target = True
            ckpt['hyper_parameters']['hparams'].target_predictor_hdims = args.target_predictor_hdims
            torch.save(ckpt, args.pretrained_model_file)
    vae: JTVAE = JTVAE.load_from_checkpoint(args.pretrained_model_file, vocab=datamodule.vocab)
    vae.beta = vae.hparams.beta_final  # Override any beta annealing
    vae.metric_loss = args.metric_loss
    vae.hparams.metric_loss = args.metric_loss
    vae.beta_metric_loss = args.beta_metric_loss
    vae.hparams.beta_metric_loss = args.beta_metric_loss
    vae.metric_loss_kw = args.metric_loss_kw
    vae.hparams.metric_loss_kw = args.metric_loss_kw
    vae.predict_target = args.predict_target
    vae.hparams.predict_target = args.predict_target
    vae.beta_target_pred_loss = args.beta_target_pred_loss
    vae.hparams.beta_target_pred_loss = args.beta_target_pred_loss
    vae.target_predictor_hdims = args.target_predictor_hdims
    vae.hparams.target_predictor_hdims = args.target_predictor_hdims
    if vae.predict_target and vae.target_predictor is None:
        vae.hparams.target_predictor_hdims = args.target_predictor_hdims
        vae.hparams.predict_target = args.predict_target
        vae.build_target_predictor()
    vae.eval()

    # Set up some stuff for the progress bar
    postfix = dict(
        n_train=len(datamodule.train_dataset.data),
        save_path=result_dir
    )

    # Set up results tracking
    start_time = time.time()

    train_chosen_indices = _choose_best_rand_points(n_rand_points=args.n_rand_points, n_best_points=args.n_best_points,
                                                    dataset=datamodule.train_dataset)
    train_mol_trees = [datamodule.train_dataset.data[i] for i in train_chosen_indices]
    train_targets = datamodule.train_dataset.data_properties[train_chosen_indices]
    train_chosen_smiles = [datamodule.train_dataset.canonic_smiles[i] for i in train_chosen_indices]

    test_chosen_indices = _choose_best_rand_points(n_rand_points=args.n_test_points, n_best_points=0,
                                                   dataset=datamodule.val_dataset)
    test_mol_trees = [datamodule.val_dataset.data[i] for i in test_chosen_indices]

    # Main loop
    with tqdm(
            total=1, dynamic_ncols=True, smoothing=0.0, file=sys.stdout
    ) as pbar:

        if vae.predict_target and vae.metric_loss is not None:
            vae.training_m = datamodule.training_m
            vae.training_M = datamodule.training_M
            vae.validation_m = datamodule.validation_m
            vae.validation_M = datamodule.validation_M

        torch.cuda.empty_cache()  # Free the memory up for tensorflow
        pbar.set_postfix(postfix)
        pbar.set_description("retraining")
        print(result_dir)

        # Optionally do retraining
        num_epochs =  args.n_init_retrain_epochs
        if num_epochs > 0:
            retrain_dir = os.path.join(result_dir, "retraining")
            version = f"retrain_0"
            retrain_model(
                model=vae, datamodule=datamodule, save_dir=retrain_dir,
                version_str=version, num_epochs=num_epochs, gpu=args.gpu
            )
            vae.eval()
        del num_epochs

        model = vae

        # Update progress bar
        pbar.set_postfix(postfix)

        # Do querying!
        gp_dir = os.path.join(result_dir, "gp")
        os.makedirs(gp_dir, exist_ok=True)
        gp_data_file = os.path.join(gp_dir, "data.npz")

        # Next, encode these mol trees
        if args.gpu:
            model = model.cuda()
        train_latent_points = _encode_mol_trees(model, train_mol_trees)
        test_latent_points = _encode_mol_trees(model, test_mol_trees)
        if args.use_decoded:
            print("Use targets from decoded latent test points")
            _, test_targets = _batch_decode_z_and_props(
                model,
                torch.as_tensor(test_latent_points, device=model.device),
                datamodule,
                invalid_score=args.invalid_score,
                pbar=pbar,
            )
            test_targets = np.array(test_targets)
        else:
            test_targets = datamodule.val_dataset.data_properties[test_chosen_indices]

        model = model.cpu()  # Make sure to free up GPU memory
        torch.cuda.empty_cache()  # Free the memory up for tensorflow

        # Save points to file
        def _save_gp_data(x, y, test_x, y_test, s, file, flip_sign=True):

            # Prevent overfitting to bad points
            y = np.maximum(y, args.invalid_score)
            if flip_sign:
                y = -y.reshape(-1, 1)  # Since it is a maximization problem
                y_test = -y_test.reshape(-1, 1)
            else:
                y = y.reshape(-1, 1)
                y_test = y_test.reshape(-1, 1)

            # Save the file
            np.savez_compressed(
                file,
                X_train=x.astype(np.float32),
                X_test=test_x.astype(np.float32),
                y_train=y.astype(np.float32),
                y_test=y_test.astype(np.float32),
                smiles=s,
            )

        _save_gp_data(train_latent_points, train_targets, test_latent_points, test_targets, train_chosen_smiles,
                      gp_data_file)
        current_n_inducing_points = min(train_latent_points.shape[0], args.n_inducing_points)

        new_gp_file = os.path.join(gp_dir, f"new.npz")
        log_path = os.path.join(gp_dir, f"gp_fit.log")

        iter_seed = int(np.random.randint(10000))
        gp_train_command = [
            "python",
            GP_TRAIN_FILE,
            f"--nZ={current_n_inducing_points}",
            f"--seed={iter_seed}",
            f"--data_file={str(gp_data_file)}",
            f"--save_file={str(new_gp_file)}",
            f"--logfile={str(log_path)}",
            "--use_test_set"
        ]
        gp_fit_desc = "GP initial fit"
        gp_train_command += [
            "--init",
            "--kmeans_init",
            f"--save_metrics_file={str(result_filepath)}"
        ]
        # Set pbar status for user
        if pbar is not None:
            old_desc = pbar.desc
            pbar.set_description(gp_fit_desc)

        _run_command(gp_train_command, f"GP train {0}")
        curr_gp_file = new_gp_file

    print_flush("=== DONE ({:.3f}s) ===".format(time.time() - start_time))
def main():
    # Create arg parser
    parser = argparse.ArgumentParser()
    parser = TopologyVAE.add_model_specific_args(parser)
    parser = WeightedNumpyDataset.add_model_specific_args(parser)
    parser = utils.DataWeighter.add_weight_args(parser)
    utils.add_default_trainer_args(parser, default_root="")

    parser.add_argument(
        "--augment_dataset",
        action='store_true',
        help="Use data augmentation or not",
    )
    parser.add_argument(
        "--use_binary_data",
        action='store_true',
        help="Binarize images in the dataset",
    )

    # Parse arguments
    hparams = parser.parse_args()
    hparams.root_dir = topology_get_path(
        k=hparams.rank_weight_k,
        n_max_epochs=hparams.max_epochs,
        predict_target=hparams.predict_target,
        hdims=hparams.target_predictor_hdims,
        metric_loss=hparams.metric_loss,
        metric_loss_kw=hparams.metric_loss_kw,
        beta_target_pred_loss=hparams.beta_target_pred_loss,
        beta_metric_loss=hparams.beta_metric_loss,
        latent_dim=hparams.latent_dim,
        beta_final=hparams.beta_final,
        use_binary_data=hparams.use_binary_data)
    print_flush(' '.join(sys.argv[1:]))
    print_flush(hparams.root_dir)
    pl.seed_everything(hparams.seed)

    # Create data
    if hparams.use_binary_data:
        if not os.path.exists(
                os.path.join(get_data_root(), 'topology_data/target_bin.npy')):
            gen_binary_dataset_from_all_files(get_data_root())
        hparams.dataset_path = os.path.join(ROOT_PROJECT,
                                            get_topology_binary_dataset_path())
    else:
        if not os.path.exists(
                os.path.join(get_data_root(), 'topology_data/target.npy')):
            gen_dataset_from_all_files(get_data_root())
        hparams.dataset_path = os.path.join(ROOT_PROJECT,
                                            get_topology_dataset_path())
    if hparams.augment_dataset:
        aug = transforms.Compose([
            # transforms.Normalize(mean=, std=),
            # transforms.RandomCrop(30, padding=10),
            transforms.RandomRotation(45),
            transforms.RandomRotation(90),
            transforms.RandomRotation(180),
            transforms.RandomVerticalFlip(0.5)
        ])
    else:
        aug = None
    datamodule = WeightedNumpyDataset(hparams,
                                      utils.DataWeighter(hparams),
                                      transform=aug)

    # Load model
    model = TopologyVAE(hparams)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(period=max(
        1, hparams.max_epochs // 10),
                                                       monitor="loss/val",
                                                       save_top_k=-1,
                                                       save_last=True,
                                                       mode='min')

    if hparams.load_from_checkpoint is not None:
        model = TopologyVAE.load_from_checkpoint(hparams.load_from_checkpoint)
        utils.update_hparams(hparams, model)
        trainer = pl.Trainer(
            gpus=[hparams.cuda] if hparams.cuda else 0,
            default_root_dir=hparams.root_dir,
            max_epochs=hparams.max_epochs,
            callbacks=[
                checkpoint_callback,
                LearningRateMonitor(logging_interval='step')
            ],
            resume_from_checkpoint=hparams.load_from_checkpoint)

        print(f'Load from checkpoint')
    else:
        # Main trainer
        trainer = pl.Trainer(
            gpus=[hparams.cuda] if hparams.cuda is not None else 0,
            default_root_dir=hparams.root_dir,
            max_epochs=hparams.max_epochs,
            checkpoint_callback=True,
            callbacks=[
                checkpoint_callback,
                LearningRateMonitor(logging_interval='step')
            ],
            terminate_on_nan=True,
            progress_bar_refresh_rate=5,
            # gradient_clip_val=20.0,
        )

    # Fit
    trainer.fit(model, datamodule=datamodule)

    print(
        f"Training finished; end of script: rename {checkpoint_callback.best_model_path}"
    )

    shutil.copyfile(
        checkpoint_callback.best_model_path,
        os.path.join(os.path.dirname(checkpoint_callback.best_model_path),
                     'best.ckpt'))
def latent_optimization(
    model: JTVAE,
    datamodule: WeightedJTNNDataset,
    n_inducing_points: int,
    n_best_points: int,
    n_rand_points: int,
    num_queries_to_do: int,
    invalid_score: float,
    gp_data_file: str,
    gp_run_folder: str,
    gpu: bool,
    error_aware_acquisition: bool,
    gp_err_data_file: Optional[str],
    pbar=None,
    postfix=None,
):
    ##################################################
    # Prepare GP
    ##################################################

    # First, choose GP points to train!
    dset = datamodule.train_dataset

    chosen_indices = _choose_best_rand_points(n_rand_points=n_rand_points,
                                              n_best_points=n_best_points,
                                              dataset=dset)
    mol_trees = [dset.data[i] for i in chosen_indices]
    targets = dset.data_properties[chosen_indices]
    chosen_smiles = [dset.canonic_smiles[i] for i in chosen_indices]

    # Next, encode these mol trees
    if gpu:
        model = model.cuda()
    latent_points = _encode_mol_trees(model, mol_trees)
    model = model.cpu()  # Make sure to free up GPU memory
    torch.cuda.empty_cache()  # Free the memory up for tensorflow

    # Save points to file
    def _save_gp_data(x, y, s, file, flip_sign=True):

        # Prevent overfitting to bad points
        y = np.maximum(y, invalid_score)
        if flip_sign:
            y = -y.reshape(-1, 1)  # Since it is a maximization problem
        else:
            y = y.reshape(-1, 1)

        # Save the file
        np.savez_compressed(
            file,
            X_train=x.astype(np.float32),
            X_test=[],
            y_train=y.astype(np.float32),
            y_test=[],
            smiles=s,
        )

    # If using error-aware acquisition, compute reconstruction error of selected points
    if error_aware_acquisition:
        assert gp_err_data_file is not None, "Please provide a data file for the error GP"
        if gpu:
            model = model.cuda()
        error_train, safe_idx = get_rec_x_error(
            model,
            tkwargs={'dtype': torch.float},
            data=[datamodule.train_dataset.data[i] for i in chosen_indices],
        )
        # exclude points for which we could not compute the reconstruction error from the objective GP dataset
        if len(safe_idx) < latent_points.shape[0]:
            failed = [
                i for i in range(latent_points.shape[0]) if i not in safe_idx
            ]
            print_flush(
                f"Could not compute the recon. err. of {len(failed)} points -> excluding them."
            )
            latent_points_err = latent_points[safe_idx]
            chosen_smiles_err = [chosen_smiles[i] for i in safe_idx]
        else:
            latent_points_err = latent_points
            chosen_smiles_err = chosen_smiles
        model = model.cpu()  # Make sure to free up GPU memory
        torch.cuda.empty_cache()  # Free the memory up for tensorflow
        _save_gp_data(latent_points,
                      error_train.cpu().numpy(), chosen_smiles,
                      gp_err_data_file)
    _save_gp_data(latent_points,
                  targets,
                  chosen_smiles,
                  gp_data_file,
                  flip_sign=False)

    ##################################################
    # Run iterative GP fitting/optimization
    ##################################################
    curr_gp_file = None
    curr_gp_err_file = None
    all_new_smiles = []
    all_new_props = []
    all_new_err = []

    for gp_iter in range(num_queries_to_do):
        gp_initial_train = gp_iter == 0
        current_n_inducing_points = min(latent_points.shape[0],
                                        n_inducing_points)
        if latent_points.shape[0] == n_inducing_points:
            gp_initial_train = True

        # Part 1: fit GP
        # ===============================
        new_gp_file = os.path.join(gp_run_folder,
                                   f"gp_train_res{gp_iter:04d}.npz")
        new_gp_err_file = os.path.join(
            gp_run_folder,
            f"gp_err_train_res0000.npz")  # no incremental fit of error-GP
        log_path = os.path.join(gp_run_folder, f"gp_train{gp_iter:04d}.log")
        err_log_path = os.path.join(gp_run_folder, f"gp_err_train0000.log")
        try:
            iter_seed = int(np.random.randint(10000))
            gp_train_command = [
                "python", GP_TRAIN_FILE, f"--nZ={current_n_inducing_points}",
                f"--seed={iter_seed}", f"--data_file={str(gp_data_file)}",
                f"--save_file={str(new_gp_file)}",
                f"--logfile={str(log_path)}", f"--normal_inputs",
                f"--standard_targets"
            ]
            gp_err_train_command = [
                "python",
                GP_TRAIN_FILE,
                f"--nZ={n_inducing_points}",
                f"--seed={iter_seed}",
                f"--data_file={str(gp_err_data_file)}",
                f"--save_file={str(new_gp_err_file)}",
                f"--logfile={str(err_log_path)}",
            ]
            if gp_initial_train:

                # Add commands for initial fitting
                gp_fit_desc = "GP initial fit"
                gp_train_command += [
                    "--init",
                    "--kmeans_init",
                ]
                gp_err_train_command += [
                    "--init",
                    "--kmeans_init",
                ]
            else:
                gp_fit_desc = "GP incremental fit"
                gp_train_command += [
                    f"--gp_file={str(curr_gp_file)}",
                    f"--n_perf_measure=1",  # specifically see how well it fits the last point!
                ]
                gp_err_train_command += [
                    f"--gp_file={str(curr_gp_err_file)}",
                    f"--n_perf_measure=1",  # specifically see how well it fits the last point!
                ]

            # Set pbar status for user
            if pbar is not None:
                old_desc = pbar.desc
                pbar.set_description(gp_fit_desc)

            # Run command
            print_flush("Training objective GP...")
            _run_command(gp_train_command, f"GP train {gp_iter}")
            curr_gp_file = new_gp_file
            if error_aware_acquisition:
                if gp_initial_train:  # currently we do not incrementally refit this GP as we do not estimate rec. err.
                    _run_command(gp_err_train_command,
                                 f"GP err train {gp_iter}")
                    curr_gp_err_file = new_gp_err_file
        except AssertionError as e:
            logs = traceback.format_exc()
            print(logs)
            print_flush(
                f'Got an error in GP training. Retrying with different seed or crash...'
            )
            iter_seed = int(np.random.randint(10000))
            gp_train_command = [
                "python",
                GP_TRAIN_FILE,
                f"--nZ={current_n_inducing_points}",
                f"--seed={iter_seed}",
                f"--data_file={str(gp_data_file)}",
                f"--save_file={str(new_gp_file)}",
                f"--logfile={str(log_path)}",
            ]
            gp_err_train_command = [
                "python", GP_TRAIN_FILE, f"--nZ={n_inducing_points}",
                f"--seed={iter_seed}", f"--data_file={str(gp_err_data_file)}",
                f"--save_file={str(new_gp_err_file)}",
                f"--logfile={str(err_log_path)}", f"--normal_inputs",
                f"--standard_targets"
            ]
            if gp_initial_train:

                # Add commands for initial fitting
                gp_fit_desc = "GP initial fit"
                gp_train_command += [
                    "--init",
                    "--kmeans_init",
                ]
                gp_err_train_command += [
                    "--init",
                    "--kmeans_init",
                ]
            else:
                gp_fit_desc = "GP incremental fit"
                gp_train_command += [
                    f"--gp_file={str(curr_gp_file)}",
                    f"--n_perf_measure=1",  # specifically see how well it fits the last point!
                ]
                gp_err_train_command += [
                    f"--gp_file={str(curr_gp_err_file)}",
                    f"--n_perf_measure=1",  # specifically see how well it fits the last point!
                ]

            # Set pbar status for user
            if pbar is not None:
                old_desc = pbar.desc
                pbar.set_description(gp_fit_desc)

            # Run command
            _run_command(gp_train_command, f"GP train {gp_iter}")
            curr_gp_file = new_gp_file
            if error_aware_acquisition:
                if gp_initial_train:  # currently we do not incrementally refit this GP as we do not estimate rec. err.
                    _run_command(gp_err_train_command,
                                 f"GP err train {gp_iter}")
                    curr_gp_err_file = new_gp_err_file

        # Part 2: optimize GP acquisition func to query point
        # ===============================

        max_retry = 3
        n_retry = 0
        good = False
        while not good:
            try:
                # Run GP opt script
                opt_path = os.path.join(gp_run_folder,
                                        f"gp_opt_res{gp_iter:04d}.npy")
                log_path = os.path.join(gp_run_folder,
                                        f"gp_opt_{gp_iter:04d}.log")
                gp_opt_command = [
                    "python",
                    GP_OPT_FILE,
                    f"--seed={iter_seed}",
                    f"--gp_file={str(curr_gp_file)}",
                    f"--data_file={str(gp_data_file)}",
                    f"--save_file={str(opt_path)}",
                    f"--n_out={1}",  # hard coded
                    f"--logfile={str(log_path)}",
                ]
                if error_aware_acquisition:
                    gp_opt_command += [
                        f"--gp_err_file={str(curr_gp_err_file)}",
                        f"--data_err_file={str(gp_err_data_file)}",
                    ]

                if pbar is not None:
                    pbar.set_description("optimizing acq func")
                print_flush("Start running gp_opt_command")
                _run_command(gp_opt_command, f"GP opt {gp_iter}")

                # Load point
                z_opt = np.load(opt_path)

                # Decode point
                smiles_opt, prop_opt = _batch_decode_z_and_props(
                    model,
                    torch.as_tensor(z_opt, device=model.device),
                    datamodule,
                    invalid_score=invalid_score,
                    pbar=pbar,
                )
                good = True
            except AssertionError:
                iter_seed = int(np.random.randint(10000))
                n_retry += 1
                print_flush(
                    f'Got an error in optimization......trial {n_retry} / {max_retry}'
                )
                if n_retry >= max_retry:
                    raise

        # Reset pbar description
        if pbar is not None:
            pbar.set_description(old_desc)

            # Update best point in progress bar
            if postfix is not None:
                postfix["best"] = max(postfix["best"], float(max(prop_opt)))
                pbar.set_postfix(postfix)

        # Append to new GP data
        latent_points = np.concatenate([latent_points, z_opt], axis=0)
        targets = np.concatenate([targets, prop_opt], axis=0)
        chosen_smiles.append(smiles_opt)
        _save_gp_data(latent_points, targets, chosen_smiles, gp_data_file)

        # Append to overall list
        all_new_smiles += smiles_opt
        all_new_props += prop_opt

        if error_aware_acquisition:
            pass

    # Update datamodule with ALL data points
    return all_new_smiles, all_new_props
def main_aux(args, result_dir: str):
    """ main """

    # Seeding
    pl.seed_everything(args.seed)

    if args.train_only and os.path.exists(
            args.save_model_path) and not args.overwrite:
        print_flush(f'--- JTVAE already trained in {args.save_model_path} ---')
        return

    # Make results directory
    data_dir = os.path.join(result_dir, "data")
    os.makedirs(data_dir, exist_ok=True)
    setup_logger(os.path.join(result_dir, "log.txt"))

    # Load data
    datamodule = WeightedJTNNDataset(args, utils.DataWeighter(args))
    datamodule.setup("fit", n_init_points=args.n_init_bo_points)

    # print python command run
    cmd = ' '.join(sys.argv[1:])
    print_flush(f"{cmd}\n")

    # Load model
    if args.use_pretrained:
        if args.predict_target:
            if 'pred_y' in args.pretrained_model_file:
                # fully supervised training from a model already trained with target prediction
                ckpt = torch.load(args.pretrained_model_file)
                ckpt['hyper_parameters'][
                    'hparams'].beta_target_pred_loss = args.beta_target_pred_loss
                ckpt['hyper_parameters']['hparams'].predict_target = True
                ckpt['hyper_parameters'][
                    'hparams'].target_predictor_hdims = args.target_predictor_hdims
                torch.save(ckpt, args.pretrained_model_file)
        print(os.path.abspath(args.pretrained_model_file))
        vae: JTVAE = JTVAE.load_from_checkpoint(args.pretrained_model_file,
                                                vocab=datamodule.vocab)
        vae.beta = vae.hparams.beta_final  # Override any beta annealing
        vae.metric_loss = args.metric_loss
        vae.hparams.metric_loss = args.metric_loss
        vae.beta_metric_loss = args.beta_metric_loss
        vae.hparams.beta_metric_loss = args.beta_metric_loss
        vae.metric_loss_kw = args.metric_loss_kw
        vae.hparams.metric_loss_kw = args.metric_loss_kw
        vae.predict_target = args.predict_target
        vae.hparams.predict_target = args.predict_target
        vae.beta_target_pred_loss = args.beta_target_pred_loss
        vae.hparams.beta_target_pred_loss = args.beta_target_pred_loss
        vae.target_predictor_hdims = args.target_predictor_hdims
        vae.hparams.target_predictor_hdims = args.target_predictor_hdims
        if vae.predict_target and vae.target_predictor is None:
            vae.hparams.target_predictor_hdims = args.target_predictor_hdims
            vae.hparams.predict_target = args.predict_target
            vae.build_target_predictor()
    else:
        print("initialising VAE from scratch !")
        vae: JTVAE = JTVAE(hparams=args, vocab=datamodule.vocab)
    vae.eval()

    # Set up some stuff for the progress bar
    num_retrain = int(np.ceil(args.query_budget / args.retraining_frequency))
    postfix = dict(retrain_left=num_retrain,
                   best=float(datamodule.train_dataset.data_properties.max()),
                   n_train=len(datamodule.train_dataset.data),
                   save_path=result_dir)

    start_num_retrain = 0

    # Set up results tracking
    results = dict(
        opt_points=[],
        opt_point_properties=[],
        opt_model_version=[],
        params=str(sys.argv),
        sample_points=[],
        sample_versions=[],
        sample_properties=[],
    )

    result_filepath = os.path.join(result_dir, 'results.npz')
    if not args.overwrite and os.path.exists(result_filepath):
        with np.load(result_filepath, allow_pickle=True) as npz:
            results = {}
            for k in list(npz.keys()):
                results[k] = npz[k]
                if k != 'params':
                    results[k] = list(results[k])
                else:
                    results[k] = npz[k].item()
        start_num_retrain = results['opt_model_version'][-1] + 1

        prev_retrain_model = args.retraining_frequency * (start_num_retrain -
                                                          1)
        num_sampled_points = len(results['opt_points'])
        if args.n_init_retrain_epochs == 0 and prev_retrain_model == 0:
            pretrained_model_path = args.pretrained_model_file
        else:
            pretrained_model_path = os.path.join(
                result_dir, 'retraining', f'retrain_{prev_retrain_model}',
                'checkpoints', 'last.ckpt')
        print(f"Found checkpoint at {pretrained_model_path}")
        ckpt = torch.load(pretrained_model_path)
        ckpt['hyper_parameters']['hparams'].metric_loss = args.metric_loss
        ckpt['hyper_parameters'][
            'hparams'].metric_loss_kw = args.metric_loss_kw
        ckpt['hyper_parameters'][
            'hparams'].beta_metric_loss = args.beta_metric_loss
        ckpt['hyper_parameters'][
            'hparams'].beta_target_pred_loss = args.beta_target_pred_loss
        if args.predict_target:
            ckpt['hyper_parameters']['hparams'].predict_target = True
            ckpt['hyper_parameters'][
                'hparams'].target_predictor_hdims = args.target_predictor_hdims
        torch.save(ckpt, pretrained_model_path)
        print(f"Loading model from {pretrained_model_path}")
        vae.load_from_checkpoint(pretrained_model_path, vocab=datamodule.vocab)
        if args.predict_target and not hasattr(vae.hparams, 'predict_target'):
            vae.hparams.target_predictor_hdims = args.target_predictor_hdims
            vae.hparams.predict_target = args.predict_target
        # vae.hparams.cuda = args.cuda
        vae.beta = vae.hparams.beta_final  # Override any beta annealing
        vae.eval()

        # Set up some stuff for the progress bar
        num_retrain = int(
            np.ceil(args.query_budget /
                    args.retraining_frequency)) - start_num_retrain

        print(f"Append existing points and properties to datamodule...")
        datamodule.append_train_data(np.array(results['opt_points']),
                                     np.array(results['opt_point_properties']))
        postfix = dict(retrain_left=num_retrain,
                       best=float(
                           datamodule.train_dataset.data_properties.max()),
                       n_train=len(datamodule.train_dataset.data),
                       initial=num_sampled_points,
                       save_path=result_dir)
        print(
            f"Retrain from {result_dir} | Best: {max(results['opt_point_properties'])}"
        )
    start_time = time.time()

    # Main loop
    with tqdm(total=args.query_budget,
              dynamic_ncols=True,
              smoothing=0.0,
              file=sys.stdout) as pbar:

        for ret_idx in range(start_num_retrain,
                             start_num_retrain + num_retrain):

            if vae.predict_target and vae.metric_loss is not None:
                vae.training_m = datamodule.training_m
                vae.training_M = datamodule.training_M
                vae.validation_m = datamodule.validation_m
                vae.validation_M = datamodule.validation_M

            torch.cuda.empty_cache()  # Free the memory up for tensorflow
            pbar.set_postfix(postfix)
            pbar.set_description("retraining")
            print(result_dir)
            # Decide whether to retrain
            samples_so_far = args.retraining_frequency * ret_idx

            # Optionally do retraining
            num_epochs = args.n_retrain_epochs
            if ret_idx == 0 and args.n_init_retrain_epochs is not None:
                num_epochs = args.n_init_retrain_epochs
            if num_epochs > 0:
                retrain_dir = os.path.join(result_dir, "retraining")
                version = f"retrain_{samples_so_far}"
                retrain_model(model=vae,
                              datamodule=datamodule,
                              save_dir=retrain_dir,
                              version_str=version,
                              num_epochs=num_epochs,
                              gpu=args.gpu,
                              store_best=args.train_only,
                              best_ckpt_path=args.save_model_path)
                vae.eval()
                if args.train_only:
                    return
            del num_epochs

            model = vae

            # Update progress bar
            postfix["retrain_left"] -= 1
            pbar.set_postfix(postfix)

            # Draw samples for logs!
            if args.samples_per_model > 0:
                pbar.set_description("sampling")
                with trange(args.samples_per_model,
                            desc="sampling",
                            leave=False) as sample_pbar:
                    sample_x, sample_y = latent_sampling(
                        args,
                        model,
                        datamodule,
                        args.samples_per_model,
                        pbar=sample_pbar)

                # Append to results dict
                results["sample_points"].append(sample_x)
                results["sample_properties"].append(sample_y)
                results["sample_versions"].append(ret_idx)

            # Do querying!
            pbar.set_description("querying")
            num_queries_to_do = min(args.retraining_frequency,
                                    args.query_budget - samples_so_far)
            if args.lso_strategy == "opt":
                gp_dir = os.path.join(result_dir, "gp",
                                      f"iter{samples_so_far}")
                os.makedirs(gp_dir, exist_ok=True)
                gp_data_file = os.path.join(gp_dir, "data.npz")
                gp_err_data_file = os.path.join(gp_dir, "data_err.npz")
                x_new, y_new = latent_optimization(
                    model=model,
                    datamodule=datamodule,
                    n_inducing_points=args.n_inducing_points,
                    n_best_points=args.n_best_points,
                    n_rand_points=args.n_rand_points,
                    num_queries_to_do=num_queries_to_do,
                    gp_data_file=gp_data_file,
                    gp_err_data_file=gp_err_data_file,
                    gp_run_folder=gp_dir,
                    gpu=args.gpu,
                    invalid_score=args.invalid_score,
                    pbar=pbar,
                    postfix=postfix,
                    error_aware_acquisition=args.error_aware_acquisition,
                )
            elif args.lso_strategy == "sample":
                x_new, y_new = latent_sampling(
                    args,
                    model,
                    datamodule,
                    num_queries_to_do,
                    pbar=pbar,
                )
            else:
                raise NotImplementedError(args.lso_strategy)

            # Update dataset
            datamodule.append_train_data(x_new, y_new)

            # Add new results
            results["opt_points"] += list(x_new)
            results["opt_point_properties"] += list(y_new)
            results["opt_model_version"] += [ret_idx] * len(x_new)

            postfix["best"] = max(postfix["best"], float(max(y_new)))
            postfix["n_train"] = len(datamodule.train_dataset.data)
            pbar.set_postfix(postfix)

            # Save results
            np.savez_compressed(os.path.join(result_dir, "results.npz"),
                                **results)

            # Keep a record of the dataset here
            new_data_file = os.path.join(
                data_dir,
                f"train_data_iter{samples_so_far + num_queries_to_do}.txt")
            with open(new_data_file, "w") as f:
                f.write("\n".join(datamodule.train_dataset.canonic_smiles))

    print_flush("=== DONE ({:.3f}s) ===".format(time.time() - start_time))
Example #8
0
def main():
    # Create arg parser
    parser = argparse.ArgumentParser()
    parser = EquationVaeTorch.add_model_specific_args(parser)
    parser = WeightedExprDataset.add_model_specific_args(parser)
    parser = utils.DataWeighter.add_weight_args(parser)
    utils.add_default_trainer_args(parser, default_root='')

    parser.add_argument("--ignore_percentile",
                        type=int,
                        default=50,
                        help="percentile of scores to ignore")
    parser.add_argument("--good_percentile",
                        type=int,
                        default=0,
                        help="percentile of good scores selected")
    parser.add_argument("--data_seed",
                        type=int,
                        required=True,
                        help="Seed that has been used to generate the dataset")

    # Parse arguments
    hparams = parser.parse_args()
    hparams.dataset_path = get_filepath(
        hparams.ignore_percentile,
        hparams.dataset_path,
        hparams.data_seed,
        good_percentile=hparams.good_percentile)
    hparams.root_dir = get_path(
        k=hparams.rank_weight_k,
        ignore_percentile=hparams.ignore_percentile,
        good_percentile=hparams.good_percentile,
        n_max_epochs=hparams.max_epochs,
        predict_target=hparams.predict_target,
        beta_final=hparams.beta_final,
        beta_target_pred_loss=hparams.beta_target_pred_loss,
        beta_metric_loss=hparams.beta_metric_loss,
        latent_dim=hparams.latent_dim,
        hdims=hparams.target_predictor_hdims,
        metric_loss=hparams.metric_loss,
        metric_loss_kw=hparams.metric_loss_kw)
    print_flush(' '.join(sys.argv[1:]))
    print_flush(hparams.root_dir)

    pl.seed_everything(hparams.seed)

    # Create data
    datamodule = WeightedExprDataset(hparams,
                                     utils.DataWeighter(hparams),
                                     add_channel=False)

    device = hparams.cuda
    if device is not None:
        torch.cuda.set_device(device)

    data_info = G.gram.split('\n')

    # Load model
    model = EquationVaeTorch(hparams, len(data_info), MAX_LEN)
    # model.decoder.apply(torch_weight_init)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(period=max(
        1, hparams.max_epochs // 20),
                                                       monitor="loss/val",
                                                       save_top_k=-1,
                                                       save_last=True,
                                                       mode='min')

    if hparams.load_from_checkpoint is not None:
        # .load_from_checkpoint(hparams.load_from_checkpoint)
        model = EquationVaeTorch.load_from_checkpoint(
            hparams.load_from_checkpoint, len(data_info), MAX_LEN)
        utils.update_hparams(hparams, model)
        trainer = pl.Trainer(
            gpus=[hparams.cuda] if hparams.cuda else 0,
            default_root_dir=hparams.root_dir,
            max_epochs=hparams.max_epochs,
            callbacks=[
                checkpoint_callback,
                LearningRateMonitor(logging_interval='step')
            ],
            resume_from_checkpoint=hparams.load_from_checkpoint)

        print(f'Load from checkpoint')
    else:
        # Main trainer
        trainer = pl.Trainer(
            gpus=[hparams.cuda] if hparams.cuda is not None else 0,
            default_root_dir=hparams.root_dir,
            max_epochs=hparams.max_epochs,
            checkpoint_callback=True,
            callbacks=[
                checkpoint_callback,
                LearningRateMonitor(logging_interval='step')
            ],
            terminate_on_nan=True,
            progress_bar_refresh_rate=100)

    # Fit
    trainer.fit(model, datamodule=datamodule)

    print(
        f"Training finished; end of script: rename {checkpoint_callback.best_model_path}"
    )

    shutil.copyfile(
        checkpoint_callback.best_model_path,
        os.path.join(os.path.dirname(checkpoint_callback.best_model_path),
                     'best.ckpt'))
Example #9
0
if __name__ == "__main__":

    # Create arg parser
    parser = argparse.ArgumentParser()
    parser = JTVAE.add_model_specific_args(parser)
    parser = WeightedJTNNDataset.add_model_specific_args(parser)
    parser = utils.DataWeighter.add_weight_args(parser)
    utils.add_default_trainer_args(parser, default_root=None)

    # Parse arguments
    hparams = parser.parse_args()

    hparams.root_dir = os.path.join(get_storage_root(), hparams.root_dir)

    pl.seed_everything(hparams.seed)
    print_flush(' '.join(sys.argv[1:]))

    # Create data
    datamodule = WeightedJTNNDataset(hparams, utils.DataWeighter(hparams))
    datamodule.setup("fit")

    # Load model
    model = JTVAE(hparams, datamodule.vocab)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        period=1, monitor="loss/val", save_top_k=1,
        save_last=True, mode='min'
    )

    if hparams.load_from_checkpoint is not None:
            # .load_from_checkpoint(hparams.load_from_checkpoint)
Example #10
0
def main():
    # Create arg parser
    parser = argparse.ArgumentParser()
    parser = ShapesVAE.add_model_specific_args(parser)
    parser = WeightedNumpyDataset.add_model_specific_args(parser)
    parser = utils.DataWeighter.add_weight_args(parser)
    utils.add_default_trainer_args(parser, default_root="")

    # Parse arguments
    hparams = parser.parse_args()

    hparams.root_dir = shape_get_path(k=hparams.rank_weight_k, predict_target=hparams.predict_target,
                                      hdims=hparams.target_predictor_hdims, metric_loss=hparams.metric_loss,
                                      metric_loss_kw=hparams.metric_loss_kw, latent_dim=hparams.latent_dim)
    print_flush(' '.join(sys.argv[1:]))
    print_flush(hparams.root_dir)

    pl.seed_everything(hparams.seed)

    # Create data
    datamodule = WeightedNumpyDataset(hparams, utils.DataWeighter(hparams))

    # Load model
    model = ShapesVAE(hparams)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        period=max(1, hparams.max_epochs // 20),
        monitor="loss/val", save_top_k=1,
        save_last=True, mode='min'
    )

    if hparams.load_from_checkpoint is not None:
        model = ShapesVAE.load_from_checkpoint(hparams.load_from_checkpoint)
        utils.update_hparams(hparams, model)
        trainer = pl.Trainer(gpus=[hparams.cuda] if hparams.cuda is not None else 0,
                             default_root_dir=hparams.root_dir,
                             max_epochs=hparams.max_epochs,
                             callbacks=[checkpoint_callback, LearningRateMonitor(logging_interval='step')],
                             resume_from_checkpoint=hparams.load_from_checkpoint)

        print(f'Load from checkpoint')
    else:
        # Main trainer
        trainer = pl.Trainer(
            gpus=[hparams.cuda] if hparams.cuda is not None else 0,
            default_root_dir=hparams.root_dir,
            max_epochs=hparams.max_epochs,
            checkpoint_callback=True,
            callbacks=[checkpoint_callback, LearningRateMonitor(logging_interval='step')],
            terminate_on_nan=True,
            progress_bar_refresh_rate=100
        )

    # Fit
    trainer.fit(model, datamodule=datamodule)

    print(f"Training finished; end of script: rename {checkpoint_callback.best_model_path}")

    shutil.copyfile(checkpoint_callback.best_model_path, os.path.join(
        os.path.dirname(checkpoint_callback.best_model_path), 'best.ckpt'
    ))