コード例 #1
0
 def test_subset_uses_placeholder_y(self, dataset_cls, data,
                                    uses_placeholder_y, cv_split_cls):
     X, _ = data
     ds = dataset_cls(X, y=None)
     ds_train, ds_valid = cv_split_cls(cv=2)(ds)
     assert uses_placeholder_y(ds_train)
     assert uses_placeholder_y(ds_valid)
コード例 #2
0
 def test_subset_dataset_uses_non_y_placeholder(self, dataset_cls, data,
                                                uses_placeholder_y,
                                                cv_split_cls):
     X, y = data
     ds = dataset_cls(X, y)
     ds_train, ds_valid = cv_split_cls(cv=2)(ds)
     assert not uses_placeholder_y(ds_train)
     assert not uses_placeholder_y(ds_valid)
コード例 #3
0
 def test_subset_of_subset_uses_non_placeholder_y(self, dataset_cls, data,
                                                  uses_placeholder_y,
                                                  cv_split_cls):
     X, y = data
     ds = dataset_cls(X, y)
     ds_split, _ = cv_split_cls(cv=4)(ds)
     ds_train, ds_valid = cv_split_cls(cv=3)(ds_split)
     assert not uses_placeholder_y(ds_train)
     assert not uses_placeholder_y(ds_valid)
コード例 #4
0
    def run_single_epoch(self, dataset, training, prefix, step_fn,
                         **fit_params):
        """Compute a single epoch of train or validation.

        Parameters
        ----------
        dataset : torch Dataset
            The initialized dataset to loop over.

        training : bool
            Whether to set the module to train mode or not.

        prefix : str
            Prefix to use when saving to the history.

        step_fn : callable
            Function to call for each batch.

        **fit_params : dict
            Additional parameters passed to the ``step_fn``.
        """
        is_placeholder_y = uses_placeholder_y(dataset)

        batch_count = 0
        for i, data in enumerate(self.get_iterator(dataset,
                                                   training=training)):
            Xi, yi = unpack_data(data)
            yi_res = yi if not is_placeholder_y else None
            self.notify("on_batch_begin", X=Xi, y=yi_res, training=training)
            step = step_fn(Xi,
                           yi,
                           train_generator=(i %
                                            self.train_generator_every == 0),
                           **fit_params)
            self.history.record_batch(prefix + "_distance",
                                      step["distance"].item())
            self.history.record_batch(prefix + "_batch_size", get_len(Xi))
            self.notify("on_batch_end",
                        X=Xi,
                        y=yi_res,
                        training=training,
                        **step)
            batch_count += 1

        self.history.record(prefix + "_batch_count", batch_count)
コード例 #5
0
    def run_single_epoch(self, dataset, training, prefix, step_fn,
                         **fit_params):
        """Compute a single epoch of train or validation.

        Parameters
        ----------
        dataset : torch Dataset
            The initialized dataset to loop over.

        training : bool
            Whether to set the module to train mode or not.

        prefix : str
            Prefix to use when saving to the history.

        step_fn : callable
            Function to call for each batch.

        **fit_params : dict
            Additional parameters passed to the ``step_fn``.
        """
        is_placeholder_y = uses_placeholder_y(dataset)

        batch_count = 0
        for data in self.get_iterator(dataset, training=training):
            # Removes the dummy target
            data, _ = data

            # Removes the query target to limit temptation
            yi = data['query'].pop(1)
            yi_res = yi if not is_placeholder_y else None

            self.notify("on_batch_begin", X=data, y=yi_res, training=training)
            step = step_fn(data, yi, **fit_params)
            self.history.record_batch(prefix + "_loss", step["loss"].item())
            self.history.record_batch(prefix + "_batch_size", get_len(data))
            self.notify("on_batch_end",
                        X=data,
                        y=yi_res,
                        training=training,
                        **step)
            batch_count += 1

        self.history.record(prefix + "_batch_count", batch_count)
