def test_receptor_neuron_server_response_with_nans(): import numpy as np y = torch.rand(3, 3, bittensor.__network_dim__) y[0][0][0] = np.nan serializer = serialization.get_serializer( serialzer_type=bittensor.proto.Serializer.MSGPACK) y_serialized = serializer.serialize( y, modality=bittensor.proto.Modality.TENSOR, from_type=bittensor.proto.TensorType.TORCH) mock_return_val = bittensor.proto.TensorMessage( version=bittensor.__version__, public_key=wallet.hotkey.public_key, return_code=bittensor.proto.ReturnCode.Success, tensors=[y_serialized]) stub.Forward = MagicMock(return_value=mock_return_val) receptor.stub = stub x = torch.rand(3, 3, bittensor.__network_dim__) out, ops = receptor.forward(x, bittensor.proto.Modality.TENSOR) assert ops.item() == bittensor.proto.ReturnCode.Success assert out[0][0][0] == 0
def test_serialize_deserialize_image(self): # Let's grab some image data data = torchvision.datasets.MNIST(root='data/datasets/', train=True, download=True, transform=transforms.ToTensor()) # Let's grab a random image, and give it a crazy type to break the system image = data[randrange(len(data))][0] serializer = serialization.get_serializer( serialzer_type=bittensor.proto.Serializer.MSGPACK) serialized_image_tensor_message = serializer.serialize( image, modality=bittensor.proto.Modality.IMAGE, from_type=bittensor.proto.TensorType.TORCH) assert image.requires_grad == serialized_image_tensor_message.requires_grad assert list(image.shape) == serialized_image_tensor_message.shape assert serialized_image_tensor_message.modality == bittensor.proto.Modality.IMAGE assert serialized_image_tensor_message.dtype != bittensor.proto.DataType.UNKNOWN deserialized_image_tensor_message = serializer.deserialize( serialized_image_tensor_message, to_type=bittensor.proto.TensorType.TORCH) assert serialized_image_tensor_message.requires_grad == deserialized_image_tensor_message.requires_grad assert serialized_image_tensor_message.shape == list( deserialized_image_tensor_message.shape) assert serialization.torch_dtype_to_bittensor_dtype( deserialized_image_tensor_message.dtype ) != bittensor.proto.DataType.UNKNOWN assert torch.all(torch.eq(deserialized_image_tensor_message, image))
def test_serialize_deserialize_image(self): # Let's grab some image data # Let's grab a random image, and give it a crazy type to break the system image = torch.ones([1, 28, 28]) serializer = serialization.get_serializer( serialzer_type=bittensor.proto.Serializer.MSGPACK) serialized_image_tensor_message = serializer.serialize( image, modality=bittensor.proto.Modality.IMAGE, from_type=bittensor.proto.TensorType.TORCH) assert image.requires_grad == serialized_image_tensor_message.requires_grad assert list(image.shape) == serialized_image_tensor_message.shape assert serialized_image_tensor_message.modality == bittensor.proto.Modality.IMAGE assert serialized_image_tensor_message.dtype != bittensor.proto.DataType.UNKNOWN deserialized_image_tensor_message = serializer.deserialize( serialized_image_tensor_message, to_type=bittensor.proto.TensorType.TORCH) assert serialized_image_tensor_message.requires_grad == deserialized_image_tensor_message.requires_grad assert serialized_image_tensor_message.shape == list( deserialized_image_tensor_message.shape) assert serialization.torch_dtype_to_bittensor_dtype( deserialized_image_tensor_message.dtype ) != bittensor.proto.DataType.UNKNOWN assert torch.all(torch.eq(deserialized_image_tensor_message, image))
def test_serialize_deserialize_tensor(self): data = torch.rand([12, 23]) serializer = serialization.get_serializer( serialzer_type=bittensor.proto.Serializer.MSGPACK) serialized_tensor_message = serializer.serialize( data, modality=bittensor.proto.Modality.TENSOR, from_type=bittensor.proto.TensorType.TORCH) assert data.requires_grad == serialized_tensor_message.requires_grad assert list(data.shape) == serialized_tensor_message.shape assert serialized_tensor_message.modality == bittensor.proto.Modality.TENSOR assert serialized_tensor_message.dtype == bittensor.proto.DataType.FLOAT32 deserialized_tensor_message = serializer.deserialize( serialized_tensor_message, to_type=bittensor.proto.TensorType.TORCH) assert serialized_tensor_message.requires_grad == deserialized_tensor_message.requires_grad assert serialized_tensor_message.shape == list( deserialized_tensor_message.shape) assert serialization.torch_dtype_to_bittensor_dtype( deserialized_tensor_message.dtype ) == bittensor.proto.DataType.FLOAT32 assert torch.all(torch.eq(deserialized_tensor_message, data))
def test_serialize_object_type_exception(self): # Let's grab a random image, and try and de-serialize it incorrectly. image = torch.ones([1, 28, 28]) serializer = serialization.get_serializer( serialzer_type=bittensor.proto.Serializer.MSGPACK) with pytest.raises( serialization.SerializationTypeNotImplementedException): serializer.serialize(image, modality=bittensor.proto.Modality.IMAGE, from_type=11)
def test_serialize(self): for _ in range(10): tensor_a = torch.rand([12, 23]) serializer = serialization.get_serializer( serialzer_type=bittensor.proto.Serializer.MSGPACK) content = serializer.serialize( tensor_a, modality=bittensor.proto.Modality.TENSOR, from_type=bittensor.proto.TensorType.TORCH) tensor_b = serializer.deserialize( content, to_type=bittensor.proto.TensorType.TORCH) torch.all(torch.eq(tensor_a, tensor_b))
def test_deserialization_object_type_exception(self): data = torch.rand([12, 23]) serializer = serialization.get_serializer( serialzer_type=bittensor.proto.Serializer.MSGPACK) tensor_message = serializer.serialize( data, modality=bittensor.proto.Modality.TEXT, from_type=bittensor.proto.TensorType.TORCH) with pytest.raises( serialization.SerializationTypeNotImplementedException): serializer.deserialize(tensor_message, to_type=11)
def test_single_item_backward_request(): axon.serve(synapse) x = torch.rand(3, 3, bittensor.__network_dim__) serializer = serialization.get_serializer( serialzer_type=bittensor_pb2.Serializer.MSGPACK) x_serialized = serializer.serialize( x, modality=bittensor_pb2.Modality.TENSOR, from_type=bittensor_pb2.TensorType.TORCH) request = bittensor_pb2.TensorMessage(version=bittensor.__version__, public_key=keypair.public_key, tensors=[x_serialized]) response = axon.Backward(request, None) assert response.return_code == bittensor_pb2.ReturnCode.InvalidRequest
def test_serialize_object_type_exception(self): # Let's grab some image data data = torchvision.datasets.MNIST(root='data/datasets/', train=True, download=True, transform=transforms.ToTensor()) # Let's grab a random image, and try and de-serialize it incorrectly. image = data[randrange(len(data))][0] serializer = serialization.get_serializer( serialzer_type=bittensor.proto.Serializer.MSGPACK) with pytest.raises( serialization.SerializationTypeNotImplementedException): serializer.serialize(image, modality=bittensor.proto.Modality.IMAGE, from_type=11)
def test_receptor_neuron_serve_timeout(): y = torch.rand(3, 3, bittensor.__network_dim__) serializer = serialization.get_serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) y_serialized = serializer.serialize(y, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) mock_return_val = bittensor.proto.TensorMessage( version = bittensor.__version__, public_key = wallet.keypair.public_key, return_code = bittensor.proto.ReturnCode.Timeout, tensors = [y_serialized]) stub.Forward = MagicMock( return_value=mock_return_val ) receptor.stub = stub x = torch.rand(3, 3, bittensor.__network_dim__) out, ops = receptor.forward(x, bittensor.proto.Modality.TENSOR) assert ops.item() == bittensor.proto.ReturnCode.Timeout assert list(out.shape) == [3, 3, bittensor.__network_dim__]
def test_forward_not_implemented(): axon.serve(synapse) nucleus.forward = MagicMock(return_value=[ None, 'not implemented', bittensor_pb2.ReturnCode.NotImplemented ]) x = torch.rand(3, 3, bittensor.__network_dim__) serializer = serialization.get_serializer( serialzer_type=bittensor_pb2.Serializer.MSGPACK) x_serialized = serializer.serialize( x, modality=bittensor_pb2.Modality.TENSOR, from_type=bittensor_pb2.TensorType.TORCH) request = bittensor_pb2.TensorMessage(version=bittensor.__version__, public_key=keypair.public_key, tensors=[x_serialized]) response = axon.Forward(request, None) assert response.return_code == bittensor_pb2.ReturnCode.NotImplemented
def test_backward_success(): axon.serve(synapse) x = torch.rand(3, 3, bittensor.__network_dim__) serializer = serialization.get_serializer( serialzer_type=bittensor_pb2.Serializer.MSGPACK) x_serialized = serializer.serialize( x, modality=bittensor_pb2.Modality.TENSOR, from_type=bittensor_pb2.TensorType.TORCH) request = bittensor_pb2.TensorMessage(version=bittensor.__version__, public_key=keypair.public_key, tensors=[x_serialized, x_serialized]) nucleus.backward = MagicMock( return_value=[x, 'success', bittensor_pb2.ReturnCode.Success]) response = axon.Backward(request, None) assert response.return_code == bittensor_pb2.ReturnCode.Success assert len(response.tensors) == 1 assert response.tensors[0].shape == [3, 3, bittensor.__network_dim__] assert serialization_utils.bittensor_dtype_to_torch_dtype( response.tensors[0].dtype) == torch.float32
def test_remote_neuron_mock_server_shape_error(): y = torch.rand(1, 3, bittensor.__network_dim__) serializer = serialization.get_serializer( serialzer_type=bittensor_pb2.Serializer.MSGPACK) y_serialized = serializer.serialize( y, modality=bittensor_pb2.Modality.TENSOR, from_type=bittensor_pb2.TensorType.TORCH) mock_return_val = bittensor_pb2.TensorMessage( version=bittensor.__version__, public_key=keypair.public_key, return_code=bittensor_pb2.ReturnCode.Success, tensors=[y_serialized]) stub.Forward = MagicMock(return_value=mock_return_val) remote.stub = stub x = torch.rand(3, 3, bittensor.__network_dim__) out, ops = remote.forward(x, bittensor_pb2.Modality.TENSOR) assert ops.item() == bittensor_pb2.ReturnCode.ResponseShapeException assert list(out.shape) == [3, 3, bittensor.__network_dim__]
def test_serialize_deserialize_text(self): # Let's create some text data words = ["This", "is", "a", "word", "list"] max_l = 0 ts_list = [] for w in words: ts_list.append(torch.ByteTensor(list(bytes(w, 'utf8')))) max_l = max(ts_list[-1].size()[0], max_l) data = torch.zeros((len(ts_list), max_l), dtype=torch.int64) for i, ts in enumerate(ts_list): data[i, 0:ts.size()[0]] = ts serializer = serialization.get_serializer( serialzer_type=bittensor.proto.Serializer.MSGPACK) serialized_data_tensor_message = serializer.serialize( data, modality=bittensor.proto.Modality.TEXT, from_type=bittensor.proto.TensorType.TORCH) assert data.requires_grad == serialized_data_tensor_message.requires_grad assert list(data.shape) == serialized_data_tensor_message.shape assert serialized_data_tensor_message.modality == bittensor.proto.Modality.TEXT assert serialized_data_tensor_message.dtype != bittensor.proto.DataType.UNKNOWN deserialized_data_tensor_message = serializer.deserialize( serialized_data_tensor_message, to_type=bittensor.proto.TensorType.TORCH) assert serialized_data_tensor_message.requires_grad == deserialized_data_tensor_message.requires_grad assert serialized_data_tensor_message.shape == list( deserialized_data_tensor_message.shape) assert serialization.torch_dtype_to_bittensor_dtype( deserialized_data_tensor_message.dtype ) != bittensor.proto.DataType.UNKNOWN assert torch.all(torch.eq(deserialized_data_tensor_message, data))
def _backward(self, request): r""" Performs validity checks on the grpc request before calling nucleus backward. Returns a the output, message and code from the backend backward call. Args: request (:obj:`bittensor_pb2`, `required`): Tensor request proto. Returns: response: (:obj:`bittensor_pb2.Tensor, `required`): serialized tensor response from the nucleus call or None. message: (str, `required`): message associated with forward call, potentially error, or 'success'. code: (:obj:`bittensor_pb2.ReturnCode, `required`) return code associated with forward call i.e. Success of Timeout. """ # ---- Check that we have a synapse ----. if self.synapse == None: message = "Remote axon not serving a synapse" code = bittensor_pb2.ReturnCode.NotServingSynapse return None, message, code # ---- Check request inputs ----. if len(request.tensors) == 2: inputs_x = request.tensors[0] grads_dy = request.tensors[1] modality_x = inputs_x.modality else: message = "During backward: There are {} tensors in the request, expected 2.".format( len(request.tensors)) code = bittensor_pb2.ReturnCode.InvalidRequest return None, message, code # ---- Deserialize request --- try: serializer = serialization.get_serializer(inputs_x.serializer) inputs_x = serializer.deserialize( inputs_x, to_type=bittensor_pb2.TensorType.TORCH) grads_dy = serializer.deserialize( grads_dy, to_type=bittensor_pb2.TensorType.TORCH) except Exception as e: message = "Backward request deserialization failed with unknown error {}".format( e) code = bittensor_pb2.ReturnCode.RequestDeserializationException return None, message, code # --- Get call priority ---- try: call_priority = self.priority[request.public_key] + random.random() except: call_priority = 1 + random.random() # ---- Save gradients to buffer for later use. --- try: self.gradients.put( (call_priority, (request.public_key, inputs_x, grads_dy, modality_x)), block=False) except queue.Full: logger.trace('gradient queue is full at size: {}', self.gradients.qsize()) # ---- Nucleus backward call ---- try: outputs, message, code = self._nucleus.backward( synapse=self.synapse, inputs_x=inputs_x, grads_dy=grads_dy, modality=modality_x, priority=call_priority) except Exception as e: message = "Unkown exception when calling backward with error {}".format( e) code = bittensor_pb2.ReturnCode.UnknownException return None, message, code # ---- Deserialize response ---- try: serializer = serialization.get_serializer( bittensor_pb2.Serializer.MSGPACK) outputs_serialized = serializer.serialize( outputs, modality=bittensor_pb2.Modality.TENSOR, from_type=bittensor_pb2.TensorType.TORCH) except Exception as e: message = "Backward request serialization failed with error {} and inputs {}".format( e, outputs) code = bittensor_pb2.ReturnCode.ResponseSerializationException return None, message, code # ---- Finaly return ---- return outputs_serialized, message, code
def _forward(self, request): r""" Performs validity checks on the grpc request before calling nucleus forward. Returns the output, message and code from the backend forward call. Args: request (:obj:`bittensor_pb2`, `required`): Tensor request proto. Returns: response: (:obj:`bittensor_pb2.Tensor, `required`): serialized tensor response from the nucleus call or None. message: (str, `required`): message associated with forward call, potentially error, or 'success'. code: (:obj:`bittensor_pb2.ReturnCode, `required`) return code associated with forward call i.e. Success of Timeout. """ # ---- Check synapse exists ---- if self.synapse == None: message = "Remote axon not serving a synapse" code = bittensor_pb2.ReturnCode.NotServingSynapse return None, message, code # ---- Check Empty request ---- if len(request.tensors) == 0: message = "Forward request contains {} tensors, expected 1 tensor in the forward call".format( len(request.tensors)) code = bittensor_pb2.ReturnCode.EmptyRequest return None, message, code # ---- Check deserialization ---- inputs = request.tensors[0] try: deserializer = serialization.get_serializer( serialzer_type=inputs.serializer) x = deserializer.deserialize( inputs, to_type=bittensor_pb2.TensorType.TORCH) except Exception as e: message = "Forward request deserialization failed with error {}".format( e) code = bittensor_pb2.ReturnCode.RequestDeserializationException return None, message, code # ---- Check shape and modality ---- if x.shape[0] < 1: message = "Froward request batch dim exception with batch_size = {} ".format( x.shape[0]) code = bittensor_pb2.ReturnCode.RequestShapeException return None, message, code if x.shape[1] < 1: message = "Forward request sequence dim exception with sequence_dim = {} ".format( x.shape[1]) code = bittensor_pb2.ReturnCode.RequestShapeException return None, message, code if inputs.modality == bittensor_pb2.Modality.TEXT: if len(x.shape) != 2: message = "Forward text input shape exception with len(request.shape) = {} must have rank 2.".format( len(x.shape)) code = bittensor_pb2.ReturnCode.RequestShapeException return None, message, code if inputs.modality == bittensor_pb2.Modality.IMAGE: if len(x.shape) != 5: message = "Forward image input shape exception for len(shape) = {} must have rank 5".format( len(x.shape)) code = bittensor_pb2.ReturnCode.RequestShapeException return None, message, code if inputs.modality == bittensor_pb2.Modality.TENSOR: if len(x.shape) != 3: message = "Forward message tensor input shape exception len(shape) = {} must have rank 3".format( len(x.shape)) code = bittensor_pb2.ReturnCode.RequestShapeException return None, message, code # --- Get call priority ---- call_priority = self.get_call_priority(request) # ---- Make Nucleus forward call. ---- try: outputs, message, code = self._nucleus.forward( synapse=self.synapse.to(self.synapse.device), inputs=x.to(self.synapse.device), mode=inputs.modality, priority=call_priority) # ---- Catch Nucleus errors ---- if code != bittensor_pb2.ReturnCode.Success: return None, message, code except Exception as e: message = "Unknown exception when calling nucleus forward {}".format( e) code = bittensor_pb2.ReturnCode.UnknownException return None, message, code # ---- Serialize response ---- try: serializer = serialization.get_serializer( bittensor_pb2.Serializer.MSGPACK) outputs_serialized = serializer.serialize( outputs, modality=bittensor_pb2.Modality.TENSOR, from_type=bittensor_pb2.TensorType.TORCH) except Exception as e: message = "Serializtion of forward response failed with error {} and inputs: {}".format( e, outputs) code = bittensor_pb2.ReturnCode.ResponseDeserializationException return None, message, code # ---- Return successful response ---- return outputs_serialized, message, code
def forward(ctx, caller: Receptor, dummy: torch.Tensor, inputs: torch.Tensor, mode: bittensor.proto.Modality) -> Tuple[torch.Tensor, int]: """ Internal autograd-friendly Forward RPC call to a remote neuron (calls the Forward method on an Axon terminal.) Args: ctx: (:obj:`torch.autograd.ctx`, `required`): Autograd context, saves state information between forward and backward calls. i.e. inputs for gradient computation. caller: (:obj:`Receptor`, `required`): Caller object the remote neuron containing the endpoint information, RPC channel etc. dummy: (:obj:`torch.Tensor`, `required`): Dummy torch tensor used to ensure that torch.backward computation is called on this function regardless of the input types. inputs (:obj:`List[torch.Tensor]` of shape :obj:`(shape)`, `required`): Torch tensor to be sent to the caller associated endpoint neurons. mode (:obj:`bittensor.proto.Modality` of shape :obj:`(1)`, `required`): Bittensor forward modality type. Enum in [TEXT, IMAGE, TENSOR] Returns: output (:obj:`Tuple[torch.FloatTensor`, torch.LongTensor]`, `optional`): Result from forward call. May be None in the case of failure. code (:obj:`bittensor.proto.ReturnCode`, `required`): Return code associated with forward call. """ # ---- Save for backward call ---- ctx.caller = caller ctx.mode = mode ctx.inputs = inputs zeros = nill_response_for(inputs) try: # ---- Check inputs size ---- if torch.numel(inputs) == 0: return zeros, torch.tensor( bittensor.proto.ReturnCode.EmptyRequest) # ---- Inputs Serialization ---- try: serializer = serialization.get_serializer( bittensor.proto.Serializer.MSGPACK) serialized_inputs = serializer.serialize( inputs, modality=mode, from_type=bittensor.proto.TensorType.TORCH) except Exception as e: logger.warning('Serialization error with error {}', e) return zeros, torch.tensor( bittensor.proto.ReturnCode.RequestSerializationException) ctx.serialized_inputs = serialized_inputs # ---- Build request ---- request = bittensor.proto.TensorMessage( version=bittensor.__version__, public_key=ctx.caller.wallet.hotkey.public_key, nounce=ctx.caller.nounce, signature=ctx.caller.signature, tensors=[serialized_inputs]) # ---- Make RPC call ---- try: start_time = time.time() ctx.caller.stats.forward_qps.update(1) ctx.caller.stats.forward_bytes_out.update( sys.getsizeof(request)) response = ctx.caller.stub.Forward( request, timeout=caller.config.receptor.timeout) ctx.caller.stats.forward_bytes_in.update( sys.getsizeof(response)) ctx.caller.stats.forward_elapsed_time.update( (time.time() - start_time)) # ---- Catch non-code ---- try: bittensor_code = response.return_code except: logger.error( 'Unknown exception returned from remote host with message {}, {}', response.message, traceback.format_exc()) return zeros, torch.tensor(bittensor_code) # ---- Catch bittensor errors ---- if bittensor_code == bittensor.proto.ReturnCode.UnknownException: logger.error( 'Unknown exception returned from remote host with message {}, {}', response.message, traceback.format_exc()) return zeros, torch.tensor(bittensor_code) elif bittensor_code != bittensor.proto.ReturnCode.Success: return zeros, torch.tensor(bittensor_code) # ---- Catch GRPC Errors ---- except grpc.RpcError as rpc_error_call: grpc_code = rpc_error_call.code() if grpc_code == grpc.StatusCode.DEADLINE_EXCEEDED: return zeros, torch.tensor( bittensor.proto.ReturnCode.Timeout) elif grpc_code == grpc.StatusCode.UNAVAILABLE: return zeros, torch.tensor( bittensor.proto.ReturnCode.Unavailable) else: logger.error( 'Uncaught GPRC error exception with code {} from endpoint {}', grpc_code, caller.endpoint) return zeros, torch.tensor( bittensor.proto.ReturnCode.UnknownException) # ---- Catch Unknown Errors ---- except Exception as e: logger.error( 'Uncaught error in forward call with error {} and endpoint', e, caller.endpoint) return zeros, torch.tensor( bittensor.proto.ReturnCode.UnknownException) # ---- Check tensor response length ---- if len(response.tensors) == 0: return zeros, torch.tensor( bittensor.proto.ReturnCode.EmptyResponse) # ---- Deserialize response ---- try: outputs = response.tensors[0] deserializer = serialization.get_serializer(outputs.serializer) outputs = deserializer.deserialize( outputs, to_type=bittensor.proto.TensorType.TORCH) except Exception as e: logger.error( 'Failed to serialize responses from forward call with error {}', e) return zeros, torch.tensor(bittensor.proto.ReturnCode. ResponseDeserializationException) # ---- Check response shape ---- if outputs.size(0) != inputs.size(0) \ or outputs.size(1) != inputs.size(1) \ or outputs.size(2) != bittensor.__network_dim__: logger.error( 'Forward request returned tensor with incorrect shape {}', list(outputs.shape)) return zeros, torch.tensor( bittensor.proto.ReturnCode.ResponseShapeException) # ---- Safe catch NaNs and replace with 0.0 ---- outputs = torch.where(torch.isnan(outputs), torch.zeros_like(outputs), outputs) # ---- Catch all ---- except Exception as e: logger.error('Forward request returned unknown error {}', e) return zeros, torch.tensor( bittensor.proto.ReturnCode.UnknownException) # ---- Return ---- return outputs, torch.tensor(response.return_code)
def backward(ctx, grads: torch.FloatTensor, code: torch.FloatTensor) -> Optional[torch.Tensor]: """ Internal autograd-friendly Backward RPC call to a remote neuron (calls the Backward method on an remote Axon terminal.) Args: ctx: (:obj:`torch.autograd.ctx`, `required`): Autograd context, saves state information between forward and backward calls. i.e. inputs for gradient computation. grads (:obj:`List[torch.Tensor]` of shape :obj:`(shape)`, `required`): Gradients of this function's outputs computed during the loss.backward() call. code (:obj:`bittensor.proto.Modality` of shape :obj:`(1)`, `required`): Code output from the forward call. Returns: output (:obj:`Tuple[torch.FloatTensor`, torch.LongTensor]`, `optional`): Gradients of the inputs with respect to the inputs and grads of the outputs. """ # ---- Zeros response in the case of failure ---- zeros = nill_response_for(ctx.inputs) # ---- Check if are passing gradients ---- if not ctx.caller.config.receptor.pass_gradients: return (None, None, zeros, None) # ---- Check that forward query was a success ---- if code.item() != bittensor.proto.ReturnCode.Success: return (None, None, zeros, None) # ---- Try to pass gradients ---- else: try: # ---- Get forward call serialzied inputs ---- try: serialized_inputs = ctx.serialized_inputs except: logger.trace( 'backward failed because forward previously failed.') return (None, None, zeros, None) # ---- Serialization ---- try: # ---- Get serializer ---- serializer = serialization.get_serializer( bittensor.proto.Serializer.MSGPACK) # ---- Serialize grads to bitensor_pb2.Tensors ---- serialized_grads = serializer.serialize( grads, modality=bittensor.proto.Modality.TENSOR, from_type=bittensor.proto.TensorType.TORCH) except Exception as e: logger.trace( 'backward failed during serialization of gradients.') return (None, None, zeros, None) # ---- Build request for backward ---- request = bittensor.proto.TensorMessage( version=bittensor.__version__, public_key=ctx.caller.wallet.hotkey.public_key, nounce=ctx.caller.nounce, signature=ctx.caller.signature, tensors=[serialized_inputs, serialized_grads]) # --- Send non blocking grad request ---- # NOTE(const): we dont care about the response. try: ctx.caller.stats.backward_qps.update(1) ctx.caller.stats.backwar_bytes_out.update( sys.getsizeof(request)) ctx.caller.stub.Backward.future( request, timeout=ctx.caller.config.receptor.timeout) ctx.caller.stats.backwar_bytes_in.update( 0.0) # responses are dropped. except: logger.trace( 'backward failed during backward call. Do not care.') return (None, None, zeros, None) # ---- Always return zeros ---- # NOTE(const): We can return non zeros but a remote host could mess with your training # without you knowing about it. i.e. by passing you malicious gradients. return (None, None, zeros, None) except: # ---- Catch all exceptions in Backward ---- rollbar.send_exception() return (None, None, zeros, None)
def _backward(self, request): r""" Performs validity checks on the grpc request before calling nucleus backward. Returns a the output, message and code from the backend backward call. Args: request (:obj:`bittensor.proto`, `required`): Tensor request proto. Returns: response: (:obj:`bittensor.proto.Tensor, `required`): serialized tensor response from the nucleus call or None. message: (str, `required`): message associated with forward call, potentially error, or 'success'. code: (:obj:`bittensor.proto.ReturnCode, `required`) return code associated with forward call i.e. Success of Timeout. """ # ---- Check that we have a synapse ----. if self.synapse == None: message = "Remote axon not serving a synapse" code = bittensor.proto.ReturnCode.NotServingSynapse return None, message, code # ---- Check request inputs ----. if len(request.tensors) == 2: inputs_x = request.tensors[0] grads_dy = request.tensors[1] modality_x = inputs_x.modality else: message = "During backward: There are {} tensors in the request, expected 2.".format( len(request.tensors)) logger.debug( '<white>Axon</white> <red>Backward Request</red> --->x <white>code</white>:<yellow>InvalidRequest</yellow>, <white>from</white>:<cyan>{}</cyan>, <white>message</white>:<red>{}</red>', request.public_key, message) code = bittensor.proto.ReturnCode.InvalidRequest return None, message, code # ---- Deserialize request --- try: serializer = serialization.get_serializer(inputs_x.serializer) inputs_x = serializer.deserialize( inputs_x, to_type=bittensor.proto.TensorType.TORCH) grads_dy = serializer.deserialize( grads_dy, to_type=bittensor.proto.TensorType.TORCH) except Exception as e: message = "Request serialization exception with error: {}".format( str(e)) logger.debug( '<white>Axon</white> <red>Backward Request</red> --->x <white>code</white>:<yellow>RequestDeserializationException</yellow>, <white>from</white>:<cyan>{}</cyan>, <white>message</white>:<red>{}</red>', request.public_key, message) code = bittensor.proto.ReturnCode.RequestDeserializationException return None, message, code # ---- Check shapes ---- if modality_x == bittensor.proto.Modality.TEXT: if len(inputs_x.shape) != 2: message = "Forward text input shape exception with len(request.shape) = {} must have rank 2.".format( len(inputs_x.shape)) logger.debug( '<white>Axon</white> <red>Backward Request</red> --->x <white>code</white>:<yellow>RequestShapeException</yellow>, <white>from</white>:<cyan>{}</cyan>, <white>message</white>:<red>{}</red>', request.public_key, message) return None, message, bittensor.proto.ReturnCode.RequestShapeException if modality_x == bittensor.proto.Modality.IMAGE: if len(inputs_x.shape) != 5: message = "Forward image input shape exception for len(shape) = {} must have rank 5".format( len(inputs_x.shape)) logger.debug( '<white>Axon</white> <red>Backward Request</red> --->x <white>code</white>:<yellow>RequestShapeException</yellow>, <white>from</white>:<cyan>{}</cyan>, <white>message</white>:<red>{}</red>', request.public_key, message) return None, message, bittensor.proto.ReturnCode.RequestShapeException if modality_x == bittensor.proto.Modality.TENSOR: if len(inputs_x.shape) != 3: message = "Forward message tensor input shape exception len(shape) = {} must have rank 3".format( len(inputs_x.shape)) logger.debug( '<white>Axon</white> <red>Backward Request</red> --->x <white>code</white>:<yellow>RequestShapeException</yellow>, <white>from</white>:<cyan>{}</cyan>, <white>message</white>:<red>{}</red>', request.public_key, message) return None, message, bittensor.proto.ReturnCode.RequestShapeException if len(grads_dy.shape) != 3: message = "Passed gradients must have rank 3 but got {}".format( len(grads_dy.shape)) logger.debug( '<white>Axon</white> <red>Backward Request</red> --->x <white>code</white>:<yellow>RequestShapeException</yellow>, <white>from</white>:<cyan>{}</cyan>, <white>message</white>:<red>{}</red>', request.public_key, message) return None, message, bittensor.proto.ReturnCode.RequestShapeException if grads_dy.shape[0] != inputs_x.shape[0] or grads_dy.shape[ 1] != inputs_x.shape[1]: message = "Passed gradients must same first and second dimension as passed inputs got shapes {} and {}".format( grads_dy.shape, inputs_x.shape) logger.debug( '<white>Axon</white> <red>Backward Request</red> --->x <white>code</white>:<yellow>RequestShapeException</yellow>, <white>from</white>:<cyan>{}</cyan>, <white>message</white>:<red>{}</red>', request.public_key, message) return None, message, bittensor.proto.ReturnCode.RequestShapeException # --- Get call priority ---- try: call_priority = self.priority[request.public_key] + random.random() except: call_priority = 1 + random.random() # ---- Save gradients to buffer for later use. --- try: self.gradients.put( (call_priority, (request.public_key, inputs_x, grads_dy, modality_x)), block=False) except queue.Full: logger.trace('gradient queue is full at size: {}', self.gradients.qsize()) # ---- nucleus.Nucleus backward call ---- try: outputs, message, code = self.nucleus.backward( synapse=self.synapse, inputs_x=inputs_x, grads_dy=grads_dy, modality=modality_x, priority=call_priority) except Exception as e: message = "Unkown exception when calling backward with error {}".format( e) code = bittensor.proto.ReturnCode.UnknownException return None, message, code # ---- Deserialize response ---- try: serializer = serialization.get_serializer( bittensor.proto.Serializer.MSGPACK) outputs_serialized = serializer.serialize( outputs, modality=bittensor.proto.Modality.TENSOR, from_type=bittensor.proto.TensorType.TORCH) except Exception as e: message = "Backward request serialization failed with error {} and inputs {}".format( e, outputs) code = bittensor.proto.ReturnCode.ResponseSerializationException return None, message, code # ---- Finaly return ---- logger.debug( '<white>Axon</white> <green>Backward Response</green> <--- <white>code</white>:<green>Success</green>, <white>to</white>:<cyan>{}</cyan>, <white>outputs</white>:<cyan>{}</cyan>', request.public_key, outputs.shape) return outputs_serialized, message, code