Пример #1
0
    def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        """
        Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually
        To submit a request for asynchronous processing, please use ``ExpertBackend.backward_pool.submit_task``.

        Subclassing:
           This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``;

           It should return gradients w.r.t. inputs that follow ``nested_flatten(self.forward_schema)``;

           Runtime doesn't guarantee that backward will be performed in the same order and for the same data
           as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.

           .. todo correct state handling (see forward)

           Please make sure to call ``ExpertBackend.apply_gradients`` here, otherwise the expert will not train
        """
        (args,
         kwargs), grad_outputs = nested_pack(inputs,
                                             structure=self.backward_schema)

        with torch.enable_grad():
            args = [
                tensor.detach().requires_grad_(True)
                if tensor.dtype in (torch.half, torch.float,
                                    torch.double) else tensor.detach()
                for tensor in args
            ]
            kwargs = {
                input_key:
                (tensor.detach().requires_grad_(True)
                 if tensor.dtype in (torch.half, torch.float,
                                     torch.double) else tensor.detach())
                for input_key, tensor in kwargs.items()
            }

            outputs = self.expert(*args, **kwargs)
            assert nested_compare(
                outputs, grad_outputs
            ), "outputs and grad_outputs must have the same structure"

            outputs_flat = tuple(nested_flatten(outputs))

            grad_outputs_flat = tuple(
                map(
                    lambda grad, out: grad.to(
                        device=out.device, dtype=out.dtype, non_blocking=True),
                    nested_flatten(grad_outputs), outputs_flat))
            torch.autograd.backward(outputs_flat,
                                    grad_tensors=grad_outputs_flat,
                                    create_graph=False,
                                    retain_graph=False)
            self.apply_gradients()

        return tuple(
            x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x)
            for x in nested_flatten((args, kwargs)))
Пример #2
0
    def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
        assert not torch.is_grad_enabled()
        (info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample,
         detect_anomalies) = ctx._saved_non_tensors
        alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors

        dummy_grad_mask, *flat_grad_outputs = raw_grads

        flat_grad_outputs_cpu = []
        for tensor in flat_grad_outputs:
            if detect_anomalies and not tensor.isfinite().all():
                raise ValueError("One of gradients has nan/inf values")
            flat_grad_outputs_cpu.append(tensor.cpu())

        num_samples, max_experts = dummy_grad_mask.shape

        inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu))
        grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu))
        backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))

        # dispatch tasks to all remote experts, collect responses
        pending_tasks = {}
        for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
                                                    inputs_per_expert, grad_outputs_per_expert):
            expert = expert_per_sample[i.item()][j.item()]
            stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
            inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
            tensors_serialized = [serialize_torch_tensor(tensor, proto.compression)
                                  for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]
            new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
            pending_tasks[new_task] = (i, j)

        backward_survivor_indices, survivor_grad_inputs = cls._collect_responses(
            pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies)
        if len(backward_survivor_indices) == 0:
            raise TimeoutError("Backward pass: no alive experts responded within timeout.")

        # assemble responses
        backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices) or ([], []))

        survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs))
        # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]

        grad_inputs = []
        for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked):
            grad_input_per_expert = torch.zeros(  # gradient tensor with individual contributions from each expert
                (num_samples, max_experts, *flat_inputs_cpu[i].shape[1:]),
                device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
            grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked

            # sum gradients from each expert
            grad_inputs.append(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))

        return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
Пример #3
0
    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
        inputs_and_grad_outputs = tuple(
            nested_flatten((ctx.saved_tensors, grad_outputs)))
        backward_schema = tuple(
            nested_flatten(
                (ctx.info["forward_schema"], ctx.info["outputs_schema"])))
        serialized_tensors = [
            serialize_torch_tensor(tensor, proto.compression)
            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
        ]

        grad_inputs = ctx.stub.backward(
            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))

        deserialized_grad_inputs = [
            deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors
        ]
        return (DUMMY, None, None, None, *deserialized_grad_inputs)
