Example #1
0
    def __init__(self,
                 training_update_function,
                 validation_inference_function=None):

        self._logger = self._get_logger()
        self._training_update_function = training_update_function
        self._validation_inference_function = validation_inference_function
        self._event_handlers = {}

        self.training_history = History()
        self.validation_history = History()
        self.current_iteration = 0
        self.current_validation_iteration = 0
        self.current_epoch = 0
        self.max_epochs = 0
        self.should_terminate = False
Example #2
0
def test_history_clear():
    history = History()
    for i in range(5):
        history.append(i)
    history.clear()
    assert len(history) == 0
Example #3
0
def test_exponential_moving_average():
    history = History()
    history.append(1)
    history.append(3)
    history.append(4)
    alpha = 0.3
    epa = history.exponential_moving_average(window_size=3, alpha=alpha)
    expected_epa = (4 + (1 - alpha) * 3 +
                    (1 - alpha)**2 * 1) / (1 + (1 - alpha) + (1 - alpha)**2)
    np.testing.assert_almost_equal(epa, expected_epa)

    history = History()
    history.append({'loss': 1, 'other': 2})
    history.append({'loss': 3, 'other': 6})
    history.append({'loss': 4, 'other': 5})
    epa = history.exponential_moving_average(window_size=3,
                                             alpha=alpha,
                                             transform=lambda x: x['loss'])
    expected_epa = (4 + (1 - alpha) * 3 +
                    (1 - alpha)**2 * 1) / (1 + (1 - alpha) + (1 - alpha)**2)
    np.testing.assert_almost_equal(epa, expected_epa)
Example #4
0
def test_weighted_moving_average():
    history = History()
    history.append(1)
    history.append(3)
    history.append(4)
    wma = history.weighted_moving_average(window_size=3,
                                          weights=[0.6, 0.8, 1.0])
    np.testing.assert_almost_equal(wma, (0.6 + (0.8 * 3) + 4) / 2.4)

    history = History()
    history.append({'loss': 1, 'other': 2})
    history.append({'loss': 3, 'other': 6})
    history.append({'loss': 4, 'other': 5})
    wma = history.weighted_moving_average(window_size=3,
                                          weights=[0.6, 0.8, 1.0],
                                          transform=lambda x: x['loss'])
    np.testing.assert_almost_equal(wma, (0.6 + (0.8 * 3) + 4) / 2.4)
Example #5
0
def test_history_simple_moving_average():
    history = History()
    history.append(1)
    history.append(3)
    history.append(4)
    sma = history.simple_moving_average(window_size=3)
    np.testing.assert_almost_equal(sma, 8 / 3.0)

    history = History()
    history.append({'loss': 1, 'other': 2})
    history.append({'loss': 3, 'other': 6})
    history.append({'loss': 4, 'other': 5})
    sma = history.simple_moving_average(window_size=3,
                                        transform=lambda x: x['loss'])
    np.testing.assert_almost_equal(sma, 8 / 3.0)
