Beispiel #1
0
    def test_predict(self):
        test_data = {'data': torch.rand(1, 3)}

        model = SimpleModel()
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        cm = CheckpointsManager(fsm)

        stages = [
            TrainStage(
                TestDataProducer([{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))])),
            ValidationStage(
                TestDataProducer([{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))]))
        ]
        trainer = Trainer(BaseTrainConfig(model, stages, SimpleLoss(), torch.optim.SGD(model.parameters(), lr=1)), fsm)\
            .set_epoch_num(1)
        cm.subscribe2trainer(trainer)
        trainer.train()
        real_predict = trainer.data_processor().predict(test_data,
                                                        is_train=False)

        fsm = FileStructManager(base_dir=self.base_dir, is_continue=True)
        cm = CheckpointsManager(fsm)

        predict = Predictor(model, checkpoints_manager=cm).predict(test_data)

        self.assertTrue(
            np.equal(real_predict.cpu().detach().numpy(),
                     predict.cpu().detach().numpy()))
Beispiel #2
0
    def test_creation(self):
        if os.path.exists(self.base_dir):
            shutil.rmtree(self.checkpoints_dir, ignore_errors=True)

        try:
            FileStructManager(base_dir=self.base_dir, is_continue=False)
        except FileStructManager.FSMException as err:
            self.fail(
                "Raise error when base directory exists: [{}]".format(err))

        self.assertFalse(os.path.exists(self.base_dir))

        try:
            FileStructManager(base_dir=self.base_dir, is_continue=False)
        except FileStructManager.FSMException as err:
            self.fail("Raise error when base directory exists but empty: [{}]".
                      format(err))

        os.makedirs(os.path.join(self.base_dir, 'new_dir'))
        try:
            FileStructManager(base_dir=self.base_dir, is_continue=False)
        except:
            self.fail(
                "Error initialize when exists non-registered folders in base directory"
            )

        shutil.rmtree(self.base_dir, ignore_errors=True)
Beispiel #3
0
    def __init__(self,
                 fsm: FileStructManager,
                 is_continue: bool,
                 network_name: str = None):
        super().__init__()
        self._writer = None
        self._txt_log_file = None

        fsm.register_dir(self)
        directory = fsm.get_path(self)
        if directory is None:
            return

        directory = os.path.join(
            directory, network_name) if network_name is not None else directory

        if not (fsm.in_continue_mode() or is_continue) and os.path.exists(
                directory) and os.path.isdir(directory):
            idx = 0
            tmp_dir = directory + "_v{}".format(idx)
            while os.path.exists(tmp_dir) and os.path.isdir(tmp_dir):
                idx += 1
                tmp_dir = directory + "_v{}".format(idx)
            directory = tmp_dir

        os.makedirs(directory, exist_ok=True)
        self._writer = SummaryWriter(directory)
        self._txt_log_file = open(os.path.join(directory, "log.txt"),
                                  'a' if is_continue else 'w')
