Example #1
0
def test_pileup_unknown_layer():
    max_reads = 100
    window_size = 5
    with pytest.raises(AttributeError):
        layers = [PileupEncoder.Layer.BLAH]
        PileupEncoder(window_size=window_size,
                      max_reads=max_reads,
                      layers=layers)
def test_snp_allele_encoding(snp_variant):
    max_reads = 1
    window_size = 5
    layers = [PileupEncoder.Layer.ALLELE]

    encoder = PileupEncoder(window_size=window_size,
                            max_reads=max_reads, layers=layers)

    variant = snp_variant
    encoding = encoder(variant)
    assert(encoding[0, 0, window_size] == base_enum_encoder[variant.allele])
def test_snp_ref_encoding(snp_variant):
    max_reads = 1
    window_size = 5
    layers = [PileupEncoder.Layer.REFERENCE]

    encoder = PileupEncoder(window_size=window_size,
                            max_reads=max_reads, layers=layers)

    variant = snp_variant
    encoding = encoder(variant)
    assert(encoding[0, 0, window_size] == base_enum_encoder[variant.ref])
def generate_hdf5(args):
    """Serialize encodings to HDF5.

    Generate encodings in multiprocess loop and save tensors to HDF5.
    """
    # Get list of files from arguments
    # and generate the variant entries using VCF reader.
    bam = args.bam
    vcf_readers = []
    for tp_file in args.tp_files:
        vcf_readers.append(VCFReader(vcf=tp_file, bams=[bam], is_fp=False))
    for fp_file in args.fp_files:
        vcf_readers.append(VCFReader(vcf=fp_file, bams=[bam], is_fp=True))
    total_labels = sum([len(reader) for reader in vcf_readers])

    # Setup encoder for samples and labels.
    sample_encoder = PileupEncoder(
        window_size=100,
        max_reads=100,
        layers=[PileupEncoder.Layer.READ, PileupEncoder.Layer.BASE_QUALITY])
    label_encoder = ZygosityLabelEncoder()

    encode_func = partial(encode, sample_encoder, label_encoder)

    # Create HDF5 datasets.
    h5_file = h5py.File(args.output_file, "w")
    encoded_data = h5_file.create_dataset(
        "encodings",
        shape=(total_labels, sample_encoder.depth, sample_encoder.height,
               sample_encoder.width),
        dtype=np.float32,
        fillvalue=0)
    label_data = h5_file.create_dataset("labels",
                                        shape=(total_labels, ),
                                        dtype=np.int64,
                                        fillvalue=0)

    pool = mp.Pool(args.threads)
    print("Serializing {} entries...".format(total_labels))
    for vcf_reader in vcf_readers:
        label_idx = 0
        for out in pool.imap(encode_func, vcf_reader):
            if label_idx % 1000 == 0:
                print("Saved {} entries".format(label_idx))
            encoding, label = out
            encoded_data[label_idx] = encoding
            label_data[label_idx] = label
            label_idx += 1
    print("Saved {} entries".format(total_labels))

    h5_file.close()
def test_snp_encoder_basic(snp_variant):
    max_reads = 100
    window_size = 10
    width = 2 * window_size + 1
    height = max_reads
    layers = [PileupEncoder.Layer.READ]

    encoder = PileupEncoder(window_size=window_size,
                            max_reads=max_reads, layers=layers)

    variant = snp_variant

    encoding = encoder(variant)
    assert(encoding.size() == torch.Size([len(layers), height, width]))
def test_deletion_read_encoding(deletion_variant):
    max_reads = 100
    window_size = 10
    width = 2 * window_size + 1
    height = max_reads
    layers = [PileupEncoder.Layer.READ, PileupEncoder.Layer.REFERENCE, PileupEncoder.Layer.ALLELE]

    encoder = PileupEncoder(window_size=window_size,
                            max_reads=max_reads, layers=layers)

    variant = deletion_variant

    encoding = encoder(variant)
    assert(encoding.size() == torch.Size([len(layers), height, width]))