Пример #4
0
    def forward(self, input: torch.Tensor, *args: torch.Tensor,
                **kwargs: torch.Tensor):
        """
        Choose k best experts with beam search, then call chosen experts and average their outputs.
        Input tensor is averaged over all dimensions except for first and last
        (we assume that extra dimensions represent sequence length or image height/width)

        :param input: a tensor of values that are used to estimate gating function, batch-first.
        :param args: extra positional parameters that will be passed to each expert after input, batch-first
        :param kwargs: extra keyword parameters that will be passed to each expert, batch-first
        :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
        """
        if input.ndim != 2:
            input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1)))
        else:
            input_for_gating = input

        # 1. compute scores and find most appropriate experts with beam search
        grid_scores = self.proj(input_for_gating).split_with_sizes(
            self.grid_size, dim=-1)

        chosen_experts: List[
            List[RemoteExpert]] = self.dht.batch_find_best_experts(
                self.uid_prefix,
                [scores.detach().cpu().numpy() for scores in grid_scores],
                self.k_best, **self.dht_kwargs)

        if self._expert_info is None:
            try:
                self._expert_info = next((expert.info
                                          for experts_i in chosen_experts
                                          for expert in experts_i))
            except grpc.RpcError as e:
                logger.warning(
                    f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")

        expert_mask, *expert_outputs = _RemoteCallMany.apply(
            DUMMY, chosen_experts, self.k_min, self.backward_k_min,
            self.timeout_after_k_min, self.forward_timeout,
            self.backward_timeout, self.info,
            *nested_flatten(((input, *args), kwargs)))
        # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]

        expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
        masked_logits = torch.full((1, ),
                                   float('-inf'),
                                   device=expert_logits.device,
                                   dtype=expert_logits.dtype)
        expert_logits = torch.where(expert_mask, expert_logits, masked_logits)
        expert_weights = torch.softmax(expert_logits, dim=1)
        averaged_outputs_flat = [
            (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(
                tensor.shape).sum(dim=1) for tensor in expert_outputs
        ]  # ^-- multiply by softmax weights along first 2 axes
        return nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
Пример #5
0
def dump_optimizer_state(optimizer: torch.optim.Optimizer):
    with torch.no_grad():
        flat_metadata, flat_tensors = [], []
        for elem in nested_flatten(optimizer.state_dict()):
            if isinstance(elem, torch.Tensor):
                flat_metadata.append(
                    dict(type='tensor', index=len(flat_tensors)))
                flat_tensors.append(elem.cpu())
            else:
                flat_metadata.append(dict(type='value', value=elem))
        return flat_metadata, flat_tensors
Пример #6
0
def dump_optimizer_state(opt: torch.optim.Optimizer):
    """ Convert optimizer state into a format of DecentralizedAverager's get_current_state/load_state_from_peers """
    with torch.no_grad():
        flat_metadata, flat_tensors = [], []
        for elem in nested_flatten(opt.state_dict()):
            if isinstance(elem, torch.Tensor):
                flat_metadata.append(
                    dict(type='tensor', index=len(flat_tensors)))
                flat_tensors.append(elem.cpu())
            else:
                flat_metadata.append(dict(type='value', value=elem))
        return flat_metadata, flat_tensors
Пример #7
0
    def forward(self, input: torch.Tensor, *args: torch.Tensor,
                **kwargs: torch.Tensor):
        """
        Choose k best experts with beam search, then call chosen experts and average their outputs. Input tensor is averaged over all
        dimensions except first and last (we assume that extra dimensions represent sequence length or image dimensions)

        :param input: a tensor of values that are used to estimate gating function, batch-first.
        :param args: extra positional parameters that will be passed to each expert after input, batch-first
        :param kwargs: extra keyword parameters that will be passed to each expert, batch-first
        :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
        """
        if input.ndim != 2:
            input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1)))
        else:
            input_for_gating = input

        # 1. compute scores and find most appropriate experts with beam search
        grid_scores = self.proj(input_for_gating).split_with_sizes(
            self.grid_size, dim=-1)

        async def _search():
            coroutines = [
                asyncio.create_task(
                    self.beam_search(
                        [dim_scores[i]
                         for dim_scores in grid_scores], self.k_best))
                for i in range(len(input))
            ]
            return list(await asyncio.gather(*coroutines))

        chosen_experts: List[
            List[RemoteExpert]] = self.loop.run_until_complete(_search())
        # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch

        expert_mask, *expert_outputs = _RemoteCallMany.apply(
            DUMMY, chosen_experts, self.k_min, self.backward_k_min,
            self.timeout_after_k_min, self.forward_timeout,
            self.backward_timeout, self.loop,
            *nested_flatten(((input, *args), kwargs)))
        # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]

        expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
        masked_logits = torch.full((1, ),
                                   float('-inf'),
                                   device=expert_logits.device,
                                   dtype=expert_logits.dtype)
        expert_logits = torch.where(expert_mask, expert_logits, masked_logits)
        expert_weights = torch.softmax(expert_logits, dim=1)
        averaged_outputs_flat = [
            (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(
                tensor.shape).sum(dim=1) for tensor in expert_outputs
        ]  # ^-- multiply by softmax weights along first 2 axes
        return nested_pack(averaged_outputs_flat, self.outputs_schema)
Пример #8
0
    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
        payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))

        grad_inputs = ctx.stub.backward(
            runtime_pb2.ExpertRequest(
                uid=ctx.uid,
                tensors=[serialize_torch_tensor(tensor)
                         for tensor in payload]))

        deserialized_grad_inputs = [
            deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors
        ]
        return (DUMMY, None, None, *deserialized_grad_inputs)
