Пример #1
0
    def __call__(self, smiles: List[str], batch_size: int = 500) -> List[List[float]]:
        """
        Makes predictions on a list of SMILES.

        :param smiles: A list of SMILES to make predictions on.
        :param batch_size: The batch size.
        :return: A list of lists of floats containing the predicted values.
        """
        test_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False, features_generator=self.args.features_generator)
        valid_indices = [i for i in range(len(test_data)) if test_data[i].mol is not None]
        test_data = MoleculeDataset([test_data[i] for i in valid_indices])

        if self.train_args.features_scaling:
            test_data.normalize_features(self.features_scaler)

        test_data_loader = MoleculeDataLoader(dataset=test_data, batch_size=batch_size)

        sum_preds = []
        for model in self.checkpoints:
            model_preds = predict(
                model=model,
                data_loader=test_data_loader,
                scaler=self.scaler,
                disable_progress_bar=True
            )
            sum_preds.append(np.array(model_preds))

        # Ensemble predictions
        sum_preds = sum(sum_preds)
        avg_preds = sum_preds / len(self.checkpoints)

        return avg_preds
Пример #2
0
def load_data(args: PredictArgs, smiles: List[List[str]]):
    """
    Function to load data from a list of smiles or a file.

    :param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
                 loading data and a model and making predictions.
    :param smiles: A list of list of smiles, or None if data is to be read from file
    :return: A tuple of a :class:`~chemprop.data.MoleculeDataset` containing all datapoints, a :class:`~chemprop.data.MoleculeDataset` containing only valid datapoints,
                 a :class:`~chemprop.data.MoleculeDataLoader` and a dictionary mapping full to valid indices.
    """
    print("Loading data")
    if smiles is not None:
        full_data = get_data_from_smiles(
            smiles=smiles,
            skip_invalid_smiles=False,
            features_generator=args.features_generator,
        )
    else:
        full_data = get_data(
            path=args.test_path,
            smiles_columns=args.smiles_columns,
            target_columns=[],
            ignore_columns=[],
            skip_invalid_smiles=False,
            args=args,
            store_row=not args.drop_extra_columns,
        )

    print("Validating SMILES")
    full_to_valid_indices = {}
    valid_index = 0
    for full_index in range(len(full_data)):
        if all(mol is not None for mol in full_data[full_index].mol):
            full_to_valid_indices[full_index] = valid_index
            valid_index += 1

    test_data = MoleculeDataset(
        [full_data[i] for i in sorted(full_to_valid_indices.keys())])

    print(f"Test size = {len(test_data):,}")

    # Create data loader
    test_data_loader = MoleculeDataLoader(dataset=test_data,
                                          batch_size=args.batch_size,
                                          num_workers=args.num_workers)

    return full_data, test_data, test_data_loader, full_to_valid_indices
Пример #3
0
def featurize_file(input_df, output_path, pretrained_model):
    smiles_list = input_df[input_df.columns[0]].tolist()
    print(len(smiles_list))
    data = get_data_from_smiles(smiles=[[smiles] for smiles in smiles_list])
    print("Starting molecule vector computation...")
    descriptors = compute_molecule_vectors(model=pretrained_model,
                                           data=data,
                                           batch_size=64)
    print("Computation finished, saving result...")
    smiles_descriptors_dict = {
        'smiles': smiles_list,
        'descriptors': descriptors
    }
    output_df = pd.DataFrame(smiles_descriptors_dict)
    output_df.to_csv(output_path,
                     mode='a+',
                     header=not os.path.exists(output_path),
                     encoding="ascii",
                     index=False)
