Esempio n. 1
0
    def test_aggreate_predictions(self):
        aggregator = AggregatePredictions()

        y_pred_1 = torch.Tensor([1, 2, 3])
        y_pred_2 = torch.Tensor([3, 4, 5])

        state_1 = {torchbearer.Y_PRED: y_pred_1}
        state_2 = {torchbearer.Y_PRED: y_pred_2}
        final_state = {}

        aggregator.on_step_validation(state_1)
        self.assertTrue(
            list(aggregator.predictions_list[0].numpy()) == list(
                y_pred_1.numpy()))

        aggregator.on_step_validation(state_2)
        self.assertTrue(
            list(aggregator.predictions_list[1].numpy()) == list(
                y_pred_2.numpy()))

        aggregate = torch.cat([y_pred_1, y_pred_2])
        aggregator.on_end_validation(final_state)
        self.assertTrue(
            list(final_state[torchbearer.FINAL_PREDICTIONS].numpy()) == list(
                aggregate.numpy()))
    def test_aggreate_predictions_multiple_calls(self):
        aggregator = AggregatePredictions()

        y_pred_1 = torch.Tensor([1,2,3])
        y_pred_2 = torch.Tensor([3,4,5])

        state_1 = {torchbearer.Y_PRED: y_pred_1}
        state_2 = {torchbearer.Y_PRED: y_pred_2}

        aggregator.on_step_validation(state_1)
        self.assertTrue(list(aggregator.predictions_list[0].numpy()) == list(y_pred_1.numpy()))

        aggregator.on_step_validation(state_2)
        self.assertTrue(list(aggregator.predictions_list[1].numpy()) == list(y_pred_2.numpy()))

        aggregator.on_end_epoch(state_2)
        self.assertTrue(list(aggregator.predictions_list) == [])
    def test_none_predictions(self):
        aggregator = AggregatePredictions()

        with warnings.catch_warnings(record=True) as w:
            state_1 = {torchbearer.Y_PRED: [None]}

            aggregator.on_step_validation(state_1)
            aggregator.on_step_validation(state_1)

            self.assertTrue(list(aggregator.predictions_list) == [[None], [None]])

            aggregator.on_end_validation(state_1)
            self.assertTrue(state_1[torchbearer.FINAL_PREDICTIONS] == [[None], [None]])
    def test_aggreate_predictions(self):
        aggregator = AggregatePredictions()

        y_pred_1 = torch.Tensor([1,2,3])
        y_pred_2 = torch.Tensor([3,4,5])

        state_1 = {tb.Y_PRED: y_pred_1}
        state_2 = {tb.Y_PRED: y_pred_2}
        final_state = {}

        aggregator.on_step_validation(state_1)
        self.assertTrue(list(aggregator.predictions_list[0].numpy()) == list(y_pred_1.numpy()))

        aggregator.on_step_validation(state_2)
        self.assertTrue(list(aggregator.predictions_list[1].numpy()) == list(y_pred_2.numpy()))

        aggregate = torch.cat([y_pred_1, y_pred_2])
        aggregator.on_end_validation(final_state)
        self.assertTrue(list(final_state[tb.FINAL_PREDICTIONS].numpy()) == list(aggregate.numpy()))
