コード例 #1
0
    def fit_loop(self, X, y=None, epochs=None, **fit_params):
        epochs = epochs if epochs is not None else self.max_epochs

        # split K-fold dataset indcies
        dataset = self.get_dataset(X, y)
        k = 10
        fold_indcies = self.split_k_fold(k, len(dataset))

        for e in range(epochs):
            # get train and validation set
            valid_fold_idx = e % k
            idx_train = []
            for i in range(k):
                if i == valid_fold_idx:
                    continue
                else:
                    idx_train += fold_indcies[i]
            idx_train = np.array(idx_train, dtype=int)
            idx_valid = np.array(fold_indcies[valid_fold_idx], dtype=int)

            dataset_train = torch.utils.data.Subset(dataset, idx_train)
            dataset_valid = torch.utils.data.Subset(dataset, idx_valid)
            on_epoch_kwargs = {
                'dataset_train': dataset_train,
                'dataset_valid': dataset_valid,
            }

            self.notify('on_epoch_begin', **on_epoch_kwargs)
            train_batch_count = 0
            for data in self.get_iterator(dataset_train, training=True):
                xi, yi = data
                self.notify('on_batch_begin', X=xi, y=yi, training=True)
                step = self.train_step(xi, yi, **fit_params)
                self.history.record_batch('train_loss', step['loss'].item())
                self.history.record_batch('train_batch_size', get_len(xi))
                self.notify('on_batch_end', X=xi, y=yi, training=True, **step)
                train_batch_count += 1
            self.history.record("train_batch_count", train_batch_count)

            valid_batch_count = 0
            for data in self.get_iterator(dataset_valid, training=False):
                xi, yi = data
                self.notify('on_batch_begin', X=xi, y=yi, training=False)
                step = self.validation_step(xi, yi, **fit_params)
                self.history.record_batch('valid_loss', step['loss'].item())
                self.history.record_batch('valid_batch_size', get_len(xi))
                self.notify('on_batch_end', X=xi, y=yi, training=False, **step)
                valid_batch_count += 1
            self.history.record("valid_batch_count", valid_batch_count)

            self.notify('on_epoch_end', **on_epoch_kwargs)
        return self
コード例 #2
0
    def run_single_epoch(self, dataset, training, prefix, step_fn,
                         **fit_params):
        """Compute a single epoch of train or validation.

        Parameters
        ----------
        dataset : torch Dataset
            The initialized dataset to loop over.

        training : bool
            Whether to set the module to train mode or not.

        prefix : str
            Prefix to use when saving to the history.

        step_fn : callable
            Function to call for each batch.

        **fit_params : dict
            Additional parameters passed to the ``step_fn``.
        """
        is_placeholder_y = uses_placeholder_y(dataset)

        batch_count = 0
        for i, data in enumerate(self.get_iterator(dataset,
                                                   training=training)):
            Xi, yi = unpack_data(data)
            yi_res = yi if not is_placeholder_y else None
            self.notify("on_batch_begin", X=Xi, y=yi_res, training=training)
            step = step_fn(Xi,
                           yi,
                           train_generator=(i %
                                            self.train_generator_every == 0),
                           **fit_params)
            self.history.record_batch(prefix + "_distance",
                                      step["distance"].item())
            self.history.record_batch(prefix + "_batch_size", get_len(Xi))
            self.notify("on_batch_end",
                        X=Xi,
                        y=yi_res,
                        training=training,
                        **step)
            batch_count += 1

        self.history.record(prefix + "_batch_count", batch_count)
コード例 #3
0
    def run_single_epoch(self, dataset, training, prefix, step_fn,
                         **fit_params):
        """Compute a single epoch of train or validation.

        Parameters
        ----------
        dataset : torch Dataset
            The initialized dataset to loop over.

        training : bool
            Whether to set the module to train mode or not.

        prefix : str
            Prefix to use when saving to the history.

        step_fn : callable
            Function to call for each batch.

        **fit_params : dict
            Additional parameters passed to the ``step_fn``.
        """
        is_placeholder_y = uses_placeholder_y(dataset)

        batch_count = 0
        for data in self.get_iterator(dataset, training=training):
            # Removes the dummy target
            data, _ = data

            # Removes the query target to limit temptation
            yi = data['query'].pop(1)
            yi_res = yi if not is_placeholder_y else None

            self.notify("on_batch_begin", X=data, y=yi_res, training=training)
            step = step_fn(data, yi, **fit_params)
            self.history.record_batch(prefix + "_loss", step["loss"].item())
            self.history.record_batch(prefix + "_batch_size", get_len(data))
            self.notify("on_batch_end",
                        X=data,
                        y=yi_res,
                        training=training,
                        **step)
            batch_count += 1

        self.history.record(prefix + "_batch_count", batch_count)
