コード例 #1
0
    def test_serialize_deserialize_tensor(self):
        data = torch.rand([12, 23])

        serializer = serialization.get_serializer(
            serialzer_type=bittensor.proto.Serializer.MSGPACK)
        serialized_tensor_message = serializer.serialize(
            data,
            modality=bittensor.proto.Modality.TENSOR,
            from_type=bittensor.proto.TensorType.TORCH)

        assert data.requires_grad == serialized_tensor_message.requires_grad
        assert list(data.shape) == serialized_tensor_message.shape
        assert serialized_tensor_message.modality == bittensor.proto.Modality.TENSOR
        assert serialized_tensor_message.dtype == bittensor.proto.DataType.FLOAT32

        deserialized_tensor_message = serializer.deserialize(
            serialized_tensor_message,
            to_type=bittensor.proto.TensorType.TORCH)
        assert serialized_tensor_message.requires_grad == deserialized_tensor_message.requires_grad
        assert serialized_tensor_message.shape == list(
            deserialized_tensor_message.shape)
        assert serialization.torch_dtype_to_bittensor_dtype(
            deserialized_tensor_message.dtype
        ) == bittensor.proto.DataType.FLOAT32

        assert torch.all(torch.eq(deserialized_tensor_message, data))
コード例 #2
0
    def test_serialize_deserialize_image(self):
        # Let's grab some image data
        data = torchvision.datasets.MNIST(root='data/datasets/',
                                          train=True,
                                          download=True,
                                          transform=transforms.ToTensor())

        # Let's grab a random image, and give it a crazy type to break the system
        image = data[randrange(len(data))][0]

        serializer = serialization.get_serializer(
            serialzer_type=bittensor.proto.Serializer.MSGPACK)
        serialized_image_tensor_message = serializer.serialize(
            image,
            modality=bittensor.proto.Modality.IMAGE,
            from_type=bittensor.proto.TensorType.TORCH)

        assert image.requires_grad == serialized_image_tensor_message.requires_grad
        assert list(image.shape) == serialized_image_tensor_message.shape
        assert serialized_image_tensor_message.modality == bittensor.proto.Modality.IMAGE
        assert serialized_image_tensor_message.dtype != bittensor.proto.DataType.UNKNOWN

        deserialized_image_tensor_message = serializer.deserialize(
            serialized_image_tensor_message,
            to_type=bittensor.proto.TensorType.TORCH)
        assert serialized_image_tensor_message.requires_grad == deserialized_image_tensor_message.requires_grad
        assert serialized_image_tensor_message.shape == list(
            deserialized_image_tensor_message.shape)
        assert serialization.torch_dtype_to_bittensor_dtype(
            deserialized_image_tensor_message.dtype
        ) != bittensor.proto.DataType.UNKNOWN

        assert torch.all(torch.eq(deserialized_image_tensor_message, image))
コード例 #3
0
    def test_serialize_deserialize_image(self):
        # Let's grab some image data
        # Let's grab a random image, and give it a crazy type to break the system
        image = torch.ones([1, 28, 28])

        serializer = serialization.get_serializer(
            serialzer_type=bittensor.proto.Serializer.MSGPACK)
        serialized_image_tensor_message = serializer.serialize(
            image,
            modality=bittensor.proto.Modality.IMAGE,
            from_type=bittensor.proto.TensorType.TORCH)

        assert image.requires_grad == serialized_image_tensor_message.requires_grad
        assert list(image.shape) == serialized_image_tensor_message.shape
        assert serialized_image_tensor_message.modality == bittensor.proto.Modality.IMAGE
        assert serialized_image_tensor_message.dtype != bittensor.proto.DataType.UNKNOWN

        deserialized_image_tensor_message = serializer.deserialize(
            serialized_image_tensor_message,
            to_type=bittensor.proto.TensorType.TORCH)
        assert serialized_image_tensor_message.requires_grad == deserialized_image_tensor_message.requires_grad
        assert serialized_image_tensor_message.shape == list(
            deserialized_image_tensor_message.shape)
        assert serialization.torch_dtype_to_bittensor_dtype(
            deserialized_image_tensor_message.dtype
        ) != bittensor.proto.DataType.UNKNOWN

        assert torch.all(torch.eq(deserialized_image_tensor_message, image))
コード例 #4
0
    def test_serialize_deserialize_text(self):
        # Let's create some text data
        words = ["This", "is", "a", "word", "list"]
        max_l = 0
        ts_list = []
        for w in words:
            ts_list.append(torch.ByteTensor(list(bytes(w, 'utf8'))))
            max_l = max(ts_list[-1].size()[0], max_l)

        data = torch.zeros((len(ts_list), max_l), dtype=torch.int64)
        for i, ts in enumerate(ts_list):
            data[i, 0:ts.size()[0]] = ts

        serializer = serialization.get_serializer(
            serialzer_type=bittensor.proto.Serializer.MSGPACK)
        serialized_data_tensor_message = serializer.serialize(
            data,
            modality=bittensor.proto.Modality.TEXT,
            from_type=bittensor.proto.TensorType.TORCH)

        assert data.requires_grad == serialized_data_tensor_message.requires_grad
        assert list(data.shape) == serialized_data_tensor_message.shape
        assert serialized_data_tensor_message.modality == bittensor.proto.Modality.TEXT
        assert serialized_data_tensor_message.dtype != bittensor.proto.DataType.UNKNOWN

        deserialized_data_tensor_message = serializer.deserialize(
            serialized_data_tensor_message,
            to_type=bittensor.proto.TensorType.TORCH)
        assert serialized_data_tensor_message.requires_grad == deserialized_data_tensor_message.requires_grad
        assert serialized_data_tensor_message.shape == list(
            deserialized_data_tensor_message.shape)
        assert serialization.torch_dtype_to_bittensor_dtype(
            deserialized_data_tensor_message.dtype
        ) != bittensor.proto.DataType.UNKNOWN

        assert torch.all(torch.eq(deserialized_data_tensor_message, data))