def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually To submit a request for asynchronous processing, please use ``ExpertBackend.backward_pool.submit_task``. Subclassing: This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``; It should return gradients w.r.t. inputs that follow ``nested_flatten(self.forward_schema)``; Runtime doesn't guarantee that backward will be performed in the same order and for the same data as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward. .. todo correct state handling (see forward) Please make sure to call ``ExpertBackend.apply_gradients`` here, otherwise the expert will not train """ (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema) with torch.enable_grad(): args = [ tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double) else tensor.detach() for tensor in args ] kwargs = { input_key: (tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double) else tensor.detach()) for input_key, tensor in kwargs.items() } outputs = self.expert(*args, **kwargs) assert nested_compare( outputs, grad_outputs ), "outputs and grad_outputs must have the same structure" outputs_flat = tuple(nested_flatten(outputs)) grad_outputs_flat = tuple( map( lambda grad, out: grad.to( device=out.device, dtype=out.dtype, non_blocking=True), nested_flatten(grad_outputs), outputs_flat)) torch.autograd.backward(outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False) self.apply_gradients() return tuple( x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x) for x in nested_flatten((args, kwargs)))
def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor): """ Choose k best experts with beam search, then call chosen experts and average their outputs. Input tensor is averaged over all dimensions except for first and last (we assume that extra dimensions represent sequence length or image height/width) :param input: a tensor of values that are used to estimate gating function, batch-first. :param args: extra positional parameters that will be passed to each expert after input, batch-first :param kwargs: extra keyword parameters that will be passed to each expert, batch-first :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first """ if input.ndim != 2: input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1))) else: input_for_gating = input # 1. compute scores and find most appropriate experts with beam search grid_scores = self.proj(input_for_gating).split_with_sizes( self.grid_size, dim=-1) chosen_experts: List[ List[RemoteExpert]] = self.dht.batch_find_best_experts( self.uid_prefix, [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best, **self.dht_kwargs) if self._expert_info is None: try: self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i)) except grpc.RpcError as e: logger.warning( f"Failed to get RemoteMixtureOfExperts.output_shape: {e}") expert_mask, *expert_outputs = _RemoteCallMany.apply( DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout, self.backward_timeout, self.info, *nested_flatten(((input, *args), kwargs))) # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape] expert_logits = self.compute_expert_scores(grid_scores, chosen_experts) masked_logits = torch.full((1, ), float('-inf'), device=expert_logits.device, dtype=expert_logits.dtype) expert_logits = torch.where(expert_mask, expert_logits, masked_logits) expert_weights = torch.softmax(expert_logits, dim=1) averaged_outputs_flat = [ (expert_weights[..., None] * tensor.flatten(start_dim=2)).view( tensor.shape).sum(dim=1) for tensor in expert_outputs ] # ^-- multiply by softmax weights along first 2 axes return nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
def load_optimizer_state(optimizer: torch.optim.Optimizer, flat_metadata: Dict, flat_tensors: Sequence[torch.Tensor]): flat_optimizer_state = [] for elem in flat_metadata: if elem.get('type') == 'tensor' and isinstance(elem.get('index'), int): flat_optimizer_state.append(flat_tensors[elem['index']]) elif elem.get('type') == 'value' and 'value' in elem: flat_optimizer_state.append(elem['value']) with torch.no_grad(): return optimizer.load_state_dict( nested_pack(flat_optimizer_state, structure=optimizer.state_dict()))
def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor): """ Choose k best experts with beam search, then call chosen experts and average their outputs. Input tensor is averaged over all dimensions except first and last (we assume that extra dimensions represent sequence length or image dimensions) :param input: a tensor of values that are used to estimate gating function, batch-first. :param args: extra positional parameters that will be passed to each expert after input, batch-first :param kwargs: extra keyword parameters that will be passed to each expert, batch-first :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first """ if input.ndim != 2: input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1))) else: input_for_gating = input # 1. compute scores and find most appropriate experts with beam search grid_scores = self.proj(input_for_gating).split_with_sizes( self.grid_size, dim=-1) async def _search(): coroutines = [ asyncio.create_task( self.beam_search( [dim_scores[i] for dim_scores in grid_scores], self.k_best)) for i in range(len(input)) ] return list(await asyncio.gather(*coroutines)) chosen_experts: List[ List[RemoteExpert]] = self.loop.run_until_complete(_search()) # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch expert_mask, *expert_outputs = _RemoteCallMany.apply( DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout, self.backward_timeout, self.loop, *nested_flatten(((input, *args), kwargs))) # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape] expert_logits = self.compute_expert_scores(grid_scores, chosen_experts) masked_logits = torch.full((1, ), float('-inf'), device=expert_logits.device, dtype=expert_logits.dtype) expert_logits = torch.where(expert_mask, expert_logits, masked_logits) expert_weights = torch.softmax(expert_logits, dim=1) averaged_outputs_flat = [ (expert_weights[..., None] * tensor.flatten(start_dim=2)).view( tensor.shape).sum(dim=1) for tensor in expert_outputs ] # ^-- multiply by softmax weights along first 2 axes return nested_pack(averaged_outputs_flat, self.outputs_schema)
def forward(self, *args, **kwargs): """ Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd. """ assert len(kwargs) == len( self.info['keyword_names'] ), f"Keyword args should be {self.info['keyword_names']}" kwargs = {key: kwargs[key] for key in self.info['keyword_names']} # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors forward_inputs = (args, kwargs) if not nested_compare(forward_inputs, self.info['forward_schema']): raise TypeError( f"Inputs do not match expert input schema. Did you pass the right number of parameters?" ) flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, *nested_flatten(forward_inputs)) # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad return nested_pack(flat_outputs, structure=self.info['outputs_schema'])
def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually; To submit a request for asynchronous processing, please use ``ExpertBackend.forward_pool.submit_task``. Subclassing: This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``; It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``; .. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice. .. For now, either register all buffers as outputs or avoid stateful experts """ args, kwargs = nested_pack(inputs, structure=self.forward_schema) with torch.no_grad(): outputs = self.expert(*args, **kwargs) # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side return tuple(nested_flatten(outputs))
def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor): if input.ndim != 2: input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1))) else: input_for_gating = input # Multiplicative jitter for regularized routing jitter_noise = torch.empty_like(input_for_gating).uniform_( 1 - self.jitter_eps, 1 + self.jitter_eps) input_for_gating *= jitter_noise # Compute scores, find most appropriate experts with beam search grid_scores = self.proj(input_for_gating).split_with_sizes( self.beam_search.grid_size, dim=-1) grid_dropout_masks = ((torch.rand(size=(dim_size, ), dtype=input_for_gating.dtype, device=input_for_gating.device) < self.grid_dropout) for dim_size in self.beam_search.grid_size) grid_scores_dropout = [ torch.where( dropout_mask, grid_score, torch.full((1, ), float('-inf'), device=grid_score.device, dtype=grid_score.dtype)) for grid_score, dropout_mask in zip(grid_scores, grid_dropout_masks) ] grid_softmax = [ torch.softmax(grid_score, dim=-1) for grid_score in grid_scores_dropout ] chosen_experts: List[ List[RemoteExpert]] = self.beam_search.batch_find_best_experts( [scores.detach().cpu() for scores in grid_scores_dropout], self.k_best) if self._expert_info is None: try: self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i)) except StopIteration: raise RuntimeError( "No responding experts found during beam search. Check that UID prefixes and " "the grid size are consistent with running Server instances." ) except grpc.RpcError as e: logger.warning( f"Failed to get RemoteSwitchMixtureOfExperts.output_shape: {e}" ) expert_mask, *expert_outputs = _RemoteCallMany.apply( DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout, self.backward_timeout, self.detect_anomalies, self.allow_zero_outputs, self.info, *nested_flatten(((input, *args), kwargs))) # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape] batch_utilization = self._compute_batch_utilization( chosen_experts, expert_mask) self.grid_utilization = \ self.utilization_alpha * self.grid_utilization + (1 - self.utilization_alpha) * batch_utilization # compute expert probabilities as product across grid dimensions expert_probs = self.compute_expert_scores(grid_softmax, chosen_experts) masked_probs = torch.zeros((1, ), device=expert_probs.device, dtype=expert_probs.dtype) expert_probs = torch.where(expert_mask, expert_probs, masked_probs) # multiply outputs by expert probabilities averaged_outputs_flat = [ (expert_probs[..., None] * tensor.flatten(start_dim=2)).view( tensor.shape).sum(dim=1) for tensor in expert_outputs ] # ^-- multiply by softmax weights along first 2 axes packed_outputs = nested_pack(averaged_outputs_flat, self.info['outputs_schema']) # Load balancing loss: multiply fractions of probability mass and fractions of routed examples # for each grid dimension, sum across all indices for a dimension. Optimizing this leads to uniform allocation balancing_loss = torch.stack([ torch.mean(dim_softmax.mean(0) * dim_utilization) * (dim_size**2) for dim_softmax, dim_utilization, dim_size in zip( grid_softmax, self.grid_utilization, self.beam_search.grid_size) ]).sum() # residual connection if isinstance(packed_outputs, torch.Tensor): packed_outputs = packed_outputs + input else: packed_outputs[0] = packed_outputs[0] + input return packed_outputs, balancing_loss