コード例 #4
0
    def run_single_epoch(self, dataset, training, prefix, step_fn,
                         **fit_params):
        is_placeholder_y = uses_placeholder_y(dataset)

        batch_count = 0
        for data in self.get_iterator(dataset, training=training):
            Xi, yi = unpack_data(data)
            yi_res = yi if not is_placeholder_y else None
            self.notify("on_batch_begin", X=Xi, y=yi_res, training=training)
            step = step_fn(Xi, yi, **fit_params)
            self.history.record_batch(prefix + "_loss", step["loss"].item())
            self.history.record_batch(prefix + "_batch_size",
                                      get_len(Xi["nodes"]))
            self.notify("on_batch_end",
                        X=Xi,
                        y=yi_res,
                        training=training,
                        **step)
            batch_count += 1
        self.history.record(prefix + "_batch_count", batch_count)
コード例 #5
0
 def _get_batches_per_epoch_phase(self, net, dataset, training):
     if dataset is None:
         return 0
     batch_size = self._get_batch_size(net, training)
     return int(np.ceil(get_len(dataset) / batch_size))
コード例 #6
0
 def test_inconsistent_lengths(self, get_len, data):
     with pytest.raises(ValueError):
         get_len(data)
コード例 #7
0
 def test_valid_lengths(self, get_len, data, expected):
     length = get_len(data)
     assert length == expected
コード例 #8
0
ファイル: logging.py プロジェクト: YangHaha11514/skorch
 def _get_batches_per_epoch_phase(self, net, X, training):
     if X is None:
         return 0
     batch_size = self._get_batch_size(net, training)
     return int(np.ceil(get_len(X) / batch_size))
コード例 #9
0
ファイル: test_dataset.py プロジェクト: YangHaha11514/skorch
 def test_inconsistent_lengths(self, get_len, data):
     with pytest.raises(ValueError):
         get_len(data)
コード例 #10
0
ファイル: test_dataset.py プロジェクト: YangHaha11514/skorch
 def test_valid_lengths(self, get_len, data, expected):
     length = get_len(data)
     assert length == expected
コード例 #11
0
ファイル: net.py プロジェクト: rain1024/skorch
    def fit_loop(self, X, y=None, epochs=None, **fit_params):
        """The proper fit loop.

        Contains the logic of what actually happens during the fit
        loop.

        Parameters
        ----------
        X : input data, compatible with skorch.dataset.Dataset
          By default, you should be able to pass:

            * numpy arrays
            * torch tensors
            * pandas DataFrame or Series
            * a dictionary of the former three
            * a list/tuple of the former three

          If this doesn't work with your data, you have to pass a
          ``Dataset`` that can deal with the data.

        y : target data, compatible with skorch.dataset.Dataset
          The same data types as for ``X`` are supported.

        epochs : int or None (default=None)
          If int, train for this number of epochs; if None, use
          ``self.max_epochs``.

        **fit_params : dict
          Additional parameters passed to the ``forward`` method of
          the module and to the train_split call.

        """
        self.check_data(X, y)
        epochs = epochs if epochs is not None else self.max_epochs

        if self.train_split:
            X_train, X_valid, y_train, y_valid = self.train_split(
                X, y, **fit_params)
            dataset_valid = self.get_dataset(X_valid, y_valid)
        else:
            X_train, X_valid, y_train, y_valid = X, None, y, None
            dataset_valid = None
        dataset_train = self.get_dataset(X_train, y_train)

        on_epoch_kwargs = {
            'X': X_train,
            'X_valid': X_valid,
            'y': y_train,
            'y_valid': y_valid,
        }

        for _ in range(epochs):
            self.notify('on_epoch_begin', **on_epoch_kwargs)

            for Xi, yi in self.get_iterator(dataset_train, training=True):
                self.notify('on_batch_begin', X=Xi, y=yi, training=True)
                step = self.train_step(Xi, yi, **fit_params)
                self.history.record_batch('train_loss', step['loss'].data[0])
                self.history.record_batch('train_batch_size', get_len(Xi))
                self.notify('on_batch_end', X=Xi, y=yi, training=True, **step)

            if X_valid is None:
                self.notify('on_epoch_end', **on_epoch_kwargs)
                continue

            for Xi, yi in self.get_iterator(dataset_valid, training=False):
                self.notify('on_batch_begin', X=Xi, y=yi, training=False)
                step = self.validation_step(Xi, yi, **fit_params)
                self.history.record_batch('valid_loss', step['loss'].data[0])
                self.history.record_batch('valid_batch_size', get_len(Xi))
                self.notify('on_batch_end', X=Xi, y=yi, training=False, **step)

            self.notify('on_epoch_end', **on_epoch_kwargs)
        return self