コード例 #1
0
ファイル: main.py プロジェクト: aashnamsft/elastic-stale
def main(local_rank, c10d_backend, rdzv_init_url, max_world_size, classy_args):
    torch.manual_seed(0)
    set_video_backend(classy_args.video_backend)

    # Loads config, sets up task
    config = load_json(classy_args.config_file)

    task = build_task(config)

    # Load checkpoint, if available
    checkpoint = load_checkpoint(classy_args.checkpoint_folder)
    task.set_checkpoint(checkpoint)

    pretrained_checkpoint = load_checkpoint(classy_args.pretrained_checkpoint_folder)
    if pretrained_checkpoint is not None:
        assert isinstance(
            task, FineTuningTask
        ), "Can only use a pretrained checkpoint for fine tuning tasks"
        task.set_pretrained_checkpoint(pretrained_checkpoint)

    hooks = [
        LossLrMeterLoggingHook(classy_args.log_freq),
        ModelComplexityHook(),
        TimeMetricsHook(),
    ]

    if classy_args.checkpoint_folder != "":
        args_dict = vars(classy_args)
        args_dict["config"] = config
        hooks.append(
            CheckpointHook(
                classy_args.checkpoint_folder,
                args_dict,
                checkpoint_period=classy_args.checkpoint_period,
            )
        )
    if classy_args.profiler:
        hooks.append(ProfilerHook())

    task.set_hooks(hooks)

    assert c10d_backend == Backend.NCCL or c10d_backend == Backend.GLOO
    if c10d_backend == torch.distributed.Backend.NCCL:
        # needed to enable NCCL error handling
        os.environ["NCCL_BLOCKING_WAIT"] = "1"

    coordinator = CoordinatorP2P(
        c10d_backend=c10d_backend,
        init_method=rdzv_init_url,
        max_num_trainers=max_world_size,
        process_group_timeout=60000,
    )
    trainer = ElasticTrainer(
        use_gpu=classy_args.device == "gpu",
        num_dataloader_workers=classy_args.num_workers,
        local_rank=local_rank,
        elastic_coordinator=coordinator,
        input_args={},
    )
    trainer.train(task)
コード例 #2
0
ファイル: args.py プロジェクト: simran2905/ClassyVision-1
def parse_args():
    """Parse arguments.

    Parses the args from argparse. If hydra is installed, uses hydra based args
    (experimental).
    """
    if hydra_available:
        global args, config
        _parse_hydra_args()
        return args, config
    else:
        args = parse_train_arguments()
        config = load_json(args.config_file)
        return args, config
コード例 #3
0
    @hydra.main(config_path="hydra_configs", config_name="args")
    def hydra_main(cfg):
        args = cfg
        check_generic_args(cfg)
        config = omegaconf.OmegaConf.to_container(cfg.config)
        main(args, config)


# run all the things:
if __name__ == "__main__":
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    logging.info("Classy Vision's default training script.")

    # This imports all modules in the same directory as classy_train.py
    # Because of the way Classy Vision's registration decorators work,
    # importing a module has a side effect of registering it with Classy
    # Vision. This means you can give classy_train.py a config referencing your
    # custom module (e.g. my_dataset) and it'll actually know how to
    # instantiate it.
    file_root = Path(__file__).parent
    import_all_packages_from_directory(file_root)

    if hydra_available:
        hydra_main()
    else:
        args = parse_train_arguments()
        config = load_json(args.config_file)
        main(args, config)
コード例 #4
0
    def test_load_config(self):
        expected_config = self._get_config()
        config = util.load_json(self._json_config_file)

        self.assertEqual(config, expected_config)