def train_model_distributed(config, summary_writer): assert (config.use_cuda_if_available and torch.cuda.is_available()) or config.distributed_world_size == 1, ( "distributed training is only available for GPU training") assert ( config.distributed_world_size == 1 or not config.task.__class__.__name__ == "DisjointMultitask.Config" ), "Distributed training currently not supported for DisjointMultitask" assert (config.distributed_world_size == 1 or config.distributed_world_size <= torch.cuda.device_count()), ( f"Only {torch.cuda.device_count()} GPUs are available, " "{config.distributed_world_size} GPUs were requested") print( f"\n=== Starting training, World size is {config.distributed_world_size}" ) if not config.use_cuda_if_available or not torch.cuda.is_available(): run_single(0, config_to_json(PyTextConfig, config), 1, None, summary_writer, None) else: with tempfile.NamedTemporaryFile(delete=False, suffix=".dist_sync") as sync_file: dist_init_method = "file://" + sync_file.name metadata = prepare_task_metadata(config) spawn( run_single, ( config_to_json(PyTextConfig, config), config.distributed_world_size, dist_init_method, summary_writer, metadata, ), config.distributed_world_size, )
def train_model_distributed(config, metric_channels: Optional[List[Channel]]): assert ( config.use_cuda_if_available and torch.cuda.is_available() ) or config.distributed_world_size == 1, ( "distributed training is only available for GPU training" ) assert ( config.distributed_world_size == 1 or config.distributed_world_size <= torch.cuda.device_count() ), ( f"Only {torch.cuda.device_count()} GPUs are available, " "{config.distributed_world_size} GPUs were requested" ) print(f"\n=== Starting training, World size is {config.distributed_world_size}") if not config.use_cuda_if_available or not torch.cuda.is_available(): run_single( rank=0, config_json=config_to_json(PyTextConfig, config), world_size=1, dist_init_method=None, metadata=None, metric_channels=metric_channels, ) else: with tempfile.NamedTemporaryFile( delete=False, suffix=".dist_sync" ) as sync_file: dist_init_method = "file://" + sync_file.name metadata = prepare_task_metadata(config) spawn( run_single, ( config_to_json(PyTextConfig, config), config.distributed_world_size, dist_init_method, metadata, [], ), config.distributed_world_size, )