Esempio n. 1
0
    def predict_generator(self,
                          generator,
                          verbose=2,
                          steps=None,
                          pass_state=False):
        """Perform a prediction loop on given data generator to predict labels

        :param generator: The prediction data generator (usually a pytorch DataLoader)
        :type generator: DataLoader
        :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
        :type verbose: int
        :param steps: The number of evaluation mini-batches to run
        :type steps: int
        :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
        :type pass_state: bool
        :return: Tensor of final predicted labels
        :rtype: torch.Tensor
        """
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.MAX_EPOCHS: 1,
            torchbearer.STOP_TRAINING: False,
            torchbearer.VALIDATION_GENERATOR: generator
        }
        state.update(self.main_state)

        _callbacks = Model._add_printer([AggregatePredictions()],
                                        verbose,
                                        validation_label_letter='p')

        self._test_loop(state, CallbackList(_callbacks), pass_state,
                        self._load_batch_predict, steps)

        return state[torchbearer.FINAL_PREDICTIONS]
Esempio n. 2
0
    def evaluate_generator(self, generator, verbose=2, steps=None, pass_state=False):
        """ Perform an evaluation loop on given data generator to evaluate metrics

        :param generator: The evaluation data generator (usually a pytorch DataLoader)
        :type generator: DataLoader
        :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
        :type verbose: int
        :param steps: The number of evaluation mini-batches to run
        :type steps: int
        :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
        :type pass_state: bool
        :return: The dictionary containing final metrics
        :rtype: dict[str,any]
        """

        state = {torchbearer.EPOCH: 0, torchbearer.MAX_EPOCHS: 1, torchbearer.STOP_TRAINING: False, torchbearer.VALIDATION_GENERATOR: generator}
        state.update(self.main_state)

        _callbacks = Model._add_printer([], verbose, validation_label_letter='e')

        if state[torchbearer.VALIDATION_GENERATOR] is None:
            batch_loader = self._load_batch_none
        else:
            batch_loader = self._load_batch_standard

        self._test_loop(state, CallbackList(_callbacks), pass_state, batch_loader, steps)

        return state[torchbearer.METRICS]
Esempio n. 3
0
    def evaluate_generator(self,
                           generator,
                           verbose=1,
                           steps=None,
                           pass_state=False):
        """ Perform an evaluation loop on given data generator to evaluate metrics

        :param generator: The evaluation data generator (usually a pytorch DataLoader)
        :type generator: DataLoader
        :param verbose: If 1 use tqdm progress frontend, else display no training progress
        :type verbose: int
        :param steps: The number of evaluation mini-batches to run
        :type steps: int
        :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
        :type pass_state: bool
        :return: The dictionary containing final metrics
        :rtype: dict[str,any]
        """

        state = {
            torchbearer.EPOCH: 0,
            torchbearer.MAX_EPOCHS: 1,
            torchbearer.STOP_TRAINING: False,
            torchbearer.VALIDATION_GENERATOR: generator
        }
        state.update(self.main_state)

        _callbacks = []
        if verbose == 1:
            _callbacks.append(Tqdm('e'))
        self._test_loop(state, CallbackList(_callbacks), pass_state,
                        self._load_batch_standard, steps)

        return state[torchbearer.METRICS]