Пример #4
0
def make_predictions(
        args: PredictArgs,
        smiles: List[List[str]] = None) -> List[List[Optional[float]]]:
    """
    Loads data and a trained model and uses the model to make predictions on the data.

    If SMILES are provided, then makes predictions on smiles.
    Otherwise makes predictions on :code:`args.test_data`.

    :param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
                 loading data and a model and making predictions.
    :param smiles: List of list of SMILES to make predictions on.
    :return: A list of lists of target predictions.
    """
    print('Loading training args')
    train_args = load_args(args.checkpoint_paths[0])
    num_tasks, task_names = train_args.num_tasks, train_args.task_names

    # If features were used during training, they must be used when predicting
    if ((train_args.features_path is not None
         or train_args.features_generator is not None)
            and args.features_path is None
            and args.features_generator is None):
        raise ValueError(
            'Features were used during training so they must be specified again during prediction '
            'using the same type of features as before (with either --features_generator or '
            '--features_path and using --no_features_scaling if applicable).')

    # Update predict args with training arguments to create a merged args object
    for key, value in vars(train_args).items():
        if not hasattr(args, key):
            setattr(args, key, value)
    args: Union[PredictArgs, TrainArgs]

    print('Loading data')
    if smiles is not None:
        full_data = get_data_from_smiles(
            smiles=smiles,
            skip_invalid_smiles=False,
            features_generator=args.features_generator)
    else:
        full_data = get_data(path=args.test_path,
                             target_columns=[],
                             ignore_columns=[],
                             skip_invalid_smiles=False,
                             args=args,
                             store_row=True)

    print('Validating SMILES')
    full_to_valid_indices = {}
    valid_index = 0
    for full_index in range(len(full_data)):
        if all(mol is not None for mol in full_data[full_index].mol):
            full_to_valid_indices[full_index] = valid_index
            valid_index += 1

    test_data = MoleculeDataset(
        [full_data[i] for i in sorted(full_to_valid_indices.keys())])

    # Edge case if empty list of smiles is provided
    if len(test_data) == 0:
        return [None] * len(full_data)

    print(f'Test size = {len(test_data):,}')

    # Predict with each model individually and sum predictions
    if args.dataset_type == 'multiclass':
        sum_preds = np.zeros(
            (len(test_data), num_tasks, args.multiclass_num_classes))
    else:
        sum_preds = np.zeros((len(test_data), num_tasks))

    # Create data loader
    test_data_loader = MoleculeDataLoader(dataset=test_data,
                                          batch_size=args.batch_size,
                                          num_workers=args.num_workers)

    print(
        f'Predicting with an ensemble of {len(args.checkpoint_paths)} models')
    for checkpoint_path in tqdm(args.checkpoint_paths,
                                total=len(args.checkpoint_paths)):
        # Load model and scalers
        model = load_checkpoint(checkpoint_path, device=args.device)
        scaler, features_scaler = load_scalers(checkpoint_path)

        # Normalize features
        if args.features_scaling:
            test_data.reset_features_and_targets()
            test_data.normalize_features(features_scaler)

        # Make predictions
        model_preds = predict(model=model,
                              data_loader=test_data_loader,
                              scaler=scaler)
        sum_preds += np.array(model_preds)

    # Ensemble predictions
    avg_preds = sum_preds / len(args.checkpoint_paths)
    avg_preds = avg_preds.tolist()

    # Save predictions
    print(f'Saving predictions to {args.preds_path}')
    assert len(test_data) == len(avg_preds)
    makedirs(args.preds_path, isfile=True)

    # Get prediction column names
    if args.dataset_type == 'multiclass':
        task_names = [
            f'{name}_class_{i}' for name in task_names
            for i in range(args.multiclass_num_classes)
        ]
    else:
        task_names = task_names

    # Copy predictions over to full_data
    for full_index, datapoint in enumerate(full_data):
        valid_index = full_to_valid_indices.get(full_index, None)
        preds = avg_preds[valid_index] if valid_index is not None else [
            'Invalid SMILES'
        ] * len(task_names)

        for pred_name, pred in zip(task_names, preds):
            datapoint.row[pred_name] = pred

    # Save
    with open(args.preds_path, 'w') as f:
        writer = csv.DictWriter(f, fieldnames=full_data[0].row.keys())
        writer.writeheader()

        for datapoint in full_data:
            writer.writerow(datapoint.row)

    return avg_preds
