示例#1
0
def get_class_sizes(data: MoleculeDataset) -> List[List[float]]:
    """
    Determines the proportions of the different classes in the classification dataset.

    :param data: A classification dataset
    :return: A list of lists of class proportions. Each inner list contains the class proportions
    for a task.
    """
    targets = data.targets()

    # Filter out Nones
    valid_targets = [[] for _ in range(data.num_tasks())]
    for i in range(len(targets)):
        for task_num in range(len(targets[i])):
            if targets[i][task_num] is not None:
                valid_targets[task_num].append(targets[i][task_num])

    class_sizes = []
    for task_targets in valid_targets:
        # Make sure we're dealing with a binary classification task
        assert set(np.unique(task_targets)) <= {0, 1}

        try:
            ones = np.count_nonzero(task_targets) / len(task_targets)
        except ZeroDivisionError:
            ones = float('nan')
            print('Warning: class has no targets')
        class_sizes.append([1 - ones, ones])

    return class_sizes
示例#2
0
def generate_fingerprints(args: Namespace, logger: Logger = None) -> List[List[float]]:
    """
    Generate the fingerprints.

    :param logger:
    :param args: Arguments.
    :return: A list of lists of target fingerprints.
    """

    checkpoint_path = args.checkpoint_paths[0]
    if logger is None:
        logger = create_logger('fingerprints', quiet=False)
    print('Loading data')
    test_data = get_data(path=args.data_path,
                         args=args,
                         use_compound_names=False,
                         max_data_size=float("inf"),
                         skip_invalid_smiles=False)
    test_data = MoleculeDataset(test_data)

    logger.info(f'Total size = {len(test_data):,}')
    logger.info(f'Generating...')
    # Load model
    model = load_checkpoint(checkpoint_path, cuda=args.cuda, current_args=args, logger=logger)
    model_preds = do_generate(
        model=model,
        data=test_data,
        args=args
    )

    return model_preds
示例#3
0
def get_data_from_smiles(smiles: List[str],
                         skip_invalid_smiles: bool = True,
                         logger: Logger = None,
                         args: Namespace = None) -> MoleculeDataset:
    """
    Converts SMILES to a MoleculeDataset.

    :param smiles: A list of SMILES strings.
    :param skip_invalid_smiles: Whether to skip and filter out invalid smiles.
    :param logger: Logger.
    :return: A MoleculeDataset with all of the provided SMILES.
    """
    debug = logger.debug if logger is not None else print

    data = MoleculeDataset(
        [MoleculeDatapoint(line=[smile], args=args) for smile in smiles])

    # Filter out invalid SMILES
    if skip_invalid_smiles:
        original_data_len = len(data)
        data = filter_invalid_smiles(data)

        if len(data) < original_data_len:
            debug(
                f'Warning: {original_data_len - len(data)} SMILES are invalid.'
            )

    return data
示例#4
0
def evaluate(model: nn.Module,
             data: MoleculeDataset,
             num_tasks: int,
             metric_func,
             loss_func,
             batch_size: int,
             dataset_type: str,
             args: Namespace,
             shared_dict,
             scaler: StandardScaler = None,
             logger = None) -> List[float]:
    """
    Evaluates an ensemble of models on a dataset.

    :param model: A model.
    :param data: A MoleculeDataset.
    :param num_tasks: Number of tasks.
    :param metric_func: Metric function which takes in a list of targets and a list of predictions.
    :param batch_size: Batch size.
    :param dataset_type: Dataset type.
    :param scaler: A StandardScaler object fit on the training targets.
    :param logger: Logger.
    :return: A list with the score for each task based on `metric_func`.
    """
    preds, loss_avg = predict(
        model=model,
        data=data,
        loss_func=loss_func,
        batch_size=batch_size,
        scaler=scaler,
        shared_dict=shared_dict,
        logger=logger,
        args=args
    )

    targets = data.targets()
    if scaler is not None:
        targets = scaler.inverse_transform(targets)



    results = evaluate_predictions(
        preds=preds,
        targets=targets,
        num_tasks=num_tasks,
        metric_func=metric_func,
        dataset_type=dataset_type,
        logger=logger
    )

    return results, loss_avg
