示例#1
0
    def vdom_plot(self, kind='seq', width=80, letter_width=0.2, letter_height=0.8, as_html=False):
        """Get the html
        """
        from basepair.plot.vdom import vdom_pssm

        if kind == 'contrib':
            # summarize across tasks
            arr = mean(list(self.contrib.values()))  # average the contrib scores across tasks
        elif kind == 'hyp_contrib':
            # summarize across tasks
            arr = mean(list(self.hyp_contrib.values()))  # average the contrib scores across tasks
        elif kind == 'seq':
            # get the IC
            arr = self.get_seq_ic()
        else:
            self._validate_kind(kind)
            arr = self._get_track(kind)

        vdom_obj = vdom_pssm(arr,
                             letter_width=letter_width,
                             letter_height=letter_height)
        if as_html:
            return vdom_obj.to_html().replace("<img", f"<img width={width}")  # hack
        else:
            return vdom_obj
示例#2
0
def average_counts(pe):
    tasks = list(pe)
    metrics = list(pe[tasks[0]])
    return {
        metric: mean([pe[task][metric] for task in tasks])
        for metric in metrics
    }
示例#3
0
def modisco_score(modisco_dir,
                  imp_scores,
                  output_tsv,
                  output_seqlets_pkl=None,
                  seqlet_len=25,
                  n_cores=1,
                  method="rank",
                  trim_pattern=False):
    """Find seqlet instances using modisco
    """
    add_file_logging(os.path.dirname(output_tsv), logger, 'modisco-score')
    mr, tasks, grad_type = load_modisco_results(modisco_dir)

    # load importance scores we want to score
    d = HDF5Reader.load(imp_scores)
    if 'hyp_imp' not in d:
        # backcompatibility
        d['hyp_imp'] = d['grads']

    if isinstance(d['inputs'], dict):
        one_hot = d['inputs']['seq']
    else:
        one_hot = d['inputs']
    hypothetical_contribs = {
        f"{task}/{gt}": mean(d['hyp_imp'][task][gt])
        for task in tasks for gt in grad_type.split(",")
    }
    contrib_scores = {
        f"{task}/{gt}": hypothetical_contribs[f"{task}/{gt}"] * one_hot
        for task in tasks for gt in grad_type.split(",")
    }

    seqlets = find_instances(mr,
                             tasks,
                             contrib_scores,
                             hypothetical_contribs,
                             one_hot,
                             seqlet_len=seqlet_len,
                             n_cores=n_cores,
                             method=method,
                             trim_pattern=trim_pattern)
    if len(seqlets) == 0:
        print("ERROR: no seqlets found!!")
        return [], None

    if output_seqlets_pkl:
        write_pkl(seqlets, output_seqlets_pkl)
    df = labelled_seqlets2df(seqlets)

    dfm = pd.DataFrame(d['metadata']['range'])
    dfm.columns = ["example_" + v for v in dfm.columns]

    df = df.merge(dfm,
                  left_on="example_idx",
                  how='left',
                  right_on="example_id")

    df.to_csv(output_tsv, sep='\t')

    return seqlets, df
示例#4
0
def logo_imp(pattern, data, width=80):
    arr = mean([data.get_imp(pattern, task, 'profile').mean(axis=0)
                for task in data.get_tasks()])

    # trim array to match the pwm
    i, j = data.get_trim_idx(pattern)
    arr = arr[i:j]
    return vdom_pssm(arr).to_html().replace("<img", f"<img width={width}")  # hack
示例#5
0
def average_profile(pe):
    tasks = list(pe)
    binsizes = list(pe[tasks[0]])
    return {
        binsize: {
            "auprc": mean([pe[task][binsize]['auprc'] for task in tasks])
        }
        for binsize in binsizes
    }
示例#6
0
 def get_task(d, k):
     if k is None:
         return d
     elif k == 'mean':
         return mean(list(d.values()))
     elif k == 'sum':
         return sum(list(d.values()))
     # TODO - add weighted based on the importance scores
     else:
         return d[k]
示例#7
0
 def aligned_distance_seq(self, pattern, metric='simmetric_kl', pseudo_p=1e-3):
     """Average per-base distribution distance
     """
     # introduce pseudo-counts
     sp1 = self.seq + pseudo_p
     sp1 = sp1 / sp1.sum(1, keepdims=True)
     sp2 = pattern.seq + pseudo_p
     sp2 = sp2 / sp1.sum(1, keepdims=True)
     m = get_metric(metric)
     return mean([m(sp1[i], sp2[i]) for i in range(len(sp1))])
