示例#1
0
    def test_history_save_load_cycle_file_path(self, history, tmpdir):
        history_f = tmpdir.mkdir('skorch').join('history.json')

        history.to_file(str(history_f))
        new_history = History.from_file(str(history_f))

        assert history == new_history
示例#2
0
    def test_history_save_load_cycle_file_obj(self, history, tmpdir):
        history_f = tmpdir.mkdir('skorch').join('history.json')

        with open(str(history_f), 'w') as f:
            history.to_file(f)

        with open(str(history_f), 'r') as f:
            new_history = History.from_file(f)

        assert history == new_history
示例#3
0
    def load_params(self,
                    f_params=None,
                    f_optimizer=None,
                    f_history=None,
                    checkpoint=None):
        """Loads the the module's parameters, history, and optimizer,
        not the whole object.

        To save and load the whole object, use pickle.

        ``f_params`` and ``f_optimizer`` uses PyTorchs'
        :func:`~torch.save`.

        Parameters
        ----------
        f_params : file-like object, str, None (default=None)
          Path of module parameters. Pass ``None`` to not load.

        f_optimizer : file-like object, str, None (default=None)
          Path of optimizer. Pass ``None`` to not load.

        f_history : file-like object, str, None (default=None)
          Path to history. Pass ``None`` to not load.

        checkpoint : :class:`.Checkpoint`, None (default=None)
          Checkpoint to load params from. If a checkpoint and a ``f_*``
          path is passed in, the ``f_*`` will be loaded. Pass
          ``None`` to not load.
        """
        def _get_state_dict(f):
            map_location = get_map_location(self.device)
            self.device = self._check_device(self.device, map_location)
            return torch.load(f, map_location=map_location)

        if f_history is not None:
            self.history = History.from_file(f_history)

        if checkpoint is not None:
            if not self.initialized_:
                self.initialize()
            if f_history is None and checkpoint.f_history is not None:
                self.history = History.from_file(checkpoint.f_history_)
            formatted_files = checkpoint.get_formatted_files(self)
            f_params = f_params or formatted_files['f_params']
            f_optimizer = f_optimizer or formatted_files['f_optimizer']

        if f_params is not None:
            msg = ("Cannot load parameters of an un-initialized model. "
                   "Please initialize first by calling .initialize() "
                   "or by fitting the model with .fit(...).")
            self.check_is_fitted(msg=msg)
            state_dict = _get_state_dict(f_params)
            state_dict_critic = _get_state_dict(f_params + '_critic')
            self.module_.load_state_dict(state_dict)
            self.critic_.load_state_dict(state_dict_critic)

        if f_optimizer is not None:
            msg = ("Cannot load state of an un-initialized optimizer. "
                   "Please initialize first by calling .initialize() "
                   "or by fitting the model with .fit(...).")
            self.check_is_fitted(attributes=['optimizer_'], msg=msg)
            state_dict = _get_state_dict(f_optimizer)
            state_dict_critic = _get_state_dict(f_optimizer + '_critic')
            self.optimizer_.load_state_dict(state_dict)
            self.critic_optimizer_.load_state_dict(state_dict_critic)
示例#4
0
                                                5,
                                                new_train_set,
                                                n_perturbations,
                                                batch_size=200)
    amp_pert_mdiff = perturbation_correlation(amp_perturbation_additive,
                                              mean_diff_feature_maps,
                                              pred_fn,
                                              5,
                                              new_train_set,
                                              n_perturbations,
                                              batch_size=200)

    freqs = np.fft.rfftfreq(new_train_set.shape[2], d=1 / 250.)
    history = History()
    history = history.from_file(
        home +
        '/logs/model_1_lr_0.001/histories/history_{last_epoch[epoch]}.json')
    plot_history(history, None)

    correlation_monitor = CorrelationMonitor1D(
        input_time_length=input_time_length, setname='idk')

    all_preds = []
    all_targets = []
    dataset = test_set
    for X, y in zip(train_set.X, train_set.y):
        preds = model(np_to_var(X).double())
        all_preds.append(preds)
        all_targets.append(y)

    preds_2d = [p[:, None] for p in all_preds]