Пример #5
0
def make_predictions(
        args: PredictArgs,
        smiles: List[List[str]] = None) -> List[List[Optional[float]]]:
    """
    Loads data and a trained model and uses the model to make predictions on the data.

    If SMILES are provided, then makes predictions on smiles.
    Otherwise makes predictions on :code:`args.test_data`.

    :param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
                 loading data and a model and making predictions.
    :param smiles: List of list of SMILES to make predictions on.
    :return: A list of lists of target predictions.
    """
    print("Loading training args")
    train_args = load_args(args.checkpoint_paths[0])
    num_tasks, task_names = train_args.num_tasks, train_args.task_names

    update_prediction_args(predict_args=args, train_args=train_args)
    args: Union[PredictArgs, TrainArgs]

    if args.atom_descriptors == "feature":
        set_extra_atom_fdim(train_args.atom_features_size)

    if args.bond_features_path is not None:
        set_extra_bond_fdim(train_args.bond_features_size)

    # set explicit H option and reaction option
    set_explicit_h(train_args.explicit_h)
    set_reaction(train_args.reaction, train_args.reaction_mode)

    print("Loading data")
    if smiles is not None:
        full_data = get_data_from_smiles(
            smiles=smiles,
            skip_invalid_smiles=False,
            features_generator=args.features_generator,
        )
    else:
        full_data = get_data(
            path=args.test_path,
            smiles_columns=args.smiles_columns,
            target_columns=[],
            ignore_columns=[],
            skip_invalid_smiles=False,
            args=args,
            store_row=not args.drop_extra_columns,
        )

    print("Validating SMILES")
    full_to_valid_indices = {}
    valid_index = 0
    for full_index in range(len(full_data)):
        if all(mol is not None for mol in full_data[full_index].mol):
            full_to_valid_indices[full_index] = valid_index
            valid_index += 1

    test_data = MoleculeDataset(
        [full_data[i] for i in sorted(full_to_valid_indices.keys())])

    # Edge case if empty list of smiles is provided
    if len(test_data) == 0:
        return [None] * len(full_data)

    print(f"Test size = {len(test_data):,}")

    # Predict with each model individually and sum predictions
    if args.dataset_type == "multiclass":
        sum_preds = np.zeros(
            (len(test_data), num_tasks, args.multiclass_num_classes))
    else:
        sum_preds = np.zeros((len(test_data), num_tasks))

    # Create data loader
    test_data_loader = MoleculeDataLoader(
        dataset=test_data,
        batch_size=args.batch_size,
        num_workers=0 if sys.platform == "darwin" else args.num_workers,
    )

    # Partial results for variance robust calculation.
    if args.ensemble_variance:
        all_preds = np.zeros(
            (len(test_data), num_tasks, len(args.checkpoint_paths)))

    print(
        f"Predicting with an ensemble of {len(args.checkpoint_paths)} models")
    for index, checkpoint_path in enumerate(
            tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths))):
        # Load model and scalers
        model = load_checkpoint(checkpoint_path, device=args.device)
        (
            scaler,
            features_scaler,
            atom_descriptor_scaler,
            bond_feature_scaler,
        ) = load_scalers(checkpoint_path)

        # Normalize features
        if (args.features_scaling or train_args.atom_descriptor_scaling
                or train_args.bond_feature_scaling):
            test_data.reset_features_and_targets()
            if args.features_scaling:
                test_data.normalize_features(features_scaler)
            if (train_args.atom_descriptor_scaling
                    and args.atom_descriptors is not None):
                test_data.normalize_features(atom_descriptor_scaler,
                                             scale_atom_descriptors=True)
            if train_args.bond_feature_scaling and args.bond_features_size > 0:
                test_data.normalize_features(bond_feature_scaler,
                                             scale_bond_features=True)

        # Make predictions
        model_preds = predict(model=model,
                              data_loader=test_data_loader,
                              scaler=scaler)
        sum_preds += np.array(model_preds)
        if args.ensemble_variance:
            all_preds[:, :, index] = model_preds

    # Ensemble predictions
    avg_preds = sum_preds / len(args.checkpoint_paths)
    avg_preds = avg_preds.tolist()

    if args.ensemble_variance:
        all_epi_uncs = np.var(all_preds, axis=2)
        all_epi_uncs = all_epi_uncs.tolist()

    # Save predictions
    print(f"Saving predictions to {args.preds_path}")
    assert len(test_data) == len(avg_preds)
    if args.ensemble_variance:
        assert len(test_data) == len(all_epi_uncs)
    makedirs(args.preds_path, isfile=True)

    # Get prediction column names
    if args.dataset_type == "multiclass":
        task_names = [
            f"{name}_class_{i}" for name in task_names
            for i in range(args.multiclass_num_classes)
        ]
    else:
        task_names = task_names

    # Copy predictions over to full_data
    for full_index, datapoint in enumerate(full_data):
        valid_index = full_to_valid_indices.get(full_index, None)
        preds = (avg_preds[valid_index] if valid_index is not None else
                 ["Invalid SMILES"] * len(task_names))
        if args.ensemble_variance:
            epi_uncs = (all_epi_uncs[valid_index] if valid_index is not None
                        else ["Invalid SMILES"] * len(task_names))

        # If extra columns have been dropped, add back in SMILES columns
        if args.drop_extra_columns:
            datapoint.row = OrderedDict()

            smiles_columns = args.smiles_columns

            for column, smiles in zip(smiles_columns, datapoint.smiles):
                datapoint.row[column] = smiles

        # Add predictions columns
        if args.ensemble_variance:
            for pred_name, pred, epi_unc in zip(task_names, preds, epi_uncs):
                datapoint.row[pred_name] = pred
                datapoint.row[pred_name + "_epi_unc"] = epi_unc
        else:
            for pred_name, pred in zip(task_names, preds):
                datapoint.row[pred_name] = pred

    # Save
    with open(args.preds_path, "w") as f:
        writer = csv.DictWriter(f, fieldnames=full_data[0].row.keys())
        writer.writeheader()

        for datapoint in full_data:
            writer.writerow(datapoint.row)

    return avg_preds