示例#8
0
    def evaluate(self,
                 dataset,
                 eval_metric=None,
                 num_workers=8,
                 batch_size=256):
        lpreds = []
        llabels = []
        for inputs, targets in tqdm(dataset.batch_train_iter(
                cycle=False, num_workers=num_workers, batch_size=batch_size),
                                    total=len(dataset) // batch_size):
            assert isinstance(targets, dict)
            target_keys = list(targets)
            llabels.append(deepcopy(targets))
            bpreds = {
                k: v
                for k, v in self.predict(inputs, batch_size=None).items()
                if k in target_keys
            }  # keep only the target key predictions
            lpreds.append(bpreds)
            del inputs
            del targets
        preds = numpy_collate_concat(lpreds)
        labels = numpy_collate_concat(llabels)
        del lpreds
        del llabels

        if eval_metric is not None:
            return eval_metric(labels, preds)
        else:
            task_avg_tape = defaultdict(list)
            out = {}
            for task, heads in self.all_heads.items():
                for head_i, head in enumerate(heads):
                    target_name = head.get_target(task)
                    if target_name not in labels:
                        print(
                            f"Target {target_name} not found. Skipping evaluation"
                        )
                        continue
                    res = head.metric(labels[target_name], preds[target_name])
                    out[target_name] = res
                    metrics_dict = flatten(res, separator='/')
                    for k, v in metrics_dict.items():
                        task_avg_tape[
                            head.target_name.replace("{task}", "avg") + "/" +
                            k].append(v)
            for k, v in task_avg_tape.items():
                # get the average
                out[k] = mean(v)

        # flatten everything
        out = flatten(out, separator='/')
        return out
示例#9
0
    def aligned_distance_profile(self, pattern, metric='simmetric_kl', pseudo_p=1e-8):
        """Compare two profile distributions (average across strands)
        """
        m = get_metric(metric)

        # introduce pseudo-counts
        o = dict()
        for t in self.tasks():
            pp1 = self.profile[t] + pseudo_p
            pp1 = pp1 / pp1.sum(0, keepdims=True)
            pp2 = pattern.profile[t] + pseudo_p
            pp2 = pp2 / pp2.sum(0, keepdims=True)
            o[t] = mean([m(pp1[i], pp2[i]) for i in range(pp1.shape[1])])
        return o
示例#10
0
def load_modisco_results(modisco_dir):
    """Load modisco_result - return

    Args:
      modisco_dir: directory path `output_dir` in `basepair.cli.modisco.modisco_run`
        contains: modisco.h5, strand_distances.h5, kwargs.json

    Returns:
      TfModiscoResults object containing original track_set
    """
    import modisco
    from modisco.tfmodisco_workflow import workflow
    modisco_kwargs = read_json(f"{modisco_dir}/kwargs.json")
    grad_type = modisco_kwargs['grad_type']

    # load used strand distance filter
    included_samples = HDF5Reader.load(
        f"{modisco_dir}/strand_distances.h5")['included_samples']

    # load importance scores
    d = HDF5Reader.load(modisco_kwargs['imp_scores'])
    if 'hyp_imp' not in d:
        # backcompatibility
        d['hyp_imp'] = d['grads']

    tasks = list(d['targets']['profile'])
    if isinstance(d['inputs'], dict):
        one_hot = d['inputs']['seq']
    else:
        one_hot = d['inputs']
    thr_hypothetical_contribs = {
        f"{task}/{gt}": mean(d['hyp_imp'][task][gt])[included_samples]
        for task in tasks for gt in grad_type.split(",")
    }
    thr_one_hot = one_hot[included_samples]
    thr_contrib_scores = {
        f"{task}/{gt}": thr_hypothetical_contribs[f"{task}/{gt}"] * thr_one_hot
        for task in tasks for gt in grad_type.split(",")
    }

    track_set = modisco.tfmodisco_workflow.workflow.prep_track_set(
        task_names=tasks,
        contrib_scores=thr_contrib_scores,
        hypothetical_contribs=thr_hypothetical_contribs,
        one_hot=thr_one_hot)

    with h5py.File(os.path.join(modisco_dir, "modisco.h5"), "r") as grp:
        mr = workflow.TfModiscoResults.from_hdf5(grp, track_set=track_set)
    return mr, tasks, grad_type
示例#11
0
    def imp_score_all(self,
                      seq,
                      method='grad',
                      aggregate_strand=False,
                      batch_size=512,
                      pred_summaries=['weighted', 'count']):
        """Compute all importance scores

        Args:
          seq: one-hot encoded DNA sequences
          method: 'grad', 'deeplift' or 'ism'
          aggregate_strands: if True, the average importance scores across strands will be returned
          batch_size: batch size when computing the importance scores

        Returns:
          dictionary with keys: {task}/{pred_summary}/{strand_i} or {task}/{pred_summary}
          and values with the same shape as `seq` corresponding to importance scores
        """
        d_n_channels = {task: 2 for task_id, task in enumerate(self.tasks)}
        # TODO - update
        # preds_dict['counts'][task_id].shape[-1]

        # TODO - implement the ism version
        # if method == 'ism':
        #     return self.ism()

        out = {
            f"{task}/{pred_summary}/{strand_i}":
            self.imp_score(seq,
                           task=task,
                           strand=strand,
                           method=method,
                           pred_summary=pred_summary,
                           batch_size=batch_size)
            for task in self.tasks for strand_i, strand in enumerate(
                ['pos', 'neg'][:d_n_channels[task]])
            for pred_summary in pred_summaries
        }
        if aggregate_strand:
            return {
                f"{task}/{pred_summary}": mean([
                    out[f"{task}/{pred_summary}/{strand_i}"] for strand_i,
                    strand in enumerate(['pos', 'neg'][:d_n_channels[task]])
                ])
                for pred_summary in ['weighted', 'count']
                for task in self.tasks
            }
        else:
            return out
示例#12
0
    def get_hyp_contrib(self, imp_score=None, idx=None, pred_summary=None):
        if pred_summary is not None:
            warnings.warn("pred_summary is deprecated. Use `imp_score`")
            imp_score = pred_summary

        imp_score = (imp_score if imp_score is not None
                     else self.default_imp_score)
        if imp_score in self._hyp_contrib_cache and idx is None:
            return self._hyp_contrib_cache[imp_score]
        else:
            # NOTE: this line averages any additional axes after {imp_score} like
            # strands denoted with:
            # /hyp_imp/{task}/{imp_score}/{strand}, where strand = 0 or 1
            out = {task: mean([self._subset(self.data[k], idx)
                               for k in self._data_subkeys(f'/hyp_imp/{task}/{imp_score}')])
                   for task in self.get_tasks()
                   }
            if idx is None:
                self._hyp_contrib_cache[imp_score] = out
            return out
示例#13
0
def modisco_run(
    imp_scores,
    output_dir,
    null_imp_scores=None,
    hparams=None,
    override_hparams="",
    grad_type="weighted",
    subset_tasks=None,
    filter_subset_tasks=False,
    filter_npy=None,
    exclude_chr="",
    seqmodel=False,  # interpretation glob
    # hparams=None,
    num_workers=10,
    max_strand_distance=0.1,
    overwrite=False,
    skip_dist_filter=False,
    use_all_seqlets=False,
    merge_tasks=False,
    gpu=None,
):
    """
    Run modisco

    Args:
      imp_scores: path to the hdf5 file of importance scores
      null_imp_scores: Path to the null importance scores
      grad_type: for which output to compute the importance scores
      hparams: None, modisco hyper - parameeters: either a path to modisco.yaml or
        a ModiscoHParams object
      override_hparams: hyper - parameters overriding the settings in the hparams file
      output_dir: output file directory
      filter_npy: path to a npy file containing a boolean vector used for subsetting
      exclude_chr: comma-separated list of chromosomes to exclude
      seqmodel: If enabled, then the importance scores came from `imp-score-seqmodel`
      subset_tasks: comma-separated list of task names to use as a subset
      filter_subset_tasks: if True, run modisco only in the regions for that TF
      hparams: hyper - parameter file
      summary: which summary statistic to use for the profile gradients
      skip_dist_filter: if True, distances are not used to filter
      use_all_seqlets: if True, don't restrict the number of seqlets
      split: On which data split to compute the results
      merge_task: if True, importance scores for the tasks will be merged
      gpu: which gpu to use. If None, don't use any GPU's

    Note: when using subset_tasks, modisco will run on all the importance scores. If you wish
      to run it only for the importance scores for a particular task you should subset it to
      the peak regions of interest using `filter_npy`
    """
    plt.switch_backend('agg')
    add_file_logging(output_dir, logger, 'modisco-run')
    import os
    if gpu is not None:
        create_tf_session(gpu)
    else:
        # Don't use any GPU's
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
    os.environ['MKL_THREADING_LAYER'] = 'GNU'
    # import theano
    import modisco
    import modisco.tfmodisco_workflow.workflow

    if seqmodel:
        assert '/' in grad_type

    if subset_tasks == '':
        logger.warn("subset_tasks == ''. Not using subset_tasks")
        subset_tasks = None

    if subset_tasks == 'all':
        # Use all subset tasks e.g. don't subset
        subset_tasks = None

    if subset_tasks is not None:
        subset_tasks = subset_tasks.split(",")
        if len(subset_tasks) == 0:
            raise ValueError("Provide one or more subset_tasks. Found None")

    if filter_subset_tasks and subset_tasks is None:
        print("Using filter_subset_tasks=False since `subset_tasks` is None")
        filter_subset_tasks = False

    if exclude_chr:
        exclude_chr = exclude_chr.split(",")
    else:
        exclude_chr = []

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    output_path = output_dir / "modisco.h5"
    remove_exists(output_path, overwrite)

    output_distances = output_dir / "strand_distances.h5"
    remove_exists(output_distances, overwrite)

    if filter_npy is not None:
        filter_npy = os.path.abspath(filter_npy)

    # save the hyper-parameters
    write_json(
        dict(
            imp_scores=os.path.abspath(imp_scores),
            grad_type=grad_type,
            output_dir=str(output_dir),
            subset_tasks=subset_tasks,
            filter_subset_tasks=filter_subset_tasks,
            hparams=hparams,
            null_imp_scores=null_imp_scores,
            # TODO - pack into hyper-parameters as well?
            filter_npy=filter_npy,
            exclude_chr=",".join(exclude_chr),
            skip_dist_filter=skip_dist_filter,
            use_all_seqlets=use_all_seqlets,
            max_strand_distance=max_strand_distance,
            gpu=gpu),
        os.path.join(output_dir, "kwargs.json"))

    print("-" * 40)
    # parse the hyper-parameters
    if hparams is None:
        print(f"Using default hyper-parameters")
        hp = ModiscoHParams()
    else:
        if isinstance(hparams, str):
            print(f"Loading hyper-parameters from file: {hparams}")
            hp = ModiscoHParams.load(hparams)
        else:
            assert isinstance(hparams, ModiscoHParams)
            hp = hparams
    if override_hparams:
        print(f"Overriding the following hyper-parameters: {override_hparams}")
    hp = tf.contrib.training.HParams(
        **hp.get_modisco_kwargs()).parse(override_hparams)

    if use_all_seqlets:
        hp.max_seqlets_per_metacluster = None

    # save the hyper-parameters
    print("Using the following hyper-parameters for modisco:")
    print("-" * 40)
    related_dump_yaml(ModiscoHParams(**hp.values()),
                      os.path.join(output_dir, "hparams.yaml"),
                      verbose=True)
    print("-" * 40)

    # TODO - replace with imp_scores
    d = HDF5Reader.load(imp_scores)
    if 'hyp_imp' not in d:
        # backcompatibility
        d['hyp_imp'] = d['grads']

    if seqmodel:
        tasks = list(d['targets'])
    else:
        tasks = list(d['targets']['profile'])

    if subset_tasks is not None:
        # validate that all the `subset_tasks`
        # are present in `tasks`
        for st in subset_tasks:
            if st not in tasks:
                raise ValueError(
                    f"subset task {st} not found in tasks: {tasks}")
        logger.info(
            f"Using the following tasks: {subset_tasks} instead of the original tasks: {tasks}"
        )
        tasks = subset_tasks

    if isinstance(d['inputs'], dict):
        one_hot = d['inputs']['seq']
    else:
        one_hot = d['inputs']

    n = len(one_hot)

    # --------------------
    # apply filters
    if not skip_dist_filter:
        print("Using profile prediction for the strand filtering")
        grad_type_filtered = 'weighted'
        distances = np.array([
            np.array([
                correlation(
                    np.ravel(d['hyp_imp'][task][grad_type_filtered][0][i]),
                    np.ravel(d['hyp_imp'][task][grad_type_filtered][1][i]))
                for i in range(n)
            ]) for task in tasks
            if len(d['hyp_imp'][task][grad_type_filtered]) == 2
        ]).T.mean(axis=-1)  # average the distances across tasks

        dist_filter = distances < max_strand_distance
        print(f"Fraction of sequences kept: {dist_filter.mean()}")

        HDF5BatchWriter.dump(output_distances, {
            "distances": distances,
            "included_samples": dist_filter
        })
    else:
        dist_filter = np.ones((n, ), dtype=bool)

    # add also the filter numpy
    if filter_npy is not None:
        print(f"Loading a filter file from {filter_npy}")
        filter_vec = np.load(filter_npy)
        dist_filter = dist_filter & filter_vec

    if filter_subset_tasks:
        assert subset_tasks is not None
        interval_from_task = pd.Series(d['metadata']['interval_from_task'])
        print(
            f"Subsetting the intervals accoring to subset_tasks: {subset_tasks}"
        )
        print(f"Number of original regions: {dist_filter.sum()}")
        dist_filter = dist_filter & interval_from_task.isin(
            subset_tasks).values
        print(
            f"Number of filtered regions after filter_subset_tasks: {dist_filter.sum()}"
        )

    # filter by chromosome
    if exclude_chr:
        logger.info(f"Excluding chromosomes: {exclude_chr}")
        chromosomes = d['metadata']['range']['chr']
        dist_filter = dist_filter & (
            ~pd.Series(chromosomes).isin(exclude_chr)).values
    # -------------------------------------------------------------
    # setup importance scores

    if seqmodel:
        thr_one_hot = one_hot[dist_filter]
        thr_hypothetical_contribs = {
            f"{task}/{gt}":
            d['hyp_imp'][task][gt.split("/")[0]][gt.split("/")[1]][dist_filter]
            for task in tasks for gt in grad_type.split(",")
        }
        thr_contrib_scores = {
            f"{task}/{gt}":
            thr_hypothetical_contribs[f"{task}/{gt}"] * thr_one_hot
            for task in tasks for gt in grad_type.split(",")
        }
        task_names = [
            f"{task}/{gt}" for task in tasks for gt in grad_type.split(",")
        ]

    else:
        if merge_tasks:
            thr_one_hot = np.concatenate([
                one_hot[dist_filter] for task in tasks
                for gt in grad_type.split(",")
            ])
            thr_hypothetical_contribs = {
                "merged":
                np.concatenate([
                    mean(d['hyp_imp'][task][gt])[dist_filter] for task in tasks
                    for gt in grad_type.split(",")
                ])
            }

            thr_contrib_scores = {
                "merged": thr_hypothetical_contribs['merged'] * thr_one_hot
            }
            task_names = ['merged']
        else:
            thr_one_hot = one_hot[dist_filter]
            thr_hypothetical_contribs = {
                f"{task}/{gt}": mean(d['hyp_imp'][task][gt])[dist_filter]
                for task in tasks for gt in grad_type.split(",")
            }
            thr_contrib_scores = {
                f"{task}/{gt}":
                thr_hypothetical_contribs[f"{task}/{gt}"] * thr_one_hot
                for task in tasks for gt in grad_type.split(",")
            }
            task_names = [
                f"{task}/{gt}" for task in tasks for gt in grad_type.split(",")
            ]

    if null_imp_scores is not None:
        logger.info(f"Using null_imp_scores: {null_imp_scores}")
        null_isf = ImpScoreFile(null_imp_scores)
        null_per_pos_scores = {
            f"{task}/{gt}": v.sum(axis=-1)
            for gt in grad_type.split(",")
            for task, v in null_isf.get_contrib(imp_score=gt).items()
            if task in tasks
        }
    else:
        # default Null distribution. Requires modisco 5.0
        logger.info(f"Using default null_imp_scores")
        null_per_pos_scores = modisco.coordproducers.LaplaceNullDist(
            num_to_samp=10000)

    # -------------------------------------------------------------
    # run modisco
    tfmodisco_results = modisco.tfmodisco_workflow.workflow.TfModiscoWorkflow(
        # Modisco defaults
        sliding_window_size=hp.sliding_window_size,
        flank_size=hp.flank_size,
        target_seqlet_fdr=hp.target_seqlet_fdr,
        min_passing_windows_frac=hp.min_passing_windows_frac,
        max_passing_windows_frac=hp.max_passing_windows_frac,
        min_metacluster_size=hp.min_metacluster_size,
        max_seqlets_per_metacluster=hp.max_seqlets_per_metacluster,
        seqlets_to_patterns_factory=modisco.tfmodisco_workflow.
        seqlets_to_patterns.TfModiscoSeqletsToPatternsFactory(
            trim_to_window_size=hp.trim_to_window_size,  # default: 30
            initial_flank_to_add=hp.initial_flank_to_add,  # default: 10
            kmer_len=hp.kmer_len,  # default: 8
            num_gaps=hp.num_gaps,  # default: 3
            num_mismatches=hp.num_mismatches,  # default: 2
            n_cores=num_workers,
            final_min_cluster_size=hp.final_min_cluster_size)  # default: 30
    )(
        task_names=task_names,
        contrib_scores=thr_contrib_scores,  # -> task score
        hypothetical_contribs=thr_hypothetical_contribs,
        one_hot=thr_one_hot,
        null_per_pos_scores=null_per_pos_scores)
    # -------------------------------------------------------------
    # save the results
    grp = h5py.File(output_path)
    tfmodisco_results.save_hdf5(grp)
示例#14
0
    def imp_score(self,
                  x,
                  task,
                  strand='both',
                  method='grad',
                  pred_summary='weighted',
                  batch_size=512):
        """Compute the importance score

        Args:
          x: one-hot encoded DNA sequence
          method: which importance score to use. Available: grad, ism, deeplift
          strand: for which strand to run it ('pos', 'neg' or 'both'). If None, the average of both strands is returned
          task_id: id of the task as an int. See `self.tasks` for available tasks

        """
        assert task in self.tasks
        # figure out the task id
        task_id = [i for i, t in enumerate(self.tasks) if t == task][0]

        # task_id
        if strand == 'both':
            # average across strands
            return mean([
                self.imp_score(x,
                               task,
                               strand=strand,
                               method=method,
                               pred_summary=pred_summary,
                               batch_size=batch_size)
                for strand in ['pos', 'neg']
            ])

        def input_to_list(input_names, x):
            if isinstance(x, list):
                return x
            elif isinstance(x, dict):
                return [x[k] for k in input_names]
            else:
                return [x]

        input_names = self.model.input_names
        assert input_names[0] == "seq"

        # get the importance scoring function
        # if method == "grad":
        #     fn = self._imp_grad_fn(strand, task_id, pred_summary) #returns fxn
        #     fn_applied = fn(input_to_list(input_names, x))[0]
        # elif method == "ism":
        #     fn = self._imp_ism_fn(strand, task_id, pred_summary) #returns fxn
        #     fn_applied = fn(input_to_list(input_names, x))[0]
        # elif method == "deeplift":
        #     fn = self._imp_deeplift_fn(x, strand, task_id, pred_summary) #returns numpy.ndarray
        #     fn_applied = fn
        # else:
        #     raise ValueError("Please provide a valid importance scoring method: grad, ism or deeplift")

        # if batch_size is None:
        #     return fn_applied
        # else:
        #     return numpy_collate_concat([fn_applied for batch in nested_numpy_minibatch(x, batch_size=batch_size)])

        if method == "grad":
            fn = self._imp_grad_fn(strand, task_id, pred_summary)
        elif method == "ism":
            fn = self._imp_ism_fn(strand, task_id, pred_summary)
        elif method == "deeplift":
            fn = self._imp_deeplift_fn(x, strand, task_id, pred_summary)
        else:
            raise ValueError(
                "Please provide a valid importance scoring method: grad, ism or deeplift"
            )

        if batch_size is None:
            return fn(input_to_list(input_names, x))[0]
        else:
            return numpy_collate_concat([
                fn(input_to_list(input_names, batch))[0]
                for batch in nested_numpy_minibatch(x, batch_size=batch_size)
            ])