Ejemplo n.º 1
0
def test_predict(architecture, weights, data, fformat, tolerance):
    """ Test correct prediction output shapes as well as satisfying
        prediction performance.

        Prediction performance is checked through sequences from SIMAP with
        known class labels. Class labels are stored as the id in the given
        fasta file. Tolerance defines how many sequences the algorithm
        is allowed to misclassify before the test fails.
    """
    module, cls = _get_module_cls_from_arch(architecture)

    # Set up device
    cuda = torch.cuda.is_available()
    device = torch.device('cuda' if cuda else 'cpu')
    # Start test
    model_dict = torch.load(weights, map_location=device)
    model = load_nn((module, cls), model_dict, phase='infer', device=device)
    dataset = ProteinIterableDataset(data, f_format=fformat)
    preds, confs, ids, indices = predict(model, dataset, device)
    # Test correct output shape
    assert (preds.shape[0] == confs.shape[0])
    assert (confs.shape[0] == len(ids))
    assert (len(ids) == len(indices))
    # Test satisfying prediction accuracy
    n = len(ids)
    ids = torch.tensor(list(map(int, ids)))
    assert (sum((ids == preds.cpu()).long()) >= n - tolerance)
Ejemplo n.º 2
0
def annotate_with_deepnog(identifier: str,
                          protein_list: List[SeqRecord],
                          database: str = 'eggNOG5',
                          tax_level: int = 2,
                          confidence_threshold: float = None,
                          verb: bool = True) -> GenotypeRecord:
    """
    Assign proteins belonging to a sample to orthologous groups using deepnog.

    :param identifier: The name associated with the sample.
    :param protein_list: A list of SeqRecords containing protein sequences.
    :param database: Orthologous group/family database to use.
    :param tax_level: The NCBI taxon ID of the taxonomic level to use from the given database.
    :param confidence_threshold: Confidence threshold of deepnog annotations below which annotations
                                 will be discarded.
    :param verb: Whether to print verbose progress messages.
    :returns: a GenotypeRecord suitable for use with phenotrex.
    """
    if not (database, tax_level) in DEEPNOG_VALID_CONFIG:
        raise RuntimeError(
            f'Unknown database and/or tax level: {database}/{tax_level}')

    device = set_device('auto')
    torch.set_num_threads(1)
    weights_path = get_weights_path(
        database=database,
        level=str(tax_level),
        architecture=DEEPNOG_ARCH,
    )
    model_dict = torch.load(weights_path, map_location=device)
    model = load_nn(
        architecture=DEEPNOG_ARCH,
        model_dict=model_dict,
        device=device,
    )
    class_labels = model_dict['classes']
    dataset = PreloadedProteinDataset(protein_list)
    preds, confs, ids, indices = predict(model,
                                         dataset,
                                         device,
                                         batch_size=1,
                                         num_workers=1,
                                         verbose=3 if verb else 0)
    threshold = float(model.threshold) if hasattr(
        model, 'threshold') else confidence_threshold
    df = create_df(
        class_labels,
        preds,
        confs,
        ids,
        indices,
        threshold=threshold,
    )

    cogs = [x for x in df.prediction.unique() if x]
    feature_type_str = f'{database}-tax-{tax_level}'
    return GenotypeRecord(identifier=identifier,
                          feature_type=feature_type_str,
                          features=cogs)
Ejemplo n.º 3
0
def test_count_params(architecture, weights):
    """ Test loading of neural network model. """
    cuda = torch.cuda.is_available()
    device = torch.device('cuda' if cuda else 'cpu')
    model_dict = torch.load(weights, map_location=device)
    model = load_nn(architecture, model_dict, phase='infer', device=device)
    n_params_tuned = count_parameters(model, tunable_only=True)
    n_params_total = count_parameters(model, tunable_only=False)
    assert n_params_total == n_params_tuned
Ejemplo n.º 4
0
def test_load_nn(architecture, weights):
    """ Test loading of neural network model. """
    module, cls = _get_module_cls_from_arch(architecture)

    # Set up device
    cuda = torch.cuda.is_available()
    device = torch.device('cuda' if cuda else 'cpu')
    # Start test
    model_dict = torch.load(weights, map_location=device)
    model = load_nn((module, cls), model_dict, phase='infer', device=device)
    assert (issubclass(type(model), nn.Module))
    assert (isinstance(model, nn.Module))
Ejemplo n.º 5
0
def test_sync_counter_of_many_empty_sequences():
    """ Test if many sequences with empty ids are counted correctly. """
    # Set up device
    torch.set_num_threads(2)
    cuda = torch.cuda.is_available()
    device = torch.device('cuda' if cuda else 'cpu')
    # Start test
    model_dict = torch.load(WEIGHTS_PATH, map_location=device)
    model = load_nn(['deepnog', 'DeepNOG'], model_dict, phase='infer', device=device)
    dataset = ProteinIterableDataset(DATA_SKIP_PATH, f_format='fasta')
    with pytest.warns(UserWarning, match='no sequence id could be detected'):
        _ = predict(model, dataset, device)

    # Test correct counted skipped sequences
    assert(int(dataset.n_skipped) == 2**16)
