예제 #1
0
파일: hooks_test.py 프로젝트: kmader/PySyft
    def test_types_guard(self):
        hook = TorchHook(verbose=False)

        with self.assertRaises(Exception) as context:
            # can't serialize an int, so should raise a TypError
            obj_type = hook.types_guard(3)

        with self.assertRaises(Exception) as context:

            # can't serialize a randoms tring as type, so should raise a TypError
            obj_type = hook.types_guard("asdf")

            assert obj_type == obj_type

            self.assertTrue('TypeError' in context.exception)

        tensor_types = {
            'torch.FloatTensor': torch.FloatTensor,
            'torch.DoubleTensor': torch.DoubleTensor,
            'torch.HalfTensor': torch.HalfTensor,
            'torch.ByteTensor': torch.ByteTensor,
            'torch.CharTensor': torch.CharTensor,
            'torch.ShortTensor': torch.ShortTensor,
            'torch.IntTensor': torch.IntTensor,
            'torch.LongTensor': torch.LongTensor
        }

        for k, v in tensor_types.items():
            assert hook.types_guard(k) == v
예제 #2
0
    def test_types_guard(self):
        hook = TorchHook(verbose=False)

        with self.assertRaises(Exception) as context:
            # can't serialize an int, so should raise a TypError
            obj_type = hook.types_guard(3)

        with self.assertRaises(Exception) as context:

            # can't serialize a random string as a type, so should raise a TypError
            obj_type = hook.guard.types_guard("asdf")

            assert obj_type == obj_type

            self.assertTrue('TypeError' in context.exception)

        tensor_types = {
            'torch.FloatTensor': torch.FloatTensor,
            'torch.DoubleTensor': torch.DoubleTensor,
            'torch.HalfTensor': torch.HalfTensor,
            'torch.ByteTensor': torch.ByteTensor,
            'torch.CharTensor': torch.CharTensor,
            'torch.ShortTensor': torch.ShortTensor,
            'torch.IntTensor': torch.IntTensor,
            'torch.LongTensor': torch.LongTensor
        }

        for k, v in tensor_types.items():
            assert hook.guard.types_guard(k) == v
예제 #3
0
파일: torch_test.py 프로젝트: kmader/PySyft
    def test_deser_tensor_from_message(self):

        hook = TorchHook(verbose=False)

        message_obj = json.loads(
            ' {"torch_type": "torch.FloatTensor", "data": [1.0, 2.0, \
                                 3.0, 4.0, 5.0], "id": 9756847736, "owners": [1], "is_poin\
                                 ter": false}')
        obj_type = hook.types_guard(message_obj['torch_type'])
        unregistered_tensor = torch.FloatTensor.deser(obj_type, message_obj)

        assert (unregistered_tensor == torch.FloatTensor(
            [1, 2, 3, 4, 5])).float().sum() == 5

        # has not been registered
        assert unregistered_tensor.id != 9756847736