def test_continue_from_checkpoint(self):
        def on_node(n1, n2):
            self.assertTrue(np.array_equal(n1.numpy(), n2.numpy()))

        model = SimpleModel().train()
        dp = DataProcessor(model=model)
        with self.assertRaises(Model.ModelException):
            dp.save_state()
        try:
            fsm = FileStructManager(self.base_dir, is_continue=False)
            dp.set_checkpoints_manager(CheckpointsManager(fsm))
            dp.save_state()
        except:
            self.fail('Fail to DataProcessor load when CheckpointsManager was defined')

        del dp

        model_new = SimpleModel().train()
        dp = DataProcessor(model=model_new)

        with self.assertRaises(Model.ModelException):
            dp.load()
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=True)
        dp.set_checkpoints_manager(CheckpointsManager(fsm))
        try:
            fsm = FileStructManager(self.base_dir, is_continue=True)
            dp.set_checkpoints_manager(CheckpointsManager(fsm))
            dp.load()
        except:
            self.fail('Fail to DataProcessor load when CheckpointsManager was defined')

        compare_two_models(self, model, model_new)
Beispiel #2
0
    def test_creation(self):
        if os.path.exists(self.base_dir):
            shutil.rmtree(self.checkpoints_dir, ignore_errors=True)

        try:
            FileStructManager(base_dir=self.base_dir, is_continue=False)
        except FileStructManager.FSMException as err:
            self.fail(
                "Raise error when base directory exists: [{}]".format(err))

        self.assertFalse(os.path.exists(self.base_dir))

        try:
            FileStructManager(base_dir=self.base_dir, is_continue=False)
        except FileStructManager.FSMException as err:
            self.fail("Raise error when base directory exists but empty: [{}]".
                      format(err))

        os.makedirs(os.path.join(self.base_dir, 'new_dir'))
        try:
            FileStructManager(base_dir=self.base_dir, is_continue=False)
        except:
            self.fail(
                "Error initialize when exists non-registered folders in base directory"
            )

        shutil.rmtree(self.base_dir, ignore_errors=True)