Beispiel #4
0
    def test_pack(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        cm = CheckpointsManager(fsm)
        with self.assertRaises(CheckpointsManager.CMException):
            cm.pack()

        os.mkdir(cm.weights_file())
        os.mkdir(cm.optimizer_state_file())
        with self.assertRaises(CheckpointsManager.CMException):
            cm.pack()

        shutil.rmtree(cm.weights_file())
        shutil.rmtree(cm.optimizer_state_file())

        f = open(cm.weights_file(), 'w')
        f.close()
        f = open(cm.optimizer_state_file(), 'w')
        f.close()
        f = open(cm.trainer_file(), 'w')
        f.close()

        try:
            cm.pack()
        except Exception as err:
            self.fail('Exception on packing files: [{}]'.format(err))

        for f in [cm.weights_file(), cm.optimizer_state_file()]:
            if os.path.exists(f) and os.path.isfile(f):
                self.fail("File '{}' doesn't remove after pack".format(f))

        result = os.path.join(
            fsm.get_path(cm, check=False, create_if_non_exists=False),
            'last_checkpoint.zip')
        self.assertTrue(os.path.exists(result) and os.path.isfile(result))

        f = open(cm.weights_file(), 'w')
        f.close()
        f = open(cm.optimizer_state_file(), 'w')
        f.close()
        f = open(cm.trainer_file(), 'w')
        f.close()

        try:
            cm.pack()
            result = os.path.join(
                fsm.get_path(cm, check=False, create_if_non_exists=False),
                'last_checkpoint.zip.old')
            self.assertTrue(os.path.exists(result) and os.path.isfile(result))
        except Exception as err:
            self.fail('Fail to pack with existing previous state file')
Beispiel #5
0
    def test_initialisation(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)

        try:
            cm = CheckpointsManager(fsm)
        except Exception as err:
            self.fail("Fail init CheckpointsManager; err: ['{}']".format(err))

        with self.assertRaises(FileStructManager.FSMException):
            CheckpointsManager(fsm)

        os.mkdir(os.path.join(fsm.get_path(cm), 'test_dir'))
        with self.assertRaises(FileStructManager.FSMException):
            CheckpointsManager(fsm)
Beispiel #6
0
def train():
    model = resnet18(classes_num=1, in_channels=3, pretrained=True)
    train_config = TrainConfig(model, [train_stage, val_stage],
                               torch.nn.BCEWithLogitsLoss(),
                               torch.optim.Adam(model.parameters(), lr=1e-4))

    file_struct_manager = FileStructManager(base_dir='data', is_continue=False)

    trainer = Trainer(train_config, file_struct_manager,
                      torch.device('cuda:0')).set_epoch_num(2)

    tensorboard = TensorboardMonitor(file_struct_manager,
                                     is_continue=False,
                                     network_name='PortraitSegmentation')
    log = LogMonitor(file_struct_manager).write_final_metrics()
    trainer.monitor_hub.add_monitor(tensorboard).add_monitor(log)
    trainer.enable_best_states_saving(
        lambda: np.mean(train_stage.get_losses()))

    trainer.enable_lr_decaying(
        coeff=0.5,
        patience=10,
        target_val_clbk=lambda: np.mean(train_stage.get_losses()))
    trainer.add_on_epoch_end_callback(
        lambda: tensorboard.update_scalar('params/lr',
                                          trainer.data_processor().get_lr()))
    trainer.train()
Beispiel #7
0
    def test_unpack(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        cm = CheckpointsManager(fsm)

        with open(cm.weights_file(), 'w') as f:
            f.write('1')
        with open(cm.optimizer_state_file(), 'w') as f:
            f.write('2')
        with open(cm.trainer_file(), 'w') as f:
            f.write('3')

        cm.pack()

        try:
            cm.unpack()
        except Exception as err:
            self.fail('Exception on unpacking')

        for i, f in enumerate(
            [cm.weights_file(),
             cm.optimizer_state_file(),
             cm.trainer_file()]):
            if not (os.path.exists(f) and os.path.isfile(f)):
                self.fail("File '{}' doesn't remove after pack".format(f))
            with open(f, 'r') as file:
                if file.read() != str(i + 1):
                    self.fail("File content corrupted")
Beispiel #8
0
    def test_lr_decaying(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()
        stages = [
            TrainStage(
                TestDataProducer([{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))])),
            ValidationStage(
                TestDataProducer([{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))]))
        ]
        trainer = Trainer(
            BaseTrainConfig(model, stages, SimpleLoss(),
                            torch.optim.SGD(model.parameters(), lr=0.1)),
            fsm).set_epoch_num(10)

        def target_value_clbk() -> float:
            return 1

        trainer.enable_lr_decaying(0.5, 3, target_value_clbk)
        trainer.train()

        self.assertAlmostEqual(trainer.data_processor().get_lr(),
                               0.1 * (0.5**3),
                               delta=1e-6)
Beispiel #9
0
    def test_base_ops(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()

        trainer = Trainer(
            BaseTrainConfig(model, [], torch.nn.L1Loss(),
                            torch.optim.SGD(model.parameters(), lr=1)), fsm)
        with self.assertRaises(Trainer.TrainerException):
            trainer.train()
Beispiel #10
0
    def test_saving_states(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()
        metrics_processor = MetricsProcessor()
        stage = TrainStage(
            TestDataProducer([{
                'data': torch.rand(1, 3),
                'target': torch.rand(1)
            } for _ in list(range(20))]))

        class Losses:
            def __init__(self):
                self.v = []
                self._fake_losses = [[i for _ in list(range(20))]
                                     for i in [5, 4, 0, 2, 1]]

            def on_stage_end(self, s: TrainStage):
                s._losses = self._fake_losses[0]
                del self._fake_losses[0]
                self.v.append(np.mean(s.get_losses()))

        losses = Losses()
        events_container.event(stage,
                               'EPOCH_END').add_callback(losses.on_stage_end)

        trainer = Trainer(
            BaseTrainConfig(model, [stage], SimpleLoss(),
                            torch.optim.SGD(model.parameters(), lr=0.1)),
            fsm).set_epoch_num(5)
        metrics_processor.subscribe_to_stage(stage)

        checkpoint_file = os.path.join(self.base_dir, 'checkpoints', 'last',
                                       'last_checkpoint.zip')
        best_checkpoint_file = os.path.join(self.base_dir, 'checkpoints',
                                            'best', 'best_checkpoint.zip')

        cm = CheckpointsManager(fsm).subscribe2trainer(trainer)
        best_cm = CheckpointsManager(fsm, prefix='best')
        bsd = BestStateDetector(trainer).subscribe2stage(stage).add_rule(
            lambda: np.mean(stage.get_losses()))
        events_container.event(bsd, 'BEST_STATE_ACHIEVED').add_callback(
            lambda b: best_cm.save_trainer_state(trainer))

        trainer.train()

        self.assertTrue(os.path.exists(best_checkpoint_file))
        best_cm.load_trainer_state(trainer)
        self.assertEqual(2, trainer.cur_epoch_id() - 1)

        self.assertTrue(os.path.exists(checkpoint_file))
        cm.load_trainer_state(trainer)
        self.assertEqual(4, trainer.cur_epoch_id() - 1)
Beispiel #11
0
    def test_clear_files(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        cm = CheckpointsManager(fsm)

        f = open(cm.weights_file(), 'w')
        f.close()
        f = open(cm.optimizer_state_file(), 'w')
        f.close()

        cm.clear_files()

        for f in [cm.weights_file(), cm.optimizer_state_file()]:
            if os.path.exists(f) and os.path.isfile(f):
                self.fail("File '{}' doesn't remove after pack".format(f))
Beispiel #12
0
    def test_train_stage(self):
        data_producer = DataProducer([{
            'data': torch.rand(1, 3),
            'target': torch.rand(1)
        } for _ in list(range(20))])
        metrics_processor = FakeMetricsProcessor()
        train_stage = TrainStage(data_producer).enable_hard_negative_mining(
            0.1)

        metrics_processor.subscribe_to_stage(train_stage)

        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()
        Trainer(BaseTrainConfig(model, [train_stage], SimpleLoss(), torch.optim.SGD(model.parameters(), lr=1)), fsm) \
            .set_epoch_num(1).train()

        self.assertEqual(metrics_processor.call_num, len(data_producer))
Beispiel #13
0
 def test_train(self):
     fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
     model = SimpleModel()
     stages = [
         TrainStage(
             TestDataProducer([{
                 'data': torch.rand(1, 3),
                 'target': torch.rand(1)
             } for _ in list(range(20))])),
         ValidationStage(
             TestDataProducer([{
                 'data': torch.rand(1, 3),
                 'target': torch.rand(1)
             } for _ in list(range(20))]))
     ]
     Trainer(BaseTrainConfig(model, stages, SimpleLoss(), torch.optim.SGD(model.parameters(), lr=1)), fsm) \
         .set_epoch_num(1).train()
Beispiel #14
0
    def test_metric_calc_in_train_loop(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()
        stages = [
            TrainStage(
                TestDataProducer([{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))])),
            ValidationStage(
                TestDataProducer([{
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                } for _ in list(range(20))]))
        ]
        trainer = Trainer(BaseTrainConfig(model, stages, SimpleLoss(), torch.optim.SGD(model.parameters(), lr=1)), fsm) \
            .set_epoch_num(2)

        mp = MetricsProcessor()
        metric1 = SimpleMetric(coeff=1, collect_values=True)
        # metric2 = SimpleMetric(coeff=1.7, collect_values=True)
        mp.add_metrics_group(MetricsGroup('grp1').add(metric1))
        # mp.add_metrics_group(MetricsGroup('grp2').add(metric2))

        mp.subscribe_to_stage(stages[0])  # .subscribe_to_stage(stages[1])
        # mp.subscribe_to_trainer(trainer)

        file_monitor_hub = FileLogMonitor(fsm).write_final_metrics()
        MonitorHub(trainer).subscribe2metrics_processor(mp).add_monitor(
            file_monitor_hub)

        trainer.train()

        with open(os.path.join(file_monitor_hub.get_dir(), 'metrics.json'),
                  'r') as metrics_file:
            metrics = json.load(metrics_file)

        self.assertAlmostEqual(
            metrics['grp1/SimpleMetric'],
            float(
                np.mean([
                    F.pairwise_distance(i[0], i[1],
                                        p=2).cpu().detach().numpy()
                    for i in metric1._inputs
                ])),
            delta=1e-2)
Beispiel #15
0
    def test_events(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        model = SimpleModel()
        stage = TrainStage(
            TestDataProducer([{
                'data': torch.rand(1, 3),
                'target': torch.rand(1)
            } for _ in list(range(20))]))
        trainer = Trainer(
            BaseTrainConfig(model, [stage], SimpleLoss(),
                            torch.optim.SGD(model.parameters(), lr=0.1)),
            fsm).set_epoch_num(3)

        metrics_processor = MetricsProcessor().subscribe_to_stage(stage)
        metrics_processor.add_metric(DummyMetric())

        with MonitorHub(trainer) as mh:

            def on_epoch_start(local_trainer: Trainer):
                self.assertIs(local_trainer, trainer)

            def on_epoch_end(local_trainer: Trainer):
                self.assertIs(local_trainer, trainer)
                self.assertIsNone(
                    local_trainer.train_config().stages()[0].get_losses())

            def stage_on_epoch_end(local_stage: TrainStage):
                self.assertIs(local_stage, stage)
                self.assertEqual(20, local_stage.get_losses().size)

            mh.subscribe2metrics_processor(metrics_processor)

            events_container.event(
                stage, 'EPOCH_END').add_callback(stage_on_epoch_end)
            events_container.event(trainer,
                                   'EPOCH_START').add_callback(on_epoch_start)
            events_container.event(trainer,
                                   'EPOCH_END').add_callback(on_epoch_end)

            trainer.train()

            self.assertEqual(None,
                             trainer.train_config().stages()[0].get_losses())
Beispiel #16
0
    def train_fold(self, init_trainer: callable, model_name: str, out_dir: str, fold_num: int):
        cur_folds = self._folds.copy()
        val_fold = cur_folds.pop(fold_num)
        folds = {'train': cur_folds, 'val': val_fold}

        fsm = FileStructManager(base_dir=os.path.join(out_dir, model_name, val_fold), is_continue=False)
        trainer = init_trainer(fsm, folds)
        trainer.train()

        meta_info = [{'model': model_name, 'fold': val_fold, 'path': os.path.join(model_name, val_fold)}]

        self._folds.append(val_fold)

        meta_file = os.path.join(out_dir, 'meta.json')

        if os.path.exists(meta_file):
            with open(meta_file, 'r') as meta_file:
                exists_meta = json.load(meta_file)
                meta_info = exists_meta + meta_info

        with open(os.path.join(out_dir, 'meta.json'), 'w') as meta_file:
            json.dump(meta_info, meta_file, indent=4)
Beispiel #17
0
    def metrics_processing(self,
                           with_final_file: bool,
                           final_file: str = None):
        fsm = FileStructManager(base_dir='data', is_continue=False)
        base_dir = os.path.join('data', 'monitors', 'metrics_log')
        expected_outs = [
            os.path.join(base_dir, f) for f in [
                'meta.json', 'd.csv',
                os.path.join('lv1', 'a.csv'),
                os.path.join('lv1', 'b.csv'),
                os.path.join('lv1', 'lv2', 'c.csv')
            ]
        ]

        metrics_group_lv1 = MetricsGroup('lv1').add(
            SimpleMetric(name='a',
                         coeff=1)).add(SimpleMetric(name='b', coeff=2))
        metrics_group_lv2 = MetricsGroup('lv2').add(
            SimpleMetric(name='c', coeff=3))
        metrics_group_lv1.add(metrics_group_lv2)
        m = SimpleMetric(name='d', coeff=4)

        values = []
        with FileLogMonitor(fsm) as monitor:
            if with_final_file:
                monitor.write_final_metrics(final_file)
            for epoch in range(10):
                cur_vals = []
                for i in range(10):
                    output, target = torch.rand(1, 3), torch.rand(1, 3)
                    metrics_group_lv1.calc(output, target)
                    m._calc(output, target)

                    cur_vals.append(
                        np.linalg.norm(output.numpy() - target.numpy()))

                values.append(float(np.mean(cur_vals)))
                monitor.set_epoch_num(epoch)
                monitor.update_metrics({
                    'metrics': [m],
                    'groups': [metrics_group_lv1]
                })
                m.reset()
                metrics_group_lv1.reset()

        for out in expected_outs:
            self.assertTrue(os.path.exists(out) and os.path.isfile(out))

        with open(os.path.join(base_dir, 'meta.json'), 'r') as file:
            meta = json.load(file)

        self.assertIn("data/monitors/metrics_log/d.csv", meta)
        self.assertEqual(meta["data/monitors/metrics_log/d.csv"], {
            "name": "d",
            "path": []
        })

        metrics_values = {}
        for path, v in meta.items():
            metrics_values[v['name']] = np.loadtxt(path, delimiter=',')

        for i, v in enumerate(values):
            self.assertAlmostEqual(metrics_values['d'][i][1:],
                                   values[i] * 4,
                                   delta=1e-5)
            self.assertAlmostEqual(metrics_values['a'][i][1:],
                                   values[i],
                                   delta=1e-5)
            self.assertAlmostEqual(metrics_values['b'][i][1:],
                                   values[i] * 2,
                                   delta=1e-5)
            self.assertAlmostEqual(metrics_values['c'][i][1:],
                                   values[i] * 3,
                                   delta=1e-5)

        return values
Beispiel #18
0
    def test_object_registration(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        fsm_exist_ok = FileStructManager(base_dir=self.base_dir,
                                         is_continue=False,
                                         exists_ok=True)
        o = self.TestObj(fsm, 'test_dir', 'test_name')
        fsm.register_dir(o)

        expected_path = os.path.join(self.base_dir, 'test_dir')
        self.assertFalse(os.path.exists(expected_path))
        self.assertEqual(fsm.get_path(o), expected_path)

        with self.assertRaises(FileStructManager.FSMException):
            fsm.register_dir(self.TestObj(fsm, 'test_dir', 'another_name'))
        try:
            fsm.register_dir(self.TestObj(fsm, 'test_dir', 'another_name'),
                             check_dir_registered=False)
            fsm_exist_ok.register_dir(
                self.TestObj(fsm, 'test_dir', 'another_name'))
            fsm_exist_ok.register_dir(self.TestObj(fsm, 'test_dir',
                                                   'another_name2'),
                                      check_dir_registered=False)
        except:
            self.fail("Folder registrable test fail when it's disabled")

        with self.assertRaises(FileStructManager.FSMException):
            fsm.register_dir(self.TestObj(fsm, 'another_dir', 'test_name'))
            fsm.register_dir(self.TestObj(fsm, 'another_dir', 'another_name'))
        with self.assertRaises(FileStructManager.FSMException):
            fsm_exist_ok.register_dir(
                self.TestObj(fsm, 'another_dir', 'test_name'))
            fsm_exist_ok.register_dir(
                self.TestObj(fsm, 'another_dir', 'another_name'))

        try:
            fsm.register_dir(self.TestObj(fsm, 'another_dir2', 'test_name'),
                             check_name_registered=False)
            fsm_exist_ok.register_dir(self.TestObj(fsm, 'another_dir2',
                                                   'test_name'),
                                      check_name_registered=False)
        except:
            self.fail("Folder registrable test fail when it's disabled")

        os.makedirs(os.path.join(self.base_dir, 'dir_dir', 'dir'))
        with self.assertRaises(FileStructManager.FSMException):
            fsm.register_dir(self.TestObj(fsm, 'dir_dir', 'name_name'))

        try:
            fsm_exist_ok.register_dir(self.TestObj(fsm, 'dir_dir',
                                                   'name_name'))
        except:
            self.fail("Folder registrable test fail when exists_ok=True")