Пример #9
0
    def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
                timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
                detect_anomalies: bool, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
        assert not torch.is_grad_enabled()
        num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))

        flat_inputs_cpu = []
        for tensor in flat_inputs:
            if detect_anomalies and not tensor.isfinite().all():
                raise ValueError("One of inputs has nan/inf values")
            flat_inputs_cpu.append(tensor.cpu())

        flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
        assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples

        # dispatch tasks to all remote experts collect responses
        pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {}
        for i in range(num_samples):
            for j, expert in enumerate(experts_per_sample[i]):
                input_tensors = [serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(
                    flat_inputs_per_sample[i], nested_flatten(info['forward_schema']))]
                stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
                new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
                pending_tasks[new_task] = (i, j)

        alive_grid_indices, alive_flat_outputs = cls._collect_responses(
            pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies)
        if len(alive_grid_indices) == 0:
            raise TimeoutError("Forward pass: no alive experts responded within timeout.")

        # assemble responses
        alive_ii, alive_jj = map(torch.as_tensor, zip(*alive_grid_indices))
        mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
        mask[alive_ii, alive_jj] = True

        alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
        # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]

        outputs = []
        for response_stacked in alive_flat_outputs_stacked:
            output = torch.zeros(
                [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device,
                dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad)
            output[alive_ii, alive_jj] = response_stacked
            outputs.append(output.to(flat_inputs[0].device))

        # save individual outputs for backward pass
        ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs_cpu)
        ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample,
                                  detect_anomalies)
        return (mask,) + tuple(outputs)
Пример #10
0
 async def backward(self, request: runtime_pb2.ExpertRequest,
                    context: grpc.ServicerContext):
     inputs_and_grad_outputs = [
         deserialize_torch_tensor(tensor) for tensor in request.tensors
     ]
     future = self.experts[request.uid].backward_pool.submit_task(
         *inputs_and_grad_outputs)
     serialized_response = [
         serialize_torch_tensor(tensor,
                                proto.compression,
                                allow_inplace=True)
         for tensor, proto in zip(
             await future,
             nested_flatten(self.experts[request.uid].grad_inputs_schema))
     ]
     return runtime_pb2.ExpertResponse(tensors=serialized_response)
Пример #11
0
    def forward(self, *args, **kwargs):
        """ Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd. """
        assert len(kwargs) == len(
            self.info['keyword_names']
        ), f"Keyword args should be {self.info['keyword_names']}"
        kwargs = {key: kwargs[key] for key in self.info['keyword_names']}

        # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors

        forward_inputs = (args, kwargs)

        if not nested_compare(forward_inputs, self.info['forward_schema']):
            raise TypeError(
                f"Inputs do not match expert input schema. Did you pass the right number of parameters?"
            )

        flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub,
                                               *nested_flatten(forward_inputs))
        # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
        return nested_pack(flat_outputs, structure=self.info['outputs_schema'])
Пример #12
0
    def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        """
        Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
        To submit a request for asynchronous processing, please use ``ExpertBackend.forward_pool.submit_task``.

        Subclassing:
           This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``;

           It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;

           .. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice.
           .. For now, either register all buffers as outputs or avoid stateful experts

        """
        args, kwargs = nested_pack(inputs, structure=self.forward_schema)

        with torch.no_grad():
            outputs = self.expert(*args, **kwargs)

        # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
        return tuple(nested_flatten(outputs))
