コード例 #1
0
def test_resolve_distributed_mode_slurm3():
    args = argparse.Namespace(
        multiprocessing_distributed=True,
        dist_world_size=None,
        dist_rank=None,
        ngpu=1,
        local_rank=None,
        dist_launcher="slurm",
        dist_backend="nccl",
        dist_init_method="env://",
        dist_master_addr=None,
        dist_master_port=10000,
    )
    env = dict(
        SLURM_PROCID="0",
        SLURM_NTASKS="1",
        SLURM_STEP_NUM_NODES="1",
        SLURM_STEP_NODELIST="localhost",
        SLURM_NODEID="0",
        CUDA_VISIBLE_DEVICES="0,1",
    )

    e = ProcessPoolExecutor(max_workers=2)
    with unittest.mock.patch.dict("os.environ", dict(env, SLURM_LOCALID="0")):
        resolve_distributed_mode(args)
        option = build_dataclass(DistributedOption, args)
        fn = e.submit(option.init)

    with unittest.mock.patch.dict("os.environ", dict(env, SLURM_LOCALID="0")):
        option2 = build_dataclass(DistributedOption, args)
        fn2 = e.submit(option2.init)

    fn.result()
    fn2.result()
コード例 #2
0
def test_default_work():
    parser = AbsTask.get_parser()
    args = parser.parse_args([])
    resolve_distributed_mode(args)
    option = build_dataclass(DistributedOption, args)
    option.init_options()
    option.init_torch_distributed()
コード例 #3
0
def test_init_cpu5():
    args = argparse.Namespace(
        multiprocessing_distributed=True,
        dist_world_size=2,
        dist_rank=None,
        ngpu=0,
        local_rank=None,
        dist_launcher=None,
        distributed=True,
        dist_backend="gloo",
        dist_init_method="env://",
        dist_master_addr="localhost",
        dist_master_port=free_port(),
    )
    args.dist_rank = 0
    option = build_dataclass(DistributedOption, args)
    args.dist_rank = 1
    option2 = build_dataclass(DistributedOption, args)
    with ProcessPoolExecutor(max_workers=2) as e:
        fn = e.submit(option.init)
        fn2 = e.submit(option2.init)
        fn.result()
        fn2.result()
コード例 #4
0
def test_init_cpu3():
    args = argparse.Namespace(
        multiprocessing_distributed=True,
        dist_world_size=2,
        dist_rank=None,
        ngpu=0,
        local_rank=None,
        dist_launcher=None,
        distributed=True,
        dist_backend="gloo",
        dist_init_method="env://",
        dist_master_addr="localhost",
        dist_master_port=None,
    )
    args.dist_rank = 0
    option = build_dataclass(DistributedOption, args)
    args.dist_rank = 1
    option2 = build_dataclass(DistributedOption, args)
    with ThreadPoolExecutor(max_workers=2) as e:
        fn = e.submit(_init, option)
        fn2 = e.submit(_init, option2)
        with pytest.raises(RuntimeError):
            fn.result()
            fn2.result()
コード例 #5
0
ファイル: gan_trainer.py プロジェクト: sadhusamik/espnet
 def build_options(cls, args: argparse.Namespace) -> TrainerOptions:
     """Build options consumed by train(), eval(), and plot_attention()."""
     assert check_argument_types()
     return build_dataclass(GANTrainerOptions, args)
