Пример #1
0
 def test_restore_from_experiment(self, tmp_path_factory):
     """Test restoring from experiment."""
     for i, arch_variant in enumerate(self.arch_variants):
         base_dir = tmp_path_factory.getbasetemp()
         config = get_base_config_template(base_dir, f'restore_experiment_{i}', arch_variant)
         classification_task(
             config,
             base_dir / 'experiments',
             RandomQuantDataLoader,
             get_tensorboard_hooks,
             base_dir / 'experiments' / f'teacher_{i}'
         )
Пример #2
0
def test_run_experiment_on_platform(tmp_path):
    config = get_base_config_template(tmp_path, 'dummy_experiment', {
        'x_quant': 'ls-2',
        'w_quant': 'ls-1'
    })

    platform = LocalComputePlatform(str(tmp_path))

    experiment = Experiment(classification_task, config, RandomQuantDataLoader,
                            get_tensorboard_hooks)
    platform.run(experiment)

    assert (tmp_path / experiment.name / 'config.yaml').exists()
    assert (tmp_path / experiment.name / 'metrics' / 'test.csv').exists()
Пример #3
0
 def test_init_from_checkpoint(self, tmp_path_factory):
     """Test initializing from checkpoint."""
     for i, arch_variant in enumerate(self.arch_variants):
         base_dir = tmp_path_factory.getbasetemp()
         config = get_base_config_template(base_dir, f'init_from_checkpoint_{i}', arch_variant)
         config['init_from_checkpoint'] = str(
             base_dir / 'experiments' / f'teacher_{i}' / 'checkpoints' / 'checkpoint_1.pt'
         )
         classification_task(
             config,
             base_dir / 'experiments',
             RandomQuantDataLoader,
             get_tensorboard_hooks
         )
Пример #4
0
    def test_train_regular_classification_task(self, tmp_path_factory):
        """Train a model from scratch, which will be used as the teacher."""
        for i, arch_variant in enumerate(self.arch_variants):
            base_dir = tmp_path_factory.getbasetemp()
            config = get_base_config_template(base_dir, f'teacher_{i}', arch_variant)
            classification_task(
                config,
                base_dir / 'experiments',
                RandomQuantDataLoader,
                get_tensorboard_hooks
            )

            with open(str(base_dir / 'experiments' / f'teacher_{i}' / 'config.yaml'), 'w') as f:
                yaml.dump(config, f)
Пример #5
0
    def test_train_student(self, tmp_path_factory):
        """Train a student model using the teacher from above."""
        for i, arch_variant in enumerate(self.arch_variants):
            base_dir = tmp_path_factory.getbasetemp()
            config = get_base_config_template(base_dir, f'student_{i}', arch_variant)
            config['model']['kd_config'] = {
                'teacher_config_path': str(
                    base_dir / 'experiments' / f'teacher_{i}' / 'config.yaml'
                ),
                'teacher_checkpoint_path': str(
                    base_dir / 'experiments' / f'teacher_{i}' / 'checkpoints' / 'checkpoint_1.pt'
                ),
                'freeze_teacher': True,
                'train_mode': True,
                'criterion_config': {'temperature': 1}
            }

            classification_task(
                config,
                base_dir / 'experiments',
                RandomQuantDataLoader,
                get_tensorboard_hooks
            )