def synchronize(self): for h in self.handles: handle, names, tensors, comm_tensors = h hvd.synchronize(handle) #name = 'merged_tensor_comm_'+','.join(names) name = ','.join(names) offset = 0 buf = self.merged_tensors[name] if self.fp16: buf = buf.float() for i, t in enumerate(tensors): numel = comm_tensors[i].numel() comm_tensor = buf.data[offset:offset + numel] if self.symmetric: lower_indices = torch.tril_indices(t.shape[0], t.shape[1], device=t.device) upper_indices = torch.triu_indices(t.shape[0], t.shape[1], device=t.device) t[upper_indices[0], upper_indices[1]] = comm_tensor.view( comm_tensors[i].shape) t[lower_indices[0], lower_indices[1]] = t.t()[lower_indices[0], lower_indices[1]] else: t.copy_(comm_tensor.view(t.shape)) offset += numel self.handles.clear()
def synchronize(self): for h in self.handles: hvd.synchronize(h) self.handles.clear() if self.merge: self._tensor_group.pull_alltensors() self._tensor_group.clear_group_flags()
def synchronize(self): for h in self.handles: hvd.synchronize(h) if self.merge: self._tensor_group.pull_alltensors() self._tensor_group.clear_group_flags() for name in self._name_tensors: tensor, comm_tensor = self._name_tensors[name] if self.symmetric: if self.fp16: comm_tensor = comm_tensor.float() lower_indices = torch.tril_indices(tensor.shape[0], tensor.shape[1], device=tensor.device) upper_indices = torch.triu_indices(tensor.shape[0], tensor.shape[1], device=tensor.device) tensor[upper_indices[0], upper_indices[1]] = comm_tensor tensor[lower_indices[0], lower_indices[1]] = tensor.t()[lower_indices[0], lower_indices[1]] else: if self.fp16: comm_tensor = comm_tensor.float() tensor.copy_(comm_tensor) if self.op == hvd.Average: tensor.div_(hvd.size()) self._name_tensors.clear() self.handles.clear()
def _broadcast_eigendecomp(self): """Broadcasts the eigendecompositions for all layers Note: we use `op=hvd.Sum` to simulate an allgather`. Each rank will either compute the eigendecomposition for a factor or just return zeros so we sum instead of averaging. """ handles = [] rank = hvd.rank() for i, m in enumerate(self.modules): rank_a = self.m_dA_ranks[m] rank_g = self.m_dG_ranks[m] name = self.module_names[i] h = hvd.broadcast_async_(self.m_QA[m], rank_a, name=name + 'mQA') handles.append(h) h = hvd.broadcast_async_(self.m_dA[m], rank_a, name=name + 'mdA') handles.append(h) h = hvd.broadcast_async_(self.m_QG[m], rank_g, name=name + 'mQG') handles.append(h) h = hvd.broadcast_async_(self.m_dG[m], rank_g, name=name + 'mdG') handles.append(h) for handle in handles: hvd.synchronize(handle)
def _allreduce_factors(self): """Allreduce the factors for all layers""" handles = [] for m in self.modules: handles.append(hvd.allreduce_async_(self.m_A[m].data, op=hvd.Average)) handles.append(hvd.allreduce_async_(self.m_G[m].data, op=hvd.Average)) for handle in handles: hvd.synchronize(handle)
def test_allgather(): torch.cuda.set_device(hvd.local_rank()) rank = hvd.rank() tensor = torch.rand(10).float().cuda() print('rank: ', rank, ', tensor: ', tensor) #handle = hvd.allgather_async(tensor) #tensor = hvd.synchronize(handle) handle = hvd.broadcast_async(tensor, 0) hvd.synchronize(handle) print('---------') print('rank: ', rank, ', tensor: ', tensor)
def _broadcast_precon_grads(self): handles = [] for i, m in enumerate(self.modules): rank_a, rank_g = self.module_ranks[m] assert rank_a == rank_g name = self.module_names[i] v = self.m_precon_grad[m] h = hvd.broadcast_async_(v, rank_a, name=name + 'preconGrad') handles.append(h) for handle in handles: hvd.synchronize(handle)
def _broadcast_inverse_factors(self): handles = [] for i, m in enumerate(self.modules): rank_a, rank_g = self.module_ranks[m] name = self.module_names[i] h = hvd.broadcast_async_(self.m_inv_A[m], rank_a, name=name+'inverseA') handles.append(h) h = hvd.broadcast_async_(self.m_inv_G[m], rank_g, name=name+'inverseG') handles.append(h) for handle in handles: hvd.synchronize(handle)
def _reduce_factors(self, eigen_ranks): """Allreduce the factors for all layers""" handles = [] for m in self.modules: name = self.module_name_map[m] ranks_a, ranks_g = eigen_ranks[m] rank_a = ranks_a[0] rank_g = ranks_g[0] handles.append(hvd.allreduce_async_(self.m_A[m].data, op=hvd.Average)) handles.append(hvd.allreduce_async_(self.m_G[m].data, op=hvd.Average)) for handle in handles: hvd.synchronize(handle)
def _allreduce_eigendecomp(self): """Allreduce the eigendecompositions for all layers Note: we use `op=hvd.Sum` to simulate an allgather`. Each rank will either compute the eigendecomposition for a factor or just return zeros so we sum instead of averaging. """ handles = [] for m in self.modules: handles.append(hvd.allreduce_async_(self.m_QA[m].data, op=hvd.Sum)) handles.append(hvd.allreduce_async_(self.m_QG[m].data, op=hvd.Sum)) for handle in handles: hvd.synchronize(handle)
def fsp_matrix_transfer(self): ''' obtain the feature maps of bottlenecks (h*w*m), reshape it to (hw*m), then do matrix multiplication (m*n) allgather the mm, use L2 loss on it :return: ''' handles = [] matrix_group = [] for key in self.activation: if 'in' in key: fm_in = self.activation[key] if 'out' in key: fm_out = self.activation[key] fm_in = fm_in.view(fm_in.shape[0], fm_in.shape[1], -1) fm_out = fm_out.view(fm_out.shape[0], fm_out.shape[1], -1) fm_out = torch.transpose(fm_out, 1, 2) fsp_matrix = torch.bmm(fm_in, fm_out) / fm_in.shape[-1] matrix_group.append(fsp_matrix) fsp_matrix = fsp_matrix.unsqueeze(0) handle = hvd.allgather_async(fsp_matrix, key) handles.append(handle) fsp_loss = 0 for idx, handle in enumerate(handles): rec_fsp = hvd.synchronize(handle) for i in range(0, hvd.size()): if i != self.task_id: fsp_loss += self.norm_loss(matrix_group[idx], rec_fsp[i]) fsp_loss /= (hvd.size() - 1) self.log_dict['transfer_count'] += 1 return fsp_loss
def delayedupdate(self, val): if self.handle is None: self.sum += 0 else: self.sum += hvd.synchronize(self.handle) self.handle = hvd.allreduce_async_(val.detach().cpu(), name=self.name) self.n += 1
def forward(ctx, tensor, name): ctx.dim = tensor.shape[0] # we try to put all sync ops in forward pass ctx.all_dims = hvd.allgather( torch.tensor([ctx.dim], device=tensor.device)).view(hvd.size()) handle = hvd.allgather_async(tensor, name) return hvd.synchronize(handle)
def attention_transfer(self): def at(x): return F.normalize(x.pow(2).mean(1).view(x.size(0), -1)) handles = [] att_group = [] for key in self.activation: at_out = at(self.activation[key]) att_group.append(at_out) at_numpy = at_out.data.unsqueeze(0) handle = hvd.allgather_async(at_numpy, key) handles.append(handle) # self.norm_loss att_loss = 0 for idx, handle in enumerate(handles): rec_att = hvd.synchronize(handle) # att_loss += self.norm_loss(att_group[idx], rec_att.mean(0).cuda(self.device)) for i in range(0, hvd.size()): if i != self.task_id: att_loss += self.norm_loss(att_group[idx], rec_att[i].cuda(self.device)) att_loss /= (hvd.size() - 1) self.log_dict['transfer_count'] += 1 return att_loss
def _broadcast_sparse_inv(self): handles = [] rank = hvd.rank() for i, m in enumerate(self.modules): rank_a = self.m_dA_ranks[m] rank_g = self.m_dG_ranks[m] name = self.module_names[i] h = hvd.broadcast_async_(self.m_QA[m], rank_a, name=name + 'mQA') handles.append(h) h = hvd.broadcast_async_(self.m_QG[m], rank_g, name=name + 'mQG') handles.append(h) for handle in handles: hvd.synchronize(handle)
def allgather_sync(self, tensors, ranks): nworkers = hvd.size() rank = hvd.rank() start = 0 sub_ranks = ranks[start:start+nworkers] sub_tensors = tensors[start:start+nworkers] while len(sub_ranks) > 0: #print('len(sub_ranks): ', len(sub_ranks)) #print('len(sub_tensors): ', len(sub_tensors)) try: idx = sub_ranks.index(rank) except: idx = -1 if idx < 0: tensor = sub_tensors[0].new(0) else: tensor = sub_tensors[idx] handle = hvd.allgather_async(tensor.view(-1)) sync_tensors = hvd.synchronize(handle) offset = 0 for i, r in enumerate(sub_ranks): if idx < 0: continue original_t = sub_tensors[r] numel = original_t.numel() t = sync_tensors[offset:offset+numel] original_t.copy_(t.view(original_t.shape)) offset += numel start += nworkers sub_ranks = ranks[start:start+nworkers] sub_tensors = tensors[start:start+nworkers]
def forward(self, handle): """ Arguments: handle: Handle returned by an `AsyncAllReduce`, `AsyncAllGather`or `AsyncBroadcast` which will be used to retrieve `torch.Tensor`. """ return hvd.synchronize(handle)
def _allgather_factors(self): """Allgather the factors for all layers""" handles = [] def _get_value_and_idx(sparse_tensor): tensor = sparse_tensor.data.view(-1) one_indexes = tensor != 0 indexes = one_indexes.nonzero().data.squeeze().view(-1) values = tensor.data[indexes] return values, indexes.int() for i, m in enumerate(self.modules): module_name = self.module_names[i] A_values, A_indexes = _get_value_and_idx(self.m_A[m].data) A_value_name = module_name + '_A_value' A_idx_name = module_name + '_A_idx' h_value = allgather_async(A_values, A_value_name) h_idx = allgather_async(A_indexes, A_idx_name) G_values, G_indexes = _get_value_and_idx(self.m_G[m].data) G_value_name = module_name + '_G_value' G_idx_name = module_name + '_G_idx' h_value_G = allgather_async(G_values, G_value_name) h_idx_G = allgather_async(G_indexes, G_idx_name) handles.append((h_value, h_idx, h_value_G, h_idx_G)) for i, handle in enumerate(handles): module_name = self.module_names[i] module = self.modules[i] m_A = self.m_A[module].view(-1) m_A.fill_(0.0) m_G = self.m_G[module].view(-1) m_G.fill_(0.0) h_value_A, h_idx_A, h_value_G, h_idx_G = handle A_values = hvd.synchronize(h_value_A) A_indexes = hvd.synchronize(h_idx_A).long() m_A.scatter_add_(0, A_indexes, A_values) m_A.div_(hvd.size()) G_values = hvd.synchronize(h_value_G) G_indexes = hvd.synchronize(h_idx_G).long() m_G.scatter_add_(0, G_indexes, G_values) m_G.div_(hvd.size())
def allreduce_parameters(params): handles = [] if isinstance(params, dict): params = sorted(params.items()) elif isinstance(params, list): # support both named_parameters() and regular parameters() params = [p if isinstance(p, tuple) else (None, p) for p in params] else: raise ValueError('invalid params of type: %s' % type(params)) # Run asynchronous broadcasts. handles = [] for name, p in params: handle = hvd.allreduce_async_(p, average=True, name=name) handles.append(handle) # Wait for completion. for handle in handles: hvd.synchronize(handle)
def test_horovod_allreduce_async_fused(self): """Test that the allreduce correctly sums 1D, 2D, 3D tensors with Tensor Fusion.""" hvd.init() size = hvd.size() dtypes = [ torch.IntTensor, torch.LongTensor, torch.FloatTensor, torch.DoubleTensor ] if _fp16_supported: dtypes += [torch.HalfTensor] if torch.cuda.is_available(): dtypes += [ torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor, torch.cuda.DoubleTensor ] if _fp16_supported: dtypes += [torch.cuda.HalfTensor] dims = [1, 2, 3] tests = [] is_hvd_poll_false_once = False for dtype, dim in itertools.product(dtypes, dims): torch.manual_seed(1234) tensor = torch.FloatTensor(*([17] * dim)).random_(-100, 100) tensor = tensor.type(dtype) handle = hvd.allreduce_async(tensor, average=False) if not hvd.poll(handle): is_hvd_poll_false_once = True tensor, = self.convert_cpu_fp16_to_fp32(tensor) multiplied = tensor * size tests.append((dtype, multiplied, handle)) # Make sure it's an asynchronous operation. assert is_hvd_poll_false_once, 'hvd.poll() always returns True, not an async op?' for dtype, multiplied, handle in tests: summed = hvd.synchronize(handle) summed, = self.convert_cpu_fp16_to_fp32(summed) max_difference = summed.sub(multiplied).max() # Threshold for floating point equality depends on number of # ranks, since we're comparing against precise multiplication. if size <= 3 or dtype in [ torch.IntTensor, torch.LongTensor, torch.cuda.IntTensor, torch.cuda.LongTensor ]: threshold = 0 elif size < 10: threshold = 1e-4 elif size < 15: threshold = 5e-4 else: break assert max_difference <= threshold, 'hvd.allreduce produces incorrect results'
def wait_receive(self, handles, ctx): tensors_decompressed = [] for ranki in handles: tensors_compressed = [synchronize(h) for h in ranki] tensor_decompressed = self.compressor.decompress( tensors_compressed, ctx) tensors_decompressed.append(tensor_decompressed) tensor_aggregated = self.compressor.aggregate(tensors_decompressed) return ( tensor_aggregated / self.world_size) if self.compressor.average else tensor_aggregated
def maybe_allreduce_grads(model): if hvd.size() > 1: tstart_reduce = time.time() named_parameters = list( sorted(model.named_parameters(), key=lambda a: a[0])) grad_handles = [] for name, p in named_parameters: if p.requires_grad: if p.grad is None: p.grad = torch.zeros_like(p) with torch.no_grad(): grad_handles.append(hvd.allreduce_async_(p.grad, name=name)) for handle in grad_handles: hvd.synchronize(handle) tlogger.record_tabular("TimeElapsedAllReduce", time.time() - tstart_reduce) if time.time() - tstart_reduce > 5: import socket tlogger.info( "Allreduce took more than 5 seconds for node {} (rank {})". format(socket.gethostname(), hvd.rank()))
def wait_receive(self, result, ctx): handles, tensor_sizes = result tensors_ag = [] for handle, sizes in zip(handles, tensor_sizes): gathered = synchronize(handle) tensors_ag.append(gathered.split(sizes)) list_tensor_decompressed = [] for tensor_compressed in zip(*tensors_ag): tensor_decompressed = self.compressor.decompress(tensor_compressed, ctx) list_tensor_decompressed.append(tensor_decompressed) tensors_aggregated = self.compressor.aggregate(list_tensor_decompressed) return (tensors_aggregated / self.world_size) if self.compressor.average else tensors_aggregated
def weights_transfer(self): # transfer model weights weights = copy.deepcopy(self.network.state_dict()) handles = [] for name in weights: # TODO: need to consider bias if 'weight' in name: # print(self.task_id, 'send', name) handle = hvd.allgather_async(weights[name], name) handles.append(handle) hidx = 0 for name, param in self.network.named_parameters(): if 'weight' in name: # print(self.task_id, 'rec', name) rec_weights = hvd.synchronize(handles[hidx]) hidx += 1 # print(rec_weights.shape) n_num = param.shape[0] rec_weights = list(torch.split(rec_weights, n_num, 0)) del rec_weights[self.task_id] # TODO weights cat in the first dim, 2*[64,3]--> [128,3] # logging.info(type(rec_weights), rec_weights.shape) # calculate IOM of each filter im_list = [] for i in range(param.shape[0]): im_list.append( torch.sum(torch.abs(param[i])).data.cpu().numpy()) im_list = np.array(im_list) # print('minimal weight sum is {} size {}'.format(im_list.min(), im_list.shape[0])) for i, im in enumerate(im_list): prob = 1 - stats.norm(0, 2).cdf(im) if np.random.rand() < prob: random_sender = np.random.randint(0, len(rec_weights)) new_param = rec_weights[random_sender].clone() # random pic random_filter = np.random.randint( 0, new_param.shape[0]) # TODO give larger weights more chance weights[name][i] = new_param[random_filter] self.log_dict['transfer_count'] += 1 # self.network.state_dict()[name].copy_(param.clone()) # TODO: maybe modify the optimizer self.network.load_state_dict(weights) hvd.allreduce(torch.zeros(1), name='Barrier')
def distributed_matmul_tn(left: Tensor, right: Tensor) -> Tensor: """ Multiply two sequence tensors to obtain the result of :math:`A^{T} B`. Left and right inputs can be N-dimensional tensors, where the first one must be of size :math:`* \times \frac{T}{N} \times T` and the second one of size , where :math:`* \times \frac{T}{N} \times D`, where :math:`T` is the total length, :math:`N` is the total number of processes available and :math:`D`, the dimension of the sequence. The result of this function is a tensor of size :math:`* \times \frac{T}{N} \times D`, that contain the result chunk for each process of the resulting operation. Inputs ------ left: Tensor :math:`A` in :math:`A^T B`, must be of size :math:`* \times \frac{T}{N} \times T` right: Tensor :math:`B` in :math:`A^T B`, must be of size :math:`* \times \frac{T}{N} \times D` Returns ------- result: Tensor For each process, this function computes the corresponding segment of the operation :math:`A^T B`, of size :math:`* \times \frac{T}{N} \times D` """ cols = left.size(-1) world_size = get_world_size() rank = get_rank() split_size = cols // world_size splits = left.split(split_size, -1) rank_block = None synchronize() for r in range(world_size): rank_split = splits[r] rank_multiplication = torch.matmul(rank_split.transpose(-1, -2), right) handle = hvd.allreduce_async(rank_multiplication, name=f'matmul_tn_{r}', op=hvd.Sum) if r == rank: rank_block = hvd.synchronize(handle) return rank_block.contiguous()
def wait_receive(self, handles, ctx): tensors_compressed = [] for h in handles: tensor_compressed = synchronize(h) tensors_compressed.append(tensor_compressed.chunk(self.world_size)) tensors_decompressed = [] if len(tensors_compressed) == 1: for tensor in tensors_compressed[0]: tensors_decompressed.append( self.compressor.decompress([tensor], ctx)) elif len(tensors_compressed) == 2: for tensor, meta in zip(tensors_compressed[0], tensors_compressed[1]): tensors_decompressed.append( self.compressor.decompress((tensor, meta), ctx)) tensors_decompressed = self.memory.aggregate(tensors_decompressed) return tensors_decompressed
def test_horovod_allreduce_async_fused(self): """Test that the allreduce correctly sums 1D, 2D, 3D tensors with Tensor Fusion.""" hvd.init() size = hvd.size() dtypes = [torch.IntTensor, torch.LongTensor, torch.FloatTensor, torch.DoubleTensor] if torch.cuda.is_available(): dtypes += [torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor, torch.cuda.DoubleTensor] dims = [1, 2, 3] tests = [] is_hvd_poll_false_once = False for dtype, dim in itertools.product(dtypes, dims): torch.manual_seed(1234) tensor = torch.FloatTensor(*([17] * dim)).random_(-100, 100) tensor = tensor.type(dtype) handle = hvd.allreduce_async(tensor, average=False) if not hvd.poll(handle): is_hvd_poll_false_once = True multiplied = tensor * size tests.append((dtype, multiplied, handle)) # Make sure it's an asynchronous operation. assert is_hvd_poll_false_once, 'hvd.poll() always returns True, not an async op?' for dtype, multiplied, handle in tests: summed = hvd.synchronize(handle) max_difference = summed.sub(multiplied).max() # Threshold for floating point equality depends on number of # ranks, since we're comparing against precise multiplication. if size <= 3 or dtype in [torch.IntTensor, torch.LongTensor, torch.cuda.IntTensor, torch.cuda.LongTensor]: threshold = 0 elif size < 10: threshold = 1e-4 elif size < 15: threshold = 5e-4 else: break assert max_difference <= threshold, 'hvd.allreduce produces incorrect results'
def _allgather_factors(self): """Allgather the factors for all layers""" handles = [] def _get_value_and_idx(sparse_tensor): tensor = sparse_tensor.data.view(-1) one_indexes = tensor != 0.0 indexes = one_indexes.nonzero().data.squeeze().view(-1) values = tensor.data[indexes] return values, indexes.int() for i, m in enumerate(self.modules): module_name = self.module_names[i] A_values, A_indexes = self.m_sparseA[ m] #_get_value_and_idx(self.m_A[m].data) if A_values.numel() == 0: continue A_value_name = module_name + '_A_value' A_idx_name = module_name + '_A_idx' #h_value = hvd.allgather_async(A_values, A_value_name) #h_idx = hvd.allgather_async(A_indexes, A_idx_name) h_value = hvd.allgather_async(A_values) h_idx = hvd.allgather_async(A_indexes) G_values, G_indexes = self.m_sparseG[ m] #_get_value_and_idx(self.m_G[m].data) G_value_name = module_name + '_G_value' G_idx_name = module_name + '_G_idx' #h_value_G = hvd.allgather_async(G_values, G_value_name) #h_idx_G = hvd.allgather_async(G_indexes, G_idx_name) if G_values is not None and G_values.numel() > 0: h_value_G = hvd.allgather_async(G_values) h_idx_G = hvd.allgather_async(G_indexes) handles.append((h_value, h_idx, h_value_G, h_idx_G)) num_of_workers = hvd.size() def _decompress(values, indices, output): numel = indices.numel() real_num_values = numel // num_of_workers for i in range(num_of_workers): tmp_values = values.data[i * real_num_values:(i + 1) * real_num_values] tmp_indices = indices.data[i * real_num_values:(i + 1) * real_num_values] output[tmp_indices] += tmp_values for i, handle in enumerate(handles): module_name = self.module_names[i] module = self.modules[i] m_A = self.m_A[module].view(-1) m_A.fill_(0.0) m_G = self.m_G[module].view(-1) m_G.fill_(0.0) h_value_A, h_idx_A, h_value_G, h_idx_G = handle A_values = hvd.synchronize(h_value_A) A_indexes = hvd.synchronize(h_idx_A).long() _decompress(A_values, A_indexes, m_A) #print(A_indexes[0]) #print(A_values[0]) #m_A.scatter_add_(0, A_indexes, A_values) m_A.div_(hvd.size()) G_values = hvd.synchronize(h_value_G) G_indexes = hvd.synchronize(h_idx_G).long() #print('G_I: ', G_indexes[0]) #print('G_V: ', G_values[0]) #m_G.scatter_add_(0, G_indexes, G_values) _decompress(G_values, G_indexes, m_G) m_G.div_(hvd.size())
def upsnet_train(): if is_master: logger.info('training config:{}\n'.format(pprint.pformat(config))) gpus = [torch.device('cuda', int(_)) for _ in config.gpus.split(',')] num_replica = hvd.size() if config.train.use_horovod else len(gpus) num_gpus = 1 if config.train.use_horovod else len(gpus) # create models train_model = eval(config.symbol)().cuda() # create optimizer params_lr = train_model.get_params_lr() # we use custom optimizer and pass lr=1 to support different lr for different weights optimizer = SGD(params_lr, lr=1, momentum=config.train.momentum, weight_decay=config.train.wd) if config.train.use_horovod: optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=train_model.named_parameters()) optimizer.zero_grad() # create data loader train_dataset = eval(config.dataset.dataset)(image_sets=config.dataset.image_set.split('+'), flip=config.train.flip, result_path=final_output_path) val_dataset = eval(config.dataset.dataset)(image_sets=config.dataset.test_image_set.split('+'), flip=False, result_path=final_output_path, phase='val') if config.train.use_horovod: train_sampler = distributed.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank()) val_sampler = distributed.DistributedSampler(val_dataset, num_replicas=hvd.size(), rank=hvd.rank()) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size, sampler=train_sampler, num_workers=num_gpus * 4, drop_last=False, collate_fn=train_dataset.collate) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train.batch_size, sampler=val_sampler, num_workers=num_gpus * 4, drop_last=False, collate_fn=val_dataset.collate) else: train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size, shuffle=config.train.shuffle, num_workers=num_gpus * 4 if not config.debug_mode else num_gpus * 4, drop_last=False, collate_fn=train_dataset.collate) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train.batch_size, shuffle=False, num_workers=num_gpus * 4 if not config.debug_mode else num_gpus * 4, drop_last=False, collate_fn=val_dataset.collate) # preparing curr_iter = config.train.begin_iteration batch_end_callback = [Speedometer(num_replica * config.train.batch_size, config.train.display_iter)] metrics = [] metrics_name = [] if config.network.has_rpn: metrics.extend([AvgMetric(name='rpn_cls_loss'), AvgMetric(name='rpn_bbox_loss'),]) metrics_name.extend(['rpn_cls_loss', 'rpn_bbox_loss']) if config.network.has_rcnn: metrics.extend([AvgMetric(name='rcnn_accuracy'), AvgMetric(name='cls_loss'), AvgMetric(name='bbox_loss'),]) metrics_name.extend(['rcnn_accuracy', 'cls_loss', 'bbox_loss']) if config.network.has_mask_head: metrics.extend([AvgMetric(name='mask_loss'), ]) metrics_name.extend(['mask_loss']) if config.network.has_fcn_head: metrics.extend([AvgMetric(name='fcn_loss'), ]) metrics_name.extend(['fcn_loss']) if config.train.fcn_with_roi_loss: metrics.extend([AvgMetric(name='fcn_roi_loss'), ]) metrics_name.extend(['fcn_roi_loss']) if config.network.has_panoptic_head: metrics.extend([AvgMetric(name='panoptic_accuracy'), AvgMetric(name='panoptic_loss'), ]) metrics_name.extend(['panoptic_accuracy', 'panoptic_loss']) if config.train.resume: train_model.load_state_dict(torch.load(os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth')), resume=True) optimizer.load_state_dict(torch.load(os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth'))) if config.train.use_horovod: hvd.broadcast_parameters(train_model.state_dict(), root_rank=0) else: if is_master: train_model.load_state_dict(torch.load(config.network.pretrained)) if config.train.use_horovod: hvd.broadcast_parameters(train_model.state_dict(), root_rank=0) if not config.train.use_horovod: train_model = DataParallel(train_model, device_ids=[int(_) for _ in config.gpus.split(',')]).to(gpus[0]) if is_master: batch_end_callback[0](0, 0) train_model.eval() # start training while curr_iter < config.train.max_iteration: if config.train.use_horovod: train_sampler.set_epoch(curr_iter) if config.network.use_syncbn: train_model.train() if config.network.backbone_freeze_at > 0: train_model.freeze_backbone(config.network.backbone_freeze_at) if config.network.backbone_fix_bn: train_model.resnet_backbone.eval() for inner_iter, batch in enumerate(train_loader): data, label, _ = batch for k, v in data.items(): data[k] = v if not torch.is_tensor(v) else v.cuda() for k, v in label.items(): label[k] = v if not torch.is_tensor(v) else v.cuda() lr = adjust_learning_rate(optimizer, curr_iter, config) optimizer.zero_grad() output = train_model(data, label) loss = 0 if config.network.has_rpn: loss = loss + output['rpn_cls_loss'].mean() + output['rpn_bbox_loss'].mean() if config.network.has_rcnn: loss = loss + output['cls_loss'].mean() + output['bbox_loss'].mean() * config.train.bbox_loss_weight if config.network.has_mask_head: loss = loss + output['mask_loss'].mean() if config.network.has_fcn_head: loss = loss + output['fcn_loss'].mean() * config.train.fcn_loss_weight if config.train.fcn_with_roi_loss: loss = loss + output['fcn_roi_loss'].mean() * config.train.fcn_loss_weight * 0.2 if config.network.has_panoptic_head: loss = loss + output['panoptic_loss'].mean() * config.train.panoptic_loss_weight loss.backward() optimizer.step(lr) losses = [] losses.append(allreduce_async(loss, name='train_total_loss')) for l in metrics_name: losses.append(allreduce_async(output[l].mean(), name=l)) loss = hvd.synchronize(losses[0]).item() if is_master: writer.add_scalar('train_total_loss', loss, curr_iter) for i, (metric, l) in enumerate(zip(metrics, metrics_name)): loss = hvd.synchronize(losses[i + 1]).item() if is_master: writer.add_scalar('train_' + l, loss, curr_iter) metric.update(_, _, loss) curr_iter += 1 if curr_iter in config.train.decay_iteration: if is_master: logger.info('decay momentum buffer') for k in optimizer.state_dict()['state'].keys(): if 'momentum_buffer' in optimizer.state_dict()['state'][k]: optimizer.state_dict()['state'][k]['momentum_buffer'].div_(10) if is_master: if curr_iter % config.train.display_iter == 0: for callback in batch_end_callback: callback(curr_iter, metrics) if curr_iter % config.train.snapshot_step == 0: logger.info('taking snapshot ...') torch.save(train_model.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.pth')) torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.state.pth')) else: inner_iter = 0 train_iterator = train_loader.__iter__() while inner_iter + num_gpus <= len(train_loader): batch = [] for gpu_id in gpus: data, label, _ = train_iterator.next() for k, v in data.items(): data[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True) for k, v in label.items(): label[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True) batch.append((data, label)) inner_iter += 1 lr = adjust_learning_rate(optimizer, curr_iter, config) optimizer.zero_grad() if config.train.use_horovod: output = train_model(data, label) else: output = train_model(*batch) loss = 0 if config.network.has_rpn: loss = loss + output['rpn_cls_loss'].mean() + output['rpn_bbox_loss'].mean() if config.network.has_rcnn: loss = loss + output['cls_loss'].mean() + output['bbox_loss'].mean() if config.network.has_mask_head: loss = loss + output['mask_loss'].mean() if config.network.has_fcn_head: loss = loss + output['fcn_loss'].mean() * config.train.fcn_loss_weight if config.train.fcn_with_roi_loss: loss = loss + output['fcn_roi_loss'].mean() * config.train.fcn_loss_weight * 0.2 if config.network.has_panoptic_head: loss = loss + output['panoptic_loss'].mean() * config.train.panoptic_loss_weight loss.backward() optimizer.step(lr) losses = [] losses.append(loss.item()) for l in metrics_name: losses.append(output[l].mean().item()) loss = losses[0] if is_master: writer.add_scalar('train_total_loss', loss, curr_iter) for i, (metric, l) in enumerate(zip(metrics, metrics_name)): loss = losses[i + 1] if is_master: writer.add_scalar('train_' + l, loss, curr_iter) metric.update(_, _, loss) curr_iter += 1 if curr_iter in config.train.decay_iteration: if is_master: logger.info('decay momentum buffer') for k in optimizer.state_dict()['state'].keys(): optimizer.state_dict()['state'][k]['momentum_buffer'].div_(10) if is_master: if curr_iter % config.train.display_iter == 0: for callback in batch_end_callback: callback(curr_iter, metrics) if curr_iter % config.train.snapshot_step == 0: logger.info('taking snapshot ...') torch.save(train_model.module.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.pth')) torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.state.pth')) while True: try: train_iterator.next() except: break for metric in metrics: metric.reset() if config.train.eval_data: train_model.eval() if config.train.use_horovod: for inner_iter, batch in enumerate(val_loader): data, label, _ = batch for k, v in data.items(): data[k] = v if not torch.is_tensor(v) else v.cuda(non_blocking=True) for k, v in label.items(): label[k] = v if not torch.is_tensor(v) else v.cuda(non_blocking=True) with torch.no_grad(): output = train_model(data, label) for metric, l in zip(metrics, metrics_name): loss = hvd.allreduce(output[l].mean()).item() if is_master: metric.update(_, _, loss) else: inner_iter = 0 val_iterator = val_loader.__iter__() while inner_iter + len(gpus) <= len(val_loader): batch = [] for gpu_id in gpus: data, label, _ = val_iterator.next() for k, v in data.items(): data[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True) for k, v in label.items(): label[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True) batch.append((data, label)) inner_iter += 1 with torch.no_grad(): if config.train.use_horovod: output = train_model(data, label) else: output = train_model(*batch) losses = [] for l in metrics_name: losses.append(allreduce_async(output[l].mean(), name=l) if config.train.use_horovod else output[l].mean().item()) for metric, loss in zip(metrics, losses): loss = hvd.synchronize(loss).item() if config.train.use_horovod else loss if is_master: metric.update(_, _, loss) while True: try: val_iterator.next() except Exception: break s = 'Batch [%d]\t Epoch[%d]\t' % (curr_iter, curr_iter // len(train_loader)) for metric in metrics: m, v = metric.get() s += 'Val-%s=%f,\t' % (m, v) if is_master: writer.add_scalar('val_' + m, v, curr_iter) metric.reset() if is_master: logger.info(s) if is_master and config.train.use_horovod: logger.info('taking snapshot ...') torch.save(train_model.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth')) torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth')) elif not config.train.use_horovod: logger.info('taking snapshot ...') torch.save(train_model.module.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth')) torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth'))
def step(self, closure=None, epoch=None): """Perform one K-FAC step Note: - this function should always be called before `optimizer.step()` - gradients must be averaged across ranks before calling `step()` Args: closure: for compatibility with the base optimizer class. `closure` is ignored by KFAC epoch (int, optional): epoch to use for determining when to end the `diag_warmup` period. `epoch` is not necessary if not using `diag_warmup` """ # Update params, used for compatibilty with `KFACParamScheduler` group = self.param_groups[0] self.lr = group['lr'] self.damping = group['damping'] self.fac_update_freq = group['fac_update_freq'] self.kfac_update_freq = group['kfac_update_freq'] updates = {} handles = [] if epoch is None: if self.diag_warmup > 0: print("WARNING: diag_warmup > 0 but epoch was not passed to " "KFAC.step(). Defaulting to no diag_warmup") diag_blocks = self.diag_blocks else: diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1 if hvd.size() > 1 and self.steps % self.fac_update_freq == 0: self.fw_merged_comm.synchronize() self.bw_merged_comm.synchronize() #for handle in self.fw_factor_handles: # hvd.synchronize(handle) #self.fw_factor_handles.clear() #for handle in self.bw_factor_handles: # hvd.synchronize(handle) #self.bw_factor_handles.clear() # if we are switching from no diag approx to approx, we need to clear # off-block-diagonal elements if not self.have_cleared_Q and \ epoch == self.diag_warmup and \ self.steps % self.kfac_update_freq == 0: self._clear_eigen() self.have_cleared_Q = True if self.steps % self.kfac_update_freq == 0: # reset rank iter so device get the same layers # to compute to take advantage of caching self.rank_iter.reset() handles = [] #eigen_ranks = self._generate_eigen_ranks(epoch) eigen_ranks = self._generate_eigen_ranks_uniform(epoch) #eigen_ranks = self._generate_eigen_ranks_naive(epoch) for module in self.modules: ranks_a, ranks_g = eigen_ranks[module] self.m_dA_ranks[module] = ranks_a[0] self.m_dG_ranks[module] = ranks_g[0] rank_a = ranks_a[0] rank_g = ranks_g[0] self._update_eigen_A(module, ranks_a) h1 = hvd.broadcast_async_(self.m_QA[module], rank_a) h2 = hvd.broadcast_async_(self.m_dA[module], rank_a) self._update_eigen_G(module, ranks_g) h3 = hvd.broadcast_async_(self.m_QG[module], rank_g) h4 = hvd.broadcast_async_(self.m_dG[module], rank_g) handles.append((h1, h2, h3, h4)) if hvd.size() > 1: #for handle in handles: # hvd.synchronize(handle) #self._allreduce_eigendecomp() #self._broadcast_eigendecomp() pass for i, module in enumerate(self.modules): if hvd.size() > 1 and len(handles) > 0: h1, h2, h3, h4 = handles[i] hvd.synchronize(h1) hvd.synchronize(h2) hvd.synchronize(h3) hvd.synchronize(h4) grad = self._get_grad(module) precon_grad = self._get_preconditioned_grad(module, grad) updates[module] = precon_grad #self._update_scale_grad(updates) self.steps += 1
def barrier(): torch.cuda.synchronize() handle = hvd.broadcast_async_(sync_tensor, root_rank=0) hvd.synchronize(handle)