Beispiel #1
0
def test_receptor_neuron_server_response_with_nans():
    import numpy as np
    y = torch.rand(3, 3, bittensor.__network_dim__)
    y[0][0][0] = np.nan

    serializer = serialization.get_serializer(
        serialzer_type=bittensor.proto.Serializer.MSGPACK)
    y_serialized = serializer.serialize(
        y,
        modality=bittensor.proto.Modality.TENSOR,
        from_type=bittensor.proto.TensorType.TORCH)

    mock_return_val = bittensor.proto.TensorMessage(
        version=bittensor.__version__,
        public_key=wallet.hotkey.public_key,
        return_code=bittensor.proto.ReturnCode.Success,
        tensors=[y_serialized])

    stub.Forward = MagicMock(return_value=mock_return_val)
    receptor.stub = stub

    x = torch.rand(3, 3, bittensor.__network_dim__)
    out, ops = receptor.forward(x, bittensor.proto.Modality.TENSOR)
    assert ops.item() == bittensor.proto.ReturnCode.Success
    assert out[0][0][0] == 0
Beispiel #2
0
    def test_serialize_deserialize_image(self):
        # Let's grab some image data
        data = torchvision.datasets.MNIST(root='data/datasets/',
                                          train=True,
                                          download=True,
                                          transform=transforms.ToTensor())

        # Let's grab a random image, and give it a crazy type to break the system
        image = data[randrange(len(data))][0]

        serializer = serialization.get_serializer(
            serialzer_type=bittensor.proto.Serializer.MSGPACK)
        serialized_image_tensor_message = serializer.serialize(
            image,
            modality=bittensor.proto.Modality.IMAGE,
            from_type=bittensor.proto.TensorType.TORCH)

        assert image.requires_grad == serialized_image_tensor_message.requires_grad
        assert list(image.shape) == serialized_image_tensor_message.shape
        assert serialized_image_tensor_message.modality == bittensor.proto.Modality.IMAGE
        assert serialized_image_tensor_message.dtype != bittensor.proto.DataType.UNKNOWN

        deserialized_image_tensor_message = serializer.deserialize(
            serialized_image_tensor_message,
            to_type=bittensor.proto.TensorType.TORCH)
        assert serialized_image_tensor_message.requires_grad == deserialized_image_tensor_message.requires_grad
        assert serialized_image_tensor_message.shape == list(
            deserialized_image_tensor_message.shape)
        assert serialization.torch_dtype_to_bittensor_dtype(
            deserialized_image_tensor_message.dtype
        ) != bittensor.proto.DataType.UNKNOWN

        assert torch.all(torch.eq(deserialized_image_tensor_message, image))
    def test_serialize_deserialize_image(self):
        # Let's grab some image data
        # Let's grab a random image, and give it a crazy type to break the system
        image = torch.ones([1, 28, 28])

        serializer = serialization.get_serializer(
            serialzer_type=bittensor.proto.Serializer.MSGPACK)
        serialized_image_tensor_message = serializer.serialize(
            image,
            modality=bittensor.proto.Modality.IMAGE,
            from_type=bittensor.proto.TensorType.TORCH)

        assert image.requires_grad == serialized_image_tensor_message.requires_grad
        assert list(image.shape) == serialized_image_tensor_message.shape
        assert serialized_image_tensor_message.modality == bittensor.proto.Modality.IMAGE
        assert serialized_image_tensor_message.dtype != bittensor.proto.DataType.UNKNOWN

        deserialized_image_tensor_message = serializer.deserialize(
            serialized_image_tensor_message,
            to_type=bittensor.proto.TensorType.TORCH)
        assert serialized_image_tensor_message.requires_grad == deserialized_image_tensor_message.requires_grad
        assert serialized_image_tensor_message.shape == list(
            deserialized_image_tensor_message.shape)
        assert serialization.torch_dtype_to_bittensor_dtype(
            deserialized_image_tensor_message.dtype
        ) != bittensor.proto.DataType.UNKNOWN

        assert torch.all(torch.eq(deserialized_image_tensor_message, image))
