Esempio n. 1
0
def test_get_weights_impossible(capsys):
    with TemporaryDirectory(prefix='deepnog_test_data_dir_') as tmpdir:
        with pytest.raises(IOError, match='Data not found'):
            _ = get_weights_path(database='testdb',
                                 level='1',
                                 architecture='do_not_delete',
                                 data_home=tmpdir,
                                 download_if_missing=False,
                                 verbose=3)
        _ = get_weights_path(database="unavailable_db",
                             level='0',
                             architecture="deepnog",
                             data_home=tmpdir,
                             verbose=3)
        assert "Download failed" in capsys.readouterr().err
Esempio n. 2
0
def annotate_with_deepnog(identifier: str,
                          protein_list: List[SeqRecord],
                          database: str = 'eggNOG5',
                          tax_level: int = 2,
                          confidence_threshold: float = None,
                          verb: bool = True) -> GenotypeRecord:
    """
    Assign proteins belonging to a sample to orthologous groups using deepnog.

    :param identifier: The name associated with the sample.
    :param protein_list: A list of SeqRecords containing protein sequences.
    :param database: Orthologous group/family database to use.
    :param tax_level: The NCBI taxon ID of the taxonomic level to use from the given database.
    :param confidence_threshold: Confidence threshold of deepnog annotations below which annotations
                                 will be discarded.
    :param verb: Whether to print verbose progress messages.
    :returns: a GenotypeRecord suitable for use with phenotrex.
    """
    if not (database, tax_level) in DEEPNOG_VALID_CONFIG:
        raise RuntimeError(
            f'Unknown database and/or tax level: {database}/{tax_level}')

    device = set_device('auto')
    torch.set_num_threads(1)
    weights_path = get_weights_path(
        database=database,
        level=str(tax_level),
        architecture=DEEPNOG_ARCH,
    )
    model_dict = torch.load(weights_path, map_location=device)
    model = load_nn(
        architecture=DEEPNOG_ARCH,
        model_dict=model_dict,
        device=device,
    )
    class_labels = model_dict['classes']
    dataset = PreloadedProteinDataset(protein_list)
    preds, confs, ids, indices = predict(model,
                                         dataset,
                                         device,
                                         batch_size=1,
                                         num_workers=1,
                                         verbose=3 if verb else 0)
    threshold = float(model.threshold) if hasattr(
        model, 'threshold') else confidence_threshold
    df = create_df(
        class_labels,
        preds,
        confs,
        ids,
        indices,
        threshold=threshold,
    )

    cogs = [x for x in df.prediction.unique() if x]
    feature_type_str = f'{database}-tax-{tax_level}'
    return GenotypeRecord(identifier=identifier,
                          feature_type=feature_type_str,
                          features=cogs)
Esempio n. 3
0
def test_get_weights():
    with TemporaryDirectory(prefix='deepnog_test_data_dir_') as tmpdir:
        p = get_weights_path(database='testdb',
                             level='1',
                             architecture='do_not_delete',
                             data_home=tmpdir,
                             download_if_missing=True,
                             verbose=3)
        assert Path(p).is_file()
Esempio n. 4
0
def test_get_weights_all(database, level, architecture="deepnog"):
    with TemporaryDirectory(
            prefix=f"deepnog_test_data_dir_{database}_") as tmp:
        p = get_weights_path(
            database=database,
            level=level,
            architecture=architecture,
            data_home=tmp,
            download_if_missing=True,
            verbose=3,
        )
        p = Path(p)
        assert p.is_file(), "Could not find file. Possibly, download failed."
        assert p.suffix == ".pth", f"Wrong file format: {p.suffix}"
        assert p.stat().st_size, "File is empty"  # File size in bytes
Esempio n. 5
0
def _start_inference(args, arch_module, arch_cls):
    from pandas import read_csv, DataFrame
    import torch
    from deepnog.data import ProteinIterableDataset
    from deepnog.learning import predict
    from deepnog.utils import create_df, get_logger, get_weights_path, load_nn
    from deepnog.utils.metrics import estimate_performance

    logger = get_logger(__name__, verbose=args.verbose)
    # Intra-op parallelization appears rather inefficient.
    # Users may override with environmental variable: export OMP_NUM_THREADS=8
    torch.set_num_threads(1)

    # Construct path to saved parameters of NN
    if args.weights is not None:
        weights_path = args.weights
    else:
        weights_path = get_weights_path(
            database=args.database,
            level=str(args.tax),
            architecture=args.architecture,
            verbose=args.verbose,
        )
    # Load neural network parameters
    logger.info(f'Loading NN-parameters from {weights_path} ...')
    model_dict = torch.load(weights_path, map_location=args.device)

    # Load dataset
    logger.info(f'Accessing dataset from {args.file} ...')
    dataset = ProteinIterableDataset(args.file,
                                     labels_file=args.test_labels,
                                     f_format=args.fformat)

    # Load class names
    try:
        class_labels = model_dict['classes']
    except KeyError:
        class_labels = dataset.label_encoder.classes_

    # Load neural network model
    model = load_nn(architecture=(arch_module, arch_cls),
                    model_dict=model_dict,
                    phase=args.phase,
                    device=args.device)

    # If given, set confidence threshold for prediction
    if args.confidence_threshold is not None:
        if 0.0 < args.confidence_threshold <= 1.0:
            threshold = float(args.confidence_threshold)
        else:
            logger.error(f'Invalid confidence threshold specified: '
                         f'{args.confidence_threshold} not in range (0, 1].')
            sys.exit(1)
    elif hasattr(model, 'threshold'):
        threshold = float(model.threshold)
        logger.info(f'Applying confidence threshold from model: {threshold}')
    else:
        threshold = None

    # Predict labels of given data
    logger.info('Starting protein sequence group/family inference ...')
    logger.debug(
        f'Processing {args.batch_size} sequences per iteration (minibatch)')
    preds, confs, ids, indices = predict(model,
                                         dataset,
                                         args.device,
                                         batch_size=args.batch_size,
                                         num_workers=args.num_workers,
                                         verbose=args.verbose)

    # Construct results dataframe
    df = create_df(class_labels,
                   preds,
                   confs,
                   ids,
                   indices,
                   threshold=threshold)

    if args.out is None:
        save_file = sys.stdout
        logger.info('Writing predictions to stdout')
    else:
        save_file = args.out
        Path(args.out).parent.mkdir(parents=True, exist_ok=True)
        logger.info(f'Writing prediction to {save_file}')

    columns = ['sequence_id', 'prediction', 'confidence']
    separator = {'csv': ',', 'tsv': '\t', 'legacy': ';'}.get(args.outformat)
    df.to_csv(save_file, sep=separator, index=False, columns=columns)

    # Measure test set performance, if labels were provided
    if args.test_labels is not None:
        if args.out is None:
            perf_file = sys.stderr
            logger.info('Writing test set performance to stderr')
        else:
            perf_file = Path(save_file).with_suffix('.performance.csv')
            logger.info(f'Writing test set performance to {perf_file}')
        # Ensure object dtype to avoid int-str mismatches
        df_true = read_csv(args.test_labels, dtype=object, index_col=0)
        df = df.astype(dtype={columns[1]: object})
        perf = estimate_performance(df_true=df_true, df_pred=df)
        df_perf = DataFrame(data=[
            perf,
        ])
        df_perf['experiment'] = args.file
        df_perf.to_csv(perf_file, )
    logger.info('All done.')
    return