Example #7
0
def generate_hdf5(args):
    """Serialize encodings to HDF5.

    Generate encodings in multiprocess loop and save tensors to HDF5.
    """
    # Get list of files from arguments.
    bam = args.bam
    file_list = []
    for tp_file in args.tp_files:
        file_list.append(VCFReader.VcfBamPath(
            vcf=tp_file, bam=bam, is_fp=False))
    for fp_file in args.fp_files:
        file_list.append(VCFReader.VcfBamPath(
            vcf=fp_file, bam=bam, is_fp=True))

    # Generate the variant entries using VCF reader.
    vcf_reader = VCFReader(file_list)

    # Setup encoder for samples and labels.
    sample_encoder = PileupEncoder(window_size=100, max_reads=100,
                                   layers=[PileupEncoder.Layer.READ, PileupEncoder.Layer.BASE_QUALITY])
    label_encoder = ZygosityLabelEncoder()

    encode_func = partial(encode, sample_encoder, label_encoder)

    # Create HDF5 datasets.
    h5_file = h5py.File(args.output_file, "w")
    encoded_data = h5_file.create_dataset("encodings",
                                          shape=(len(vcf_reader), sample_encoder.depth,
                                                 sample_encoder.height, sample_encoder.width),
                                          dtype=np.float32, fillvalue=0)
    label_data = h5_file.create_dataset("labels",
                                        shape=(len(vcf_reader),), dtype=np.int64, fillvalue=0)

    pool = mp.Pool(args.threads)
    print("Serializing {} entries...".format(len(vcf_reader)))
    for i, out in enumerate(pool.imap(encode_func, vcf_reader)):
        if i % 1000 == 0:
            print("Saved {} entries".format(i))
        encoding, label = out
        encoded_data[i] = encoding
        label_data[i] = label
    print("Saved {} entries".format(len(vcf_reader)))

    h5_file.close()
def test_snp_encoder_base_quality(snp_variant):
    max_reads = 100
    window_size = 5
    width = 2 * window_size + 1
    height = max_reads
    layers = [PileupEncoder.Layer.BASE_QUALITY]

    encoder = PileupEncoder(window_size=window_size,
                            max_reads=max_reads, layers=layers)

    variant = snp_variant

    encoding = encoder(variant)
    assert(encoding.size() == torch.Size([len(layers), height, width]))

    # Verify that all elements are <= 1 by first outputing a bool tensor
    # and then converting it to a long tensor and summing up all elements to match
    # against total size.
    all_lt_1 = (encoding <= 1.0).long()
    assert(torch.sum(all_lt_1) == (height * width))
