def __init__(self, training_update_function, validation_inference_function=None): self._logger = self._get_logger() self._training_update_function = training_update_function self._validation_inference_function = validation_inference_function self._event_handlers = {} self.training_history = History() self.validation_history = History() self.current_iteration = 0 self.current_validation_iteration = 0 self.current_epoch = 0 self.max_epochs = 0 self.should_terminate = False
def test_history_clear(): history = History() for i in range(5): history.append(i) history.clear() assert len(history) == 0
def test_exponential_moving_average(): history = History() history.append(1) history.append(3) history.append(4) alpha = 0.3 epa = history.exponential_moving_average(window_size=3, alpha=alpha) expected_epa = (4 + (1 - alpha) * 3 + (1 - alpha)**2 * 1) / (1 + (1 - alpha) + (1 - alpha)**2) np.testing.assert_almost_equal(epa, expected_epa) history = History() history.append({'loss': 1, 'other': 2}) history.append({'loss': 3, 'other': 6}) history.append({'loss': 4, 'other': 5}) epa = history.exponential_moving_average(window_size=3, alpha=alpha, transform=lambda x: x['loss']) expected_epa = (4 + (1 - alpha) * 3 + (1 - alpha)**2 * 1) / (1 + (1 - alpha) + (1 - alpha)**2) np.testing.assert_almost_equal(epa, expected_epa)
def test_weighted_moving_average(): history = History() history.append(1) history.append(3) history.append(4) wma = history.weighted_moving_average(window_size=3, weights=[0.6, 0.8, 1.0]) np.testing.assert_almost_equal(wma, (0.6 + (0.8 * 3) + 4) / 2.4) history = History() history.append({'loss': 1, 'other': 2}) history.append({'loss': 3, 'other': 6}) history.append({'loss': 4, 'other': 5}) wma = history.weighted_moving_average(window_size=3, weights=[0.6, 0.8, 1.0], transform=lambda x: x['loss']) np.testing.assert_almost_equal(wma, (0.6 + (0.8 * 3) + 4) / 2.4)
def test_history_simple_moving_average(): history = History() history.append(1) history.append(3) history.append(4) sma = history.simple_moving_average(window_size=3) np.testing.assert_almost_equal(sma, 8 / 3.0) history = History() history.append({'loss': 1, 'other': 2}) history.append({'loss': 3, 'other': 6}) history.append({'loss': 4, 'other': 5}) sma = history.simple_moving_average(window_size=3, transform=lambda x: x['loss']) np.testing.assert_almost_equal(sma, 8 / 3.0)
class Trainer(object): """ Generic trainer class. Training update and validation functions receive batches of data and return values which will be stored in the `training_history` and `validation_history`. The trainer defines multiple events in `TrainingEvents` for which the user can attach event handlers to. The events get passed the trainer, so they can access the training/validation history Parameters ---------- training_update_function : callable Update function receiving the current training batch in each iteration validation_inference_function : callable Function receiving data and performing a feed forward without update """ def __init__(self, training_update_function, validation_inference_function=None): self._logger = self._get_logger() self._training_update_function = training_update_function self._validation_inference_function = validation_inference_function self._event_handlers = {} self.training_history = History() self.validation_history = History() self.current_iteration = 0 self.current_validation_iteration = 0 self.current_epoch = 0 self.max_epochs = 0 self.should_terminate = False def _get_logger(self): logger = logging.getLogger(__name__ + "." + self.__class__.__name__) logger.addHandler(logging.NullHandler()) return logger def add_event_handler(self, event_name, handler, *args, **kwargs): """ Add an event handler to be executed when the specified event is fired Parameters ---------- event_name: enum event from ignite.trainer.TrainingEvents to attach the handler to handler: Callable the callable event handler that should be invoked args: optional args to be passed to `handler` kwargs: optional keyword args to be passed to `handler` Returns ------- None """ if event_name not in TrainingEvents.__members__.values(): self._logger.error( "attempt to add event handler to non-existent event %s ", event_name) raise ValueError( "Event {} not a valid training event".format(event_name)) if event_name not in self._event_handlers.keys(): self._event_handlers[event_name] = [] self._event_handlers[event_name].append((handler, args, kwargs)) self._logger.debug("added handler for event % ", event_name) def on(self, event_name, *args, **kwargs): """ Decorator shortcut for add_event_handler Parameters ---------- event_name: enum event from ignite.trainer.TrainingEvents to attach the handler to args: optional args to be passed to `handler` kwargs: optional keyword args to be passed to `handler` Returns ------- None """ def decorator(f): self.add_event_handler(event_name, f, *args, **kwargs) return f return decorator def _fire_event(self, event_name): if event_name in self._event_handlers.keys(): self._logger.debug("firing handlers for event %s ", event_name) for func, args, kwargs in self._event_handlers[event_name]: func(self, *args, **kwargs) def _train_one_epoch(self, training_data): self._fire_event(TrainingEvents.TRAINING_EPOCH_STARTED) start_time = time.time() self.epoch_losses = [] for _, batch in enumerate(training_data, 1): self._fire_event(TrainingEvents.TRAINING_ITERATION_STARTED) training_step_result = self._training_update_function(batch) if training_step_result is not None: self.training_history.append(training_step_result) self.current_iteration += 1 self._fire_event(TrainingEvents.TRAINING_ITERATION_COMPLETED) if self.should_terminate: return time_taken = time.time() - start_time hours, mins, secs = _to_hours_mins_secs(time_taken) self._logger.info("Epoch[%s] Complete. Time taken: %02d:%02d:%02d", self.current_epoch, hours, mins, secs) self._fire_event(TrainingEvents.TRAINING_EPOCH_COMPLETED) def validate(self, validation_data): """ Evaluates the validation set""" if self._validation_inference_function is None: raise ValueError( "Trainer must have a validation_inference_function in order to validate" ) self.current_validation_iteration = 0 self._fire_event(TrainingEvents.VALIDATION_STARTING) start_time = time.time() for _, batch in enumerate(validation_data, 1): self._fire_event(TrainingEvents.VALIDATION_ITERATION_STARTED) validation_step_result = self._validation_inference_function(batch) if validation_step_result is not None: self.validation_history.append(validation_step_result) self.current_validation_iteration += 1 self._fire_event(TrainingEvents.VALIDATION_ITERATION_COMPLETED) if self.should_terminate: break time_taken = time.time() - start_time hours, mins, secs = _to_hours_mins_secs(time_taken) self._logger.info("Validation Complete. Time taken: %02d:%02d:%02d", hours, mins, secs) self._fire_event(TrainingEvents.VALIDATION_COMPLETED) def terminate(self): """ Sends terminate signal to trainer, so that training terminates after the current iteration """ self._logger.info( "Terminate signaled to trainer. " + "Training will stop after current iteration is finished") self.should_terminate = True def run(self, training_data, max_epochs=1): """ Train the model, evaluate the validation set and update best parameters if the validation loss improves. In the event that the validation set is not run (or doesn't exist), the training loss is used to update the best parameters. Parameters ---------- training_data : Iterable Collection of training batches allowing repeated iteration (e.g., list or DataLoader) max_epochs: int, optional max epochs to train for [default=1] Returns ------- None """ try: self._logger.info( "Training starting with max_epochs={}".format(max_epochs)) self.max_epochs = max_epochs start_time = time.time() self._fire_event(TrainingEvents.TRAINING_STARTED) while self.current_epoch < max_epochs and not self.should_terminate: self._fire_event(TrainingEvents.EPOCH_STARTED) self._train_one_epoch(training_data) if self.should_terminate: break self._fire_event(TrainingEvents.EPOCH_COMPLETED) self.current_epoch += 1 self._fire_event(TrainingEvents.TRAINING_COMPLETED) time_taken = time.time() - start_time mins, secs = divmod(time_taken, 60) hours, mins = divmod(mins, 60) self._logger.info("Training complete. Time taken %02d:%02d:%02d" % (hours, mins, secs)) except BaseException as e: self._logger.error("Training is terminating due to exception: %s", str(e)) self._fire_event(TrainingEvents.EXCEPTION_RAISED) raise e
def test_mean_squared_error(): history = History() history.append((torch.FloatTensor([[4.5], [4.0]]), torch.FloatTensor([5.0, 3.5]))) history.append((torch.FloatTensor([[3.5], [3.0]]), torch.FloatTensor([3.0, 3.5]))) result = mean_squared_error(history) assert result == approx(0.25) history = History() history.append({ 'y_pred': torch.FloatTensor([[4.5], [4.0]]), 'y': torch.FloatTensor([5.0, 3.5]) }) history.append({ 'y_pred': torch.FloatTensor([[3.5], [3.0]]), 'y': torch.FloatTensor([3.0, 3.5]) }) result = mean_squared_error(history, transform=lambda x: (x['y_pred'], x['y'])) assert result == approx(0.25)
def test_binary_accuracy(): history = History() history.append((torch.FloatTensor([[0.8], [0.6]]), torch.LongTensor([1, 0]))) history.append((torch.FloatTensor([[0.4], [0.2]]), torch.LongTensor([1, 0]))) result = binary_accuracy(history) assert result == approx(0.5) history = History() history.append({ 'y_pred': torch.FloatTensor([[0.8], [0.6]]), 'y': torch.LongTensor([1, 0]) }) history.append({ 'y_pred': torch.FloatTensor([[0.4], [0.2]]), 'y': torch.LongTensor([1, 0]) }) result = binary_accuracy(history, transform=lambda x: (x['y_pred'], x['y'])) assert result == approx(0.5)
def test_top_k_categorical_accuracy(): history = History() history.append((torch.FloatTensor([[3, 2, 1], [6, 5, 4]]), torch.LongTensor([1, 0]))) history.append((torch.FloatTensor([[9, 8, 7], [12, 11, 10]]), torch.LongTensor([1, 0]))) top_2_result = top_k_categorical_accuracy(history, k=2) top_1_result = top_k_categorical_accuracy(history, k=1) assert top_2_result == approx(1.0) assert top_1_result == approx(0.5) assert top_1_result == approx(categorical_accuracy(history)) history = History() history.append({ 'y_pred': torch.FloatTensor([[3, 2, 1], [6, 5, 4]]), 'y': torch.LongTensor([1, 0]) }) history.append({ 'y_pred': torch.FloatTensor([[9, 8, 7], [12, 11, 10]]), 'y': torch.LongTensor([1, 0]) }) top_2_result = top_k_categorical_accuracy(history, k=2, transform=lambda x: (x['y_pred'], x['y'])) top_1_result = top_k_categorical_accuracy(history, k=1, transform=lambda x: (x['y_pred'], x['y'])) assert top_2_result == approx(1.0) assert top_1_result == approx(0.5) assert top_1_result == approx( categorical_accuracy(history, transform=lambda x: (x['y_pred'], x['y'])))
def test_categorical_accuracy(): history = History() history.append((torch.FloatTensor([[2, 1], [4, 3]]), torch.LongTensor([1, 0]))) history.append((torch.FloatTensor([[6, 5], [8, 7]]), torch.LongTensor([1, 0]))) result = categorical_accuracy(history) assert result == approx(0.5) history = History() history.append({ 'y_pred': torch.FloatTensor([[2, 1], [4, 3]]), 'y': torch.LongTensor([1, 0]) }) history.append({ 'y_pred': torch.FloatTensor([[6, 5], [8, 7]]), 'y': torch.LongTensor([1, 0]) }) result = categorical_accuracy(history, transform=lambda x: (x['y_pred'], x['y'])) assert result == approx(0.5)