def test_remove_singletons():
    '''
    Test if singletons are removed
    '''

    lengths = [60, 100, 80]
    h5_data = generate_h5_file(*lengths, n_samples=3,
                               baselines=[20, 40, 30],
                               empty_samples=[[False]*3, [True, True, False], [False]*3],
                               filename='coverage.h5')
    fasta = generate_fasta_file(*lengths)
    singleton_file = Path('singletons.txt')

    compo = CompositionFeature(path=dict(filt_fasta=fasta))
    cover = CoverageFeature(path=dict(h5=h5_data))

    cover.remove_singletons(output=singleton_file, min_prevalence=2)
    compo.filter_by_ids(output='filt.fasta', ids_file=singleton_file)

    singletons = pd.read_csv('singletons.txt', sep='\t', header=None).values
    n_filt = sum(1 for _ in SeqIO.parse('filt.fasta', 'fasta'))

    for f in [fasta, h5_data, 'singletons.txt', 'filt.fasta']:
        Path(f).unlink()

    assert singletons.shape == (1, 3)
    assert n_filt == 2
Beispiel #2
0
def test_save_repr():
    '''
    Test save repr
    '''

    model = initialize_model('CoCoNet', TEST_SHAPES, TEST_ARCHITECTURE)
    fasta = generate_fasta_file(*TEST_CTG_LENGTHS)
    coverage = generate_h5_file(*TEST_CTG_LENGTHS, filename='coverage.h5')

    output = {k: Path('repr_{}.h5'.format(k))
              for k in ['composition', 'coverage']}

    save_repr_all(model, fasta, coverage, n_frags=5, frag_len=FL, output=output,
                  min_ctg_len=0, wsize=WSIZE, wstep=WSTEP)

    assert all(out.is_file() for out in output.values())

    handles = {k: h5py.File(v, 'r') for (k, v) in output.items()}
    firsts = {k: handle.get(list(handle.keys())[0]).shape
              for k, handle in handles.items()}

    latent_dim = (TEST_ARCHITECTURE['composition']['neurons'][-1]
                  + TEST_ARCHITECTURE['coverage']['neurons'][-1])

    assert firsts['composition'] == (5, latent_dim)
    assert firsts['coverage'] == (5, latent_dim)

    fasta.unlink()
    coverage.unlink()
    for key, filename in output.items():
        handles[key].close()
        filename.unlink()
Beispiel #3
0
def test_load_model():
    '''
    Test if model can be loaded
    '''

    args = {'compo_neurons': TEST_ARCHITECTURE['composition']['neurons'],
            'cover_neurons': TEST_ARCHITECTURE['coverage']['neurons'],
            'cover_filters': TEST_ARCHITECTURE['coverage']['n_filters'],
            'cover_kernel': TEST_ARCHITECTURE['coverage']['kernel_size'],
            'cover_stride': TEST_ARCHITECTURE['coverage']['conv_stride'],
            'merge_neurons': TEST_ARCHITECTURE['merge']['neurons'],
            'kmer': 4, 'no_rc': True,
            'fragment_length': FL, 'wsize': WSIZE, 'wstep': WSTEP}

    cfg = Configuration()
    cfg.init_config(output='.', **args)
    cfg.io['h5'] = generate_h5_file(FL, filename='coverage.h5')

    model = initialize_model('CoCoNet', cfg.get_input_shapes(), cfg.get_architecture())
    model_path = Path('CoCoNet.pth')

    torch.save({
        'state': model.state_dict()
    }, model_path)

    loaded_model = load_model(cfg)

    model_path.unlink()
    cfg.io['h5'].unlink()

    assert isinstance(loaded_model, CoCoNet)
Beispiel #4
0
def test_learn_save_load_model():
    '''
    Check:
    - if the training goes through
    - if the model is saved
    '''

    model = initialize_model('CoCoNet', TEST_SHAPES, TEST_ARCHITECTURE)
    model_file = Path('{}/test_model.pth'.format(LOCAL_DIR))
    results_file = Path('{}/test_res.csv'.format(LOCAL_DIR))

    pair_files = {'train': Path('{}/pairs_train.npy'.format(LOCAL_DIR)),
                  'test': Path('{}/pairs_test.npy'.format(LOCAL_DIR))}

    coverage_file = generate_h5_file(*TEST_CTG_LENGTHS, filename='coverage.h5')
    fasta_file = generate_fasta_file(*TEST_CTG_LENGTHS, save=True)

    fasta = [(seq.id, str(seq.seq)) for seq in SeqIO.parse(fasta_file, 'fasta')]

    make_pairs(fasta, STEP, FL, output=pair_files['train'], n_examples=50)
    make_pairs(fasta, STEP, FL, output=pair_files['test'], n_examples=5)

    train(model, fasta_file, coverage_file, pair_files, results_file, output=model_file,
          **TEST_LEARN_PRMS)

    tests = model_file.is_file() and results_file.is_file()

    for path in list(pair_files.values()) + [fasta_file, coverage_file, model_file, results_file]:
        path.unlink()

    assert tests
