def test_collectives(self): dist.Backend.register_backend("dummy", PythonProcessGroupTest.create_dummy) os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '6789' dist.init_process_group("dummy", rank=self.rank, world_size=self.world_size) # test all_gather input_tensor = torch.ones(2, 2) * 7 output_tensor_list = [torch.zeros(2, 2) for _ in range(self.world_size)] dist.all_gather(output_tensor_list, input_tensor) for tensor in output_tensor_list: self.assertEqual(tensor, input_tensor) # test all_reduce input_tensor = torch.ones(2, 2) * 7 dist.all_reduce(input_tensor) self.assertEqual(input_tensor, torch.ones(2, 2) * 7 + 2) # test broadcast input_tensor = torch.zeros(2, 2) dist.broadcast(input_tensor, 0, async_op=True).wait() self.assertEqual(torch.ones(2, 2), input_tensor) # test reduce_scatter output_tensor = torch.zeros(2, 2) input_tensor_list = [torch.ones(2, 2) for _ in range(self.world_size)] dist.reduce_scatter(output_tensor, input_tensor_list) self.assertEqual(output_tensor, torch.zeros(2, 2) + 1) dist.destroy_process_group()
def forward(ctx, op, group, tensor, *input_tensor_list): ctx.group = group dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group) return tensor
def _run_partial_tensor_n_reshard(self, reshard_spec, input_size, world_size, reduce_op, dtype=torch.float): results = [] results_compare = [] for _ in range(0, world_size): tensor = torch.rand(*input_size, dtype=dtype).cuda(self.rank) results.append(tensor) results_compare.append(tensor.clone().detach()) pg = dist.distributed_c10d._get_default_group() parital_tensor = _PartialTensor(torch.cat(results), pg, reduce_op=reduce_op) local_sharded_result = parital_tensor.reshard(reshard_spec) local_shards = local_sharded_result.local_shards() if pg.size() > world_size: chunk_mode_res = (input_size[0] * world_size) % pg.size() padding = [0] * (len(input_size) * 2) padding[-1] = pg.size() - chunk_mode_res results_compare = list( torch.nn.functional.pad( torch.cat(results_compare), tuple(padding), "constant", 0, ).chunk(pg.size())) local_result_compare = torch.empty_like(results_compare[0]) dist.reduce_scatter(local_result_compare, results_compare, op=reduce_op) self.assertEqual(1, len(local_shards)) self.assertEqual(local_shards[0].tensor, local_result_compare)
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg): # alltoall to gather all the appropriate inputs. input_t = input.t().contiguous() input_t_size = input_t.size() # Compute expected size split_size = get_split_size(input_t_size[0], world_size) input_split_sizes = [0] * world_size rearrange_rows = False for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx) input_split_sizes[placement.rank()] = sharded_dim_size if placement.rank() != idx: rearrange_rows = True if rearrange_rows: # Need to re-arrange rows of input_t for all2all. indices: List[List[int]] = [[0]] * world_size # When we do the chunk split, we always ensure the first N - 1 chunks get max out # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] # are not possible. The expected split size will be [4, 4, 4, 1]. sharded_dim_size_max = max(input_split_sizes) for idx, placement in enumerate(weight._sharding_spec.placements): split_size = input_split_sizes[placement.rank()] offset_start_idx = idx * sharded_dim_size_max indices[placement.rank()] = list( range(offset_start_idx, offset_start_idx + split_size)) indices_flatten = list(idx for indice in indices for idx in indice) input_t = input_t.index_select( 0, torch.tensor(indices_flatten, device=input_t.device)) gathered_input = torch.empty(input_split_sizes[rank] * world_size, input_t_size[1], device=input_t.device) # Perform alltoall dist.all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg) gathered_input = gathered_input.t() # Perform local matmuls for all shards shard_size = local_shard_t.size()[0] results = [] for r in range(world_size): inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size) results.append(inp.matmul(local_shard_t)) # Gather all the results appropriately. local_result = torch.empty_like(results[rank]) dist.reduce_scatter(local_result, results, group=pg) # Return the appropriate local result. return local_result + bias
def forward(ctx, op, group, tensor, *input_tensor_list): ctx.group = group input_tensor_list = tuple(t.contiguous() for t in input_tensor_list) dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group) return tensor
def backward(ctx, *grads): world_size = dist.get_world_size() rank = dist.get_rank() grad_list = list(grads) grad_out = torch.zeros_like(grad_list[rank], requires_grad=True) dist.reduce_scatter(grad_out, grad_list, op=ReduceOp.SUM) # Gradient correction for DistCrossEntropy grad_out = grad_out * world_size return (grad_out, None)
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg): # alltoall to gather all the appropriate inputs. input_t = input.t().contiguous() input_t_size = input_t.size() # Compute expected size split_size = get_split_size(input_t_size[0], world_size) input_split_sizes = [0] * world_size rearrange_rows = False for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx) input_split_sizes[placement.rank()] = sharded_dim_size if placement.rank() != idx: rearrange_rows = True if rearrange_rows: # Need to re-arrange rows of input_t for all2all. indices: List[int] = [] for placement in weight._sharding_spec.placements: sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, placement.rank()) input_idx = placement.rank() * split_size indices += range(input_idx, input_idx + sharded_dim_size) input_t = input_t.index_select( 0, torch.tensor(indices, device=input_t.device)) gathered_input = torch.empty(input_split_sizes[rank] * world_size, input_t_size[1], device=input_t.device) # Perform alltoall dist.all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg) gathered_input = gathered_input.t() # Perform local matmuls for all shards shard_size = local_shard_t.size()[0] results = [] for r in range(world_size): inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size) results.append(inp.matmul(local_shard_t)) # Gather all the results appropriately. local_result = torch.empty_like(results[rank]) dist.reduce_scatter(local_result, results, group=pg) # Return the appropriate local result. return local_result + bias
def forward_backward(self, label, features, optimizer): total_label, norm_weight = self.prepare(label, optimizer) total_features = torch.zeros( size=[self.batch_size * self.world_size, self.embedding_size], device=self.device) dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data) total_features.requires_grad = True logits = self.forward(total_features, norm_weight) logits = self.margin_softmax(logits, total_label) with torch.no_grad(): max_fc = torch.max(logits, dim=1, keepdim=True)[0] dist.all_reduce(max_fc, dist.ReduceOp.MAX) # calculate exp(logits) and all-reduce logits_exp = torch.exp(logits - max_fc) logits_sum_exp = logits_exp.sum(dim=1, keepdims=True) dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM) # calculate prob logits_exp.div_(logits_sum_exp) # get one-hot grad = logits_exp index = torch.where(total_label != -1)[0] one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device) one_hot.scatter_(1, total_label[index, None], 1) # calculate loss loss = torch.zeros(grad.size()[0], 1, device=grad.device) loss[index] = grad[index].gather(1, total_label[index, None]) dist.all_reduce(loss, dist.ReduceOp.SUM) loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) # calculate grad grad[index] -= one_hot grad.div_(self.batch_size * self.world_size) logits.backward(grad) if total_features.grad is not None: total_features.grad.detach_() x_grad: torch.Tensor = torch.zeros_like(features) x_grad.mul_(self.world_size) # feature gradient all-reduce dist.reduce_scatter( x_grad, list(total_features.grad.chunk(self.world_size, dim=0))) # backward backbone return x_grad, loss_v
def _torch_reduce_scatter(output_tensors, input_tensors, partition_sizes, rank, world_size): """""" for part_idx, part_size in enumerate(partition_sizes): output_t = output_tensors[part_idx] input_t = input_tensors[part_idx] input_list = [] for i in range(world_size): _input = input_t.narrow(0, i * part_size, part_size) input_list.append(_input) dist.reduce_scatter(output_t, input_list) torch.cuda.synchronize()
def flush(self) -> None: if self.offset == 0: assert len(self.callbacks) == 0 return # reduce-scatter bucket dist.reduce_scatter(self.output_shard[:self.offset], list(self.data[:, :self.offset].unbind(0)), group=self.group) # execute post-reduction callbacks for callback_fn in self.callbacks: callback_fn() # reuse input bucket but allocate a fresh output shard self.data[:, :self.offset].zero_() self.offset = 0 self.callbacks.clear() self.output_shard = torch.zeros_like(self.data[0])
def reduce_scatter_fn(output_tensor, input_tensor, group=None, async_op=False): from torch.distributed import reduce_scatter, get_world_size from torch import chunk input_tensor_lst = list(chunk(input_tensor, get_world_size(group))) return reduce_scatter(output_tensor, input_tensor_lst, group=group)
def _mp_fn(index): device = xm.xla_device() if xm.xla_device_hw(device) in ('TPU', 'GPU'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() dist.init_process_group('xla', world_size=world_size, rank=rank) input_size = (32, 3) inputs = torch.ones(input_size).split(input_size[0] // world_size) output = torch.zeros_like(inputs[0]) xinputs = [i.to(device) for i in inputs] xoutput = output.to(device) dist.reduce_scatter(xoutput, xinputs) expected = torch.ones_like(inputs[0]) * world_size assert torch.all(xoutput.cpu() == expected), f'{xoutput} != {expected}' else: print('Default device {} is not a TPU or GPU device'.format(device), file=sys.stderr)
def reduce_scatter(self, collectiveArgs, retFlag=False, pair=False): retObj = dist.reduce_scatter( output=collectiveArgs.opTensor, input_list=collectiveArgs.ipTensor, group=collectiveArgs.group, async_op=collectiveArgs.asyncOp, ) # synchronicity is maintained in runColl if collectiveArgs.asyncOp: collectiveArgs.waitObj.append(retObj) if retFlag: return retObj
def flush(self) -> None: """Flush content of the bucket.""" if self.offset == 0: assert len(self.callbacks) == 0 return # reduce-scatter bucket if hasattr(dist, "_reduce_scatter_base"): dist._reduce_scatter_base( # type: ignore self.output_shard[:self.offset], self.data[:, :self.offset].contiguous(), group=self.group) else: dist.reduce_scatter(self.output_shard[:self.offset], list(self.data[:, :self.offset].unbind(0)), group=self.group) # execute post-reduction callbacks for callback_fn in self.callbacks: callback_fn() # reuse input bucket but allocate a fresh output shard self.data[:, :self.offset].zero_() self.offset = 0 self.callbacks.clear() self.output_shard = torch.zeros_like(self.data[0])
def _run_partial_tensor_n_reshard(self, reshard_spec, input_size, world_size, reduce_op, dtype=torch.float): results = [] results_compare = [] for _ in range(0, world_size): tensor = torch.rand(*input_size, dtype=dtype).cuda(self.rank) results.append(tensor) results_compare.append(tensor.clone().detach()) pg = dist.distributed_c10d._get_default_group() parital_tensor = _PartialTensor(torch.cat(results), pg, reduce_op=reduce_op) local_sharded_result = parital_tensor.reshard(reshard_spec) local_shards = local_sharded_result.local_shards() local_result_compare = torch.empty_like(results_compare[0]) dist.reduce_scatter(local_result_compare, results_compare, op=reduce_op) self.assertEqual(1, len(local_shards)) self.assertEqual(local_shards[0].tensor, local_result_compare)
def main(local_rank): dist.init_process_group(backend='nccl', init_method='env://') cfg.local_rank = local_rank torch.cuda.set_device(local_rank) cfg.rank = dist.get_rank() cfg.world_size = dist.get_world_size() trainset = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank) train_sampler = torch.utils.data.distributed.DistributedSampler( trainset, shuffle=True) train_loader = DataLoaderX(local_rank=local_rank, dataset=trainset, batch_size=cfg.batch_size, sampler=train_sampler, num_workers=0, pin_memory=True, drop_last=False) backbone = backbones.iresnet100(False).to(local_rank) backbone.train() # Broadcast init parameters for ps in backbone.parameters(): dist.broadcast(ps, 0) # DDP backbone = torch.nn.parallel.DistributedDataParallel( module=backbone, broadcast_buffers=False, device_ids=[cfg.local_rank]) backbone.train() # Memory classifer dist_sample_classifer = DistSampleClassifier(rank=dist.get_rank(), local_rank=local_rank, world_size=cfg.world_size) # Margin softmax margin_softmax = MarginSoftmax(s=64.0, m=0.4) # Optimizer for backbone and classifer optimizer = SGD([{ 'params': backbone.parameters() }, { 'params': dist_sample_classifer.parameters() }], lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay, rescale=cfg.world_size) # Lr scheduler scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=cfg.lr_func) n_epochs = cfg.num_epoch start_epoch = 0 if local_rank == 0: writer = SummaryWriter(log_dir='logs/shows') # total_step = int( len(trainset) / cfg.batch_size / dist.get_world_size() * cfg.num_epoch) if dist.get_rank() == 0: print("Total Step is: %d" % total_step) losses = AverageMeter() global_step = 0 train_start = time.time() for epoch in range(start_epoch, n_epochs): train_sampler.set_epoch(epoch) for step, (img, label) in enumerate(train_loader): total_label, norm_weight = dist_sample_classifer.prepare( label, optimizer) features = F.normalize(backbone(img)) # Features all-gather total_features = torch.zeros(features.size()[0] * cfg.world_size, cfg.embedding_size, device=local_rank) dist.all_gather(list(total_features.chunk(cfg.world_size, dim=0)), features.data) total_features.requires_grad = True # Calculate logits logits = dist_sample_classifer(total_features, norm_weight) logits = margin_softmax(logits, total_label) with torch.no_grad(): max_fc = torch.max(logits, dim=1, keepdim=True)[0] dist.all_reduce(max_fc, dist.ReduceOp.MAX) # Calculate exp(logits) and all-reduce logits_exp = torch.exp(logits - max_fc) logits_sum_exp = logits_exp.sum(dim=1, keepdims=True) dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM) # Calculate prob logits_exp.div_(logits_sum_exp) # Get one-hot grad = logits_exp index = torch.where(total_label != -1)[0] one_hot = torch.zeros(index.size()[0], grad.size()[1], device=grad.device) one_hot.scatter_(1, total_label[index, None], 1) # Calculate loss loss = torch.zeros(grad.size()[0], 1, device=grad.device) loss[index] = grad[index].gather(1, total_label[index, None]) dist.all_reduce(loss, dist.ReduceOp.SUM) loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) # Calculate grad grad[index] -= one_hot grad.div_(features.size()[0]) logits.backward(grad) if total_features.grad is not None: total_features.grad.detach_() x_grad = torch.zeros_like(features) # Feature gradient all-reduce dist.reduce_scatter( x_grad, list(total_features.grad.chunk(cfg.world_size, dim=0))) x_grad.mul_(cfg.world_size) # Backward backbone features.backward(x_grad) optimizer.step() # Update classifer dist_sample_classifer.update() optimizer.zero_grad() losses.update(loss_v, 1) if cfg.local_rank == 0 and step % 50 == 0: time_now = (time.time() - train_start) / 3600 time_total = time_now / ((global_step + 1) / total_step) time_for_end = time_total - time_now writer.add_scalar('time_for_end', time_for_end, global_step) writer.add_scalar('loss', loss_v, global_step) print( "Speed %d samples/sec Loss %.4f Epoch: %d Global Step: %d Required: %1.f hours" % ((cfg.batch_size * global_step / (time.time() - train_start) * cfg.world_size), losses.avg, epoch, global_step, time_for_end)) losses.reset() global_step += 1 scheduler.step() if dist.get_rank() == 0: import os if not os.path.exists(cfg.output): os.makedirs(cfg.output) torch.save(backbone.module.state_dict(), os.path.join(cfg.output, str(epoch) + 'backbone.pth')) dist.destroy_process_group()
def reduce_scatter_gradients(self, postscale_gradients, gradient_predivide_factor, gradient_average): world_size = dist.get_world_size(group=self.dp_process_group) local_rank = dist.get_rank(group=self.dp_process_group) for i, group in enumerate(self.fp16_groups): partition_param_map = {} param_partition_map = {} my_params = set() # [rank] -> [comm] -> partition num_comm_intervals = self.num_comm_intervals_per_group[i] all_sub_partitions = [] for rank in range(world_size): # gsp is list of partitions indexed by comm_idx #FIXME: currently hardcoding fp16, should infer dtype grad_sub_partitions, partition_params, param_offsets = self.get_flat_sub_partitions( comm_tensor_list=self.params_in_rank_sub_partitions[i] [rank], comm_param_offsets=self. params_in_rank_sub_partitions_offsets[i][rank], sub_partition_size=self.sub_partition_sizes[i], dtype=torch. half, #self.params_in_rank_sub_partitions[i][rank][0][0].dtype, num_comm_intervals=self.num_comm_intervals_per_group[i], default_device= 'cuda', #self.params_in_rank_sub_partitions[i][rank][0][0].device, return_partition_params=True) all_sub_partitions.append(grad_sub_partitions) # create map from partition -> params in that partition for comm_idx, part in enumerate(grad_sub_partitions): partition_param_map[part] = (partition_params[comm_idx], param_offsets[comm_idx]) for comm_idx, params in enumerate(partition_params): for pidx, p in enumerate(params): # store the parameters we care about locally if rank == local_rank: my_params.add(p) # map from param -> partitions if p in param_partition_map: param_partition_map[p].append( grad_sub_partitions[comm_idx]) else: param_partition_map[p] = [ grad_sub_partitions[comm_idx] ] assert len(grad_sub_partitions) == num_comm_intervals if not postscale_gradients: raise NotImplementedError( "pre-scale_gradients is not implemented") all_comm_partitions = [] for comm_idx in range(num_comm_intervals): single_comm_all_partitions = [] for rank in range(world_size): single_comm_all_partitions.append( all_sub_partitions[rank][comm_idx]) dist.reduce_scatter( output=single_comm_all_partitions[local_rank], input_list=single_comm_all_partitions, group=self.dp_process_group) if gradient_average: for partition in single_comm_all_partitions: partition.mul_(gradient_predivide_factor / world_size) all_comm_partitions.append(single_comm_all_partitions) for p in my_params: partitions = param_partition_map[p] parts = [] for part in partitions: params, offsets = partition_param_map[part] found = False for p_idx, _p in enumerate(params): if p.__hash__() == _p.__hash__(): found = True if offsets[p_idx][0] is not None: my_part = part.narrow(0, offsets[p_idx][0], offsets[p_idx][1]) parts.append(my_part) assert found if p is not None: updated_grad = _unflatten_dense_tensors( torch.cat(parts), [p]) p.grad.copy_(updated_grad[0])
def main(local_rank): cfg.local_rank = local_rank # cfg.rank = dist.get_rank() # cfg.world_size = dist.get_world_size() backbone = backbones.iresnet50(False) weights = torch.load("pytorch/partial_fc_glint360k_r50/16backbone.pth", map_location=torch.device('cpu')) backbone.load_state_dict(weights) backbone = backbone.float() backbone = backbone.eval() # embedding 512 img1 = cv2.imread('boy_1.jpg') img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) img1 = image_preprocessing(img1) img1 = np.ones([112, 112, 3], dtype=np.float32) img1 = img1.transpose([2, 0, 1]) img1 = np.expand_dims(img1, axis=0) img1 = torch.from_numpy(img1).float() img1 = torch.autograd.Variable(img1, requires_grad=False).to('cpu') img2 = cv2.imread('man_2.jpg') img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB) img2 = image_preprocessing(img2) img2 = img2.transpose([2, 0, 1]) img2 = np.expand_dims(img2, axis=0) img2 = torch.from_numpy(img2).float() img2 = torch.autograd.Variable(img2, requires_grad=False).to('cpu') with torch.no_grad(): v1 = backbone.forward(img1) v2 = backbone.forward(img2) v1 = np.asarray(v1) import pickle pickle.dump(v1, open("sample.pkl", "wb")) print(v1) result = cosine_dist(v1, v2) print(result) exit(0) # Broadcast init parameters for ps in backbone.parameters(): dist.broadcast(ps, 0) # DDP backbone = torch.nn.parallel.DistributedDataParallel( module=backbone, broadcast_buffers=False, device_ids=[cfg.local_rank]) backbone.train() # Memory classifer dist_sample_classifer = DistSampleClassifier(rank=dist.get_rank(), local_rank=local_rank, world_size=cfg.world_size) # Margin softmax margin_softmax = MarginSoftmax(s=64.0, m=0.4) # Optimizer for backbone and classifer optimizer = SGD([{ 'params': backbone.parameters() }, { 'params': dist_sample_classifer.parameters() }], lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay, rescale=cfg.world_size) # Lr scheduler scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=cfg.lr_func) n_epochs = cfg.num_epoch start_epoch = 0 if local_rank == 0: writer = SummaryWriter(log_dir='logs/shows') # total_step = int( len(trainset) / cfg.batch_size / dist.get_world_size() * cfg.num_epoch) if dist.get_rank() == 0: print("Total Step is: %d" % total_step) losses = AverageMeter() global_step = 0 train_start = time.time() for epoch in range(start_epoch, n_epochs): train_sampler.set_epoch(epoch) for step, (img, label) in enumerate(train_loader): total_label, norm_weight = dist_sample_classifer.prepare( label, optimizer) features = F.normalize(backbone(img)) # Features all-gather total_features = torch.zeros(features.size()[0] * cfg.world_size, cfg.embedding_size, device=local_rank) dist.all_gather(list(total_features.chunk(cfg.world_size, dim=0)), features.data) total_features.requires_grad = True # Calculate logits logits = dist_sample_classifer(total_features, norm_weight) logits = margin_softmax(logits, total_label) with torch.no_grad(): max_fc = torch.max(logits, dim=1, keepdim=True)[0] dist.all_reduce(max_fc, dist.ReduceOp.MAX) # Calculate exp(logits) and all-reduce logits_exp = torch.exp(logits - max_fc) logits_sum_exp = logits_exp.sum(dim=1, keepdims=True) dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM) # Calculate prob logits_exp.div_(logits_sum_exp) # Get one-hot grad = logits_exp index = torch.where(total_label != -1)[0] one_hot = torch.zeros(index.size()[0], grad.size()[1], device=grad.device) one_hot.scatter_(1, total_label[index, None], 1) # Calculate loss loss = torch.zeros(grad.size()[0], 1, device=grad.device) loss[index] = grad[index].gather(1, total_label[index, None]) dist.all_reduce(loss, dist.ReduceOp.SUM) loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) # Calculate grad grad[index] -= one_hot grad.div_(features.size()[0]) logits.backward(grad) if total_features.grad is not None: total_features.grad.detach_() x_grad = torch.zeros_like(features) # Feature gradient all-reduce dist.reduce_scatter( x_grad, list(total_features.grad.chunk(cfg.world_size, dim=0))) x_grad.mul_(cfg.world_size) # Backward backbone features.backward(x_grad) optimizer.step() # Update classifer dist_sample_classifer.update() optimizer.zero_grad() losses.update(loss_v, 1) if cfg.local_rank == 0 and step % 50 == 0: time_now = (time.time() - train_start) / 3600 time_total = time_now / ((global_step + 1) / total_step) time_for_end = time_total - time_now writer.add_scalar('time_for_end', time_for_end, global_step) writer.add_scalar('loss', loss_v, global_step) print( "Speed %d samples/sec Loss %.4f Epoch: %d Global Step: %d Required: %1.f hours" % ((cfg.batch_size * global_step / (time.time() - train_start) * cfg.world_size), losses.avg, epoch, global_step, time_for_end)) losses.reset() global_step += 1 scheduler.step() if dist.get_rank() == 0: import os if not os.path.exists(cfg.output): os.makedirs(cfg.output) torch.save(backbone.module.state_dict(), os.path.join(cfg.output, str(epoch) + 'backbone.pth')) dist.destroy_process_group()
def reduce_scatter_async( self, input_list: List[Tensor], group: ProcessGroup, callback_fn: Optional[Callable] = None, ) -> None: """ Reduce-scatter a list of tensors asynchronously, so smaller reductions can be bucketed together. The given callback (``callback_fn``) will be called with the reduced result at some later time. Call ``flush()`` to force all queued ops and callbacks to be executed. Note that large inputs will be reduced immediately, and this function may also flush the relevant bucket to make room for ``input_list``. Args: input_list (List[Tensor]): list of tensors to reduce-scatter. List should contain ``group.size()`` tensors and each tensor should have identical shape, dtype and device. group (ProcessGroup): process group for reduction callback_fn (Callable, Optional): callback function to call after the reduction executes. Function will be called with a single argument corresponding to the reduced result. """ world_size = group.size() assert ( len(input_list) == world_size ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})" first_input = input_list[0] first_input_size = first_input.numel() bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size) if first_input_size > bucket_shard_size: # TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors) # input is too big to fit in the bucket, reduce-scatter directly output = torch.zeros_like(input_list[0]) if hasattr(dist, "_reduce_scatter_base"): input_flattened = torch.cat(input_list) dist._reduce_scatter_base(output, input_flattened, group=group) # type: ignore else: # fallback dist.reduce_scatter(output, input_list, group=group) if callback_fn is not None: callback_fn(output) return bucket = self._get_bucket(first_input, group) if first_input_size > bucket.data.size(1) - bucket.offset: # not enough space remaining in bucket, flush it now bucket.flush() # copy data from input_list into bucket stacked_input = torch.stack(input_list).view(world_size, first_input_size) offset = bucket.offset bucket.data[:, offset:offset + first_input_size].copy_(stacked_input) bucket.offset += first_input_size # callback will be given the reduced result if callback_fn is not None: result_view = bucket.output_shard[offset:offset + first_input_size].view_as( first_input) bucket.callbacks.append(functools.partial(callback_fn, result_view))
def main(local_rank, world_size, init_method='tcp://127.0.0.1:23499'): dist.init_process_group(backend='nccl', init_method=init_method, rank=local_rank, world_size=world_size) cfg.local_rank = local_rank torch.cuda.set_device(local_rank) cfg.rank = dist.get_rank() cfg.world_size = world_size print(cfg.rank, dist.get_world_size()) trainset = MXFaceDataset(root_dir='/root/face_datasets/webface/', local_rank=local_rank) train_sampler = torch.utils.data.distributed.DistributedSampler( trainset, shuffle=True) trainloader = DataLoaderX(local_rank=local_rank, dataset=trainset, batch_size=cfg.batch_size, sampler=train_sampler, num_workers=0, pin_memory=True, drop_last=False) backbone = iresnet50(False).to(cfg.local_rank) backbone.train() # backbone = nn.SyncBatchNorm.convert_sync_batchnorm(backbone) for ps in backbone.parameters(): dist.broadcast(ps, 0) backbone = torch.nn.parallel.DistributedDataParallel( backbone, broadcast_buffers=False, device_ids=[dist.get_rank()]) backbone.train() sub_start, sub_classnum = get_sub_class(cfg.rank, dist.get_world_size()) print(sub_start, sub_classnum) classifier_head = classifier(cfg.embedding_size, sub_classnum, sample_rate=0.4) cosface = CosFace(s=64.0, m=0.4) optimizer = SGD([{ 'params': backbone.parameters() }, { 'params': classifier_head.parameters() }], 0.1, momentum=0.9, weight_decay=cfg.weight_decay, rescale=cfg.world_size) warm_up_with_multistep_lr = lambda epoch: ( (epoch + 1) / (4 + 1))**2 if epoch < -1 else 0.1**len( [m for m in [20, 29] if m - 1 <= epoch]) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=warm_up_with_multistep_lr) n_epochs = 33 start_epoch = 0 if cfg.local_rank == 0: writer = SummaryWriter(log_dir='logs/shows') global_step = 0 loss_fun = nn.CrossEntropyLoss() for epoch in range(start_epoch, n_epochs): train_sampler.set_epoch(epoch) for step, (img, label) in enumerate(trainloader): start = time.time() lable_gather, norm_weight = classifier_head.prepare( label, optimizer) x = F.normalize(backbone(img)) x_gather = torch.zeros(x.size()[0] * cfg.world_size, cfg.embedding_size, device=cfg.local_rank) dist.all_gather(list(x_gather.chunk(cfg.world_size, dim=0)), x.data) x_gather.requires_grad = True logits = classifier_head(x_gather, norm_weight) logits = cosface(logits, lable_gather) with torch.no_grad(): max_v = torch.max(logits, dim=1, keepdim=True)[0] dist.all_reduce(max_v, dist.ReduceOp.MAX) exp = torch.exp(logits - max_v) sum_exp = exp.sum(dim=1, keepdims=True) dist.all_reduce(sum_exp, dist.ReduceOp.SUM) exp.div_(sum_exp.clamp_min(1e-20)) grad = exp index = torch.where(lable_gather != -1)[0] one_hot = torch.zeros(index.size()[0], grad.size()[1], device=grad.device) one_hot.scatter_(1, lable_gather[index, None], 1) loss = torch.zeros(grad.size()[0], 1, device=grad.device) loss[index] = grad[index].gather(1, lable_gather[index, None]) dist.all_reduce(loss, dist.ReduceOp.SUM) loss_v = loss.clamp_min_(1e-20).log_().mean() * (-1) grad[index] -= one_hot grad.div_(grad.size()[0]) logits.backward(grad) if x_gather.grad is not None: x_gather.grad.detach_() x_grad = torch.zeros_like(x) dist.reduce_scatter( x_grad, list(x_gather.grad.chunk(cfg.world_size, dim=0))) x.backward(x_grad) optimizer.step() classifier_head.update() optimizer.zero_grad() if cfg.rank == 0: print(x_gather.grad.max(), x_gather.grad.min()) print('loss_v', loss_v.item(), global_step) writer.add_scalar('loss', loss_v, global_step) print('lr', optimizer.state_dict()['param_groups'][0]['lr'], global_step) print(cfg.batch_size / (time.time() - start)) global_step += 1 scheduler.step() if cfg.rank == 0: torch.save(backbone.module.state_dict(), "models/" + str(epoch) + 'backbone.pth') dist.destroy_process_group()
def forward_backward(self, label, features, optimizer, feature_w): self._iters += 1 total_label, norm_weight, index_positive = self.prepare( label, optimizer) total_features = torch.zeros( size=[self.batch_size * self.world_size, self.embedding_size], device=self.device) dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data) total_features.requires_grad = True if feature_w is not None: total_feature_w = torch.zeros( size=[self.batch_size * self.world_size, self.embedding_size], device=self.device) dist.all_gather( list(total_feature_w.chunk(self.world_size, dim=0)), feature_w.data) if self.vpl_mode >= 0: self.prepare_queue_lambda(total_label, self._iters) _lambda = self.queue_lambda.view(self.num_local, 1) injected_weight = norm_weight * (1.0 - _lambda) + self.queue * _lambda injected_norm_weight = normalize(injected_weight) logits = self.forward(total_features, injected_norm_weight) else: logits = self.forward(total_features, norm_weight) logits = self.margin_softmax(logits, total_label) with torch.no_grad(): max_fc = torch.max(logits, dim=1, keepdim=True)[0] dist.all_reduce(max_fc, dist.ReduceOp.MAX) # calculate exp(logits) and all-reduce logits_exp = torch.exp(logits - max_fc) logits_sum_exp = logits_exp.sum(dim=1, keepdims=True) dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM) # calculate prob logits_exp.div_(logits_sum_exp) # get one-hot grad = logits_exp index = torch.where(total_label != -1)[0] one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device) one_hot.scatter_(1, total_label[index, None], 1) # calculate loss loss = torch.zeros(grad.size()[0], 1, device=grad.device) loss[index] = grad[index].gather(1, total_label[index, None]) dist.all_reduce(loss, dist.ReduceOp.SUM) loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) # calculate grad grad[index] -= one_hot grad.div_(self.batch_size * self.world_size) logits.backward(grad) if total_features.grad is not None: total_features.grad.detach_() x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True) # feature gradient all-reduce dist.reduce_scatter( x_grad, list(total_features.grad.chunk(self.world_size, dim=0))) x_grad = x_grad * self.world_size #vpl set queue if self.vpl_mode >= 0: if feature_w is None: self.set_queue(total_features.detach(), total_label, index_positive, self._iters) else: self.set_queue(total_feature_w, total_label, index_positive, self._iters) # backward backbone return x_grad, loss_v
def forward_backward(self, label, features, optimizer, x): """ Partial fc forward and backward with model parallel label: tensor Label tensor on each rank(GPU) features: tensor Features tensor on each rank(GPU) optimizer: optimizer Optimizer for partial fc Returns: -------- x_grad: tensor The gradient of features. loss_v: tensor Loss value for cross entropy. """ total_label, norm_weight = self.prepare(label, optimizer) total_features = torch.zeros( size=[self.batch_size * self.world_size, self.embedding_size], device=self.device) dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data) total_features.requires_grad = True logits = self.forward(total_features, norm_weight) #print(logits.size()) logits, g = self.margin_softmax(logits, total_label, x) with torch.no_grad(): max_fc = torch.max(logits, dim=1, keepdim=True)[0] dist.all_reduce(max_fc, dist.ReduceOp.MAX) # calculate exp(logits) and all-reduce logits_exp = torch.exp(logits - max_fc) logits_sum_exp = logits_exp.sum(dim=1, keepdims=True) dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM) # calculate prob logits_exp.div_(logits_sum_exp) # get one-hot grad = logits_exp index = torch.where(total_label != -1)[0] one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device) one_hot.scatter_(1, total_label[index, None], 1) # calculate loss loss = torch.zeros(grad.size()[0], 1, device=grad.device) loss[index] = grad[index].gather(1, total_label[index, None]) dist.all_reduce(loss, dist.ReduceOp.SUM) loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) + g # calculate grad grad[index] -= one_hot grad.div_(self.batch_size * self.world_size) logits.backward(grad) if total_features.grad is not None: total_features.grad.detach_() x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True) # feature gradient all-reduce dist.reduce_scatter( x_grad, list(total_features.grad.chunk(self.world_size, dim=0))) x_grad = x_grad * self.world_size # backward backbone return x_grad, loss_v
def reduce_scatter_gradients(self, postscale_gradients, gradient_predivide_factor, gradient_average): world_size = dist.get_world_size(group=self.dp_process_group) local_rank = dist.get_rank(group=self.dp_process_group) for i, group in enumerate(self.fp16_groups): partition_param_map = {} param_partition_map = {} my_params = set() # [rank] -> [comm] -> partition num_comm_intervals = self.num_comm_intervals_per_group[i] all_sub_partitions = [] for rank in range(world_size): # gsp is list of partitions indexed by comm_idx #FIXME: currently hardcoding fp16, should infer dtype grad_sub_partitions, partition_params, param_offsets = self.get_flat_sub_partitions( comm_tensor_list=self.params_in_rank_sub_partitions[i] [rank], comm_param_offsets=self. params_in_rank_sub_partitions_offsets[i][rank], sub_partition_size=self.sub_partition_sizes[i], dtype=torch. half, #self.params_in_rank_sub_partitions[i][rank][0][0].dtype, num_comm_intervals=self.num_comm_intervals_per_group[i], default_device= 'cuda', #self.params_in_rank_sub_partitions[i][rank][0][0].device, return_partition_params=True) all_sub_partitions.append(grad_sub_partitions) # create map from partition -> params in that partition for comm_idx, part in enumerate(grad_sub_partitions): partition_param_map[part] = (partition_params[comm_idx], param_offsets[comm_idx]) for comm_idx, params in enumerate(partition_params): for pidx, p in enumerate(params): # store the parameters we care about locally if rank == local_rank: my_params.add(p) # map from param -> partitions if p in param_partition_map: param_partition_map[p].append( grad_sub_partitions[comm_idx]) else: param_partition_map[p] = [ grad_sub_partitions[comm_idx] ] assert len(grad_sub_partitions) == num_comm_intervals if not postscale_gradients: raise NotImplementedError( "pre-scale_gradients is not implemented") all_comm_partitions = [] for comm_idx in range(num_comm_intervals): single_comm_all_partitions = [] for rank in range(world_size): single_comm_all_partitions.append( all_sub_partitions[rank][comm_idx]) dist.reduce_scatter( output=single_comm_all_partitions[local_rank], input_list=single_comm_all_partitions, group=self.dp_process_group) if gradient_average: for partition in single_comm_all_partitions: partition.mul_(gradient_predivide_factor / world_size) all_comm_partitions.append(single_comm_all_partitions) # stitch together all rank sub partitions for each comm idx flat_comm_grads = [] for comm_idx, rank_partitions in enumerate(all_comm_partitions): flat_comm_grads.append(torch.cat(rank_partitions)) flat_all_grads = torch.cat(flat_comm_grads) # copy back reduced gradients but only those needed for this local rank for param, updated_grad in zip( self.fp16_groups[i], _unflatten_dense_tensors(flat_all_grads, self.fp16_groups[i])): if param in my_params: param.grad.copy_(updated_grad)
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg): """ Entry-point function to handle the logic of row-wise sharding of weight for Linear. (Detailed explanations of the logic can be found in the comment for sharded_linear.) Args: input: matrix to be multiplied with the sharded weight. world_size: number of ranks. weight: shareded weight tensor. rank: # of cuda process. local_shard_t: row-wise shared local weight used for lookup. bias: bias term of linear op. pg: process group. Returns: final result of linear operation. """ # alltoall to gather all the appropriate inputs. input_t = input.t().contiguous() input_t_size = input_t.size() # Compute expected size split_size = get_split_size(input_t_size[0], world_size) input_split_sizes = [0] * world_size rearrange_rows = False for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx) input_split_sizes[placement.rank()] = sharded_dim_size if placement.rank() != idx: rearrange_rows = True if rearrange_rows: # Need to re-arrange rows of input_t for all2all. indices: List[List[int]] = [[0]] * world_size # When we do the chunk split, we always ensure the first N - 1 chunks get max out # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] # are not possible. The expected split size will be [4, 4, 4, 1]. sharded_dim_size_max = max(input_split_sizes) for idx, placement in enumerate(weight._sharding_spec.placements): split_size = input_split_sizes[placement.rank()] offset_start_idx = idx * sharded_dim_size_max indices[placement.rank()] = list(range(offset_start_idx, offset_start_idx + split_size)) indices_flatten = list(idx for indice in indices for idx in indice) input_t = input_t.index_select(0, torch.tensor(indices_flatten, device=input_t.device)) gathered_input = torch.empty(input_split_sizes[rank] * world_size, input_t_size[1], device=input_t.device) # Perform alltoall dist.all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg) gathered_input = gathered_input.t() # Perform local matmuls for all shards shard_size = local_shard_t.size()[0] results = [] for r in range(world_size): inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size) results.append(inp.matmul(local_shard_t)) # Gather all the results appropriately. local_result = torch.empty_like(results[rank]) dist.reduce_scatter(local_result, results, group=pg) # Return the appropriate local result. return local_result + bias
import os import torch import torch.distributed as dist from torch.multiprocessing import Process os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29500' backend = 'nccl' dist.init_process_group(backend) rank = dist.get_rank() # simple dist if rank == 0: print("pid is {}, rank is {}".format(os.getpid(), rank)) else: print("pid is {}, rank is {}".format(os.getpid(), rank)) torch.cuda.set_device(rank) var = torch.tensor(rank + 10).to("cuda:{}".format(rank)) var_list = [torch.tensor(var).to("cuda:{}".format(rank)) for var in range(4)] if rank == 0: dist.reduce_scatter(var, var_list, op=dist.ReduceOp.SUM, async_op=False) else: dist.reduce_scatter(var, var_list, op=dist.ReduceOp.SUM, async_op=False) print('Pid is ', os.getpid(), ', Rank ', rank, ', tensor data is: ', var)
def forward_backward(self, features, targets, optimizer): """ Partial FC forward, which will sample positive weights and part of negative weights, then compute logits and get the grad of features. """ total_targets = self.prepare(targets, optimizer) if self.world_size > 1: total_features = concat_all_gather(features) else: total_features = features.detach() total_features.requires_grad_(True) logits = self.forward(total_features) logits = self.cls_layer(logits, total_targets) # from ipdb import set_trace; set_trace() with torch.no_grad(): max_fc = torch.max(logits, dim=1, keepdim=True)[0] if self.world_size > 1: dist.all_reduce(max_fc, dist.ReduceOp.MAX) # calculate exp(logits) and all-reduce logits_exp = torch.exp(logits - max_fc) logits_sum_exp = logits_exp.sum(dim=1, keepdim=True) if self.world_size > 1: dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM) # calculate prob logits_exp.div_(logits_sum_exp) # get one-hot grad = logits_exp index = torch.where(total_targets != -1)[0] one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device) one_hot.scatter_(1, total_targets[index, None], 1) # calculate loss loss = torch.zeros(grad.size()[0], 1, device=grad.device) loss[index] = grad[index].gather(1, total_targets[index, None]) if self.world_size > 1: dist.all_reduce(loss, dist.ReduceOp.SUM) loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) # calculate grad grad[index] -= one_hot grad.div_(logits.size(0)) logits.backward(grad) if total_features.grad is not None: total_features.grad.detach_() x_grad: torch.Tensor = torch.zeros_like(features) # feature gradient all-reduce if self.world_size > 1: dist.reduce_scatter( x_grad, list(total_features.grad.chunk(self.world_size, dim=0))) else: x_grad = total_features.grad x_grad = x_grad * self.world_size # backward backbone return x_grad, loss_v
def _handle_row_wise_sharding( input, world_size, weight, local_shard, offsets, per_sample_weights, mode, max_norm, norm_type, padding_idx, rank, pg, ): """ Entry-point function to handle the logic of row-wise sharding of weight for embeddingBag. (Detailed explanations of the logic can be found in the comment for sharded_embedding_bag.) Args: input: list of ID used for lookup and aggregation. world_size: number of ranks. weight: shareded weight tensor. local_shard: row-wise shared local weight used for lookup. offsets: list of start positions of each bag for 1D input. per_sample_weights: weights for weighted sum mode. mode: aggregation method of each bag. max_norm: If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Note: this will modify weight in-place. norm_type: The p in the p-norm to compute for the max_norm option. padding_idx: If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”. Note that the embedding vector at padding_idx is excluded from the reduction. rank: # of cuda process. pg: process group. Returns: gathered_output: final result of lookup and aggregation. """ # We sort each interval defined by offset. If 2D, each interval is a row. input_size = input.size() ( input_split_sorted_list, input_split_sorted_indices, split_sizes_1d, split_sizes_1d_with_padding, ) = _input_split_sort(input, offsets, padding_idx) # Within each interval of the sorted list, we first need to distribute # each ID to different bucket(rank) and also ensure the rearrangement # has been done in case the placement idx not equal to rank. # We then perform some simple stats on each interval for the next step # If user specifies per_sample_weights we need to rearrange them # to be sync with IDs and then distribute them to each rank ( input_combined, input_combined_split_sizes, offsets_rearrange_list, offsets_rearrange_sizes, per_sample_weights, sharded_dim_size_max, padding_idx, ) = _sorted_input_distribute_prepare( input_split_sorted_list, input_split_sorted_indices, world_size, input, weight, per_sample_weights, rank, padding_idx, ) # Send ID/offsets/per_sample_weights to different bucket(rank). ( gathered_input, output_offsets_tensor_list, output_split_sizes, gathered_per_sample_weights, ) = _distribute_input( input_combined, input_combined_split_sizes, offsets_rearrange_list, offsets_rearrange_sizes, sharded_dim_size_max, world_size, input, per_sample_weights, pg, ) # Perform the embedding bag look-up and aggregation results = [] for i, inp in enumerate(gathered_input): per_sample_weights = ( gathered_per_sample_weights[i] if gathered_per_sample_weights is not None else None ) # If input is None, passing in max_norm causes # errors in CUDA. if max_norm is not None and inp.size(0) == 0: max_norm = None # Perform local embedding look up and aggregation. result = torch.nn.functional.embedding_bag( inp, local_shard, offsets=output_offsets_tensor_list[i], mode=mode if mode != "mean" else "sum", per_sample_weights=per_sample_weights, max_norm=max_norm, norm_type=norm_type, padding_idx=padding_idx, ) if mode != "max": results.append(result) # For max case, it there is no look-up from some ranks # it will return all zero for that. For that case, we need # to set the row to neg inf; otherwise, in the final # aggregation negative values will be rounded up to zero. elif inp.size(0) == 0: result[:] = -float("Inf") results.append(result) else: for idx, current_offset in enumerate(output_offsets_tensor_list[i]): next_offset = current_offset if idx == len(output_offsets_tensor_list[i]) - 1: next_offset = output_split_sizes[i] else: next_offset = output_offsets_tensor_list[i][idx + 1] # When there is no interval in the current rank or all IDs # are equal to padding_idx, we then need to ensure they # don't contribute to the final result. if (current_offset == next_offset) or ( padding_idx is not None and not torch.any( torch.ne(inp[current_offset:next_offset], padding_idx) ) ): result[idx] = -float("Inf") results.append(result) # Gather all the aggregated results appropriately by using reduce_scatter. row_size = input.size(0) if len(input_size) > 1 else len(split_sizes_1d) gathered_output = torch.empty(row_size, weight.size(1), device=input.device) op = ReduceOp.SUM if mode != "max" else ReduceOp.MAX dist.reduce_scatter(gathered_output, results, op=op, group=pg) # For Mean, we cannot do the division until very end because the sum of means # not equal to the mean of sum. (Divisor is different) if mode == "mean": split_sizes_1d_tensor = torch.tensor( split_sizes_1d_with_padding, dtype=torch.float, device=input.device ) # Make sure divisor is not zero. split_sizes_1d_tensor[split_sizes_1d_tensor == 0.0] = 1.0 return ( torch.div(gathered_output.t().contiguous(), split_sizes_1d_tensor) .t() .contiguous() ) # Return the appropriate local result. return gathered_output
def _handle_row_wise_sharding(input, world_size, weight, local_shard, offsets, per_sample_weights, mode, pg): """ Entry-point function to handle the logic of row-wise sharding of weight for embeddingBag. (Detailed explanations of the logic can be found in the comment for sharded_embedding_bag.) Args: input: list of ID used for lookup and aggregation. world_size: number of ranks. weight: shareded weight tensor. local_shard: row-wise shared local weight used for lookup. offsets: list of start positions of each bag for 1D input. per_sample_weights: weights for weighted sum mode. mode: aggregation method of each bag. pg: process group. Returns: gathered_output: final result of lookup and aggregation. """ # We sort each interval defined by offset. If 2D, each interval is a row. input_size = input.size() ( input_split_sorted_list, input_split_sorted_indices, split_sizes_1d, ) = _input_split_sort(input, offsets) # Within each interval of the sorted list, we first need to distribute # each ID to different bucket(rank) and also ensure the rearrangement # has been done in case the placement idx not equal to rank. # We then perform some simple stats on each interval for the next step # If user specifies per_sample_weights we need to rearrange them # to be sync with IDs and then distribute them to each rank ( input_combined, input_combined_split_sizes, offsets_rearrange_list, offsets_rearrange_sizes, per_sample_weights, sharded_dim_size_max, ) = _sorted_input_distribute_prepare( input_split_sorted_list, input_split_sorted_indices, world_size, input, weight, per_sample_weights, ) # Send ID/offsets/per_sample_weights to different bucket(rank). ( gathered_input, output_offsets_tensor_list, output_split_sizes, gathered_per_sample_weights, ) = _distribute_input( input_combined, input_combined_split_sizes, offsets_rearrange_list, offsets_rearrange_sizes, sharded_dim_size_max, world_size, input, per_sample_weights, pg, ) # Perform the embedding bag look-up and aggregation results = [] for i, inp in enumerate(gathered_input): per_sample_weights = (gathered_per_sample_weights[i] if gathered_per_sample_weights is not None else None) result = torch.nn.functional.embedding_bag( inp, local_shard, offsets=output_offsets_tensor_list[i], mode=mode if mode != "mean" else "sum", per_sample_weights=per_sample_weights, ) if mode != "max": results.append(result) # For max case, it there is no look-up from some ranks # it will return all zero for that. For that case, we need # to set the row to neg inf; otherwise, in the final # aggregation negative values will be rounded up to zero. elif inp.size(0) == 0: result[:] = -float("Inf") results.append(result) else: for idx, current_offset in enumerate( output_offsets_tensor_list[i]): next_offset = current_offset if idx == len(output_offsets_tensor_list[i]) - 1: next_offset = output_split_sizes[i] else: next_offset = output_offsets_tensor_list[i][idx + 1] if current_offset == next_offset: result[idx] = -float("Inf") results.append(result) # Gather all the aggregated results appropriately by using reduce_scatter. row_size = input.size(0) if len(input_size) > 1 else len(split_sizes_1d) gathered_output = torch.empty(row_size, weight.size(1), device=input.device) op = ReduceOp.SUM if mode != "max" else ReduceOp.MAX dist.reduce_scatter(gathered_output, results, op=op, group=pg) # For Mean, we cannot do the division until very end because the sum of means # not equal to the mean of sum. (Divisor is different) if mode == "mean": # For a 2D tensor, we just average by col number if len(input_size) > 1: return torch.div(gathered_output, float(input.size(1))) # For a 1D tensor, we need to average by each offset else: split_sizes_1d_tensor = torch.tensor(split_sizes_1d, dtype=torch.float, device=input.device) # Make sure divisor is not zero. split_sizes_1d_tensor[split_sizes_1d_tensor == 0.0] = 1.0 return (torch.div(gathered_output.t().contiguous(), split_sizes_1d_tensor).t().contiguous()) # Return the appropriate local result. return gathered_output