args.save_dir = '/home/willlamb/checkpoints/dun'
args.results_dir = '/home/willlamb/results/dun'
args.wandb_proj = 'official_dun'
args.wandb_name = 'dun'
args.checkpoint_path = None

# ensembling and samples
args.ensemble_size = 1
args.ensemble_start_idx = 0
args.pytorch_seeds = [0, 1, 2, 3, 4]
args.samples = 100

### dun ###

args.dun = True
args.depth_min = 1
args.depth_max = 5

args.epochs = 0
args.epochs_dun = 350

args.batch_size_dun = 50
args.lr_dun_min = 1e-4
args.lr_dun_max = 1e-3
args.prior_sig_dun = 0.05

args.rho_min_dun = -5.5
args.rho_max_dun = -5
args.samples_dun = 5

args.presave_dun = 150
示例#2
0
def new_noise(args: TrainArgs, logger: Logger = None) -> List[float]:
    """
    Trains a model and returns test scores on the model checkpoint with the highest validation score.

    :param args: Arguments.
    :param logger: Logger.
    :return: A list of ensemble scores for each task.
    """

    debug = info = print

    # Get data
    args.task_names = args.target_columns or get_task_names(args.data_path)
    data = get_data(path=args.data_path, args=args, logger=logger)
    args.num_tasks = data.num_tasks()
    args.features_size = data.features_size()

    # Split data
    debug(f'Splitting data with seed {args.seed}')
    train_data, val_data, test_data = split_data(data=data,
                                                 split_type=args.split_type,
                                                 sizes=args.split_sizes,
                                                 seed=args.seed,
                                                 args=args,
                                                 logger=logger)

    if args.features_scaling:
        features_scaler = train_data.normalize_features(replace_nan_token=0)
        val_data.normalize_features(features_scaler)
        test_data.normalize_features(features_scaler)
    else:
        features_scaler = None

    args.train_data_size = len(train_data)

    # Initialize scaler and scale training targets by subtracting mean and dividing standard deviation (regression only)
    if args.dataset_type == 'regression':
        debug('Fitting scaler')
        train_smiles, train_targets = train_data.smiles(), train_data.targets()
        scaler = StandardScaler().fit(train_targets)
        scaled_targets = scaler.transform(train_targets).tolist()
        train_data.set_targets(scaled_targets)
    else:
        scaler = None

    # Get loss and metric functions
    loss_func = neg_log_like
    metric_func = get_metric_func(metric=args.metric)

    # Set up test set evaluation
    test_smiles, test_targets = test_data.smiles(), test_data.targets()
    sum_test_preds = np.zeros((len(test_smiles), args.num_tasks))

    # Automatically determine whether to cache
    if len(data) <= args.cache_cutoff:
        cache = True
        num_workers = 0
    else:
        cache = False
        num_workers = args.num_workers

    # Create data loaders
    train_data_loader = MoleculeDataLoader(dataset=train_data,
                                           batch_size=args.batch_size,
                                           num_workers=num_workers,
                                           cache=cache)
    val_data_loader = MoleculeDataLoader(dataset=val_data,
                                         batch_size=args.batch_size,
                                         num_workers=num_workers,
                                         cache=cache)
    test_data_loader = MoleculeDataLoader(dataset=test_data,
                                          batch_size=args.batch_size,
                                          num_workers=num_workers,
                                          cache=cache)

    ###########################################
    ########## Outer loop over ensemble members
    ###########################################

    for model_idx in range(args.ensemble_start_idx,
                           args.ensemble_start_idx + args.ensemble_size):

        # load the model
        if (args.method == 'map') or (args.method == 'swag') or (args.method
                                                                 == 'sgld'):
            model = load_checkpoint(args.checkpoint_path +
                                    f'/model_{model_idx}/model.pt',
                                    device=args.device,
                                    logger=logger)

        if args.method == 'gp':
            args.num_inducing_points = 1200
            fake_model = MoleculeModel(args)
            fake_model.featurizer = True
            feature_extractor = fake_model
            inducing_points = initial_inducing_points(train_data_loader,
                                                      feature_extractor, args)
            gp_layer = GPLayer(inducing_points, args.num_tasks)
            model = load_checkpoint(
                args.checkpoint_path + f'/model_{model_idx}/DKN_model.pt',
                device=args.device,
                logger=None,
                template=DKLMoleculeModel(MoleculeModel(args, featurizer=True),
                                          gp_layer))

        if args.method == 'dropR' or args.method == 'dropA':
            model = load_checkpoint(args.checkpoint_path +
                                    f'/model_{model_idx}/model.pt',
                                    device=args.device,
                                    logger=logger)

        if args.method == 'bbp':
            template = MoleculeModelBBP(args)
            for layer in template.children():
                if isinstance(layer, BayesLinear):
                    layer.init_rho(args.rho_min_bbp, args.rho_max_bbp)
            for layer in template.encoder.encoder.children():
                if isinstance(layer, BayesLinear):
                    layer.init_rho(args.rho_min_bbp, args.rho_max_bbp)
            model = load_checkpoint(args.checkpoint_path +
                                    f'/model_{model_idx}/model_bbp.pt',
                                    device=args.device,
                                    logger=None,
                                    template=template)

        if args.method == 'dun':
            args.prior_sig_dun = 0.05
            args.depth_min = 1
            args.depth_max = 5
            args.rho_min_dun = -5.5
            args.rho_max_dun = -5
            args.log_cat_init = 0
            template = MoleculeModelDUN(args)
            for layer in template.children():
                if isinstance(layer, BayesLinear):
                    layer.init_rho(args.rho_min_dun, args.rho_max_dun)
            for layer in template.encoder.encoder.children():
                if isinstance(layer, BayesLinear):
                    layer.init_rho(args.rho_min_dun, args.rho_max_dun)
            template.create_log_cat(args)
            model = load_checkpoint(args.checkpoint_path +
                                    f'/model_{model_idx}/model_dun.pt',
                                    device=args.device,
                                    logger=None,
                                    template=template)

        # make results_dir
        results_dir = os.path.join(args.results_dir, f'model_{model_idx}')
        makedirs(results_dir)

        # train_preds, train_targets
        train_preds = predict(model=model,
                              data_loader=train_data_loader,
                              args=args,
                              scaler=scaler,
                              test_data=False,
                              bbp_sample=False)
        train_preds = np.array(train_preds)
        train_targets = np.array(train_targets)

        # compute tstats
        tstats = np.ones((12, 3))
        for task in range(12):
            resid = train_preds[:, task] - train_targets[:, task]
            tstats[task] = np.array(stats.t.fit(resid, floc=0.0))

        ##################################
        ########## Inner loop over samples
        ##################################

        for sample_idx in range(args.samples):

            # save down
            np.savez(os.path.join(results_dir, f'tstats_{sample_idx}'), tstats)

            print('done one')