Esempio n. 4
0
    def fit_generator(self,
                      generator,
                      train_steps=None,
                      epochs=1,
                      verbose=2,
                      callbacks=[],
                      validation_generator=None,
                      validation_steps=None,
                      initial_epoch=0,
                      pass_state=False):
        """ Perform fitting of a model to given data generator

        :param generator: The training data generator (usually a pytorch DataLoader)
        :type generator: DataLoader
        :param train_steps: The number of training mini-batches to run per epoch
        :type train_steps: int
        :param epochs: The number of training epochs to be run (each sample from the dataset is viewed exactly once)
        :type epochs: int
        :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
        :type verbose: int
        :param callbacks: The list of torchbearer callbacks to be called during training and validation
        :type callbacks: list
        :param validation_generator: The validation data generator (usually a pytorch DataLoader)
        :type validation_generator: DataLoader
        :param validation_steps: The number of validation mini-batches to run per epoch
        :type validation_steps: int
        :param initial_epoch: The integer value representing the first epoch - useful for continuing training after a number of epochs
        :type initial_epoch: int
        :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
        :type pass_state: bool
        :return: The final state context dictionary
        :rtype: dict[str,any]
        """
        callbacks = Model._add_printer(callbacks, verbose)
        _callbacks = CallbackList(callbacks)

        # Get train and validation steps
        if validation_steps is None and validation_generator is not None:
            validation_steps = len(validation_generator)
        if train_steps is None:
            train_steps = len(generator)
        if generator is not None and train_steps > len(generator):
            train_steps = len(generator)
        if not isinstance(train_steps, int):
            train_steps = int(train_steps)
            warnings.warn(
                "Number of training steps is not an int, converting to int")

        if not isinstance(epochs, int):
            if isinstance(epochs, float):
                epochs = int(epochs)
                warnings.warn("Number of epochs is a float, converting to int")
            else:
                warnings.warn(
                    "Number of epochs is neither float nor int, setting to 0")
                epochs = 0

        # Init state
        state = {
            torchbearer.MAX_EPOCHS: epochs,
            torchbearer.TRAIN_STEPS: train_steps,
            torchbearer.BATCH: 0,
            torchbearer.GENERATOR: generator,
            torchbearer.STOP_TRAINING: False
        }
        state.update(self.main_state)
        state[torchbearer.CALLBACK_LIST] = state[
            torchbearer.CALLBACK_LIST].copy()
        state[torchbearer.CALLBACK_LIST].append(_callbacks)

        state[torchbearer.CALLBACK_LIST].on_start(state)

        for state[torchbearer.EPOCH] in range(initial_epoch, epochs):
            state[torchbearer.CALLBACK_LIST].on_start_epoch(state)

            if state[torchbearer.GENERATOR] is not None:
                state[torchbearer.TRAIN_ITERATOR] = iter(
                    state[torchbearer.GENERATOR])
            self.train()

            state[torchbearer.CALLBACK_LIST].on_start_training(state)
            state[torchbearer.METRIC_LIST].reset(state)
            state[torchbearer.METRICS] = {}

            for state[torchbearer.BATCH] in range(
                    0, state[torchbearer.TRAIN_STEPS]):
                # Extract batch
                if state[torchbearer.
                         GENERATOR] is None:  # TODO: Replace with flag check
                    self._load_batch_none('train', state)
                else:
                    self._load_batch_standard('train', state)

                state[torchbearer.CALLBACK_LIST].on_sample(state)

                # Zero grads
                state[torchbearer.OPTIMIZER].zero_grad()

                # Forward pass
                if pass_state:
                    state[torchbearer.Y_PRED] = state[torchbearer.MODEL](
                        state[torchbearer.X], state=state)
                else:
                    state[torchbearer.Y_PRED] = state[torchbearer.MODEL](
                        state[torchbearer.X])
                state[torchbearer.CALLBACK_LIST].on_forward(state)

                # Loss Calculation
                state[torchbearer.LOSS] = state[torchbearer.CRITERION](
                    state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE])

                state[torchbearer.CALLBACK_LIST].on_criterion(state)
                state[torchbearer.METRICS] = state[
                    torchbearer.METRIC_LIST].process(state)

                # Backwards pass
                state[torchbearer.LOSS].backward()
                state[torchbearer.CALLBACK_LIST].on_backward(state)

                # Update parameters
                state[torchbearer.OPTIMIZER].step()
                state[torchbearer.CALLBACK_LIST].on_step_training(state)

                if state[torchbearer.STOP_TRAINING]:
                    break

            state[torchbearer.METRICS].update(
                state[torchbearer.METRIC_LIST].process_final(state))
            final_metrics = state[torchbearer.METRICS]

            state[torchbearer.CALLBACK_LIST].on_end_training(state)

            # Validate
            if validation_generator is not None or validation_steps is not None:
                state[torchbearer.VALIDATION_GENERATOR] = validation_generator
                state[torchbearer.VALIDATION_STEPS] = validation_steps
                self.eval()
                self._validate(state, state[torchbearer.CALLBACK_LIST],
                               pass_state)

            final_metrics.update(state[torchbearer.METRICS])
            state[torchbearer.METRICS] = final_metrics
            state[torchbearer.CALLBACK_LIST].on_end_epoch(state)

            if state[torchbearer.STOP_TRAINING]:
                break
        state[torchbearer.CALLBACK_LIST].on_end(state)

        return state
