Esempio n. 1
0
def test_vcf_outputting(monkeypatch):
    """Write inference output into vcf files
    """
    first_vcf_bam_tuple = VCFReader.VcfBamPath(vcf="/dummy/path1.gz",
                                               bam="temp.bam",
                                               is_fp=False)
    second_vcf_bam_tuple = VCFReader.VcfBamPath(vcf="/dummy/path2.gz",
                                                bam="temp.bam",
                                                is_fp=False)
    with monkeypatch.context() as mp:
        mp.setattr(vcf.Reader, "__init__", MockPyVCFReader.new_vcf_reader_init)
        vcf_loader = VCFReader([first_vcf_bam_tuple, second_vcf_bam_tuple])
    inferred_results = [
        VariantZygosity.HOMOZYGOUS, VariantZygosity.HOMOZYGOUS,
        VariantZygosity.HETEROZYGOUS, VariantZygosity.HETEROZYGOUS,
        VariantZygosity.HOMOZYGOUS, VariantZygosity.HETEROZYGOUS
    ]
    assert (len(inferred_results) == len(vcf_loader))
    with monkeypatch.context() as mp:
        mp.setattr(vcf.Reader, "__init__", MockPyVCFReader.new_vcf_reader_init)
        result_writer = VCFResultWriter(vcf_loader, inferred_results)
        result_writer.write_output()
    # Validate output files format and make sure the outputted genotype for each record matches to the network output
    i = 0
    for f in ['inferred_path1.vcf', 'inferred_path2.vcf']:
        vcf_reader = vcf.Reader(
            filename=os.path.join(result_writer.output_location, f))
        for record in vcf_reader:
            assert (record.samples[0]['GT'] == result_writer.
                    zygosity_to_vcf_genotype[inferred_results[i]])
            i += 1
    assert (i == 6)
    # Clean up files
    shutil.rmtree(result_writer.output_location)
Esempio n. 2
0
def test_vcf_load_variant_from_multiple_files(get_created_vcf_tabix_files):
    """Get variants from multiple mocked VCF files.
    """
    vcf_file_path, tabix_file_path = get_created_vcf_tabix_files(mock_file_input())
    first_vcf_bam_tuple = VCFReader.VcfBamPath(vcf=vcf_file_path, bam=tabix_file_path, is_fp=False)
    second_vcf_bam_tuple = VCFReader.VcfBamPath(vcf=vcf_file_path, bam=tabix_file_path, is_fp=False)
    vcf_loader = VCFReader([first_vcf_bam_tuple])
    vcf_loader_2x = VCFReader([first_vcf_bam_tuple, second_vcf_bam_tuple])
    assert (2 * len(vcf_loader) == len(vcf_loader_2x))
Esempio n. 3
0
def test_vcf_load_variant_from_multiple_files(monkeypatch):
    """Get variants from multiple mocked VCF files.
    """
    first_vcf_bam_tuple = VCFReader.VcfBamPath(vcf="/dummy/path.gz",
                                               bam="temp.bam",
                                               is_fp=False)
    second_vcf_bam_tuple = VCFReader.VcfBamPath(vcf="/dummy/path.gz",
                                                bam="temp.bam",
                                                is_fp=False)
    vcf_loader = MockPyVCFReader.get_vcf(monkeypatch, [first_vcf_bam_tuple])
    vcf_loader_2x = MockPyVCFReader.get_vcf(
        monkeypatch, [first_vcf_bam_tuple, second_vcf_bam_tuple])
    assert (2 * len(vcf_loader) == len(vcf_loader_2x))
Esempio n. 4
0
def test_load_vcf_content_with_wrong_format(get_created_vcf_tabix_files):
    """ parse vcf file with wrong format
    """
    vcf_file_path, tabix_file_path = get_created_vcf_tabix_files(mock_invalid_file_input())
    vcf_bam_tuple = VCFReader.VcfBamPath(vcf=vcf_file_path, bam=tabix_file_path, is_fp=False)
    with pytest.raises(RuntimeError):
        VCFReader([vcf_bam_tuple])