Esempio n. 5
0
class Trial(object):
    """
    The trial class contains all of the required hyper-parameters for model running in torchbearer and presents an
    API for model fitting, evaluating and predicting.

    Args:
        model (torch.nn.Module): The base pytorch model
        optimizer (torch.optim.Optimizer): The optimizer used for pytorch model weight updates
        criterion (func / None): The final loss criterion that provides a loss value to the optimizer
        metrics (list): The list of :class:`torchbearer.Metric <.Metric>` instances to process during fitting
        callbacks (list): The list of :class:`torchbearer.Callback <.Callback>` instances to call during fitting
        verbose (int): Global verbosity .If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training
            progress
    """
    def __init__(self, model, optimizer=None, criterion=None, metrics=[], callbacks=[], verbose=2):
        if criterion is None:
            def criterion(_, __):
                return torch.zeros(1, device=self.state[torchbearer.DEVICE], dtype=self.state[torchbearer.DATA_TYPE], requires_grad=True)

        self.verbose = verbose

        self.closure = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, torchbearer.OPTIMIZER)
        self.state = State()
        self.state.update({
            torchbearer.MODEL: model,
            torchbearer.CRITERION: criterion,
            torchbearer.OPTIMIZER: optimizer if optimizer is not None else MockOptimizer(),
            torchbearer.METRIC_LIST: MetricList(metrics),
            torchbearer.CALLBACK_LIST: CallbackList(callbacks),
            torchbearer.DEVICE: 'cpu',
            torchbearer.DATA_TYPE: torch.float32,
            torchbearer.SELF: self,
            torchbearer.HISTORY: [],
            torchbearer.BACKWARD_ARGS: {},
            torchbearer.TRAIN_GENERATOR: None,
            torchbearer.VALIDATION_GENERATOR: None,
            torchbearer.TEST_GENERATOR: None,
            torchbearer.TRAIN_STEPS: None,
            torchbearer.VALIDATION_STEPS: None,
            torchbearer.TEST_STEPS: None,
            torchbearer.TRAIN_DATA: None,
            torchbearer.VALIDATION_DATA: None,
            torchbearer.TEST_DATA: None,
            torchbearer.INF_TRAIN_LOADING: False,
            torchbearer.LOADER: None
        })

        self.state[torchbearer.CALLBACK_LIST].on_init(self.state)

    def __str__(self):
        def state_string(name, state_key):
            import math
            N = (50-len(name))/2
            res = "-" * int(math.floor(N)) + " " + name.upper() + " " + "-" * int(math.ceil(N))
            res = res + "-" if len(res) < 52 else res
            return res + "\n" + str(self.state[state_key]) + "\n\n"

        optim_str = state_string('Optimzer', torchbearer.OPTIMIZER)
        crit_str = state_string("Criterion", torchbearer.CRITERION)
        metrics_str = state_string("Metrics", torchbearer.METRIC_LIST)
        callbacks_str = state_string("Callbacks", torchbearer.CALLBACK_LIST)
        model_str = state_string("Model", torchbearer.MODEL)

        return optim_str + crit_str + metrics_str + callbacks_str + model_str

    def __repr__(self):
        return str(self)

    def for_train_steps(self, steps):
        """Run this trial for the given number of training steps. Note that the generator will output (None, None) if it
        has not been set. Useful for differentiable programming. Returns self so that methods can be chained for
        convenience. If steps is larger than dataset size then loader will be refreshed like if it was a new epoch. If
        steps is -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.

        Args:
            steps (int): The number of training steps per epoch to run.

        Returns:
            Trial: self
        """
        if not isinstance(steps, int):
            warnings.warn("Number of training steps is not an int, casting to int")
            steps = int(steps)
        self.state[torchbearer.TRAIN_STEPS] = steps
        self.state[torchbearer.TRAIN_DATA] = (self.state[torchbearer.TRAIN_GENERATOR], self.state[torchbearer.TRAIN_STEPS])

        return self

    def with_train_generator(self, generator, steps=None):
        """Use this trial with the given train generator. Returns self so that methods can be chained for convenience.

        Args:
            generator: The train data generator to use during calls to :meth:`.run`
            steps (int): The number of steps per epoch to take when using this generator.

        Returns:
            Trial: self
        """
        self.state[torchbearer.TRAIN_GENERATOR] = generator
        steps = self.state[torchbearer.TRAIN_STEPS] if steps is None else steps
        steps = len(generator) if steps is None else steps
        self.for_train_steps(steps)

        return self

    def with_train_data(self, x, y, batch_size=1, shuffle=True, num_workers=1, steps=None):
        """Use this trial with the given train data. Returns self so that methods can be chained for convenience.

        Args:
            x (torch.Tensor): The train x data to use during calls to :meth:`.run`
            y (torch.Tensor): The train labels to use during calls to :meth:`.run`
            batch_size (int): The size of each batch to sample from the data
            shuffle (bool): If True, then data will be shuffled each epoch
            num_workers (int): Number of worker threads to use in the data loader
            steps (int): The number of steps per epoch to take when using this data

        Returns:
            Trial: self
        """
        dataset = TensorDataset(x, y)
        dataloader = DataLoader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers)
        self.with_train_generator(dataloader, steps=steps)

        return self

    def for_val_steps(self, steps):
        """Run this trial for the given number of validation steps. Note that the generator will output (None, None) if
        it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for
        convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If
        steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.

        Args:
            steps (int): The number of validation steps per epoch to run

        Returns:
            Trial: self
        """
        if not isinstance(steps, int):
            warnings.warn("Number of validation steps is not an int, casting to int")
            steps = int(steps)
        self.state[torchbearer.VALIDATION_STEPS] = steps
        self.state[torchbearer.VALIDATION_DATA] = (self.state[torchbearer.VALIDATION_GENERATOR], self.state[torchbearer.VALIDATION_STEPS])

        return self

    def with_val_generator(self, generator, steps=None):
        """Use this trial with the given validation generator. Returns self so that methods can be chained for
        convenience.

        Args:
            generator: The validation data generator to use during calls to :meth:`.run` and :meth:`.evaluate`
            steps (int): The number of steps per epoch to take when using this generator

        Returns:
            Trial: self
        """
        self.state[torchbearer.VALIDATION_GENERATOR] = generator
        steps = self.state[torchbearer.VALIDATION_STEPS] if steps is None else steps
        steps = len(generator) if steps is None else steps
        self.for_val_steps(steps)

        return self

    def with_val_data(self, x, y, batch_size=1, shuffle=True, num_workers=1, steps=None):
        """Use this trial with the given validation data. Returns self so that methods can be chained for convenience.

        Args:
            x (torch.Tensor): The validation x data to use during calls to :meth:`.run` and :meth:`.evaluate`
            y (torch.Tensor): The validation labels to use during calls to :meth:`.run` and :meth:`.evaluate`
            batch_size (int): The size of each batch to sample from the data
            shuffle (bool): If True, then data will be shuffled each epoch
            num_workers (int): Number of worker threads to use in the data loader
            steps (int): The number of steps per epoch to take when using this data

        Returns:
            Trial: self
        """
        dataset = TensorDataset(x, y)
        dataloader = DataLoader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers)
        self.with_val_generator(dataloader, steps=steps)

        return self

    def for_test_steps(self, steps):
        """Run this trial for the given number of test steps. Note that the generator will output (None, None) if
        it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for
        convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If
        steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.

        Args:
            steps (int): The number of test steps per epoch to run (when using :meth:`.predict`)

        Returns:
            Trial: self
        """
        if not isinstance(steps, int):
            warnings.warn("Number of test steps is not an int, casting to int")
            steps = int(steps)
        self.state[torchbearer.TEST_STEPS] = steps
        self.state[torchbearer.TEST_DATA] = (self.state[torchbearer.TEST_GENERATOR], self.state[torchbearer.TEST_STEPS])

        return self

    def with_test_generator(self, generator, steps=None):
        """Use this trial with the given test generator. Returns self so that methods can be chained for convenience.

        Args:
            generator: The test data generator to use during calls to :meth:`.predict`
            steps (int): The number of steps per epoch to take when using this generator

        Returns:
            Trial: self
        """
        self.state[torchbearer.TEST_GENERATOR] = generator
        steps = self.state[torchbearer.TEST_STEPS] if steps is None else steps
        steps = len(generator) if steps is None else steps
        self.for_test_steps(steps)

        return self

    def with_test_data(self, x, batch_size=1, num_workers=1, steps=None):
        """Use this trial with the given test data. Returns self so that methods can be chained for convenience.

        Args:
            x (torch.Tensor): The test x data to use during calls to :meth:`.predict`
            batch_size (int): The size of each batch to sample from the data
            num_workers (int): Number of worker threads to use in the data loader
            steps (int): The number of steps per epoch to take when using this data

        Returns:
            Trial: self
        """
        dataset = TensorDataset(x)
        dataloader = DataLoader(dataset, batch_size, num_workers=num_workers)
        self.with_test_generator(dataloader, steps=steps)

        return self

    def for_steps(self, train_steps=None, val_steps=None, test_steps=None):
        """Use this trial for the given number of train, val and test steps. Returns self so that methods can be chained
        for convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If
        steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.

        Args:
            train_steps (int): The number of training steps per epoch to run
            val_steps (int): The number of validation steps per epoch to run
            test_steps (int): The number of test steps per epoch to run (when using :meth:`.predict`)

        Returns:
            Trial: self
        """
        if train_steps is not None:
            self.for_train_steps(train_steps)
        if val_steps is not None:
            self.for_val_steps(val_steps)
        if test_steps is not None:
            self.for_test_steps(test_steps)

        return self

    def with_generators(self, train_generator=None, val_generator=None, test_generator=None, train_steps=None, val_steps=None, test_steps=None):
        """Use this trial with the given generators. Returns self so that methods can be chained for convenience.

        Args:
            train_generator: The training data generator to use during calls to :meth:`.run`
            val_generator: The validation data generator to use during calls to :meth:`.run` and :meth:`.evaluate`
            test_generator: The testing data generator to use during calls to :meth:`.predict`
            train_steps (int): The number of steps per epoch to take when using the training generator
            val_steps (int): The number of steps per epoch to take when using the validation generator
            test_steps (int): The number of steps per epoch to take when using the testing generator

        Returns:
            Trial: self
        """
        if train_generator is not None:
            self.with_train_generator(train_generator, train_steps)
        if val_generator is not None:
            self.with_val_generator(val_generator, val_steps)
        if test_generator is not None:
            self.with_test_generator(test_generator, test_steps)

        return self

    def for_inf_train_steps(self):
        """Use this trial with an infinite number of training steps (until stopped via STOP_TRAINING flag or similar). 
        Returns self so that methods can be chained for convenience.

        Returns:
            Trial: self
        """
        self.for_train_steps(-1)
        return self

    def for_inf_val_steps(self):
        """Use this trial with an infinite number of validation steps (until stopped via STOP_TRAINING flag or similar).
        Returns self so that methods can be chained for convenience.
        
        Returns:
            Trial: self
        """
        self.for_val_steps(-1)
        return self

    def for_inf_test_steps(self):
        """Use this trial with an infinite number of test steps (until stopped via STOP_TRAINING flag or similar). 
        Returns self so that methods can be chained for convenience.
        
        Returns:
            Trial: self
        """
        self.for_test_steps(-1)
        return self

    def for_inf_steps(self, train=True, val=True, test=True):
        """Use this trail with infinite steps. Returns self so that methods can be chained for convenience.
        
        Args:
            train (bool): Use an infinite number of training steps
            val (bool): Use an infinite number of validation steps
            test (bool): Use an infinite number of test steps

        Returns:
            Trial: self
        """
        if train: self.for_inf_train_steps()
        if val: self.for_inf_val_steps()
        if test: self.for_inf_test_steps()

        return self

    def with_inf_train_loader(self):
        """Use this trial with a training iterator that refreshes when it finishes instead of each epoch. 
        This allows for setting training steps less than the size of the generator and model will still be trained on 
        all training samples if enough "epochs" are run.
        
        Returns:
            Trial: self:
        """
        self.state[torchbearer.INF_TRAIN_LOADING] = True

        return self

    def with_loader(self, batch_loader):
        """Use this trial with custom batch loader. Usually calls next on state[torchbearer.ITERATOR] and populates
        state[torchbearer.X] and state[torchbearer.Y_TRUE]

        Args:
            batch_loader (function): Function of state that extracts data from data loader (stored under torchbearer.ITERATOR), stores it in
            state and sends it to the correct device

        Returns:
            Trial: self:
        """
        self.state[torchbearer.LOADER] = batch_loader
        return self

    def with_closure(self, closure):
        """Use this trial with custom closure

        Args:
            closure (function): Function of state that defines the custom closure

        Returns:
            Trial: self:
        """
        self.closure = closure

        return self

    @inject_printer()
    def run(self, epochs=1, verbose=-1):
        r"""Run this trial for the given number of epochs, starting from the last trained epoch.

        Args:
            epochs (int, optional): The number of epochs to run for
            verbose (int, optional): If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training
            progress, If -1: Automatic

        State Requirements:
            - :attr:`torchbearer.state.MODEL`: Model should be callable and not none, set on Trial init

        Returns:
            list: The model history (list of tuple of steps summary and epoch metric dicts)
        """
        state = State()
        state.update({
            torchbearer.MAX_EPOCHS: epochs,
            torchbearer.STOP_TRAINING: False,
        })

        state.update(self.state)  # TODO: Swap this for something which makes `self.state` still mutable

        if state[torchbearer.MODEL] is None or not callable(state[torchbearer.MODEL]):
            warnings.warn('The Model is None or not callable which may cause issues if not deliberate')
            state[torchbearer.MODEL] = lambda *args, **kwargs: None

        if state[torchbearer.TRAIN_GENERATOR] is not None \
                or state[torchbearer.TRAIN_STEPS] is not None \
                or state[torchbearer.VALIDATION_GENERATOR] is not None \
                or state[torchbearer.VALIDATION_STEPS] is not None:

            state[torchbearer.CALLBACK_LIST].on_start(state)

            for state[torchbearer.EPOCH] in range(len(state[torchbearer.HISTORY]), state[torchbearer.MAX_EPOCHS]):
                state[torchbearer.CALLBACK_LIST].on_start_epoch(state)

                final_metrics = self._fit_pass(state)[torchbearer.METRICS]

                if state[torchbearer.STOP_TRAINING]:
                    break

                final_metrics.update(self._validation_pass(state))
                state[torchbearer.METRICS] = final_metrics
                state[torchbearer.CALLBACK_LIST].on_end_epoch(state)
                steps_summary = (state[torchbearer.TRAIN_STEPS], state[torchbearer.VALIDATION_STEPS])
                self.state[torchbearer.HISTORY].append((steps_summary, state[torchbearer.METRICS]))
                state[torchbearer.CALLBACK_LIST].on_checkpoint(state)

                if state[torchbearer.STOP_TRAINING]:
                    break

            state[torchbearer.CALLBACK_LIST].on_end(state)

        return self.state[torchbearer.HISTORY]

    @staticmethod
    def _new_iter(generator):
        if generator is None:
            return None
        if hasattr(generator, 'inf') and generator.inf:  # Inf train loader deals with the iterator itself
            return generator.tb_iter
        else:
            return iter(generator)

    @inject_sampler(torchbearer.TRAIN_DATA, load_batch_standard)
    def _fit_pass(self, state):
        state.update(self.state)  # TODO: Hack to make injection work, should be removed if `self.state` is mutable
        self.train()

        state[torchbearer.ITERATOR] = Trial._new_iter(state[torchbearer.GENERATOR])

        state[torchbearer.METRIC_LIST].reset(state)
        state[torchbearer.METRICS] = {}

        state[torchbearer.CALLBACK_LIST].on_start_training(state)
        for state[torchbearer.BATCH] in (range(state[torchbearer.STEPS]) if state[torchbearer.STEPS] != -1 else itertools.count()):
            state[torchbearer.SAMPLER](state)
            state[torchbearer.CALLBACK_LIST].on_sample(state)

            # Update parameters
            state[torchbearer.OPTIMIZER].step(lambda: self.closure(state))

            state[torchbearer.METRICS] = state[torchbearer.METRIC_LIST].process(state.data)
            state[torchbearer.CALLBACK_LIST].on_step_training(state)

            if state[torchbearer.STOP_TRAINING]:
                break

        state[torchbearer.METRICS].update(state[torchbearer.METRIC_LIST].process_final(state.data))

        state[torchbearer.CALLBACK_LIST].on_end_training(state)
        return state

    def _test_pass(self, state):
        with torch.no_grad():
            state[torchbearer.ITERATOR] = Trial._new_iter(state[torchbearer.GENERATOR])

            state[torchbearer.METRIC_LIST].reset(state)
            state[torchbearer.METRICS] = {}

            state[torchbearer.CALLBACK_LIST].on_start_validation(state)

            for state[torchbearer.BATCH] in range(state[torchbearer.STEPS]):
                state[torchbearer.SAMPLER](state)
                state[torchbearer.CALLBACK_LIST].on_sample_validation(state)

                # Forward Pass
                try:
                    state[torchbearer.Y_PRED] = state[torchbearer.MODEL](state[torchbearer.X], state=state)
                except TypeError:
                    state[torchbearer.Y_PRED] = state[torchbearer.MODEL](state[torchbearer.X])

                state[torchbearer.CALLBACK_LIST].on_forward_validation(state)

                # Loss and metrics
                if torchbearer.Y_TRUE in state:
                    # Loss Calculation
                    try:
                        state[torchbearer.LOSS] = state[torchbearer.CRITERION](state)
                    except TypeError:
                        state[torchbearer.LOSS] = state[torchbearer.CRITERION](state[torchbearer.Y_PRED],
                                                                           state[torchbearer.Y_TRUE])
                    state[torchbearer.CALLBACK_LIST].on_criterion_validation(state)
                    state[torchbearer.METRICS] = state[torchbearer.METRIC_LIST].process(state.data)

                state[torchbearer.CALLBACK_LIST].on_step_validation(state)
                if state[torchbearer.STOP_TRAINING]:
                    break

            if torchbearer.Y_TRUE in state:
                state[torchbearer.METRICS].update(state[torchbearer.METRIC_LIST].process_final(state.data))
            state[torchbearer.CALLBACK_LIST].on_end_validation(state)
        return state

    @inject_sampler(torchbearer.VALIDATION_DATA, load_batch_standard)
    def _validation_pass(self, state):
        state.update(self.state)  # TODO: Hack to make injection work, should be removed if `self.state` is mutable

        if state[torchbearer.VALIDATION_GENERATOR] is not None or state[torchbearer.VALIDATION_STEPS] is not None:
            self.eval()

            self._test_pass(state)
        return state[torchbearer.METRICS]

    @inject_sampler(torchbearer.VALIDATION_DATA, load_batch_standard)
    @inject_printer(validation_label_letter='e')
    def evaluate(self, verbose=-1, data_key=None):  # Note: kwargs appear unused but are inspected in inject_sampler
        """Evaluate this trial on the validation data.

        Args:
            verbose (int): If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress, If -1: Automatic
            data_key (StateKey): Optional :class:`.StateKey` for the data to evaluate on. Default: torchbearer.VALIDATION_DATA

        Returns:
            dict: The final metric values
        """
        state = State()
        state.update({
            torchbearer.MAX_EPOCHS: 1,
            torchbearer.EPOCH: 0,
            torchbearer.STOP_TRAINING: False
        })
        state.update(self.state)  # TODO: Hack to make injection work, should be removed if `self.state` is mutable

        if state[torchbearer.GENERATOR] is not None or state[torchbearer.STEPS] is not None:
            state[torchbearer.CALLBACK_LIST].on_start(state)
            state[torchbearer.CALLBACK_LIST].on_start_epoch(state)

            self.eval()
            state = self._test_pass(state)

            state[torchbearer.CALLBACK_LIST].on_end_epoch(state)

            if len(self.state[torchbearer.HISTORY]) != 0:
                self.state[torchbearer.HISTORY][-1][1].update(state[torchbearer.METRICS])

            state[torchbearer.CALLBACK_LIST].on_end(state)
            return state[torchbearer.METRICS]
        return {}

    @inject_callback(AggregatePredictions())
    @inject_sampler(torchbearer.TEST_DATA, load_batch_predict)
    @inject_printer(validation_label_letter='p')
    def predict(self, verbose=-1, data_key=None):  # Note: kwargs appear unused but are inspected in inject_sampler
        """Determine predictions for this trial on the test data.

        Args:
            verbose (int): If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress, If -1: Automatic
            data_key (StateKey): Optional :class:`.StateKey` for the data to predict on. Default: torchbearer.TEST_DATA

        Returns:
            list: Model outputs as a list
        """
        state = {
            torchbearer.MAX_EPOCHS: 1,
            torchbearer.EPOCH: 0,
            torchbearer.STOP_TRAINING: False
        }

        state.update(self.state)  # TODO: Hack to make injection work, should be removed if `self.state` is mutable

        if state[torchbearer.GENERATOR] is not None or state[torchbearer.STEPS] is not None:
            state[torchbearer.CALLBACK_LIST].on_start(state)
            state[torchbearer.CALLBACK_LIST].on_start_epoch(state)

            self.eval()
            res = self._test_pass(state)[torchbearer.FINAL_PREDICTIONS]

            state[torchbearer.CALLBACK_LIST].on_end_epoch(state)
            state[torchbearer.CALLBACK_LIST].on_end(state)
            return res
        return []

    def replay(self, callbacks=[], verbose=2, one_batch=False):  # TODO: Should we track if testing passes have happened?
        """ Replay the fit passes stored in history with given callbacks, useful when reloading a saved Trial. Note that only progress and metric information is populated in state during a replay.

        Args:
            callbacks (list): List of callbacks to be run during the replay
            verbose (int): If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
            one_batch (bool): If True, only one batch per epoch is replayed. If False, all batches are replayed

        Returns:
            Trial: self
        """
        history = self.state[torchbearer.HISTORY]
        callbacks.append(get_printer(verbose=verbose, validation_label_letter='v'))
        callbacks = CallbackList(callbacks)

        state = State()
        state.update(self.state)
        state[torchbearer.STOP_TRAINING] = False
        state[torchbearer.MAX_EPOCHS] = len(history)

        callbacks.on_start(state)
        for i in range(len(history)):
            state[torchbearer.EPOCH] = i
            if not one_batch:
                state[torchbearer.TRAIN_STEPS], state[torchbearer.VALIDATION_STEPS] = history[i][0]
            else:
                state[torchbearer.TRAIN_STEPS], state[torchbearer.VALIDATION_STEPS] = 1, 1
            state[torchbearer.METRICS] = history[i][1]

            self._replay_pass(state, callbacks)
        callbacks.on_end(state)

        return self

    def _replay_pass(self, state, callback_list):
        callback_list.on_start_epoch(state)
        all_metrics = state[torchbearer.METRICS]

        # Training pass
        state[torchbearer.STEPS] = state[torchbearer.TRAIN_STEPS] if state[torchbearer.TRAIN_STEPS] is not None else 0
        state[torchbearer.METRICS] = {key: all_metrics[key] for key in all_metrics.keys() if "val_" not in key}
        callback_list.on_start_training(state)
        for state[torchbearer.BATCH] in range(state[torchbearer.STEPS]):
            callback_list.on_sample(state)
            callback_list.on_forward(state)
            callback_list.on_criterion(state)
            callback_list.on_backward(state)
            callback_list.on_step_training(state)
            if state[torchbearer.STOP_TRAINING]:
                break
        callback_list.on_end_training(state)

        # Validation pass
        if not state[torchbearer.STOP_TRAINING]:
            state[torchbearer.STEPS] = state[torchbearer.VALIDATION_STEPS] if state[torchbearer.VALIDATION_STEPS] is not None else 0
            state[torchbearer.METRICS] = {key: all_metrics[key] for key in all_metrics.keys() if "val_" in key}
            callback_list.on_start_validation(state)
            for state[torchbearer.BATCH] in range(state[torchbearer.STEPS]):
                callback_list.on_sample_validation(state)
                callback_list.on_forward_validation(state)
                callback_list.on_criterion_validation(state)
                callback_list.on_step_validation(state)
                if state[torchbearer.STOP_TRAINING]:
                    break
            callback_list.on_end_validation(state)

        state[torchbearer.METRICS] = all_metrics
        callback_list.on_end_epoch(state)

        return self

    def train(self):
        """Set model and metrics to training mode.

        Returns:
            Trial: self
        """
        self.state[torchbearer.MODEL].train()
        self.state[torchbearer.METRIC_LIST].train()

        return self

    def eval(self):
        """Set model and metrics to evaluation mode

        Returns:
            Trial: self
        """
        self.state[torchbearer.MODEL].eval()
        if torchbearer.DATA in self.state:
            self.state[torchbearer.METRIC_LIST].eval(data_key=self.state[torchbearer.DATA])
        else:
            self.state[torchbearer.METRIC_LIST].eval()

        return self

    def to(self, *args, **kwargs):
        """ Moves and/or casts the parameters and buffers.

        Args:
            args: See: `torch.nn.Module.to <https://pytorch.org/docs/stable/nn.html?highlight=#torch.nn.Module.to>`_
            kwargs: See: `torch.nn.Module.to <https://pytorch.org/docs/stable/nn.html?highlight=#torch.nn.Module.to>`_

        Returns:
            Trial: self
        """
        self.state[torchbearer.MODEL].to(*args, **kwargs)

        for state in self.state[torchbearer.OPTIMIZER].state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(*args, **kwargs)

        self.state = update_device_and_dtype(self.state, *args, **kwargs)

        return self

    def cuda(self, device=None):
        """ Moves all model parameters and buffers to the GPU.

        Args:
            device (int): if specified, all parameters will be copied to that device

        Returns:
            Trial: self
        """
        if device is None:
            device = torch.cuda.current_device()
        self.to('cuda:' + str(device))

        return self

    def cpu(self):
        """ Moves all model parameters and buffers to the CPU.

        Returns:
            Trial: self
        """
        self.to('cpu')

        return self

    def state_dict(self, **kwargs):
        """Get a dict containing the model and optimizer states, as well as the model history.

        Args:
            kwargs: See: `torch.nn.Module.state_dict <https://pytorch.org/docs/stable/nn.html?highlight=#torch.nn.Module.state_dict>`_

        Returns:
            dict: A dict containing parameters and persistent buffers.
        """
        state_dict = {
            torchbearer.VERSION: torchbearer.__version__.replace('.dev', ''),
            torchbearer.MODEL: self.state[torchbearer.MODEL].state_dict(**kwargs),
            torchbearer.OPTIMIZER: self.state[torchbearer.OPTIMIZER].state_dict(),
            torchbearer.HISTORY: self.state[torchbearer.HISTORY],
            torchbearer.CALLBACK_LIST: self.state[torchbearer.CALLBACK_LIST].state_dict()
        }
        return state_dict

    def load_state_dict(self, state_dict, resume=True, **kwargs):
        """Resume this trial from the given state. Expects that this trial was constructed in the same way. Optionally,
        just load the model state when resume=False.

        Args:
            state_dict (dict): The state dict to reload
            resume (bool): If True, resume from the given state. Else, just load in the model weights.
            kwargs: See: `torch.nn.Module.load_state_dict <https://pytorch.org/docs/stable/nn.html?highlight=#torch.nn.Module.load_state_dict>`_

        Returns:
            Trial: self
        """
        if resume and torchbearer.MODEL in state_dict:  # torchbearer dict
            if torchbearer.VERSION in state_dict and state_dict[torchbearer.VERSION] != torchbearer.__version__.replace('.dev', ''):
                warnings.warn('This state dict was saved with a different torchbearer version, loading available keys. Consider setting resume=False')

            if torchbearer.MODEL in state_dict:
                self.state[torchbearer.MODEL].load_state_dict(state_dict[torchbearer.MODEL], **kwargs)

            if torchbearer.OPTIMIZER in state_dict:
                self.state[torchbearer.OPTIMIZER].load_state_dict(state_dict[torchbearer.OPTIMIZER])

            if torchbearer.HISTORY in state_dict:
                self.state[torchbearer.HISTORY] = state_dict[torchbearer.HISTORY]

            if torchbearer.CALLBACK_LIST in state_dict:
                self.state[torchbearer.CALLBACK_LIST].load_state_dict(state_dict[torchbearer.CALLBACK_LIST])
        elif torchbearer.MODEL in state_dict:
            self.state[torchbearer.MODEL].load_state_dict(state_dict[torchbearer.MODEL], **kwargs)
        else:  # something else
            warnings.warn('Not a torchbearer state dict, passing to model')
            self.state[torchbearer.MODEL].load_state_dict(state_dict, **kwargs)

        return self
