예제 #1
0
    def __init__(self, trainer: Trainer):
        self._rules, self._prev_states = [], None
        self._best_state_achieved_event = events_container.add_event(
            "BEST_STATE_ACHIEVED", Event(self))

        events_container.event(
            trainer, 'TRAIN_DONE').add_callback(lambda t: self.reset())
예제 #2
0
    def test_saving_states(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()
        metrics_processor = MetricsProcessor()
        stage = TrainStage(
            TestDataProducer([{
                'data': torch.rand(1, 3),
                'target': torch.rand(1)
            } for _ in list(range(20))]))

        class Losses:
            def __init__(self):
                self.v = []
                self._fake_losses = [[i for _ in list(range(20))]
                                     for i in [5, 4, 0, 2, 1]]

            def on_stage_end(self, s: TrainStage):
                s._losses = self._fake_losses[0]
                del self._fake_losses[0]
                self.v.append(np.mean(s.get_losses()))

        losses = Losses()
        events_container.event(stage,
                               'EPOCH_END').add_callback(losses.on_stage_end)

        trainer = Trainer(
            BaseTrainConfig(model, [stage], SimpleLoss(),
                            torch.optim.SGD(model.parameters(), lr=0.1)),
            fsm).set_epoch_num(5)
        metrics_processor.subscribe_to_stage(stage)

        checkpoint_file = os.path.join(self.base_dir, 'checkpoints', 'last',
                                       'last_checkpoint.zip')
        best_checkpoint_file = os.path.join(self.base_dir, 'checkpoints',
                                            'best', 'best_checkpoint.zip')

        cm = CheckpointsManager(fsm).subscribe2trainer(trainer)
        best_cm = CheckpointsManager(fsm, prefix='best')
        bsd = BestStateDetector(trainer).subscribe2stage(stage).add_rule(
            lambda: np.mean(stage.get_losses()))
        events_container.event(bsd, 'BEST_STATE_ACHIEVED').add_callback(
            lambda b: best_cm.save_trainer_state(trainer))

        trainer.train()

        self.assertTrue(os.path.exists(best_checkpoint_file))
        best_cm.load_trainer_state(trainer)
        self.assertEqual(2, trainer.cur_epoch_id() - 1)

        self.assertTrue(os.path.exists(checkpoint_file))
        cm.load_trainer_state(trainer)
        self.assertEqual(4, trainer.cur_epoch_id() - 1)
예제 #3
0
    def test_events(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()
        stage = TrainStage(
            TestDataProducer([{
                'data': torch.rand(1, 3),
                'target': torch.rand(1)
            } for _ in list(range(20))]))
        trainer = Trainer(
            BaseTrainConfig(model, [stage], SimpleLoss(),
                            torch.optim.SGD(model.parameters(), lr=0.1)),
            fsm).set_epoch_num(3)

        metrics_processor = MetricsProcessor().subscribe_to_stage(stage)
        metrics_processor.add_metric(DummyMetric())

        with MonitorHub(trainer) as mh:

            def on_epoch_start(local_trainer: Trainer):
                self.assertIs(local_trainer, trainer)

            def on_epoch_end(local_trainer: Trainer):
                self.assertIs(local_trainer, trainer)
                self.assertIsNone(
                    local_trainer.train_config().stages()[0].get_losses())

            def stage_on_epoch_end(local_stage: TrainStage):
                self.assertIs(local_stage, stage)
                self.assertEqual(20, local_stage.get_losses().size)

            mh.subscribe2metrics_processor(metrics_processor)

            events_container.event(
                stage, 'EPOCH_END').add_callback(stage_on_epoch_end)
            events_container.event(trainer,
                                   'EPOCH_START').add_callback(on_epoch_start)
            events_container.event(trainer,
                                   'EPOCH_END').add_callback(on_epoch_end)

            trainer.train()

            self.assertEqual(None,
                             trainer.train_config().stages()[0].get_losses())
예제 #4
0
 def subscribe2metrics_processor(self, metrics_processor: MetricsProcessor) -> 'MonitorHub':
     events_container.event(metrics_processor, "BEFORE_METRICS_RESET").add_callback(lambda mp: self.update_metrics(mp.get_metrics()))
     return self
예제 #5
0
 def __init__(self, trainer: Trainer):
     self.monitors = []
     events_container.event(trainer, 'EPOCH_START').add_callback(lambda t: self.set_epoch_num(t.cur_epoch_id()))
예제 #6
0
 def subscribe2trainer(self, trainer: Trainer) -> 'CheckpointsManager':
     events_container.event(trainer, 'EPOCH_END').add_callback(
         self.save_trainer_state)
     return self
예제 #7
0
 def subscribe2stage(self, stage: AbstractStage) -> 'BestStateDetector':
     events_container.event(stage, 'EPOCH_END').add_callback(
         lambda t: self.check_best_state_achieved())
     return self
예제 #8
0
 def subscribe_to_stage(self, stage: AbstractStage) -> 'MetricsProcessor':
     events_container.event(stage, 'BATCH_PROCESSED').add_callback(lambda s: self.calc_metrics(**s.get_last_result()))
     events_container.event(stage, 'STAGE_END').add_callback(lambda s: self.reset_metrics())
     return self