def get_binary_input_func(files_spec_path, input_length, reference_fasta):
    """
    Returns a data function needed to run binary models. This data function will
    take in an N-array of bin indices, and return the corresponding data needed
    to run the model.
    Arguments:
        `files_spec_path`: path to the JSON files spec for the model
        `input_length`: length of input sequence
        `reference_fasta`: path to reference fasta
    Returns a function that takes in an N-array of bin indices, and returns the
    following: the N x I x 4 one-hot encoded sequences, the N x T array of
    output values, and the N x 3 object array of input sequence coordinates.
    """
    with open(files_spec_path, "r") as f:
        files_spec = json.load(f)

    # Maps coordinates to 1-hot encoded sequence
    coords_to_seq = feature_util.CoordsToSeq(reference_fasta,
                                             center_size_to_use=input_length)

    # Maps bin index to values
    bins_to_vals = make_binary_dataset.BinsToVals(files_spec["labels_hdf5"])

    def input_func(bin_inds):
        coords, output_vals = bins_to_vals(bin_inds)
        input_seqs = coords_to_seq(coords)
        return input_seqs, output_vals, coords

    return input_func
def get_profile_input_func(files_spec_path, input_length, profile_length,
                           reference_fasta):
    """
    Returns a data function needed to run profile models. This data function
    will take in an N x 3 object array of coordinates, and return the
    corresponding data needed to run the model.
    Arguments:
        `files_spec_path`: path to the JSON files spec for the model
        `input_length`: length of input sequence
        `profile_length`: length of output profiles
        `reference_fasta`: path to reference fasta
    Returns a function that takes in an N x 3 array of coordinates, and returns
    the following: the N x I x 4 one-hot encoded sequences, and the
    N x (T or T + 1 or 2T) x O x 2 profiles (perhaps with controls).
    """
    with open(files_spec_path, "r") as f:
        files_spec = json.load(f)

    # Maps coordinates to 1-hot encoded sequence
    coords_to_seq = feature_util.CoordsToSeq(reference_fasta,
                                             center_size_to_use=input_length)

    # Maps coordinates to profiles
    coords_to_vals = make_profile_dataset.CoordsToVals(
        files_spec["profile_hdf5"], profile_length)

    def input_func(coords):
        input_seq = coords_to_seq(coords)
        profs = coords_to_vals(coords)
        return input_seq, np.swapaxes(profs, 1, 2)

    return input_func
