예제 #1
0
    def evaluate_generator(self,
                           generator,
                           verbose=1,
                           steps=None,
                           pass_state=False):
        """ Perform an evaluation loop on given data generator to evaluate metrics

        :param generator: The evaluation data generator (usually a pytorch DataLoader)
        :type generator: DataLoader
        :param verbose: If 1 use tqdm progress frontend, else display no training progress
        :type verbose: int
        :param steps: The number of evaluation mini-batches to run
        :type steps: int
        :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
        :type pass_state: bool
        :return: The dictionary containing final metrics
        :rtype: dict[str,any]
        """

        state = {
            torchbearer.EPOCH: 0,
            torchbearer.MAX_EPOCHS: 1,
            torchbearer.STOP_TRAINING: False,
            torchbearer.VALIDATION_GENERATOR: generator
        }
        state.update(self.main_state)

        _callbacks = []
        if verbose == 1:
            _callbacks.append(Tqdm('e'))
        self._test_loop(state, CallbackList(_callbacks), pass_state,
                        self._load_batch_standard, steps)

        return state[torchbearer.METRICS]
예제 #2
0
    def predict_generator(self,
                          generator,
                          verbose=1,
                          steps=None,
                          pass_state=False):
        """Perform a prediction loop on given data generator to predict labels

        :param generator: The prediction data generator (usually a pytorch DataLoader)
        :type generator: DataLoader
        :param verbose: If 1 use tqdm progress frontend, else display no training progress
        :type verbose: int
        :param steps: The number of evaluation mini-batches to run
        :type steps: int
        :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
        :type pass_state: bool
        :return: Tensor of final predicted labels
        :rtype: torch.Tensor
        """
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.MAX_EPOCHS: 1,
            torchbearer.STOP_TRAINING: False,
            torchbearer.VALIDATION_GENERATOR: generator
        }
        state.update(self.main_state)

        _callbacks = [AggregatePredictions()]
        if verbose == 1:
            _callbacks.append(Tqdm('p'))
        self._test_loop(state, CallbackList(_callbacks), pass_state,
                        self._load_batch_predict, steps)

        return state[torchbearer.FINAL_PREDICTIONS]
예제 #3
0
    def _add_printer(callbacks, verbose, validation_label_letter='v'):
        """Static method used to add the printer callback to the given list for the given verbose level

        :param callbacks: The list to add to
        :type callbacks: list
        :param verbose: 2, 1 or 0, Most -> Least verbose
        :type verbose: int
        :param validation_label_letter: Pass to Tqdm
        :type validation_label_letter: str
        :return: The updated list
        :rtype: list
        """
        if verbose >= 2:
            return [Tqdm(validation_label_letter=validation_label_letter)] + callbacks
        elif verbose >= 1:
            return [Tqdm(validation_label_letter=validation_label_letter, on_epoch=True)] + callbacks
        else:
            return callbacks
예제 #4
0
    def fit_generator(self,
                      generator,
                      train_steps=None,
                      epochs=1,
                      verbose=1,
                      callbacks=[],
                      validation_generator=None,
                      validation_steps=None,
                      initial_epoch=0,
                      pass_state=False):
        """ Perform fitting of a model to given data generator

        :param generator: The training data generator (usually a pytorch DataLoader)
        :type generator: DataLoader
        :param train_steps: The number of training mini-batches to run per epoch
        :type train_steps: int
        :param epochs: The number of training epochs to be run (each sample from the dataset is viewed exactly once)
        :type epochs: int
        :param verbose: If 1 use tqdm progress frontend, else display no training progress
        :type verbose: int
        :param callbacks: The list of torchbearer callbacks to be called during training and validation
        :type callbacks: list
        :param validation_generator: The validation data generator (usually a pytorch DataLoader)
        :type validation_generator: DataLoader
        :param validation_steps: The number of validation mini-batches to run per epoch
        :type validation_steps: int
        :param initial_epoch: The integer value representing the first epoch - useful for continuing training after a number of epochs
        :type initial_epoch: int
        :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
        :type pass_state: bool
        :return: The final state context dictionary
        :rtype: dict[str,any]
        """
        if verbose == 1:
            callbacks = [Tqdm()] + callbacks
        _callbacks = CallbackList(callbacks)

        # Get train and validation steps
        if validation_steps is None and validation_generator is not None:
            validation_steps = len(validation_generator)
        if train_steps is None or train_steps > len(generator):
            train_steps = len(generator)
        if not isinstance(train_steps, int):
            train_steps = int(train_steps)
            warnings.warn(
                "Number of training steps is not an int, converting to int")

        if not isinstance(epochs, int):
            if isinstance(epochs, float):
                epochs = int(epochs)
                warnings.warn("Number of epochs is a float, converting to int")
            else:
                warnings.warn(
                    "Number of epochs is neither float nor int, setting to 0")
                epochs = 0

        # Init state
        state = {
            torchbearer.MAX_EPOCHS: epochs,
            torchbearer.TRAIN_STEPS: train_steps,
            torchbearer.BATCH: 0,
            torchbearer.GENERATOR: generator,
            torchbearer.STOP_TRAINING: False
        }
        state.update(self.main_state)

        _callbacks.on_start(state)

        for state[torchbearer.EPOCH] in range(initial_epoch, epochs):
            _callbacks.on_start_epoch(state)

            state[torchbearer.TRAIN_ITERATOR] = iter(
                state[torchbearer.GENERATOR])
            self.train()

            _callbacks.on_start_training(state)
            state[torchbearer.METRIC_LIST].reset(state)
            state[torchbearer.METRICS] = {}

            for state[torchbearer.BATCH] in range(
                    0, state[torchbearer.TRAIN_STEPS]):
                # Extract batch
                self._load_batch_standard('train', state)
                _callbacks.on_sample(state)

                # Zero grads
                state[torchbearer.OPTIMIZER].zero_grad()

                # Forward pass
                if pass_state:
                    state[torchbearer.Y_PRED] = state[torchbearer.MODEL](
                        state[torchbearer.X], state=state)
                else:
                    state[torchbearer.Y_PRED] = state[torchbearer.MODEL](
                        state[torchbearer.X])
                _callbacks.on_forward(state)

                # Loss Calculation
                state[torchbearer.LOSS] = state[torchbearer.CRITERION](
                    state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE])

                _callbacks.on_criterion(state)
                state[torchbearer.METRICS] = state[
                    torchbearer.METRIC_LIST].process(state)

                # Backwards pass
                state[torchbearer.LOSS].backward()
                _callbacks.on_backward(state)

                # Update parameters
                state[torchbearer.OPTIMIZER].step()
                _callbacks.on_step_training(state)

                if state[torchbearer.STOP_TRAINING]:
                    break

            state[torchbearer.METRICS].update(
                state[torchbearer.METRIC_LIST].process_final(state))
            final_metrics = state[torchbearer.METRICS]

            _callbacks.on_end_training(state)

            # Validate
            if validation_generator is not None:
                state[torchbearer.VALIDATION_GENERATOR] = validation_generator
                state[torchbearer.VALIDATION_STEPS] = validation_steps
                self.eval()
                self._validate(state, _callbacks, pass_state)

            final_metrics.update(state[torchbearer.METRICS])
            state[torchbearer.METRICS] = final_metrics
            _callbacks.on_end_epoch(state)

            if state[torchbearer.STOP_TRAINING]:
                break
        _callbacks.on_end(state)

        return state