示例#1
0
    def test_saving_states(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()
        metrics_processor = MetricsProcessor()
        stage = TrainStage(
            TestDataProducer([{
                'data': torch.rand(1, 3),
                'target': torch.rand(1)
            } for _ in list(range(20))]))

        class Losses:
            def __init__(self):
                self.v = []
                self._fake_losses = [[i for _ in list(range(20))]
                                     for i in [5, 4, 0, 2, 1]]

            def on_stage_end(self, s: TrainStage):
                s._losses = self._fake_losses[0]
                del self._fake_losses[0]
                self.v.append(np.mean(s.get_losses()))

        losses = Losses()
        events_container.event(stage,
                               'EPOCH_END').add_callback(losses.on_stage_end)

        trainer = Trainer(
            BaseTrainConfig(model, [stage], SimpleLoss(),
                            torch.optim.SGD(model.parameters(), lr=0.1)),
            fsm).set_epoch_num(5)
        metrics_processor.subscribe_to_stage(stage)

        checkpoint_file = os.path.join(self.base_dir, 'checkpoints', 'last',
                                       'last_checkpoint.zip')
        best_checkpoint_file = os.path.join(self.base_dir, 'checkpoints',
                                            'best', 'best_checkpoint.zip')

        cm = CheckpointsManager(fsm).subscribe2trainer(trainer)
        best_cm = CheckpointsManager(fsm, prefix='best')
        bsd = BestStateDetector(trainer).subscribe2stage(stage).add_rule(
            lambda: np.mean(stage.get_losses()))
        events_container.event(bsd, 'BEST_STATE_ACHIEVED').add_callback(
            lambda b: best_cm.save_trainer_state(trainer))

        trainer.train()

        self.assertTrue(os.path.exists(best_checkpoint_file))
        best_cm.load_trainer_state(trainer)
        self.assertEqual(2, trainer.cur_epoch_id() - 1)

        self.assertTrue(os.path.exists(checkpoint_file))
        cm.load_trainer_state(trainer)
        self.assertEqual(4, trainer.cur_epoch_id() - 1)
示例#2
0
    def test_lr_decaying(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()
        stages = [
            TrainStage(
                TestDataProducer([{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))])),
            ValidationStage(
                TestDataProducer([{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))]))
        ]
        trainer = Trainer(
            BaseTrainConfig(model, stages, SimpleLoss(),
                            torch.optim.SGD(model.parameters(), lr=0.1)),
            fsm).set_epoch_num(10)

        def target_value_clbk() -> float:
            return 1

        trainer.enable_lr_decaying(0.5, 3, target_value_clbk)
        trainer.train()

        self.assertAlmostEqual(trainer.data_processor().get_lr(),
                               0.1 * (0.5**3),
                               delta=1e-6)
示例#3
0
    def __init__(self, fold_indices: {}):
        model = self.create_model(pretrained=False).cuda()

        train_dts = []
        for indices in fold_indices['train']:
            train_dts.append(
                create_augmented_dataset(is_train=True,
                                         indices_path=os.path.join(
                                             INDICES_DIR, indices + '.npy')))

        val_dts = create_augmented_dataset(is_train=False,
                                           indices_path=os.path.join(
                                               INDICES_DIR,
                                               fold_indices['val'] + '.npy'))

        workers_num = 4
        self._train_data_producer = DataProducer(DatasetsContainer(train_dts), batch_size=batch_size, num_workers=workers_num). \
            global_shuffle(True)#.pin_memory(True)
        self._val_data_producer = DataProducer(val_dts, batch_size=batch_size, num_workers=workers_num). \
            global_shuffle(True)#.pin_memory(True)

        self.train_stage = TrainStage(self._train_data_producer)
        self.val_stage = ValidationStage(self._val_data_producer)

        loss = RMSELoss().cuda()
        optimizer = Adam(params=model.parameters(), lr=1e-4)

        super().__init__(model, [self.train_stage, self.val_stage], loss,
                         optimizer)
示例#4
0
    def test_predict(self):
        test_data = {'data': torch.rand(1, 3)}

        model = SimpleModel()
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        cm = CheckpointsManager(fsm)

        stages = [
            TrainStage(
                TestDataProducer([{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))])),
            ValidationStage(
                TestDataProducer([{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))]))
        ]
        trainer = Trainer(BaseTrainConfig(model, stages, SimpleLoss(), torch.optim.SGD(model.parameters(), lr=1)), fsm)\
            .set_epoch_num(1)
        cm.subscribe2trainer(trainer)
        trainer.train()
        real_predict = trainer.data_processor().predict(test_data,
                                                        is_train=False)

        fsm = FileStructManager(base_dir=self.base_dir, is_continue=True)
        cm = CheckpointsManager(fsm)

        predict = Predictor(model, checkpoints_manager=cm).predict(test_data)

        self.assertTrue(
            np.equal(real_predict.cpu().detach().numpy(),
                     predict.cpu().detach().numpy()))
示例#5
0
 def test_train(self):
     fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
     model = SimpleModel()
     stages = [
         TrainStage(
             TestDataProducer([{
                 'data': torch.rand(1, 3),
                 'target': torch.rand(1)
             } for _ in list(range(20))])),
         ValidationStage(
             TestDataProducer([{
                 'data': torch.rand(1, 3),
                 'target': torch.rand(1)
             } for _ in list(range(20))]))
     ]
     Trainer(BaseTrainConfig(model, stages, SimpleLoss(), torch.optim.SGD(model.parameters(), lr=1)), fsm) \
         .set_epoch_num(1).train()