Beispiel #5
0
def test_load_data_cover():
    '''
    Test coverage generator
    '''

    contigs = generate_fasta_file(*TEST_CTG_LENGTHS, save=False)
    coverage_file = generate_h5_file(*TEST_CTG_LENGTHS, filename='coverage.h5')
    pairs_file = Path('pairs.npy').resolve()

    contigs = [(seq.id, str(seq.seq)) for seq in contigs]
    make_pairs(contigs, STEP, FL, output=pairs_file, n_examples=50)

    gen = CoverageGenerator(pairs_file, coverage_file,
                            batch_size=TEST_LEARN_PRMS['batch_size'],
                            load_batch=TEST_LEARN_PRMS['load_batch'],
                            wsize=TEST_LEARN_PRMS['wsize'],
                            wstep=TEST_LEARN_PRMS['wstep'])

    X1, X2 = next(gen)

    pairs_file.unlink()
    coverage_file.unlink()

    assert X1.shape == X2.shape
    assert X1.shape == (TEST_LEARN_PRMS['batch_size'], 2, 9)
Beispiel #6
0
def test_get_coverage_with_unmatched_ctg(window_size=4):

    pairs = generate_pair_file(save=False)
    data_h5 = generate_h5_file(30, filename='coverage.h5')

    with pytest.raises(KeyError):
        assert get_coverage(pairs, data_h5, window_size, window_size // 2)

    data_h5.unlink()
def test_coverage_feature():
    h5 = generate_h5_file(10, 20, filename='coverage.h5')
    f = CoverageFeature(path={'h5': h5})
    ctg = f.get_contigs()

    found_2_ctg = len(ctg) == 2
    h5.unlink()
    
    assert found_2_ctg
Beispiel #8
0
def test_get_coverage(window_size=4):

    pairs = generate_pair_file(save=False)
    data_h5 = generate_h5_file(30, 40, filename='coverage.h5')

    (X1, X2) = get_coverage(pairs, data_h5, window_size, window_size // 2)
    (T1, T2) = slow_coverage(pairs, data_h5, window_size, window_size // 2)

    data_h5.unlink()

    assert np.sum(X1 != T1) + np.sum(X2 != T2) == 0
Beispiel #9
0
def test_make_pregraph():
    output = Path('pregraph.pkl')
    model = generate_rd_model()

    h5_data = [(name, generate_h5_file(8, 8, 8, n_samples=5))
                for name in ['composition', 'coverage']]

    make_pregraph(model, h5_data, output)

    assert output.is_file()

    output.unlink()
Beispiel #10
0
def test_pairwise_comparisons():
    model = generate_rd_model()
    h5_data = [(f, generate_h5_file(8, 8, 8, n_samples=5))
               for f in ['composition', 'coverage']]

    pair_generators = (((x, y) for (x,y) in [('V0', 'V1'), ('V0', 'V2')]),)

    edges = compute_pairwise_comparisons(model, h5_data, pair_generators, vote_threshold=0.5)

    assert ('V0', 'V1') in edges
    assert ('V0', 'V2') in edges
    assert ('V1', 'V2') not in edges
Beispiel #11
0
def test_input_sizes():
    '''
    Check if input sizes are correct
    '''

    cfg = Configuration()
    cfg.init_config(output='test123',
                    kmer=4,
                    no_rc=False,
                    fragment_length=10,
                    wsize=4,
                    wstep=2)
    cfg.io['h5'] = generate_h5_file(10, filename='coverage.h5')

    input_shapes = {'composition': 136, 'coverage': (4, 2)}

    auto_shapes = cfg.get_input_shapes()

    shutil.rmtree(cfg.io['output'])

    assert input_shapes == auto_shapes
Beispiel #12
0
def test_refine_clustering():
    model = generate_rd_model()
    h5_data = [(k, generate_h5_file(*[8]*5, n_samples=5))
               for k in ['composition', 'coverage']]

    files = ['pre_graph.pkl', 'graph.pkl', 'assignments.csv']

    adj = np.array([[25, 24,  0, -1,  0],
                    [24, 25, -1,  0, -1],
                    [ 0, -1, 25, 25, 24],
                    [-1,  0, 25, 25, 23],
                    [ 0, -1, 23, 23, 25]])

    contigs = ["V{}".format(i) for i in range(5)]
    edges = [(f"V{i}", f"V{j}", adj[i, j]) for i in range(len(contigs)) for j in range(len(contigs))
             if adj[i, j] >= 0 and i>j]

    graph = igraph.Graph()
    graph.add_vertices(contigs)
    for i, j, w in edges:
        graph.add_edge(i, j, weight=w)

    graph.write_pickle(files[0])

    refine_clustering(model, h5_data, files[0],
                      graph_file=files[1],
                      assignments_file=files[2])

    clustering = pd.read_csv(files[2], header=None, index_col=0)[1]

    all_files = all(Path(f).is_file() for f in files)

    for f in files:
        Path(f).unlink()

    assert all_files
    assert clustering.loc['V0'] == clustering.loc['V1']
    assert clustering.loc['V2'] == clustering.loc['V3']
    assert clustering.loc['V3'] == clustering.loc['V4']