Beispiel #1
0
def set_features(args: PredictArgs, train_args: TrainArgs):
    """
    Function to set extra options.

    :param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
                 loading data and a model and making predictions.
    :param train_args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model.
    """
    reset_featurization_parameters()

    if args.atom_descriptors == 'feature':
        set_extra_atom_fdim(train_args.atom_features_size)

    if args.bond_features_path is not None:
        set_extra_bond_fdim(train_args.bond_features_size)

    #set explicit H option and reaction option
    set_explicit_h(train_args.explicit_h)
    set_adding_hs(args.adding_h)
    set_reaction(train_args.reaction, train_args.reaction_mode)
Beispiel #2
0
def make_predictions(
        args: PredictArgs,
        smiles: List[List[str]] = None) -> List[List[Optional[float]]]:
    """
    Loads data and a trained model and uses the model to make predictions on the data.

    If SMILES are provided, then makes predictions on smiles.
    Otherwise makes predictions on :code:`args.test_data`.

    :param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
                 loading data and a model and making predictions.
    :param smiles: List of list of SMILES to make predictions on.
    :return: A list of lists of target predictions.
    """
    print("Loading training args")
    train_args = load_args(args.checkpoint_paths[0])
    num_tasks, task_names = train_args.num_tasks, train_args.task_names

    update_prediction_args(predict_args=args, train_args=train_args)
    args: Union[PredictArgs, TrainArgs]

    if args.atom_descriptors == "feature":
        set_extra_atom_fdim(train_args.atom_features_size)

    if args.bond_features_path is not None:
        set_extra_bond_fdim(train_args.bond_features_size)

    # set explicit H option and reaction option
    set_explicit_h(train_args.explicit_h)
    set_reaction(train_args.reaction, train_args.reaction_mode)

    print("Loading data")
    if smiles is not None:
        full_data = get_data_from_smiles(
            smiles=smiles,
            skip_invalid_smiles=False,
            features_generator=args.features_generator,
        )
    else:
        full_data = get_data(
            path=args.test_path,
            smiles_columns=args.smiles_columns,
            target_columns=[],
            ignore_columns=[],
            skip_invalid_smiles=False,
            args=args,
            store_row=not args.drop_extra_columns,
        )

    print("Validating SMILES")
    full_to_valid_indices = {}
    valid_index = 0
    for full_index in range(len(full_data)):
        if all(mol is not None for mol in full_data[full_index].mol):
            full_to_valid_indices[full_index] = valid_index
            valid_index += 1

    test_data = MoleculeDataset(
        [full_data[i] for i in sorted(full_to_valid_indices.keys())])

    # Edge case if empty list of smiles is provided
    if len(test_data) == 0:
        return [None] * len(full_data)

    print(f"Test size = {len(test_data):,}")

    # Predict with each model individually and sum predictions
    if args.dataset_type == "multiclass":
        sum_preds = np.zeros(
            (len(test_data), num_tasks, args.multiclass_num_classes))
    else:
        sum_preds = np.zeros((len(test_data), num_tasks))

    # Create data loader
    test_data_loader = MoleculeDataLoader(
        dataset=test_data,
        batch_size=args.batch_size,
        num_workers=0 if sys.platform == "darwin" else args.num_workers,
    )

    # Partial results for variance robust calculation.
    if args.ensemble_variance:
        all_preds = np.zeros(
            (len(test_data), num_tasks, len(args.checkpoint_paths)))

    print(
        f"Predicting with an ensemble of {len(args.checkpoint_paths)} models")
    for index, checkpoint_path in enumerate(
            tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths))):
        # Load model and scalers
        model = load_checkpoint(checkpoint_path, device=args.device)
        (
            scaler,
            features_scaler,
            atom_descriptor_scaler,
            bond_feature_scaler,
        ) = load_scalers(checkpoint_path)

        # Normalize features
        if (args.features_scaling or train_args.atom_descriptor_scaling
                or train_args.bond_feature_scaling):
            test_data.reset_features_and_targets()
            if args.features_scaling:
                test_data.normalize_features(features_scaler)
            if (train_args.atom_descriptor_scaling
                    and args.atom_descriptors is not None):
                test_data.normalize_features(atom_descriptor_scaler,
                                             scale_atom_descriptors=True)
            if train_args.bond_feature_scaling and args.bond_features_size > 0:
                test_data.normalize_features(bond_feature_scaler,
                                             scale_bond_features=True)

        # Make predictions
        model_preds = predict(model=model,
                              data_loader=test_data_loader,
                              scaler=scaler)
        sum_preds += np.array(model_preds)
        if args.ensemble_variance:
            all_preds[:, :, index] = model_preds

    # Ensemble predictions
    avg_preds = sum_preds / len(args.checkpoint_paths)
    avg_preds = avg_preds.tolist()

    if args.ensemble_variance:
        all_epi_uncs = np.var(all_preds, axis=2)
        all_epi_uncs = all_epi_uncs.tolist()

    # Save predictions
    print(f"Saving predictions to {args.preds_path}")
    assert len(test_data) == len(avg_preds)
    if args.ensemble_variance:
        assert len(test_data) == len(all_epi_uncs)
    makedirs(args.preds_path, isfile=True)

    # Get prediction column names
    if args.dataset_type == "multiclass":
        task_names = [
            f"{name}_class_{i}" for name in task_names
            for i in range(args.multiclass_num_classes)
        ]
    else:
        task_names = task_names

    # Copy predictions over to full_data
    for full_index, datapoint in enumerate(full_data):
        valid_index = full_to_valid_indices.get(full_index, None)
        preds = (avg_preds[valid_index] if valid_index is not None else
                 ["Invalid SMILES"] * len(task_names))
        if args.ensemble_variance:
            epi_uncs = (all_epi_uncs[valid_index] if valid_index is not None
                        else ["Invalid SMILES"] * len(task_names))

        # If extra columns have been dropped, add back in SMILES columns
        if args.drop_extra_columns:
            datapoint.row = OrderedDict()

            smiles_columns = args.smiles_columns

            for column, smiles in zip(smiles_columns, datapoint.smiles):
                datapoint.row[column] = smiles

        # Add predictions columns
        if args.ensemble_variance:
            for pred_name, pred, epi_unc in zip(task_names, preds, epi_uncs):
                datapoint.row[pred_name] = pred
                datapoint.row[pred_name + "_epi_unc"] = epi_unc
        else:
            for pred_name, pred in zip(task_names, preds):
                datapoint.row[pred_name] = pred

    # Save
    with open(args.preds_path, "w") as f:
        writer = csv.DictWriter(f, fieldnames=full_data[0].row.keys())
        writer.writeheader()

        for datapoint in full_data:
            writer.writerow(datapoint.row)

    return avg_preds