Beispiel #4
0
    def test_serialize_deserialize_tensor(self):
        data = torch.rand([12, 23])

        serializer = serialization.get_serializer(
            serialzer_type=bittensor.proto.Serializer.MSGPACK)
        serialized_tensor_message = serializer.serialize(
            data,
            modality=bittensor.proto.Modality.TENSOR,
            from_type=bittensor.proto.TensorType.TORCH)

        assert data.requires_grad == serialized_tensor_message.requires_grad
        assert list(data.shape) == serialized_tensor_message.shape
        assert serialized_tensor_message.modality == bittensor.proto.Modality.TENSOR
        assert serialized_tensor_message.dtype == bittensor.proto.DataType.FLOAT32

        deserialized_tensor_message = serializer.deserialize(
            serialized_tensor_message,
            to_type=bittensor.proto.TensorType.TORCH)
        assert serialized_tensor_message.requires_grad == deserialized_tensor_message.requires_grad
        assert serialized_tensor_message.shape == list(
            deserialized_tensor_message.shape)
        assert serialization.torch_dtype_to_bittensor_dtype(
            deserialized_tensor_message.dtype
        ) == bittensor.proto.DataType.FLOAT32

        assert torch.all(torch.eq(deserialized_tensor_message, data))
    def test_serialize_object_type_exception(self):
        # Let's grab a random image, and try and de-serialize it incorrectly.
        image = torch.ones([1, 28, 28])

        serializer = serialization.get_serializer(
            serialzer_type=bittensor.proto.Serializer.MSGPACK)
        with pytest.raises(
                serialization.SerializationTypeNotImplementedException):
            serializer.serialize(image,
                                 modality=bittensor.proto.Modality.IMAGE,
                                 from_type=11)
Beispiel #6
0
 def test_serialize(self):
     for _ in range(10):
         tensor_a = torch.rand([12, 23])
         serializer = serialization.get_serializer(
             serialzer_type=bittensor.proto.Serializer.MSGPACK)
         content = serializer.serialize(
             tensor_a,
             modality=bittensor.proto.Modality.TENSOR,
             from_type=bittensor.proto.TensorType.TORCH)
         tensor_b = serializer.deserialize(
             content, to_type=bittensor.proto.TensorType.TORCH)
         torch.all(torch.eq(tensor_a, tensor_b))
Beispiel #7
0
    def test_deserialization_object_type_exception(self):
        data = torch.rand([12, 23])

        serializer = serialization.get_serializer(
            serialzer_type=bittensor.proto.Serializer.MSGPACK)
        tensor_message = serializer.serialize(
            data,
            modality=bittensor.proto.Modality.TEXT,
            from_type=bittensor.proto.TensorType.TORCH)

        with pytest.raises(
                serialization.SerializationTypeNotImplementedException):
            serializer.deserialize(tensor_message, to_type=11)
Beispiel #8
0
def test_single_item_backward_request():
    axon.serve(synapse)
    x = torch.rand(3, 3, bittensor.__network_dim__)
    serializer = serialization.get_serializer(
        serialzer_type=bittensor_pb2.Serializer.MSGPACK)
    x_serialized = serializer.serialize(
        x,
        modality=bittensor_pb2.Modality.TENSOR,
        from_type=bittensor_pb2.TensorType.TORCH)

    request = bittensor_pb2.TensorMessage(version=bittensor.__version__,
                                          public_key=keypair.public_key,
                                          tensors=[x_serialized])
    response = axon.Backward(request, None)
    assert response.return_code == bittensor_pb2.ReturnCode.InvalidRequest
Beispiel #9
0
    def test_serialize_object_type_exception(self):
        # Let's grab some image data
        data = torchvision.datasets.MNIST(root='data/datasets/',
                                          train=True,
                                          download=True,
                                          transform=transforms.ToTensor())

        # Let's grab a random image, and try and de-serialize it incorrectly.
        image = data[randrange(len(data))][0]

        serializer = serialization.get_serializer(
            serialzer_type=bittensor.proto.Serializer.MSGPACK)
        with pytest.raises(
                serialization.SerializationTypeNotImplementedException):
            serializer.serialize(image,
                                 modality=bittensor.proto.Modality.IMAGE,
                                 from_type=11)
Beispiel #10
0
def test_receptor_neuron_serve_timeout():
    y = torch.rand(3, 3, bittensor.__network_dim__)
    
    serializer = serialization.get_serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK )
    y_serialized = serializer.serialize(y, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH)
            
    mock_return_val = bittensor.proto.TensorMessage(
            version = bittensor.__version__,
            public_key = wallet.keypair.public_key,
            return_code = bittensor.proto.ReturnCode.Timeout,
            tensors = [y_serialized])

    stub.Forward = MagicMock( return_value=mock_return_val )
    receptor.stub = stub

    x = torch.rand(3, 3, bittensor.__network_dim__)
    out, ops = receptor.forward(x, bittensor.proto.Modality.TENSOR)
    assert ops.item() == bittensor.proto.ReturnCode.Timeout
    assert list(out.shape) == [3, 3, bittensor.__network_dim__]