Example #6
0
class Trainer(object):
    """
    Generic trainer class.

    Training update and validation functions receive batches of data and return values which will
    be stored in the `training_history` and `validation_history`. The trainer defines multiple
    events in `TrainingEvents` for which the user can attach event handlers to. The events get
    passed the trainer, so they can access the training/validation history


    Parameters
    ----------
    training_update_function : callable
        Update function receiving the current training batch in each iteration

    validation_inference_function : callable
        Function receiving data and performing a feed forward without update
    """
    def __init__(self,
                 training_update_function,
                 validation_inference_function=None):

        self._logger = self._get_logger()
        self._training_update_function = training_update_function
        self._validation_inference_function = validation_inference_function
        self._event_handlers = {}

        self.training_history = History()
        self.validation_history = History()
        self.current_iteration = 0
        self.current_validation_iteration = 0
        self.current_epoch = 0
        self.max_epochs = 0
        self.should_terminate = False

    def _get_logger(self):
        logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
        logger.addHandler(logging.NullHandler())
        return logger

    def add_event_handler(self, event_name, handler, *args, **kwargs):
        """
        Add an event handler to be executed when the specified event is fired

        Parameters
        ----------
        event_name: enum
            event from ignite.trainer.TrainingEvents to attach the
            handler to
        handler: Callable
            the callable event handler that should be invoked
        args:
            optional args to be passed to `handler`
        kwargs:
            optional keyword args to be passed to `handler`

        Returns
        -------
        None
        """
        if event_name not in TrainingEvents.__members__.values():
            self._logger.error(
                "attempt to add event handler to non-existent event %s ",
                event_name)
            raise ValueError(
                "Event {} not a valid training event".format(event_name))

        if event_name not in self._event_handlers.keys():
            self._event_handlers[event_name] = []

        self._event_handlers[event_name].append((handler, args, kwargs))
        self._logger.debug("added handler for event % ", event_name)

    def on(self, event_name, *args, **kwargs):
        """
        Decorator shortcut for add_event_handler

        Parameters
        ----------
        event_name: enum
            event from ignite.trainer.TrainingEvents to attach the
            handler to
        args:
            optional args to be passed to `handler`
        kwargs:
            optional keyword args to be passed to `handler`

        Returns
        -------
        None
        """
        def decorator(f):
            self.add_event_handler(event_name, f, *args, **kwargs)
            return f

        return decorator

    def _fire_event(self, event_name):
        if event_name in self._event_handlers.keys():
            self._logger.debug("firing handlers for event %s ", event_name)
            for func, args, kwargs in self._event_handlers[event_name]:
                func(self, *args, **kwargs)

    def _train_one_epoch(self, training_data):
        self._fire_event(TrainingEvents.TRAINING_EPOCH_STARTED)
        start_time = time.time()

        self.epoch_losses = []
        for _, batch in enumerate(training_data, 1):
            self._fire_event(TrainingEvents.TRAINING_ITERATION_STARTED)

            training_step_result = self._training_update_function(batch)
            if training_step_result is not None:
                self.training_history.append(training_step_result)

            self.current_iteration += 1

            self._fire_event(TrainingEvents.TRAINING_ITERATION_COMPLETED)
            if self.should_terminate:
                return

        time_taken = time.time() - start_time
        hours, mins, secs = _to_hours_mins_secs(time_taken)
        self._logger.info("Epoch[%s] Complete. Time taken: %02d:%02d:%02d",
                          self.current_epoch, hours, mins, secs)

        self._fire_event(TrainingEvents.TRAINING_EPOCH_COMPLETED)

    def validate(self, validation_data):
        """ Evaluates the validation set"""
        if self._validation_inference_function is None:
            raise ValueError(
                "Trainer must have a validation_inference_function in order to validate"
            )

        self.current_validation_iteration = 0
        self._fire_event(TrainingEvents.VALIDATION_STARTING)
        start_time = time.time()

        for _, batch in enumerate(validation_data, 1):
            self._fire_event(TrainingEvents.VALIDATION_ITERATION_STARTED)
            validation_step_result = self._validation_inference_function(batch)
            if validation_step_result is not None:
                self.validation_history.append(validation_step_result)

            self.current_validation_iteration += 1
            self._fire_event(TrainingEvents.VALIDATION_ITERATION_COMPLETED)
            if self.should_terminate:
                break

        time_taken = time.time() - start_time
        hours, mins, secs = _to_hours_mins_secs(time_taken)
        self._logger.info("Validation Complete. Time taken: %02d:%02d:%02d",
                          hours, mins, secs)

        self._fire_event(TrainingEvents.VALIDATION_COMPLETED)

    def terminate(self):
        """
        Sends terminate signal to trainer, so that training terminates after the current iteration
        """
        self._logger.info(
            "Terminate signaled to trainer. " +
            "Training will stop after current iteration is finished")
        self.should_terminate = True

    def run(self, training_data, max_epochs=1):
        """
        Train the model, evaluate the validation set and update best parameters if the validation loss
        improves.
        In the event that the validation set is not run (or doesn't exist), the training loss is used
        to update the best parameters.

        Parameters
        ----------
        training_data : Iterable
            Collection of training batches allowing repeated iteration (e.g., list or DataLoader)
        max_epochs: int, optional
            max epochs to train for [default=1]

        Returns
        -------
        None
        """

        try:
            self._logger.info(
                "Training starting with max_epochs={}".format(max_epochs))

            self.max_epochs = max_epochs

            start_time = time.time()

            self._fire_event(TrainingEvents.TRAINING_STARTED)
            while self.current_epoch < max_epochs and not self.should_terminate:
                self._fire_event(TrainingEvents.EPOCH_STARTED)
                self._train_one_epoch(training_data)
                if self.should_terminate:
                    break

                self._fire_event(TrainingEvents.EPOCH_COMPLETED)
                self.current_epoch += 1

            self._fire_event(TrainingEvents.TRAINING_COMPLETED)
            time_taken = time.time() - start_time
            mins, secs = divmod(time_taken, 60)
            hours, mins = divmod(mins, 60)
            self._logger.info("Training complete. Time taken %02d:%02d:%02d" %
                              (hours, mins, secs))
        except BaseException as e:
            self._logger.error("Training is terminating due to exception: %s",
                               str(e))
            self._fire_event(TrainingEvents.EXCEPTION_RAISED)
            raise e
