Beispiel #1
0
    def evaluate(self, verbose=-1, data_key=None):  # Note: kwargs appear unused but are inspected in inject_sampler
        """Evaluate this trial on the validation data.

        Args:
            verbose (int): If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress, If -1: Automatic
            data_key (StateKey): Optional :class:`.StateKey` for the data to evaluate on. Default: torchbearer.VALIDATION_DATA

        Returns:
            dict: The final metric values
        """
        state = State()
        state.update({
            torchbearer.MAX_EPOCHS: 1,
            torchbearer.EPOCH: 0,
            torchbearer.STOP_TRAINING: False
        })
        state.update(self.state)  # TODO: Hack to make injection work, should be removed if `self.state` is mutable

        if state[torchbearer.GENERATOR] is not None or state[torchbearer.STEPS] is not None:
            state[torchbearer.CALLBACK_LIST].on_start(state)
            state[torchbearer.CALLBACK_LIST].on_start_epoch(state)

            self.eval()
            state = self._test_pass(state)

            state[torchbearer.CALLBACK_LIST].on_end_epoch(state)

            if len(self.state[torchbearer.HISTORY]) != 0:
                self.state[torchbearer.HISTORY][-1][1].update(state[torchbearer.METRICS])

            state[torchbearer.CALLBACK_LIST].on_end(state)
            return state[torchbearer.METRICS]
        return {}
Beispiel #2
0
    def replay(self, callbacks=[], verbose=2, one_batch=False):  # TODO: Should we track if testing passes have happened?
        """ Replay the fit passes stored in history with given callbacks, useful when reloading a saved Trial. Note that only progress and metric information is populated in state during a replay.

        Args:
            callbacks (list): List of callbacks to be run during the replay
            verbose (int): If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
            one_batch (bool): If True, only one batch per epoch is replayed. If False, all batches are replayed

        Returns:
            Trial: self
        """
        history = self.state[torchbearer.HISTORY]
        callbacks.append(get_printer(verbose=verbose, validation_label_letter='v'))
        callbacks = CallbackList(callbacks)

        state = State()
        state.update(self.state)
        state[torchbearer.STOP_TRAINING] = False
        state[torchbearer.MAX_EPOCHS] = len(history)

        callbacks.on_start(state)
        for i in range(len(history)):
            state[torchbearer.EPOCH] = i
            if not one_batch:
                state[torchbearer.TRAIN_STEPS], state[torchbearer.VALIDATION_STEPS] = history[i][0]
            else:
                state[torchbearer.TRAIN_STEPS], state[torchbearer.VALIDATION_STEPS] = 1, 1
            state[torchbearer.METRICS] = history[i][1]

            self._replay_pass(state, callbacks)
        callbacks.on_end(state)

        return self
Beispiel #3
0
    def __init__(self, model, optimizer=None, criterion=None, metrics=[], callbacks=[], verbose=2):
        if criterion is None:
            def criterion(_, __):
                return torch.zeros(1, device=self.state[torchbearer.DEVICE], dtype=self.state[torchbearer.DATA_TYPE], requires_grad=True)

        self.verbose = verbose

        self.closure = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, torchbearer.OPTIMIZER)
        self.state = State()
        self.state.update({
            torchbearer.MODEL: model,
            torchbearer.CRITERION: criterion,
            torchbearer.OPTIMIZER: optimizer if optimizer is not None else MockOptimizer(),
            torchbearer.METRIC_LIST: MetricList(metrics),
            torchbearer.CALLBACK_LIST: CallbackList(callbacks),
            torchbearer.DEVICE: 'cpu',
            torchbearer.DATA_TYPE: torch.float32,
            torchbearer.SELF: self,
            torchbearer.HISTORY: [],
            torchbearer.BACKWARD_ARGS: {},
            torchbearer.TRAIN_GENERATOR: None,
            torchbearer.VALIDATION_GENERATOR: None,
            torchbearer.TEST_GENERATOR: None,
            torchbearer.TRAIN_STEPS: None,
            torchbearer.VALIDATION_STEPS: None,
            torchbearer.TEST_STEPS: None,
            torchbearer.TRAIN_DATA: None,
            torchbearer.VALIDATION_DATA: None,
            torchbearer.TEST_DATA: None,
            torchbearer.INF_TRAIN_LOADING: False,
            torchbearer.LOADER: None
        })

        self.state[torchbearer.CALLBACK_LIST].on_init(self.state)