Example #9
0
    def __init__(self,
                 data_loader_type,
                 variant_loaders,
                 batch_size=32,
                 shuffle=True,
                 num_workers=4,
                 sample_encoder=PileupEncoder(
                     window_size=100,
                     max_reads=100,
                     layers=[PileupEncoder.Layer.READ]),
                 label_encoder=ZygosityLabelEncoder()):
        """Construct a data loader.

        Args:
            data_loader_type : Type of data loader (ReadPileupDataLoader.Type.TRAIN/EVAL/TEST)
            variant_loaders : A list of loader classes for variants
            batch_size : batch size for data loader [32]
            shuffle : shuffle dataset [True]
            num_workers : numbers of parallel data loader threads [4]
            sample_encoder : Custom pileup encoder for variant [READ pileup encoding, window size 100]
            label_encoder : Custom label encoder for variant [ZygosityLabelEncoder] (Only applicable
            when type=TRAIN/EVAL)

        Returns:
            Instance of class.
        """
        super().__init__()
        self.data_loader_type = data_loader_type
        self.variant_loaders = variant_loaders
        self.sample_encoder = sample_encoder
        self.label_encoder = label_encoder

        class DatasetWrapper(TorchDataset):
            """A wrapper around Torch dataset class to generate individual samples."""
            def __init__(self, data_loader_type, sample_encoder,
                         variant_loaders, label_encoder):
                """Construct a dataset wrapper.

                Args:
                    data_loader_type : Type of data loader
                    sample_encoder : Custom pileup encoder for variant
                    variant_loaders : A list of loader classes for variants
                    label_encoder : Custom label encoder for variant

                Returns:
                    Instance of class.
                """
                super().__init__()
                self.variant_loaders = variant_loaders
                self.label_encoder = label_encoder
                self.sample_encoder = sample_encoder
                self.data_loader_type = data_loader_type

                self._len = sum(
                    [len(loader) for loader in self.variant_loaders])

            def _map_idx_to_sample(self, sample_idx):
                file_idx = 0
                while (file_idx < len(self.variant_loaders)):
                    if sample_idx < len(self.variant_loaders[file_idx]):
                        return self.variant_loaders[file_idx][sample_idx]
                    else:
                        sample_idx -= len(self.variant_loaders[file_idx])
                        file_idx += 1
                raise RuntimeError(
                    "Could not map sample index to file. This is a bug.")

            def __len__(self):
                return self._len

            def __getitem__(self, idx):
                sample = self._map_idx_to_sample(idx)

                if self.data_loader_type == ReadPileupDataLoader.Type.TEST:
                    sample = self.sample_encoder(sample)

                    return sample
                else:
                    encoding = self.sample_encoder(sample)
                    label = self.label_encoder(sample)

                    return label, encoding

        dataset = DatasetWrapper(data_loader_type, self.sample_encoder,
                                 self.variant_loaders, self.label_encoder)
        self.dataloader = TorchDataLoader(dataset,
                                          batch_size=batch_size,
                                          shuffle=shuffle,
                                          num_workers=num_workers)
Example #10
0
    def __init__(self,
                 data_loader_type,
                 variant_loader,
                 batch_size=32,
                 shuffle=True,
                 num_workers=4,
                 sample_encoder=PileupEncoder(
                     window_size=100,
                     max_reads=100,
                     layers=[PileupEncoder.Layer.READ]),
                 label_encoder=ZygosityLabelEncoder()):
        """Construct a data loader.

        Args:
            data_loader_type : Type of data loader (ReadPileupDataLoader.Type.TRAIN/EVAL/TEST)
            variant_loader : A loader class for variants
            batch_size : batch size for data loader [32]
            shuffle : shuffle dataset [True]
            num_workers : numbers of parallel data loader threads [4]
            sample_encoder : Custom pileup encoder for variant [READ pileup encoding, window size 100]
            label_encoder : Custom label encoder for variant [ZygosityLabelEncoder] (Only applicable
            when type=TRAIN/EVAL)

        Returns:
            Instance of class.
        """
        super().__init__()
        self.data_loader_type = data_loader_type
        self.variant_loader = variant_loader
        self.sample_encoder = sample_encoder
        self.label_encoder = label_encoder

        class DatasetWrapper(TorchDataset):
            """A wrapper around Torch dataset class to generate individual samples."""
            def __init__(self, data_loader_type, sample_encoder,
                         variant_loader, label_encoder):
                """Construct a dataset wrapper.

                Args:
                    data_loader_type : Type of data loader
                    sample_encoder : Custom pileup encoder for variant
                    variant_loader : A loader class for variants
                    label_encoder : Custom label encoder for variant

                Returns:
                    Instance of class.
                """
                super().__init__()
                self.variant_loader = variant_loader
                self.label_encoder = label_encoder
                self.sample_encoder = sample_encoder
                self.data_loader_type = data_loader_type

            def __len__(self):
                return len(self.variant_loader)

            def __getitem__(self, idx):
                sample = self.variant_loader[idx]

                if self.data_loader_type == ReadPileupDataLoader.Type.TEST:
                    sample = self.sample_encoder(sample)

                    return sample
                else:
                    encoding = self.sample_encoder(sample)
                    label = self.label_encoder(sample)

                    return label, encoding

        dataset = DatasetWrapper(data_loader_type, self.sample_encoder,
                                 self.variant_loader, self.label_encoder)
        self.dataloader = TorchDataLoader(dataset,
                                          batch_size=batch_size,
                                          shuffle=shuffle,
                                          num_workers=num_workers)