def create_data_loader(peaks_bed_paths,
                       peak_bed_trans_paths,
                       profile_hdf5_path,
                       profile_trans_hdf5_path,
                       sampling_type,
                       batch_size,
                       reference_fasta,
                       chrom_sizes_tsv,
                       input_length,
                       profile_length,
                       negative_ratio,
                       peak_tiling_stride,
                       peak_retention,
                       num_workers,
                       revcomp,
                       jitter_size,
                       negative_seed,
                       shuffle_seed,
                       jitter_seed,
                       sig_thresh,
                       chrom_set=None,
                       shuffle=True,
                       return_coords=False):
    """
    Creates an IterableDataset object, which iterates through batches of
    coordinates and returns profiles for the coordinates.
    Arguments:
        `peaks_bed_paths`: a list of paths to gzipped 6-column BED files
            containing coordinates of positive-binding coordinates
        `profile_hdf5_path`: path to HDF5 containing reads mapped to each
            coordinate; this HDF5 must be organized by chromosome, with each
            dataset being L x S x 2, where L is the length of the chromosome,
            S is the number of tracks stored, and 2 is for each strand
        `sampling_type`: one of ("SamplingCoordsBatcher",
            "SummitCenteringCoordsBatcher", or "PeakTilingCoordsBatcher"), which
            corresponds to sampling positive and negative regions, taking only
            positive regions centered around summits, and taking only positive
            regions tiled across peaks
        `chrom_set`: a list of chromosomes to restrict to for the positives and
            sampled negatives; defaults to all coordinates in the given BEDs and
            sampling over the entire genome
        `shuffle`: if specified, shuffle the coordinates before each epoch
        `return_coords`: if specified, also return the underlying coordinates
            and peak data along with the profiles in each batch
    """
    # assert sampling_type in (
    #         "SamplingCoordsBatcher", "SummitCenteringCoordsBatcher",
    #         "PeakTilingCoordsBatcher"
    # )

    # Maps set of coordinates to profiles
    coords_to_vals = CoordsToVals(profile_hdf5_path, profile_length)
    coords_to_vals_trans = CoordsToVals(profile_trans_hdf5_path,
                                        profile_length)

    if sampling_type == "SamplingCoordsBatcher":
        # Randomly samples from genome
        genome_sampler = GenomeIntervalSampler(chrom_sizes_tsv,
                                               input_length,
                                               chroms_keep=chrom_set,
                                               seed=negative_seed)
        # Yields batches of positive and negative coordinates
        coords_batcher = SamplingCoordsBatcher(peaks_bed_paths,
                                               peak_bed_trans_paths,
                                               batch_size,
                                               negative_ratio,
                                               jitter_size,
                                               chrom_sizes_tsv,
                                               input_length,
                                               genome_sampler,
                                               chroms_keep=chrom_set,
                                               peak_retention=peak_retention,
                                               return_peaks=return_coords,
                                               shuffle_before_epoch=shuffle,
                                               jitter_seed=jitter_seed,
                                               shuffle_seed=shuffle_seed)

    elif sampling_type == "SamplingCoordsBatcherIntersect":
        # Randomly samples from genome
        genome_sampler = GenomeIntervalSampler(chrom_sizes_tsv,
                                               input_length,
                                               chroms_keep=chrom_set,
                                               seed=negative_seed)
        # Yields batches of positive and negative coordinates
        coords_batcher = SamplingCoordsBatcher(peaks_bed_paths,
                                               peak_bed_trans_paths,
                                               batch_size,
                                               negative_ratio,
                                               jitter_size,
                                               chrom_sizes_tsv,
                                               input_length,
                                               genome_sampler,
                                               chroms_keep=chrom_set,
                                               peak_retention=peak_retention,
                                               return_peaks=return_coords,
                                               shuffle_before_epoch=shuffle,
                                               jitter_seed=jitter_seed,
                                               shuffle_seed=shuffle_seed,
                                               peaks_trans_thresh_type="sig")

    elif sampling_type == "SamplingCoordsBatcherUnion":
        # Randomly samples from genome
        genome_sampler = GenomeIntervalSampler(chrom_sizes_tsv,
                                               input_length,
                                               chroms_keep=chrom_set,
                                               seed=negative_seed)
        # Yields batches of positive and negative coordinates
        coords_batcher = SamplingCoordsBatcher(peak_bed_trans_paths,
                                               peaks_bed_paths,
                                               batch_size,
                                               negative_ratio,
                                               jitter_size,
                                               chrom_sizes_tsv,
                                               input_length,
                                               genome_sampler,
                                               chroms_keep=chrom_set,
                                               peak_retention=peak_retention,
                                               return_peaks=return_coords,
                                               shuffle_before_epoch=shuffle,
                                               jitter_seed=jitter_seed,
                                               shuffle_seed=shuffle_seed,
                                               peaks_trans_thresh_type="union")

    elif sampling_type == "SummitCenteringCoordsBatcherToSig":
        # Yields batches of positive coordinates, centered at summits
        coords_batcher = SummitCenteringCoordsBatcher(
            peaks_bed_paths,
            peak_bed_trans_paths,
            batch_size,
            chrom_sizes_tsv,
            input_length,
            chroms_keep=chrom_set,
            return_peaks=return_coords,
            shuffle_before_epoch=shuffle,
            shuffle_seed=shuffle_seed)
    # elif sampling_type == "SummitCenteringCoordsBatcherToSig":
    #     # Yields batches of positive coordinates, centered at summits
    #     coords_batcher = SummitCenteringCoordsBatcher(
    #         peaks_bed_paths, peak_bed_trans_paths, batch_size, chrom_sizes_tsv, input_length,
    #         chroms_keep=chrom_set, return_peaks=return_coords,
    #         shuffle_before_epoch=shuffle, shuffle_seed=shuffle_seed,
    #         sig_thresh=sig_thresh, peaks_thresh_type="sig", peaks_trans_thresh_type=None
    #     )
    elif sampling_type == "SummitCenteringCoordsBatcherFromSig":
        # Yields batches of positive coordinates, centered at summits
        coords_batcher = SummitCenteringCoordsBatcher(
            peak_bed_trans_paths,
            peaks_bed_paths,
            batch_size,
            chrom_sizes_tsv,
            input_length,
            chroms_keep=chrom_set,
            return_peaks=return_coords,
            shuffle_before_epoch=shuffle,
            shuffle_seed=shuffle_seed)
    elif sampling_type == "SummitCenteringCoordsBatcherToSigFromSig":
        # Yields batches of positive coordinates, centered at summits
        coords_batcher = SummitCenteringCoordsBatcher(
            peaks_bed_paths,
            peak_bed_trans_paths,
            batch_size,
            chrom_sizes_tsv,
            input_length,
            chroms_keep=chrom_set,
            return_peaks=return_coords,
            shuffle_before_epoch=shuffle,
            shuffle_seed=shuffle_seed,
            peaks_trans_thresh_type="sig")
    elif sampling_type == "SummitCenteringCoordsBatcherToInsigFromSig":
        # Yields batches of positive coordinates, centered at summits
        coords_batcher = SummitCenteringCoordsBatcher(
            peak_bed_trans_paths,
            peaks_bed_paths,
            batch_size,
            chrom_sizes_tsv,
            input_length,
            chroms_keep=chrom_set,
            return_peaks=return_coords,
            shuffle_before_epoch=shuffle,
            shuffle_seed=shuffle_seed,
            peaks_trans_thresh_type="insig")
    elif sampling_type == "SummitCenteringCoordsBatcherToSigFromInsig":
        # Yields batches of positive coordinates, centered at summits
        coords_batcher = SummitCenteringCoordsBatcher(
            peaks_bed_paths,
            peak_bed_trans_paths,
            batch_size,
            chrom_sizes_tsv,
            input_length,
            chroms_keep=chrom_set,
            return_peaks=return_coords,
            shuffle_before_epoch=shuffle,
            shuffle_seed=shuffle_seed,
            peaks_trans_thresh_type="insig")
    elif sampling_type == "SummitCenteringCoordsBatcherUnion":
        # Yields batches of positive coordinates, centered at summits
        coords_batcher = SummitCenteringCoordsBatcher(
            peaks_bed_paths,
            peak_bed_trans_paths,
            batch_size,
            chrom_sizes_tsv,
            input_length,
            chroms_keep=chrom_set,
            return_peaks=return_coords,
            shuffle_before_epoch=shuffle,
            shuffle_seed=shuffle_seed,
            peaks_trans_thresh_type="union")
    # elif sampling_type == "SummitCenteringCoordsBatcherToInsigFromInsig":
    #     # Yields batches of positive coordinates, centered at summits
    #     coords_batcher = SummitCenteringCoordsBatcher(
    #         peaks_bed_paths, peak_bed_trans_paths, batch_size, chrom_sizes_tsv, input_length,
    #         chroms_keep=chrom_set, return_peaks=return_coords,
    #         shuffle_before_epoch=shuffle, shuffle_seed=shuffle_seed,
    #         sig_thresh=sig_thresh, peaks_thresh_type="insig", peaks_trans_thresh_type="insig"
    #     )
    else:
        # Yields batches of positive coordinates, tiled across peaks
        coords_batcher = PeakTilingCoordsBatcher(peaks_bed_paths,
                                                 peak_bed_trans_paths,
                                                 peak_tiling_stride,
                                                 batch_size,
                                                 chrom_sizes_tsv,
                                                 input_length,
                                                 chroms_keep=chrom_set,
                                                 return_peaks=return_coords,
                                                 shuffle_before_epoch=shuffle,
                                                 shuffle_seed=shuffle_seed)

    # print(sampling_type, coords_batcher.num_total_pos) ####

    # Maps set of coordinates to 1-hot encoding, padded
    coords_to_seq = util.CoordsToSeq(reference_fasta,
                                     center_size_to_use=input_length)

    # Dataset
    dataset = CoordDatasetTrans(coords_batcher,
                                coords_to_seq,
                                coords_to_vals,
                                coords_to_vals_trans,
                                revcomp=revcomp,
                                return_coords=return_coords)

    # Dataset loader: dataset is iterable and already returns batches
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=None,
                                         num_workers=num_workers,
                                         collate_fn=lambda x: x)

    return loader
