def test_unit_train_classify(tmpdir): tmpdir = str(tmpdir) out_path = os.path.join(tmpdir, 'model.pkl') run([ 'train', '--model', get_test_file('random_forest_test.json'), '--classes', get_test_file('BGC0000015.classes.csv'), '--output', out_path, get_test_file('BGC0000015.pfam.csv') ]) assert os.path.exists(out_path) model = SequenceModelWrapper.load(out_path) domains = pd.read_csv(get_test_file('BGC0000015.pfam.csv')) classes = model.predict( [sample for _, sample in domains.groupby('sequence_id')]) assert isinstance(classes, pd.DataFrame) assert list(classes.columns) == ['class1', 'class2', 'class3', 'class4'] assert len(classes.index) == 2 assert list(classes.iloc[0] > 0.5) == [True, False, True, False] assert list(classes.iloc[1] > 0.5) == [False, True, False, True]
def test_integration_pfam_annotator(tmpdir): tmpdir = str(tmpdir) tmppath = os.path.join(tmpdir, 'test') records = SeqIO.parse(get_test_file('BGC0000015.gbk'), format='genbank') record = next(records) annotator = HmmscanPfamRecordAnnotator( record=record, tmp_path_prefix=tmppath, db_path=get_test_file('Pfam-A.PF00005.hmm'), clans_path=get_test_file('Pfam-A.PF00005.clans.tsv')) annotator.annotate() pfams = util.get_pfam_features(record) assert len(pfams) == 2 pfam = pfams[0] assert pfam.location.start == 249 assert pfam.location.end == 696 assert pfam.location.strand == -1 assert pfam.qualifiers.get('db_xref') == ['PF00005.26'] assert pfam.qualifiers.get('locus_tag') == ['AAK73498.1'] assert pfam.qualifiers.get('description') == ['ABC transporter'] assert pfam.qualifiers.get('database') == ['31.0'] assert_sorted_features(record)
def test_integration_train_detect_fail_fasta(): # Should fail due to unprocessed input sequence with pytest.raises(NotImplementedError): run([ 'train', '--model', get_test_file('clusterfinder_geneborder_test.json'), '--output', 'bar.pkl', get_test_file('BGC0000015.fa') ])
def test_unit_train_detect(model, tmpdir): tmpdir = str(tmpdir) out_path = os.path.join(tmpdir, 'model.pkl') run([ 'train', '--model', get_test_file(model), '--config', 'PFAM2VEC', get_test_file('pfam2vec.test.tsv'), '--output', out_path, get_test_file('BGC0000015.pfam.csv'), get_test_file('negative.pfam.csv') ]) assert os.path.exists(out_path) model = SequenceModelWrapper.load(out_path) pos_domains = pd.read_csv(get_test_file('BGC0000015.pfam.csv')) neg_domains = pd.read_csv(get_test_file('negative.pfam.csv')) pos_prediction = model.predict(pos_domains) neg_prediction = model.predict(neg_domains) assert isinstance(pos_prediction, pd.Series) assert isinstance(neg_prediction, pd.Series) assert pos_prediction.index.equals(pos_domains.index) assert neg_prediction.index.equals(neg_domains.index) assert pos_prediction.mean() > 0.5 assert neg_prediction.mean() < 0.5
def test_integration_pipeline_labelled(tmpdir): tmpdir = str(tmpdir) report_dir = os.path.join(tmpdir, 'report') run(['pipeline', '--output', report_dir, get_test_file('labelled.gbk')]) evaluation_dir = os.path.join(report_dir, 'evaluation') files = os.listdir(evaluation_dir) for file in files: print(file) assert 'report.bgc.png' in files assert 'report.score.png' in files assert 'report.roc.png' in files assert 'report.pr.png' in files
def test_integration_prepare_default(tmpdir): tmpdir = str(tmpdir) outgbk = os.path.join(tmpdir, 'outfile.gbk') outtsv = os.path.join(tmpdir, 'outfile.tsv') run([ 'prepare', '--output-gbk', outgbk, '--output-tsv', outtsv, get_test_file('BGC0000015.fa') ]) records = list(SeqIO.parse(outgbk, 'genbank')) assert len(records) == 2 record = records[0] assert_sorted_features(record) proteins = util.get_protein_features(record) pfams = util.get_pfam_features(record) assert len(proteins) == 18 print([util.get_protein_id(f) for f in proteins]) assert len(pfams) == 111 record = records[1] assert_sorted_features(record) proteins = util.get_protein_features(record) pfams = util.get_pfam_features(record) assert len(proteins) == 27 assert len(pfams) == 36 domains = pd.read_csv(outtsv, sep='\t') records = domains.groupby('sequence_id') assert len(records) == 2 record = records.get_group('BGC0000015.1') print(record['protein_id'].unique()) # some of the proteins do not have any Pfam domains so they are not present assert len(record['protein_id'].unique()) == 17 assert len(record) == 111 record = records.get_group('BGC0000015.2') # some of the proteins do not have any Pfam domains so they are not present assert len(record['protein_id'].unique()) == 11 assert len(record) == 36
def test_integration_protein_annotator(tmpdir): tmpdir = str(tmpdir) tmppath = os.path.join(tmpdir, 'test') records = SeqIO.parse(get_test_file('BGC0000015.fa'), format='fasta') record = next(records) annotator = ProdigalProteinRecordAnnotator(record=record, tmp_path_prefix=tmppath) annotator.annotate() proteins = util.get_protein_features(record) assert len(proteins) == 18 protein = proteins[0] assert protein.location.start == 3 assert protein.location.end == 1824 assert protein.id == 'BGC0000015.1_1' assert protein.qualifiers.get('locus_tag') == ['BGC0000015.1_BGC0000015.1_1'] assert_sorted_features(record)
def test_integration_pipeline_default(tmpdir, input_file): tmpdir = str(tmpdir) report_dir = os.path.join(tmpdir, 'report') run(['pipeline', '--output', report_dir, get_test_file(input_file)]) files = os.listdir(report_dir) for file in files: print(file) assert 'README.txt' in files assert 'report.bgc.gbk' in files assert 'report.bgc.tsv' in files assert 'report.full.gbk' in files assert 'report.pfam.tsv' in files evaluation_dir = os.path.join(report_dir, 'evaluation') files = os.listdir(evaluation_dir) for file in files: print(file) assert 'report.bgc.png' in files assert 'report.score.png' in files records = list( SeqIO.parse(os.path.join(report_dir, 'report.full.gbk'), 'genbank')) assert len(records) == 2 record = records[0] cluster_features = util.get_cluster_features(record) assert len(cluster_features) >= 1 record = records[1] cluster_features = util.get_cluster_features(record) assert len(cluster_features) >= 1 cluster_records = list( SeqIO.parse(os.path.join(report_dir, 'report.bgc.gbk'), 'genbank')) assert len(cluster_records) >= 2