def test_subset_uses_placeholder_y(self, dataset_cls, data, uses_placeholder_y, cv_split_cls): X, _ = data ds = dataset_cls(X, y=None) ds_train, ds_valid = cv_split_cls(cv=2)(ds) assert uses_placeholder_y(ds_train) assert uses_placeholder_y(ds_valid)
def test_subset_dataset_uses_non_y_placeholder(self, dataset_cls, data, uses_placeholder_y, cv_split_cls): X, y = data ds = dataset_cls(X, y) ds_train, ds_valid = cv_split_cls(cv=2)(ds) assert not uses_placeholder_y(ds_train) assert not uses_placeholder_y(ds_valid)
def test_subset_of_subset_uses_non_placeholder_y(self, dataset_cls, data, uses_placeholder_y, cv_split_cls): X, y = data ds = dataset_cls(X, y) ds_split, _ = cv_split_cls(cv=4)(ds) ds_train, ds_valid = cv_split_cls(cv=3)(ds_split) assert not uses_placeholder_y(ds_train) assert not uses_placeholder_y(ds_valid)
def run_single_epoch(self, dataset, training, prefix, step_fn, **fit_params): """Compute a single epoch of train or validation. Parameters ---------- dataset : torch Dataset The initialized dataset to loop over. training : bool Whether to set the module to train mode or not. prefix : str Prefix to use when saving to the history. step_fn : callable Function to call for each batch. **fit_params : dict Additional parameters passed to the ``step_fn``. """ is_placeholder_y = uses_placeholder_y(dataset) batch_count = 0 for i, data in enumerate(self.get_iterator(dataset, training=training)): Xi, yi = unpack_data(data) yi_res = yi if not is_placeholder_y else None self.notify("on_batch_begin", X=Xi, y=yi_res, training=training) step = step_fn(Xi, yi, train_generator=(i % self.train_generator_every == 0), **fit_params) self.history.record_batch(prefix + "_distance", step["distance"].item()) self.history.record_batch(prefix + "_batch_size", get_len(Xi)) self.notify("on_batch_end", X=Xi, y=yi_res, training=training, **step) batch_count += 1 self.history.record(prefix + "_batch_count", batch_count)
def run_single_epoch(self, dataset, training, prefix, step_fn, **fit_params): """Compute a single epoch of train or validation. Parameters ---------- dataset : torch Dataset The initialized dataset to loop over. training : bool Whether to set the module to train mode or not. prefix : str Prefix to use when saving to the history. step_fn : callable Function to call for each batch. **fit_params : dict Additional parameters passed to the ``step_fn``. """ is_placeholder_y = uses_placeholder_y(dataset) batch_count = 0 for data in self.get_iterator(dataset, training=training): # Removes the dummy target data, _ = data # Removes the query target to limit temptation yi = data['query'].pop(1) yi_res = yi if not is_placeholder_y else None self.notify("on_batch_begin", X=data, y=yi_res, training=training) step = step_fn(data, yi, **fit_params) self.history.record_batch(prefix + "_loss", step["loss"].item()) self.history.record_batch(prefix + "_batch_size", get_len(data)) self.notify("on_batch_end", X=data, y=yi_res, training=training, **step) batch_count += 1 self.history.record(prefix + "_batch_count", batch_count)
def run_single_epoch(self, dataset, training, prefix, step_fn, **fit_params): is_placeholder_y = uses_placeholder_y(dataset) batch_count = 0 for data in self.get_iterator(dataset, training=training): Xi, yi = unpack_data(data) yi_res = yi if not is_placeholder_y else None self.notify("on_batch_begin", X=Xi, y=yi_res, training=training) step = step_fn(Xi, yi, **fit_params) self.history.record_batch(prefix + "_loss", step["loss"].item()) self.history.record_batch(prefix + "_batch_size", get_len(Xi["nodes"])) self.notify("on_batch_end", X=Xi, y=yi_res, training=training, **step) batch_count += 1 self.history.record(prefix + "_batch_count", batch_count)
def test_custom_dataset_uses_non_y_placeholder(self, custom_dataset_cls, uses_placeholder_y): ds = custom_dataset_cls() assert not uses_placeholder_y(ds)
def test_dataset_uses_non_y_placeholder(self, dataset_cls, data, uses_placeholder_y): X, y = data ds = dataset_cls(X, y) assert not uses_placeholder_y(ds)
def test_dataset_uses_y_placeholder(self, dataset_cls, data, uses_placeholder_y): X, _ = data ds = dataset_cls(X, y=None) assert uses_placeholder_y(ds)
def fit_loop(self, X, y=None, epochs=None, **fit_params): """The proper fit loop. Contains the logic of what actually happens during the fit loop. Parameters ---------- X : input data, compatible with skorch.dataset.Dataset By default, you should be able to pass: * numpy arrays * torch tensors * pandas DataFrame or Series * a dictionary of the former three * a list/tuple of the former three * a Dataset If this doesn't work with your data, you have to pass a ``Dataset`` that can deal with the data. y : target data, compatible with skorch.dataset.Dataset The same data types as for ``X`` are supported. If your X is a Dataset that contains the target, ``y`` may be set to None. epochs : int or None (default=None) If int, train for this number of epochs; if None, use ``self.max_epochs``. **fit_params : dict Additional parameters passed to the ``forward`` method of the module and to the ``self.train_split`` call. """ self.check_data(X, y) epochs = epochs if epochs is not None else self.max_epochs dataset_train, dataset_valid = self.get_split_datasets( X, y, **fit_params) on_epoch_kwargs = { 'dataset_train': dataset_train, 'dataset_valid': dataset_valid, } y_train_is_ph = uses_placeholder_y(dataset_train) y_valid_is_ph = uses_placeholder_y(dataset_valid) for _ in range(epochs): self.notify('on_epoch_begin', **on_epoch_kwargs) for Xi, yi in self.get_iterator(dataset_train, training=True): yi_res = yi if not y_train_is_ph else None self.notify('on_batch_begin', X=Xi, y=yi_res, training=True) step = self.train_step(Xi, yi, **fit_params) self.history.record_batch('train_loss', step['loss'].item()) self.history.record_batch('train_batch_size', get_len(Xi)) self.notify('on_batch_end', X=Xi, y=yi_res, training=True, **step) if dataset_valid is None: self.notify('on_epoch_end', **on_epoch_kwargs) continue for Xi, yi in self.get_iterator(dataset_valid, training=False): yi_res = yi if not y_valid_is_ph else None self.notify('on_batch_begin', X=Xi, y=yi_res, training=False) step = self.validation_step(Xi, yi, **fit_params) self.history.record_batch('valid_loss', step['loss'].item()) self.history.record_batch('valid_batch_size', get_len(Xi)) self.notify('on_batch_end', X=Xi, y=yi_res, training=False, **step) self.notify('on_epoch_end', **on_epoch_kwargs) return self