Esempio n. 5
0
def test_vcf_loader_snps(get_created_vcf_tabix_files):
    """Get all variants from mocked file stream, filter SNPs, multi allele & multi samples
    """
    vcf_file_path, tabix_file_path = get_created_vcf_tabix_files(mock_file_input())
    vcf_bam_tuple = VCFReader.VcfBamPath(vcf=vcf_file_path, bam=tabix_file_path, is_fp=False)
    vcf_loader = VCFReader([vcf_bam_tuple])
    assert(len(vcf_loader) == 13)
Esempio n. 6
0
def test_vcf_load_fp(get_created_vcf_tabix_files):
    """Get first variant from false positive mocked VCF file stream and check zygosity.
    """
    vcf_file_path, tabix_file_path = get_created_vcf_tabix_files(mock_file_input())
    vcf_bam_tuple = VCFReader.VcfBamPath(vcf=vcf_file_path, bam=tabix_file_path, is_fp=True)
    vcf_loader = VCFReader([vcf_bam_tuple])
    for v in vcf_loader:
        assert(v.zygosity == VariantZygosity.NO_VARIANT)
Esempio n. 7
0
def test_vcf_loader_snps(monkeypatch):
    """Get all variants from mocked file stream, filter SNPs, multi allele & multi samples
    """
    vcf_bam_tuple = VCFReader.VcfBamPath(vcf="/dummy/path.gz",
                                         bam="temp.bam",
                                         is_fp=False)
    vcf_loader = MockPyVCFReader.get_vcf(monkeypatch, [vcf_bam_tuple])
    assert (len(vcf_loader) == 13)
Esempio n. 8
0
def test_load_vcf_content_with_wrong_format(monkeypatch):
    """ parse vcf file with wrong format
    """
    vcf_bam_tuple = VCFReader.VcfBamPath(vcf="/dummy/path.gz",
                                         bam="temp.bam",
                                         is_fp=False)
    with pytest.raises(RuntimeError):
        MockPyVCFReader.get_invalid_vcf(monkeypatch, [vcf_bam_tuple])
Esempio n. 9
0
def generate_hdf5(args):
    """Serialize encodings to HDF5.

    Generate encodings in multiprocess loop and save tensors to HDF5.
    """
    # Get list of files from arguments.
    bam = args.bam
    file_list = []
    for tp_file in args.tp_files:
        file_list.append(VCFReader.VcfBamPath(
            vcf=tp_file, bam=bam, is_fp=False))
    for fp_file in args.fp_files:
        file_list.append(VCFReader.VcfBamPath(
            vcf=fp_file, bam=bam, is_fp=True))

    # Generate the variant entries using VCF reader.
    vcf_reader = VCFReader(file_list)

    # Setup encoder for samples and labels.
    sample_encoder = PileupEncoder(window_size=100, max_reads=100,
                                   layers=[PileupEncoder.Layer.READ, PileupEncoder.Layer.BASE_QUALITY])
    label_encoder = ZygosityLabelEncoder()

    encode_func = partial(encode, sample_encoder, label_encoder)

    # Create HDF5 datasets.
    h5_file = h5py.File(args.output_file, "w")
    encoded_data = h5_file.create_dataset("encodings",
                                          shape=(len(vcf_reader), sample_encoder.depth,
                                                 sample_encoder.height, sample_encoder.width),
                                          dtype=np.float32, fillvalue=0)
    label_data = h5_file.create_dataset("labels",
                                        shape=(len(vcf_reader),), dtype=np.int64, fillvalue=0)

    pool = mp.Pool(args.threads)
    print("Serializing {} entries...".format(len(vcf_reader)))
    for i, out in enumerate(pool.imap(encode_func, vcf_reader)):
        if i % 1000 == 0:
            print("Saved {} entries".format(i))
        encoding, label = out
        encoded_data[i] = encoding
        label_data[i] = label
    print("Saved {} entries".format(len(vcf_reader)))

    h5_file.close()
Esempio n. 10
0
def test_vcf_load_fp(monkeypatch):
    """Get first variant from false positive mocked VCF file stream and check zygosity.
    """
    vcf_bam_tuple = VCFReader.VcfBamPath(vcf="/dummy/path.gz",
                                         bam="temp.bam",
                                         is_fp=True)
    vcf_loader = MockPyVCFReader.get_vcf(monkeypatch, [vcf_bam_tuple])
    for v in vcf_loader:
        assert (v.zygosity == VariantZygosity.NO_VARIANT)