示例#5
0
def filter_invalid_smiles(data: MoleculeDataset) -> MoleculeDataset:
    """
    Filters out invalid SMILES.

    :param data: A MoleculeDataset.
    :return: A MoleculeDataset with only valid molecules.
    """
    datapoint_list = []
    for idx, datapoint in enumerate(data):
        if datapoint.smiles == '':
            print(f'invalid smiles {idx}: {datapoint.smiles}')
            continue
        mol = Chem.MolFromSmiles(datapoint.smiles)
        if mol.GetNumHeavyAtoms() == 0:
            print(f'invalid heavy {idx}')
            continue
        datapoint_list.append(datapoint)
    return MoleculeDataset(datapoint_list)
示例#6
0
def make_predictions(args: Namespace, newest_train_args=None, smiles: List[str] = None):
    """
    Makes predictions. If smiles is provided, makes predictions on smiles.
    Otherwise makes predictions on args.test_data.

    :param args: Arguments.
    :param smiles: Smiles to make predictions on.
    :return: A list of lists of target predictions.
    """
    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)

    print('Loading training args')

    path = args.checkpoint_paths[0]
    scaler, features_scaler = load_scalars(path)
    train_args = load_args(path)

    # Update args with training arguments saved in checkpoint
    for key, value in vars(train_args).items():
        if not hasattr(args, key):
            setattr(args, key, value)

    # update args with newest training args
    if newest_train_args is not None:
        for key, value in vars(newest_train_args).items():
            if not hasattr(args, key):
                setattr(args, key, value)


    # deal with multiprocess problem
    args.debug = True

    logger = create_logger('predict', quiet=False)
    print('Loading data')
    args.task_names = get_task_names(args.data_path)
    if smiles is not None:
        test_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False)
    else:
        test_data = get_data(path=args.data_path, args=args,
                             use_compound_names=args.use_compound_names, skip_invalid_smiles=False)


    args.num_tasks = test_data.num_tasks()
    args.features_size = test_data.features_size()

    print('Validating SMILES')
    valid_indices = [i for i in range(len(test_data))]
    full_data = test_data
    # test_data = MoleculeDataset([test_data[i] for i in valid_indices])
    test_data_list = []
    for i in valid_indices:
        test_data_list.append(test_data[i])
    test_data = MoleculeDataset(test_data_list)

    # Edge case if empty list of smiles is provided
    if len(test_data) == 0:
        return [None] * len(full_data)

    print(f'Test size = {len(test_data):,}')

    # Normalize features
    if hasattr(train_args, 'features_scaling'):
        if train_args.features_scaling:
            test_data.normalize_features(features_scaler)

    # Predict with each model individually and sum predictions
    if hasattr(args, 'num_tasks'):
        sum_preds = np.zeros((len(test_data), args.num_tasks))
    print(f'Predicting...')
    shared_dict = {}
    # loss_func = torch.nn.BCEWithLogitsLoss()
    count = 0
    for checkpoint_path in tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths)):
        # Load model
        model = load_checkpoint(checkpoint_path, cuda=args.cuda, current_args=args, logger=logger)
        model_preds, _ = predict(
            model=model,
            data=test_data,
            batch_size=args.batch_size,
            scaler=scaler,
            shared_dict=shared_dict,
            args=args,
            logger=logger,
            loss_func=None
        )

        if args.fingerprint:
            return model_preds

        sum_preds += np.array(model_preds, dtype=float)
        count += 1

    # Ensemble predictions
    avg_preds = sum_preds / len(args.checkpoint_paths)

    # Save predictions
    assert len(test_data) == len(avg_preds)

    # Put Nones for invalid smiles
    args.valid_indices = valid_indices
    avg_preds = np.array(avg_preds)
    test_smiles = full_data.smiles()
    return avg_preds, test_smiles
