示例#1
0
    def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
        assert not torch.is_grad_enabled()
        (info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample,
         detect_anomalies) = ctx._saved_non_tensors
        alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors

        dummy_grad_mask, *flat_grad_outputs = raw_grads

        flat_grad_outputs_cpu = []
        for tensor in flat_grad_outputs:
            if detect_anomalies and not tensor.isfinite().all():
                raise ValueError("One of gradients has nan/inf values")
            flat_grad_outputs_cpu.append(tensor.cpu())

        num_samples, max_experts = dummy_grad_mask.shape

        inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu))
        grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu))
        backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))

        # dispatch tasks to all remote experts, collect responses
        pending_tasks = {}
        for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
                                                    inputs_per_expert, grad_outputs_per_expert):
            expert = expert_per_sample[i.item()][j.item()]
            stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
            inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
            tensors_serialized = [serialize_torch_tensor(tensor, proto.compression)
                                  for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]
            new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
            pending_tasks[new_task] = (i, j)

        backward_survivor_indices, survivor_grad_inputs = cls._collect_responses(
            pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies)
        if len(backward_survivor_indices) == 0:
            raise TimeoutError("Backward pass: no alive experts responded within timeout.")

        # assemble responses
        backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices) or ([], []))

        survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs))
        # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]

        grad_inputs = []
        for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked):
            grad_input_per_expert = torch.zeros(  # gradient tensor with individual contributions from each expert
                (num_samples, max_experts, *flat_inputs_cpu[i].shape[1:]),
                device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
            grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked

            # sum gradients from each expert
            grad_inputs.append(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))

        return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
示例#2
0
    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
        payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))

        grad_inputs = ctx.stub.backward(
            runtime_pb2.ExpertRequest(
                uid=ctx.uid,
                tensors=[serialize_torch_tensor(tensor)
                         for tensor in payload]))

        deserialized_grad_inputs = [
            deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors
        ]
        return (DUMMY, None, None, *deserialized_grad_inputs)
示例#3
0
    def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
                timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
                detect_anomalies: bool, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
        assert not torch.is_grad_enabled()
        num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))

        flat_inputs_cpu = []
        for tensor in flat_inputs:
            if detect_anomalies and not tensor.isfinite().all():
                raise ValueError("One of inputs has nan/inf values")
            flat_inputs_cpu.append(tensor.cpu())

        flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
        assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples

        # dispatch tasks to all remote experts collect responses
        pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {}
        for i in range(num_samples):
            for j, expert in enumerate(experts_per_sample[i]):
                input_tensors = [serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(
                    flat_inputs_per_sample[i], nested_flatten(info['forward_schema']))]
                stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
                new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
                pending_tasks[new_task] = (i, j)

        alive_grid_indices, alive_flat_outputs = cls._collect_responses(
            pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies)
        if len(alive_grid_indices) == 0:
            raise TimeoutError("Forward pass: no alive experts responded within timeout.")

        # assemble responses
        alive_ii, alive_jj = map(torch.as_tensor, zip(*alive_grid_indices))
        mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
        mask[alive_ii, alive_jj] = True

        alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
        # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]

        outputs = []
        for response_stacked in alive_flat_outputs_stacked:
            output = torch.zeros(
                [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device,
                dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad)
            output[alive_ii, alive_jj] = response_stacked
            outputs.append(output.to(flat_inputs[0].device))

        # save individual outputs for backward pass
        ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs_cpu)
        ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample,
                                  detect_anomalies)
        return (mask,) + tuple(outputs)
示例#4
0
    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
        inputs_and_grad_outputs = tuple(
            nested_flatten((ctx.saved_tensors, grad_outputs)))
        backward_schema = tuple(
            nested_flatten(
                (ctx.info["forward_schema"], ctx.info["outputs_schema"])))
        serialized_tensors = [
            serialize_torch_tensor(tensor, proto.compression)
            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
        ]

        grad_inputs = ctx.stub.backward(
            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))

        deserialized_grad_inputs = [
            deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors
        ]
        return (DUMMY, None, None, None, *deserialized_grad_inputs)
示例#5
0
 async def _forward_one_expert(grid_indices: Tuple[int, ...],
                               expert: RemoteExpert,
                               inputs: Tuple[torch.Tensor]):
     stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(
         expert.endpoint, aio=True)
     try:
         outputs = await stub.forward(
             runtime_pb2.ExpertRequest(uid=expert.uid,
                                       tensors=[
                                           serialize_torch_tensor(tensor)
                                           for tensor in inputs
                                       ]))
         return grid_indices, tuple(
             deserialize_torch_tensor(tensor) for tensor in outputs.tensors)
     except grpc.experimental.aio.AioRpcError as error:
         logger.warning(
             f"RemoteExpert {expert} failed forward: {error.code()} (inputs: {inputs})"
         )