コード例 #6
0
    def main_worker(cls, args: argparse.Namespace):
        assert check_argument_types()

        # 0. Init distributed process
        distributed_option = build_dataclass(DistributedOption, args)
        # Setting distributed_option.dist_rank, etc.
        distributed_option.init_options()

        # NOTE(kamo): Don't use logging before invoking logging.basicConfig()
        if not distributed_option.distributed or distributed_option.dist_rank == 0:
            if not distributed_option.distributed:
                _rank = ""
            else:
                _rank = (f":{distributed_option.dist_rank}/"
                         f"{distributed_option.dist_world_size}")

            # NOTE(kamo):
            # logging.basicConfig() is invoked in main_worker() instead of main()
            # because it can be invoked only once in a process.
            # FIXME(kamo): Should we use logging.getLogger()?
            logging.basicConfig(
                level=args.log_level,
                format=f"[{os.uname()[1].split('.')[0]}{_rank}]"
                f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
            )
        else:
            # Suppress logging if RANK != 0
            logging.basicConfig(
                level="ERROR",
                format=f"[{os.uname()[1].split('.')[0]}"
                f":{distributed_option.dist_rank}/{distributed_option.dist_world_size}]"
                f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
            )
        # Invoking torch.distributed.init_process_group
        distributed_option.init_torch_distributed()

        # 1. Set random-seed
        set_all_random_seed(args.seed)
        torch.backends.cudnn.enabled = args.cudnn_enabled
        torch.backends.cudnn.benchmark = args.cudnn_benchmark
        torch.backends.cudnn.deterministic = args.cudnn_deterministic
        if args.detect_anomaly:
            logging.info("Invoking torch.autograd.set_detect_anomaly(True)")
            torch.autograd.set_detect_anomaly(args.detect_anomaly)

        # 2. Build model
        model = cls.build_model(args=args)
        if not isinstance(model, AbsESPnetModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
            )
        model = model.to(
            dtype=getattr(torch, args.train_dtype),
            device="cuda" if args.ngpu > 0 else "cpu",
        )
        for t in args.freeze_param:
            for k, p in model.named_parameters():
                if k.startswith(t + ".") or k == t:
                    logging.info(f"Setting {k}.requires_grad = False")
                    p.requires_grad = False

        # 3. Build optimizer
        optimizers = cls.build_optimizers(args, model=model)

        # 4. Build schedulers
        schedulers = []
        for i, optim in enumerate(optimizers, 1):
            suf = "" if i == 1 else str(i)
            name = getattr(args, f"scheduler{suf}")
            conf = getattr(args, f"scheduler{suf}_conf")
            if name is not None:
                # cls_ = scheduler_classes.get(name)
                cls_ = TriStageLR
                if cls_ is None:
                    raise ValueError(
                        f"must be one of {list(scheduler_classes)}: {name}")
                scheduler = cls_(optim, **conf)
            else:
                scheduler = None

            schedulers.append(scheduler)

        logging.info(pytorch_cudnn_version())
        logging.info(model_summary(model))
        for i, (o, s) in enumerate(zip(optimizers, schedulers), 1):
            suf = "" if i == 1 else str(i)
            logging.info(f"Optimizer{suf}:\n{o}")
            logging.info(f"Scheduler{suf}: {s}")

        # 5. Dump "args" to config.yaml
        # NOTE(kamo): "args" should be saved after object-buildings are done
        #  because they are allowed to modify "args".
        output_dir = Path(args.output_dir)
        if not distributed_option.distributed or distributed_option.dist_rank == 0:
            output_dir.mkdir(parents=True, exist_ok=True)
            with (output_dir / "config.yaml").open("w", encoding="utf-8") as f:
                logging.info(
                    f'Saving the configuration in {output_dir / "config.yaml"}'
                )
                yaml_no_alias_safe_dump(vars(args),
                                        f,
                                        indent=4,
                                        sort_keys=False)

        # 6. Loads pre-trained model
        for p in args.init_param:
            logging.info(f"Loading pretrained params from {p}")
            load_pretrained_model(
                model=model,
                init_param=p,
                # NOTE(kamo): "cuda" for torch.load always indicates cuda:0
                #   in PyTorch<=1.4
                map_location=f"cuda:{torch.cuda.current_device()}"
                if args.ngpu > 0 else "cpu",
            )

        if args.dry_run:
            pass
        elif args.collect_stats:
            # Perform on collect_stats mode. This mode has two roles
            # - Derive the length and dimension of all input data
            # - Accumulate feats, square values, and the length for whitening
            logging.info(args)

            if args.valid_batch_size is None:
                args.valid_batch_size = args.batch_size

            if len(args.train_shape_file) != 0:
                train_key_file = args.train_shape_file[0]
            else:
                train_key_file = None
            if len(args.train_pseudo_shape_file) != 0:
                train_pseudo_key_file = args.train_pseudo_shape_file[0]
            else:
                train_pseudo_key_file = None
            if len(args.valid_shape_file) != 0:
                valid_key_file = args.valid_shape_file[0]
            else:
                valid_key_file = None

            collect_stats(
                model=model,
                train_iter=cls.build_streaming_iterator(
                    data_path_and_name_and_type=args.
                    train_data_path_and_name_and_type,
                    key_file=train_key_file,
                    batch_size=args.batch_size,
                    dtype=args.train_dtype,
                    num_workers=args.num_workers,
                    allow_variable_data_keys=args.allow_variable_data_keys,
                    ngpu=args.ngpu,
                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
                    collate_fn=cls.build_collate_fn(args, train=False),
                ),
                train_pseudo_iter=cls.build_streaming_iterator(
                    data_path_and_name_and_type=args.
                    train_pseudo_data_path_and_name_and_type,
                    key_file=train_pseudo_key_file,
                    batch_size=args.batch_size,
                    dtype=args.train_dtype,
                    num_workers=args.num_workers,
                    allow_variable_data_keys=args.allow_variable_data_keys,
                    ngpu=args.ngpu,
                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
                    collate_fn=cls.build_collate_fn(args, train=False),
                ),
                valid_iter=cls.build_streaming_iterator(
                    data_path_and_name_and_type=args.
                    valid_data_path_and_name_and_type,
                    key_file=valid_key_file,
                    batch_size=args.valid_batch_size,
                    dtype=args.train_dtype,
                    num_workers=args.num_workers,
                    allow_variable_data_keys=args.allow_variable_data_keys,
                    ngpu=args.ngpu,
                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
                    collate_fn=cls.build_collate_fn(args, train=False),
                ),
                output_dir=output_dir,
                ngpu=args.ngpu,
                log_interval=args.log_interval,
                write_collected_feats=args.write_collected_feats,
            )
        else:

            # 7. Build iterator factories
            assert not args.multiple_iterator
            train_iter_factory = cls.build_iter_factory(
                args=args,
                distributed_option=distributed_option,
                mode="train",
            )

            train_pseudo_iter_factory = cls.build_iter_factory(
                args=args,
                distributed_option=distributed_option,
                mode="pseudo",
            )

            valid_iter_factory = cls.build_iter_factory(
                args=args,
                distributed_option=distributed_option,
                mode="valid",
            )
            if args.num_att_plot != 0:
                plot_attention_iter_factory = cls.build_iter_factory(
                    args=args,
                    distributed_option=distributed_option,
                    mode="plot_att",
                )
            else:
                plot_attention_iter_factory = None

            # 8. Start training
            if args.use_wandb:
                if (not distributed_option.distributed
                        or distributed_option.dist_rank == 0):
                    if args.wandb_project is None:
                        project = ("ESPnet_" + cls.__name__ +
                                   str(Path(".").resolve()).replace("/", "_"))
                    else:
                        project = args.wandb_project
                    if args.wandb_id is None:
                        wandb_id = str(output_dir).replace("/", "_")
                    else:
                        wandb_id = args.wandb_id

                    wandb.init(
                        project=project,
                        dir=output_dir,
                        id=wandb_id,
                        resume="allow",
                    )
                    wandb.config.update(args)
                else:
                    # wandb also supports grouping for distributed training,
                    # but we only logs aggregated data,
                    # so it's enough to perform on rank0 node.
                    args.use_wandb = False

            # Don't give args to trainer.run() directly!!!
            # Instead of it, define "Options" object and build here.
            trainer_options = cls.trainer.build_options(args)
            cls.trainer.run(
                model=model,
                optimizers=optimizers,
                schedulers=schedulers,
                train_iter_factory=train_iter_factory,
                train_pseudo_iter_factory=train_pseudo_iter_factory,
                valid_iter_factory=valid_iter_factory,
                plot_attention_iter_factory=plot_attention_iter_factory,
                trainer_options=trainer_options,
                distributed_option=distributed_option,
            )
コード例 #7
0
def test_build_dataclass_insufficient():
    args = Namespace(a="foo")
    with pytest.raises(ValueError):
        build_dataclass(A, args)
コード例 #8
0
def test_build_dataclass():
    args = Namespace(a="foo", b="bar")
    a = build_dataclass(A, args)
    assert a.a == args.a
    assert a.b == args.b