コード例 #6
0
    def run_single_epoch(self, dataset, training, prefix, step_fn,
                         **fit_params):
        is_placeholder_y = uses_placeholder_y(dataset)

        batch_count = 0
        for data in self.get_iterator(dataset, training=training):
            Xi, yi = unpack_data(data)
            yi_res = yi if not is_placeholder_y else None
            self.notify("on_batch_begin", X=Xi, y=yi_res, training=training)
            step = step_fn(Xi, yi, **fit_params)
            self.history.record_batch(prefix + "_loss", step["loss"].item())
            self.history.record_batch(prefix + "_batch_size",
                                      get_len(Xi["nodes"]))
            self.notify("on_batch_end",
                        X=Xi,
                        y=yi_res,
                        training=training,
                        **step)
            batch_count += 1
        self.history.record(prefix + "_batch_count", batch_count)
コード例 #7
0
 def test_custom_dataset_uses_non_y_placeholder(self, custom_dataset_cls,
                                                uses_placeholder_y):
     ds = custom_dataset_cls()
     assert not uses_placeholder_y(ds)
コード例 #8
0
 def test_dataset_uses_non_y_placeholder(self, dataset_cls, data,
                                         uses_placeholder_y):
     X, y = data
     ds = dataset_cls(X, y)
     assert not uses_placeholder_y(ds)
コード例 #9
0
 def test_dataset_uses_y_placeholder(self, dataset_cls, data,
                                     uses_placeholder_y):
     X, _ = data
     ds = dataset_cls(X, y=None)
     assert uses_placeholder_y(ds)
コード例 #10
0
    def fit_loop(self, X, y=None, epochs=None, **fit_params):
        """The proper fit loop.

        Contains the logic of what actually happens during the fit
        loop.

        Parameters
        ----------
        X : input data, compatible with skorch.dataset.Dataset
          By default, you should be able to pass:

            * numpy arrays
            * torch tensors
            * pandas DataFrame or Series
            * a dictionary of the former three
            * a list/tuple of the former three
            * a Dataset

          If this doesn't work with your data, you have to pass a
          ``Dataset`` that can deal with the data.

        y : target data, compatible with skorch.dataset.Dataset
          The same data types as for ``X`` are supported. If your X is
          a Dataset that contains the target, ``y`` may be set to
          None.

        epochs : int or None (default=None)
          If int, train for this number of epochs; if None, use
          ``self.max_epochs``.

        **fit_params : dict
          Additional parameters passed to the ``forward`` method of
          the module and to the ``self.train_split`` call.

        """
        self.check_data(X, y)
        epochs = epochs if epochs is not None else self.max_epochs

        dataset_train, dataset_valid = self.get_split_datasets(
            X, y, **fit_params)
        on_epoch_kwargs = {
            'dataset_train': dataset_train,
            'dataset_valid': dataset_valid,
        }

        y_train_is_ph = uses_placeholder_y(dataset_train)
        y_valid_is_ph = uses_placeholder_y(dataset_valid)

        for _ in range(epochs):
            self.notify('on_epoch_begin', **on_epoch_kwargs)

            for Xi, yi in self.get_iterator(dataset_train, training=True):
                yi_res = yi if not y_train_is_ph else None
                self.notify('on_batch_begin', X=Xi, y=yi_res, training=True)
                step = self.train_step(Xi, yi, **fit_params)
                self.history.record_batch('train_loss', step['loss'].item())
                self.history.record_batch('train_batch_size', get_len(Xi))
                self.notify('on_batch_end',
                            X=Xi,
                            y=yi_res,
                            training=True,
                            **step)

            if dataset_valid is None:
                self.notify('on_epoch_end', **on_epoch_kwargs)
                continue

            for Xi, yi in self.get_iterator(dataset_valid, training=False):
                yi_res = yi if not y_valid_is_ph else None
                self.notify('on_batch_begin', X=Xi, y=yi_res, training=False)
                step = self.validation_step(Xi, yi, **fit_params)
                self.history.record_batch('valid_loss', step['loss'].item())
                self.history.record_batch('valid_batch_size', get_len(Xi))
                self.notify('on_batch_end',
                            X=Xi,
                            y=yi_res,
                            training=False,
                            **step)

            self.notify('on_epoch_end', **on_epoch_kwargs)
        return self