def molecule_fingerprint(
        args: FingerprintArgs,
        smiles: List[List[str]] = None) -> List[List[Optional[float]]]:
    """
    Loads data and a trained model and uses the model to encode fingerprint vectors for the data.

    :param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
                 loading data and a model and making predictions.
    :param smiles: List of list of SMILES to make predictions on.
    :return: A list of fingerprint vectors (list of floats)
    """

    print('Loading training args')
    train_args = load_args(args.checkpoint_paths[0])

    # Update args with training arguments
    if args.fingerprint_type == 'MPN':  # only need to supply input features if using FFN latent representation and if model calls for them.
        validate_feature_sources = False
    else:
        validate_feature_sources = True
    update_prediction_args(predict_args=args,
                           train_args=train_args,
                           validate_feature_sources=validate_feature_sources)
    args: Union[FingerprintArgs, TrainArgs]

    #set explicit H option and reaction option
    reset_featurization_parameters()
    if args.atom_descriptors == 'feature':
        set_extra_atom_fdim(train_args.atom_features_size)

    if args.bond_features_path is not None:
        set_extra_bond_fdim(train_args.bond_features_size)

    set_explicit_h(train_args.explicit_h)
    set_adding_hs(args.adding_h)
    if train_args.reaction:
        set_reaction(train_args.reaction, train_args.reaction_mode)
    elif train_args.reaction_solvent:
        set_reaction(True, train_args.reaction_mode)

    print('Loading data')
    if smiles is not None:
        full_data = get_data_from_smiles(
            smiles=smiles,
            skip_invalid_smiles=False,
            features_generator=args.features_generator)
    else:
        full_data = get_data(path=args.test_path,
                             smiles_columns=args.smiles_columns,
                             target_columns=[],
                             ignore_columns=[],
                             skip_invalid_smiles=False,
                             args=args,
                             store_row=True)

    print('Validating SMILES')
    full_to_valid_indices = {}
    valid_index = 0
    for full_index in range(len(full_data)):
        if all(mol is not None for mol in full_data[full_index].mol):
            full_to_valid_indices[full_index] = valid_index
            valid_index += 1

    test_data = MoleculeDataset(
        [full_data[i] for i in sorted(full_to_valid_indices.keys())])

    # Edge case if empty list of smiles is provided
    if len(test_data) == 0:
        return [None] * len(full_data)

    print(f'Test size = {len(test_data):,}')

    # Create data loader
    test_data_loader = MoleculeDataLoader(dataset=test_data,
                                          batch_size=args.batch_size,
                                          num_workers=args.num_workers)

    # Set fingerprint size
    if args.fingerprint_type == 'MPN':
        if args.atom_descriptors == "descriptor":  # special case when we have 'descriptor' extra dimensions need to be added
            total_fp_size = (
                args.hidden_size +
                test_data.atom_descriptors_size()) * args.number_of_molecules
        else:
            if args.reaction_solvent:
                total_fp_size = args.hidden_size + args.hidden_size_solvent
            else:
                total_fp_size = args.hidden_size * args.number_of_molecules
        if args.features_only:
            raise ValueError(
                'With features_only models, there is no latent MPN representation. Use last_FFN fingerprint type instead.'
            )
    elif args.fingerprint_type == 'last_FFN':
        if args.ffn_num_layers != 1:
            total_fp_size = args.ffn_hidden_size
        else:
            raise ValueError(
                'With a ffn_num_layers of 1, there is no latent FFN representation. Use MPN fingerprint type instead.'
            )
    else:
        raise ValueError(
            f'Fingerprint type {args.fingerprint_type} not supported')
    all_fingerprints = np.zeros(
        (len(test_data), total_fp_size, len(args.checkpoint_paths)))

    # Load model
    print(
        f'Encoding smiles into a fingerprint vector from {len(args.checkpoint_paths)} models.'
    )

    for index, checkpoint_path in enumerate(
            tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths))):
        model = load_checkpoint(checkpoint_path, device=args.device)
        scaler, features_scaler, atom_descriptor_scaler, bond_feature_scaler = load_scalers(
            args.checkpoint_paths[index])

        # Normalize features
        if args.features_scaling or train_args.atom_descriptor_scaling or train_args.bond_feature_scaling:
            test_data.reset_features_and_targets()
            if args.features_scaling:
                test_data.normalize_features(features_scaler)
            if train_args.atom_descriptor_scaling and args.atom_descriptors is not None:
                test_data.normalize_features(atom_descriptor_scaler,
                                             scale_atom_descriptors=True)
            if train_args.bond_feature_scaling and args.bond_features_size > 0:
                test_data.normalize_features(bond_feature_scaler,
                                             scale_bond_features=True)

        # Make fingerprints
        model_fp = model_fingerprint(model=model,
                                     data_loader=test_data_loader,
                                     fingerprint_type=args.fingerprint_type)
        if args.fingerprint_type == 'MPN' and (
                args.features_path is not None or args.features_generator
        ):  # truncate any features from MPN fingerprint
            model_fp = np.array(model_fp)[:, :total_fp_size]
        all_fingerprints[:, :, index] = model_fp

    # Save predictions
    print(f'Saving predictions to {args.preds_path}')
    # assert len(test_data) == len(all_fingerprints) #TODO: add unit test for this
    makedirs(args.preds_path, isfile=True)

    # Set column names
    fingerprint_columns = []
    if args.fingerprint_type == 'MPN':
        if len(args.checkpoint_paths) == 1:
            for j in range(total_fp_size // args.number_of_molecules):
                for k in range(args.number_of_molecules):
                    fingerprint_columns.append(f'fp_{j}_mol_{k}')
        else:
            for j in range(total_fp_size // args.number_of_molecules):
                for i in range(len(args.checkpoint_paths)):
                    for k in range(args.number_of_molecules):
                        fingerprint_columns.append(f'fp_{j}_mol_{k}_model_{i}')

    else:  # args == 'last_FNN'
        if len(args.checkpoint_paths) == 1:
            for j in range(total_fp_size):
                fingerprint_columns.append(f'fp_{j}')
        else:
            for j in range(total_fp_size):
                for i in range(len(args.checkpoint_paths)):
                    fingerprint_columns.append(f'fp_{j}_model_{i}')

    # Copy predictions over to full_data
    for full_index, datapoint in enumerate(full_data):
        valid_index = full_to_valid_indices.get(full_index, None)
        preds = all_fingerprints[valid_index].reshape(
            (len(args.checkpoint_paths) * total_fp_size
             )) if valid_index is not None else ['Invalid SMILES'] * len(
                 args.checkpoint_paths) * total_fp_size

        for i in range(len(fingerprint_columns)):
            datapoint.row[fingerprint_columns[i]] = preds[i]

    # Write predictions
    with open(args.preds_path, 'w') as f:
        writer = csv.DictWriter(f,
                                fieldnames=args.smiles_columns +
                                fingerprint_columns,
                                extrasaction='ignore')
        writer.writeheader()
        for datapoint in full_data:
            writer.writerow(datapoint.row)

    return all_fingerprints
def molecule_fingerprint(
        args: PredictArgs,
        smiles: List[List[str]] = None) -> List[List[Optional[float]]]:
    """
    Loads data and a trained model and uses the model to encode fingerprint vectors for the data.

    :param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
                 loading data and a model and making predictions.
    :param smiles: List of list of SMILES to make predictions on.
    :return: A list of fingerprint vectors (list of floats)
    """

    print('Loading training args')
    train_args = load_args(args.checkpoint_paths[0])

    # Update args with training arguments
    update_prediction_args(predict_args=args,
                           train_args=train_args,
                           validate_feature_sources=False)
    args: Union[PredictArgs, TrainArgs]

    #set explicit H option and reaction option
    set_explicit_h(train_args.explicit_h)
    set_reaction(train_args.reaction, train_args.reaction_mode)

    print('Loading data')
    if smiles is not None:
        full_data = get_data_from_smiles(
            smiles=smiles,
            skip_invalid_smiles=False,
            features_generator=args.features_generator)
    else:
        full_data = get_data(path=args.test_path,
                             smiles_columns=args.smiles_columns,
                             target_columns=[],
                             ignore_columns=[],
                             skip_invalid_smiles=False,
                             args=args,
                             store_row=True)

    print('Validating SMILES')
    full_to_valid_indices = {}
    valid_index = 0
    for full_index in range(len(full_data)):
        if all(mol is not None for mol in full_data[full_index].mol):
            full_to_valid_indices[full_index] = valid_index
            valid_index += 1

    test_data = MoleculeDataset(
        [full_data[i] for i in sorted(full_to_valid_indices.keys())])

    # Edge case if empty list of smiles is provided
    if len(test_data) == 0:
        return [None] * len(full_data)

    print(f'Test size = {len(test_data):,}')

    # Create data loader
    test_data_loader = MoleculeDataLoader(dataset=test_data,
                                          batch_size=args.batch_size,
                                          num_workers=args.num_workers)

    # Load model
    print(f'Encoding smiles into a fingerprint vector from a single model')
    if len(args.checkpoint_paths) != 1:
        raise ValueError(
            "Fingerprint generation only supports one model, cannot use an ensemble"
        )

    model = load_checkpoint(args.checkpoint_paths[0], device=args.device)
    scaler, features_scaler, atom_descriptor_scaler, bond_feature_scaler = load_scalers(
        args.checkpoint_paths[0])

    # Normalize features
    if args.features_scaling or train_args.atom_descriptor_scaling or train_args.bond_feature_scaling:
        test_data.reset_features_and_targets()
        if args.features_scaling:
            test_data.normalize_features(features_scaler)
        if train_args.atom_descriptor_scaling and args.atom_descriptors is not None:
            test_data.normalize_features(atom_descriptor_scaler,
                                         scale_atom_descriptors=True)
        if train_args.bond_feature_scaling and args.bond_features_size > 0:
            test_data.normalize_features(bond_feature_scaler,
                                         scale_bond_features=True)

    # Make fingerprints
    model_preds = model_fingerprint(model=model, data_loader=test_data_loader)

    # Save predictions
    print(f'Saving predictions to {args.preds_path}')
    assert len(test_data) == len(model_preds)
    makedirs(args.preds_path, isfile=True)

    # Copy predictions over to full_data
    total_hidden_size = args.hidden_size * args.number_of_molecules
    for full_index, datapoint in enumerate(full_data):
        valid_index = full_to_valid_indices.get(full_index, None)
        preds = model_preds[valid_index] if valid_index is not None else [
            'Invalid SMILES'
        ] * total_hidden_size

        fingerprint_columns = [f'fp_{i}' for i in range(total_hidden_size)]
        for i in range(len(fingerprint_columns)):
            datapoint.row[fingerprint_columns[i]] = preds[i]

    # Write predictions
    with open(args.preds_path, 'w') as f:
        writer = csv.DictWriter(f,
                                fieldnames=args.smiles_columns +
                                fingerprint_columns,
                                extrasaction='ignore')
        writer.writeheader()
        for datapoint in full_data:
            writer.writerow(datapoint.row)

    return model_preds
Beispiel #5
0
def cross_validate(args: TrainArgs,
                   train_func: Callable[[TrainArgs, MoleculeDataset, Logger], Dict[str, List[float]]]
                   ) -> Tuple[float, float]:
    """
    Runs k-fold cross-validation.

    For each of k splits (folds) of the data, trains and tests a model on that split
    and aggregates the performance across folds.

    :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for
                 loading data and training the Chemprop model.
    :param train_func: Function which runs training.
    :return: A tuple containing the mean and standard deviation performance across folds.
    """
    logger = create_logger(name=TRAIN_LOGGER_NAME, save_dir=args.save_dir, quiet=args.quiet)
    if logger is not None:
        debug, info = logger.debug, logger.info
    else:
        debug = info = print

    # Initialize relevant variables
    init_seed = args.seed
    save_dir = args.save_dir
    args.task_names = get_task_names(path=args.data_path, smiles_columns=args.smiles_columns,
                                     target_columns=args.target_columns, ignore_columns=args.ignore_columns)

    # Print command line
    debug('Command line')
    debug(f'python {" ".join(sys.argv)}')

    # Print args
    debug('Args')
    debug(args)

    # Save args
    makedirs(args.save_dir)
    try:
        args.save(os.path.join(args.save_dir, 'args.json'))
    except subprocess.CalledProcessError:
        debug('Could not write the reproducibility section of the arguments to file, thus omitting this section.')
        args.save(os.path.join(args.save_dir, 'args.json'), with_reproducibility=False)

    # set explicit H option and reaction option
    reset_featurization_parameters(logger=logger)
    set_explicit_h(args.explicit_h)
    set_adding_hs(args.adding_h)
    if args.reaction:
        set_reaction(args.reaction, args.reaction_mode)
    elif args.reaction_solvent:
        set_reaction(True, args.reaction_mode)
    
    # Get data
    debug('Loading data')
    data = get_data(
        path=args.data_path,
        args=args,
        logger=logger,
        skip_none_targets=True,
        data_weights_path=args.data_weights_path
    )
    validate_dataset_type(data, dataset_type=args.dataset_type)
    args.features_size = data.features_size()

    if args.atom_descriptors == 'descriptor':
        args.atom_descriptors_size = data.atom_descriptors_size()
        args.ffn_hidden_size += args.atom_descriptors_size
    elif args.atom_descriptors == 'feature':
        args.atom_features_size = data.atom_features_size()
        set_extra_atom_fdim(args.atom_features_size)
    if args.bond_features_path is not None:
        args.bond_features_size = data.bond_features_size()
        set_extra_bond_fdim(args.bond_features_size)

    debug(f'Number of tasks = {args.num_tasks}')

    if args.target_weights is not None and len(args.target_weights) != args.num_tasks:
        raise ValueError('The number of provided target weights must match the number and order of the prediction tasks')

    # Run training on different random seeds for each fold
    all_scores = defaultdict(list)
    for fold_num in range(args.num_folds):
        info(f'Fold {fold_num}')
        args.seed = init_seed + fold_num
        args.save_dir = os.path.join(save_dir, f'fold_{fold_num}')
        makedirs(args.save_dir)
        data.reset_features_and_targets()

        # If resuming experiment, load results from trained models
        test_scores_path = os.path.join(args.save_dir, 'test_scores.json')
        if args.resume_experiment and os.path.exists(test_scores_path):
            print('Loading scores')
            with open(test_scores_path) as f:
                model_scores = json.load(f)
        # Otherwise, train the models
        else:
            model_scores = train_func(args, data, logger)

        for metric, scores in model_scores.items():
            all_scores[metric].append(scores)
    all_scores = dict(all_scores)

    # Convert scores to numpy arrays
    for metric, scores in all_scores.items():
        all_scores[metric] = np.array(scores)

    # Report results
    info(f'{args.num_folds}-fold cross validation')

    # Report scores for each fold
    contains_nan_scores = False
    for fold_num in range(args.num_folds):
        for metric, scores in all_scores.items():
            info(f'\tSeed {init_seed + fold_num} ==> test {metric} = {multitask_mean(scores[fold_num], metric):.6f}')

            if args.show_individual_scores:
                for task_name, score in zip(args.task_names, scores[fold_num]):
                    info(f'\t\tSeed {init_seed + fold_num} ==> test {task_name} {metric} = {score:.6f}')
                    if np.isnan(score):
                        contains_nan_scores = True

    # Report scores across folds
    for metric, scores in all_scores.items():
        avg_scores = multitask_mean(scores, axis=1, metric=metric)  # average score for each model across tasks
        mean_score, std_score = np.mean(avg_scores), np.std(avg_scores)
        info(f'Overall test {metric} = {mean_score:.6f} +/- {std_score:.6f}')

        if args.show_individual_scores:
            for task_num, task_name in enumerate(args.task_names):
                info(f'\tOverall test {task_name} {metric} = '
                     f'{np.mean(scores[:, task_num]):.6f} +/- {np.std(scores[:, task_num]):.6f}')

    if contains_nan_scores:
        info("The metric scores observed for some fold test splits contain 'nan' values. \
            This can occur when the test set does not meet the requirements \
            for a particular metric, such as having no valid instances of one \
            task in the test set or not having positive examples for some classification metrics. \
            Before v1.5.1, the default behavior was to ignore nan values in individual folds or tasks \
            and still return an overall average for the remaining folds or tasks. The behavior now \
            is to include them in the average, converting overall average metrics to 'nan' as well.")

    # Save scores
    with open(os.path.join(save_dir, TEST_SCORES_FILE_NAME), 'w') as f:
        writer = csv.writer(f)

        header = ['Task']
        for metric in args.metrics:
            header += [f'Mean {metric}', f'Standard deviation {metric}'] + \
                      [f'Fold {i} {metric}' for i in range(args.num_folds)]
        writer.writerow(header)

        if args.dataset_type == 'spectra': # spectra data type has only one score to report
            row = ['spectra']
            for metric, scores in all_scores.items():
                task_scores = scores[:,0]
                mean, std = np.mean(task_scores), np.std(task_scores)
                row += [mean, std] + task_scores.tolist()
            writer.writerow(row)
        else: # all other data types, separate scores by task
            for task_num, task_name in enumerate(args.task_names):
                row = [task_name]
                for metric, scores in all_scores.items():
                    task_scores = scores[:, task_num]
                    mean, std = np.mean(task_scores), np.std(task_scores)
                    row += [mean, std] + task_scores.tolist()
                writer.writerow(row)

    # Determine mean and std score of main metric
    avg_scores = multitask_mean(all_scores[args.metric], metric=args.metric, axis=1)
    mean_score, std_score = np.mean(avg_scores), np.std(avg_scores)

    # Optionally merge and save test preds
    if args.save_preds:
        all_preds = pd.concat([pd.read_csv(os.path.join(save_dir, f'fold_{fold_num}', 'test_preds.csv'))
                               for fold_num in range(args.num_folds)])
        all_preds.to_csv(os.path.join(save_dir, 'test_preds.csv'), index=False)

    return mean_score, std_score
Beispiel #6
0
def cross_validate(
    args: TrainArgs, train_func: Callable[[TrainArgs, MoleculeDataset, Logger],
                                          Dict[str, List[float]]]
) -> Tuple[float, float]:
    """
    Runs k-fold cross-validation.

    For each of k splits (folds) of the data, trains and tests a model on that split
    and aggregates the performance across folds.

    :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for
                 loading data and training the Chemprop model.
    :param train_func: Function which runs training.
    :return: A tuple containing the mean and standard deviation performance across folds.
    """
    logger = create_logger(name=TRAIN_LOGGER_NAME,
                           save_dir=args.save_dir,
                           quiet=args.quiet)
    if logger is not None:
        debug, info = logger.debug, logger.info
    else:
        debug = info = print

    # Initialize relevant variables
    init_seed = args.seed
    save_dir = args.save_dir
    args.task_names = get_task_names(path=args.data_path,
                                     smiles_columns=args.smiles_columns,
                                     target_columns=args.target_columns,
                                     ignore_columns=args.ignore_columns)

    # Print command line
    debug('Command line')
    debug(f'python {" ".join(sys.argv)}')

    # Print args
    debug('Args')
    debug(args)

    # Save args
    makedirs(args.save_dir)
    args.save(os.path.join(args.save_dir, 'args.json'))

    #set explicit H option and reaction option
    set_explicit_h(args.explicit_h)
    set_reaction(args.reaction, args.reaction_mode)

    # Get data
    debug('Loading data')
    data = get_data(path=args.data_path,
                    args=args,
                    smiles_columns=args.smiles_columns,
                    logger=logger,
                    skip_none_targets=True)
    validate_dataset_type(data, dataset_type=args.dataset_type)
    args.features_size = data.features_size()

    if args.atom_descriptors == 'descriptor':
        args.atom_descriptors_size = data.atom_descriptors_size()
        args.ffn_hidden_size += args.atom_descriptors_size
    elif args.atom_descriptors == 'feature':
        args.atom_features_size = data.atom_features_size()
        set_extra_atom_fdim(args.atom_features_size)
    if args.bond_features_path is not None:
        args.bond_features_size = data.bond_features_size()
        set_extra_bond_fdim(args.bond_features_size)

    debug(f'Number of tasks = {args.num_tasks}')

    # Run training on different random seeds for each fold
    all_scores = defaultdict(list)
    for fold_num in range(args.num_folds):
        info(f'Fold {fold_num}')
        args.seed = init_seed + fold_num
        args.save_dir = os.path.join(save_dir, f'fold_{fold_num}')
        makedirs(args.save_dir)
        data.reset_features_and_targets()

        # If resuming experiment, load results from trained models
        test_scores_path = os.path.join(args.save_dir, 'test_scores.json')
        if args.resume_experiment and os.path.exists(test_scores_path):
            print('Loading scores')
            with open(test_scores_path) as f:
                model_scores = json.load(f)
        # Otherwise, train the models
        else:
            model_scores = train_func(args, data, logger)

        for metric, scores in model_scores.items():
            all_scores[metric].append(scores)
    all_scores = dict(all_scores)

    # Convert scores to numpy arrays
    for metric, scores in all_scores.items():
        all_scores[metric] = np.array(scores)

    # Report results
    info(f'{args.num_folds}-fold cross validation')

    # Report scores for each fold
    for fold_num in range(args.num_folds):
        for metric, scores in all_scores.items():
            info(
                f'\tSeed {init_seed + fold_num} ==> test {metric} = {np.nanmean(scores[fold_num]):.6f}'
            )

            if args.show_individual_scores:
                for task_name, score in zip(args.task_names, scores[fold_num]):
                    info(
                        f'\t\tSeed {init_seed + fold_num} ==> test {task_name} {metric} = {score:.6f}'
                    )

    # Report scores across folds
    for metric, scores in all_scores.items():
        avg_scores = np.nanmean(
            scores, axis=1)  # average score for each model across tasks
        mean_score, std_score = np.nanmean(avg_scores), np.nanstd(avg_scores)
        info(f'Overall test {metric} = {mean_score:.6f} +/- {std_score:.6f}')

        if args.show_individual_scores:
            for task_num, task_name in enumerate(args.task_names):
                info(
                    f'\tOverall test {task_name} {metric} = '
                    f'{np.nanmean(scores[:, task_num]):.6f} +/- {np.nanstd(scores[:, task_num]):.6f}'
                )

    # Save scores
    with open(os.path.join(save_dir, TEST_SCORES_FILE_NAME), 'w') as f:
        writer = csv.writer(f)

        header = ['Task']
        for metric in args.metrics:
            header += [f'Mean {metric}', f'Standard deviation {metric}'] + \
                      [f'Fold {i} {metric}' for i in range(args.num_folds)]
        writer.writerow(header)

        for task_num, task_name in enumerate(args.task_names):
            row = [task_name]
            for metric, scores in all_scores.items():
                task_scores = scores[:, task_num]
                mean, std = np.nanmean(task_scores), np.nanstd(task_scores)
                row += [mean, std] + task_scores.tolist()
            writer.writerow(row)

    # Determine mean and std score of main metric
    avg_scores = np.nanmean(all_scores[args.metric], axis=1)
    mean_score, std_score = np.nanmean(avg_scores), np.nanstd(avg_scores)

    # Optionally merge and save test preds
    if args.save_preds:
        all_preds = pd.concat([
            pd.read_csv(
                os.path.join(save_dir, f'fold_{fold_num}', 'test_preds.csv'))
            for fold_num in range(args.num_folds)
        ])
        all_preds.to_csv(os.path.join(save_dir, 'test_preds.csv'), index=False)

    return mean_score, std_score