def synchronize(self): completed = set() for x in self._handles.keys(): completed.update(x) if isinstance(x, tuple) else completed.add(x) missing_p = self._requires_update - completed for p in missing_p: handle, ctx = self._allreduce_grad_async(p) self._handles[p] = (handle, ctx) for p, (handle, ctx) in self._handles.items(): if handle is None: handle, ctx = self._allreduce_grad_async(p) self._handles[p] = (handle, ctx) for p, (handle, ctx) in self._handles.items(): if isinstance(p, tuple): # This was a grouped result, need to unpack outputs = synchronize(handle) for gp, output, gctx in zip(p, outputs, ctx): self._allreduce_delay[gp] = self.backward_passes_per_step gp.grad.set_(self._compression.decompress(output, gctx)) else: output = synchronize(handle) self._allreduce_delay[p] = self.backward_passes_per_step p.grad.set_(self._compression.decompress(output, ctx)) self._handles.clear() self._synchronized = True
def synchronize(self): missing_p = self._requires_update - set(self._handles.keys()) for p in missing_p: if self._sparse: handle, ctx = self._sparse_allreduce_async(p) self._handles[p] = (handle, ctx) else: handle, ctx = self._allreduce_grad_async(p) self._handles[p] = (handle, ctx) num_of_workers = size() for p, value in self._handles.items(): name = self._parameter_names.get(p) if self._sparse: handle, ctx = value output = synchronize(handle) new_grad = p.grad.data.view(-1).fill_(0.0) numel = output.numel() real_num_values = numel//num_of_workers for i in range(num_of_workers): values_and_indexes = output.data[i*real_num_values:(i+1)*real_num_values] values = values_and_indexes[0:real_num_values//2] indexes = values_and_indexes[real_num_values//2:].long() new_grad[indexes] = values new_grad = new_grad.reshape(p.grad.data.shape) #print('name: ', name, ' output shape: ', output.shape) #p.grad.data.set_(self._compression.decompress(output, None, name=name)) p.grad.data.set_(new_grad) else: handle, ctx = value output = synchronize(handle) p.grad.data.set_(self._compression.decompress(output, ctx, name=name)) self._handles.clear()
def broadcast_parameters(params, root_rank): """ Broadcasts the parameters from root rank to all other processes. Typical usage is to broadcast the `model.state_dict()`, `model.named_parameters()`, or `model.parameters()`. Arguments: params: One of the following: - list of parameters to broadcast - dict of parameters to broadcast root_rank: The rank of the process from which parameters will be broadcasted to all other processes. """ if isinstance(params, dict): params = sorted(params.items()) elif isinstance(params, list): # support both named_parameters() and regular parameters() params = [p if isinstance(p, tuple) else (None, p) for p in params] else: raise ValueError('invalid params of type: %s' % type(params)) # Run asynchronous broadcasts. handles = [] for name, p in params: if isinstance(p, torch.autograd.Variable): p = p.data handle = broadcast_async_(p, root_rank, name) handles.append(handle) # Wait for completion. for handle in handles: synchronize(handle)
def synchronize(self): if hvd.size() > 1: for p in self._handles: handle = self._handles[p] synchronize(handle) begin_time = time.time() p_size = np.prod(p.size()) torch.cuda.synchronize() begin_comm_time = time.time() if self._use_allgather and p_size > 1024: #fjr decompress name = self._parameter_names.get(p) msg_size = self._compressed_msg_size[name] g_size = p.grad.data.size() p_flatten = p.grad.data.view(-1) p_flatten.zero_() p_flatten[self._compressed_idx[ name]] = self._compressed_val[name] p.grad.data = p.grad.data.view(g_size) if self._debug: print("diff : ", torch.sum(self._v_ref[name] - p.grad.data)) torch.cuda.synchronize() end_comm_time = time.time() self.pack_time += end_comm_time - begin_comm_time torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time self._handles.clear()
def broadcast_parameters(params, root_rank): """ Broadcasts the parameters from root rank to all other processes. Typical usage is to broadcast the `model.state_dict()`, `model.named_parameters()`, or `model.parameters()`. Arguments: params: The list of parameters to broadcast. root_rank: The rank of the process from which parameters will be broadcasted to all other processes. """ if isinstance(params, dict): params = sorted(params.items()) else: # support both named_parameters() and regular parameters() params = [p if isinstance(p, tuple) else (None, p) for p in params] # Run asynchronous broadcasts. handles = [] for name, p in params: if isinstance(p, torch.autograd.Variable): p = p.data handle = broadcast_async_(p, root_rank, name) handles.append(handle) # Wait for completion. for handle in handles: synchronize(handle)
def broadcast_parameters(params, root_rank): """ Broadcasts the parameters from root rank to all other processes. Typical usage is to broadcast the `model.state_dict()`, `model.named_parameters()`, or `model.parameters()`. Arguments: params: One of the following: - list of parameters to broadcast - dict of parameters to broadcast root_rank: The rank of the process from which parameters will be broadcasted to all other processes. """ if isinstance(params, dict): params = sorted(params.items()) elif isinstance(params, list): # support both named_parameters() and regular parameters() params = [p if isinstance(p, tuple) else (None, p) for p in params] else: raise ValueError('invalid params of type: %s' % type(params)) # Run asynchronous broadcasts. handles = [] for name, p in params: handle = broadcast_async_(p, root_rank, name) handles.append(handle) # Wait for completion. for handle in handles: synchronize(handle)
def synchronize(self): if hvd.size() > 1: for p in self._handles: handle = self._handles[p] synchronize(handle) #p_size = np.prod(p.size()) p_size = torch.numel(p) if self._use_allgather and p_size > self._plan1: handle = self._handles_val[p] synchronize(handle) torch.cuda.synchronize() begin_time_sync = time.time() #fjr decompress name = self._parameter_names.get(p) g_size = p.grad.data.size() p_flatten = p.grad.data.view(-1) p_flatten.zero_() torch.cuda.synchronize() begin_unpack_time = time.time() if self._use_gpu: if p_size > self._plan3: #count_nnz = 0 offset = 0 for node_idx in range(hvd.size()): msg_size = self._compressed_idx[name][offset] offset += 1 p_flatten[self._compressed_idx[name][ offset: \ offset + msg_size]] += \ self._compressed_val[name][node_idx] offset += msg_size; #count_nnz += msg_size #if hvd.rank() == 0: # print("sparsity ", name, count_nnz.cpu().numpy()/(p_size)) else: msg_size = self._compressed_msg_size[name] for node_idx in range(hvd.size()): p_flatten[self._compressed_idx[name][node_idx*msg_size : \ node_idx*msg_size + msg_size]] += \ self._compressed_val[name][node_idx] p.grad.data = p_flatten.view(g_size) torch.cuda.synchronize() self.unpack_time += time.time() - begin_unpack_time torch.cuda.synchronize() self.pruning_time += time.time() - begin_time_sync if self._debug: diff = torch.sum(self._v_ref[name] - p.grad.data) if( torch.abs(diff) > 1e-3 ): print("error diff is, ", diff, name, p.size()) else: pass self._handles.clear() self._handles_val.clear()
def synchronize(self): for param_group in self.param_groups: for p in param_group['params']: name = self._parameter_names.get(p) handle = allreduce_async_(p.grad.data, average=True, name=name) self._handles[p] = handle for handle in self._handles.values(): synchronize(handle) self._handles.clear()
def synchronize(self): if hvd.size() > 1: for p in self._handles: handle = self._handles[p] synchronize(handle) torch.cuda.synchronize() begin_time = time.time() p_size = np.prod(p.size()) if self._use_allgather and p_size > 1024: #fjr decompress name = self._parameter_names.get(p) torch.cuda.synchronize() begin_pack_time = time.time() g_size = p.grad.data.size() p_flatten = p.grad.data.view(-1) p_flatten.zero_() #print("p_flatten size is ,", p_flatten.size()) #print("compressed msg, ", self._compressed_msg[name], 'rank, ', hvd.local_size()) #print("hand is ", handle) offset = 0 for node_idx in range(hvd.size()): if self._use_gpu: msg_size = self._compressed_msg[name][offset].type( 'torch.cuda.LongTensor') offset += 1 p_flatten[self._compressed_msg[name][ offset: \ offset + msg_size].type('torch.cuda.LongTensor')] += \ self._compressed_msg[name][offset + msg_size : \ offset + 2*msg_size] offset += msg_size * 2 else: msg_size = self._compressed_msg[name][offset].type( 'torch.LongTensor') offset += 1 p_flatten[self._compressed_msg[name][ offset: \ offset + msg_size].type('torch.LongTensor')] += \ self._compressed_msg[name][offset + msg_size : \ offset + 2*msg_size] offset += msg_size * 2 torch.cuda.synchronize() self.pack_time += time.time() - begin_pack_time p.grad.data = p_flatten.view(g_size) if self._debug: diff = torch.sum(self._v_ref[name] - p.grad.data) if (torch.abs(diff) > 1e-3): print("error diff is, ", diff, name, p.size()) torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time self._handles.clear()
def synchronize(self): for p in self._handles: handle = self._handles[p] synchronize(handle) begin_time = time.time() torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time self._handles.clear()
def synchronize(self): for p in self._handles: handle = self._handles[p] synchronize(handle) p_size = np.prod(p.size()) begin_time = time.time() if self._use_allgather and p_size > 1024: torch.cuda.synchronize() begin_unpack_time = time.time() #fjr decompress if p_size < 1500 * 10000: handle = self._handles_val[p] synchronize(handle) name = self._parameter_names.get(p) g_size = p.grad.data.size() p_flatten = p.grad.data.view(-1) p_flatten.zero_() offset = 0 if p_size < 1500 * 10000: for node_idx in range(hvd.size()): if self._use_gpu: msg_size = self._compressed_idx[name][offset] offset += 1 p_flatten[self._compressed_idx[name][ offset: \ offset + msg_size]] += \ self._compressed_val[name][node_idx] offset += msg_size else: for node_idx in range(hvd.size()): if self._use_gpu: msg_size = self._compressed_msg[name][offset].type( 'torch.cuda.LongTensor') offset += 1 p_flatten[self._compressed_msg[name][ offset: \ offset + msg_size].type('torch.cuda.LongTensor')] += \ self._compressed_msg[name][offset + msg_size : \ offset + 2*msg_size] offset += msg_size * 2 p.grad.data = p.grad.data.view(g_size) if self._debug: print("diff : ", torch.sum(self._v_ref[name] - p.grad.data)) torch.cuda.synchronize() self.unpack_time += time.time() - begin_unpack_time torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time self._handles.clear() self._handles_val.clear()
def synchronize(self): for p in self._handles: handle = self._handles[p] synchronize(handle) p_size = np.prod(p.size()) begin_time = time.time() torch.cuda.synchronize() begin_comm_time = time.time() if self._use_allgather and p_size > 1024: #fjr decompress name = self._parameter_names.get(p) msg_size = self._compressed_msg_size[name] #print("rank, msg_size is ", hvd.local_rank(), msg_size) g_size = p.grad.data.size() p_flatten = p.grad.data.view(-1) p_flatten.zero_() #print("p_flatten size is ,", p_flatten.size()) #print("compressed msg, ", self._compressed_msg[name], 'rank, ', hvd.local_size()) #print("hand is ", handle) for node_idx in range(hvd.size()): if self._use_gpu: if p_size == 1500 * 10000: p_flatten[self._compressed_msg[name][node_idx*msg_size*2 : \ node_idx*msg_size*2 + msg_size].type('torch.cuda.LongTensor')] += \ self._compressed_msg[name][node_idx*msg_size*2 + msg_size : \ node_idx*msg_size*2 + 2*msg_size] else: mean_val = torch.mean(self._compressed_msg[name][node_idx*msg_size*2 + msg_size : \ node_idx*msg_size*2 + 2*msg_size]) p_flatten[self._compressed_msg[name][node_idx*msg_size*2 : \ node_idx*msg_size*2 + msg_size].type('torch.cuda.LongTensor')] += \ mean_val else: p_flatten[self._compressed_msg[name][node_idx*msg_size*2 : \ node_idx*msg_size*2 + msg_size].type('torch.LongTensor')] += \ self._compressed_msg[name][node_idx*msg_size*2 + msg_size : \ node_idx*msg_size*2 + 2*msg_size] p.grad.data = p.grad.data.view(g_size) if self._debug: print("diff : ", torch.sum(self._v_ref[name] - p.grad.data)) torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time self.comm_time += time.time() - begin_comm_time self._handles.clear()
def backward(self, grad_output): grad_output = grad_output.contiguous() saved_input, weight, mean, invstd, count_all = self.saved_tensors need_input_grad, need_weight_grad, need_bias_grad = self.needs_input_grad[ 0:3] # calculate local stats as well as grad_weight / grad_bias sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( grad_output, saved_input, mean, invstd, weight, need_input_grad, need_weight_grad, need_bias_grad) if need_input_grad: # synchronizing stats used to calculate input gradient. sum_dy_handle = allreduce_async(sum_dy, op=Sum, name='sync_batch_norm.sum_dy') sum_dy_xmu_handle = allreduce_async( sum_dy_xmu, op=Sum, name='sync_batch_norm.sum_dy_xmu') # wait on the async communication to finish sum_dy = synchronize(sum_dy_handle) sum_dy_xmu = synchronize(sum_dy_xmu_handle) if _SYNC_BN_V2 or _SYNC_BN_V3: count_all_sum = count_all.sum() mean_dy = sum_dy / count_all_sum mean_dy_xmu = sum_dy_xmu / count_all_sum else: # before 1.5.0, sum_dy was sum of means from every worker, so we just # need to divide it by number of workers mean_dy = sum_dy / size() mean_dy_xmu = sum_dy_xmu / size() # backward pass for gradient calculation grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, mean_dy, mean_dy_xmu) else: grad_input = None # synchronizing of grad_weight / grad_bias is not needed as distributed # training would handle all reduce. if weight is None or not need_weight_grad: grad_weight = None if weight is None or not need_bias_grad: grad_bias = None return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
def synchronize(self): if not self.process_set.included(): self._synchronized = True return completed = set() for x in self._handles.keys(): completed.update(x) if isinstance(x, tuple) else completed.add(x) missing_p = self._requires_update - completed for p in missing_p: handle, ctx = self._allreduce_grad_async(p) self._handles[p] = (handle, ctx) for p, (handle, ctx) in self._handles.items(): if handle is None: handle, ctx = self._allreduce_grad_async(p) self._handles[p] = (handle, ctx) for p, (handle, ctx) in self._handles.items(): if isinstance(p, tuple): # This was a grouped result, need to unpack outputs = synchronize(handle) for gp, output, gctx in zip(p, outputs, ctx): self._allreduce_delay[gp] = self.backward_passes_per_step gp.grad.set_(self._compression.decompress(output, gctx)) if self._groups is not None and self._group_counts[p] != 0: self._group_counts[p] = 0 else: # When handle is a callable function, it returns the aggregated tensor result output = synchronize( handle) if not callable(handle) else handle() self._allreduce_delay[p] = self.backward_passes_per_step if self._groups is not None: group = self._p_to_group[p] if self._group_counts[group] != 0: self._group_counts[group] = 0 if p.grad.is_sparse: aggregated = self._compression.decompress(output, ctx) if not aggregated.is_sparse: # When sparse_as_dense=True we need to convert the grad back to sparse before update aggregated = aggregated.to_sparse() # Sparse grads do not support set_ for some reason, so we do this as an equivalent p.grad.zero_().add_(aggregated) else: p.grad.set_(self._compression.decompress(output, ctx)) self._handles.clear() self._synchronized = True
def synchronize(self): synced = False if self.count_down == 0: missing_p = self._requires_update - set(self._handles.keys()) for p in missing_p: self._allreduce_tensor(p) if self._multi_node: for p, value in self._handles.items(): handle, ctx = value output = synchronize(handle) p.grad.set_( self._compression.decompress(output, ctx) / self.accumulation_step) else: buckets = OrderedDict() for tensor in self._handles.values(): tp = tensor.type() if tp not in buckets: buckets[tp] = [] buckets[tp].append(tensor) for tp in buckets: bucket = buckets[tp] coalesced = flatten( bucket) / self.world_size / self.accumulation_step torch.distributed.all_reduce_multigpu([coalesced]) for buf, synced in zip(bucket, unflatten(coalesced, bucket)): buf.copy_(synced) self._handles.clear() synced = True self.count_down = self.accumulation_step self.count_down -= 1 return synced
def synchronize(self): if hvd.size() > 1: for p in self._handles: handle = self._handles[p] synchronize(handle) torch.cuda.synchronize() begin_time = time.time() p_size = np.prod(p.size()) if self._use_allgather and p_size > 1024: handle = self._handles_val[p] synchronize(handle) #fjr decompress name = self._parameter_names.get(p) #msg_size = self._compressed_msg_size[name] #print("rank, msg_size is ", hvd.local_rank(), msg_size) torch.cuda.synchronize() begin_pack_time = time.time() g_size = p.grad.data.size() p_flatten = p.grad.data.view(-1) p_flatten.zero_() offset = 0 for node_idx in range(hvd.size()): if self._use_gpu: msg_size = self._compressed_idx[name][offset] offset += 1 p_flatten[self._compressed_idx[name][ offset: \ offset + msg_size]] += \ self._compressed_val[name][node_idx] offset += msg_size; p.grad.data = p_flatten.view(g_size) torch.cuda.synchronize() self.pack_time += time.time() - begin_pack_time if self._debug: diff = torch.sum(self._v_ref[name] - p.grad.data) if( torch.abs(diff) > 1e-3 ): print("error diff is, ", diff, name, p.size()) torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time self._handles.clear() self._handles_val.clear()
def synchronize(self): for p in self._handles: handle = self._handles[p] synchronize(handle) begin_time = time.time() torch.cuda.synchronize() begin_comm_time = time.time() #if self._use_allgather and p_size > 1024 and hvd.size() > 1: # #fjr decompress # name = self._parameter_names.get(p) # msg_size = self._compressed_msg_size[name] # #print("rank, msg_size is ", hvd.local_rank(), msg_size) # g_size = p.grad.data.size() # p_flatten = p.grad.data.view(-1) # p_flatten.zero_() # #print("p_flatten size is ,", p_flatten.size()) # #print("compressed msg, ", self._compressed_msg[name], 'rank, ', hvd.local_size()) # #print("hand is ", handle) # for node_idx in range(hvd.size()): # if self._use_gpu: # p_flatten[self._compressed_msg[name][node_idx*msg_size*2 : \ # node_idx*msg_size*2 + msg_size].type('torch.cuda.LongTensor')] += \ # self._compressed_msg[name][node_idx*msg_size*2 + msg_size : \ # node_idx*msg_size*2 + 2*msg_size] # else: # p_flatten[self._compressed_msg[name][node_idx*msg_size*2 : \ # node_idx*msg_size*2 + msg_size].type('torch.LongTensor')] += \ # self._compressed_msg[name][node_idx*msg_size*2 + msg_size : \ # node_idx*msg_size*2 + 2*msg_size] # p.grad.data = p.grad.data.view(g_size) # if self._debug: # print("diff : ", torch.sum(self._v_ref[name] - p.grad.data)) torch.cuda.synchronize() end_comm_time = time.time() self.pack_time += end_comm_time - begin_comm_time torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time self._handles.clear()
def synchronize(self): missing_p = self._requires_update - set(self._handles.keys()) for p in missing_p: self._allreduce_grad(p) for p, value in self._handles.items(): handle, ctx = value output = synchronize(handle) p.grad.data.set_(self._compression.decompress(output, ctx)) self._handles.clear()
def synchronize(self): for p in self._handles: handle = self._handles[p] synchronize(handle) p_size = np.prod(p.size()) begin_time = time.time() torch.cuda.synchronize() begin_comm_time = time.time() if self._use_allgather and p_size > 1024: #fjr decompress handle = self._handles_val[p] synchronize(handle) name = self._parameter_names.get(p) msg_size = self._compressed_msg_size[name] #print("rank, msg_size is ", hvd.local_rank(), msg_size) g_size = p.grad.data.size() p_flatten = p.grad.data.view(-1) p_flatten.zero_() #print("p_flatten size is ,", p_flatten.size()) #print("compressed msg, ", self._compressed_msg[name], 'rank, ', hvd.local_size()) #print("hand is ", handle) offset = 0 for node_idx in range(hvd.size()): if self._use_gpu: msg_size = self._compressed_idx[name][offset] offset += 1 p_flatten[self._compressed_idx[name][ offset: \ offset + msg_size]] += \ self._compressed_val[name][node_idx] offset += msg_size; p.grad.data = p.grad.data.view(g_size) if self._debug: print("diff : ", torch.sum(self._v_ref[name] - p.grad.data)) torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time self.comm_time += time.time() - begin_comm_time self._handles.clear()
def synchronize(self): for p, value in self._handles.items(): handle, ctx = value if handle is None: handle, ctx = self._allreduce_grad(p) self._handles[p] = (handle, ctx) for p, (handle, _) in self._handles.items(): output = mpi_ops.synchronize(handle) self._allreduce_delay[p] = self.backward_passes_per_step p.grad.data.set_(self._compression.decompress(output, ctx)) self._handles.clear()
def test_parallel(self): hvd.init() # TODO support non-MPI Adasum operation # Only do this test if there are GPUs available. if not hvd.mpi_enabled() or not torch.cuda.is_available(): self.skipTest("No GPUs available") device = torch.device('cuda:{}'.format(hvd.local_rank())) np.random.seed(2) torch.manual_seed(2) size = hvd.size() local_size = hvd.local_size() rank = hvd.rank() for data_type in self.data_types: all_Ns = [size * 20 - 13, size * 2 + 1, size + 2, 2**19] tensors = [] all_qs = [] for N in all_Ns: a = np.random.normal(0, 1, (N, 1)).astype(np.float64) r = np.random.normal(0, 1, (size, 1)).astype(np.float64) q = np.dot(a, r.T) q = q.astype(data_type) all_qs.append(q.astype(np.float64)) tensors.append(q[:, hvd.rank()]) tensors = list( map(lambda x: torch.from_numpy(x).to(device), tensors)) handles = [ hvd.allreduce_async(tensor, op=hvd.Adasum) for tensor in tensors ] reduced_tensors = [synchronize(h) for h in handles] expected = [np.sum(q, axis=1) / size for q in all_qs] all_comp = [ self.are_close(data_type, e, rt.cpu().numpy()) for e, rt in zip(expected, reduced_tensors) ] if np.alltrue(all_comp): print('Parallel test passed') else: for c, e, rt in zip(all_comp, expected, reduced_tensors): if c == False: print('computed: ', rt) print('expected: ', e) print('off by: ', self.diff_ratio(e, rt.cpu().numpy())) assert np.alltrue(all_comp)
def synchronize(self): if hvd.size() > 1: for p in self._handles: handle = self._handles[p] synchronize(handle) begin_time = time.time() p_size = np.prod(p.size()) torch.cuda.synchronize() begin_comm_time = time.time() if self._use_allgather and p_size > 1024: handle = self._handles_val[p] synchronize(handle) name = self._parameter_names.get(p) msg_size = self._compressed_msg_size[name] g_size = p.grad.data.size() p_flatten = p.grad.data.view(-1) p_flatten.zero_() for node_idx in range(hvd.size()): if self._use_gpu: p_flatten[self._compressed_idx[name][node_idx*msg_size : \ node_idx*msg_size + msg_size]] += \ self._compressed_val[name][node_idx] p.grad.data = p.grad.data.view(g_size) if self._debug: diff = torch.sum(self._v_ref[name] - p.grad.data) if torch.abs(diff) > 1e-3: print(diff, name) torch.cuda.synchronize() end_comm_time = time.time() self.pack_time += end_comm_time - begin_comm_time torch.cuda.synchronize() end_time = time.time() self.pruning_time += end_time - begin_time self._handles.clear() self._handles_val.clear()
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum): input = input.contiguous() size = input.numel() // input.size(1) count = torch.tensor([size]) # calculate mean/invstd for input. mean, invstd = torch.batch_norm_stats(input, eps) count_handle = allgather_async(count.unsqueeze(0), name='sync_batch_norm.count') mean_handle = allgather_async(mean.unsqueeze(0), name='sync_batch_norm.mean') invstd_handle = allgather_async(invstd.unsqueeze(0), name='sync_batch_norm.invstd') # wait on the async communication to finish count_all = synchronize(count_handle) mean_all = synchronize(mean_handle) invstd_all = synchronize(invstd_handle) if _SYNC_BN_V2: counts_for_bngswc = count_all.view(-1).float().to(input.device) else: # backwards compatibility counts_for_bngswc = count_all.view(-1).tolist() # calculate global mean & invstd mean, invstd = torch.batch_norm_gather_stats_with_counts( input, mean_all, invstd_all, running_mean, running_var, momentum, eps, counts_for_bngswc) self.save_for_backward(input, weight, mean, invstd, count_all) # apply element-wise normalization return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
def synchronize(self): missing_p = self._requires_update - set(self._handles.keys()) for p in missing_p: handle, ctx = self._allreduce_grad_async(p) self._handles[p] = (handle, ctx) for p, value in self._handles.items(): handle, ctx = value if handle is None: handle, ctx = self._allreduce_grad_async(p) self._handles[p] = (handle, ctx) for p, (handle, _) in self._handles.items(): output = synchronize(handle) self._allreduce_delay[p] = self.backward_passes_per_step p.grad.set_(self._compression.decompress(output, ctx)) self._handles.clear()
def synchronize(self): missing_p = self._requires_update - set(self._handles.keys()) for p in missing_p: handle, ctx = self._broadcast_grad_async(p) self._handles[p] = (handle, ctx) for p, value in self._handles.items(): handle, ctx = value if handle is None: handle, ctx = self._broadcast_grad_async(p) self._handles[p] = (handle, ctx) for p, (handle, _) in self._handles.items(): outputs = [synchronize(hd) for hd in handle] self._allreduce_delay[p] = self.backward_passes_per_step compressor = self._compressors[p] if compressor is None: p.grad.set_(self._compression.decompress(outputs[0], ctx)) else: p.grad.set_(compressor.decompress(outputs)) self._handles.clear() self._synchronized = True
def synchronize(self): for p, value in self._handles.items(): name = self._merged_parameter_names.get(p) handle, ctx, density = value stime = time.time() output = synchronize(handle) if self._profiling: utils.force_insert_item(self._allreduce_timers, name, time.time() - stime) stime = time.time() if self._norm_clip is not None: norm_clip = np.sqrt(1.0 / size()) * self._norm_clip norm_type = 2.0 param_norm = output.norm(norm_type) total_norm = param_norm.item() clip_coef = norm_clip / (total_norm + 1e-6) if clip_coef < 1: output.mul_(clip_coef) p.set_(output) if self._profiling: utils.force_insert_item(self._update_times, name, time.time() - stime) if len(self._groups) != len(self._sequential_keys): for merged_p, value in self._handles.items(): new_name = self._merged_parameter_names.get(merged_p) tensors = self._pull_from_buffer(new_name, merged_p) for n in tensors: p = self._named_parameters.get(n) if settings.FP16: p.grad.set_(tensors[n].data.type(p.grad.type())) else: p.grad.set_(tensors[n].data) self.train_iter += 1 self._handles.clear() self._print_profiling()
def step(self, closure=None): loss = None if closure is not None: loss = closure() missing_p = self._requires_update - set(self._handles.keys()) for p in missing_p: handle, ctx = self._allreduce_grad_async(p) self._handles[p] = (handle, ctx) for p, (handle, ctx) in self._handles.items(): # This means step() is called before backward_passes_per_steps finished. # We do a synchoronous allreduce here. if not handle: handle, ctx = self._allreduce_grad_async(p) self._handles[p] = (handle, ctx) delta = synchronize(handle) delta = self._compression.decompress(delta, ctx) start = self._starting_models[p] start.data.add_(delta.data) p.data.copy_(start) self._allreduce_delay[p] = self.backward_passes_per_step self._handles.clear() return loss
def synchronize(self): for p, value in self._handles.items(): handle, ctx = value output = synchronize(handle) p.grad.data.set_(self._compression.decompress(output, ctx)) self._handles.clear()
def synchronize(self): num_of_workers = size() for p, value in self._handles.items(): name = self._merged_parameter_names.get(p) handle, ctx, density = value if self._sparse and density < 1: stime = time.time() handle_idx = None all_indexes = None if type(handle) is tuple: handle, handle_idx = handle[0], handle[1] output = synchronize(handle) if handle_idx is not None: all_indexes = synchronize(handle_idx) if self._profiling: utils.force_insert_item(self._allreduce_timers, name, time.time() - stime) stime = time.time() new_grad = p.data.view(-1) new_grad.fill_(0.0) numel = output.size(0) real_num_values = numel // num_of_workers for i in range(num_of_workers): values_and_indexes = output.data[i * real_num_values:(i + 1) * real_num_values] if all_indexes is None: values = values_and_indexes indexes = None per_values = values per_values = self._compression.decompress( per_values, p.size()) new_grad += per_values.view(-1) else: values = values_and_indexes indexes = all_indexes.data[i * real_num_values:(i + 1) * real_num_values].long() per_values = values[0:indexes.numel()] per_values = self._compression.decompress( per_values, p.size()) new_grad[indexes[0:indexes.numel()]] += per_values new_grad /= num_of_workers if self._profiling: utils.force_insert_item(self._update_times, name, time.time() - stime) else: stime = time.time() output = synchronize(handle) if self._profiling: utils.force_insert_item(self._allreduce_timers, name, time.time() - stime) stime = time.time() if self._norm_clip is not None: norm_clip = np.sqrt(1.0 / size()) * self._norm_clip norm_type = 2.0 param_norm = output.norm(norm_type) total_norm = param_norm.item() clip_coef = norm_clip / (total_norm + 1e-6) if clip_coef < 1: output.mul_(clip_coef) if self._compression: output = self._compression.decompress(output, p.size()) p.set_(output) if self._profiling: utils.force_insert_item(self._update_times, name, time.time() - stime) if len(self._groups) != len(self._sequential_keys): for merged_p, value in self._handles.items(): new_name = self._merged_parameter_names.get(merged_p) tensors = self._pull_from_buffer(new_name, merged_p) for n in tensors: p = self._named_parameters.get(n) if self._fp16: p.grad.set_(tensors[n].data.type(p.grad.type())) else: p.grad.set_(tensors[n].data) self.train_iter += 1 self._handles.clear() self._print_profiling()
def synchronize(self): for handle in self._handles.values(): synchronize(handle) self._handles.clear()
def backward(self, grad_output): grad_output = grad_output.contiguous() saved_input, weight, mean, invstd, count_all = self.saved_tensors need_input_grad, need_weight_grad, need_bias_grad = self.needs_input_grad[ 0:3] # calculate local stats as well as grad_weight / grad_bias sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( grad_output, saved_input, mean, invstd, weight, need_input_grad, need_weight_grad, need_bias_grad) if need_input_grad: # synchronizing stats used to calculate input gradient. sum_dy_handle = allreduce_async(sum_dy, op=Sum, name='sync_batch_norm.sum_dy') sum_dy_xmu_handle = allreduce_async( sum_dy_xmu, op=Sum, name='sync_batch_norm.sum_dy_xmu') # wait on the async communication to finish sum_dy = synchronize(sum_dy_handle) sum_dy_xmu = synchronize(sum_dy_xmu_handle) if _SYNC_BN_V4: # from 1.9.0 on we need a count tensor on all devices # count_all is calculated as total count across all ranks in forward function count_all = count_all.to(dtype=torch.int, device=grad_output.device) elif _SYNC_BN_V2 or _SYNC_BN_V3: # before 1.9.0 we need the count as an integer to compute means values count = count_all.sum() else: # before 1.5.0, sum_dy was sum of means from every worker, so we just # need to divide it by number of workers count = size() # backward pass for gradient calculation # we are calling into a non-public undocumented function which broke moving to 1.9.0 # https://github.com/pytorch/pytorch/issues/57900 if _SYNC_BN_V4: # from 1.9.0 on, sums and count parameters expected grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_all) else: # before 1.9.0, mean parameters expected, not sums and count grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, sum_dy / count, sum_dy_xmu / count) else: grad_input = None # synchronizing of grad_weight / grad_bias is not needed as distributed # training would handle all reduce. if weight is None or not need_weight_grad: grad_weight = None if weight is None or not need_bias_grad: grad_bias = None return grad_input, grad_weight, grad_bias, None, None, None, None, None, None