示例#7
0
def scaffold_split(
    data: MoleculeDataset,
    sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1),
    balanced: bool = False,
    seed: int = 0,
    logger: logging.Logger = None
) -> Tuple[MoleculeDataset, MoleculeDataset, MoleculeDataset]:
    """
    Split a dataset by scaffold so that no molecules sharing a scaffold are in the same split.

    :param data: A MoleculeDataset.
    :param sizes: A length-3 tuple with the proportions of data in the
    train, validation, and test sets.
    :param balanced: Try to balance sizes of scaffolds in each set, rather than just putting smallest in test set.
    :param seed: Seed for shuffling when doing balanced splitting.
    :param logger: A logger.
    :return: A tuple containing the train, validation, and test splits of the data.
    """
    assert sum(sizes) == 1

    # Split
    train_size, val_size, test_size = sizes[0] * len(data), sizes[1] * len(
        data), sizes[2] * len(data)
    train, val, test = [], [], []
    train_scaffold_count, val_scaffold_count, test_scaffold_count = 0, 0, 0

    # Map from scaffold to index in the data
    scaffold_to_indices = scaffold_to_smiles(data.smiles(), use_indices=True)

    if balanced:  # Put stuff that's bigger than half the val/test size into train, rest just order randomly
        index_sets = list(scaffold_to_indices.values())
        big_index_sets = []
        small_index_sets = []
        for index_set in index_sets:
            if len(index_set) > val_size / 2 or len(index_set) > test_size / 2:
                big_index_sets.append(index_set)
            else:
                small_index_sets.append(index_set)
        random.seed(seed)
        random.shuffle(big_index_sets)
        random.shuffle(small_index_sets)
        index_sets = big_index_sets + small_index_sets
    else:  # Sort from largest to smallest scaffold sets
        index_sets = sorted(list(scaffold_to_indices.values()),
                            key=lambda index_set: len(index_set),
                            reverse=True)

    for index_set in index_sets:
        if len(train) + len(index_set) <= train_size:
            train += index_set
            train_scaffold_count += 1
        elif len(val) + len(index_set) <= val_size:
            val += index_set
            val_scaffold_count += 1
        else:
            test += index_set
            test_scaffold_count += 1

    if logger is not None:
        logger.debug(f'Total scaffolds = {len(scaffold_to_indices):,} | '
                     f'train scaffolds = {train_scaffold_count:,} | '
                     f'val scaffolds = {val_scaffold_count:,} | '
                     f'test scaffolds = {test_scaffold_count:,}')

    log_scaffold_stats(data, index_sets, logger=logger)

    # Map from indices to data
    train = [data[i] for i in train]
    val = [data[i] for i in val]
    test = [data[i] for i in test]

    return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)
示例#8
0
def split_data(
    data: MoleculeDataset,
    split_type: str = 'random',
    sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1),
    seed: int = 0,
    args: Namespace = None,
    logger: Logger = None
) -> Tuple[MoleculeDataset, MoleculeDataset, MoleculeDataset]:
    """
    Splits data into training, validation, and test splits.

    :param data: A MoleculeDataset.
    :param split_type: Split type.
    :param sizes: A length-3 tuple with the proportions of data in the
    train, validation, and test sets.
    :param seed: The random seed to use before shuffling data.
    :param args: Namespace of arguments.
    :param logger: A logger.
    :return: A tuple containing the train, validation, and test splits of the data.
    """
    assert len(sizes) == 3 and sum(sizes) == 1

    if args is not None:
        folds_file, val_fold_index, test_fold_index = \
            args.folds_file, args.val_fold_index, args.test_fold_index
    else:
        folds_file = val_fold_index = test_fold_index = None

    if split_type == 'crossval':
        index_set = args.crossval_index_sets[args.seed]
        data_split = []
        for split in range(3):
            split_indices = []
            for index in index_set[split]:
                with open(
                        os.path.join(args.crossval_index_dir, f'{index}.pkl'),
                        'rb') as rf:
                    split_indices.extend(pickle.load(rf))
            data_split.append([data[i] for i in split_indices])
        train, val, test = tuple(data_split)
        return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(
            test)

    elif split_type == 'index_predetermined':
        split_indices = args.crossval_index_sets[args.seed]
        assert len(split_indices) == 3
        data_split = []
        for split in range(3):
            data_split.append([data[i] for i in split_indices[split]])
        train, val, test = tuple(data_split)
        return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(
            test)

    elif split_type == 'predetermined':
        if not val_fold_index:
            assert sizes[
                2] == 0  # test set is created separately so use all of the other data for train and val
        assert folds_file is not None
        assert test_fold_index is not None

        try:
            with open(folds_file, 'rb') as f:
                all_fold_indices = pickle.load(f)
        except UnicodeDecodeError:
            with open(folds_file, 'rb') as f:
                all_fold_indices = pickle.load(
                    f, encoding='latin1'
                )  # in case we're loading indices from python2
        # assert len(data) == sum([len(fold_indices) for fold_indices in all_fold_indices])

        log_scaffold_stats(data, all_fold_indices, logger=logger)

        folds = [[data[i] for i in fold_indices]
                 for fold_indices in all_fold_indices]

        test = folds[test_fold_index]
        if val_fold_index is not None:
            val = folds[val_fold_index]

        train_val = []
        for i in range(len(folds)):
            if i != test_fold_index and (val_fold_index is None
                                         or i != val_fold_index):
                train_val.extend(folds[i])

        if val_fold_index is not None:
            train = train_val
        else:
            random.seed(seed)
            random.shuffle(train_val)
            train_size = int(sizes[0] * len(train_val))
            train = train_val[:train_size]
            val = train_val[train_size:]

        return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(
            test)

    elif split_type == 'scaffold_balanced':
        return scaffold_split(data,
                              sizes=sizes,
                              balanced=True,
                              seed=seed,
                              logger=logger)

    elif split_type == 'random':
        data.shuffle(seed=seed)

        train_size = int(sizes[0] * len(data))
        train_val_size = int((sizes[0] + sizes[1]) * len(data))

        train = data[:train_size]
        val = data[train_size:train_val_size]
        test = data[train_val_size:]

        return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(
            test)

    else:
        raise ValueError(f'split_type "{split_type}" not supported.')
