def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually To submit a request for asynchronous processing, please use ``ExpertBackend.backward_pool.submit_task``. Subclassing: This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``; It should return gradients w.r.t. inputs that follow ``nested_flatten(self.forward_schema)``; Runtime doesn't guarantee that backward will be performed in the same order and for the same data as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward. .. todo correct state handling (see forward) Please make sure to call ``ExpertBackend.apply_gradients`` here, otherwise the expert will not train """ (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema) with torch.enable_grad(): args = [ tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double) else tensor.detach() for tensor in args ] kwargs = { input_key: (tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double) else tensor.detach()) for input_key, tensor in kwargs.items() } outputs = self.expert(*args, **kwargs) assert nested_compare( outputs, grad_outputs ), "outputs and grad_outputs must have the same structure" outputs_flat = tuple(nested_flatten(outputs)) grad_outputs_flat = tuple( map( lambda grad, out: grad.to( device=out.device, dtype=out.dtype, non_blocking=True), nested_flatten(grad_outputs), outputs_flat)) torch.autograd.backward(outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False) self.apply_gradients() return tuple( x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x) for x in nested_flatten((args, kwargs)))
def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]: assert not torch.is_grad_enabled() (info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample, detect_anomalies) = ctx._saved_non_tensors alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors dummy_grad_mask, *flat_grad_outputs = raw_grads flat_grad_outputs_cpu = [] for tensor in flat_grad_outputs: if detect_anomalies and not tensor.isfinite().all(): raise ValueError("One of gradients has nan/inf values") flat_grad_outputs_cpu.append(tensor.cpu()) num_samples, max_experts = dummy_grad_mask.shape inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu)) grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu)) backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"]))) # dispatch tasks to all remote experts, collect responses pending_tasks = {} for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert): expert = expert_per_sample[i.item()][j.item()] stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint) inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij))) tensors_serialized = [serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)] new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized)) pending_tasks[new_task] = (i, j) backward_survivor_indices, survivor_grad_inputs = cls._collect_responses( pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies) if len(backward_survivor_indices) == 0: raise TimeoutError("Backward pass: no alive experts responded within timeout.") # assemble responses backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices) or ([], [])) survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs)) # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape] grad_inputs = [] for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked): grad_input_per_expert = torch.zeros( # gradient tensor with individual contributions from each expert (num_samples, max_experts, *flat_inputs_cpu[i].shape[1:]), device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype) grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked # sum gradients from each expert grad_inputs.append(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1)) return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]: inputs_and_grad_outputs = tuple( nested_flatten((ctx.saved_tensors, grad_outputs))) backward_schema = tuple( nested_flatten( (ctx.info["forward_schema"], ctx.info["outputs_schema"]))) serialized_tensors = [ serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(inputs_and_grad_outputs, backward_schema) ] grad_inputs = ctx.stub.backward( runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)) deserialized_grad_inputs = [ deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors ] return (DUMMY, None, None, None, *deserialized_grad_inputs)
def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor): """ Choose k best experts with beam search, then call chosen experts and average their outputs. Input tensor is averaged over all dimensions except for first and last (we assume that extra dimensions represent sequence length or image height/width) :param input: a tensor of values that are used to estimate gating function, batch-first. :param args: extra positional parameters that will be passed to each expert after input, batch-first :param kwargs: extra keyword parameters that will be passed to each expert, batch-first :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first """ if input.ndim != 2: input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1))) else: input_for_gating = input # 1. compute scores and find most appropriate experts with beam search grid_scores = self.proj(input_for_gating).split_with_sizes( self.grid_size, dim=-1) chosen_experts: List[ List[RemoteExpert]] = self.dht.batch_find_best_experts( self.uid_prefix, [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best, **self.dht_kwargs) if self._expert_info is None: try: self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i)) except grpc.RpcError as e: logger.warning( f"Failed to get RemoteMixtureOfExperts.output_shape: {e}") expert_mask, *expert_outputs = _RemoteCallMany.apply( DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout, self.backward_timeout, self.info, *nested_flatten(((input, *args), kwargs))) # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape] expert_logits = self.compute_expert_scores(grid_scores, chosen_experts) masked_logits = torch.full((1, ), float('-inf'), device=expert_logits.device, dtype=expert_logits.dtype) expert_logits = torch.where(expert_mask, expert_logits, masked_logits) expert_weights = torch.softmax(expert_logits, dim=1) averaged_outputs_flat = [ (expert_weights[..., None] * tensor.flatten(start_dim=2)).view( tensor.shape).sum(dim=1) for tensor in expert_outputs ] # ^-- multiply by softmax weights along first 2 axes return nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
def dump_optimizer_state(optimizer: torch.optim.Optimizer): with torch.no_grad(): flat_metadata, flat_tensors = [], [] for elem in nested_flatten(optimizer.state_dict()): if isinstance(elem, torch.Tensor): flat_metadata.append( dict(type='tensor', index=len(flat_tensors))) flat_tensors.append(elem.cpu()) else: flat_metadata.append(dict(type='value', value=elem)) return flat_metadata, flat_tensors
def dump_optimizer_state(opt: torch.optim.Optimizer): """ Convert optimizer state into a format of DecentralizedAverager's get_current_state/load_state_from_peers """ with torch.no_grad(): flat_metadata, flat_tensors = [], [] for elem in nested_flatten(opt.state_dict()): if isinstance(elem, torch.Tensor): flat_metadata.append( dict(type='tensor', index=len(flat_tensors))) flat_tensors.append(elem.cpu()) else: flat_metadata.append(dict(type='value', value=elem)) return flat_metadata, flat_tensors
def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor): """ Choose k best experts with beam search, then call chosen experts and average their outputs. Input tensor is averaged over all dimensions except first and last (we assume that extra dimensions represent sequence length or image dimensions) :param input: a tensor of values that are used to estimate gating function, batch-first. :param args: extra positional parameters that will be passed to each expert after input, batch-first :param kwargs: extra keyword parameters that will be passed to each expert, batch-first :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first """ if input.ndim != 2: input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1))) else: input_for_gating = input # 1. compute scores and find most appropriate experts with beam search grid_scores = self.proj(input_for_gating).split_with_sizes( self.grid_size, dim=-1) async def _search(): coroutines = [ asyncio.create_task( self.beam_search( [dim_scores[i] for dim_scores in grid_scores], self.k_best)) for i in range(len(input)) ] return list(await asyncio.gather(*coroutines)) chosen_experts: List[ List[RemoteExpert]] = self.loop.run_until_complete(_search()) # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch expert_mask, *expert_outputs = _RemoteCallMany.apply( DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout, self.backward_timeout, self.loop, *nested_flatten(((input, *args), kwargs))) # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape] expert_logits = self.compute_expert_scores(grid_scores, chosen_experts) masked_logits = torch.full((1, ), float('-inf'), device=expert_logits.device, dtype=expert_logits.dtype) expert_logits = torch.where(expert_mask, expert_logits, masked_logits) expert_weights = torch.softmax(expert_logits, dim=1) averaged_outputs_flat = [ (expert_weights[..., None] * tensor.flatten(start_dim=2)).view( tensor.shape).sum(dim=1) for tensor in expert_outputs ] # ^-- multiply by softmax weights along first 2 axes return nested_pack(averaged_outputs_flat, self.outputs_schema)
def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]: payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs))) grad_inputs = ctx.stub.backward( runtime_pb2.ExpertRequest( uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload])) deserialized_grad_inputs = [ deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors ] return (DUMMY, None, None, *deserialized_grad_inputs)
def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int, timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float], detect_anomalies: bool, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]: assert not torch.is_grad_enabled() num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample)) flat_inputs_cpu = [] for tensor in flat_inputs: if detect_anomalies and not tensor.isfinite().all(): raise ValueError("One of inputs has nan/inf values") flat_inputs_cpu.append(tensor.cpu()) flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu))) assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples # dispatch tasks to all remote experts collect responses pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {} for i in range(num_samples): for j, expert in enumerate(experts_per_sample[i]): input_tensors = [serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip( flat_inputs_per_sample[i], nested_flatten(info['forward_schema']))] stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint) new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors)) pending_tasks[new_task] = (i, j) alive_grid_indices, alive_flat_outputs = cls._collect_responses( pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies) if len(alive_grid_indices) == 0: raise TimeoutError("Forward pass: no alive experts responded within timeout.") # assemble responses alive_ii, alive_jj = map(torch.as_tensor, zip(*alive_grid_indices)) mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device) mask[alive_ii, alive_jj] = True alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs)) # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape] outputs = [] for response_stacked in alive_flat_outputs_stacked: output = torch.zeros( [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device, dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad) output[alive_ii, alive_jj] = response_stacked outputs.append(output.to(flat_inputs[0].device)) # save individual outputs for backward pass ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs_cpu) ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample, detect_anomalies) return (mask,) + tuple(outputs)
async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext): inputs_and_grad_outputs = [ deserialize_torch_tensor(tensor) for tensor in request.tensors ] future = self.experts[request.uid].backward_pool.submit_task( *inputs_and_grad_outputs) serialized_response = [ serialize_torch_tensor(tensor, proto.compression, allow_inplace=True) for tensor, proto in zip( await future, nested_flatten(self.experts[request.uid].grad_inputs_schema)) ] return runtime_pb2.ExpertResponse(tensors=serialized_response)
def forward(self, *args, **kwargs): """ Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd. """ assert len(kwargs) == len( self.info['keyword_names'] ), f"Keyword args should be {self.info['keyword_names']}" kwargs = {key: kwargs[key] for key in self.info['keyword_names']} # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors forward_inputs = (args, kwargs) if not nested_compare(forward_inputs, self.info['forward_schema']): raise TypeError( f"Inputs do not match expert input schema. Did you pass the right number of parameters?" ) flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, *nested_flatten(forward_inputs)) # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad return nested_pack(flat_outputs, structure=self.info['outputs_schema'])
def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually; To submit a request for asynchronous processing, please use ``ExpertBackend.forward_pool.submit_task``. Subclassing: This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``; It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``; .. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice. .. For now, either register all buffers as outputs or avoid stateful experts """ args, kwargs = nested_pack(inputs, structure=self.forward_schema) with torch.no_grad(): outputs = self.expert(*args, **kwargs) # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side return tuple(nested_flatten(outputs))
async def _backward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert, inputs: Tuple[torch.Tensor], grad_outputs: Tuple[torch.Tensor]): stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub( expert.endpoint, aio=True) payload = tuple(nested_flatten((inputs, grad_outputs))) try: grad_inputs = await stub.backward( runtime_pb2.ExpertRequest(uid=expert.uid, tensors=[ serialize_torch_tensor(tensor) for tensor in payload ])) return grid_indices, tuple( deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors) except grpc.experimental.aio.AioRpcError as error: logger.warning( f"RemoteExpert {expert} failed backward: {error.code()} ({inputs}, {grad_outputs})" )
def forward(ctx, dummy: torch.Tensor, uid: str, stub: runtime_grpc.ConnectionHandlerStub, info: Dict[str, Any], *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: # Note: *inputs are flattened input tensors that follow the expert's info['input_schema'] # detach to avoid pickling the computation graph inputs = tuple(tensor.cpu().detach() for tensor in inputs) ctx.uid, ctx.stub, ctx.info = uid, stub, info ctx.save_for_backward(*inputs) serialized_tensors = [ serialize_torch_tensor(inp, proto.compression) for inp, proto in zip(inputs, nested_flatten(info["forward_schema"])) ] outputs = stub.forward( runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)) deserialized_outputs = [ deserialize_torch_tensor(tensor) for tensor in outputs.tensors ] return tuple(deserialized_outputs)
def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int, timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float], detect_anomalies: bool, allow_zero_outputs: bool, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]: assert not torch.is_grad_enabled() num_samples, max_experts = len(experts_per_sample), max( map(len, experts_per_sample)) flat_inputs_cpu = [] for tensor in flat_inputs: if detect_anomalies and not tensor.isfinite().all(): raise ValueError("One of inputs has nan/inf values") flat_inputs_cpu.append(tensor.cpu()) flat_inputs_per_sample = list( zip(*(x.split(1, dim=0) for x in flat_inputs_cpu))) assert len(experts_per_sample) == len( flat_inputs_per_sample) == num_samples # dispatch tasks to all remote experts collect responses pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {} for i in range(num_samples): for j, expert in enumerate(experts_per_sample[i]): input_tensors = [ serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip( flat_inputs_per_sample[i], nested_flatten(info['forward_schema'])) ] stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub( expert.endpoint) new_task = stub.forward.future( runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors)) pending_tasks[new_task] = (i, j) responded_inds, alive_flat_outputs = cls._collect_responses( pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies) if len(responded_inds) < k_min: raise TimeoutError( f"Forward pass: less than {k_min} responded within timeout.") if not isinstance(info['outputs_schema'], tuple): outputs_schema = (info['outputs_schema'], ) else: outputs_schema = info['outputs_schema'] outputs = nested_map( lambda descriptor: descriptor.make_empty( num_samples, max_experts, device=flat_inputs[0].device).zero_( ), outputs_schema) # assemble responses if len(responded_inds) > 0 or allow_zero_outputs: batch_inds, expert_inds = map( lambda x: torch.as_tensor( x, device=flat_inputs[0].device, dtype=torch.long), list(zip(*responded_inds)) or ([], [])) alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip( *alive_flat_outputs)) # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape] for output, response_stacked in zip(outputs, alive_flat_outputs_stacked): output[batch_inds, expert_inds] = response_stacked.to(output.device) else: raise RuntimeError( 'Forward pass: 0 experts responded within timeout and allow_zero_outputs is False' ) mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device) mask[batch_inds, expert_inds] = True # save individual outputs for backward pass ctx.save_for_backward(batch_inds, expert_inds, *flat_inputs_cpu) ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample, detect_anomalies) return (mask, ) + outputs
def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor): if input.ndim != 2: input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1))) else: input_for_gating = input # Multiplicative jitter for regularized routing jitter_noise = torch.empty_like(input_for_gating).uniform_( 1 - self.jitter_eps, 1 + self.jitter_eps) input_for_gating *= jitter_noise # Compute scores, find most appropriate experts with beam search grid_scores = self.proj(input_for_gating).split_with_sizes( self.beam_search.grid_size, dim=-1) grid_dropout_masks = ((torch.rand(size=(dim_size, ), dtype=input_for_gating.dtype, device=input_for_gating.device) < self.grid_dropout) for dim_size in self.beam_search.grid_size) grid_scores_dropout = [ torch.where( dropout_mask, grid_score, torch.full((1, ), float('-inf'), device=grid_score.device, dtype=grid_score.dtype)) for grid_score, dropout_mask in zip(grid_scores, grid_dropout_masks) ] grid_softmax = [ torch.softmax(grid_score, dim=-1) for grid_score in grid_scores_dropout ] chosen_experts: List[ List[RemoteExpert]] = self.beam_search.batch_find_best_experts( [scores.detach().cpu() for scores in grid_scores_dropout], self.k_best) if self._expert_info is None: try: self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i)) except StopIteration: raise RuntimeError( "No responding experts found during beam search. Check that UID prefixes and " "the grid size are consistent with running Server instances." ) except grpc.RpcError as e: logger.warning( f"Failed to get RemoteSwitchMixtureOfExperts.output_shape: {e}" ) expert_mask, *expert_outputs = _RemoteCallMany.apply( DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout, self.backward_timeout, self.detect_anomalies, self.allow_zero_outputs, self.info, *nested_flatten(((input, *args), kwargs))) # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape] batch_utilization = self._compute_batch_utilization( chosen_experts, expert_mask) self.grid_utilization = \ self.utilization_alpha * self.grid_utilization + (1 - self.utilization_alpha) * batch_utilization # compute expert probabilities as product across grid dimensions expert_probs = self.compute_expert_scores(grid_softmax, chosen_experts) masked_probs = torch.zeros((1, ), device=expert_probs.device, dtype=expert_probs.dtype) expert_probs = torch.where(expert_mask, expert_probs, masked_probs) # multiply outputs by expert probabilities averaged_outputs_flat = [ (expert_probs[..., None] * tensor.flatten(start_dim=2)).view( tensor.shape).sum(dim=1) for tensor in expert_outputs ] # ^-- multiply by softmax weights along first 2 axes packed_outputs = nested_pack(averaged_outputs_flat, self.info['outputs_schema']) # Load balancing loss: multiply fractions of probability mass and fractions of routed examples # for each grid dimension, sum across all indices for a dimension. Optimizing this leads to uniform allocation balancing_loss = torch.stack([ torch.mean(dim_softmax.mean(0) * dim_utilization) * (dim_size**2) for dim_softmax, dim_utilization, dim_size in zip( grid_softmax, self.grid_utilization, self.beam_search.grid_size) ]).sum() # residual connection if isinstance(packed_outputs, torch.Tensor): packed_outputs = packed_outputs + input else: packed_outputs[0] = packed_outputs[0] + input return packed_outputs, balancing_loss