def instentiate_expr_datamodule() -> WeightedExprDataset:
    """ Create a WeightedExprDataset """

    ignore_percentile = 65
    dataset_path = os.path.join(ROOT_PROJECT, 'weighted_retraining/data/expr')
    data_seed = 0
    good_percentile = 5
    weight_type = 'rank'
    rank_weight_k = 1
    weight_quantile = None
    val_frac = .1
    property_key = 'scores'
    second_key = 'expr'
    batch_size = 128
    predict_target = False
    metric_loss = None

    hparams = Namespace()
    hparams.ignore_percentile = ignore_percentile
    hparams.data_seed = data_seed
    hparams.good_percentile = good_percentile
    hparams.weight_type = weight_type
    hparams.dataset_path = dataset_path
    hparams.rank_weight_k = rank_weight_k
    hparams.weight_quantile = weight_quantile
    hparams.val_frac = val_frac
    hparams.property_key = property_key
    hparams.second_key = second_key
    hparams.batch_size = batch_size
    hparams.predict_target = predict_target
    hparams.metric_loss = metric_loss

    hparams.dataset_path = get_filepath(hparams.ignore_percentile, hparams.dataset_path, hparams.data_seed,
                                        good_percentile=hparams.good_percentile)

    datamodule = WeightedExprDataset(hparams, utils.DataWeighter(hparams), add_channel=False)
    datamodule.setup()
    return datamodule