Beispiel #11
0
def test_forward_not_implemented():
    axon.serve(synapse)
    nucleus.forward = MagicMock(return_value=[
        None, 'not implemented', bittensor_pb2.ReturnCode.NotImplemented
    ])
    x = torch.rand(3, 3, bittensor.__network_dim__)

    serializer = serialization.get_serializer(
        serialzer_type=bittensor_pb2.Serializer.MSGPACK)
    x_serialized = serializer.serialize(
        x,
        modality=bittensor_pb2.Modality.TENSOR,
        from_type=bittensor_pb2.TensorType.TORCH)

    request = bittensor_pb2.TensorMessage(version=bittensor.__version__,
                                          public_key=keypair.public_key,
                                          tensors=[x_serialized])
    response = axon.Forward(request, None)
    assert response.return_code == bittensor_pb2.ReturnCode.NotImplemented
Beispiel #12
0
def test_backward_success():
    axon.serve(synapse)
    x = torch.rand(3, 3, bittensor.__network_dim__)
    serializer = serialization.get_serializer(
        serialzer_type=bittensor_pb2.Serializer.MSGPACK)
    x_serialized = serializer.serialize(
        x,
        modality=bittensor_pb2.Modality.TENSOR,
        from_type=bittensor_pb2.TensorType.TORCH)

    request = bittensor_pb2.TensorMessage(version=bittensor.__version__,
                                          public_key=keypair.public_key,
                                          tensors=[x_serialized, x_serialized])
    nucleus.backward = MagicMock(
        return_value=[x, 'success', bittensor_pb2.ReturnCode.Success])
    response = axon.Backward(request, None)

    assert response.return_code == bittensor_pb2.ReturnCode.Success
    assert len(response.tensors) == 1
    assert response.tensors[0].shape == [3, 3, bittensor.__network_dim__]
    assert serialization_utils.bittensor_dtype_to_torch_dtype(
        response.tensors[0].dtype) == torch.float32
def test_remote_neuron_mock_server_shape_error():
    y = torch.rand(1, 3, bittensor.__network_dim__)

    serializer = serialization.get_serializer(
        serialzer_type=bittensor_pb2.Serializer.MSGPACK)
    y_serialized = serializer.serialize(
        y,
        modality=bittensor_pb2.Modality.TENSOR,
        from_type=bittensor_pb2.TensorType.TORCH)

    mock_return_val = bittensor_pb2.TensorMessage(
        version=bittensor.__version__,
        public_key=keypair.public_key,
        return_code=bittensor_pb2.ReturnCode.Success,
        tensors=[y_serialized])

    stub.Forward = MagicMock(return_value=mock_return_val)
    remote.stub = stub

    x = torch.rand(3, 3, bittensor.__network_dim__)
    out, ops = remote.forward(x, bittensor_pb2.Modality.TENSOR)
    assert ops.item() == bittensor_pb2.ReturnCode.ResponseShapeException
    assert list(out.shape) == [3, 3, bittensor.__network_dim__]
Beispiel #14
0
    def test_serialize_deserialize_text(self):
        # Let's create some text data
        words = ["This", "is", "a", "word", "list"]
        max_l = 0
        ts_list = []
        for w in words:
            ts_list.append(torch.ByteTensor(list(bytes(w, 'utf8'))))
            max_l = max(ts_list[-1].size()[0], max_l)

        data = torch.zeros((len(ts_list), max_l), dtype=torch.int64)
        for i, ts in enumerate(ts_list):
            data[i, 0:ts.size()[0]] = ts

        serializer = serialization.get_serializer(
            serialzer_type=bittensor.proto.Serializer.MSGPACK)
        serialized_data_tensor_message = serializer.serialize(
            data,
            modality=bittensor.proto.Modality.TEXT,
            from_type=bittensor.proto.TensorType.TORCH)

        assert data.requires_grad == serialized_data_tensor_message.requires_grad
        assert list(data.shape) == serialized_data_tensor_message.shape
        assert serialized_data_tensor_message.modality == bittensor.proto.Modality.TEXT
        assert serialized_data_tensor_message.dtype != bittensor.proto.DataType.UNKNOWN

        deserialized_data_tensor_message = serializer.deserialize(
            serialized_data_tensor_message,
            to_type=bittensor.proto.TensorType.TORCH)
        assert serialized_data_tensor_message.requires_grad == deserialized_data_tensor_message.requires_grad
        assert serialized_data_tensor_message.shape == list(
            deserialized_data_tensor_message.shape)
        assert serialization.torch_dtype_to_bittensor_dtype(
            deserialized_data_tensor_message.dtype
        ) != bittensor.proto.DataType.UNKNOWN

        assert torch.all(torch.eq(deserialized_data_tensor_message, data))
