def _start_reduction_threads(self): num_buckets = len(self.bucket_sizes) self._reduction_queues = [queue.Queue() for _ in range(num_buckets)] self._reduction_threads = [] self._reduction_streams = [[] for _ in range(num_buckets)] self._nccl_streams = [] self._default_streams = [] for dev_id in self.device_ids: with torch.cuda.device(dev_id): # TODO: don't assume we're on a default stream self._default_streams.append(torch.cuda.current_stream()) self._nccl_streams.append(torch.cuda.Stream()) for reduction_queue, reduction_streams in zip(self._reduction_queues, self._reduction_streams): for dev_id in self.device_ids: with torch.cuda.device(dev_id): reduction_streams.append(torch.cuda.Stream()) # We only use the first device for distributed reductions dist._register_stream(reduction_streams[0]) if dist._backend == dist.dist_backend.NCCL: group_id = dist.group.WORLD else: group_id = dist.new_group() self._reduction_threads.append(threading.Thread( target=self._reduction_thread_fn, args=(reduction_queue, group_id, self.device_ids, reduction_streams, self._nccl_streams))) self._reduction_threads[-1].daemon = True self._reduction_threads[-1].start()
def _init_group_test(self): group = [1, 2] group_id = dist.new_group(group) rank = dist.get_rank() if rank not in group: return ([], None, rank) return (group, group_id, rank)
def _remote_worker_process(self, ddp_mode): gLogger.info("The remote worker is running.") dist.init_process_group( backend="gloo", init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), world_size=self.world_size, rank=self.rank, ) if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE): # new_group needs to be called on ranks. dist.new_group(TRAINER_RANKS) global shutdown_signal with shutdown_signal: shutdown_signal.wait() gLogger.info("Exiting remote worker.") dist.destroy_process_group()
def initTorchDist(self): self.printToLog("dist args:", 'nccl', self.dist_addr, self.rank, self.world_size) dist.init_process_group('nccl', rank=self.rank, world_size=self.world_size) # dist.init_process_group('nccl', init_method=self.dist_addr, # rank=self.rank, world_size=self.world_size) self.worker_group = dist.new_group(list(range(1, self.world_size)))
def run_all_reduce(rank, size): """ all reduce fashion send """ group = dist.new_group([0, 1]) dat = torch.ones(1) dist.all_reduce(dat, op=dist.ReduceOp.SUM, group=group) print("rank ", rank, "has data ", dat[0]) return
def _get_torch_dist_group(ranks): import torch.distributed as dist with _TORCH_DIST_LOCK: pg = _TORCH_DIST_GROUPS.get(ranks, None) if not pg: pg = dist.new_group(ranks=ranks) _TORCH_DIST_GROUPS[ranks] = pg return pg
def _get_global_gloo_group(): """ Return a process group based on gloo backend, containing all the ranks The result is cached. """ if dist.get_backend() == "nccl": return dist.new_group(backend="gloo") else: return dist.group.WORLD
def __init__(self, local_master_rank, dest, src, local_rank): self.src = src self.dest = dest self.process_group = dist.new_group([src, dest]) if local_master_rank in [self.src, self.dest] and local_rank == 0: initializer_tensor = torch.Tensor([1]).cuda() dist.all_reduce(initializer_tensor, group=self.process_group) initializer_tensor = torch.Tensor([1]).cuda().half() dist.all_reduce(initializer_tensor, group=self.process_group)
def simple_group_split(world_size, rank, num_groups): groups = [] rank_list = np.split(np.arange(world_size), num_groups) rank_list = [list(map(int, x)) for x in rank_list] for i in range(num_groups): groups.append(dist.new_group(rank_list[i])) group_size = world_size // num_groups print ("Rank no.{} start sync BN on the process group of {}".format(rank, rank_list[rank//group_size])) return groups[rank//group_size]
def rendezvous(self, world_size): dist.init_process_group(self.backend, rank=self.rank, timeout=datetime.timedelta(seconds=10), world_size=2, init_method='tcp://{}:60000'.format( os.environ['MASTER_ADDR'])) self.peers = list(filter(lambda x: x != self.rank, [0, 1])) return dist.new_group(range(world_size))
def consistent_indices(self, indices, shuffle): if self.args.graph.rank == 0 and shuffle: random.shuffle(indices) # broadcast. indices = torch.IntTensor(indices) group = dist.new_group(self.args.graph.ranks) dist.broadcast(indices, src=0, group=group) return list(indices)
def run_DSGD_my(rank, size, lr, epoches, q=16): """ Distributed Synchronous SGD Example """ torch.manual_seed(1234) device = torch.device("cuda:" + str(rank) if torch.cuda.is_available() else "cpu") group_list = [] for i in range(size): group_list.append( dist.new_group([i, (i + 1) % size, ((i - 1) + size) % size])) train_set_large, bsz_large = partition_dataset_cifar_large() train_set, bsz = partition_dataset_cifar() # train_set, bsz = class_partition_cifar() if args.model_load: model = torch.load(args.DSGD_model_path) model.to(device) else: model = Net(ResidualBlock).to(device) model_loss = copy.deepcopy(model) # optimizer = my_SGD.my_SGD(model.parameters(), lr=lr) optimizer = optim.SGD(model.parameters(), lr=lr) model.train() num_batches = ceil(len(train_set.dataset) / float(bsz)) loss_iteration = [] for epoch in range(epoches): epoch_loss = 0.0 train_set_large, bsz_large = partition_dataset_cifar_large() train_set_large = enumerate(train_set_large) train_set, bsz = partition_dataset_cifar() # train_set, bsz = class_partition_cifar() index = 0 for data, target in train_set: data, target = Variable(data.to(device)), Variable( target.to(device)) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) epoch_loss += loss.item() loss.backward() dist.barrier() dist_sgd_size8(model, rank, group_list) optimizer.step() dist.barrier() # calculate the loss in the average model if index % q == 0: avg_model(model, model_loss) output_loss = loss_compute(model_loss, train_set_large) loss_iteration.append(output_loss.item()) index += 1 print('Rank ', dist.get_rank(), ', epoch ', epoch, ': ', epoch_loss / num_batches) file_name = args.DSGD_file_name print_loss_file(file_name, loss_iteration, rank, size) avg_model(model, model_loss) if rank == 0: PATH = args.DSGD_model_path_save torch.save(model, PATH)
def test_new_group(self): spec = ChunkShardingSpec( dim=0, placements=[ "rank:1/cuda:2", "rank:2/cuda:3", ], ) pg = dist.new_group(ranks=[1, 2, 3]) if self.rank >= 1: sharded_tensor = _sharded_tensor.empty(spec, 10, 20, process_group=pg) # Validate local shard. local_shards = sharded_tensor.local_shards() if self.rank >= 2: self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) self.assertEqual((5, 20), local_shard.size()) else: self.assertEqual(0, len(local_shards)) # Validate global metadata. sharded_tensor_metadata = sharded_tensor.metadata() shards_metadata = sharded_tensor_metadata.shards_metadata self.assertEqual(2, len(shards_metadata)) for shard_rank, shard_metadata in enumerate(shards_metadata): self.assertEqual([shard_rank * 5, 0], shard_metadata.shard_offsets) self.assertEqual([5, 20], shard_metadata.shard_lengths) self.assertEqual( f'rank:{shard_rank + 1}/cuda:{shard_rank + 2}', shard_metadata.placement) # Validate remote shards. remote_shards = sharded_tensor.remote_shards if self.rank >= 2: self.assertEqual(1, len(remote_shards)) else: self.assertEqual(2, len(remote_shards)) for rpc_rank, shards in remote_shards.items(): self.assertEqual(1, len(shards)) for remote_shard in shards: shard = remote_shard.to_here() self.assertEqual(rpc_rank, remote_shard.owner().id) self.assertEqual(f'rank:{rpc_rank - 1}/cuda:{rpc_rank}', shard.metadata.placement) self.assertEqual((5, 20), shard.tensor.size())
def __init__(self, mpu=None): if mpu is None: self.world_group = dist.new_group( ranks=range(dist.get_world_size())) else: self.mpu = mpu self.world_group = self.mpu.get_data_parallel_group() self.rank = dist.get_rank(group=self.world_group) self.size = dist.get_world_size(group=self.world_group) self.compression_backend = CupyBackend()
def set_online_clients(args): # Define online clients for the current round of communication for Federated Learning setting useable_ranks = args.graph.ranks ranks_shuffled = np.random.permutation(useable_ranks) online_clients = ranks_shuffled[:int(args.online_client_rate * len(useable_ranks))] online_clients = torch.IntTensor(online_clients) group = dist.new_group(args.graph.ranks) dist.broadcast(online_clients, src=0, group=group) return list(online_clients.numpy())
def master_run(size): modell = model.CNN() group = dist.new_group(range(size)) while True: for param in modell.parameters(): dist.broadcast(param.data, src=0, group=group) for param in modell.parameters(): tensor_temp = torch.zeros_like(param.data) dist.reduce(tensor_temp, dst=0, op=dist.reduce_op.SUM, group=group) param.data = tensor_temp / (size - 1)
def distributed_worker(local_rank, main_func, nprocs, dist_url, args): dist.init_process_group(backend="gloo", init_method=dist_url, world_size=nprocs, rank=local_rank) comm.synchronize() assert comm._LOCAL_PROCESS_GROUP is None pg = dist.new_group(list(range(nprocs))) comm._LOCAL_PROCESS_GROUP = pg main_func(*args)
def init_groups(): """ Creating communication groups at servers and workers. Since init groups should be executed the same way for all machines in the group, all machines will store groups of itself and other machines as well. """ global all_ps global all_workers_ps if fps > 0 or mar == 'crash': #Now we need to tolerate Byzantine PS -> Do GuanYu (for now) #Create groups of communication all_ps = dist.new_group([i for i in range(num_ps)]) #Creating groups of all workers with each of the PS....useful in collecting gradients from workers on the PS side for ps in range(num_ps): g = [i+num_ps for i in range(num_workers+1)] g[-1] = ps all_workers_ps[ps] = dist.new_group(g) #Create groups of all PS with one worker...useful in collecting aggregated gradients by workers in GuanYu for worker in range(num_workers): g = [i for i in range(num_ps+1)] g[-1] = worker+num_ps all_ps_worker[worker] = dist.new_group(g)
def __init__(self, *args, **kwargs): super(AsynchronousDistributedTraining, self).__init__(*args, **kwargs) self.gossip_step = 0 self.world_size = distributed.get_world_size() self.rank = distributed.get_rank() self.groups = [] for idx in range(self.world_size - 1): partner = (self.rank + idx + 1) % self.world_size group = distributed.new_group(ranks=[self.rank, partner]) self.groups.append(group)
def test_new_group_invalid_ranks(self): set_world_size(12) # unevenly distributed ranks = [1, 5, 10] world_rank = 5 set_world_rank(world_rank) with new_group_barrier_disabled(): with self.assertRaises(ValueError): pg = dist.new_group(ranks=ranks)
def _create_model_parallel_group(self): # Call the init process init_distributed() local_rank = int(os.getenv('LOCAL_RANK', '0')) torch.cuda.set_device(local_rank) ranks = [i for i in range(self.mp_world_size)] self.mp_group = dist.new_group(ranks)
def _index_tied_modules(self): ''' Build communication structures for tied modules. ''' tied_comms = {} if self._topo.get_dim('pipe') == 1: return tied_comms specs = self._layer_specs tie_keys = set(s.key for s in specs if isinstance(s, TiedLayerSpec)) for key in tie_keys: # Find the layers that the tied module appears in tied_layers = [] for idx, layer in enumerate(specs): if isinstance(layer, TiedLayerSpec) and layer.key == key: tied_layers.append(idx) # Find all stages with this tied module # TODO: Would be nice to remove the nested data/model parallelism loops and # TODO: instead generalize in some way, since we really just care about the # TODO: stage that owns the tied layer. Then loop over each (dp, mp, ...) # TODO: fiber to generate process groups. tied_stages = set(self.stage_owner(idx) for idx in tied_layers) for dp in range(self._grid.data_parallel_size): for mp in range(self._grid.model_parallel_size): tied_ranks = [] for s in sorted(tied_stages): if self._grid.model_parallel_size > 1: tied_ranks.append( self._grid.stage_to_global(stage_id=s, data=dp, model=mp)) else: tied_ranks.append( self._grid.stage_to_global(stage_id=s, data=dp)) group = dist.new_group(ranks=tied_ranks) # Record this tied module if we own a local copy of it. if self.global_rank in tied_ranks: assert key in self.tied_modules if key in self.tied_modules: tied_comms[key] = { 'ranks': tied_ranks, 'group': group, 'weight_attr': self.tied_weight_attrs[key], 'module': self.tied_modules[key], } # Only count the tied module once in the eyes of the FP16 optimizer if self.global_rank != tied_ranks[0]: for p in self.tied_modules[key].parameters(): p.model_parallel = False ''' if len(tied_comms) > 0: print(f'RANK={self.global_rank} tied_comms={tied_comms}') ''' return tied_comms
def setup_distributed(self): dist.init_process_group(backend='nccl', init_method='env://') self.rank = dist.get_rank() self.size = dist.get_world_size() self.gpu = self.options.local_rank torch.cuda.set_device(self.gpu) if self.options.distillation: assert self.size % 2 == 0 self.team_size = self.size // 2 self.team_rank = self.rank % self.team_size self.team = self.rank // self.team_size team_groups = [] model_comm_groups = [] self.team_ranks = [ [0] + list(range(self.team_size, self.team_size * 2)), [self.team_size] + list(range(0, self.team_size)) ] for i in range(2): team_groups.append( dist.new_group(ranks=list( range(self.team * self.team_size, (self.team + 1) * self.team_size)))) model_comm_groups.append( dist.new_group(ranks=self.team_ranks[i])) self.team_group = team_groups[self.team] self.team_leader = self.team * self.team_size self.model_comm_groups = model_comm_groups if self.options.equalize_data: for i in range(self.size // 2): equalize_distillation_group = dist.new_group( ranks=[i, i + self.size // 2]) if i == self.team_rank: self.equalize_distillation_group = equalize_distillation_group else: self.team_size = self.size self.team_rank = self.rank self.team = 0 self.team_group = dist.new_group(ranks=list(range(self.size))) self.team_leader = 0
def get_global_group(): if use_xla(): return new_groups([list(range(get_global_world_size()))]) elif torch.distributed.is_initialized(): if not hasattr(get_global_group, "_global_group"): # ideally we could use torch.distributed.group.WORLD, but it seems # to cause random NCCL hangs in some cases get_global_group._global_group = dist.new_group() return get_global_group._global_group else: return None
def __init__(self, rank_list, logger=print): self.printToLog = logger self.rank = dist.get_rank() self.rank_list = rank_list self.default_group = dist.new_group(rank_list) self.cube_phase = 0 self.def_group_size = len(rank_list) self.cube_dim = int(np.log2(self.def_group_size)) # self.printToLog('partial groups:', [ # [self.rank, cube_correspond(self.rank, turn)].sort() if cube_correspond(self.rank, turn)<self.def_group_size # else dist.new_group([self.rank]) # for turn in range(self.cube_dim) # ], # flush=True # ) self.partial_groups = [ dist.new_group([self.rank, cube_correspond(self.rank, turn)].sort()) if cube_correspond(self.rank, turn)<self.def_group_size else dist.new_group([self.rank]) for turn in range(self.cube_dim) ]
def get_global_group(): """ Singleton pytorch distributed group Inspired by https://github.com/pytorch/fairseq """ if dist.is_initialized(): if not hasattr(get_global_group, "_global_group"): get_global_group._global_group = dist.new_group() return get_global_group._global_group else: return None
def test_invalid_sharding(self): self.init_pg() spec = ChunkShardingSpec(dim=0, placements=["rank:1/cuda:1"]) pg = dist.new_group(ranks=[2, 3]) if self.rank < 2: with self.assertRaisesRegex(ValueError, 'not part of process group'): _sharded_tensor.empty(spec, 10, 20, process_group=pg) spec = ChunkShardingSpec(dim='H', placements=["rank:1/cuda:1"]) with self.assertRaisesRegex(ValueError, 'needs to be an integer'): _sharded_tensor.empty(spec, 10, 20) for dim in [2, 3, 4, -3, -4, -5]: spec = ChunkShardingSpec(dim=dim, placements=["rank:1/cuda:1"]) with self.assertRaisesRegex(ValueError, 'Invalid sharding dim'): _sharded_tensor.empty(spec, 10, 20) spec = ChunkShardingSpec(dim=0, placements=["rank:5/cuda:1"]) with self.assertRaisesRegex(ValueError, 'Invalid rank'): _sharded_tensor.empty(spec, 10, 20) spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) sharded_tensor = _sharded_tensor.empty(spec, 10, 20) tensor = torch.empty(10, 20) with self.assertRaisesRegex(RuntimeError, "torch function 'add' not supported for ShardedTensor!"): torch.add(sharded_tensor, tensor) spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) with self.assertRaisesRegex(ValueError, 'Only torch.strided layout is currently supported'): _sharded_tensor.empty(spec, 10, 20, layout=torch.sparse) spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) with self.assertRaisesRegex(ValueError, 'Only torch.contiguous_format memory_format is currently supported'): _sharded_tensor.empty(spec, 10, 20, memory_format=torch.channels_last) spec = ChunkShardingSpec(dim=0, placements=["worker0/cuda:1"]) with self.assertRaisesRegex(RuntimeError, 'RPC framework needs to be initialized'): _sharded_tensor.empty(spec, 10, 20) spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) with self.assertRaisesRegex(RuntimeError, 'RPC was not initialized'): st = _sharded_tensor.empty(spec, 10, 20) st.remote_shards self.init_rpc() # ShardedTensor was initialized before RPC. with self.assertRaisesRegex(RuntimeError, 'RPC was not initialized'): st.remote_shards spec = ChunkShardingSpec(dim=0, placements=["workerfoo/cuda:1"]) with self.assertRaisesRegex(ValueError, 'Invalid worker name'): _sharded_tensor.empty(spec, 10, 20)
def test_new_group(self): spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:2/cuda:3", ), ]) pg = dist.new_group(ranks=[1, 2, 3]) if self.rank >= 1: sharded_tensor = _sharded_tensor.empty(spec, 10, 5, process_group=pg) self.assertEqual((10, 5), sharded_tensor.size()) if self.rank == 1 or self.rank == 3: # Verify local shard. local_shard = sharded_tensor.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. self.assertEqual((self.rank // 2 * 5, 0), local_shard.metadata.shard_offsets) self.assertEqual((5, 5), local_shard.metadata.shard_lengths) self.assertEqual(f'rank:{self.rank - 1}/cuda:{self.rank}', local_shard.metadata.placement) # Verify global metadata. sharded_tensor_metadata = sharded_tensor.metadata() shards_metadata = sharded_tensor_metadata.shards_metadata self.assertEqual(2, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_lengths) self.assertEqual(f'rank:{rank * 2}/cuda:{rank * 2 + 1}', shard_metadata.placement) # Validate remote shards. remote_shards = sharded_tensor.remote_shards if self.rank == 1 or self.rank == 3: self.assertEqual(1, len(remote_shards)) else: self.assertEqual(2, len(remote_shards)) owners = {} for rpc_rank, shards in remote_shards.items(): self.assertEqual(1, len(shards)) for remote_shard in shards: self.assertEqual(rpc_rank, remote_shard.owner().id) shard = remote_shard.to_here() self.assertEqual((5, 5), shard.tensor.size())
def __init__(self, rank, world_size, state, model_tag='', callback=None, all_workers=False): """ Constructor for ClusterManager() :param rank: The agent's rank (unique Id) :param world_size: Number of agents used for training :param state: Dictionary used to encode training state :param model_tag: Tag used in the name of the checkpoint file :param callback: function to execute when SIGTERM is received :param all_workers: Whether to save all workers' models in checkpoints """ assert ClusterManager.CHECKPOINT_DIR is not None self.rank = rank self.world_size = world_size self.state = state self.all_workers = all_workers self.main_pid = os.getpid() self.signal_tensor = torch.zeros(1) if torch.cuda.is_available(): self.signal_tensor = self.signal_tensor.cuda() self.signal_handlers_installed = False self.logger = make_logger(rank) self.callback = None if all_workers: model_rank = rank else: model_rank = ClusterManager.MASTER_RANK self.model_tag = model_tag self.checkpoint_fname = 'checkpoint_r' + str(model_rank) + \ '_n' + str(world_size) + \ '.pth.tar' self.model_best_fname = 'model_best_r' + str(model_rank) + \ '_n' + str(world_size) + \ '.pth.tar' self.checkpoint_fpath = ClusterManager.CHECKPOINT_DIR \ + self.model_tag + self.checkpoint_fname self.model_best_fpath = ClusterManager.CHECKPOINT_DIR \ + self.model_tag + self.model_best_fname self.install_signal_handlers() if self.world_size > 1: assert dist.is_initialized() self.process_group = dist.new_group(list(range(self.world_size)))
def get_proc_groups(rank, size, replication): rank_c = rank // replication row_procs = [] for i in range(0, size, replication): row_procs.append(list(range(i, i + replication))) col_procs = [] for i in range(replication): col_procs.append(list(range(i, size, replication))) row_groups = [] for i in range(len(row_procs)): row_groups.append(dist.new_group(row_procs[i])) col_groups = [] for i in range(len(col_procs)): col_groups.append(dist.new_group(col_procs[i])) return row_groups, col_groups
def init_groups(size): """ Initialization of all distributed groups for the whole training process. We do this in advance so as not to hurt the performance of training. The server initializes the group and send it to all workers so that everybody can agree on the working group at some round. Args size The total number of machines in the current setup """ global all_groups all_groups = [] for i in range(size - 1): group = dist.new_group([0, i + 1]) all_groups.append(group)
def _test_sequence_num_incremented_subgroup(self, backend_name): torch.cuda.set_device(self.rank) store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( backend_name, world_size=self.world_size, rank=self.rank, store=store, ) subgroup_ranks = [0, 1, 2] subgroup = dist.new_group(subgroup_ranks) self._test_sequence_num_incremented(subgroup, subgroup_ranks)
def _register_nccl_grad_hook(self): """ This function registers the callback all-reduction function for the NCCL backend. All gradients will be all reduced in one single step. The NCCL reduction will directly be enqueued into the default CUDA stream. Therefore, no synchronization is needed. """ # Creating a new group self.nccl_reduction_group_id = dist.new_group() def reduction_fn_nccl(): # This function only needs to be called once if not self.need_reduction: return self.need_reduction = False all_grads = [[] for _ in range(len(self._module_copies))] all_grads_buckets_iters = [] # Bucketing all the gradients for dev_idx, module in enumerate(self._module_copies): for param in module.parameters(): if not param.requires_grad or param.grad is None: continue if param.grad.requires_grad: raise RuntimeError("DistributedDataParallel only works " "with gradients that don't require " "grad") # Adding the gradients for reduction all_grads[dev_idx].append(param.grad.data) # Now bucketing the parameters dev_grads_buckets = _take_tensors(all_grads[dev_idx], self.nccl_reduce_bucket_size) all_grads_buckets_iters.append(dev_grads_buckets) # Now reduce each bucket one after another for grads_batch in zip(*all_grads_buckets_iters): grads_batch_coalesced = [] # Coalesce each bucket for dev_idx, dev_grads_batch in enumerate(grads_batch): dev_id = self.device_ids[dev_idx] with torch.cuda.device(dev_id): dev_grads_batch_coalesced = _flatten_dense_tensors(dev_grads_batch) grads_batch_coalesced.append(dev_grads_batch_coalesced) # We will only use device 0's results, but this single op should be # faster than doing the following two operation sequentially: # (1) intra-node reduce to lead GPU, followed by # (2) inter-node allreduce for all the first lead GPUs in all nodes dist.all_reduce_multigpu(grads_batch_coalesced, group=self.nccl_reduction_group_id) # Now only work on the first device of self.device_ids, uncoalesce # the gradients for each bucket grads_batch_coalesced[0] /= dist.get_world_size() grads_batch_reduced = _unflatten_dense_tensors(grads_batch_coalesced[0], grads_batch[0]) for grad, reduced in zip(grads_batch[0], grads_batch_reduced): grad.copy_(reduced) # clear the gradients and save memory for replicas for module in self._module_copies[1:]: for param in module.parameters(): if param.requires_grad: param.grad = None param.data.set_() # Now register the reduction hook on the parameters for p in self.module.parameters(): if not p.requires_grad: continue def allreduce_hook(*unused): Variable._execution_engine.queue_callback(reduction_fn_nccl) p.register_hook(allreduce_hook)