def main_aux(args, result_dir: str):
    """ main """

    # Seeding
    pl.seed_everything(args.seed)

    # Make results directory
    data_dir = os.path.join(result_dir, "data")
    os.makedirs(data_dir, exist_ok=True)
    setup_logger(os.path.join(result_dir, "log.txt"))

    result_filepath = os.path.join(result_dir, 'results.pkl')
    if not args.overwrite and os.path.exists(result_filepath):
        print(f"Already exists: {result_dir}")
        return

    # Load data
    datamodule = WeightedJTNNDataset(args, utils.DataWeighter(args))
    datamodule.setup("fit", n_init_points=args.n_init_bo_points)

    # print python command run
    cmd = ' '.join(sys.argv[1:])
    print_flush(f"{cmd}\n")

    # Load model
    assert args.use_pretrained

    if args.predict_target:
        if 'pred_y' in args.pretrained_model_file:
            # fully supervised setup from a model trained with target prediction
            ckpt = torch.load(args.pretrained_model_file)
            ckpt['hyper_parameters']['hparams'].beta_target_pred_loss = args.beta_target_pred_loss
            ckpt['hyper_parameters']['hparams'].predict_target = True
            ckpt['hyper_parameters']['hparams'].target_predictor_hdims = args.target_predictor_hdims
            torch.save(ckpt, args.pretrained_model_file)
    vae: JTVAE = JTVAE.load_from_checkpoint(args.pretrained_model_file, vocab=datamodule.vocab)
    vae.beta = vae.hparams.beta_final  # Override any beta annealing
    vae.metric_loss = args.metric_loss
    vae.hparams.metric_loss = args.metric_loss
    vae.beta_metric_loss = args.beta_metric_loss
    vae.hparams.beta_metric_loss = args.beta_metric_loss
    vae.metric_loss_kw = args.metric_loss_kw
    vae.hparams.metric_loss_kw = args.metric_loss_kw
    vae.predict_target = args.predict_target
    vae.hparams.predict_target = args.predict_target
    vae.beta_target_pred_loss = args.beta_target_pred_loss
    vae.hparams.beta_target_pred_loss = args.beta_target_pred_loss
    vae.target_predictor_hdims = args.target_predictor_hdims
    vae.hparams.target_predictor_hdims = args.target_predictor_hdims
    if vae.predict_target and vae.target_predictor is None:
        vae.hparams.target_predictor_hdims = args.target_predictor_hdims
        vae.hparams.predict_target = args.predict_target
        vae.build_target_predictor()
    vae.eval()

    # Set up some stuff for the progress bar
    postfix = dict(
        n_train=len(datamodule.train_dataset.data),
        save_path=result_dir
    )

    # Set up results tracking
    start_time = time.time()

    train_chosen_indices = _choose_best_rand_points(n_rand_points=args.n_rand_points, n_best_points=args.n_best_points,
                                                    dataset=datamodule.train_dataset)
    train_mol_trees = [datamodule.train_dataset.data[i] for i in train_chosen_indices]
    train_targets = datamodule.train_dataset.data_properties[train_chosen_indices]
    train_chosen_smiles = [datamodule.train_dataset.canonic_smiles[i] for i in train_chosen_indices]

    test_chosen_indices = _choose_best_rand_points(n_rand_points=args.n_test_points, n_best_points=0,
                                                   dataset=datamodule.val_dataset)
    test_mol_trees = [datamodule.val_dataset.data[i] for i in test_chosen_indices]

    # Main loop
    with tqdm(
            total=1, dynamic_ncols=True, smoothing=0.0, file=sys.stdout
    ) as pbar:

        if vae.predict_target and vae.metric_loss is not None:
            vae.training_m = datamodule.training_m
            vae.training_M = datamodule.training_M
            vae.validation_m = datamodule.validation_m
            vae.validation_M = datamodule.validation_M

        torch.cuda.empty_cache()  # Free the memory up for tensorflow
        pbar.set_postfix(postfix)
        pbar.set_description("retraining")
        print(result_dir)

        # Optionally do retraining
        num_epochs =  args.n_init_retrain_epochs
        if num_epochs > 0:
            retrain_dir = os.path.join(result_dir, "retraining")
            version = f"retrain_0"
            retrain_model(
                model=vae, datamodule=datamodule, save_dir=retrain_dir,
                version_str=version, num_epochs=num_epochs, gpu=args.gpu
            )
            vae.eval()
        del num_epochs

        model = vae

        # Update progress bar
        pbar.set_postfix(postfix)

        # Do querying!
        gp_dir = os.path.join(result_dir, "gp")
        os.makedirs(gp_dir, exist_ok=True)
        gp_data_file = os.path.join(gp_dir, "data.npz")

        # Next, encode these mol trees
        if args.gpu:
            model = model.cuda()
        train_latent_points = _encode_mol_trees(model, train_mol_trees)
        test_latent_points = _encode_mol_trees(model, test_mol_trees)
        if args.use_decoded:
            print("Use targets from decoded latent test points")
            _, test_targets = _batch_decode_z_and_props(
                model,
                torch.as_tensor(test_latent_points, device=model.device),
                datamodule,
                invalid_score=args.invalid_score,
                pbar=pbar,
            )
            test_targets = np.array(test_targets)
        else:
            test_targets = datamodule.val_dataset.data_properties[test_chosen_indices]

        model = model.cpu()  # Make sure to free up GPU memory
        torch.cuda.empty_cache()  # Free the memory up for tensorflow

        # Save points to file
        def _save_gp_data(x, y, test_x, y_test, s, file, flip_sign=True):

            # Prevent overfitting to bad points
            y = np.maximum(y, args.invalid_score)
            if flip_sign:
                y = -y.reshape(-1, 1)  # Since it is a maximization problem
                y_test = -y_test.reshape(-1, 1)
            else:
                y = y.reshape(-1, 1)
                y_test = y_test.reshape(-1, 1)

            # Save the file
            np.savez_compressed(
                file,
                X_train=x.astype(np.float32),
                X_test=test_x.astype(np.float32),
                y_train=y.astype(np.float32),
                y_test=y_test.astype(np.float32),
                smiles=s,
            )

        _save_gp_data(train_latent_points, train_targets, test_latent_points, test_targets, train_chosen_smiles,
                      gp_data_file)
        current_n_inducing_points = min(train_latent_points.shape[0], args.n_inducing_points)

        new_gp_file = os.path.join(gp_dir, f"new.npz")
        log_path = os.path.join(gp_dir, f"gp_fit.log")

        iter_seed = int(np.random.randint(10000))
        gp_train_command = [
            "python",
            GP_TRAIN_FILE,
            f"--nZ={current_n_inducing_points}",
            f"--seed={iter_seed}",
            f"--data_file={str(gp_data_file)}",
            f"--save_file={str(new_gp_file)}",
            f"--logfile={str(log_path)}",
            "--use_test_set"
        ]
        gp_fit_desc = "GP initial fit"
        gp_train_command += [
            "--init",
            "--kmeans_init",
            f"--save_metrics_file={str(result_filepath)}"
        ]
        # Set pbar status for user
        if pbar is not None:
            old_desc = pbar.desc
            pbar.set_description(gp_fit_desc)

        _run_command(gp_train_command, f"GP train {0}")
        curr_gp_file = new_gp_file

    print_flush("=== DONE ({:.3f}s) ===".format(time.time() - start_time))