Beispiel #15
0
    def _backward(self, request):
        r""" Performs validity checks on the grpc request before calling nucleus backward.
            Returns a the output, message and code from the backend backward call.
            Args:
                request (:obj:`bittensor_pb2`, `required`): 
                    Tensor request proto.
            Returns:
                response: (:obj:`bittensor_pb2.Tensor, `required`): 
                    serialized tensor response from the nucleus call or None.
                message: (str, `required`): 
                    message associated with forward call, potentially error, or 'success'.
                code: (:obj:`bittensor_pb2.ReturnCode, `required`)
                    return code associated with forward call i.e. Success of Timeout.
        """
        # ---- Check that we have a synapse ----.
        if self.synapse == None:
            message = "Remote axon not serving a synapse"
            code = bittensor_pb2.ReturnCode.NotServingSynapse
            return None, message, code

        # ---- Check request inputs ----.
        if len(request.tensors) == 2:
            inputs_x = request.tensors[0]
            grads_dy = request.tensors[1]
            modality_x = inputs_x.modality
        else:
            message = "During backward: There are {} tensors in the request, expected 2.".format(
                len(request.tensors))
            code = bittensor_pb2.ReturnCode.InvalidRequest
            return None, message, code

        # ---- Deserialize request ---
        try:
            serializer = serialization.get_serializer(inputs_x.serializer)
            inputs_x = serializer.deserialize(
                inputs_x, to_type=bittensor_pb2.TensorType.TORCH)
            grads_dy = serializer.deserialize(
                grads_dy, to_type=bittensor_pb2.TensorType.TORCH)

        except Exception as e:
            message = "Backward request deserialization failed with unknown error {}".format(
                e)
            code = bittensor_pb2.ReturnCode.RequestDeserializationException
            return None, message, code

        # --- Get call priority ----
        try:
            call_priority = self.priority[request.public_key] + random.random()
        except:
            call_priority = 1 + random.random()

        # ---- Save gradients to buffer for later use. ---
        try:
            self.gradients.put(
                (call_priority,
                 (request.public_key, inputs_x, grads_dy, modality_x)),
                block=False)
        except queue.Full:
            logger.trace('gradient queue is full at size: {}',
                         self.gradients.qsize())

        # ---- Nucleus backward call ----
        try:
            outputs, message, code = self._nucleus.backward(
                synapse=self.synapse,
                inputs_x=inputs_x,
                grads_dy=grads_dy,
                modality=modality_x,
                priority=call_priority)
        except Exception as e:
            message = "Unkown exception when calling backward with error {}".format(
                e)
            code = bittensor_pb2.ReturnCode.UnknownException
            return None, message, code

        # ---- Deserialize response ----
        try:
            serializer = serialization.get_serializer(
                bittensor_pb2.Serializer.MSGPACK)
            outputs_serialized = serializer.serialize(
                outputs,
                modality=bittensor_pb2.Modality.TENSOR,
                from_type=bittensor_pb2.TensorType.TORCH)

        except Exception as e:
            message = "Backward request serialization failed with error {} and inputs {}".format(
                e, outputs)
            code = bittensor_pb2.ReturnCode.ResponseSerializationException
            return None, message, code

        # ---- Finaly return ----
        return outputs_serialized, message, code
