def fit_loop(self, X, y=None, epochs=None, **fit_params): epochs = epochs if epochs is not None else self.max_epochs # split K-fold dataset indcies dataset = self.get_dataset(X, y) k = 10 fold_indcies = self.split_k_fold(k, len(dataset)) for e in range(epochs): # get train and validation set valid_fold_idx = e % k idx_train = [] for i in range(k): if i == valid_fold_idx: continue else: idx_train += fold_indcies[i] idx_train = np.array(idx_train, dtype=int) idx_valid = np.array(fold_indcies[valid_fold_idx], dtype=int) dataset_train = torch.utils.data.Subset(dataset, idx_train) dataset_valid = torch.utils.data.Subset(dataset, idx_valid) on_epoch_kwargs = { 'dataset_train': dataset_train, 'dataset_valid': dataset_valid, } self.notify('on_epoch_begin', **on_epoch_kwargs) train_batch_count = 0 for data in self.get_iterator(dataset_train, training=True): xi, yi = data self.notify('on_batch_begin', X=xi, y=yi, training=True) step = self.train_step(xi, yi, **fit_params) self.history.record_batch('train_loss', step['loss'].item()) self.history.record_batch('train_batch_size', get_len(xi)) self.notify('on_batch_end', X=xi, y=yi, training=True, **step) train_batch_count += 1 self.history.record("train_batch_count", train_batch_count) valid_batch_count = 0 for data in self.get_iterator(dataset_valid, training=False): xi, yi = data self.notify('on_batch_begin', X=xi, y=yi, training=False) step = self.validation_step(xi, yi, **fit_params) self.history.record_batch('valid_loss', step['loss'].item()) self.history.record_batch('valid_batch_size', get_len(xi)) self.notify('on_batch_end', X=xi, y=yi, training=False, **step) valid_batch_count += 1 self.history.record("valid_batch_count", valid_batch_count) self.notify('on_epoch_end', **on_epoch_kwargs) return self
def run_single_epoch(self, dataset, training, prefix, step_fn, **fit_params): """Compute a single epoch of train or validation. Parameters ---------- dataset : torch Dataset The initialized dataset to loop over. training : bool Whether to set the module to train mode or not. prefix : str Prefix to use when saving to the history. step_fn : callable Function to call for each batch. **fit_params : dict Additional parameters passed to the ``step_fn``. """ is_placeholder_y = uses_placeholder_y(dataset) batch_count = 0 for i, data in enumerate(self.get_iterator(dataset, training=training)): Xi, yi = unpack_data(data) yi_res = yi if not is_placeholder_y else None self.notify("on_batch_begin", X=Xi, y=yi_res, training=training) step = step_fn(Xi, yi, train_generator=(i % self.train_generator_every == 0), **fit_params) self.history.record_batch(prefix + "_distance", step["distance"].item()) self.history.record_batch(prefix + "_batch_size", get_len(Xi)) self.notify("on_batch_end", X=Xi, y=yi_res, training=training, **step) batch_count += 1 self.history.record(prefix + "_batch_count", batch_count)
def run_single_epoch(self, dataset, training, prefix, step_fn, **fit_params): """Compute a single epoch of train or validation. Parameters ---------- dataset : torch Dataset The initialized dataset to loop over. training : bool Whether to set the module to train mode or not. prefix : str Prefix to use when saving to the history. step_fn : callable Function to call for each batch. **fit_params : dict Additional parameters passed to the ``step_fn``. """ is_placeholder_y = uses_placeholder_y(dataset) batch_count = 0 for data in self.get_iterator(dataset, training=training): # Removes the dummy target data, _ = data # Removes the query target to limit temptation yi = data['query'].pop(1) yi_res = yi if not is_placeholder_y else None self.notify("on_batch_begin", X=data, y=yi_res, training=training) step = step_fn(data, yi, **fit_params) self.history.record_batch(prefix + "_loss", step["loss"].item()) self.history.record_batch(prefix + "_batch_size", get_len(data)) self.notify("on_batch_end", X=data, y=yi_res, training=training, **step) batch_count += 1 self.history.record(prefix + "_batch_count", batch_count)
def run_single_epoch(self, dataset, training, prefix, step_fn, **fit_params): is_placeholder_y = uses_placeholder_y(dataset) batch_count = 0 for data in self.get_iterator(dataset, training=training): Xi, yi = unpack_data(data) yi_res = yi if not is_placeholder_y else None self.notify("on_batch_begin", X=Xi, y=yi_res, training=training) step = step_fn(Xi, yi, **fit_params) self.history.record_batch(prefix + "_loss", step["loss"].item()) self.history.record_batch(prefix + "_batch_size", get_len(Xi["nodes"])) self.notify("on_batch_end", X=Xi, y=yi_res, training=training, **step) batch_count += 1 self.history.record(prefix + "_batch_count", batch_count)
def _get_batches_per_epoch_phase(self, net, dataset, training): if dataset is None: return 0 batch_size = self._get_batch_size(net, training) return int(np.ceil(get_len(dataset) / batch_size))
def test_inconsistent_lengths(self, get_len, data): with pytest.raises(ValueError): get_len(data)
def test_valid_lengths(self, get_len, data, expected): length = get_len(data) assert length == expected
def _get_batches_per_epoch_phase(self, net, X, training): if X is None: return 0 batch_size = self._get_batch_size(net, training) return int(np.ceil(get_len(X) / batch_size))
def test_inconsistent_lengths(self, get_len, data): with pytest.raises(ValueError): get_len(data)
def test_valid_lengths(self, get_len, data, expected): length = get_len(data) assert length == expected
def fit_loop(self, X, y=None, epochs=None, **fit_params): """The proper fit loop. Contains the logic of what actually happens during the fit loop. Parameters ---------- X : input data, compatible with skorch.dataset.Dataset By default, you should be able to pass: * numpy arrays * torch tensors * pandas DataFrame or Series * a dictionary of the former three * a list/tuple of the former three If this doesn't work with your data, you have to pass a ``Dataset`` that can deal with the data. y : target data, compatible with skorch.dataset.Dataset The same data types as for ``X`` are supported. epochs : int or None (default=None) If int, train for this number of epochs; if None, use ``self.max_epochs``. **fit_params : dict Additional parameters passed to the ``forward`` method of the module and to the train_split call. """ self.check_data(X, y) epochs = epochs if epochs is not None else self.max_epochs if self.train_split: X_train, X_valid, y_train, y_valid = self.train_split( X, y, **fit_params) dataset_valid = self.get_dataset(X_valid, y_valid) else: X_train, X_valid, y_train, y_valid = X, None, y, None dataset_valid = None dataset_train = self.get_dataset(X_train, y_train) on_epoch_kwargs = { 'X': X_train, 'X_valid': X_valid, 'y': y_train, 'y_valid': y_valid, } for _ in range(epochs): self.notify('on_epoch_begin', **on_epoch_kwargs) for Xi, yi in self.get_iterator(dataset_train, training=True): self.notify('on_batch_begin', X=Xi, y=yi, training=True) step = self.train_step(Xi, yi, **fit_params) self.history.record_batch('train_loss', step['loss'].data[0]) self.history.record_batch('train_batch_size', get_len(Xi)) self.notify('on_batch_end', X=Xi, y=yi, training=True, **step) if X_valid is None: self.notify('on_epoch_end', **on_epoch_kwargs) continue for Xi, yi in self.get_iterator(dataset_valid, training=False): self.notify('on_batch_begin', X=Xi, y=yi, training=False) step = self.validation_step(Xi, yi, **fit_params) self.history.record_batch('valid_loss', step['loss'].data[0]) self.history.record_batch('valid_batch_size', get_len(Xi)) self.notify('on_batch_end', X=Xi, y=yi, training=False, **step) self.notify('on_epoch_end', **on_epoch_kwargs) return self