Exemplo n.º 1
0
def run_dist_training(rank_id: int, world_size: int, task: str,
                      task_cfg: CfgNode, parsed_args, model, dist_url):
    """method to run on distributed process
       passed to multiprocessing.spawn
    
    Parameters
    ----------
    rank_id : int
        rank id, ith spawned process 
    world_size : int
        total number of spawned process
    task : str
        task name (passed to builder)
    task_cfg : CfgNode
        task builder (passed to builder)
    parsed_args : [type]
        parsed arguments from command line
    """
    devs = ["cuda:{}".format(rank_id)]
    # set up distributed
    setup(rank_id, world_size, dist_url)
    dist_utils.synchronize()
    # move model to device before building optimizer.
    # quick fix for resuming of DDP
    # TODO: need to be refined in future
    model.set_device(devs[0])
    # build optimizer
    optimizer = optim_builder.build(task, task_cfg.optim, model)
    # build dataloader with trainer
    with Timer(name="Dataloader building", verbose=True):
        dataloader = dataloader_builder.build(task, task_cfg.data, seed=rank_id)
    # build trainer
    trainer = engine_builder.build(task, task_cfg.trainer, "trainer", optimizer,
                                   dataloader)
    trainer.set_device(
        devs
    )  # need to be placed after optimizer built (potential pytorch issue)
    trainer.resume(parsed_args.resume)
    # trainer.init_train()
    logger.info("Start training")
    while not trainer.is_completed():
        trainer.train()
        if rank_id == 0:
            trainer.save_snapshot()
        dist_utils.synchronize()  # one synchronization per epoch

    # clean up distributed
    cleanup()
Exemplo n.º 2
0
    logger.info("Task configuration backed up at %s" % cfg_bak_file)
    # device config
    if task_cfg.device == "cuda":
        world_size = task_cfg.num_processes
        assert torch.cuda.is_available(), "please check your devices"
        assert torch.cuda.device_count(
        ) >= world_size, "cuda device {} is less than {}".format(
            torch.cuda.device_count(), world_size)
        devs = ["cuda:{}".format(i) for i in range(world_size)]
    else:
        devs = ["cpu"]
    # build model
    model = model_builder.build(task, task_cfg.model)
    model.set_device(devs[0])
    # load data
    with Timer(name="Dataloader building", verbose=True):
        dataloader = dataloader_builder.build(task, task_cfg.data)
    # build optimizer
    optimizer = optim_builder.build(task, task_cfg.optim, model)
    # build trainer
    trainer = engine_builder.build(task, task_cfg.trainer, "trainer", optimizer,
                                   dataloader)
    trainer.set_device(devs)
    trainer.resume(parsed_args.resume)
    # trainer.init_train()
    logger.info("Start training")
    while not trainer.is_completed():
        trainer.train()
        trainer.save_snapshot()
    logger.info("Training completed.")