Beispiel #16
0
    def _forward(self, request):
        r""" Performs validity checks on the grpc request before calling nucleus forward.
            Returns the output, message and code from the backend forward call.
            Args:
                request (:obj:`bittensor_pb2`, `required`): 
                    Tensor request proto.
            Returns:
                response: (:obj:`bittensor_pb2.Tensor, `required`): 
                    serialized tensor response from the nucleus call or None.
                message: (str, `required`): 
                    message associated with forward call, potentially error, or 'success'.
                code: (:obj:`bittensor_pb2.ReturnCode, `required`)
                    return code associated with forward call i.e. Success of Timeout.
        """

        # ---- Check synapse exists ----
        if self.synapse == None:
            message = "Remote axon not serving a synapse"
            code = bittensor_pb2.ReturnCode.NotServingSynapse
            return None, message, code

        # ---- Check Empty request ----
        if len(request.tensors) == 0:
            message = "Forward request contains {} tensors, expected 1 tensor in the forward call".format(
                len(request.tensors))
            code = bittensor_pb2.ReturnCode.EmptyRequest
            return None, message, code

        # ---- Check deserialization ----
        inputs = request.tensors[0]
        try:
            deserializer = serialization.get_serializer(
                serialzer_type=inputs.serializer)
            x = deserializer.deserialize(
                inputs, to_type=bittensor_pb2.TensorType.TORCH)
        except Exception as e:
            message = "Forward request deserialization failed with error {}".format(
                e)
            code = bittensor_pb2.ReturnCode.RequestDeserializationException
            return None, message, code

        # ---- Check shape and modality ----
        if x.shape[0] < 1:
            message = "Froward request batch dim exception with batch_size = {} ".format(
                x.shape[0])
            code = bittensor_pb2.ReturnCode.RequestShapeException
            return None, message, code

        if x.shape[1] < 1:
            message = "Forward request sequence dim exception with sequence_dim = {} ".format(
                x.shape[1])
            code = bittensor_pb2.ReturnCode.RequestShapeException
            return None, message, code

        if inputs.modality == bittensor_pb2.Modality.TEXT:
            if len(x.shape) != 2:
                message = "Forward text input shape exception with len(request.shape) = {} must have rank 2.".format(
                    len(x.shape))
                code = bittensor_pb2.ReturnCode.RequestShapeException
                return None, message, code

        if inputs.modality == bittensor_pb2.Modality.IMAGE:
            if len(x.shape) != 5:
                message = "Forward image input shape exception for len(shape) = {}  must have rank 5".format(
                    len(x.shape))
                code = bittensor_pb2.ReturnCode.RequestShapeException
                return None, message, code

        if inputs.modality == bittensor_pb2.Modality.TENSOR:
            if len(x.shape) != 3:
                message = "Forward message tensor input shape exception len(shape) = {} must have rank 3".format(
                    len(x.shape))
                code = bittensor_pb2.ReturnCode.RequestShapeException
                return None, message, code

        # --- Get call priority ----
        call_priority = self.get_call_priority(request)

        # ---- Make Nucleus forward call. ----
        try:
            outputs, message, code = self._nucleus.forward(
                synapse=self.synapse.to(self.synapse.device),
                inputs=x.to(self.synapse.device),
                mode=inputs.modality,
                priority=call_priority)

            # ---- Catch Nucleus errors ----
            if code != bittensor_pb2.ReturnCode.Success:
                return None, message, code

        except Exception as e:
            message = "Unknown exception when calling nucleus forward {}".format(
                e)
            code = bittensor_pb2.ReturnCode.UnknownException
            return None, message, code

        # ---- Serialize response ----
        try:
            serializer = serialization.get_serializer(
                bittensor_pb2.Serializer.MSGPACK)
            outputs_serialized = serializer.serialize(
                outputs,
                modality=bittensor_pb2.Modality.TENSOR,
                from_type=bittensor_pb2.TensorType.TORCH)

        except Exception as e:
            message = "Serializtion of forward response failed with error {} and inputs: {}".format(
                e, outputs)
            code = bittensor_pb2.ReturnCode.ResponseDeserializationException
            return None, message, code

        # ---- Return successful response ----
        return outputs_serialized, message, code