Пример #13
0
 async def _backward_one_expert(grid_indices: Tuple[int, ...],
                                expert: RemoteExpert,
                                inputs: Tuple[torch.Tensor],
                                grad_outputs: Tuple[torch.Tensor]):
     stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(
         expert.endpoint, aio=True)
     payload = tuple(nested_flatten((inputs, grad_outputs)))
     try:
         grad_inputs = await stub.backward(
             runtime_pb2.ExpertRequest(uid=expert.uid,
                                       tensors=[
                                           serialize_torch_tensor(tensor)
                                           for tensor in payload
                                       ]))
         return grid_indices, tuple(
             deserialize_torch_tensor(tensor)
             for tensor in grad_inputs.tensors)
     except grpc.experimental.aio.AioRpcError as error:
         logger.warning(
             f"RemoteExpert {expert} failed backward: {error.code()} ({inputs}, {grad_outputs})"
         )
Пример #14
0
    def forward(ctx, dummy: torch.Tensor, uid: str,
                stub: runtime_grpc.ConnectionHandlerStub, info: Dict[str, Any],
                *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
        # detach to avoid pickling the computation graph
        inputs = tuple(tensor.cpu().detach() for tensor in inputs)
        ctx.uid, ctx.stub, ctx.info = uid, stub, info
        ctx.save_for_backward(*inputs)

        serialized_tensors = [
            serialize_torch_tensor(inp, proto.compression) for inp, proto in
            zip(inputs, nested_flatten(info["forward_schema"]))
        ]

        outputs = stub.forward(
            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))

        deserialized_outputs = [
            deserialize_torch_tensor(tensor) for tensor in outputs.tensors
        ]

        return tuple(deserialized_outputs)