示例#6
0
    def forward(ctx, dummy: torch.Tensor, uid: str,
                stub: runtime_grpc.ConnectionHandlerStub,
                *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
        inputs = tuple(
            map(torch.Tensor.detach,
                inputs))  # detach to avoid pickling the computation graph
        ctx.uid, ctx.stub = uid, stub
        ctx.save_for_backward(*inputs)

        outputs = stub.forward(
            runtime_pb2.ExpertRequest(
                uid=ctx.uid,
                tensors=[serialize_torch_tensor(tensor) for tensor in inputs]))

        deserialized_outputs = [
            deserialize_torch_tensor(tensor) for tensor in outputs.tensors
        ]

        return tuple(deserialized_outputs)
示例#7
0
    def forward(ctx, dummy: torch.Tensor, uid: str,
                stub: runtime_grpc.ConnectionHandlerStub, info: Dict[str, Any],
                *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
        # detach to avoid pickling the computation graph
        inputs = tuple(tensor.cpu().detach() for tensor in inputs)
        ctx.uid, ctx.stub, ctx.info = uid, stub, info
        ctx.save_for_backward(*inputs)

        serialized_tensors = [
            serialize_torch_tensor(inp, proto.compression) for inp, proto in
            zip(inputs, nested_flatten(info["forward_schema"]))
        ]

        outputs = stub.forward(
            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))

        deserialized_outputs = [
            deserialize_torch_tensor(tensor) for tensor in outputs.tensors
        ]

        return tuple(deserialized_outputs)
示例#8
0
    def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]],
                k_min: int, backward_k_min: int, timeout_after_k_min: float,
                forward_timeout: Optional[float],
                backward_timeout: Optional[float], detect_anomalies: bool,
                allow_zero_outputs: bool, info: Dict[str, Any],
                *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
        assert not torch.is_grad_enabled()
        num_samples, max_experts = len(experts_per_sample), max(
            map(len, experts_per_sample))

        flat_inputs_cpu = []
        for tensor in flat_inputs:
            if detect_anomalies and not tensor.isfinite().all():
                raise ValueError("One of inputs has nan/inf values")
            flat_inputs_cpu.append(tensor.cpu())

        flat_inputs_per_sample = list(
            zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
        assert len(experts_per_sample) == len(
            flat_inputs_per_sample) == num_samples

        # dispatch tasks to all remote experts collect responses
        pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {}
        for i in range(num_samples):
            for j, expert in enumerate(experts_per_sample[i]):
                input_tensors = [
                    serialize_torch_tensor(tensor, proto.compression)
                    for tensor, proto in zip(
                        flat_inputs_per_sample[i],
                        nested_flatten(info['forward_schema']))
                ]
                stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(
                    expert.endpoint)
                new_task = stub.forward.future(
                    runtime_pb2.ExpertRequest(uid=expert.uid,
                                              tensors=input_tensors))
                pending_tasks[new_task] = (i, j)

        responded_inds, alive_flat_outputs = cls._collect_responses(
            pending_tasks, num_samples, k_min, forward_timeout,
            timeout_after_k_min, detect_anomalies)
        if len(responded_inds) < k_min:
            raise TimeoutError(
                f"Forward pass: less than {k_min} responded within timeout.")

        if not isinstance(info['outputs_schema'], tuple):
            outputs_schema = (info['outputs_schema'], )
        else:
            outputs_schema = info['outputs_schema']
        outputs = nested_map(
            lambda descriptor: descriptor.make_empty(
                num_samples, max_experts, device=flat_inputs[0].device).zero_(
                ), outputs_schema)

        # assemble responses
        if len(responded_inds) > 0 or allow_zero_outputs:
            batch_inds, expert_inds = map(
                lambda x: torch.as_tensor(
                    x, device=flat_inputs[0].device, dtype=torch.long),
                list(zip(*responded_inds)) or ([], []))

            alive_flat_outputs_stacked = (torch.cat(outputs)
                                          for outputs in zip(
                                              *alive_flat_outputs))
            # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]

            for output, response_stacked in zip(outputs,
                                                alive_flat_outputs_stacked):
                output[batch_inds,
                       expert_inds] = response_stacked.to(output.device)

        else:
            raise RuntimeError(
                'Forward pass: 0 experts responded within timeout and allow_zero_outputs is False'
            )

        mask = torch.zeros([num_samples, max_experts],
                           dtype=torch.bool,
                           device=flat_inputs[0].device)
        mask[batch_inds, expert_inds] = True

        # save individual outputs for backward pass
        ctx.save_for_backward(batch_inds, expert_inds, *flat_inputs_cpu)
        ctx._saved_non_tensors = (info, backward_k_min, backward_timeout,
                                  timeout_after_k_min, experts_per_sample,
                                  detect_anomalies)

        return (mask, ) + outputs