示例#6
0
    def test_metric_calc_in_train_loop(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()
        stages = [
            TrainStage(
                TestDataProducer([{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))])),
            ValidationStage(
                TestDataProducer([{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))]))
        ]
        trainer = Trainer(BaseTrainConfig(model, stages, SimpleLoss(), torch.optim.SGD(model.parameters(), lr=1)), fsm) \
            .set_epoch_num(2)

        mp = MetricsProcessor()
        metric1 = SimpleMetric(coeff=1, collect_values=True)
        # metric2 = SimpleMetric(coeff=1.7, collect_values=True)
        mp.add_metrics_group(MetricsGroup('grp1').add(metric1))
        # mp.add_metrics_group(MetricsGroup('grp2').add(metric2))

        mp.subscribe_to_stage(stages[0])  # .subscribe_to_stage(stages[1])
        # mp.subscribe_to_trainer(trainer)

        file_monitor_hub = FileLogMonitor(fsm).write_final_metrics()
        MonitorHub(trainer).subscribe2metrics_processor(mp).add_monitor(
            file_monitor_hub)

        trainer.train()

        with open(os.path.join(file_monitor_hub.get_dir(), 'metrics.json'),
                  'r') as metrics_file:
            metrics = json.load(metrics_file)

        self.assertAlmostEqual(
            metrics['grp1/SimpleMetric'],
            float(
                np.mean([
                    F.pairwise_distance(i[0], i[1],
                                        p=2).cpu().detach().numpy()
                    for i in metric1._inputs
                ])),
            delta=1e-2)
示例#7
0
    def test_train_stage(self):
        data_producer = DataProducer([{
            'data': torch.rand(1, 3),
            'target': torch.rand(1)
        } for _ in list(range(20))])
        metrics_processor = FakeMetricsProcessor()
        train_stage = TrainStage(data_producer).enable_hard_negative_mining(
            0.1)

        metrics_processor.subscribe_to_stage(train_stage)

        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()
        Trainer(BaseTrainConfig(model, [train_stage], SimpleLoss(), torch.optim.SGD(model.parameters(), lr=1)), fsm) \
            .set_epoch_num(1).train()

        self.assertEqual(metrics_processor.call_num, len(data_producer))
示例#8
0
    def test_events(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()
        stage = TrainStage(
            TestDataProducer([{
                'data': torch.rand(1, 3),
                'target': torch.rand(1)
            } for _ in list(range(20))]))
        trainer = Trainer(
            BaseTrainConfig(model, [stage], SimpleLoss(),
                            torch.optim.SGD(model.parameters(), lr=0.1)),
            fsm).set_epoch_num(3)

        metrics_processor = MetricsProcessor().subscribe_to_stage(stage)
        metrics_processor.add_metric(DummyMetric())

        with MonitorHub(trainer) as mh:

            def on_epoch_start(local_trainer: Trainer):
                self.assertIs(local_trainer, trainer)

            def on_epoch_end(local_trainer: Trainer):
                self.assertIs(local_trainer, trainer)
                self.assertIsNone(
                    local_trainer.train_config().stages()[0].get_losses())

            def stage_on_epoch_end(local_stage: TrainStage):
                self.assertIs(local_stage, stage)
                self.assertEqual(20, local_stage.get_losses().size)

            mh.subscribe2metrics_processor(metrics_processor)

            events_container.event(
                stage, 'EPOCH_END').add_callback(stage_on_epoch_end)
            events_container.event(trainer,
                                   'EPOCH_START').add_callback(on_epoch_start)
            events_container.event(trainer,
                                   'EPOCH_END').add_callback(on_epoch_end)

            trainer.train()

            self.assertEqual(None,
                             trainer.train_config().stages()[0].get_losses())
示例#9
0
 def stage_on_epoch_end(local_stage: TrainStage):
     self.assertIs(local_stage, stage)
     self.assertEqual(20, local_stage.get_losses().size)
示例#10
0
 def on_stage_end(self, s: TrainStage):
     s._losses = self._fake_losses[0]
     del self._fake_losses[0]
     self.v.append(np.mean(s.get_losses()))
示例#11
0
    def test_hard_negatives_mining(self):
        with self.assertRaises(ValueError):
            stage = TrainStage(None).enable_hard_negative_mining(0)
        with self.assertRaises(ValueError):
            stage = TrainStage(None).enable_hard_negative_mining(1)
        with self.assertRaises(ValueError):
            stage = TrainStage(None).enable_hard_negative_mining(-1)
        with self.assertRaises(ValueError):
            stage = TrainStage(None).enable_hard_negative_mining(1.1)

        dp = TestDataProducer([{
            'data': torch.Tensor([i]),
            'target': torch.rand(1)
        } for i in list(range(20))]).pass_indices(True)
        stage = TrainStage(dp).enable_hard_negative_mining(0.1)
        losses = np.random.rand(20)
        samples = []

        def on_batch(batch, data_processor):
            samples.append(batch)
            stage.hnm._losses = np.array([0])

        stage.hnm._process_batch = on_batch
        stage.hnm.exec(None, losses, [[str(i)] for i in range(20)])

        self.assertEqual(len(samples), 2)

        losses = [float(v) for v in losses]
        idxs = [int(s['data']) for s in samples]
        max_losses = [losses[i] for i in idxs]
        idxs.sort(reverse=True)
        for i in idxs:
            del losses[i]

        for l in losses:
            self.assertLess(l, min(max_losses))

        stage.on_epoch_end()

        self.assertIsNone(stage.hnm._losses)

        stage.disable_hard_negative_mining()
        self.assertIsNone(stage.hnm)

        for data in dp:
            self.assertIn('data_idx', data)