Пример #6
0
def molecule_fingerprint(
        args: FingerprintArgs,
        smiles: List[List[str]] = None) -> List[List[Optional[float]]]:
    """
    Loads data and a trained model and uses the model to encode fingerprint vectors for the data.

    :param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
                 loading data and a model and making predictions.
    :param smiles: List of list of SMILES to make predictions on.
    :return: A list of fingerprint vectors (list of floats)
    """

    print('Loading training args')
    train_args = load_args(args.checkpoint_paths[0])

    # Update args with training arguments
    if args.fingerprint_type == 'MPN':  # only need to supply input features if using FFN latent representation and if model calls for them.
        validate_feature_sources = False
    else:
        validate_feature_sources = True
    update_prediction_args(predict_args=args,
                           train_args=train_args,
                           validate_feature_sources=validate_feature_sources)
    args: Union[FingerprintArgs, TrainArgs]

    #set explicit H option and reaction option
    reset_featurization_parameters()
    if args.atom_descriptors == 'feature':
        set_extra_atom_fdim(train_args.atom_features_size)

    if args.bond_features_path is not None:
        set_extra_bond_fdim(train_args.bond_features_size)

    set_explicit_h(train_args.explicit_h)
    set_adding_hs(args.adding_h)
    if train_args.reaction:
        set_reaction(train_args.reaction, train_args.reaction_mode)
    elif train_args.reaction_solvent:
        set_reaction(True, train_args.reaction_mode)

    print('Loading data')
    if smiles is not None:
        full_data = get_data_from_smiles(
            smiles=smiles,
            skip_invalid_smiles=False,
            features_generator=args.features_generator)
    else:
        full_data = get_data(path=args.test_path,
                             smiles_columns=args.smiles_columns,
                             target_columns=[],
                             ignore_columns=[],
                             skip_invalid_smiles=False,
                             args=args,
                             store_row=True)

    print('Validating SMILES')
    full_to_valid_indices = {}
    valid_index = 0
    for full_index in range(len(full_data)):
        if all(mol is not None for mol in full_data[full_index].mol):
            full_to_valid_indices[full_index] = valid_index
            valid_index += 1

    test_data = MoleculeDataset(
        [full_data[i] for i in sorted(full_to_valid_indices.keys())])

    # Edge case if empty list of smiles is provided
    if len(test_data) == 0:
        return [None] * len(full_data)

    print(f'Test size = {len(test_data):,}')

    # Create data loader
    test_data_loader = MoleculeDataLoader(dataset=test_data,
                                          batch_size=args.batch_size,
                                          num_workers=args.num_workers)

    # Set fingerprint size
    if args.fingerprint_type == 'MPN':
        if args.atom_descriptors == "descriptor":  # special case when we have 'descriptor' extra dimensions need to be added
            total_fp_size = (
                args.hidden_size +
                test_data.atom_descriptors_size()) * args.number_of_molecules
        else:
            if args.reaction_solvent:
                total_fp_size = args.hidden_size + args.hidden_size_solvent
            else:
                total_fp_size = args.hidden_size * args.number_of_molecules
        if args.features_only:
            raise ValueError(
                'With features_only models, there is no latent MPN representation. Use last_FFN fingerprint type instead.'
            )
    elif args.fingerprint_type == 'last_FFN':
        if args.ffn_num_layers != 1:
            total_fp_size = args.ffn_hidden_size
        else:
            raise ValueError(
                'With a ffn_num_layers of 1, there is no latent FFN representation. Use MPN fingerprint type instead.'
            )
    else:
        raise ValueError(
            f'Fingerprint type {args.fingerprint_type} not supported')
    all_fingerprints = np.zeros(
        (len(test_data), total_fp_size, len(args.checkpoint_paths)))

    # Load model
    print(
        f'Encoding smiles into a fingerprint vector from {len(args.checkpoint_paths)} models.'
    )

    for index, checkpoint_path in enumerate(
            tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths))):
        model = load_checkpoint(checkpoint_path, device=args.device)
        scaler, features_scaler, atom_descriptor_scaler, bond_feature_scaler = load_scalers(
            args.checkpoint_paths[index])

        # Normalize features
        if args.features_scaling or train_args.atom_descriptor_scaling or train_args.bond_feature_scaling:
            test_data.reset_features_and_targets()
            if args.features_scaling:
                test_data.normalize_features(features_scaler)
            if train_args.atom_descriptor_scaling and args.atom_descriptors is not None:
                test_data.normalize_features(atom_descriptor_scaler,
                                             scale_atom_descriptors=True)
            if train_args.bond_feature_scaling and args.bond_features_size > 0:
                test_data.normalize_features(bond_feature_scaler,
                                             scale_bond_features=True)

        # Make fingerprints
        model_fp = model_fingerprint(model=model,
                                     data_loader=test_data_loader,
                                     fingerprint_type=args.fingerprint_type)
        if args.fingerprint_type == 'MPN' and (
                args.features_path is not None or args.features_generator
        ):  # truncate any features from MPN fingerprint
            model_fp = np.array(model_fp)[:, :total_fp_size]
        all_fingerprints[:, :, index] = model_fp

    # Save predictions
    print(f'Saving predictions to {args.preds_path}')
    # assert len(test_data) == len(all_fingerprints) #TODO: add unit test for this
    makedirs(args.preds_path, isfile=True)

    # Set column names
    fingerprint_columns = []
    if args.fingerprint_type == 'MPN':
        if len(args.checkpoint_paths) == 1:
            for j in range(total_fp_size // args.number_of_molecules):
                for k in range(args.number_of_molecules):
                    fingerprint_columns.append(f'fp_{j}_mol_{k}')
        else:
            for j in range(total_fp_size // args.number_of_molecules):
                for i in range(len(args.checkpoint_paths)):
                    for k in range(args.number_of_molecules):
                        fingerprint_columns.append(f'fp_{j}_mol_{k}_model_{i}')

    else:  # args == 'last_FNN'
        if len(args.checkpoint_paths) == 1:
            for j in range(total_fp_size):
                fingerprint_columns.append(f'fp_{j}')
        else:
            for j in range(total_fp_size):
                for i in range(len(args.checkpoint_paths)):
                    fingerprint_columns.append(f'fp_{j}_model_{i}')

    # Copy predictions over to full_data
    for full_index, datapoint in enumerate(full_data):
        valid_index = full_to_valid_indices.get(full_index, None)
        preds = all_fingerprints[valid_index].reshape(
            (len(args.checkpoint_paths) * total_fp_size
             )) if valid_index is not None else ['Invalid SMILES'] * len(
                 args.checkpoint_paths) * total_fp_size

        for i in range(len(fingerprint_columns)):
            datapoint.row[fingerprint_columns[i]] = preds[i]

    # Write predictions
    with open(args.preds_path, 'w') as f:
        writer = csv.DictWriter(f,
                                fieldnames=args.smiles_columns +
                                fingerprint_columns,
                                extrasaction='ignore')
        writer.writeheader()
        for datapoint in full_data:
            writer.writerow(datapoint.row)

    return all_fingerprints
Пример #7
0
def molecule_fingerprint(
        args: PredictArgs,
        smiles: List[List[str]] = None) -> List[List[Optional[float]]]:
    """
    Loads data and a trained model and uses the model to encode fingerprint vectors for the data.

    :param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
                 loading data and a model and making predictions.
    :param smiles: List of list of SMILES to make predictions on.
    :return: A list of fingerprint vectors (list of floats)
    """

    print('Loading training args')
    train_args = load_args(args.checkpoint_paths[0])

    # Update args with training arguments
    update_prediction_args(predict_args=args,
                           train_args=train_args,
                           validate_feature_sources=False)
    args: Union[PredictArgs, TrainArgs]

    #set explicit H option and reaction option
    set_explicit_h(train_args.explicit_h)
    set_reaction(train_args.reaction, train_args.reaction_mode)

    print('Loading data')
    if smiles is not None:
        full_data = get_data_from_smiles(
            smiles=smiles,
            skip_invalid_smiles=False,
            features_generator=args.features_generator)
    else:
        full_data = get_data(path=args.test_path,
                             smiles_columns=args.smiles_columns,
                             target_columns=[],
                             ignore_columns=[],
                             skip_invalid_smiles=False,
                             args=args,
                             store_row=True)

    print('Validating SMILES')
    full_to_valid_indices = {}
    valid_index = 0
    for full_index in range(len(full_data)):
        if all(mol is not None for mol in full_data[full_index].mol):
            full_to_valid_indices[full_index] = valid_index
            valid_index += 1

    test_data = MoleculeDataset(
        [full_data[i] for i in sorted(full_to_valid_indices.keys())])

    # Edge case if empty list of smiles is provided
    if len(test_data) == 0:
        return [None] * len(full_data)

    print(f'Test size = {len(test_data):,}')

    # Create data loader
    test_data_loader = MoleculeDataLoader(dataset=test_data,
                                          batch_size=args.batch_size,
                                          num_workers=args.num_workers)

    # Load model
    print(f'Encoding smiles into a fingerprint vector from a single model')
    if len(args.checkpoint_paths) != 1:
        raise ValueError(
            "Fingerprint generation only supports one model, cannot use an ensemble"
        )

    model = load_checkpoint(args.checkpoint_paths[0], device=args.device)
    scaler, features_scaler, atom_descriptor_scaler, bond_feature_scaler = load_scalers(
        args.checkpoint_paths[0])

    # Normalize features
    if args.features_scaling or train_args.atom_descriptor_scaling or train_args.bond_feature_scaling:
        test_data.reset_features_and_targets()
        if args.features_scaling:
            test_data.normalize_features(features_scaler)
        if train_args.atom_descriptor_scaling and args.atom_descriptors is not None:
            test_data.normalize_features(atom_descriptor_scaler,
                                         scale_atom_descriptors=True)
        if train_args.bond_feature_scaling and args.bond_features_size > 0:
            test_data.normalize_features(bond_feature_scaler,
                                         scale_bond_features=True)

    # Make fingerprints
    model_preds = model_fingerprint(model=model, data_loader=test_data_loader)

    # Save predictions
    print(f'Saving predictions to {args.preds_path}')
    assert len(test_data) == len(model_preds)
    makedirs(args.preds_path, isfile=True)

    # Copy predictions over to full_data
    total_hidden_size = args.hidden_size * args.number_of_molecules
    for full_index, datapoint in enumerate(full_data):
        valid_index = full_to_valid_indices.get(full_index, None)
        preds = model_preds[valid_index] if valid_index is not None else [
            'Invalid SMILES'
        ] * total_hidden_size

        fingerprint_columns = [f'fp_{i}' for i in range(total_hidden_size)]
        for i in range(len(fingerprint_columns)):
            datapoint.row[fingerprint_columns[i]] = preds[i]

    # Write predictions
    with open(args.preds_path, 'w') as f:
        writer = csv.DictWriter(f,
                                fieldnames=args.smiles_columns +
                                fingerprint_columns,
                                extrasaction='ignore')
        writer.writeheader()
        for datapoint in full_data:
            writer.writerow(datapoint.row)

    return model_preds
Пример #8
0
def find_similar_mols(test_smiles: List[str],
                      train_smiles: List[str],
                      distance_measure: str,
                      model: MoleculeModel = None,
                      num_neighbors: int = None,
                      batch_size: int = 50) -> List[OrderedDict]:
    """
    For each test molecule, finds the N most similar training molecules according to some distance measure.

    :param test_smiles: A list of test SMILES strings.
    :param train_smiles: A list of train SMILES strings.
    :param model: A trained MoleculeModel (only needed for distance_measure == 'embedding').
    :param distance_measure: The distance measure to use to determine nearest neighbors.
    :param num_neighbors: The number of nearest training molecules to find for each test molecule.
    :param batch_size: Batch size.
    :return: A list of OrderedDicts containing the test smiles, the num_neighbors nearest training smiles,
    and other relevant distance info.
    """
    test_data = get_data_from_smiles(smiles=[[smiles]
                                             for smiles in test_smiles])
    train_data = get_data_from_smiles(smiles=[[smiles]
                                              for smiles in train_smiles])
    train_smiles_set = set(train_smiles)

    # Create data loader
    test_data_loader = MoleculeDataLoader(dataset=test_data,
                                          batch_size=batch_size,
                                          num_workers=args.num_workers)
    train_data_loader = MoleculeDataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           num_workers=args.num_workers)

    print(f'Computing {distance_measure} vectors')
    if distance_measure == 'embedding':
        assert model is not None
        test_vecs = np.array(
            model_fingerprint(model=model,
                              data_loader=test_data_loader,
                              fingerprint_type='last_FFN'))
        train_vecs = np.array(
            model_fingerprint(model=model,
                              data_loader=train_data_loader,
                              fingerprint_type='last_FFN'))
        metric = 'cosine'
    elif distance_measure == 'morgan':
        test_vecs = np.array([
            morgan_binary_features_generator(smiles)
            for smiles in tqdm(test_smiles, total=len(test_smiles))
        ])
        train_vecs = np.array([
            morgan_binary_features_generator(smiles)
            for smiles in tqdm(train_smiles, total=len(train_smiles))
        ])
        metric = 'jaccard'
    elif distance_measure == 'tanimoto':
        # Generate RDKit topological fingerprints
        test_fps = [Chem.RDKFingerprint(m.mol[0]) for m in tqdm(test_data)]
        train_fps = [Chem.RDKFingerprint(m.mol[0]) for m in tqdm(train_data)]

        # Compute pairwise similarity
        print('Computing distances')
        similarity = np.zeros([len(test_fps), len(train_fps)])
        for (x, y), _ in np.ndenumerate(similarity):
            similarity[x, y] = DataStructs.FingerprintSimilarity(
                test_fps[x], train_fps[y])

        # Convert the tanimoto similarity to a distance
        distances = 1 - similarity
        metric = 'tanimoto'
    else:
        raise ValueError(
            f'Distance measure "{distance_measure}" not supported.')

    if distance_measure in ('embedding', 'morgan'):
        print('Computing distances')
        distances = cdist(test_vecs, train_vecs, metric=metric)

    print('Finding neighbors')
    neighbors = []
    for test_index, test_smile in enumerate(test_smiles):
        # Find the num_neighbors molecules in the training set which are most similar to the test molecule
        nearest_train_indices = np.argsort(
            distances[test_index])[:num_neighbors]

        # Build dictionary with distance info
        neighbor = OrderedDict()
        neighbor['test_smiles'] = test_smile
        neighbor['test_in_train'] = test_smile in train_smiles_set

        for i, train_index in enumerate(nearest_train_indices):
            neighbor[f'train_{i + 1}_smiles'] = train_smiles[train_index]
            neighbor[
                f'train_{i + 1}_{distance_measure}_{metric}_distance'] = distances[
                    test_index][train_index]

        neighbors.append(neighbor)

    return neighbors