Exemplo n.º 1
0
 def test_cfg_fail_on_empty(self):
     try:
         SSLHydraConfig.from_configs()
         self.fail("We should fail if config is not specified")
     except ConfigCompositionException:
         # we must specify the base config otherwise it fails
         pass
Exemplo n.º 2
0
 def test_cfg_fail_composition(self):
     # compose the configs and check that the model is changed
     try:
         SSLHydraConfig.from_configs([
             "config=test/integration_test/quick_simclr",
             "config/pretrain/simclr/models=resnext101",
         ])
         self.fail(
             "We should fail for invalid composition. "
             "+ is necessary as the group does not exists in defaults")
     except HydraException:
         pass
Exemplo n.º 3
0
    def test_run(self, config_file_path: str):
        """
        Instantiate and run all the test tasks

        Arguments:
            config_file_path {str} -- path to the config for the task to be run
        """
        logger.info(f"Loading {config_file_path}")
        cfg = SSLHydraConfig.from_configs([config_file_path])
        args, config = convert_to_attrdict(cfg.default_cfg)
        checkpoint_folder = get_checkpoint_folder(config)

        # Complete the data localization at runtime
        config.DATA.TRAIN.DATA_PATHS = [
            pkg_resources.resource_filename(__name__, "test_data")
        ]

        # run training and make sure no exception is raised
        dist_run_id = get_dist_run_id(config, config.DISTRIBUTED.NUM_NODES)
        train_main(
            config,
            dist_run_id=dist_run_id,
            checkpoint_path=None,
            checkpoint_folder=checkpoint_folder,
            local_rank=0,
            node_id=0,
            hook_generator=default_hook_generator,
        )
Exemplo n.º 4
0
 def test_benchmark_model(self, filepath: str):
     logger.info(f"Loading {filepath}")
     cfg = SSLHydraConfig.from_configs(
         [filepath, "config.DISTRIBUTED.NUM_PROC_PER_NODE=1"])
     _, config = convert_to_attrdict(cfg.default_cfg)
     if not is_fsdp_model_config(config):
         build_model(config.MODEL, config.OPTIMIZER)
Exemplo n.º 5
0
    def test_run(self, config_file_path: str):
        """
        Instantiate and run all the test tasks

        Arguments:
            config_file_path {str} -- path to the config for the task to be run
        """
        logger.info(f"Loading {config_file_path}")
        cfg = SSLHydraConfig.from_configs([config_file_path])
        args, config = convert_to_attrdict(cfg.default_cfg)
        checkpoint_folder = get_checkpoint_folder(config)

        # Complete the data localization at runtime
        config.DATA.TRAIN.DATA_PATHS = [
            pkg_resources.resource_filename(__name__, "test_data")
        ]

        if torch.distributed.is_initialized():
            # Destroy process groups as torch may be initialized with NCCL, which
            # is incompatible with test_cpu_regnet_moco.yaml
            torch.distributed.destroy_process_group()

        # run training and make sure no exception is raised
        dist_run_id = get_dist_run_id(config, config.DISTRIBUTED.NUM_NODES)
        train_main(
            config,
            dist_run_id=dist_run_id,
            checkpoint_path=None,
            checkpoint_folder=checkpoint_folder,
            local_rank=0,
            node_id=0,
            hook_generator=default_hook_generator,
        )
Exemplo n.º 6
0
 def test_load_cfg_success(self):
     # simply load from the config and this should pass
     self.assertTrue(
         SSLHydraConfig.from_configs(
             ["config=test/integration_test/quick_simclr"]),
         "config must be loaded successfully",
     )
Exemplo n.º 7
0
 def test_cfg_composition(self):
     # compose the configs and check that the model is changed
     cfg = SSLHydraConfig.from_configs([
         "config=test/integration_test/quick_simclr",
         "+config/pretrain/simclr/models=resnext101",
     ])
     _, config = convert_to_attrdict(cfg.default_cfg)
     self.assertEqual(config.MODEL.TRUNK.RESNETS.DEPTH, 101,
                      "config composition failed")
Exemplo n.º 8
0
 def test_loss_build(self, filepath):
     logger.info(f"Loading {filepath}")
     cfg = SSLHydraConfig.from_configs([
         filepath,
         "config.DATA.TRAIN.DATA_SOURCES=[synthetic]",
         "config.DATA.TEST.DATA_SOURCES=[synthetic]",
     ])
     _, config = convert_to_attrdict(cfg.default_cfg)
     task = SelfSupervisionTask.from_config(config)
     task.datasets, _ = task.build_datasets()
     self.assertTrue(task._build_loss(), "failed to build loss")