Esempio n. 11
0
def test_vcf_fetch_variant(get_created_vcf_tabix_files):
    """Get first variant from mocked VCF file stream.
    """
    vcf_file_path, tabix_file_path = get_created_vcf_tabix_files(mock_file_input())
    vcf_bam_tuple = VCFReader.VcfBamPath(vcf=vcf_file_path, bam=tabix_file_path, is_fp=False)
    vcf_loader = VCFReader([vcf_bam_tuple])
    try:
        assert (type(vcf_loader[0]) == Variant)
    except IndexError:
        pytest.fail("Can not retrieve first element from VCFReader")
Esempio n. 12
0
def test_vcf_fetch_variant(monkeypatch):
    """Get first variant from mocked VCF file stream.
    """
    vcf_bam_tuple = VCFReader.VcfBamPath(vcf="/dummy/path.gz",
                                         bam="temp.bam",
                                         is_fp=False)
    vcf_loader = MockPyVCFReader.get_vcf(monkeypatch, [vcf_bam_tuple])
    try:
        assert (type(vcf_loader[0]) == Variant)
    except IndexError:
        pytest.fail("Can not retrieve first element from VCFReader")
Esempio n. 13
0
def test_vcf_outputting(get_created_vcf_tabix_files):
    """Write inference output into vcf files
    """
    first_vcf_file_path, first_tabix_file_path = get_created_vcf_tabix_files(
        mock_small_filtered_file_input())
    second_vcf_file_path, second_tabix_file_path = get_created_vcf_tabix_files(
        mock_small_filtered_file_input())
    first_vcf_bam_tuple = VCFReader.VcfBamPath(vcf=first_vcf_file_path,
                                               bam=first_tabix_file_path,
                                               is_fp=False)
    second_vcf_bam_tuple = VCFReader.VcfBamPath(vcf=second_vcf_file_path,
                                                bam=second_tabix_file_path,
                                                is_fp=False)
    vcf_loader = VCFReader([first_vcf_bam_tuple, second_vcf_bam_tuple])

    inferred_results = [
        VariantZygosity.HOMOZYGOUS, VariantZygosity.HOMOZYGOUS,
        VariantZygosity.HETEROZYGOUS, VariantZygosity.HETEROZYGOUS,
        VariantZygosity.HOMOZYGOUS, VariantZygosity.HETEROZYGOUS
    ]
    assert (len(inferred_results) == len(vcf_loader))

    result_writer = VCFResultWriter(vcf_loader, inferred_results)
    result_writer.write_output()

    # Validate output files format and make sure the outputted genotype for each record matches to the network output
    first_output_file_name = \
        '{}_{}.{}'.format("inferred", "".join(os.path.basename(first_vcf_file_path).split('.')[0:-2]), 'vcf')
    second_output_file_name = \
        '{}_{}.{}'.format("inferred", "".join(os.path.basename(second_vcf_file_path).split('.')[0:-2]), 'vcf')
    i = 0
    for f in [first_output_file_name, second_output_file_name]:
        vcf_reader = vcf.Reader(
            filename=os.path.join(result_writer.output_location, f))
        for record in vcf_reader:
            assert (record.samples[0]['GT'] == result_writer.
                    zygosity_to_vcf_genotype[inferred_results[i]])
            i += 1
    assert (i == 6)
    # Clean up files
    shutil.rmtree(result_writer.output_location)