Esempio n. 5
0
    def fit_generator(self, generator, train_steps=None, epochs=1, verbose=1, callbacks=[],
                      validation_generator=None, validation_steps=None, initial_epoch=0, pass_state=False):
        """ Perform fitting of a model to given data generator

        :param generator: The training data generator (usually a pytorch DataLoader)
        :type generator: DataLoader
        :param train_steps: The number of training mini-batches to run per epoch
        :type train_steps: int
        :param epochs: The number of training epochs to be run (each sample from the dataset is viewed exactly once)
        :type epochs: int
        :param verbose: If 1 use tqdm progress frontend, else display no training progress
        :type verbose: int
        :param callbacks: The list of torchbearer callbacks to be called during training and validation
        :type callbacks: list
        :param validation_generator: The validation data generator (usually a pytorch DataLoader)
        :type validation_generator: DataLoader
        :param validation_steps: The number of validation mini-batches to run per epoch
        :type validation_steps: int
        :param initial_epoch: The integer value representing the first epoch - useful for continuing training after a number of epochs
        :type initial_epoch: int
        :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
        :type pass_state: bool
        :return: The final state context dictionary
        :rtype: dict[str,any]
        """
        if verbose == 1:
            callbacks = [Tqdm()] + callbacks
        _callbacks = CallbackList(callbacks)

        # Get train and validation steps
        if validation_steps is None and validation_generator is not None:
            validation_steps = len(validation_generator)
        if train_steps is None or train_steps > len(generator):
            train_steps = len(generator)
        if not isinstance(train_steps, int):
            train_steps = int(train_steps)
            warnings.warn("Number of training steps is not an int, converting to int")

        if not isinstance(epochs, int):
            if isinstance(epochs, float):
                epochs = int(epochs)
                warnings.warn("Number of epochs is a float, converting to int")
            else:
                warnings.warn("Number of epochs is neither float nor int, setting to 0")
                epochs = 0

        # Init state
        state = {
            torchbearer.MAX_EPOCHS: epochs,
            torchbearer.TRAIN_STEPS: train_steps,
            torchbearer.BATCH: 0,
            torchbearer.GENERATOR: generator,
            torchbearer.STOP_TRAINING: False
        }
        state.update(self.main_state)

        _callbacks.on_start(state)

        for state[torchbearer.EPOCH] in range(initial_epoch, epochs):
            _callbacks.on_start_epoch(state)

            state[torchbearer.TRAIN_ITERATOR] = iter(state[torchbearer.GENERATOR])
            self.train()

            _callbacks.on_start_training(state)
            state[torchbearer.METRIC_LIST].reset(state)
            state[torchbearer.METRICS] = {}

            for state[torchbearer.BATCH] in range(0, state[torchbearer.TRAIN_STEPS]):
                # Extract batch
                self._load_batch_standard('train', state)
                _callbacks.on_sample(state)

                # Zero grads
                state[torchbearer.OPTIMIZER].zero_grad()

                # Forward pass
                if pass_state:
                    state[torchbearer.Y_PRED] = state[torchbearer.MODEL](state[torchbearer.X], state=state)
                else:
                    state[torchbearer.Y_PRED] = state[torchbearer.MODEL](state[torchbearer.X])
                _callbacks.on_forward(state)

                # Loss Calculation
                state[torchbearer.LOSS] = state[torchbearer.CRITERION](state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE])

                _callbacks.on_criterion(state)
                state[torchbearer.METRICS] = state[torchbearer.METRIC_LIST].process(state)

                # Backwards pass
                state[torchbearer.LOSS].backward()
                _callbacks.on_backward(state)

                # Update parameters
                state[torchbearer.OPTIMIZER].step()
                _callbacks.on_step_training(state)

                if state[torchbearer.STOP_TRAINING]:
                    break

            state[torchbearer.METRICS].update(state[torchbearer.METRIC_LIST].process_final(state))
            final_metrics = state[torchbearer.METRICS]

            _callbacks.on_end_training(state)

            # Validate
            if validation_generator is not None:
                state[torchbearer.VALIDATION_GENERATOR] = validation_generator
                state[torchbearer.VALIDATION_STEPS] = validation_steps
                self.eval()
                self._validate(state, _callbacks, pass_state)

            final_metrics.update(state[torchbearer.METRICS])
            state[torchbearer.METRICS] = final_metrics
            _callbacks.on_end_epoch(state)

            if state[torchbearer.STOP_TRAINING]:
                break
        _callbacks.on_end(state)

        return state