def test_reduce_scatter_coalesced_single_input(): input = torch.full((6, ), dist.get_rank(), dtype=torch.half, device=torch.cuda.current_device()) (output, ) = reduce_scatter_coalesced([input], dist.get_world_group()) assert output.shape == (3, ) assert torch.allclose(output, torch.full_like(output, 0.5))
def has_overflow(self, params, has_moe_params=None): if has_moe_params is None: has_moe_params = self.has_moe_params overflow = self.has_overflow_serial(params) # Since each model parallel GPU carries only part of the model, # make sure overflow flag is synced across all the model parallel GPUs overflow_gpu = torch.cuda.ByteTensor([overflow]) # deepspeeed.comm.all_reduce(overflow_gpu, # op=deepspeed.comm.ReduceOp.MAX, # group=mpu.get_model_parallel_group()) if has_moe_params: # All reduce this across expert_parallel_group, so that if an expert # overflows, we detect it here dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=groups._get_max_expert_parallel_group()) if self.zero_reduce_scatter: dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=dist.get_world_group()) elif self.mpu is not None: if self.deepspeed is not None: using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce is False) or ( not using_pipeline and self.deepspeed.enable_backward_allreduce is False): dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_data_parallel_group()) dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group()) elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False: dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=dist.get_world_group()) overflow = overflow_gpu[0].item() return bool(overflow)
def test_reduce_scatter_coalesced_tensor_smaller_than_world_sz(): input = torch.zeros((1, ), dtype=torch.half, device=torch.cuda.current_device()) (output, ) = reduce_scatter_coalesced([input], dist.get_world_group()) if dist.get_rank() == 0: assert output.shape == (1, ) assert torch.allclose(output, torch.zeros_like(output)) elif dist.get_rank() == 1: assert output.shape == (0, )
def test_reduce_scatter_coalesced_two_inputs(): tensor_kwargs = { "device": torch.cuda.current_device(), "dtype": torch.half } inputs = [ dist.get_rank() * torch.arange(0, 6, **tensor_kwargs), dist.get_rank() * torch.arange(6, 9, **tensor_kwargs), ] output1, output2 = reduce_scatter_coalesced(inputs, dist.get_world_group()) if dist.get_rank() == 0: assert output1.shape == (3, ) assert torch.allclose(output1, torch.arange(0, 3, **tensor_kwargs) / 2) assert output2.shape == (2, ) assert torch.allclose(output2, torch.arange(6, 8, **tensor_kwargs) / 2) elif dist.get_rank() == 1: assert output1.shape == (3, ) assert torch.allclose(output1, torch.arange(3, 6, **tensor_kwargs) / 2) assert output2.shape == (1, ) assert torch.allclose(output2, torch.arange(8, 9, **tensor_kwargs) / 2)
def top1gating( logits: Tensor, capacity_factor: float, min_capacity: int, used_token: Tensor = None, noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, use_rts: bool = True, use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Implements Top1Gating on logits.""" if noisy_gate_policy == 'RSample': logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) # everything is in fp32 in this function gates = F.softmax(logits, dim=1) capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity)) # Create a mask for 1st's expert per token # noisy gating indices1_s = torch.argmax( logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1) num_experts = int(gates.shape[1]) mask1 = F.one_hot(indices1_s, num_classes=num_experts) # mask only used tokens if used_token is not None: mask1 = einsum("s,se->se", used_token, mask1) # gating decisions exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') # if we don't want to drop any tokens if not drop_tokens: new_capacity = torch.max(exp_counts).to(logits.device) dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) capacity = new_capacity # Compute l_aux me = torch.mean(gates, dim=0) ce = torch.mean(mask1.float(), dim=0) l_aux = torch.sum(me * ce) * num_experts # Random Token Selection if use_rts: uniform = exp_selection_uniform_map.get(logits.device) if uniform is None: uniform = torch.distributions.uniform.Uniform( low=torch.tensor(0.0, device=logits.device), high=torch.tensor(1.0, device=logits.device)).rsample exp_selection_uniform_map[logits.device] = uniform mask1_rand = mask1 * uniform(mask1.shape) else: mask1_rand = mask1 assert logits.shape[ 0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." top_idx = _top_idx(mask1_rand, capacity) new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) mask1 = new_mask1 if use_tutel: # Tutel doesn't support index values masked with zero # so we need to replace masked indices with -1 indices_mask = mask1.sum(dim=1) * num_experts - 1 indices1_s = torch.min(indices1_s, indices_mask) # Compute locations in capacity buffer if use_tutel: locations1 = tutel_moe.fast_cumsum_sub_one(mask1) else: locations1 = torch.cumsum(mask1, dim=0) - 1 if use_tutel: gates1_s = (gates * mask1).sum(dim=1) locations1_s = torch.sum(locations1 * mask1, dim=1) return l_aux, capacity, num_experts, [ indices1_s, ], [ locations1_s, ], [ gates1_s, ], exp_counts # Store the capacity location for each token locations1_s = torch.sum(locations1 * mask1, dim=1) # Normalize gate probabilities mask1_float = mask1.float() gates = gates * mask1_float locations1_sc = _one_hot_to_float(locations1_s, capacity) combine_weights = einsum("se,sc->sec", gates, locations1_sc) dispatch_mask = combine_weights.bool() return l_aux, combine_weights, dispatch_mask, exp_counts