Beispiel #4
0
    def replay(self,
               callbacks=[],
               verbose=2
               ):  # TODO: Should we track if testing passes have happened?
        """ Replay the fit passes stored in history with given callbacks, useful when reloading a saved Trial. Note that only progress and metric information is populated in state during a replay.

        :param callbacks: List of callbacks to be run during the replay
        :type callbacks: list
        :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
        :type verbose: int
        :return: self
        :rtype: Trial
        """
        history = self.state[torchbearer.HISTORY]
        callbacks.append(
            get_printer(verbose=verbose, validation_label_letter='v'))
        callbacks = CallbackList(callbacks)

        state = State()
        state.update(self.state)
        state[torchbearer.STOP_TRAINING] = False
        state[torchbearer.MAX_EPOCHS] = len(history)

        callbacks.on_start(state)
        for i in range(len(history)):
            state[torchbearer.EPOCH] = i
            state[torchbearer.TRAIN_STEPS], state[
                torchbearer.VALIDATION_STEPS] = history[i][0]
            state[torchbearer.METRICS] = history[i][1]

            self._replay_pass(state, callbacks)
        callbacks.on_end(state)
Beispiel #5
0
    def evaluate(
        self,
        verbose=2,
        data_key=None
    ):  # Note: kwargs appear unused but are inspected in inject_sampler
        """Evaluate this trial on the validation data.

        :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
        :type verbose: int
        :param data_key: Optional key for the data to evaluate on. Default: torchbearer.VALIDATION_DATA
        :type data_key: StateKey
        :return: The final metric values
        :rtype: dict
        """
        state = State()
        state.update({
            torchbearer.MAX_EPOCHS: 1,
            torchbearer.EPOCH: 0,
            torchbearer.STOP_TRAINING: False
        })
        state.update(
            self.state
        )  # TODO: Hack to make injection work, should be removed if `self.state` is mutable

        if state[torchbearer.GENERATOR] is not None or state[
                torchbearer.STEPS] is not None:
            self.eval()

            return self._test_pass(state)[torchbearer.METRICS]
        return {}
Beispiel #6
0
    def evaluate(self, verbose=2):
        """Evaluate this trial on the validation data.

        :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
        :type verbose: int
        :return: The final metric values
        :rtype: dict
        """
        state = State()
        state.update({
            torchbearer.MAX_EPOCHS: 1,
            torchbearer.EPOCH: 0,
            torchbearer.STOP_TRAINING: False
        })

        state.update(self.state)  # TODO: Hack to make injection work, should be removed if `self.state` is mutable

        if state[torchbearer.VALIDATION_GENERATOR] is not None or state[torchbearer.VALIDATION_STEPS] is not None:
            self.eval()

            state[torchbearer.STEPS] = state[torchbearer.VALIDATION_STEPS]
            state[torchbearer.GENERATOR] = state[torchbearer.VALIDATION_GENERATOR]

            return self._test_pass(state)[torchbearer.METRICS]
        return {}
Beispiel #7
0
    def __init__(self, model, optimizer=None, criterion=None, metrics=[], callbacks=[], pass_state=False):
        if criterion is None:
            def criterion(_, y_true):
                return torch.zeros(1, device=y_true.device)

        self.pass_state = pass_state

        self.state = State()
        self.state.update({
            torchbearer.MODEL: model,
            torchbearer.CRITERION: criterion,
            torchbearer.OPTIMIZER: optimizer if optimizer is not None else MockOptimizer(),
            torchbearer.METRIC_LIST: MetricList(metrics),
            torchbearer.CALLBACK_LIST: CallbackList(callbacks),
            torchbearer.DEVICE: 'cpu',
            torchbearer.DATA_TYPE: torch.float32,
            torchbearer.SELF: self,
            torchbearer.HISTORY: [],
            torchbearer.BACKWARD_ARGS: {},
            torchbearer.TRAIN_GENERATOR: None,
            torchbearer.VALIDATION_GENERATOR: None,
            torchbearer.TEST_GENERATOR: None,
            torchbearer.TRAIN_STEPS: None,
            torchbearer.VALIDATION_STEPS: None,
            torchbearer.TEST_STEPS: None
        })