示例#9
0
def get_data(path: str,
             skip_invalid_smiles: bool = True,
             args: Namespace = None,
             features_path: List[str] = None,
             max_data_size: int = None,
             use_compound_names: bool = None,
             logger: Logger = None) -> MoleculeDataset:
    """
    Gets smiles string and target values (and optionally compound names if provided) from a CSV file.

    :param path: Path to a CSV file.
    :param skip_invalid_smiles: Whether to skip and filter out invalid smiles.
    :param args: Arguments.
    :param features_path: A list of paths to files containing features. If provided, it is used
    in place of args.features_path.
    :param max_data_size: The maximum number of data points to load.
    :param use_compound_names: Whether file has compound names in addition to smiles strings.
    :param logger: Logger.
    :return: A MoleculeDataset containing smiles strings and target values along
    with other info such as additional features and compound names when desired.
    """
    debug = logger.debug if logger is not None else print

    if args is not None:
        # Prefer explicit function arguments but default to args if not provided
        features_path = features_path if features_path is not None else args.features_path
        max_data_size = max_data_size if max_data_size is not None else args.max_data_size
        use_compound_names = use_compound_names if use_compound_names is not None else args.use_compound_names
    else:
        use_compound_names = False

    max_data_size = max_data_size or float('inf')

    # Load features
    if features_path is not None:
        features_data = []
        for feat_path in features_path:
            features_data.append(
                load_features(feat_path))  # each is num_data x num_features
        features_data = np.concatenate(features_data, axis=1)
        args.features_dim = len(features_data[0])
    else:
        features_data = None
        if args is not None:
            args.features_dim = 0

    skip_smiles = set()

    # Load data
    with open(path) as f:
        reader = csv.reader(f)
        next(reader)  # skip header

        lines = []
        for line in reader:
            smiles = line[0]

            if smiles in skip_smiles:
                continue

            lines.append(line)

            if len(lines) >= max_data_size:
                break

        data = MoleculeDataset([
            MoleculeDatapoint(line=line,
                              args=args,
                              features=features_data[i]
                              if features_data is not None else None,
                              use_compound_names=use_compound_names) for i,
            line in tqdm(enumerate(lines), total=len(lines), disable=True)
        ])

    # Filter out invalid SMILES
    if skip_invalid_smiles:
        original_data_len = len(data)
        data = filter_invalid_smiles(data)

        if len(data) < original_data_len:
            debug(
                f'Warning: {original_data_len - len(data)} SMILES are invalid.'
            )

    return data