Пример #15
0
    def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]],
                k_min: int, backward_k_min: int, timeout_after_k_min: float,
                forward_timeout: Optional[float],
                backward_timeout: Optional[float], detect_anomalies: bool,
                allow_zero_outputs: bool, info: Dict[str, Any],
                *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
        assert not torch.is_grad_enabled()
        num_samples, max_experts = len(experts_per_sample), max(
            map(len, experts_per_sample))

        flat_inputs_cpu = []
        for tensor in flat_inputs:
            if detect_anomalies and not tensor.isfinite().all():
                raise ValueError("One of inputs has nan/inf values")
            flat_inputs_cpu.append(tensor.cpu())

        flat_inputs_per_sample = list(
            zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
        assert len(experts_per_sample) == len(
            flat_inputs_per_sample) == num_samples

        # dispatch tasks to all remote experts collect responses
        pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {}
        for i in range(num_samples):
            for j, expert in enumerate(experts_per_sample[i]):
                input_tensors = [
                    serialize_torch_tensor(tensor, proto.compression)
                    for tensor, proto in zip(
                        flat_inputs_per_sample[i],
                        nested_flatten(info['forward_schema']))
                ]
                stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(
                    expert.endpoint)
                new_task = stub.forward.future(
                    runtime_pb2.ExpertRequest(uid=expert.uid,
                                              tensors=input_tensors))
                pending_tasks[new_task] = (i, j)

        responded_inds, alive_flat_outputs = cls._collect_responses(
            pending_tasks, num_samples, k_min, forward_timeout,
            timeout_after_k_min, detect_anomalies)
        if len(responded_inds) < k_min:
            raise TimeoutError(
                f"Forward pass: less than {k_min} responded within timeout.")

        if not isinstance(info['outputs_schema'], tuple):
            outputs_schema = (info['outputs_schema'], )
        else:
            outputs_schema = info['outputs_schema']
        outputs = nested_map(
            lambda descriptor: descriptor.make_empty(
                num_samples, max_experts, device=flat_inputs[0].device).zero_(
                ), outputs_schema)

        # assemble responses
        if len(responded_inds) > 0 or allow_zero_outputs:
            batch_inds, expert_inds = map(
                lambda x: torch.as_tensor(
                    x, device=flat_inputs[0].device, dtype=torch.long),
                list(zip(*responded_inds)) or ([], []))

            alive_flat_outputs_stacked = (torch.cat(outputs)
                                          for outputs in zip(
                                              *alive_flat_outputs))
            # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]

            for output, response_stacked in zip(outputs,
                                                alive_flat_outputs_stacked):
                output[batch_inds,
                       expert_inds] = response_stacked.to(output.device)

        else:
            raise RuntimeError(
                'Forward pass: 0 experts responded within timeout and allow_zero_outputs is False'
            )

        mask = torch.zeros([num_samples, max_experts],
                           dtype=torch.bool,
                           device=flat_inputs[0].device)
        mask[batch_inds, expert_inds] = True

        # save individual outputs for backward pass
        ctx.save_for_backward(batch_inds, expert_inds, *flat_inputs_cpu)
        ctx._saved_non_tensors = (info, backward_k_min, backward_timeout,
                                  timeout_after_k_min, experts_per_sample,
                                  detect_anomalies)

        return (mask, ) + outputs
Пример #16
0
    def forward(self, input: torch.Tensor, *args: torch.Tensor,
                **kwargs: torch.Tensor):
        if input.ndim != 2:
            input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1)))
        else:
            input_for_gating = input

        # Multiplicative jitter for regularized routing
        jitter_noise = torch.empty_like(input_for_gating).uniform_(
            1 - self.jitter_eps, 1 + self.jitter_eps)
        input_for_gating *= jitter_noise

        # Compute scores, find most appropriate experts with beam search
        grid_scores = self.proj(input_for_gating).split_with_sizes(
            self.beam_search.grid_size, dim=-1)

        grid_dropout_masks = ((torch.rand(size=(dim_size, ),
                                          dtype=input_for_gating.dtype,
                                          device=input_for_gating.device) <
                               self.grid_dropout)
                              for dim_size in self.beam_search.grid_size)
        grid_scores_dropout = [
            torch.where(
                dropout_mask, grid_score,
                torch.full((1, ),
                           float('-inf'),
                           device=grid_score.device,
                           dtype=grid_score.dtype)) for grid_score,
            dropout_mask in zip(grid_scores, grid_dropout_masks)
        ]

        grid_softmax = [
            torch.softmax(grid_score, dim=-1)
            for grid_score in grid_scores_dropout
        ]
        chosen_experts: List[
            List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
                [scores.detach().cpu() for scores in grid_scores_dropout],
                self.k_best)

        if self._expert_info is None:
            try:
                self._expert_info = next((expert.info
                                          for experts_i in chosen_experts
                                          for expert in experts_i))
            except StopIteration:
                raise RuntimeError(
                    "No responding experts found during beam search. Check that UID prefixes and "
                    "the grid size are consistent with running Server instances."
                )
            except grpc.RpcError as e:
                logger.warning(
                    f"Failed to get RemoteSwitchMixtureOfExperts.output_shape: {e}"
                )

        expert_mask, *expert_outputs = _RemoteCallMany.apply(
            DUMMY, chosen_experts, self.k_min, self.backward_k_min,
            self.timeout_after_k_min, self.forward_timeout,
            self.backward_timeout, self.detect_anomalies,
            self.allow_zero_outputs, self.info,
            *nested_flatten(((input, *args), kwargs)))
        # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]

        batch_utilization = self._compute_batch_utilization(
            chosen_experts, expert_mask)
        self.grid_utilization = \
            self.utilization_alpha * self.grid_utilization + (1 - self.utilization_alpha) * batch_utilization

        # compute expert probabilities as product across grid dimensions
        expert_probs = self.compute_expert_scores(grid_softmax, chosen_experts)
        masked_probs = torch.zeros((1, ),
                                   device=expert_probs.device,
                                   dtype=expert_probs.dtype)
        expert_probs = torch.where(expert_mask, expert_probs, masked_probs)

        # multiply outputs by expert probabilities
        averaged_outputs_flat = [
            (expert_probs[..., None] * tensor.flatten(start_dim=2)).view(
                tensor.shape).sum(dim=1) for tensor in expert_outputs
        ]  # ^-- multiply by softmax weights along first 2 axes

        packed_outputs = nested_pack(averaged_outputs_flat,
                                     self.info['outputs_schema'])

        # Load balancing loss: multiply fractions of probability mass and fractions of routed examples
        # for each grid dimension, sum across all indices for a dimension. Optimizing this leads to uniform allocation
        balancing_loss = torch.stack([
            torch.mean(dim_softmax.mean(0) * dim_utilization) * (dim_size**2)
            for dim_softmax, dim_utilization, dim_size in zip(
                grid_softmax, self.grid_utilization,
                self.beam_search.grid_size)
        ]).sum()

        # residual connection
        if isinstance(packed_outputs, torch.Tensor):
            packed_outputs = packed_outputs + input
        else:
            packed_outputs[0] = packed_outputs[0] + input

        return packed_outputs, balancing_loss