Beispiel #17
0
    def forward(ctx, caller: Receptor, dummy: torch.Tensor,
                inputs: torch.Tensor,
                mode: bittensor.proto.Modality) -> Tuple[torch.Tensor, int]:
        """ Internal autograd-friendly Forward RPC call to a remote neuron (calls the Forward method on an Axon terminal.)

            Args:
                ctx: (:obj:`torch.autograd.ctx`, `required`):
                    Autograd context, saves state information between forward and backward calls. i.e. inputs for gradient computation.

                caller: (:obj:`Receptor`, `required`):
                    Caller object the remote neuron containing the endpoint information, RPC channel etc.

                dummy: (:obj:`torch.Tensor`, `required`):
                    Dummy torch tensor used to ensure that torch.backward computation is called on this function 
                    regardless of the input types.
  
                inputs (:obj:`List[torch.Tensor]` of shape :obj:`(shape)`, `required`):
                    Torch tensor to be sent to the caller associated endpoint neurons.

                mode (:obj:`bittensor.proto.Modality` of shape :obj:`(1)`, `required`):
                    Bittensor forward modality type. Enum in [TEXT, IMAGE, TENSOR]

            Returns:
                output (:obj:`Tuple[torch.FloatTensor`, torch.LongTensor]`, `optional`):
                    Result from forward call. May be None in the case of failure.

                code (:obj:`bittensor.proto.ReturnCode`, `required`):
                    Return code associated with forward call.
        """

        # ---- Save for backward call ----
        ctx.caller = caller
        ctx.mode = mode
        ctx.inputs = inputs

        zeros = nill_response_for(inputs)
        try:
            # ---- Check inputs size ----
            if torch.numel(inputs) == 0:
                return zeros, torch.tensor(
                    bittensor.proto.ReturnCode.EmptyRequest)

            # ---- Inputs Serialization ----
            try:
                serializer = serialization.get_serializer(
                    bittensor.proto.Serializer.MSGPACK)
                serialized_inputs = serializer.serialize(
                    inputs,
                    modality=mode,
                    from_type=bittensor.proto.TensorType.TORCH)
            except Exception as e:
                logger.warning('Serialization error with error {}', e)
                return zeros, torch.tensor(
                    bittensor.proto.ReturnCode.RequestSerializationException)
            ctx.serialized_inputs = serialized_inputs

            # ---- Build request ----
            request = bittensor.proto.TensorMessage(
                version=bittensor.__version__,
                public_key=ctx.caller.wallet.hotkey.public_key,
                nounce=ctx.caller.nounce,
                signature=ctx.caller.signature,
                tensors=[serialized_inputs])

            # ---- Make RPC call ----
            try:

                start_time = time.time()
                ctx.caller.stats.forward_qps.update(1)
                ctx.caller.stats.forward_bytes_out.update(
                    sys.getsizeof(request))
                response = ctx.caller.stub.Forward(
                    request, timeout=caller.config.receptor.timeout)
                ctx.caller.stats.forward_bytes_in.update(
                    sys.getsizeof(response))
                ctx.caller.stats.forward_elapsed_time.update(
                    (time.time() - start_time))

                # ---- Catch non-code ----
                try:
                    bittensor_code = response.return_code
                except:
                    logger.error(
                        'Unknown exception returned from remote host with message {}, {}',
                        response.message, traceback.format_exc())
                    return zeros, torch.tensor(bittensor_code)

                # ---- Catch bittensor errors ----
                if bittensor_code == bittensor.proto.ReturnCode.UnknownException:
                    logger.error(
                        'Unknown exception returned from remote host with message {}, {}',
                        response.message, traceback.format_exc())
                    return zeros, torch.tensor(bittensor_code)

                elif bittensor_code != bittensor.proto.ReturnCode.Success:
                    return zeros, torch.tensor(bittensor_code)

            # ---- Catch GRPC Errors ----
            except grpc.RpcError as rpc_error_call:
                grpc_code = rpc_error_call.code()

                if grpc_code == grpc.StatusCode.DEADLINE_EXCEEDED:
                    return zeros, torch.tensor(
                        bittensor.proto.ReturnCode.Timeout)

                elif grpc_code == grpc.StatusCode.UNAVAILABLE:
                    return zeros, torch.tensor(
                        bittensor.proto.ReturnCode.Unavailable)

                else:
                    logger.error(
                        'Uncaught GPRC error exception with code {} from endpoint {}',
                        grpc_code, caller.endpoint)
                    return zeros, torch.tensor(
                        bittensor.proto.ReturnCode.UnknownException)

            # ---- Catch Unknown Errors ----
            except Exception as e:
                logger.error(
                    'Uncaught error in forward call with error {} and endpoint',
                    e, caller.endpoint)
                return zeros, torch.tensor(
                    bittensor.proto.ReturnCode.UnknownException)

            # ---- Check tensor response length ----
            if len(response.tensors) == 0:
                return zeros, torch.tensor(
                    bittensor.proto.ReturnCode.EmptyResponse)

            # ---- Deserialize response ----
            try:
                outputs = response.tensors[0]
                deserializer = serialization.get_serializer(outputs.serializer)
                outputs = deserializer.deserialize(
                    outputs, to_type=bittensor.proto.TensorType.TORCH)

            except Exception as e:
                logger.error(
                    'Failed to serialize responses from forward call with error {}',
                    e)
                return zeros, torch.tensor(bittensor.proto.ReturnCode.
                                           ResponseDeserializationException)

            # ---- Check response shape ----
            if  outputs.size(0) != inputs.size(0) \
                or outputs.size(1) != inputs.size(1) \
                or outputs.size(2) != bittensor.__network_dim__:
                logger.error(
                    'Forward request returned tensor with incorrect shape {}',
                    list(outputs.shape))
                return zeros, torch.tensor(
                    bittensor.proto.ReturnCode.ResponseShapeException)

            # ---- Safe catch NaNs and replace with 0.0 ----
            outputs = torch.where(torch.isnan(outputs),
                                  torch.zeros_like(outputs), outputs)

        # ---- Catch all ----
        except Exception as e:
            logger.error('Forward request returned unknown error {}', e)
            return zeros, torch.tensor(
                bittensor.proto.ReturnCode.UnknownException)

        # ---- Return ----
        return outputs, torch.tensor(response.return_code)