Exemplo n.º 9
0
 def test_cfg_key_addition(self):
     # compose the configs and check that the new key is inserted
     cfg = SSLHydraConfig.from_configs([
         "config=test/integration_test/quick_simclr",
         "+config.LOSS.simclr_info_nce_loss.buffer_params.MY_TEST_KEY=dummy",
     ])
     _, config = convert_to_attrdict(cfg.default_cfg)
     self.assertTrue(
         "MY_TEST_KEY" in config.LOSS.simclr_info_nce_loss.buffer_params,
         "something went wrong, new key not added. Fail.",
     )
Exemplo n.º 10
0
 def test_pytorch_loss(self):
     cfg = SSLHydraConfig.from_configs([
         "config=test/integration_test/quick_simclr",
         "config.LOSS.name=CosineEmbeddingLoss",
         "+config.LOSS.CosineEmbeddingLoss.margin=1.0",
         "config.DATA.TRAIN.DATA_SOURCES=[synthetic]",
         "config.DATA.TEST.DATA_SOURCES=[synthetic]",
     ])
     _, config = convert_to_attrdict(cfg.default_cfg)
     task = SelfSupervisionTask.from_config(config)
     task.datasets, _ = task.build_datasets()
     self.assertTrue(task._build_loss(), "failed to build loss")
Exemplo n.º 11
0
 def test_sqrt_lr_scaling(self):
     # compose the configs and check that the LR is changed
     cfg = SSLHydraConfig.from_configs([
         "config=test/integration_test/quick_simclr",
         "+config/pretrain/simclr/models=resnext101",
         "config.OPTIMIZER.param_schedulers.lr.auto_lr_scaling.auto_scale=True",
         'config.OPTIMIZER.param_schedulers.lr.name="linear"',
         'config.OPTIMIZER.param_schedulers.lr.auto_lr_scaling.scaling_type="sqrt"',
     ])
     _, config = convert_to_attrdict(cfg.default_cfg)
     param_schedulers = config.OPTIMIZER.param_schedulers.lr
     self.assertEqual(0.3 * (0.125**0.5), param_schedulers.end_value)
Exemplo n.º 12
0
 def test_cfg_cli_composition(self):
     # compose the configs and check that the model is changed
     cfg = SSLHydraConfig.from_configs([
         "config=test/integration_test/quick_simclr",
         "+config/pretrain/simclr/models=resnext101",
         "config.MODEL.TRUNK.TRUNK_PARAMS.RESNETS.GROUPS=32",
         "config.MODEL.TRUNK.TRUNK_PARAMS.RESNETS.WIDTH_PER_GROUP=16",
     ])
     _, config = convert_to_attrdict(cfg.default_cfg)
     self.assertEqual(
         config.MODEL.TRUNK.TRUNK_PARAMS.RESNETS.GROUPS,
         32,
         "config composition failed",
     )
     self.assertEqual(
         config.MODEL.TRUNK.TRUNK_PARAMS.RESNETS.WIDTH_PER_GROUP,
         16,
         "config composition failed",
     )
Exemplo n.º 13
0
 def test_integration_test_model(self, filepath: str):
     logger.info(f"Loading {filepath}")
     cfg = SSLHydraConfig.from_configs([filepath])
     _, config = convert_to_attrdict(cfg.default_cfg)
     if not is_fsdp_model_config(config):
         build_model(config.MODEL, config.OPTIMIZER)
Exemplo n.º 14
0
 def test_pretrain_model(self, filepath):
     logger.info(f"Loading {filepath}")
     cfg = SSLHydraConfig.from_configs([filepath])
     _, config = convert_to_attrdict(cfg.default_cfg)
     build_model(config.MODEL, config.OPTIMIZER)
Exemplo n.º 15
0
 def test_meter_build(self, filepath):
     logger.info(f"Loading {filepath}")
     cfg = SSLHydraConfig.from_configs([filepath])
     _, config = convert_to_attrdict(cfg.default_cfg)
     meters = SelfSupervisionTask.from_config(config)._build_meters()
     self.assertGreaterEqual(len(meters), 0, "Failed to build meters")
Exemplo n.º 16
0
 def test_integration_test_config(self, filepath):
     logger.warning(f"Loading {filepath}")
     self.assertTrue(
         SSLHydraConfig.from_configs([filepath]),
         "config must be loaded successfully",
     )