Example #7
0
def test_mean_squared_error():
    history = History()
    history.append((torch.FloatTensor([[4.5],
                                       [4.0]]), torch.FloatTensor([5.0, 3.5])))
    history.append((torch.FloatTensor([[3.5],
                                       [3.0]]), torch.FloatTensor([3.0, 3.5])))
    result = mean_squared_error(history)
    assert result == approx(0.25)

    history = History()
    history.append({
        'y_pred': torch.FloatTensor([[4.5], [4.0]]),
        'y': torch.FloatTensor([5.0, 3.5])
    })
    history.append({
        'y_pred': torch.FloatTensor([[3.5], [3.0]]),
        'y': torch.FloatTensor([3.0, 3.5])
    })
    result = mean_squared_error(history,
                                transform=lambda x: (x['y_pred'], x['y']))
    assert result == approx(0.25)
Example #8
0
def test_binary_accuracy():
    history = History()
    history.append((torch.FloatTensor([[0.8], [0.6]]), torch.LongTensor([1,
                                                                         0])))
    history.append((torch.FloatTensor([[0.4], [0.2]]), torch.LongTensor([1,
                                                                         0])))
    result = binary_accuracy(history)
    assert result == approx(0.5)

    history = History()
    history.append({
        'y_pred': torch.FloatTensor([[0.8], [0.6]]),
        'y': torch.LongTensor([1, 0])
    })
    history.append({
        'y_pred': torch.FloatTensor([[0.4], [0.2]]),
        'y': torch.LongTensor([1, 0])
    })
    result = binary_accuracy(history,
                             transform=lambda x: (x['y_pred'], x['y']))

    assert result == approx(0.5)
Example #9
0
def test_top_k_categorical_accuracy():
    history = History()
    history.append((torch.FloatTensor([[3, 2, 1],
                                       [6, 5, 4]]), torch.LongTensor([1, 0])))
    history.append((torch.FloatTensor([[9, 8, 7],
                                       [12, 11, 10]]), torch.LongTensor([1,
                                                                         0])))
    top_2_result = top_k_categorical_accuracy(history, k=2)
    top_1_result = top_k_categorical_accuracy(history, k=1)
    assert top_2_result == approx(1.0)
    assert top_1_result == approx(0.5)
    assert top_1_result == approx(categorical_accuracy(history))

    history = History()
    history.append({
        'y_pred': torch.FloatTensor([[3, 2, 1], [6, 5, 4]]),
        'y': torch.LongTensor([1, 0])
    })
    history.append({
        'y_pred': torch.FloatTensor([[9, 8, 7], [12, 11, 10]]),
        'y': torch.LongTensor([1, 0])
    })
    top_2_result = top_k_categorical_accuracy(history,
                                              k=2,
                                              transform=lambda x:
                                              (x['y_pred'], x['y']))
    top_1_result = top_k_categorical_accuracy(history,
                                              k=1,
                                              transform=lambda x:
                                              (x['y_pred'], x['y']))
    assert top_2_result == approx(1.0)
    assert top_1_result == approx(0.5)
    assert top_1_result == approx(
        categorical_accuracy(history,
                             transform=lambda x: (x['y_pred'], x['y'])))
Example #10
0
def test_categorical_accuracy():
    history = History()
    history.append((torch.FloatTensor([[2, 1],
                                       [4, 3]]), torch.LongTensor([1, 0])))
    history.append((torch.FloatTensor([[6, 5],
                                       [8, 7]]), torch.LongTensor([1, 0])))
    result = categorical_accuracy(history)
    assert result == approx(0.5)

    history = History()
    history.append({
        'y_pred': torch.FloatTensor([[2, 1], [4, 3]]),
        'y': torch.LongTensor([1, 0])
    })
    history.append({
        'y_pred': torch.FloatTensor([[6, 5], [8, 7]]),
        'y': torch.LongTensor([1, 0])
    })
    result = categorical_accuracy(history,
                                  transform=lambda x: (x['y_pred'], x['y']))
    assert result == approx(0.5)