예제 #1
0
파일: base.py 프로젝트: hleu/torchtuples
    def _predict_func(self,
                      func,
                      input,
                      batch_size=8224,
                      numpy=None,
                      eval_=True,
                      grads=False,
                      to_cpu=False,
                      num_workers=0,
                      is_dataloader=None,
                      **kwargs):
        """Get predictions from `input` which can be data or a DataLoader.
        `func` can be anything and is not concatenated to `self.net` or `self.net.predict`.
        This is different from `predict` and `predict_net` which both use call `self.net`.
        """
        if is_data(input) or (is_dataloader is False):
            dl = self.make_dataloader_predict(input,
                                              batch_size,
                                              shuffle=False,
                                              num_workers=num_workers,
                                              **kwargs)
        elif is_dl(input) or (is_dataloader is True):
            dl = input
        else:
            raise ValueError(
                "Did not recognize data type. You can set `is_dataloader to `Ture`"
                + + " or `False` to force usage.")

        to_cpu = numpy or to_cpu
        preds = self._predict_func_dl(func, dl, numpy, eval_, grads, to_cpu)
        return array_or_tensor(preds, numpy, input)
예제 #2
0
파일: base.py 프로젝트: havakv/torchtuples
    def fit(
        self,
        input,
        target=None,
        batch_size=256,
        epochs=1,
        callbacks=None,
        verbose=True,
        num_workers=0,
        shuffle=True,
        metrics=None,
        val_data=None,
        val_batch_size=8224,
        **kwargs,
    ):
        """Fit  model with inputs and targets.

        Arguments:
            input {np.array, tensor or tuple} -- Input (x) passed to net.
            target {np.array, tensor or tuple} -- Target (y) passed to loss function.

        Keyword Arguments:
            batch_size {int} -- Elemets in each batch (default: {256})
            epochs {int} -- Number of epochs (default: {1})
            callbacks {list} -- list of callbacks (default: {None})
            verbose {bool} -- Print progress (default: {True})
            num_workers {int} -- Number of workers used in the dataloader (default: {0})
            shuffle {bool} -- If we should shuffle the order of the dataset (default: {True})
            **kwargs -- Passed to the 'make_dataloader' method. Set e.g. `torch_ds_dl to use
                the TensorDataset and DataLoader provided by torch instead of the torchtuples
                implementations.

        Returns:
            TrainingLogger -- Training log
        """
        if target is not None:
            input = (input, target)
        dataloader = self.make_dataloader(input, batch_size, shuffle,
                                          num_workers, **kwargs)
        val_dataloader = val_data
        if (is_dl(val_data) is False) and (val_data is not None):
            val_dataloader = self.make_dataloader(val_data,
                                                  val_batch_size,
                                                  shuffle=False,
                                                  num_workers=num_workers,
                                                  **kwargs)
        log = self.fit_dataloader(dataloader, epochs, callbacks, verbose,
                                  metrics, val_dataloader)
        return log