Esempio n. 14
0
def test_simple_vc_infer():
    # Load checkpointed model and run inference
    test_data_dir = get_data_folder()
    model_dir = os.path.join(test_data_dir, ".test_model")

    # Create neural factory
    nf = nemo.core.NeuralModuleFactory(
        placement=nemo.core.neural_factory.DeviceType.GPU,
        checkpoint_dir=model_dir)

    # Generate dataset
    bam = os.path.join(test_data_dir, "small_bam.bam")
    labels = os.path.join(test_data_dir, "candidates.vcf.gz")
    vcf_bam_tuple = VCFReader.VcfBamPath(vcf=labels, bam=bam, is_fp=False)
    vcf_loader = VCFReader([vcf_bam_tuple])
    test_dataset = ReadPileupDataLoader(ReadPileupDataLoader.Type.TEST,
                                        vcf_loader,
                                        batch_size=32,
                                        shuffle=False)

    # Neural Network
    alexnet = AlexNet(num_input_channels=1, num_output_logits=3)

    # Create train DAG
    encoding = test_dataset()
    vz = alexnet(encoding=encoding)

    # Invoke the "train" action.
    results = nf.infer([vz], checkpoint_dir=model_dir, verbose=True)

    # Decode inference results to labels
    zyg_decoder = ZygosityLabelDecoder()
    for tensor_batches in results:
        for batch in tensor_batches:
            predicted_classes = torch.argmax(batch, dim=1)
            inferred_zygosity = [
                zyg_decoder(pred) for pred in predicted_classes
            ]

    assert (len(inferred_zygosity) == len(vcf_loader))

    shutil.rmtree(model_dir)
Esempio n. 15
0
def test_simple_vc_trainer():
    # Train a sample model with test data

    # Create neural factory
    model_dir = os.path.join(get_data_folder(), ".test_model")
    nf = nemo.core.NeuralModuleFactory(
        placement=nemo.core.neural_factory.DeviceType.GPU,
        checkpoint_dir=model_dir)

    # Generate dataset
    bam = os.path.join(get_data_folder(), "small_bam.bam")
    labels = os.path.join(get_data_folder(), "candidates.vcf.gz")
    vcf_bam_tuple = VCFReader.VcfBamPath(vcf=labels, bam=bam, is_fp=False)
    vcf_loader = VCFReader([vcf_bam_tuple])

    # Neural Network
    alexnet = AlexNet(num_input_channels=1, num_output_logits=3)

    # Create train DAG
    dataset_train = ReadPileupDataLoader(ReadPileupDataLoader.Type.TRAIN,
                                         vcf_loader,
                                         batch_size=32,
                                         shuffle=True)
    vz_ce_loss = CrossEntropyLossNM(logits_ndim=2)
    vz_labels, encoding = dataset_train()
    vz = alexnet(encoding=encoding)
    vz_loss = vz_ce_loss(logits=vz, labels=vz_labels)

    # Create evaluation DAG using same dataset as training
    dataset_eval = ReadPileupDataLoader(ReadPileupDataLoader.Type.EVAL,
                                        vcf_loader,
                                        batch_size=32,
                                        shuffle=False)
    vz_ce_loss_eval = CrossEntropyLossNM(logits_ndim=2)
    vz_labels_eval, encoding_eval = dataset_eval()
    vz_eval = alexnet(encoding=encoding_eval)
    vz_loss_eval = vz_ce_loss_eval(logits=vz_eval, labels=vz_labels_eval)

    # Logger callback
    logger_callback = nemo.core.SimpleLossLoggerCallback(
        tensors=[vz_loss, vz, vz_labels],
        step_freq=1,
    )

    evaluator_callback = nemo.core.EvaluatorCallback(
        eval_tensors=[vz_loss_eval, vz_eval, vz_labels_eval],
        user_iter_callback=eval_iter_callback,
        user_epochs_done_callback=eval_epochs_done_callback,
        eval_step=1,
    )

    # Checkpointing models through NeMo callback
    checkpoint_callback = nemo.core.CheckpointCallback(
        folder=nf.checkpoint_dir,
        load_from_folder=None,
        # Checkpointing frequency in steps
        step_freq=-1,
        # Checkpointing frequency in epochs
        epoch_freq=1,
        # Number of checkpoints to keep
        checkpoints_to_keep=1,
        # If True, CheckpointCallback will raise an Error if restoring fails
        force_load=False)

    # Invoke the "train" action.
    nf.train(
        [vz_loss],
        callbacks=[logger_callback, checkpoint_callback, evaluator_callback],
        optimization_params={
            "num_epochs": 1,
            "lr": 0.001
        },
        optimizer="adam")

    assert (os.path.exists(os.path.join(model_dir, "AlexNet-EPOCH-1.pt")))