def test_client(): server = create_server() server.start() address = "localhost:8812" channel = grpc.insecure_channel(address) stub = proto_grpc.BittensorStub(channel) request = proto_pb2.TensorMessage() response = stub.Forward(request) request = proto_pb2.TensorMessage() response = stub.Backward(request) server.stop(0)
def Backward(self, request: bittensor_pb2.TensorMessage, context: grpc.ServicerContext) -> bittensor_pb2.TensorMessage: r""" The function called by remote GRPC Backward requests from other neurons. Backward is equivalent to a 'backward' gradient descent pass through a neural network. After checking request validity, passes the request to the nucleus for processing. See bittensor_pb2.ReturnCode for all possible return codes. Args: request (:obj:`bittensor_pb2`, `required`): Tensor request proto. context (:obj:`grpc.ServicerContext`, `required`): grpc server context. Returns: response: (bittensor_pb2.TensorMessage): proto response carring the synapse backward output or None under failure. """ tensor, message, code = self._backward(request) response = bittensor_pb2.TensorMessage( version=bittensor.__version__, public_key=self.__keypair.public_key, return_code=code, message=message, tensors=[tensor] if tensor is not None else [], ) self.update_stats_for_request(request, response) return response
def forward(ctx, caller: RemoteSynapse, dummy: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor: # Save for backward call. ctx.caller = caller # Serialize inputs to bytes. serialized_inputs = PyTorchSerializer.serialize(inputs) ctx.serialized_inputs = serialized_inputs # Build request for forward. request = bittensor_pb2.TensorMessage( version=bittensor.__version__, neuron_key=ctx.caller.local_neuron_key, synapse_key=ctx.caller.synapse.synapse_key, nounce=ctx.caller.nounce, signature=ctx.caller.signature, tensors=[serialized_inputs]) # Make rpc call. response = ctx.caller.stub.Forward(request) # Deserialize outputs and return. outputs = PyTorchSerializer.deserialize(response.tensors[0]) return outputs
def backward(ctx, grads: torch.Tensor) -> Optional[torch.Tensor]: # Serialize inputs to bytes. serialized_grads = PyTorchSerializer.serialize(grads) serialized_inputs = ctx.serialized_inputs # Build request for forward. request = bittensor_pb2.TensorMessage( version=bittensor.__version__, neuron_key=ctx.caller.local_neuron_key, synapse_key=ctx.caller.synapse.synapse_key, nounce=ctx.caller.nounce, signature=ctx.caller.signature, tensors=[serialized_inputs, serialized_grads]) # Attain backward response # print ('dendrite ->', request) response = ctx.caller.stub.Backward(request) # Deserialize grad responses. deserialized_grad_inputs = PyTorchSerializer.deserialize( response.tensors[0]) # Return grads return (None, None, deserialized_grad_inputs)
def test_empty_backward_request(): axon.serve(synapse) request = bittensor_pb2.TensorMessage( version=bittensor.__version__, public_key=keypair.public_key, ) response = axon.Backward(request, None) assert response.return_code == bittensor_pb2.ReturnCode.InvalidRequest
def test_backward_not_serving(): axon.synapse = None request = bittensor_pb2.TensorMessage( version=bittensor.__version__, public_key=keypair.public_key, ) response = axon.Backward(request, None) assert response.return_code == bittensor_pb2.ReturnCode.NotServingSynapse
def test_backward_deserialization_error(): axon.serve(synapse) x = dict() y = dict() # Not tensors that can be deserialized. request = bittensor_pb2.TensorMessage(version=bittensor.__version__, public_key=keypair.public_key, tensors=[x, y]) response = axon.Backward(request, None) assert response.return_code == bittensor_pb2.ReturnCode.RequestDeserializationException
def Backward(self, request: bittensor_pb2.TensorMessage, context: grpc.ServicerContext): # TODO (const): optionally check signature. # Return null response if the target does not exist. if request.synapse_key not in self._local_synapses: return bittensor_pb2.TensorMessage() synapse = self._local_synapses[request.synapse_key] # Make local call. x = PyTorchSerializer.deserialize(request.tensors[0]) dy = PyTorchSerializer.deserialize(request.tensors[1]) dx = synapse.call_backward(x, dy) dx_serialized = PyTorchSerializer.serialize(dx) response = bittensor_pb2.TensorMessage( version=bittensor.__version__, neuron_key=self._config.neuron_key, synapse_key=request.synapse_key, tensors=[dx_serialized]) return response
def test_single_item_backward_request(): axon.serve(synapse) x = torch.rand(3, 3, bittensor.__network_dim__) serializer = serialization.get_serializer( serialzer_type=bittensor_pb2.Serializer.MSGPACK) x_serialized = serializer.serialize( x, modality=bittensor_pb2.Modality.TENSOR, from_type=bittensor_pb2.TensorType.TORCH) request = bittensor_pb2.TensorMessage(version=bittensor.__version__, public_key=keypair.public_key, tensors=[x_serialized]) response = axon.Backward(request, None) assert response.return_code == bittensor_pb2.ReturnCode.InvalidRequest
def test_remote_neuron_mock_server_deserialization_error(): y = dict() # bad response mock_return_val = bittensor_pb2.TensorMessage( version=bittensor.__version__, public_key=keypair.public_key, return_code=bittensor_pb2.ReturnCode.Success, tensors=[y]) stub.Forward = MagicMock(return_value=mock_return_val) remote.stub = stub x = torch.rand(3, 3, bittensor.__network_dim__) out, ops = remote.forward(x, bittensor_pb2.Modality.TENSOR) assert ops.item( ) == bittensor_pb2.ReturnCode.ResponseDeserializationException assert list(out.shape) == [3, 3, bittensor.__network_dim__]
def test_forward_not_implemented(): axon.serve(synapse) nucleus.forward = MagicMock(return_value=[ None, 'not implemented', bittensor_pb2.ReturnCode.NotImplemented ]) x = torch.rand(3, 3, bittensor.__network_dim__) serializer = serialization.get_serializer( serialzer_type=bittensor_pb2.Serializer.MSGPACK) x_serialized = serializer.serialize( x, modality=bittensor_pb2.Modality.TENSOR, from_type=bittensor_pb2.TensorType.TORCH) request = bittensor_pb2.TensorMessage(version=bittensor.__version__, public_key=keypair.public_key, tensors=[x_serialized]) response = axon.Forward(request, None) assert response.return_code == bittensor_pb2.ReturnCode.NotImplemented
def test_backward_success(): axon.serve(synapse) x = torch.rand(3, 3, bittensor.__network_dim__) serializer = serialization.get_serializer( serialzer_type=bittensor_pb2.Serializer.MSGPACK) x_serialized = serializer.serialize( x, modality=bittensor_pb2.Modality.TENSOR, from_type=bittensor_pb2.TensorType.TORCH) request = bittensor_pb2.TensorMessage(version=bittensor.__version__, public_key=keypair.public_key, tensors=[x_serialized, x_serialized]) nucleus.backward = MagicMock( return_value=[x, 'success', bittensor_pb2.ReturnCode.Success]) response = axon.Backward(request, None) assert response.return_code == bittensor_pb2.ReturnCode.Success assert len(response.tensors) == 1 assert response.tensors[0].shape == [3, 3, bittensor.__network_dim__] assert serialization_utils.bittensor_dtype_to_torch_dtype( response.tensors[0].dtype) == torch.float32
def test_remote_neuron_mock_server_shape_error(): y = torch.rand(1, 3, bittensor.__network_dim__) serializer = serialization.get_serializer( serialzer_type=bittensor_pb2.Serializer.MSGPACK) y_serialized = serializer.serialize( y, modality=bittensor_pb2.Modality.TENSOR, from_type=bittensor_pb2.TensorType.TORCH) mock_return_val = bittensor_pb2.TensorMessage( version=bittensor.__version__, public_key=keypair.public_key, return_code=bittensor_pb2.ReturnCode.Success, tensors=[y_serialized]) stub.Forward = MagicMock(return_value=mock_return_val) remote.stub = stub x = torch.rand(3, 3, bittensor.__network_dim__) out, ops = remote.forward(x, bittensor_pb2.Modality.TENSOR) assert ops.item() == bittensor_pb2.ReturnCode.ResponseShapeException assert list(out.shape) == [3, 3, bittensor.__network_dim__]
def forward(ctx, caller: RemoteNeuron, dummy: torch.Tensor, inputs: torch.Tensor, mode: bittensor_pb2.Modality) -> Tuple[torch.Tensor, int]: """ Internal autograd-friendly Forward RPC call to a remote neuron (calls the Forward method on an Axon terminal.) Args: ctx: (:obj:`torch.autograd.ctx`, `required`): Autograd context, saves state information between forward and backward calls. i.e. inputs for gradient computation. caller: (:obj:`RemoteNeuron`, `required`): Caller object the remote neuron containing the endpoint information, RPC channel etc. dummy: (:obj:`torch.Tensor`, `required`): Dummy torch tensor used to ensure that torch.backward computation is called on this function regardless of the input types. inputs (:obj:`List[torch.Tensor]` of shape :obj:`(shape)`, `required`): Torch tensor to be sent to the caller associated endpoint neurons. mode (:obj:`bittensor_pb2.Modality` of shape :obj:`(1)`, `required`): Bittensor forward modality type. Enum in [TEXT, IMAGE, TENSOR] Returns: output (:obj:`Tuple[torch.FloatTensor`, torch.LongTensor]`, `optional`): Result from forward call. May be None in the case of failure. code (:obj:`bittensor_pb2.ReturnCode`, `required`): Return code associated with forward call. """ # ---- Save for backward call ---- ctx.caller = caller ctx.mode = mode ctx.inputs = inputs zeros = nill_response_for(inputs) try: # ---- Check inputs size ---- if torch.numel(inputs) == 0: return zeros, torch.tensor( bittensor_pb2.ReturnCode.EmptyRequest) # ---- Inputs Serialization ---- try: serializer = serialization.get_serializer( bittensor_pb2.Serializer.MSGPACK) serialized_inputs = serializer.serialize( inputs, modality=mode, from_type=bittensor_pb2.TensorType.TORCH) except Exception as e: logger.warning('Serialization error with error {}', e) return zeros, torch.tensor( bittensor_pb2.ReturnCode.RequestSerializationException) ctx.serialized_inputs = serialized_inputs # ---- Build request ---- request = bittensor_pb2.TensorMessage( version=bittensor.__version__, public_key=ctx.caller.keypair.public_key, nounce=ctx.caller.nounce, signature=ctx.caller.signature, tensors=[serialized_inputs]) # ---- Make RPC call ---- try: start_time = time.time() ctx.caller.stats.forward_qps.update(1) ctx.caller.stats.forward_bytes_out.update( sys.getsizeof(request)) response = ctx.caller.stub.Forward( request, timeout=caller.config.dendrite.timeout) ctx.caller.stats.forward_bytes_in.update( sys.getsizeof(response)) ctx.caller.stats.forward_elapsed_time.update( (time.time() - start_time)) # ---- Catch non-code ---- try: bittensor_code = response.return_code except: logger.error( 'Unknown exception returned from remote host with message {}, {}', response.message, traceback.format_exc()) return zeros, torch.tensor(bittensor_code) # ---- Catch bittensor errors ---- if bittensor_code == bittensor_pb2.ReturnCode.UnknownException: logger.error( 'Unknown exception returned from remote host with message {}, {}', response.message, traceback.format_exc()) return zeros, torch.tensor(bittensor_code) elif bittensor_code != bittensor_pb2.ReturnCode.Success: return zeros, torch.tensor(bittensor_code) # ---- Catch GRPC Errors ---- except grpc.RpcError as rpc_error_call: grpc_code = rpc_error_call.code() if grpc_code == grpc.StatusCode.DEADLINE_EXCEEDED: return zeros, torch.tensor( bittensor_pb2.ReturnCode.Timeout) elif grpc_code == grpc.StatusCode.UNAVAILABLE: return zeros, torch.tensor( bittensor_pb2.ReturnCode.Unavailable) else: logger.error( 'Uncaught GPRC error exception with code {} from endpoint {}', grpc_code, caller.endpoint) return zeros, torch.tensor( bittensor_pb2.ReturnCode.UnknownException) # ---- Catch Unknown Errors ---- except Exception as e: logger.error( 'Uncaught error in forward call with error {} and endpoint', e, caller.endpoint) return zeros, torch.tensor( bittensor_pb2.ReturnCode.UnknownException) # ---- Check tensor response length ---- if len(response.tensors) == 0: return zeros, torch.tensor( bittensor_pb2.ReturnCode.EmptyResponse) # ---- Deserialize response ---- try: outputs = response.tensors[0] deserializer = serialization.get_serializer(outputs.serializer) outputs = deserializer.deserialize( outputs, to_type=bittensor_pb2.TensorType.TORCH) except Exception as e: logger.error( 'Failed to serialize responses from forward call with error {}', e) return zeros, torch.tensor( bittensor_pb2.ReturnCode.ResponseDeserializationException) # ---- Check response shape ---- if outputs.size(0) != inputs.size(0) \ or outputs.size(1) != inputs.size(1) \ or outputs.size(2) != bittensor.__network_dim__: logger.error( 'Forward request returned tensor with incorrect shape {}', list(outputs.shape)) return zeros, torch.tensor( bittensor_pb2.ReturnCode.ResponseShapeException) # ---- Safe catch NaNs and replace with 0.0 ---- outputs = torch.where(torch.isnan(outputs), torch.zeros_like(outputs), outputs) # ---- Catch all ---- except Exception as e: logger.error('Forward request returned unknown error {}', e) return zeros, torch.tensor( bittensor_pb2.ReturnCode.UnknownException) # ---- Return ---- return outputs, torch.tensor(response.return_code)
def test_empty_protos(): bittensor_pb2.TensorMessage() return
def backward(ctx, grads: torch.FloatTensor, code: torch.FloatTensor) -> Optional[torch.Tensor]: """ Internal autograd-friendly Backward RPC call to a remote neuron (calls the Backward method on an remote Axon terminal.) Args: ctx: (:obj:`torch.autograd.ctx`, `required`): Autograd context, saves state information between forward and backward calls. i.e. inputs for gradient computation. grads (:obj:`List[torch.Tensor]` of shape :obj:`(shape)`, `required`): Gradients of this function's outputs computed during the loss.backward() call. code (:obj:`bittensor_pb2.Modality` of shape :obj:`(1)`, `required`): Code output from the forward call. Returns: output (:obj:`Tuple[torch.FloatTensor`, torch.LongTensor]`, `optional`): Gradients of the inputs with respect to the inputs and grads of the outputs. """ # ---- Zeros response in the case of failure ---- zeros = nill_response_for(ctx.inputs) # ---- Check if are passing gradients ---- if not ctx.caller.config.dendrite.pass_gradients: return (None, None, zeros, None) # ---- Check that forward query was a success ---- if code.item() != bittensor_pb2.ReturnCode.Success: return (None, None, zeros, None) # ---- Try to pass gradients ---- else: try: # ---- Get forward call serialzied inputs ---- try: serialized_inputs = ctx.serialized_inputs except: logger.trace( 'backward failed because forward previously failed.') return (None, None, zeros, None) # ---- Serialization ---- try: # ---- Get serializer ---- serializer = serialization.get_serializer( bittensor_pb2.Serializer.MSGPACK) # ---- Serialize grads to bitensor_pb2.Tensors ---- serialized_grads = serializer.serialize( grads, modality=bittensor_pb2.Modality.TENSOR, from_type=bittensor_pb2.TensorType.TORCH) except Exception as e: logger.trace( 'backward failed during serialization of gradients.') return (None, None, zeros, None) # ---- Build request for backward ---- request = bittensor_pb2.TensorMessage( version=bittensor.__version__, public_key=ctx.caller.keypair.public_key, nounce=ctx.caller.nounce, signature=ctx.caller.signature, tensors=[serialized_inputs, serialized_grads]) # --- Send non blocking grad request ---- # NOTE(const): we dont care about the response. try: ctx.caller.stats.backward_qps.update(1) ctx.caller.stats.backwar_bytes_out.update( sys.getsizeof(request)) ctx.caller.stub.Backward.future( request, timeout=ctx.caller.config.dendrite.timeout) ctx.caller.stats.backwar_bytes_in.update( 0.0) # responses are dropped. except: logger.trace( 'backward failed during backward call. Do not care.') return (None, None, zeros, None) # ---- Always return zeros ---- # NOTE(const): We can return non zeros but a remote host could mess with your training # without you knowing about it. i.e. by passing you malicious gradients. return (None, None, zeros, None) except: # ---- Catch all exceptions in Backward ---- rollbar.send_exception() return (None, None, zeros, None)
def Backward(self, contect, request): response = proto_pb2.TensorMessage() return response
def Forward(self, context, request): response = proto_pb2.TensorMessage() return response