예제 #1
0
def test_predict(architecture, weights, data, fformat, tolerance):
    """ Test correct prediction output shapes as well as satisfying
        prediction performance.

        Prediction performance is checked through sequences from SIMAP with
        known class labels. Class labels are stored as the id in the given
        fasta file. Tolerance defines how many sequences the algorithm
        is allowed to misclassfy before the test fails.
    """
    # Set up device
    cuda = torch.cuda.is_available()
    device = torch.device('cuda' if cuda else 'cpu')
    # Start test
    model_dict = torch.load(weights, map_location=device)
    model = load_nn(architecture, model_dict, device)
    dataset = ProteinDataset(data, f_format=fformat)
    preds, confs, ids, indices = predict(model, dataset, device)
    # Test correct output shape
    assert (preds.shape[0] == confs.shape[0])
    assert (confs.shape[0] == len(ids))
    assert (len(ids) == len(indices))
    # Test satisfying prediction accuracy
    N = len(ids)
    ids = torch.tensor(list(map(int, ids)))
    assert (sum((ids == preds.cpu()).long()) >= N - tolerance)
예제 #2
0
def test_load_nn(architecture, weights):
    """ Test loading of neural network model. """
    # Set up device
    cuda = torch.cuda.is_available()
    device = torch.device('cuda' if cuda else 'cpu')
    # Start test
    model_dict = torch.load(weights, map_location=device)
    model = load_nn(architecture, model_dict, device)
    assert (issubclass(type(model), nn.Module))
    assert (isinstance(model, nn.Module))
예제 #3
0
def annotate_with_deepnog(identifier: str,
                          protein_list: List[SeqRecord],
                          database: str = 'eggNOG5',
                          tax_level: int = 2,
                          verb: bool = True) -> GenotypeRecord:
    """
    Perform calling of EggNOG5 clusters on a list of SeqRecords belonging to a sample, using deepnog.

    :param identifier: The name associated with the sample.
    :param protein_list: A list of SeqRecords containing protein sequences.
    :param database: Orthologous group/family database to use.
    :param tax_level: The NCBI taxon ID of the taxonomic level to use from the given database.
    :param verb: Whether to print verbose progress messages.
    :returns: a GenotypeRecord suitable for use with phenotrex.
    """
    if not (database, tax_level) in DEEPNOG_VALID_CONFIG:
        raise RuntimeError(
            f'Unknown database and/or tax level: {database}/{tax_level}')

    device = set_device('auto')
    torch.set_num_threads(1)
    weights_path = get_weights_path(
        database=database,
        level=str(tax_level),
        architecture=DEEPNOG_ARCH,
    )
    model_dict = torch.load(weights_path, map_location=device)
    model = load_nn(DEEPNOG_ARCH, model_dict, device)
    class_labels = model_dict['classes']
    dataset = PreloadedProteinDataset(protein_list)
    preds, confs, ids, indices = predict(model,
                                         dataset,
                                         device,
                                         batch_size=1,
                                         num_workers=1,
                                         verbose=3 if verb else 0)
    threshold = float(model.threshold) if hasattr(model, 'threshold') else None
    df = create_df(class_labels,
                   preds,
                   confs,
                   ids,
                   indices,
                   threshold=threshold,
                   verbose=0)

    cogs = [x for x in df.prediction.unique() if x]
    feature_type_str = f'{database}-tax-{tax_level}'
    return GenotypeRecord(identifier=identifier,
                          feature_type=feature_type_str,
                          features=cogs)
예제 #4
0
def test_sync_counter_of_many_empty_sequences():
    """ Test if many sequences with empty ids are counted correctly. """
    # Set up device
    torch.set_num_threads(16)
    cuda = torch.cuda.is_available()
    device = torch.device('cuda' if cuda else 'cpu')
    # Start test
    model_dict = torch.load(weights_path, map_location=device)
    model = load_nn('deepencoding', model_dict, device)
    dataset = ProteinDataset(data_skip_path, f_format='fasta')
    with pytest.warns(UserWarning, match='no sequence id could be detected'):
        _ = predict(model, dataset, device)

    # Test correct counted skipped sequences
    assert (int(dataset.n_skipped) == 2**16)
예제 #5
0
def test_skip_empty_sequences(architecture, weights, data, fformat):
    """ Test if sequences with empty ids are skipped and counted correctly.
    """
    # Set up device
    cuda = torch.cuda.is_available()
    device = torch.device('cuda' if cuda else 'cpu')
    # Start test
    model_dict = torch.load(weights, map_location=device)
    model = load_nn(architecture, model_dict, device)
    dataset = ProteinDataset(data, f_format=fformat)
    with pytest.warns(UserWarning, match='no sequence id could be detected'):
        preds, confs, ids, indices = predict(model, dataset, device)
    # Test correct output shape
    assert (preds.shape[0] == 70)
    # Test correct counted skipped sequences
    assert (int(dataset.n_skipped) == 20)