Ejemplo n.º 1
0
    def evaluate(self,
                 dataset,
                 eval_metric=None,
                 num_workers=8,
                 batch_size=256):
        lpreds = []
        llabels = []
        for inputs, targets in tqdm(dataset.batch_train_iter(
                cycle=False, num_workers=num_workers, batch_size=batch_size),
                                    total=len(dataset) // batch_size):
            assert isinstance(targets, dict)
            target_keys = list(targets)
            llabels.append(deepcopy(targets))
            bpreds = {
                k: v
                for k, v in self.predict(inputs, batch_size=None).items()
                if k in target_keys
            }  # keep only the target key predictions
            lpreds.append(bpreds)
            del inputs
            del targets
        preds = numpy_collate_concat(lpreds)
        labels = numpy_collate_concat(llabels)
        del lpreds
        del llabels

        if eval_metric is not None:
            return eval_metric(labels, preds)
        else:
            task_avg_tape = defaultdict(list)
            out = {}
            for task, heads in self.all_heads.items():
                for head_i, head in enumerate(heads):
                    target_name = head.get_target(task)
                    if target_name not in labels:
                        print(
                            f"Target {target_name} not found. Skipping evaluation"
                        )
                        continue
                    res = head.metric(labels[target_name], preds[target_name])
                    out[target_name] = res
                    metrics_dict = flatten(res, separator='/')
                    for k, v in metrics_dict.items():
                        task_avg_tape[
                            head.target_name.replace("{task}", "avg") + "/" +
                            k].append(v)
            for k, v in task_avg_tape.items():
                # get the average
                out[k] = mean(v)

        # flatten everything
        out = flatten(out, separator='/')
        return out
Ejemplo n.º 2
0
    def load_all(self, **kwargs):
        """Loads and returns the whole dataset

        Arguments:
            **kwargs: passed to batch_iter()
        """
        return numpy_collate_concat([x for x in tqdm(self.batch_iter(**kwargs))])
Ejemplo n.º 3
0
    def batch_write(self, batch):
        fbatch = flatten(batch, separator="/")

        batch_sizes = [fbatch[k].shape[0] for k in fbatch]
        # assert all shapes are the same
        assert len(pd.Series(batch_sizes).unique()) == 1
        batch_size = batch_sizes[0]

        if self.first_pass:
            # have a dictionary holding
            for k in fbatch:
                if fbatch[k].dtype.type in [np.string_, np.str_, np.unicode_]:
                    dtype = self.string_type
                else:
                    dtype = fbatch[k].dtype

                self.f.create_dataset(k,
                                      shape=(0, ) + fbatch[k].shape[1:],
                                      dtype=dtype,
                                      maxshape=(None, ) + fbatch[k].shape[1:],
                                      compression=self.compression,
                                      chunks=(self.chunk_size, ) +
                                      fbatch[k].shape[1:])
            self.first_pass = False
        # add data to the buffer
        if self.write_buffer is None:
            self.write_buffer = fbatch
            self.write_buffer_size = batch_size
        else:
            self.write_buffer = numpy_collate_concat(
                [self.write_buffer, fbatch])
            self.write_buffer_size += batch_size

        if self.write_buffer is not None and self.write_buffer_size >= self.chunk_size:
            self._flush_buffer()
Ejemplo n.º 4
0
def feature_importance(model,
                       dataloader,
                       importance_score,
                       importance_score_kwargs={},
                       batch_size=32,
                       num_workers=0):
    """Return feature importance scores
    """
    ImpScore = get_importance_score(importance_score)
    if not ImpScore.is_compatible(model):
        raise ValueError(
            "model not compatible with score: {0}".format(importance_score))
    impscore = ImpScore(model, **importance_score_kwargs)

    def append_key(d, k, v):
        d[k] = v
        return d

    # TODO - handle the reference-based importance scores...
    return numpy_collate_concat([
        append_key(batch, "importance_scores", impscore.score(batch['inputs']))
        for batch in tqdm(
            dataloader.batch_iter(batch_size=batch_size,
                                  num_workers=num_workers))
    ])
Ejemplo n.º 5
0
 def load_all(self, batch_size=32, **kwargs):
     """Load the whole dataset into memory
     Arguments:
         batch_size (int, optional): how many samples per batch to load
             (default: 1).
     """
     return numpy_collate_concat([x for x in tqdm(self.batch_iter(batch_size, **kwargs))])
Ejemplo n.º 6
0
    def __add__(self, stacked_seqlets):
        s = stacked_seqlets
        # tasks = self.tasks()

        from kipoi.data_utils import numpy_collate_concat
        return StackedSeqletImp(
            name=self.name,
            seq=np.concatenate([self.seq, s.seq]),
            contrib=numpy_collate_concat([self.contrib, s.contrib]),
            hyp_contrib=numpy_collate_concat([self.hyp_contrib, s.hyp_contrib]),
            profile=numpy_collate_concat([self.profile, s.profile]),
            dfi=(pd.concat([self.dfi, s.dfi])
                 if self.dfi is not None and s.dfi is not None
                 else None),
            attrs=self.attrs
        )
Ejemplo n.º 7
0
 def load_all_with_metadata(self,
                            batch_size=64,
                            num_workers=64,
                            shuffle=True,
                            drop_last=False):
     it = self.batch_iter(batch_size=batch_size,
                          num_workers=num_workers,
                          shuffle=shuffle,
                          drop_last=drop_last)
     return numpy_collate_concat([x for x in tqdm(it)])
Ejemplo n.º 8
0
 def load_all(self, batch_size=32, num_workers=0, **kwargs):
     """Load the whole dataset into memory
     Arguments:
         batch_size (int, optional): how many samples per batch to load
             (default: 1).
         num_workers (int, optional): how many subprocesses to use for data
             loading. 0 means that the data will be loaded in the main process
             (default: 0)
     """
     return numpy_collate_concat([x for x in tqdm(self.batch_iter(batch_size,
                                                                  num_workers=num_workers))])
Ejemplo n.º 9
0
 def load_all(self,
              batch_size=64,
              num_workers=64,
              shuffle=True,
              drop_last=False):
     it = self.batch_iter(batch_size=batch_size,
                          num_workers=num_workers,
                          shuffle=shuffle,
                          drop_last=drop_last)
     dataset = numpy_collate_concat([x for x in tqdm(it)])
     return dataset['inputs'], dataset['targets']
Ejemplo n.º 10
0
 def load_all(self, batch_size=32, **kwargs):
     """Load the whole dataset into memory
     Arguments:
         batch_size (int, optional): how many samples per batch to load
             (default: 1).
     """
     from copy import deepcopy
     return numpy_collate_concat([deepcopy(x)
                                  for x in tqdm(self.batch_iter(batch_size,
                                                                **kwargs),
                                                total=len(self) // batch_size)])
Ejemplo n.º 11
0
 def _flush_buffer(self):
     """Write buffer
     """
     wb = numpy_collate_concat(self.write_buffer)  # merge the buffer
     for k in wb:
         if sys.version_info[0] == 2 and wb[k].dtype.type in [
                 np.string_, np.str_, np.unicode_
         ]:
             self.root[k].append(wb[k].astype(unicode))
         else:
             self.root[k].append(wb[k])
     self.write_buffer = None
     self.write_buffer_size = 0
Ejemplo n.º 12
0
 def _flush_buffer(self):
     """Write buffer
     """
     wb = numpy_collate_concat(self.write_buffer)
     for k in wb:
         dset = self.f[k]
         clen = dset.shape[0]
         # resize
         dset.resize(clen + self.write_buffer_size, axis=0)
         # write
         dset[clen:] = wb[k]
     self.f.flush()
     self.write_buffer = None
     self.write_buffer_size = 0
Ejemplo n.º 13
0
    def imp_score(self,
                  x,
                  name,
                  method='deeplift',
                  batch_size=512,
                  preact_only=False):
        """Compute the importance score

        Args:
          x: one-hot encoded DNA sequence
          name: which interepretation method to compute
          method: which importance score to use. Available: grad, ism, deeplift
        """
        # Do we need bias?
        if not isinstance(x, dict) and not isinstance(x, list):
            seqlen = x.shape[1]
            x = {'seq': x, **self.neutral_bias_inputs(len(x), seqlen=seqlen)}

        if method == "deeplift":
            fn = self._imp_deeplift_fn(x, name, preact_only=preact_only)
        else:
            raise ValueError(
                "Please provide a valid importance scoring method: grad, ism or deeplift"
            )

        def input_to_list(input_names, x):
            if isinstance(x, list):
                return x
            elif isinstance(x, dict):
                return [x[k] for k in input_names]
            else:
                return [x]

        input_names = self.model.input_names
        assert input_names[0] == "seq"

        if batch_size is None:
            return fn(input_to_list(input_names, x))[0]
        else:
            return numpy_collate_concat([
                fn(input_to_list(input_names, batch))[0]
                for batch in nested_numpy_minibatch(x, batch_size=batch_size)
            ])
Ejemplo n.º 14
0
 def concat(cls, objects):
     return cls(data=numpy_collate_concat(objects), attrs=None)
Ejemplo n.º 15
0
 def append(self, datax):
     """Append two datasets
     """
     return super().__init__(data=numpy_collate_concat([self.data, datax.data]),
                             attrs=deepcopy(self.attrs))
Ejemplo n.º 16
0
    def imp_score(self,
                  x,
                  task,
                  strand='both',
                  method='grad',
                  pred_summary='weighted',
                  batch_size=512):
        """Compute the importance score

        Args:
          x: one-hot encoded DNA sequence
          method: which importance score to use. Available: grad, ism, deeplift
          strand: for which strand to run it ('pos', 'neg' or 'both'). If None, the average of both strands is returned
          task_id: id of the task as an int. See `self.tasks` for available tasks

        """
        assert task in self.tasks
        # figure out the task id
        task_id = [i for i, t in enumerate(self.tasks) if t == task][0]

        # task_id
        if strand == 'both':
            # average across strands
            return mean([
                self.imp_score(x,
                               task,
                               strand=strand,
                               method=method,
                               pred_summary=pred_summary,
                               batch_size=batch_size)
                for strand in ['pos', 'neg']
            ])

        def input_to_list(input_names, x):
            if isinstance(x, list):
                return x
            elif isinstance(x, dict):
                return [x[k] for k in input_names]
            else:
                return [x]

        input_names = self.model.input_names
        assert input_names[0] == "seq"

        # get the importance scoring function
        # if method == "grad":
        #     fn = self._imp_grad_fn(strand, task_id, pred_summary) #returns fxn
        #     fn_applied = fn(input_to_list(input_names, x))[0]
        # elif method == "ism":
        #     fn = self._imp_ism_fn(strand, task_id, pred_summary) #returns fxn
        #     fn_applied = fn(input_to_list(input_names, x))[0]
        # elif method == "deeplift":
        #     fn = self._imp_deeplift_fn(x, strand, task_id, pred_summary) #returns numpy.ndarray
        #     fn_applied = fn
        # else:
        #     raise ValueError("Please provide a valid importance scoring method: grad, ism or deeplift")

        # if batch_size is None:
        #     return fn_applied
        # else:
        #     return numpy_collate_concat([fn_applied for batch in nested_numpy_minibatch(x, batch_size=batch_size)])

        if method == "grad":
            fn = self._imp_grad_fn(strand, task_id, pred_summary)
        elif method == "ism":
            fn = self._imp_ism_fn(strand, task_id, pred_summary)
        elif method == "deeplift":
            fn = self._imp_deeplift_fn(x, strand, task_id, pred_summary)
        else:
            raise ValueError(
                "Please provide a valid importance scoring method: grad, ism or deeplift"
            )

        if batch_size is None:
            return fn(input_to_list(input_names, x))[0]
        else:
            return numpy_collate_concat([
                fn(input_to_list(input_names, batch))[0]
                for batch in nested_numpy_minibatch(x, batch_size=batch_size)
            ])
Ejemplo n.º 17
0
    def evaluate(self,
                 metric,
                 batch_size=256,
                 num_workers=8,
                 eval_train=False,
                 eval_skip=(),
                 save=True,
                 **kwargs):
        """Evaluate the model on the validation set
        Args:
          metrics: a list or a dictionary of metrics
          batch_size:
          num_workers:
          eval_train: if True, also compute the evaluation metrics on the training set
          save: save the json file to the output directory
        """
        if len(kwargs) > 0:
            logger.warning(
                f"Extra kwargs were provided to trainer.evaluate: {kwargs}")
        # contruct a list of dataset to evaluate
        if eval_train:
            eval_datasets = [('train', self.train_dataset)
                             ] + self.valid_dataset
        else:
            eval_datasets = self.valid_dataset

        try:
            if len(eval_skip) > 0:
                eval_datasets = [(k, v) for k, v in eval_datasets
                                 if k not in eval_skip]
        except:
            logger.warning(
                f"eval datasets don't contain tuples. Unable to skip them using {eval_skip}"
            )

        metric_res = OrderedDict()
        for d in eval_datasets:
            if len(d) == 2:
                dataset_name, dataset = d
                eval_metric = metric  # use the default eval metric
            elif len(d) == 3:
                # specialized evaluation metric was passed
                dataset_name, dataset, eval_metric = d
            else:
                # TODO - this should be made more explicit with classes
                raise ValueError(
                    "Valid dataset needs to be a list of tuples of 2 or 3 elements"
                    "(name, dataset) or (name, dataset, metric)")
            logger.info(f"Evaluating dataset: {dataset_name}")
            lpreds = []
            llabels = []
            from copy import deepcopy
            for inputs, targets in tqdm(
                    dataset.batch_train_iter(cycle=False,
                                             num_workers=num_workers,
                                             batch_size=batch_size),
                    total=len(dataset) // batch_size):
                lpreds.append(self.model.predict_on_batch(inputs))
                llabels.append(deepcopy(targets))
                del inputs
                del targets
            preds = numpy_collate_concat(lpreds)
            labels = numpy_collate_concat(llabels)
            del lpreds
            del llabels
            metric_res[dataset_name] = eval_metric(labels, preds)

        if save:
            write_json(metric_res, self.evaluation_path, indent=2)
            logger.info("Saved metrics to {}".format(self.evaluation_path))

        if self.cometml_experiment is not None:
            self.cometml_experiment.log_multiple_metrics(flatten(
                metric_res, separator='/'),
                                                         prefix="eval/")

        if self.wandb_run is not None:
            self.wandb_run.summary.update(
                flatten(prefix_dict(metric_res, prefix="eval/"),
                        separator='/'))
        metric_res = {**self.metrics, **metric_res}
        return metric_res
Ejemplo n.º 18
0
    else:
        gpu = args.gpu
    create_tf_session(gpu)

    model = load_model(args.model)

    # Get the dataloader
    # Dl = kipoi.get_dataloader_factory(args.dataloader, args.dataloader_source)
    Dl = kipoi.get_model(args.dataloader,
                         args.dataloader_source).default_dataloader
    dl = Dl(intervals_file=args.intervals_file,
            fasta_file=args.fasta_file,
            ignore_targets=False,
            num_chr_fasta=True)

    metric_fns = {"auprc": auprc, "auc": auc, "accuracy": accuracy}

    y_true = dl.seq_dl.bed.df[3].values
    y_pred = numpy_collate_concat([
        model.predict_on_batch(x['inputs']) for x in tqdm(
            dl.batch_iter(batch_size=args.batch_size,
                          num_workers=args.num_workers))
    ])

    # ---------------
    metrics = {
        k: m(np.ravel(y_true), np.ravel(y_pred))
        for k, m in metric_fns.items()
    }
    output.write_text(json.dumps(metrics))