Esempio n. 6
0
class Trial(object):
    """ The trial class contains all of the required hyper-parameters for model running in torchbearer and presents an
    API for model fitting, evaluating and predicting.

    :param model: The base pytorch model
    :type model: torch.nn.Module
    :param optimizer: The optimizer used for pytorch model weight updates
    :type optimizer: torch.optim.Optimizer
    :param criterion: The final loss criterion that provides a loss value to the optimizer
    :type criterion: function or None
    :param metrics: The list of :class:`torchbearer.Metric <.Metric>` instances to process during fitting
    :type metrics: list
    :param callbacks: The list of :class:`torchbearer.Callback <.Callback>` instances to call during fitting
    :type callbacks: list
    :param pass_state: If True, the torchbearer state will be passed to the model during fitting
    :type pass_state: bool
    """
    def __init__(self,
                 model,
                 optimizer=None,
                 criterion=None,
                 metrics=[],
                 callbacks=[],
                 pass_state=False):
        if criterion is None:

            def criterion(_, y_true):
                return torch.zeros(1, device=y_true.device)

        self.pass_state = pass_state

        self.state = State()
        self.state.update({
            torchbearer.MODEL:
            model,
            torchbearer.CRITERION:
            criterion,
            torchbearer.OPTIMIZER:
            optimizer if optimizer is not None else MockOptimizer(),
            torchbearer.METRIC_LIST:
            MetricList(metrics),
            torchbearer.CALLBACK_LIST:
            CallbackList(callbacks),
            torchbearer.DEVICE:
            'cpu',
            torchbearer.DATA_TYPE:
            torch.float32,
            torchbearer.SELF:
            self,
            torchbearer.HISTORY: [],
            torchbearer.BACKWARD_ARGS: {},
            torchbearer.TRAIN_GENERATOR:
            None,
            torchbearer.VALIDATION_GENERATOR:
            None,
            torchbearer.TEST_GENERATOR:
            None,
            torchbearer.TRAIN_STEPS:
            None,
            torchbearer.VALIDATION_STEPS:
            None,
            torchbearer.TEST_STEPS:
            None,
            torchbearer.TRAIN_DATA:
            None,
            torchbearer.VALIDATION_DATA:
            None,
            torchbearer.TEST_DATA:
            None,
        })

    @fluent
    def for_train_steps(self, steps):
        """Run this trial for the given number of training steps. Note that the generator will output (None, None) if it
        has not been set. Useful for differentiable programming. Returns self so that methods can be chained for
        convenience.

        :param steps: The number of training steps per epoch to run
        :type steps: int
        :return: self
        :rtype: Trial
        """
        if not isinstance(steps, int):
            warnings.warn(
                "Number of training steps is not an int, casting to int")
            steps = int(steps)
        generator = self.state[torchbearer.TRAIN_GENERATOR]
        if generator is not None and steps > len(generator):
            warnings.warn(
                "Number of training steps exceeds number of data items, limiting to number of items"
            )
            steps = len(generator)
        self.state[torchbearer.TRAIN_STEPS] = steps
        self.state[torchbearer.TRAIN_DATA] = (
            self.state[torchbearer.TRAIN_GENERATOR],
            self.state[torchbearer.TRAIN_STEPS])

    @fluent
    def with_train_generator(self, generator, steps=None):
        """Use this trial with the given train generator. Returns self so that methods can be chained for convenience.

        :param generator: The train data generator to use during calls to :meth:`.run`
        :type generator: DataLoader
        :param steps: The number of steps per epoch to take when using this generator
        :type steps: int
        :return: self
        :rtype: Trial
        """
        self.state[torchbearer.TRAIN_GENERATOR] = generator
        steps = len(generator) if steps is None else steps
        self.for_train_steps(steps)

    @fluent
    def with_train_data(self,
                        x,
                        y,
                        batch_size=1,
                        shuffle=True,
                        num_workers=1,
                        steps=None):
        """Use this trial with the given train data. Returns self so that methods can be chained for convenience.

        :param x: The train x data to use during calls to :meth:`.run`
        :type x: torch.Tensor
        :param y: The train labels to use during calls to :meth:`.run`
        :type y: torch.Tensor
        :param batch_size: The size of each batch to sample from the data
        :type batch_size: int
        :param shuffle: If True, then data will be shuffled each epoch
        :type shuffle: bool
        :param num_workers: Number of worker threads to use in the data loader
        :type num_workers: int
        :param steps: The number of steps per epoch to take when using this data
        :type steps: int
        :return: self
        :rtype: Trial
        """
        dataset = TensorDataset(x, y)
        dataloader = DataLoader(dataset,
                                batch_size,
                                shuffle=shuffle,
                                num_workers=num_workers)
        self.with_train_generator(dataloader, steps=steps)

    @fluent
    def for_val_steps(self, steps):
        """Run this trial for the given number of validation steps. Note that the generator will output (None, None) if
        it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for
        convenience.

        :param steps: The number of validation steps per epoch to run
        :type steps: int
        :return: self
        :rtype: Trial
        """
        if not isinstance(steps, int):
            warnings.warn(
                "Number of validation steps is not an int, casting to int")
            steps = int(steps)
        generator = self.state[torchbearer.VALIDATION_GENERATOR]
        if generator is not None and steps > len(generator):
            warnings.warn(
                "Number of validation steps exceeds number of data items, limiting to number of items"
            )
            steps = len(generator)
        self.state[torchbearer.VALIDATION_STEPS] = steps
        self.state[torchbearer.VALIDATION_DATA] = (
            self.state[torchbearer.VALIDATION_GENERATOR],
            self.state[torchbearer.VALIDATION_STEPS])

    @fluent
    def with_val_generator(self, generator, steps=None):
        """Use this trial with the given validation generator. Returns self so that methods can be chained for
        convenience.

        :param generator: The validation data generator to use during calls to :meth:`.run` and :meth:`.evaluate`
        :type generator: DataLoader
        :param steps: The number of steps per epoch to take when using this generator
        :type steps: int
        :return: self
        :rtype: Trial
        """
        self.state[torchbearer.VALIDATION_GENERATOR] = generator
        steps = len(generator) if steps is None else steps
        self.for_val_steps(steps)

    @fluent
    def with_val_data(self,
                      x,
                      y,
                      batch_size=1,
                      shuffle=True,
                      num_workers=1,
                      steps=None):
        """Use this trial with the given validation data. Returns self so that methods can be chained for convenience.

        :param x: The validation x data to use during calls to :meth:`.run` and :meth:`.evaluate`
        :type x: torch.Tensor
        :param y: The validation labels to use during calls to :meth:`.run` and :meth:`.evaluate`
        :type y: torch.Tensor
        :param batch_size: The size of each batch to sample from the data
        :type batch_size: int
        :param shuffle: If True, then data will be shuffled each epoch
        :type shuffle: bool
        :param num_workers: Number of worker threads to use in the data loader
        :type num_workers: int
        :param steps: The number of steps per epoch to take when using this data
        :type steps: int
        :return: self
        :rtype: Trial
        """
        dataset = TensorDataset(x, y)
        dataloader = DataLoader(dataset,
                                batch_size,
                                shuffle=shuffle,
                                num_workers=num_workers)
        self.with_val_generator(dataloader, steps=steps)

    @fluent
    def for_test_steps(self, steps):
        """Run this trial for the given number of test steps. Note that the generator will output (None, None) if
        it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for
        convenience.

        :param steps: The number of test steps per epoch to run (when using :meth:`.predict`)
        :type steps: int
        :return: self
        :rtype: Trial
        """
        if not isinstance(steps, int):
            warnings.warn("Number of test steps is not an int, casting to int")
            steps = int(steps)
        generator = self.state[torchbearer.TEST_GENERATOR]
        if generator is not None and steps > len(generator):
            warnings.warn(
                "Number of test steps exceeds number of data items, limiting to number of items"
            )
            steps = len(generator)
        self.state[torchbearer.TEST_STEPS] = steps
        self.state[torchbearer.TEST_DATA] = (
            self.state[torchbearer.TEST_GENERATOR],
            self.state[torchbearer.TEST_STEPS])

    @fluent
    def with_test_generator(self, generator, steps=None):
        """Use this trial with the given test generator. Returns self so that methods can be chained for convenience.

        :param generator: The test data generator to use during calls to :meth:`.predict`
        :type generator: DataLoader
        :param steps: The number of steps per epoch to take when using this generator
        :type steps: int
        :return: self
        :rtype: Trial
        """
        self.state[torchbearer.TEST_GENERATOR] = generator
        steps = len(generator) if steps is None else steps
        self.for_test_steps(steps)

    @fluent
    def with_test_data(self, x, batch_size=1, num_workers=1, steps=None):
        """Use this trial with the given test data. Returns self so that methods can be chained for convenience.

        :param x: The test x data to use during calls to :meth:`.predict`
        :type x: torch.Tensor
        :param batch_size: The size of each batch to sample from the data
        :type batch_size: int
        :param num_workers: Number of worker threads to use in the data loader
        :type num_workers: int
        :param steps: The number of steps per epoch to take when using this data
        :type steps: int
        :return: self
        :rtype: Trial
        """
        dataset = TensorDataset(x)
        dataloader = DataLoader(dataset, batch_size, num_workers=num_workers)
        self.with_test_generator(dataloader, steps=steps)

    @fluent
    def for_steps(self, train_steps=None, val_steps=None, test_steps=None):
        """Use this trial for the given number of train, val and test steps. Returns self so that methods can be chained
        for convenience.

        :param train_steps: The number of training steps per epoch to run
        :type train_steps: int, optional
        :param val_steps: The number of validation steps per epoch to run
        :type val_steps: int, optional
        :param test_steps: The number of test steps per epoch to run (when using :meth:`.predict`)
        :type test_steps: int, optional
        :return: self
        :rtype: Trial
        """
        if train_steps is not None:
            self.for_train_steps(train_steps)
        if val_steps is not None:
            self.for_val_steps(val_steps)
        if test_steps is not None:
            self.for_test_steps(test_steps)

    @fluent
    def with_generators(self,
                        train_generator=None,
                        val_generator=None,
                        test_generator=None,
                        train_steps=None,
                        val_steps=None,
                        test_steps=None):
        """Use this trial with the given generators. Returns self so that methods can be chained for convenience.

        :param train_generator: The training data generator to use during calls to :meth:`.run`
        :type train_generator: DataLoader
        :param val_generator: The validation data generator to use during calls to :meth:`.run` and :meth:`.evaluate`
        :type val_generator: DataLoader
        :param test_generator: The testing data generator to use during calls to :meth:`.predict`
        :type test_generator: DataLoader
        :param train_steps: The number of steps per epoch to take when using the training generator
        :type train_steps: int
        :param val_steps: The number of steps per epoch to take when using the validation generator
        :type val_steps: int
        :param test_steps: The number of steps per epoch to take when using the testing generator
        :type test_steps: int
        :return: self
        :rtype: Trial
        """
        if train_generator is not None:
            self.with_train_generator(train_generator, train_steps)
        if val_generator is not None:
            self.with_val_generator(val_generator, val_steps)
        if test_generator is not None:
            self.with_test_generator(test_generator, test_steps)

    @inject_printer()
    def run(self, epochs=1, verbose=2):
        """Run this trial for the given number of epochs, starting from the last trained epoch.

        :param epochs: The number of epochs to run for
        :type epochs: int
        :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
        :type verbose: int
        :return: The model history (dict of epoch metrics)
        :rtype: dict
        """
        state = State()
        state.update({
            torchbearer.MAX_EPOCHS: epochs,
            torchbearer.STOP_TRAINING: False
        })

        state.update(
            self.state
        )  # TODO: Swap this for something which makes `self.state` still mutable

        state[torchbearer.CALLBACK_LIST].on_start(state)

        for state[torchbearer.EPOCH] in range(len(state[torchbearer.HISTORY]),
                                              state[torchbearer.MAX_EPOCHS]):
            state[torchbearer.CALLBACK_LIST].on_start_epoch(state)

            final_metrics = self._fit_pass(state)[torchbearer.METRICS]

            if state[torchbearer.STOP_TRAINING]:
                break

            final_metrics.update(self._validation_pass(state))
            state[torchbearer.METRICS] = final_metrics
            state[torchbearer.CALLBACK_LIST].on_end_epoch(state)
            steps_summary = (state[torchbearer.TRAIN_STEPS],
                             state[torchbearer.VALIDATION_STEPS])
            self.state[torchbearer.HISTORY].append(
                (steps_summary, state[torchbearer.METRICS]))

            if state[torchbearer.STOP_TRAINING]:
                break

        state[torchbearer.CALLBACK_LIST].on_end(state)

        return self.state[torchbearer.HISTORY]

    @inject_sampler(torchbearer.TRAIN_DATA)
    def _fit_pass(self, state):
        state.update(
            self.state
        )  # TODO: Hack to make injection work, should be removed if `self.state` is mutable

        self.train()

        state[torchbearer.ITERATOR] = iter(
            state[torchbearer.GENERATOR]) if state[
                torchbearer.
                GENERATOR] is not None else None  # TODO: Inject this?

        state[torchbearer.METRIC_LIST].reset(state)
        state[torchbearer.METRICS] = {}

        state[torchbearer.CALLBACK_LIST].on_start_training(state)

        for state[torchbearer.BATCH] in range(0, state[torchbearer.STEPS]):
            state[torchbearer.SAMPLER].sample(state)
            state[torchbearer.CALLBACK_LIST].on_sample(state)

            def closure():
                # Zero grads
                state[torchbearer.OPTIMIZER].zero_grad()

                # Forward Pass
                if self.pass_state:
                    state[torchbearer.Y_PRED] = state[torchbearer.MODEL](
                        state[torchbearer.X], state=state)
                else:
                    state[torchbearer.Y_PRED] = state[torchbearer.MODEL](
                        state[torchbearer.X])

                state[torchbearer.CALLBACK_LIST].on_forward(state)

                # Loss Calculation
                state[torchbearer.LOSS] = state[torchbearer.CRITERION](
                    state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE])

                state[torchbearer.CALLBACK_LIST].on_criterion(state)

                # Backwards pass
                state[torchbearer.LOSS].backward(
                    **state[torchbearer.BACKWARD_ARGS])
                state[torchbearer.CALLBACK_LIST].on_backward(state)

            # Update parameters
            state[torchbearer.OPTIMIZER].step(closure)
            state[torchbearer.METRICS] = state[
                torchbearer.METRIC_LIST].process(state)
            state[torchbearer.CALLBACK_LIST].on_step_training(state)

            if state[torchbearer.STOP_TRAINING]:
                break

        state[torchbearer.METRICS].update(
            state[torchbearer.METRIC_LIST].process_final(state))

        state[torchbearer.CALLBACK_LIST].on_end_training(state)
        return state

    def _test_pass(self, state):
        with torch.no_grad():
            state[torchbearer.ITERATOR] = iter(
                state[torchbearer.GENERATOR]) if state[
                    torchbearer.
                    GENERATOR] is not None else None  # TODO: Inject this?

            state[torchbearer.METRIC_LIST].reset(state)
            state[torchbearer.METRICS] = {}

            state[torchbearer.CALLBACK_LIST].on_start_validation(state)

            for state[torchbearer.BATCH] in range(state[torchbearer.STEPS]):
                state[torchbearer.SAMPLER].sample(state)
                state[torchbearer.CALLBACK_LIST].on_sample_validation(state)

                # Forward Pass
                if self.pass_state:
                    state[torchbearer.Y_PRED] = state[torchbearer.MODEL](
                        state[torchbearer.X], state=state)
                else:
                    state[torchbearer.Y_PRED] = state[torchbearer.MODEL](
                        state[torchbearer.X])

                state[torchbearer.CALLBACK_LIST].on_forward_validation(state)

                # Loss and metrics
                if torchbearer.Y_TRUE in state:
                    state[torchbearer.LOSS] = state[torchbearer.CRITERION](
                        state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE])
                    state[torchbearer.CALLBACK_LIST].on_criterion_validation(
                        state)
                    state[torchbearer.METRICS] = state[
                        torchbearer.METRIC_LIST].process(state)

                state[torchbearer.CALLBACK_LIST].on_step_validation(state)
                if state[torchbearer.STOP_TRAINING]:
                    break

            if torchbearer.Y_TRUE in state:
                state[torchbearer.METRICS].update(
                    state[torchbearer.METRIC_LIST].process_final(state))
            state[torchbearer.CALLBACK_LIST].on_end_validation(state)
        return state

    @inject_sampler(torchbearer.VALIDATION_DATA)
    def _validation_pass(self, state):
        state.update(
            self.state
        )  # TODO: Hack to make injection work, should be removed if `self.state` is mutable

        if state[torchbearer.VALIDATION_GENERATOR] is not None or state[
                torchbearer.VALIDATION_STEPS] is not None:
            self.eval()

            self._test_pass(state)
        return state[torchbearer.METRICS]

    @inject_sampler(torchbearer.VALIDATION_DATA)
    @inject_printer(validation_label_letter='e')
    def evaluate(
        self,
        verbose=2,
        data_key=None
    ):  # Note: kwargs appear unused but are inspected in inject_sampler
        """Evaluate this trial on the validation data.

        :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
        :type verbose: int
        :param data_key: Optional key for the data to evaluate on. Default: torchbearer.VALIDATION_DATA
        :type data_key: StateKey
        :return: The final metric values
        :rtype: dict
        """
        state = State()
        state.update({
            torchbearer.MAX_EPOCHS: 1,
            torchbearer.EPOCH: 0,
            torchbearer.STOP_TRAINING: False
        })
        state.update(
            self.state
        )  # TODO: Hack to make injection work, should be removed if `self.state` is mutable

        if state[torchbearer.GENERATOR] is not None or state[
                torchbearer.STEPS] is not None:
            self.eval()

            return self._test_pass(state)[torchbearer.METRICS]
        return {}

    @inject_callback(AggregatePredictions())
    @inject_sampler(torchbearer.TEST_DATA, predict=True)
    @inject_printer(validation_label_letter='p')
    def predict(
        self,
        verbose=2,
        data_key=None
    ):  # Note: kwargs appear unused but are inspected in inject_sampler
        """Determine predictions for this trial on the test data.

        :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
        :type verbose: int
        :param data_key: Optional key for the data to predict on. Default: torchbearer.TEST_DATA
        :type data_key: StateKey
        :return: Model outputs as a list
        :rtype: list
        """
        state = {
            torchbearer.MAX_EPOCHS: 1,
            torchbearer.EPOCH: 0,
            torchbearer.STOP_TRAINING: False
        }

        state.update(
            self.state
        )  # TODO: Hack to make injection work, should be removed if `self.state` is mutable

        if state[torchbearer.GENERATOR] is not None or state[
                torchbearer.STEPS] is not None:
            self.eval()

            return self._test_pass(state)[torchbearer.FINAL_PREDICTIONS]
        return []

    @fluent
    def replay(self,
               callbacks=[],
               verbose=2
               ):  # TODO: Should we track if testing passes have happened?
        """ Replay the fit passes stored in history with given callbacks, useful when reloading a saved Trial. Note that only progress and metric information is populated in state during a replay.

        :param callbacks: List of callbacks to be run during the replay
        :type callbacks: list
        :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
        :type verbose: int
        :return: self
        :rtype: Trial
        """
        history = self.state[torchbearer.HISTORY]
        callbacks.append(
            get_printer(verbose=verbose, validation_label_letter='v'))
        callbacks = CallbackList(callbacks)

        state = State()
        state.update(self.state)
        state[torchbearer.STOP_TRAINING] = False
        state[torchbearer.MAX_EPOCHS] = len(history)

        callbacks.on_start(state)
        for i in range(len(history)):
            state[torchbearer.EPOCH] = i
            state[torchbearer.TRAIN_STEPS], state[
                torchbearer.VALIDATION_STEPS] = history[i][0]
            state[torchbearer.METRICS] = history[i][1]

            self._replay_pass(state, callbacks)
        callbacks.on_end(state)

    @fluent
    def _replay_pass(self, state, callback_list):
        callback_list.on_start_epoch(state)
        all_metrics = state[torchbearer.METRICS]

        # Training pass
        state[torchbearer.STEPS] = state[torchbearer.TRAIN_STEPS]
        state[torchbearer.METRICS] = {
            key: all_metrics[key]
            for key in all_metrics.keys() if "val_" not in key
        }
        callback_list.on_start_training(state)
        for state[torchbearer.BATCH] in range(state[torchbearer.STEPS]):
            callback_list.on_sample(state)
            callback_list.on_forward(state)
            callback_list.on_criterion(state)
            callback_list.on_backward(state)
            callback_list.on_step_training(state)
            if state[torchbearer.STOP_TRAINING]:
                break
        callback_list.on_end_training(state)

        # Validation pass
        if not state[torchbearer.STOP_TRAINING]:
            state[torchbearer.STEPS] = state[torchbearer.VALIDATION_STEPS]
            state[torchbearer.METRICS] = {
                key: all_metrics[key]
                for key in all_metrics.keys() if "val_" in key
            }
            callback_list.on_start_validation(state)
            for state[torchbearer.BATCH] in range(state[torchbearer.STEPS]):
                callback_list.on_sample_validation(state)
                callback_list.on_forward_validation(state)
                callback_list.on_criterion_validation(state)
                callback_list.on_step_validation(state)
                if state[torchbearer.STOP_TRAINING]:
                    break
            callback_list.on_end_validation(state)

        state[torchbearer.METRICS] = all_metrics
        callback_list.on_end_epoch(state)

    @fluent
    def train(self):
        """Set model and metrics to training mode.

        :return: self
        :rtype: Trial
        """
        self.state[torchbearer.MODEL].train()
        self.state[torchbearer.METRIC_LIST].train()

    @fluent
    def eval(self):
        """Set model and metrics to evaluation mode

        :return: self
        :rtype: Trial
        """
        self.state[torchbearer.MODEL].eval()
        self.state[torchbearer.METRIC_LIST].eval()

    @fluent
    def to(self, *args, **kwargs):
        """ Moves and/or casts the parameters and buffers.

        :param args: See: `torch.nn.Module.to <https://pytorch.org/docs/stable/nn.html?highlight=#torch.nn.Module.to>`_
        :param kwargs: See: `torch.nn.Module.to <https://pytorch.org/docs/stable/nn.html?highlight=#torch.nn.Module.to>`_
        :return: self
        :rtype: Trial
        """
        self.state[torchbearer.MODEL].to(*args, **kwargs)

        for state in self.state[torchbearer.OPTIMIZER].state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(*args, **kwargs)

        self.state = update_device_and_dtype(self.state, *args, **kwargs)

    @fluent
    def cuda(self, device=None):
        """ Moves all model parameters and buffers to the GPU.

        :param device: if specified, all parameters will be copied to that device
        :type device: int, optional
        :return: self
        :rtype: Trial
        """
        if device is None:
            device = torch.cuda.current_device()
        self.to('cuda:' + str(device))

    @fluent
    def cpu(self):
        """ Moves all model parameters and buffers to the CPU.

        :return: self
        :rtype: Trial
        """
        self.to('cpu')

    def state_dict(self, **kwargs):
        """Get a dict containing the model and optimizer states, as well as the model history.

        :param kwargs: See: `torch.nn.Module.state_dict <https://pytorch.org/docs/stable/nn.html?highlight=#torch.nn.Module.state_dict>`_
        :return: A dict containing parameters and persistent buffers.
        :rtype: dict
        """
        state_dict = {
            torchbearer.VERSION:
            torchbearer.__version__.replace('.dev', ''),
            torchbearer.MODEL:
            self.state[torchbearer.MODEL].state_dict(**kwargs),
            torchbearer.OPTIMIZER:
            self.state[torchbearer.OPTIMIZER].state_dict(),
            torchbearer.HISTORY:
            self.state[torchbearer.HISTORY],
            torchbearer.CALLBACK_LIST:
            self.state[torchbearer.CALLBACK_LIST].state_dict()
        }
        return state_dict

    @fluent
    def load_state_dict(self, state_dict, resume=True, **kwargs):
        """Resume this trial from the given state. Expects that this trial was constructed in the same way. Optionally,
        just load the model state when resume=False.

        :param state_dict: The state dict to reload
        :type state_dict: dict
        :param resume: If True, resume from the given state. Else, just load in the model weights.
        :param kwargs: See: `torch.nn.Module.load_state_dict <https://pytorch.org/docs/stable/nn.html?highlight=#torch.nn.Module.load_state_dict>`_
        :return: self
        :rtype: Trial
        """
        if resume and torchbearer.MODEL in state_dict:  # torchbearer dict
            if torchbearer.VERSION in state_dict and state_dict[
                    torchbearer.
                    VERSION] is not torchbearer.__version__.replace(
                        '.dev', ''):
                warnings.warn(
                    'This state dict was saved with a different torchbearer version, loading available keys. Consider setting resume=False'
                )

            if torchbearer.MODEL in state_dict:
                self.state[torchbearer.MODEL].load_state_dict(
                    state_dict[torchbearer.MODEL], **kwargs)

            if torchbearer.OPTIMIZER in state_dict:
                self.state[torchbearer.OPTIMIZER].load_state_dict(
                    state_dict[torchbearer.OPTIMIZER])

            if torchbearer.HISTORY in state_dict:
                self.state[torchbearer.HISTORY] = state_dict[
                    torchbearer.HISTORY]

            if torchbearer.CALLBACK_LIST in state_dict:
                self.state[torchbearer.CALLBACK_LIST].load_state_dict(
                    state_dict[torchbearer.CALLBACK_LIST])
        elif torchbearer.MODEL in state_dict:
            self.state[torchbearer.MODEL].load_state_dict(
                state_dict[torchbearer.MODEL], **kwargs)
        else:  # something else
            warnings.warn('Not a torchbearer state dict, passing to model')
            self.state[torchbearer.MODEL].load_state_dict(state_dict, **kwargs)