def create_data_loader(labels_hdf5_path,
                       bin_labels_npy_or_array,
                       batch_size,
                       reference_fasta,
                       simulate_seqs,
                       motif_path,
                       motif_bound,
                       gc_prob,
                       input_length,
                       negative_ratio,
                       peak_retention,
                       num_workers,
                       revcomp,
                       negative_seed,
                       shuffle_seed,
                       peak_signals_npy_or_array=None,
                       chrom_set=None,
                       shuffle=True,
                       return_coords=False):
    """
    Creates an IterableDataset object, which iterates through batches of
    bins and returns values for the bins.
    Arguments:
        `labels_hdf5_path`: path to HDF5 containing labels; this HDF5 must be a
            single dataset created by `generate_ENCODE_TFChIP_binary_labels.sh`;
            each row must be: (index, values, end, start, chrom), where the
            values is a T-array of values, for each task T, containing 0, 1, or
            nan
        `bin_labels_npy_or_array`: either the path to a pickled N x 2 object
            array, or the array already imported; this array must be generated
            by `create_ENCODE_binary_bins.py`
        `peak_signals_npy_or_array`: either the path to an N x T array, or the
            array already imported; this array must be generated by
            `create_ENCODE_binary_bins.py`; this is only required if
            `peak_retention` is used
        `chrom_set`: a list of chromosomes to restrict to for the positives and
            negatives; defaults to all coordinates in HDF5
        `shuffle`: if specified, shuffle the coordinates before each epoch
        `return_coords`: if specified, also return the underlying coordinates
            along with the values in each batch
    """
    # Maps set of bin indices to coordinates and values
    bins_to_vals = BinsToVals(labels_hdf5_path)

    # Yields batches of positive and negative bin indices
    if type(bin_labels_npy_or_array) is str:
        bin_labels_array = np.load(bin_labels_npy_or_array, allow_pickle=True)
    else:
        bin_labels_array = bin_labels_npy_or_array
    if type(peak_signals_npy_or_array) is str:
        peak_signals_array = np.load(peak_signals_npy_or_array)
    else:
        peak_signals_array = peak_signals_npy_or_array  # Could be None
    bins_batcher = SamplingBinsBatcher(bin_labels_array,
                                       batch_size,
                                       negative_ratio,
                                       chroms_keep=chrom_set,
                                       peak_retention=peak_retention,
                                       peak_signals_array=peak_signals_array,
                                       shuffle_before_epoch=shuffle,
                                       shuffle_seed=shuffle_seed)

    print("Total class counts:")
    num_pos, num_neg = len(bins_batcher.pos_inds), len(bins_batcher.neg_inds)
    print("\tPos: %d, Neg: %d" % (num_pos, num_neg))
    if num_pos:
        print("\tNeg/Pos = %f" % (num_neg / num_pos))

    if simulate_seqs:
        # Maps set of 1s/0s to 1-hot encoding
        seq_mapper = util.StatusToSimulatedSeq(input_length, motif_path,
                                               motif_bound, gc_prob)
    else:
        # Maps set of coordinates to 1-hot encoding, padded
        seq_mapper = util.CoordsToSeq(reference_fasta,
                                      center_size_to_use=input_length)

    # Dataset
    dataset = BinDataset(bins_batcher,
                         seq_mapper,
                         bins_to_vals,
                         revcomp=revcomp,
                         return_coords=return_coords)

    # Dataset loader: dataset is iterable and already returns batches
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=None,
                                         num_workers=num_workers,
                                         collate_fn=lambda x: x)

    return loader