Beispiel #8
0
    def run(self, epochs=1, verbose=-1):
        r"""Run this trial for the given number of epochs, starting from the last trained epoch.

        Args:
            epochs (int, optional): The number of epochs to run for
            verbose (int, optional): If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training
            progress, If -1: Automatic

        State Requirements:
            - :attr:`torchbearer.state.MODEL`: Model should be callable and not none, set on Trial init

        Returns:
            list: The model history (list of tuple of steps summary and epoch metric dicts)
        """
        state = State()
        state.update({
            torchbearer.MAX_EPOCHS: epochs,
            torchbearer.STOP_TRAINING: False,
        })

        state.update(self.state)  # TODO: Swap this for something which makes `self.state` still mutable

        if state[torchbearer.MODEL] is None or not callable(state[torchbearer.MODEL]):
            warnings.warn('The Model is None or not callable which may cause issues if not deliberate')
            state[torchbearer.MODEL] = lambda *args, **kwargs: None

        if state[torchbearer.TRAIN_GENERATOR] is not None \
                or state[torchbearer.TRAIN_STEPS] is not None \
                or state[torchbearer.VALIDATION_GENERATOR] is not None \
                or state[torchbearer.VALIDATION_STEPS] is not None:

            state[torchbearer.CALLBACK_LIST].on_start(state)

            for state[torchbearer.EPOCH] in range(len(state[torchbearer.HISTORY]), state[torchbearer.MAX_EPOCHS]):
                state[torchbearer.CALLBACK_LIST].on_start_epoch(state)

                final_metrics = self._fit_pass(state)[torchbearer.METRICS]

                if state[torchbearer.STOP_TRAINING]:
                    break

                final_metrics.update(self._validation_pass(state))
                state[torchbearer.METRICS] = final_metrics
                state[torchbearer.CALLBACK_LIST].on_end_epoch(state)
                steps_summary = (state[torchbearer.TRAIN_STEPS], state[torchbearer.VALIDATION_STEPS])
                self.state[torchbearer.HISTORY].append((steps_summary, state[torchbearer.METRICS]))
                state[torchbearer.CALLBACK_LIST].on_checkpoint(state)

                if state[torchbearer.STOP_TRAINING]:
                    break

            state[torchbearer.CALLBACK_LIST].on_end(state)

        return self.state[torchbearer.HISTORY]
Beispiel #9
0
    def run(self, epochs=1, verbose=2):
        """Run this trial for the given number of epochs, starting from the last trained epoch.

        :param epochs: The number of epochs to run for
        :type epochs: int
        :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
        :type verbose: int
        :return: The model history (dict of epoch metrics)
        :rtype: dict
        """
        state = State()
        state.update({
            torchbearer.MAX_EPOCHS: epochs,
            torchbearer.STOP_TRAINING: False
        })

        state.update(
            self.state
        )  # TODO: Swap this for something which makes `self.state` still mutable

        state[torchbearer.CALLBACK_LIST].on_start(state)

        for state[torchbearer.EPOCH] in range(len(state[torchbearer.HISTORY]),
                                              state[torchbearer.MAX_EPOCHS]):
            state[torchbearer.CALLBACK_LIST].on_start_epoch(state)

            final_metrics = self._fit_pass(state)[torchbearer.METRICS]

            if state[torchbearer.STOP_TRAINING]:
                break

            final_metrics.update(self._validation_pass(state))
            state[torchbearer.METRICS] = final_metrics
            state[torchbearer.CALLBACK_LIST].on_end_epoch(state)
            steps_summary = (state[torchbearer.TRAIN_STEPS],
                             state[torchbearer.VALIDATION_STEPS])
            self.state[torchbearer.HISTORY].append(
                (steps_summary, state[torchbearer.METRICS]))

            if state[torchbearer.STOP_TRAINING]:
                break

        state[torchbearer.CALLBACK_LIST].on_end(state)

        return self.state[torchbearer.HISTORY]