Beispiel #3
0
    def test_pack(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        cm = CheckpointsManager(fsm)
        with self.assertRaises(CheckpointsManager.SMException):
            cm.pack()

        os.mkdir(cm.weights_file())
        os.mkdir(cm.optimizer_state_file())
        with self.assertRaises(CheckpointsManager.SMException):
            cm.pack()

        shutil.rmtree(cm.weights_file())
        shutil.rmtree(cm.optimizer_state_file())

        f = open(cm.weights_file(), 'w')
        f.close()
        f = open(cm.optimizer_state_file(), 'w')
        f.close()
        f = open(cm.trainer_file(), 'w')
        f.close()

        try:
            cm.pack()
        except Exception as err:
            self.fail('Exception on packing files: [{}]'.format(err))

        for f in [cm.weights_file(), cm.optimizer_state_file()]:
            if os.path.exists(f) and os.path.isfile(f):
                self.fail("File '{}' doesn't remove after pack".format(f))

        result = os.path.join(
            fsm.get_path(cm, check=False, create_if_non_exists=False),
            'last_checkpoint.zip')
        self.assertTrue(os.path.exists(result) and os.path.isfile(result))

        f = open(cm.weights_file(), 'w')
        f.close()
        f = open(cm.optimizer_state_file(), 'w')
        f.close()
        f = open(cm.trainer_file(), 'w')
        f.close()

        try:
            cm.pack()
            result = os.path.join(
                fsm.get_path(cm, check=False, create_if_non_exists=False),
                'last_checkpoint.zip.old')
            self.assertTrue(os.path.exists(result) and os.path.isfile(result))
        except Exception as err:
            self.fail('Fail to pack with existing previous state file')
Beispiel #4
0
    def test_initialisation(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)

        try:
            cm = CheckpointsManager(fsm)
        except Exception as err:
            self.fail("Fail init CheckpointsManager; err: ['{}']".format(err))

        with self.assertRaises(FileStructManager.FSMException):
            CheckpointsManager(fsm)

        os.mkdir(os.path.join(fsm.get_path(cm), 'test_dir'))
        with self.assertRaises(FileStructManager.FSMException):
            CheckpointsManager(fsm)
    def test_continue_from_checkpoint(self):
        def on_node(n1, n2):
            self.assertTrue(np.array_equal(n1.numpy(), n2.numpy()))

        model = SimpleModel().train()
        loss = SimpleLoss()

        for optim in [
                torch.optim.SGD(model.parameters(), lr=0.1),
                torch.optim.Adam(model.parameters(), lr=0.1)
        ]:
            train_config = TrainConfig([], loss, optim)

            dp_before = TrainDataProcessor(model=model,
                                           train_config=train_config)
            before_state_dict = model.state_dict().copy()
            dp_before.update_lr(0.023)

            with self.assertRaises(Model.ModelException):
                dp_before.save_state()
            try:
                fsm = FileStructManager(base_dir=self.base_dir,
                                        is_continue=False)
                dp_before.set_checkpoints_manager(CheckpointsManager(fsm))
                dp_before.save_state()
            except:
                self.fail(
                    "Exception on saving state when 'CheckpointsManager' specified"
                )

            fsm = FileStructManager(base_dir=self.base_dir, is_continue=True)
            dp_after = TrainDataProcessor(model=model,
                                          train_config=train_config)
            with self.assertRaises(Model.ModelException):
                dp_after.load()
            try:
                cm = CheckpointsManager(fsm)
                dp_after.set_checkpoints_manager(cm)
                dp_after.load()
            except:
                self.fail('DataProcessor initialisation raises exception')

            after_state_dict = model.state_dict().copy()

            dict_pair_recursive_bypass(before_state_dict, after_state_dict,
                                       on_node)
            self.assertEqual(dp_before.get_lr(), dp_after.get_lr())

            shutil.rmtree(self.base_dir)
Beispiel #6
0
    def test_unpack(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        cm = CheckpointsManager(fsm)

        f = open(cm.weights_file(), 'w')
        f.write('1')
        f.close()
        f = open(cm.optimizer_state_file(), 'w')
        f.write('2')
        f = open(cm.trainer_file(), 'w')
        f.write('3')
        f.close()

        cm.pack()

        try:
            cm.unpack()
        except Exception as err:
            self.fail('Exception on unpacking')

        for i, f in enumerate(
            [cm.weights_file(),
             cm.optimizer_state_file(),
             cm.trainer_file()]):
            if not (os.path.exists(f) and os.path.isfile(f)):
                self.fail("File '{}' doesn't remove after pack".format(f))
            with open(f, 'r') as file:
                if file.read() != str(i + 1):
                    self.fail("File content corrupted")
    def test_base_execution(self):
        fsm = FileStructManager(base_dir='data', is_continue=False)
        expected_out = os.path.join('data', 'monitors', 'metrics_log',
                                    'metrics_log.json')
        try:
            with LogMonitor(fsm) as m:
                self.assertEqual(m._file, expected_out)
        except:
            self.fail('Fail initialisation')

        self.assertTrue(
            os.path.exists(expected_out) and os.path.isfile(expected_out))
Beispiel #8
0
    def test_clear_files(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        cm = CheckpointsManager(fsm)

        f = open(cm.weights_file(), 'w')
        f.close()
        f = open(cm.optimizer_state_file(), 'w')
        f.close()

        cm.clear_files()

        for f in [cm.weights_file(), cm.optimizer_state_file()]:
            if os.path.exists(f) and os.path.isfile(f):
                self.fail("File '{}' doesn't remove after pack".format(f))
Beispiel #9
0
def continue_training():
    ########################################################
    # Create needed parameters again
    ########################################################

    model = resnet18(classes_num=1, in_channels=3, pretrained=True)
    train_config = TrainConfig([train_stage, val_stage],
                               torch.nn.BCEWithLogitsLoss(),
                               torch.optim.Adam(model.parameters(), lr=1e-4))

    ########################################################
    # If FileStructManager creates again - just 'set is_continue' parameter to True
    ########################################################
    file_struct_manager = FileStructManager(base_dir='data', is_continue=True)

    trainer = Trainer(model, train_config, file_struct_manager,
                      torch.device('cuda:0')).set_epoch_num(10)

    tensorboard = TensorboardMonitor(file_struct_manager,
                                     is_continue=False,
                                     network_name='PortraitSegmentation')
    log = LogMonitor(file_struct_manager).write_final_metrics()
    trainer.monitor_hub.add_monitor(tensorboard).add_monitor(log)
    trainer.enable_best_states_saving(
        lambda: np.mean(train_stage.get_losses()))

    trainer.enable_lr_decaying(
        coeff=0.5,
        patience=10,
        target_val_clbk=lambda: np.mean(train_stage.get_losses()))
    trainer.add_on_epoch_end_callback(
        lambda: tensorboard.update_scalar('params/lr',
                                          trainer.data_processor().get_lr()))

    ########################################################
    # For set resume mode to Trainer just call 'resume' method
    ########################################################

    trainer.resume(from_best_checkpoint=False).train()
    def metrics_processing(self,
                           with_final_file: bool,
                           final_file: str = None):
        fsm = FileStructManager(base_dir='data', is_continue=False)
        expected_out = os.path.join('data', 'monitors', 'metrics_log',
                                    'metrics_log.json')

        metrics_group_lv1 = MetricsGroup('lv1').add(
            SimpleMetric(name='a',
                         coeff=1)).add(SimpleMetric(name='b', coeff=2))
        metrics_group_lv2 = MetricsGroup('lv2').add(
            SimpleMetric(name='c', coeff=3))
        metrics_group_lv1.add(metrics_group_lv2)
        m = SimpleMetric(name='d', coeff=4)

        values = []
        with LogMonitor(fsm) as monitor:
            if with_final_file:
                monitor.write_final_metrics(final_file)
            for epoch in range(10):
                cur_vals = []
                for i in range(10):
                    output, target = torch.rand(1, 3), torch.rand(1, 3)
                    metrics_group_lv1.calc(output, target)
                    m._calc(output, target)
                    cur_vals.append(
                        np.linalg.norm(output.numpy() - target.numpy()))

                values.append(float(np.mean(cur_vals)))
                monitor.set_epoch_num(epoch)
                monitor.update_metrics({
                    'metrics': [m],
                    'groups': [metrics_group_lv1]
                })
                m.reset()
                metrics_group_lv1.reset()

        self.assertTrue(
            os.path.exists(expected_out) and os.path.isfile(expected_out))

        with open(expected_out, 'r') as file:
            data = json.load(file)

        self.assertIn('d', data)
        self.assertIn('lv1', data)
        self.assertIn('lv2', data['lv1'])
        self.assertIn('a', data['lv1'])
        self.assertIn('b', data['lv1'])
        self.assertIn('c', data['lv1']['lv2'])

        self.assertEqual(len(data['d']), len(values))
        self.assertEqual(len(data['lv1']['a']), len(values))
        self.assertEqual(len(data['lv1']['b']), len(values))
        self.assertEqual(len(data['lv1']['lv2']['c']), len(values))

        for i, v in enumerate(values):
            self.assertAlmostEqual(data['d'][i], values[i] * 4, delta=1e-5)
            self.assertAlmostEqual(data['lv1']['a'][i], values[i], delta=1e-5)
            self.assertAlmostEqual(data['lv1']['b'][i],
                                   values[i] * 2,
                                   delta=1e-5)
            self.assertAlmostEqual(data['lv1']['lv2']['c'][i],
                                   values[i] * 3,
                                   delta=1e-5)

        return values
Beispiel #11
0
    def test_object_registration(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        fsm_exist_ok = FileStructManager(base_dir=self.base_dir,
                                         is_continue=False,
                                         exists_ok=True)
        o = self.TestObj(fsm, 'test_dir', 'test_name')
        fsm.register_dir(o)

        expected_path = os.path.join(self.base_dir, 'test_dir')
        self.assertFalse(os.path.exists(expected_path))
        self.assertEqual(fsm.get_path(o), expected_path)

        with self.assertRaises(FileStructManager.FSMException):
            fsm.register_dir(self.TestObj(fsm, 'test_dir', 'another_name'))
        try:
            fsm.register_dir(self.TestObj(fsm, 'test_dir', 'another_name'),
                             check_dir_registered=False)
            fsm_exist_ok.register_dir(
                self.TestObj(fsm, 'test_dir', 'another_name'))
            fsm_exist_ok.register_dir(self.TestObj(fsm, 'test_dir',
                                                   'another_name2'),
                                      check_dir_registered=False)
        except:
            self.fail("Folder registrable test fail when it's disabled")

        with self.assertRaises(FileStructManager.FSMException):
            fsm.register_dir(self.TestObj(fsm, 'another_dir', 'test_name'))
            fsm.register_dir(self.TestObj(fsm, 'another_dir', 'another_name'))
        with self.assertRaises(FileStructManager.FSMException):
            fsm_exist_ok.register_dir(
                self.TestObj(fsm, 'another_dir', 'test_name'))
            fsm_exist_ok.register_dir(
                self.TestObj(fsm, 'another_dir', 'another_name'))

        try:
            fsm.register_dir(self.TestObj(fsm, 'another_dir2', 'test_name'),
                             check_name_registered=False)
            fsm_exist_ok.register_dir(self.TestObj(fsm, 'another_dir2',
                                                   'test_name'),
                                      check_name_registered=False)
        except:
            self.fail("Folder registrable test fail when it's disabled")

        os.makedirs(os.path.join(self.base_dir, 'dir_dir', 'dir'))
        with self.assertRaises(FileStructManager.FSMException):
            fsm.register_dir(self.TestObj(fsm, 'dir_dir', 'name_name'))

        try:
            fsm_exist_ok.register_dir(self.TestObj(fsm, 'dir_dir',
                                                   'name_name'))
        except:
            self.fail("Folder registrable test fail when exists_ok=True")