def main():
    # Create arg parser
    parser = argparse.ArgumentParser()
    parser = TopologyVAE.add_model_specific_args(parser)
    parser = WeightedNumpyDataset.add_model_specific_args(parser)
    parser = utils.DataWeighter.add_weight_args(parser)
    utils.add_default_trainer_args(parser, default_root="")

    parser.add_argument(
        "--augment_dataset",
        action='store_true',
        help="Use data augmentation or not",
    )
    parser.add_argument(
        "--use_binary_data",
        action='store_true',
        help="Binarize images in the dataset",
    )

    # Parse arguments
    hparams = parser.parse_args()
    hparams.root_dir = topology_get_path(
        k=hparams.rank_weight_k,
        n_max_epochs=hparams.max_epochs,
        predict_target=hparams.predict_target,
        hdims=hparams.target_predictor_hdims,
        metric_loss=hparams.metric_loss,
        metric_loss_kw=hparams.metric_loss_kw,
        beta_target_pred_loss=hparams.beta_target_pred_loss,
        beta_metric_loss=hparams.beta_metric_loss,
        latent_dim=hparams.latent_dim,
        beta_final=hparams.beta_final,
        use_binary_data=hparams.use_binary_data)
    print_flush(' '.join(sys.argv[1:]))
    print_flush(hparams.root_dir)
    pl.seed_everything(hparams.seed)

    # Create data
    if hparams.use_binary_data:
        if not os.path.exists(
                os.path.join(get_data_root(), 'topology_data/target_bin.npy')):
            gen_binary_dataset_from_all_files(get_data_root())
        hparams.dataset_path = os.path.join(ROOT_PROJECT,
                                            get_topology_binary_dataset_path())
    else:
        if not os.path.exists(
                os.path.join(get_data_root(), 'topology_data/target.npy')):
            gen_dataset_from_all_files(get_data_root())
        hparams.dataset_path = os.path.join(ROOT_PROJECT,
                                            get_topology_dataset_path())
    if hparams.augment_dataset:
        aug = transforms.Compose([
            # transforms.Normalize(mean=, std=),
            # transforms.RandomCrop(30, padding=10),
            transforms.RandomRotation(45),
            transforms.RandomRotation(90),
            transforms.RandomRotation(180),
            transforms.RandomVerticalFlip(0.5)
        ])
    else:
        aug = None
    datamodule = WeightedNumpyDataset(hparams,
                                      utils.DataWeighter(hparams),
                                      transform=aug)

    # Load model
    model = TopologyVAE(hparams)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(period=max(
        1, hparams.max_epochs // 10),
                                                       monitor="loss/val",
                                                       save_top_k=-1,
                                                       save_last=True,
                                                       mode='min')

    if hparams.load_from_checkpoint is not None:
        model = TopologyVAE.load_from_checkpoint(hparams.load_from_checkpoint)
        utils.update_hparams(hparams, model)
        trainer = pl.Trainer(
            gpus=[hparams.cuda] if hparams.cuda else 0,
            default_root_dir=hparams.root_dir,
            max_epochs=hparams.max_epochs,
            callbacks=[
                checkpoint_callback,
                LearningRateMonitor(logging_interval='step')
            ],
            resume_from_checkpoint=hparams.load_from_checkpoint)

        print(f'Load from checkpoint')
    else:
        # Main trainer
        trainer = pl.Trainer(
            gpus=[hparams.cuda] if hparams.cuda is not None else 0,
            default_root_dir=hparams.root_dir,
            max_epochs=hparams.max_epochs,
            checkpoint_callback=True,
            callbacks=[
                checkpoint_callback,
                LearningRateMonitor(logging_interval='step')
            ],
            terminate_on_nan=True,
            progress_bar_refresh_rate=5,
            # gradient_clip_val=20.0,
        )

    # Fit
    trainer.fit(model, datamodule=datamodule)

    print(
        f"Training finished; end of script: rename {checkpoint_callback.best_model_path}"
    )

    shutil.copyfile(
        checkpoint_callback.best_model_path,
        os.path.join(os.path.dirname(checkpoint_callback.best_model_path),
                     'best.ckpt'))
def main_aux(args, result_dir: str):
    """ main """

    # Seeding
    pl.seed_everything(args.seed)

    if args.train_only and os.path.exists(
            args.save_model_path) and not args.overwrite:
        print_flush(f'--- JTVAE already trained in {args.save_model_path} ---')
        return

    # Make results directory
    data_dir = os.path.join(result_dir, "data")
    os.makedirs(data_dir, exist_ok=True)
    setup_logger(os.path.join(result_dir, "log.txt"))

    # Load data
    datamodule = WeightedJTNNDataset(args, utils.DataWeighter(args))
    datamodule.setup("fit", n_init_points=args.n_init_bo_points)

    # print python command run
    cmd = ' '.join(sys.argv[1:])
    print_flush(f"{cmd}\n")

    # Load model
    if args.use_pretrained:
        if args.predict_target:
            if 'pred_y' in args.pretrained_model_file:
                # fully supervised training from a model already trained with target prediction
                ckpt = torch.load(args.pretrained_model_file)
                ckpt['hyper_parameters'][
                    'hparams'].beta_target_pred_loss = args.beta_target_pred_loss
                ckpt['hyper_parameters']['hparams'].predict_target = True
                ckpt['hyper_parameters'][
                    'hparams'].target_predictor_hdims = args.target_predictor_hdims
                torch.save(ckpt, args.pretrained_model_file)
        print(os.path.abspath(args.pretrained_model_file))
        vae: JTVAE = JTVAE.load_from_checkpoint(args.pretrained_model_file,
                                                vocab=datamodule.vocab)
        vae.beta = vae.hparams.beta_final  # Override any beta annealing
        vae.metric_loss = args.metric_loss
        vae.hparams.metric_loss = args.metric_loss
        vae.beta_metric_loss = args.beta_metric_loss
        vae.hparams.beta_metric_loss = args.beta_metric_loss
        vae.metric_loss_kw = args.metric_loss_kw
        vae.hparams.metric_loss_kw = args.metric_loss_kw
        vae.predict_target = args.predict_target
        vae.hparams.predict_target = args.predict_target
        vae.beta_target_pred_loss = args.beta_target_pred_loss
        vae.hparams.beta_target_pred_loss = args.beta_target_pred_loss
        vae.target_predictor_hdims = args.target_predictor_hdims
        vae.hparams.target_predictor_hdims = args.target_predictor_hdims
        if vae.predict_target and vae.target_predictor is None:
            vae.hparams.target_predictor_hdims = args.target_predictor_hdims
            vae.hparams.predict_target = args.predict_target
            vae.build_target_predictor()
    else:
        print("initialising VAE from scratch !")
        vae: JTVAE = JTVAE(hparams=args, vocab=datamodule.vocab)
    vae.eval()

    # Set up some stuff for the progress bar
    num_retrain = int(np.ceil(args.query_budget / args.retraining_frequency))
    postfix = dict(retrain_left=num_retrain,
                   best=float(datamodule.train_dataset.data_properties.max()),
                   n_train=len(datamodule.train_dataset.data),
                   save_path=result_dir)

    start_num_retrain = 0

    # Set up results tracking
    results = dict(
        opt_points=[],
        opt_point_properties=[],
        opt_model_version=[],
        params=str(sys.argv),
        sample_points=[],
        sample_versions=[],
        sample_properties=[],
    )

    result_filepath = os.path.join(result_dir, 'results.npz')
    if not args.overwrite and os.path.exists(result_filepath):
        with np.load(result_filepath, allow_pickle=True) as npz:
            results = {}
            for k in list(npz.keys()):
                results[k] = npz[k]
                if k != 'params':
                    results[k] = list(results[k])
                else:
                    results[k] = npz[k].item()
        start_num_retrain = results['opt_model_version'][-1] + 1

        prev_retrain_model = args.retraining_frequency * (start_num_retrain -
                                                          1)
        num_sampled_points = len(results['opt_points'])
        if args.n_init_retrain_epochs == 0 and prev_retrain_model == 0:
            pretrained_model_path = args.pretrained_model_file
        else:
            pretrained_model_path = os.path.join(
                result_dir, 'retraining', f'retrain_{prev_retrain_model}',
                'checkpoints', 'last.ckpt')
        print(f"Found checkpoint at {pretrained_model_path}")
        ckpt = torch.load(pretrained_model_path)
        ckpt['hyper_parameters']['hparams'].metric_loss = args.metric_loss
        ckpt['hyper_parameters'][
            'hparams'].metric_loss_kw = args.metric_loss_kw
        ckpt['hyper_parameters'][
            'hparams'].beta_metric_loss = args.beta_metric_loss
        ckpt['hyper_parameters'][
            'hparams'].beta_target_pred_loss = args.beta_target_pred_loss
        if args.predict_target:
            ckpt['hyper_parameters']['hparams'].predict_target = True
            ckpt['hyper_parameters'][
                'hparams'].target_predictor_hdims = args.target_predictor_hdims
        torch.save(ckpt, pretrained_model_path)
        print(f"Loading model from {pretrained_model_path}")
        vae.load_from_checkpoint(pretrained_model_path, vocab=datamodule.vocab)
        if args.predict_target and not hasattr(vae.hparams, 'predict_target'):
            vae.hparams.target_predictor_hdims = args.target_predictor_hdims
            vae.hparams.predict_target = args.predict_target
        # vae.hparams.cuda = args.cuda
        vae.beta = vae.hparams.beta_final  # Override any beta annealing
        vae.eval()

        # Set up some stuff for the progress bar
        num_retrain = int(
            np.ceil(args.query_budget /
                    args.retraining_frequency)) - start_num_retrain

        print(f"Append existing points and properties to datamodule...")
        datamodule.append_train_data(np.array(results['opt_points']),
                                     np.array(results['opt_point_properties']))
        postfix = dict(retrain_left=num_retrain,
                       best=float(
                           datamodule.train_dataset.data_properties.max()),
                       n_train=len(datamodule.train_dataset.data),
                       initial=num_sampled_points,
                       save_path=result_dir)
        print(
            f"Retrain from {result_dir} | Best: {max(results['opt_point_properties'])}"
        )
    start_time = time.time()

    # Main loop
    with tqdm(total=args.query_budget,
              dynamic_ncols=True,
              smoothing=0.0,
              file=sys.stdout) as pbar:

        for ret_idx in range(start_num_retrain,
                             start_num_retrain + num_retrain):

            if vae.predict_target and vae.metric_loss is not None:
                vae.training_m = datamodule.training_m
                vae.training_M = datamodule.training_M
                vae.validation_m = datamodule.validation_m
                vae.validation_M = datamodule.validation_M

            torch.cuda.empty_cache()  # Free the memory up for tensorflow
            pbar.set_postfix(postfix)
            pbar.set_description("retraining")
            print(result_dir)
            # Decide whether to retrain
            samples_so_far = args.retraining_frequency * ret_idx

            # Optionally do retraining
            num_epochs = args.n_retrain_epochs
            if ret_idx == 0 and args.n_init_retrain_epochs is not None:
                num_epochs = args.n_init_retrain_epochs
            if num_epochs > 0:
                retrain_dir = os.path.join(result_dir, "retraining")
                version = f"retrain_{samples_so_far}"
                retrain_model(model=vae,
                              datamodule=datamodule,
                              save_dir=retrain_dir,
                              version_str=version,
                              num_epochs=num_epochs,
                              gpu=args.gpu,
                              store_best=args.train_only,
                              best_ckpt_path=args.save_model_path)
                vae.eval()
                if args.train_only:
                    return
            del num_epochs

            model = vae

            # Update progress bar
            postfix["retrain_left"] -= 1
            pbar.set_postfix(postfix)

            # Draw samples for logs!
            if args.samples_per_model > 0:
                pbar.set_description("sampling")
                with trange(args.samples_per_model,
                            desc="sampling",
                            leave=False) as sample_pbar:
                    sample_x, sample_y = latent_sampling(
                        args,
                        model,
                        datamodule,
                        args.samples_per_model,
                        pbar=sample_pbar)

                # Append to results dict
                results["sample_points"].append(sample_x)
                results["sample_properties"].append(sample_y)
                results["sample_versions"].append(ret_idx)

            # Do querying!
            pbar.set_description("querying")
            num_queries_to_do = min(args.retraining_frequency,
                                    args.query_budget - samples_so_far)
            if args.lso_strategy == "opt":
                gp_dir = os.path.join(result_dir, "gp",
                                      f"iter{samples_so_far}")
                os.makedirs(gp_dir, exist_ok=True)
                gp_data_file = os.path.join(gp_dir, "data.npz")
                gp_err_data_file = os.path.join(gp_dir, "data_err.npz")
                x_new, y_new = latent_optimization(
                    model=model,
                    datamodule=datamodule,
                    n_inducing_points=args.n_inducing_points,
                    n_best_points=args.n_best_points,
                    n_rand_points=args.n_rand_points,
                    num_queries_to_do=num_queries_to_do,
                    gp_data_file=gp_data_file,
                    gp_err_data_file=gp_err_data_file,
                    gp_run_folder=gp_dir,
                    gpu=args.gpu,
                    invalid_score=args.invalid_score,
                    pbar=pbar,
                    postfix=postfix,
                    error_aware_acquisition=args.error_aware_acquisition,
                )
            elif args.lso_strategy == "sample":
                x_new, y_new = latent_sampling(
                    args,
                    model,
                    datamodule,
                    num_queries_to_do,
                    pbar=pbar,
                )
            else:
                raise NotImplementedError(args.lso_strategy)

            # Update dataset
            datamodule.append_train_data(x_new, y_new)

            # Add new results
            results["opt_points"] += list(x_new)
            results["opt_point_properties"] += list(y_new)
            results["opt_model_version"] += [ret_idx] * len(x_new)

            postfix["best"] = max(postfix["best"], float(max(y_new)))
            postfix["n_train"] = len(datamodule.train_dataset.data)
            pbar.set_postfix(postfix)

            # Save results
            np.savez_compressed(os.path.join(result_dir, "results.npz"),
                                **results)

            # Keep a record of the dataset here
            new_data_file = os.path.join(
                data_dir,
                f"train_data_iter{samples_so_far + num_queries_to_do}.txt")
            with open(new_data_file, "w") as f:
                f.write("\n".join(datamodule.train_dataset.canonic_smiles))

    print_flush("=== DONE ({:.3f}s) ===".format(time.time() - start_time))
Beispiel #5
0
def main():
    # Create arg parser
    parser = argparse.ArgumentParser()
    parser = EquationVaeTorch.add_model_specific_args(parser)
    parser = WeightedExprDataset.add_model_specific_args(parser)
    parser = utils.DataWeighter.add_weight_args(parser)
    utils.add_default_trainer_args(parser, default_root='')

    parser.add_argument("--ignore_percentile",
                        type=int,
                        default=50,
                        help="percentile of scores to ignore")
    parser.add_argument("--good_percentile",
                        type=int,
                        default=0,
                        help="percentile of good scores selected")
    parser.add_argument("--data_seed",
                        type=int,
                        required=True,
                        help="Seed that has been used to generate the dataset")

    # Parse arguments
    hparams = parser.parse_args()
    hparams.dataset_path = get_filepath(
        hparams.ignore_percentile,
        hparams.dataset_path,
        hparams.data_seed,
        good_percentile=hparams.good_percentile)
    hparams.root_dir = get_path(
        k=hparams.rank_weight_k,
        ignore_percentile=hparams.ignore_percentile,
        good_percentile=hparams.good_percentile,
        n_max_epochs=hparams.max_epochs,
        predict_target=hparams.predict_target,
        beta_final=hparams.beta_final,
        beta_target_pred_loss=hparams.beta_target_pred_loss,
        beta_metric_loss=hparams.beta_metric_loss,
        latent_dim=hparams.latent_dim,
        hdims=hparams.target_predictor_hdims,
        metric_loss=hparams.metric_loss,
        metric_loss_kw=hparams.metric_loss_kw)
    print_flush(' '.join(sys.argv[1:]))
    print_flush(hparams.root_dir)

    pl.seed_everything(hparams.seed)

    # Create data
    datamodule = WeightedExprDataset(hparams,
                                     utils.DataWeighter(hparams),
                                     add_channel=False)

    device = hparams.cuda
    if device is not None:
        torch.cuda.set_device(device)

    data_info = G.gram.split('\n')

    # Load model
    model = EquationVaeTorch(hparams, len(data_info), MAX_LEN)
    # model.decoder.apply(torch_weight_init)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(period=max(
        1, hparams.max_epochs // 20),
                                                       monitor="loss/val",
                                                       save_top_k=-1,
                                                       save_last=True,
                                                       mode='min')

    if hparams.load_from_checkpoint is not None:
        # .load_from_checkpoint(hparams.load_from_checkpoint)
        model = EquationVaeTorch.load_from_checkpoint(
            hparams.load_from_checkpoint, len(data_info), MAX_LEN)
        utils.update_hparams(hparams, model)
        trainer = pl.Trainer(
            gpus=[hparams.cuda] if hparams.cuda else 0,
            default_root_dir=hparams.root_dir,
            max_epochs=hparams.max_epochs,
            callbacks=[
                checkpoint_callback,
                LearningRateMonitor(logging_interval='step')
            ],
            resume_from_checkpoint=hparams.load_from_checkpoint)

        print(f'Load from checkpoint')
    else:
        # Main trainer
        trainer = pl.Trainer(
            gpus=[hparams.cuda] if hparams.cuda is not None else 0,
            default_root_dir=hparams.root_dir,
            max_epochs=hparams.max_epochs,
            checkpoint_callback=True,
            callbacks=[
                checkpoint_callback,
                LearningRateMonitor(logging_interval='step')
            ],
            terminate_on_nan=True,
            progress_bar_refresh_rate=100)

    # Fit
    trainer.fit(model, datamodule=datamodule)

    print(
        f"Training finished; end of script: rename {checkpoint_callback.best_model_path}"
    )

    shutil.copyfile(
        checkpoint_callback.best_model_path,
        os.path.join(os.path.dirname(checkpoint_callback.best_model_path),
                     'best.ckpt'))
    parser = argparse.ArgumentParser()
    parser = JTVAE.add_model_specific_args(parser)
    parser = WeightedJTNNDataset.add_model_specific_args(parser)
    parser = utils.DataWeighter.add_weight_args(parser)
    utils.add_default_trainer_args(parser, default_root=None)

    # Parse arguments
    hparams = parser.parse_args()

    hparams.root_dir = os.path.join(get_storage_root(), hparams.root_dir)

    pl.seed_everything(hparams.seed)
    print_flush(' '.join(sys.argv[1:]))

    # Create data
    datamodule = WeightedJTNNDataset(hparams, utils.DataWeighter(hparams))
    datamodule.setup("fit")

    # Load model
    model = JTVAE(hparams, datamodule.vocab)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        period=1, monitor="loss/val", save_top_k=1,
        save_last=True, mode='min'
    )

    if hparams.load_from_checkpoint is not None:
            # .load_from_checkpoint(hparams.load_from_checkpoint)
        # utils.update_hparams(hparams, model)
        trainer = pl.Trainer(gpus=[hparams.cuda] if hparams.cuda else 0,
            default_root_dir=hparams.root_dir,
Beispiel #7
0
def main():
    # Create arg parser
    parser = argparse.ArgumentParser()
    parser = ShapesVAE.add_model_specific_args(parser)
    parser = WeightedNumpyDataset.add_model_specific_args(parser)
    parser = utils.DataWeighter.add_weight_args(parser)
    utils.add_default_trainer_args(parser, default_root="")

    # Parse arguments
    hparams = parser.parse_args()

    hparams.root_dir = shape_get_path(k=hparams.rank_weight_k, predict_target=hparams.predict_target,
                                      hdims=hparams.target_predictor_hdims, metric_loss=hparams.metric_loss,
                                      metric_loss_kw=hparams.metric_loss_kw, latent_dim=hparams.latent_dim)
    print_flush(' '.join(sys.argv[1:]))
    print_flush(hparams.root_dir)

    pl.seed_everything(hparams.seed)

    # Create data
    datamodule = WeightedNumpyDataset(hparams, utils.DataWeighter(hparams))

    # Load model
    model = ShapesVAE(hparams)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        period=max(1, hparams.max_epochs // 20),
        monitor="loss/val", save_top_k=1,
        save_last=True, mode='min'
    )

    if hparams.load_from_checkpoint is not None:
        model = ShapesVAE.load_from_checkpoint(hparams.load_from_checkpoint)
        utils.update_hparams(hparams, model)
        trainer = pl.Trainer(gpus=[hparams.cuda] if hparams.cuda is not None else 0,
                             default_root_dir=hparams.root_dir,
                             max_epochs=hparams.max_epochs,
                             callbacks=[checkpoint_callback, LearningRateMonitor(logging_interval='step')],
                             resume_from_checkpoint=hparams.load_from_checkpoint)

        print(f'Load from checkpoint')
    else:
        # Main trainer
        trainer = pl.Trainer(
            gpus=[hparams.cuda] if hparams.cuda is not None else 0,
            default_root_dir=hparams.root_dir,
            max_epochs=hparams.max_epochs,
            checkpoint_callback=True,
            callbacks=[checkpoint_callback, LearningRateMonitor(logging_interval='step')],
            terminate_on_nan=True,
            progress_bar_refresh_rate=100
        )

    # Fit
    trainer.fit(model, datamodule=datamodule)

    print(f"Training finished; end of script: rename {checkpoint_callback.best_model_path}")

    shutil.copyfile(checkpoint_callback.best_model_path, os.path.join(
        os.path.dirname(checkpoint_callback.best_model_path), 'best.ckpt'
    ))