Beispiel #18
0
    def backward(ctx, grads: torch.FloatTensor,
                 code: torch.FloatTensor) -> Optional[torch.Tensor]:
        """ Internal autograd-friendly Backward RPC call to a remote neuron (calls the Backward method on an remote Axon terminal.)

            Args:
                ctx: (:obj:`torch.autograd.ctx`, `required`):
                    Autograd context, saves state information between forward and backward calls. i.e. inputs for gradient computation.
  
                grads (:obj:`List[torch.Tensor]` of shape :obj:`(shape)`, `required`):
                    Gradients of this function's outputs computed during the loss.backward() call.

                code (:obj:`bittensor.proto.Modality` of shape :obj:`(1)`, `required`):
                    Code output from the forward call.

            Returns:
                output (:obj:`Tuple[torch.FloatTensor`, torch.LongTensor]`, `optional`):
                    Gradients of the inputs with respect to the inputs and grads of the outputs.
        """
        # ---- Zeros response in the case of failure ----
        zeros = nill_response_for(ctx.inputs)

        # ---- Check if are passing gradients ----
        if not ctx.caller.config.receptor.pass_gradients:
            return (None, None, zeros, None)

        # ---- Check that forward query was a success ----
        if code.item() != bittensor.proto.ReturnCode.Success:
            return (None, None, zeros, None)

        # ---- Try to pass gradients ----
        else:
            try:

                # ---- Get forward call serialzied inputs ----
                try:
                    serialized_inputs = ctx.serialized_inputs
                except:
                    logger.trace(
                        'backward failed because forward previously failed.')
                    return (None, None, zeros, None)

                # ---- Serialization ----
                try:
                    # ---- Get serializer ----
                    serializer = serialization.get_serializer(
                        bittensor.proto.Serializer.MSGPACK)

                    # ---- Serialize grads to bitensor_pb2.Tensors ----
                    serialized_grads = serializer.serialize(
                        grads,
                        modality=bittensor.proto.Modality.TENSOR,
                        from_type=bittensor.proto.TensorType.TORCH)

                except Exception as e:
                    logger.trace(
                        'backward failed during serialization of gradients.')
                    return (None, None, zeros, None)

                # ---- Build request for backward ----
                request = bittensor.proto.TensorMessage(
                    version=bittensor.__version__,
                    public_key=ctx.caller.wallet.hotkey.public_key,
                    nounce=ctx.caller.nounce,
                    signature=ctx.caller.signature,
                    tensors=[serialized_inputs, serialized_grads])

                # --- Send non blocking grad request ----
                # NOTE(const): we dont care about the response.
                try:
                    ctx.caller.stats.backward_qps.update(1)
                    ctx.caller.stats.backwar_bytes_out.update(
                        sys.getsizeof(request))
                    ctx.caller.stub.Backward.future(
                        request, timeout=ctx.caller.config.receptor.timeout)
                    ctx.caller.stats.backwar_bytes_in.update(
                        0.0)  # responses are dropped.

                except:
                    logger.trace(
                        'backward failed during backward call. Do not care.')
                    return (None, None, zeros, None)

                # ---- Always return zeros ----
                # NOTE(const): We can return non zeros but a remote host could mess with your training
                # without you knowing about it. i.e. by passing you malicious gradients.
                return (None, None, zeros, None)

            except:

                # ---- Catch all exceptions in Backward ----
                rollbar.send_exception()
                return (None, None, zeros, None)