Ejemplo n.º 6
0
def test_skip_empty_sequences(architecture, weights, data, fformat):
    """ Test if sequences with empty ids are skipped and counted correctly.
    """
    module, cls = _get_module_cls_from_arch(architecture)

    # Set up device
    cuda = torch.cuda.is_available()
    device = torch.device('cuda' if cuda else 'cpu')
    # Start test
    model_dict = torch.load(weights, map_location=device)
    model = load_nn((module, cls), model_dict, phase='infer', device=device)
    dataset = ProteinIterableDataset(data, f_format=fformat)
    with pytest.warns(UserWarning, match='no sequence id could be detected'):
        preds, confs, ids, indices = predict(model, dataset, device)
    # Test correct output shape
    assert (preds.shape[0] == 70)
    # Test correct counted skipped sequences
    assert (int(dataset.n_skipped) == 20)
Ejemplo n.º 7
0
def _start_inference(args, arch_module, arch_cls):
    from pandas import read_csv, DataFrame
    import torch
    from deepnog.data import ProteinIterableDataset
    from deepnog.learning import predict
    from deepnog.utils import create_df, get_logger, get_weights_path, load_nn
    from deepnog.utils.metrics import estimate_performance

    logger = get_logger(__name__, verbose=args.verbose)
    # Intra-op parallelization appears rather inefficient.
    # Users may override with environmental variable: export OMP_NUM_THREADS=8
    torch.set_num_threads(1)

    # Construct path to saved parameters of NN
    if args.weights is not None:
        weights_path = args.weights
    else:
        weights_path = get_weights_path(
            database=args.database,
            level=str(args.tax),
            architecture=args.architecture,
            verbose=args.verbose,
        )
    # Load neural network parameters
    logger.info(f'Loading NN-parameters from {weights_path} ...')
    model_dict = torch.load(weights_path, map_location=args.device)

    # Load dataset
    logger.info(f'Accessing dataset from {args.file} ...')
    dataset = ProteinIterableDataset(args.file,
                                     labels_file=args.test_labels,
                                     f_format=args.fformat)

    # Load class names
    try:
        class_labels = model_dict['classes']
    except KeyError:
        class_labels = dataset.label_encoder.classes_

    # Load neural network model
    model = load_nn(architecture=(arch_module, arch_cls),
                    model_dict=model_dict,
                    phase=args.phase,
                    device=args.device)

    # If given, set confidence threshold for prediction
    if args.confidence_threshold is not None:
        if 0.0 < args.confidence_threshold <= 1.0:
            threshold = float(args.confidence_threshold)
        else:
            logger.error(f'Invalid confidence threshold specified: '
                         f'{args.confidence_threshold} not in range (0, 1].')
            sys.exit(1)
    elif hasattr(model, 'threshold'):
        threshold = float(model.threshold)
        logger.info(f'Applying confidence threshold from model: {threshold}')
    else:
        threshold = None

    # Predict labels of given data
    logger.info('Starting protein sequence group/family inference ...')
    logger.debug(
        f'Processing {args.batch_size} sequences per iteration (minibatch)')
    preds, confs, ids, indices = predict(model,
                                         dataset,
                                         args.device,
                                         batch_size=args.batch_size,
                                         num_workers=args.num_workers,
                                         verbose=args.verbose)

    # Construct results dataframe
    df = create_df(class_labels,
                   preds,
                   confs,
                   ids,
                   indices,
                   threshold=threshold)

    if args.out is None:
        save_file = sys.stdout
        logger.info('Writing predictions to stdout')
    else:
        save_file = args.out
        Path(args.out).parent.mkdir(parents=True, exist_ok=True)
        logger.info(f'Writing prediction to {save_file}')

    columns = ['sequence_id', 'prediction', 'confidence']
    separator = {'csv': ',', 'tsv': '\t', 'legacy': ';'}.get(args.outformat)
    df.to_csv(save_file, sep=separator, index=False, columns=columns)

    # Measure test set performance, if labels were provided
    if args.test_labels is not None:
        if args.out is None:
            perf_file = sys.stderr
            logger.info('Writing test set performance to stderr')
        else:
            perf_file = Path(save_file).with_suffix('.performance.csv')
            logger.info(f'Writing test set performance to {perf_file}')
        # Ensure object dtype to avoid int-str mismatches
        df_true = read_csv(args.test_labels, dtype=object, index_col=0)
        df = df.astype(dtype={columns[1]: object})
        perf = estimate_performance(df_true=df_true, df_pred=df)
        df_perf = DataFrame(data=[
            perf,
        ])
        df_perf['experiment'] = args.file
        df_perf.to_csv(perf_file, )
    logger.info('All done.')
    return