示例#1
0
    def _on_epoch_end(self):
        self.on_epoch_end()

        epoch, phase = self.epoch_and_phase

        self.fire(Event.ON_EPOCH_END(Moment(epoch, Frequency.EPOCH, phase)),
                  self.epoch_monitors(phase))
示例#2
0
    def _on_epoch_begin(self):
        self._reset_model_trainers()
        self.on_epoch_begin()

        epoch, phase = self.epoch_and_phase

        self.fire(Event.ON_EPOCH_BEGIN(Moment(epoch, Frequency.EPOCH, phase)))
示例#3
0
    def _on_batch_begin(self):
        self.on_batch_begin()

        iteration, phase = self.iteration_and_phase

        self.fire(
            Event.ON_BATCH_BEGIN(Moment(iteration, Frequency.STEP, phase)))
示例#4
0
    def _on_batch_end(self):
        self.on_batch_end()

        iteration, phase = self.iteration_and_phase

        self.fire(Event.ON_BATCH_END(Moment(iteration, Frequency.STEP, phase)),
                  self.step_monitors(phase))
示例#5
0
 def _on_valid_batch_end(self):
     self.on_valid_batch_begin()
     self.fire(
         Event.ON_VALID_BATCH_END(
             Moment(self.current_valid_step,
                    Frequency.STEP, Phase.VALIDATION)),
         self.step_monitors(Phase.VALIDATION))
    def test_should_not_save_model_with_higher_valid_losses(
            self, model_state_mock, optimizer_states_mock):
        model_state_mock.return_value = dict()
        optimizer_states_mock.return_value = list(dict())
        moment = Moment(200, Frequency.EPOCH, Phase.VALIDATION)

        handler_mock = mockito.spy(
            Checkpoint(self.SAVE_PATH, self.MODEL_NAME, ["MSELoss", "L1Loss"],
                       0.01, MonitorMode.MIN))

        monitors = {
            self.MODEL_NAME: {
                Phase.TRAINING: {
                    Monitor.METRICS: {},
                    Monitor.LOSS: {}
                },
                Phase.VALIDATION: {
                    Monitor.METRICS: {},
                    Monitor.LOSS: {
                        "MSELoss": torch.tensor([0.5]),
                        "L1Loss": torch.tensor([0.5])
                    }
                },
                Phase.TEST: {
                    Monitor.METRICS: {},
                    Monitor.LOSS: {}
                }
            }
        }

        handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors,
                     self._trainer_mock)

        monitors = {
            self.MODEL_NAME: {
                Phase.TRAINING: {
                    Monitor.METRICS: {},
                    Monitor.LOSS: {}
                },
                Phase.VALIDATION: {
                    Monitor.METRICS: {},
                    Monitor.LOSS: {
                        "MSELoss": torch.tensor([0.5]),
                        "L1Loss": torch.tensor([0.6])
                    }
                },
                Phase.TEST: {
                    Monitor.METRICS: {},
                    Monitor.LOSS: {}
                }
            }
        }

        handler_mock(TemporalEvent(Event.ON_EPOCH_END, moment), monitors,
                     self._trainer_mock)

        assert_that(not os.path.exists(
            os.path.join(self.SAVE_PATH, self.MODEL_NAME, self.MODEL_NAME +
                         ".tar")))
示例#7
0
 def _on_test_epoch_end(self):
     self._current_valid_batch = 0
     self._current_test_batch = 0
     self.on_test_epoch_end()
     self.fire(
         Event.ON_TEST_EPOCH_END(
             Moment(self.epoch, Frequency.STEP, Phase.TEST)),
         self.epoch_monitors(Phase.TEST))
示例#8
0
 def _on_valid_batch_begin(self):
     self.on_valid_batch_begin()
     self.fire(
         Event.ON_VALID_BATCH_BEGIN(
             Moment(self.current_valid_step, Frequency.STEP,
                    Phase.VALIDATION)))
