def _predict_func(self, func, input, batch_size=8224, numpy=None, eval_=True, grads=False, to_cpu=False, num_workers=0, is_dataloader=None, **kwargs): """Get predictions from `input` which can be data or a DataLoader. `func` can be anything and is not concatenated to `self.net` or `self.net.predict`. This is different from `predict` and `predict_net` which both use call `self.net`. """ if is_data(input) or (is_dataloader is False): dl = self.make_dataloader_predict(input, batch_size, shuffle=False, num_workers=num_workers, **kwargs) elif is_dl(input) or (is_dataloader is True): dl = input else: raise ValueError( "Did not recognize data type. You can set `is_dataloader to `Ture`" + + " or `False` to force usage.") to_cpu = numpy or to_cpu preds = self._predict_func_dl(func, dl, numpy, eval_, grads, to_cpu) return array_or_tensor(preds, numpy, input)
def fit( self, input, target=None, batch_size=256, epochs=1, callbacks=None, verbose=True, num_workers=0, shuffle=True, metrics=None, val_data=None, val_batch_size=8224, **kwargs, ): """Fit model with inputs and targets. Arguments: input {np.array, tensor or tuple} -- Input (x) passed to net. target {np.array, tensor or tuple} -- Target (y) passed to loss function. Keyword Arguments: batch_size {int} -- Elemets in each batch (default: {256}) epochs {int} -- Number of epochs (default: {1}) callbacks {list} -- list of callbacks (default: {None}) verbose {bool} -- Print progress (default: {True}) num_workers {int} -- Number of workers used in the dataloader (default: {0}) shuffle {bool} -- If we should shuffle the order of the dataset (default: {True}) **kwargs -- Passed to the 'make_dataloader' method. Set e.g. `torch_ds_dl to use the TensorDataset and DataLoader provided by torch instead of the torchtuples implementations. Returns: TrainingLogger -- Training log """ if target is not None: input = (input, target) dataloader = self.make_dataloader(input, batch_size, shuffle, num_workers, **kwargs) val_dataloader = val_data if (is_dl(val_data) is False) and (val_data is not None): val_dataloader = self.make_dataloader(val_data, val_batch_size, shuffle=False, num_workers=num_workers, **kwargs) log = self.fit_dataloader(dataloader, epochs, callbacks, verbose, metrics, val_dataloader) return log