Beispiel #19
0
    def _backward(self, request):
        r""" Performs validity checks on the grpc request before calling nucleus backward.
            Returns a the output, message and code from the backend backward call.
            Args:
                request (:obj:`bittensor.proto`, `required`): 
                    Tensor request proto.
            Returns:
                response: (:obj:`bittensor.proto.Tensor, `required`): 
                    serialized tensor response from the nucleus call or None.
                message: (str, `required`): 
                    message associated with forward call, potentially error, or 'success'.
                code: (:obj:`bittensor.proto.ReturnCode, `required`)
                    return code associated with forward call i.e. Success of Timeout.
        """
        # ---- Check that we have a synapse ----.
        if self.synapse == None:
            message = "Remote axon not serving a synapse"
            code = bittensor.proto.ReturnCode.NotServingSynapse
            return None, message, code

        # ---- Check request inputs ----.
        if len(request.tensors) == 2:
            inputs_x = request.tensors[0]
            grads_dy = request.tensors[1]
            modality_x = inputs_x.modality
        else:
            message = "During backward: There are {} tensors in the request, expected 2.".format(
                len(request.tensors))
            logger.debug(
                '<white>Axon</white> <red>Backward Request</red> --->x <white>code</white>:<yellow>InvalidRequest</yellow>, <white>from</white>:<cyan>{}</cyan>, <white>message</white>:<red>{}</red>',
                request.public_key, message)
            code = bittensor.proto.ReturnCode.InvalidRequest
            return None, message, code

        # ---- Deserialize request ---
        try:
            serializer = serialization.get_serializer(inputs_x.serializer)
            inputs_x = serializer.deserialize(
                inputs_x, to_type=bittensor.proto.TensorType.TORCH)
            grads_dy = serializer.deserialize(
                grads_dy, to_type=bittensor.proto.TensorType.TORCH)

        except Exception as e:
            message = "Request serialization exception with error: {}".format(
                str(e))
            logger.debug(
                '<white>Axon</white> <red>Backward Request</red> --->x <white>code</white>:<yellow>RequestDeserializationException</yellow>, <white>from</white>:<cyan>{}</cyan>, <white>message</white>:<red>{}</red>',
                request.public_key, message)
            code = bittensor.proto.ReturnCode.RequestDeserializationException
            return None, message, code

        # ---- Check shapes ----
        if modality_x == bittensor.proto.Modality.TEXT:
            if len(inputs_x.shape) != 2:
                message = "Forward text input shape exception with len(request.shape) = {} must have rank 2.".format(
                    len(inputs_x.shape))
                logger.debug(
                    '<white>Axon</white> <red>Backward Request</red> --->x <white>code</white>:<yellow>RequestShapeException</yellow>, <white>from</white>:<cyan>{}</cyan>, <white>message</white>:<red>{}</red>',
                    request.public_key, message)
                return None, message, bittensor.proto.ReturnCode.RequestShapeException

        if modality_x == bittensor.proto.Modality.IMAGE:
            if len(inputs_x.shape) != 5:
                message = "Forward image input shape exception for len(shape) = {}  must have rank 5".format(
                    len(inputs_x.shape))
                logger.debug(
                    '<white>Axon</white> <red>Backward Request</red> --->x <white>code</white>:<yellow>RequestShapeException</yellow>, <white>from</white>:<cyan>{}</cyan>, <white>message</white>:<red>{}</red>',
                    request.public_key, message)
                return None, message, bittensor.proto.ReturnCode.RequestShapeException

        if modality_x == bittensor.proto.Modality.TENSOR:
            if len(inputs_x.shape) != 3:
                message = "Forward message tensor input shape exception len(shape) = {} must have rank 3".format(
                    len(inputs_x.shape))
                logger.debug(
                    '<white>Axon</white> <red>Backward Request</red> --->x <white>code</white>:<yellow>RequestShapeException</yellow>, <white>from</white>:<cyan>{}</cyan>, <white>message</white>:<red>{}</red>',
                    request.public_key, message)
                return None, message, bittensor.proto.ReturnCode.RequestShapeException

        if len(grads_dy.shape) != 3:
            message = "Passed gradients must have rank 3 but got {}".format(
                len(grads_dy.shape))
            logger.debug(
                '<white>Axon</white> <red>Backward Request</red> --->x <white>code</white>:<yellow>RequestShapeException</yellow>, <white>from</white>:<cyan>{}</cyan>, <white>message</white>:<red>{}</red>',
                request.public_key, message)
            return None, message, bittensor.proto.ReturnCode.RequestShapeException

        if grads_dy.shape[0] != inputs_x.shape[0] or grads_dy.shape[
                1] != inputs_x.shape[1]:
            message = "Passed gradients must same first and second dimension as passed inputs got shapes {} and {}".format(
                grads_dy.shape, inputs_x.shape)
            logger.debug(
                '<white>Axon</white> <red>Backward Request</red> --->x <white>code</white>:<yellow>RequestShapeException</yellow>, <white>from</white>:<cyan>{}</cyan>, <white>message</white>:<red>{}</red>',
                request.public_key, message)
            return None, message, bittensor.proto.ReturnCode.RequestShapeException

        # --- Get call priority ----
        try:
            call_priority = self.priority[request.public_key] + random.random()
        except:
            call_priority = 1 + random.random()

        # ---- Save gradients to buffer for later use. ---
        try:
            self.gradients.put(
                (call_priority,
                 (request.public_key, inputs_x, grads_dy, modality_x)),
                block=False)
        except queue.Full:
            logger.trace('gradient queue is full at size: {}',
                         self.gradients.qsize())

        # ---- nucleus.Nucleus backward call ----
        try:
            outputs, message, code = self.nucleus.backward(
                synapse=self.synapse,
                inputs_x=inputs_x,
                grads_dy=grads_dy,
                modality=modality_x,
                priority=call_priority)
        except Exception as e:
            message = "Unkown exception when calling backward with error {}".format(
                e)
            code = bittensor.proto.ReturnCode.UnknownException
            return None, message, code

        # ---- Deserialize response ----
        try:
            serializer = serialization.get_serializer(
                bittensor.proto.Serializer.MSGPACK)
            outputs_serialized = serializer.serialize(
                outputs,
                modality=bittensor.proto.Modality.TENSOR,
                from_type=bittensor.proto.TensorType.TORCH)

        except Exception as e:
            message = "Backward request serialization failed with error {} and inputs {}".format(
                e, outputs)
            code = bittensor.proto.ReturnCode.ResponseSerializationException
            return None, message, code

        # ---- Finaly return ----
        logger.debug(
            '<white>Axon</white> <green>Backward Response</green> <--- <white>code</white>:<green>Success</green>, <white>to</white>:<cyan>{}</cyan>, <white>outputs</white>:<cyan>{}</cyan>',
            request.public_key, outputs.shape)
        return outputs_serialized, message, code