示例#9
0
 def _on_train_batch_begin(self):
     self.on_train_batch_begin()
     self.fire(
         Event.ON_TRAIN_BATCH_BEGIN(
             Moment(self.current_train_step, Frequency.STEP,
                    Phase.TRAINING)))
示例#10
0
 def _on_train_batch_end(self):
     self.on_train_batch_end()
     self.fire(
         Event.ON_TRAIN_BATCH_END(
             Moment(self.current_train_step, Frequency.STEP,
                    Phase.TRAINING)), self.step_monitors(Phase.TRAINING))
示例#11
0
 def _on_test_batch_end(self):
     self.on_test_batch_end()
     self.fire(
         Event.ON_TEST_BATCH_END(
             Moment(self.current_test_step, Frequency.STEP, Phase.TEST)),
         self.step_monitors(Phase.TEST))
示例#12
0
 def _on_test_batch_begin(self):
     self.on_test_batch_begin()
     self.fire(
         Event.ON_TEST_BATCH_BEGIN(
             Moment(self.current_test_step, Frequency.STEP, Phase.TEST)))
示例#13
0
 def _on_test_begin(self):
     self.on_test_begin()
     self._status = Status.TESTING
     self.fire(Event.ON_TEST_BEGIN(Moment(0, Frequency.PHASE, Phase.TEST)))
示例#14
0
 def _on_test_end(self):
     self.on_test_end()
     self.fire(Event.ON_TEST_END(Moment(0, Frequency.PHASE, Phase.TEST)))
示例#15
0
 def _on_valid_begin(self):
     self.on_valid_begin()
     self._status = Status.VALIDATING
     self.fire(
         Event.ON_VALID_BEGIN(Moment(0, Frequency.PHASE, Phase.VALIDATION)))
示例#16
0
 def _on_valid_end(self):
     self.on_valid_end()
     self.fire(
         Event.ON_VALID_END(Moment(0, Frequency.PHASE, Phase.VALIDATION)))
示例#17
0
 def _on_training_begin(self):
     self.on_training_begin()
     self._status = Status.TRAINING
     self.fire(
         Event.ON_TRAINING_BEGIN(Moment(0, Frequency.PHASE,
                                        Phase.TRAINING)))
示例#18
0
 def _on_training_end(self):
     self.on_training_end()
     self.scheduler_step()
     self.fire(
         Event.ON_TRAINING_END(Moment(0, Frequency.PHASE, Phase.TRAINING)))
示例#19
0
 def _on_train_epoch_begin(self):
     self._status = Status.TRAINING
     self.on_train_epoch_begin()
     self.fire(
         Event.ON_TRAIN_EPOCH_BEGIN(
             Moment(self.epoch, Frequency.EPOCH, Phase.TRAINING)))
示例#20
0
 def _on_test_epoch_begin(self):
     self.on_test_epoch_begin()
     self.fire(
         Event.ON_TEST_EPOCH_BEGIN(
             Moment(self.epoch, Frequency.EPOCH, Phase.TEST)))
示例#21
0
 def _on_valid_epoch_end(self):
     self.on_valid_epoch_end()
     self.fire(
         Event.ON_VALID_EPOCH_END(
             Moment(self.epoch, Frequency.EPOCH, Phase.VALIDATION)),
         self.epoch_monitors(Phase.VALIDATION))
示例#22
0
 def _on_valid_epoch_begin(self):
     self.on_valid_epoch_begin()
     self.fire(
         Event.ON_VALID_EPOCH_BEGIN(
             Moment(self.epoch, Frequency.EPOCH, Phase.VALIDATION)))
示例#23
0
 def _on_train_epoch_end(self):
     self.on_train_epoch_end()
     self.fire(
         Event.ON_TRAIN_EPOCH_END(
             Moment(self.epoch, Frequency.EPOCH, Phase.TRAINING)),
         self.epoch_monitors(Phase.TRAINING))