def test_deserialize_federated_value_with_incompatible_member_types_raises(
      self):
    x = 10
    x_type = computation_types.to_type(tf.int32)
    int_member_proto, _ = value_serialization.serialize_value(x, x_type)
    y = 10.
    y_type = computation_types.to_type(tf.float32)
    float_member_proto, _ = value_serialization.serialize_value(y, y_type)
    fully_specified_type_at_clients = type_serialization.serialize_type(
        computation_types.at_clients(tf.int32))

    unspecified_member_federated_type = computation_pb2.FederatedType(
        placement=fully_specified_type_at_clients.federated.placement,
        all_equal=False)

    federated_proto = serialization_bindings.Federated(
        type=unspecified_member_federated_type,
        value=[int_member_proto, float_member_proto])
    federated_value_proto = serialization_bindings.Value(
        federated=federated_proto)

    self.assertIsInstance(int_member_proto, serialization_bindings.Value)
    self.assertIsInstance(float_member_proto, serialization_bindings.Value)
    self.assertIsInstance(federated_value_proto, serialization_bindings.Value)

    with self.assertRaises(TypeError):
      value_serialization.deserialize_value(federated_value_proto)
Exemplo n.º 2
0
    def test_executor_service_create_one_arg_computation_value_and_call(self):
        ex_factory = executor_test_utils.BasicTestExFactory(
            eager_tf_executor.EagerTFExecutor())
        env = TestEnv(ex_factory)

        @tensorflow_computation.tf_computation(tf.int32)
        def comp(x):
            return tf.add(x, 1)

        value_proto, _ = value_serialization.serialize_value(comp)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                            value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        comp_ref = response.value_ref

        value_proto, _ = value_serialization.serialize_value(10, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                            value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        arg_ref = response.value_ref

        response = env.stub.CreateCall(
            executor_pb2.CreateCallRequest(executor=env.executor_pb,
                                           function_ref=comp_ref,
                                           argument_ref=arg_ref))
        self.assertIsInstance(response, executor_pb2.CreateCallResponse)
        value_id = str(response.value_ref.id)
        value = env.get_value(value_id)
        self.assertEqual(value, 11)
        del env
 def test_serialize_struct_with_type_element_mismatch(self):
   x = {'a': 1}
   with self.assertRaisesRegex(
       TypeError, ('Cannot serialize a struct value of 1 elements to a struct '
                   'type requiring 2 elements.')):
     value_serialization.serialize_value(
         x, computation_types.StructType([('a', tf.int32), ('b', tf.int32)]))
  def test_deserialize_federated_value_promotes_types(self):
    x = [10]
    smaller_type = computation_types.StructType([
        (None, computation_types.to_type(tf.int32))
    ])
    smaller_type_member_proto, _ = value_serialization.serialize_value(
        x, smaller_type)
    larger_type = computation_types.StructType([
        ('a', computation_types.to_type(tf.int32))
    ])
    larger_type_member_proto, _ = value_serialization.serialize_value(
        x, larger_type)
    type_at_clients = type_serialization.serialize_type(
        computation_types.at_clients(tf.int32))

    unspecified_member_federated_type = computation_pb2.FederatedType(
        placement=type_at_clients.federated.placement, all_equal=False)

    federated_proto = executor_pb2.Value.Federated(
        type=unspecified_member_federated_type,
        value=[larger_type_member_proto, smaller_type_member_proto])
    federated_value_proto = executor_pb2.Value(federated=federated_proto)

    _, deserialized_type_spec = value_serialization.deserialize_value(
        federated_value_proto)
    type_test_utils.assert_types_identical(
        deserialized_type_spec, computation_types.at_clients(larger_type))
Exemplo n.º 5
0
    def test_create_value(self):
        executor = executor_bindings.create_reference_resolving_executor(
            executor_bindings.create_tensorflow_executor())
        # 1. Test a simple tensor.
        expected_type_spec = TensorType(shape=[3], dtype=tf.int64)
        value_pb, _ = value_serialization.serialize_value([1, 2, 3],
                                                          expected_type_spec)
        value = executor.create_value(value_pb)
        self.assertIsInstance(value, executor_bindings.OwnedValueId)
        self.assertEqual(value.ref, 0)
        self.assertEqual(str(value), '0')
        self.assertEqual(repr(value), r'<OwnedValueId: 0>')
        materialized_value = executor.materialize(value.ref)
        deserialized_value, type_spec = value_serialization.deserialize_value(
            materialized_value)
        type_test_utils.assert_types_identical(type_spec, expected_type_spec)
        self.assertAllEqual(deserialized_value, [1, 2, 3])
        # 2. Test a struct of tensors, ensure that we get a different ID.
        expected_type_spec = StructType([
            ('a', TensorType(shape=[3], dtype=tf.int64)),
            ('b', TensorType(shape=[], dtype=tf.float32))
        ])
        value_pb, _ = value_serialization.serialize_value(
            collections.OrderedDict(a=tf.constant([1, 2, 3]),
                                    b=tf.constant(42.0)), expected_type_spec)
        value = executor.create_value(value_pb)
        self.assertIsInstance(value, executor_bindings.OwnedValueId)
        # Assert the value ID was incremented.
        self.assertEqual(value.ref, 1)
        self.assertEqual(str(value), '1')
        self.assertEqual(repr(value), r'<OwnedValueId: 1>')
        materialized_value = executor.materialize(value.ref)
        deserialized_value, type_spec = value_serialization.deserialize_value(
            materialized_value)
        # Note: here we've lost the names `a` and `b` in the output. The output
        # is a more _strict_ type.
        self.assertTrue(expected_type_spec.is_assignable_from(type_spec))
        deserialized_value = type_conversions.type_to_py_container(
            deserialized_value, expected_type_spec)
        self.assertAllClose(deserialized_value,
                            collections.OrderedDict(a=(1, 2, 3), b=42.0))
        # 3. Test creating a value from a computation.
        @tensorflow_computation.tf_computation(tf.int32, tf.int32)
        def foo(a, b):
            return tf.add(a, b)

        value_pb, _ = value_serialization.serialize_value(foo)
        value = executor.create_value(value_pb)
        self.assertIsInstance(value, executor_bindings.OwnedValueId)
        # Assert the value ID was incremented again.
        self.assertEqual(value.ref, 2)
        self.assertEqual(str(value), '2')
        self.assertEqual(repr(value), '<OwnedValueId: 2>')
Exemplo n.º 6
0
    def test_dispose_does_not_trigger_cleanup(self):
        class MockFactory(executor_factory.ExecutorFactory, mock.MagicMock):
            def create_executor(self, *args, **kwargs):
                return mock.MagicMock()

            def clean_up_executor(self):
                return

        ex_factory = MockFactory()
        ex_factory.clean_up_executor = mock.MagicMock()

        env = TestEnv(ex_factory)
        value_proto, _ = value_serialization.serialize_value(
            tf.constant(10.0).numpy(), tf.float32)
        # Create the value
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                            value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        value_id = str(response.value_ref.id)
        # Check that the value appears in the _values map
        env.get_value_future_directly(value_id)
        # Dispose of the value
        dispose_request = executor_pb2.DisposeRequest(executor=env.executor_pb)
        dispose_request.value_ref.append(response.value_ref)
        response = env.stub.Dispose(dispose_request)
        # We shouldn't be propagating close down the executor stack on Dispose--this
        # would close the bidi stream and cause a hang in the streaming case with
        # intermediate aggregation. Python GC takes care of pushing Dispose requests
        # from the aggregators to the workers.
        ex_factory.clean_up_executor.assert_not_called()
Exemplo n.º 7
0
 def test_create_selection(self):
     executor = executor_bindings.create_reference_resolving_executor(
         executor_bindings.create_tensorflow_executor())
     expected_type_spec = TensorType(shape=[3], dtype=tf.int64)
     value_pb, _ = value_serialization.serialize_value(
         tf.constant([1, 2, 3]), expected_type_spec)
     value = executor.create_value(value_pb)
     self.assertEqual(value.ref, 0)
     # 1. Create a struct from duplicated values.
     struct_value = executor.create_struct([value.ref, value.ref])
     self.assertEqual(struct_value.ref, 1)
     materialized_value = executor.materialize(struct_value.ref)
     deserialized_value, type_spec = value_serialization.deserialize_value(
         materialized_value)
     struct_type_spec = computation_types.to_type(
         [expected_type_spec, expected_type_spec])
     type_test_utils.assert_types_equivalent(type_spec, struct_type_spec)
     deserialized_value = type_conversions.type_to_py_container(
         deserialized_value, struct_type_spec)
     self.assertAllClose([(1, 2, 3), (1, 2, 3)], deserialized_value)
     # 2. Select the first value out of the struct.
     new_value = executor.create_selection(struct_value.ref, 0)
     materialized_value = executor.materialize(new_value.ref)
     deserialized_value, type_spec = value_serialization.deserialize_value(
         materialized_value)
     type_test_utils.assert_types_equivalent(type_spec, expected_type_spec)
     deserialized_value = type_conversions.type_to_py_container(
         deserialized_value, struct_type_spec)
     self.assertAllClose((1, 2, 3), deserialized_value)
Exemplo n.º 8
0
    def test_raising_failed_precondition_destroys_executor(self):

        # A simple clas to mock out the exceptions raised by GRPC.
        class GrpcFailedPrecondition(grpc.RpcError):
            pass

            def code(self):
                return grpc.StatusCode.FAILED_PRECONDITION

        def _return_error():
            return GrpcFailedPrecondition('Raising failed precondition')

        raising_ex = executor_test_utils.RaisingExecutor(_return_error)
        ex_factory = executor_test_utils.BasicTestExFactory(raising_ex)
        env = TestEnv(ex_factory)
        value_proto, _ = value_serialization.serialize_value(10, tf.int32)
        value_id = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                            value=value_proto))
        # When CreateValue is forced to run, the executor's error will be raised. We
        # raise lazily because the service operates future-to-future.
        with self.assertRaises(grpc.RpcError):
            env.get_value(value_id.value_ref.id)

        # The next CreateValue call, or indeed any call, should eagerly raise,
        # since the executor has been destroyed.
        with self.assertRaises(grpc.RpcError) as rpc_err:
            env.stub.CreateValue(
                executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                                value=value_proto))
        self.assertEqual(rpc_err.exception.code(),
                         grpc.StatusCode.FAILED_PRECONDITION)
        self.assertIn('No executor found',
                      rpc_err.exception.exception().details())
Exemplo n.º 9
0
    def test_create_tuple_of_value_sequence(self):
        datasets = (tf.data.Dataset.range(5), tf.data.Dataset.range(5))
        executor = executor_bindings.create_tensorflow_executor()
        struct_of_sequence_type = StructType([
            (None, SequenceType(datasets[0].element_spec)),
            (None, SequenceType(datasets[0].element_spec))
        ])
        arg_value_pb, _ = value_serialization.serialize_value(
            datasets, struct_of_sequence_type)
        arg = executor.create_value(arg_value_pb)

        @tensorflow_computation.tf_computation(struct_of_sequence_type)
        def preprocess(datasets):
            def double_value(x):
                return 2 * x

            @tf.function
            def add_preprocessing(ds1, ds2):
                return ds1.map(double_value), ds2.map(double_value)

            return add_preprocessing(*datasets)

        comp_pb = executor_pb2.Value(
            computation=preprocess.get_proto(preprocess))
        comp = executor.create_value(comp_pb)
        result = executor.create_call(comp.ref, arg.ref)
        output_pb = executor.materialize(result.ref)
        result, result_type_spec = value_serialization.deserialize_value(
            output_pb, type_hint=struct_of_sequence_type)
        type_test_utils.assert_types_identical(result_type_spec,
                                               struct_of_sequence_type)
 def test_serialize_sequence_bad_container_type(self):
   # A Python non-ODict container whose names are not in alphabetical order.
   bad_container = collections.namedtuple('BadContainer', 'y x')
   x = tf.data.Dataset.range(5).map(lambda x: bad_container(x, x * 2))
   with self.assertRaisesRegex(ValueError, r'non-OrderedDict container'):
     _ = value_serialization.serialize_value(
         x, computation_types.SequenceType(bad_container(tf.int64, tf.int64)))
 def test_serialize_deserialize_tensor_value_with_different_dtype(self):
   x = tf.constant(10)
   value_proto, value_type = value_serialization.serialize_value(
       x, TensorType(tf.float32))
   type_test_utils.assert_types_identical(value_type, TensorType(tf.float32))
   y, type_spec = value_serialization.deserialize_value(value_proto)
   type_test_utils.assert_types_identical(type_spec, TensorType(tf.float32))
   self.assertEqual(y, 10.0)
 def test_serialize_deserialize_nested_tuple_value_without_names(self):
   x = (10, 20)
   x_type = computation_types.to_type((tf.int32, tf.int32))
   value_proto, value_type = value_serialization.serialize_value(x, x_type)
   type_test_utils.assert_types_identical(value_type, x_type)
   y, type_spec = value_serialization.deserialize_value(value_proto)
   type_test_utils.assert_types_equivalent(type_spec, x_type)
   self.assertEqual(y, structure.from_container((10, 20)))
 def test_serialize_deserialize_variable_as_tensor_value(self):
   x = tf.Variable(10.0)
   type_spec = TensorType(tf.as_dtype(x.dtype), x.shape)
   value_proto, value_type = value_serialization.serialize_value(x, type_spec)
   type_test_utils.assert_types_identical(value_type, TensorType(tf.float32))
   y, type_spec = value_serialization.deserialize_value(value_proto)
   type_test_utils.assert_types_identical(type_spec, TensorType(tf.float32))
   self.assertAllEqual(x, y)
Exemplo n.º 14
0
    def test_executor_service_create_and_select_from_tuple(self):
        ex_factory = executor_test_utils.BasicTestExFactory(
            eager_tf_executor.EagerTFExecutor())
        env = TestEnv(ex_factory)

        value_proto, _ = value_serialization.serialize_value(10, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                            value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        ten_ref = response.value_ref
        self.assertEqual(env.get_value(ten_ref.id), 10)

        value_proto, _ = value_serialization.serialize_value(20, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                            value=value_proto))
        self.assertIsInstance(response, executor_pb2.CreateValueResponse)
        twenty_ref = response.value_ref
        self.assertEqual(env.get_value(twenty_ref.id), 20)

        response = env.stub.CreateStruct(
            executor_pb2.CreateStructRequest(
                executor=env.executor_pb,
                element=[
                    executor_pb2.CreateStructRequest.Element(
                        name='a', value_ref=ten_ref),
                    executor_pb2.CreateStructRequest.Element(
                        name='b', value_ref=twenty_ref)
                ]))
        self.assertIsInstance(response, executor_pb2.CreateStructResponse)
        tuple_ref = response.value_ref
        self.assertEqual(str(env.get_value(tuple_ref.id)), '<a=10,b=20>')

        for index, result_val in [(0, 10), (1, 20)]:
            response = env.stub.CreateSelection(
                executor_pb2.CreateSelectionRequest(executor=env.executor_pb,
                                                    source_ref=tuple_ref,
                                                    index=index))
            self.assertIsInstance(response,
                                  executor_pb2.CreateSelectionResponse)
            selection_ref = response.value_ref
            self.assertEqual(env.get_value(selection_ref.id), result_val)

        del env
Exemplo n.º 15
0
    async def invoke(self, comp, arg):
        compiled_comp = self._compiler_pipeline.compile(comp)
        serialized_comp, _ = value_serialization.serialize_value(
            compiled_comp, comp.type_signature)
        cardinalities = self._cardinality_inference_fn(
            arg, comp.type_signature.parameter)

        try:
            with self._reset_factory_on_error(self._executor_factory,
                                              cardinalities) as executor:
                fn = executor.create_value(serialized_comp)
                if arg is not None:
                    try:
                        serialized_arg, _ = value_serialization.serialize_value(
                            arg, comp.type_signature.parameter)
                    except Exception as e:
                        raise TypeError(
                            f'Failed to serialize argument:\n{arg}\nas a value of type:\n'
                            f'{comp.type_signature.parameter}') from e
                    arg_value = executor.create_value(serialized_arg)
                    call = executor.create_call(fn.ref, arg_value.ref)
                else:
                    call = executor.create_call(fn.ref, None)
                # Delaying grabbing the event loop til now ensures that the call below
                # is attached to the loop running the invoke.
                running_loop = asyncio.get_running_loop()
                result_pb = await running_loop.run_in_executor(
                    self._futures_executor_pool,
                    lambda: executor.materialize(call.ref))
        except absl_status.StatusNotOk as e:
            indent = lambda s: textwrap.indent(s, prefix='\t')
            if arg is None:
                arg_str = 'without any arguments'
            else:
                arg_str = f'with argument:\n{indent(pprint.pformat(arg))}'
            logging.error(
                'Error invoking computation with signature:\n%s\n%s\n',
                indent(comp.type_signature.formatted_representation()),
                arg_str)
            logging.error('Error: \n%s', e.status.message())
            raise e
        result_value, _ = value_serialization.deserialize_value(
            result_pb, comp.type_signature.result)
        return type_conversions.type_to_py_container(
            result_value, comp.type_signature.result)
 def test_serialize_deserialize_federated_at_server(self):
   x = 10
   x_type = computation_types.at_server(tf.int32)
   value_proto, value_type = value_serialization.serialize_value(x, x_type)
   type_test_utils.assert_types_identical(
       value_type, computation_types.at_server(tf.int32))
   y, type_spec = value_serialization.deserialize_value(value_proto)
   type_test_utils.assert_types_identical(type_spec, x_type)
   self.assertEqual(y, 10)
 def test_serialize_deserialize_tensor_value_with_nontrivial_shape(self):
   x = tf.constant([10, 20, 30])
   value_proto, value_type = value_serialization.serialize_value(
       x, TensorType(tf.int32, [3]))
   type_test_utils.assert_types_identical(value_type,
                                          TensorType(tf.int32, [3]))
   y, type_spec = value_serialization.deserialize_value(value_proto)
   type_test_utils.assert_types_identical(type_spec, TensorType(tf.int32, [3]))
   self.assertAllEqual(x, y)
 def test_serialize_deserialize_tensor_value_without_hint(
     self, x, serialize_type_spec):
   value_proto, value_type = value_serialization.serialize_value(
       x, serialize_type_spec)
   type_test_utils.assert_types_identical(value_type, serialize_type_spec)
   y, type_spec = value_serialization.deserialize_value(value_proto)
   type_test_utils.assert_types_identical(type_spec, serialize_type_spec)
   self.assertEqual(y.dtype, serialize_type_spec.dtype.as_numpy_dtype)
   self.assertAllEqual(x, y)
 def test_serialize_deserialize_tensor_value_unk_shape_without_hint(self):
   x = np.asarray([1., 2.])
   serialize_type_spec = TensorType(tf.float32, [None])
   value_proto, value_type = value_serialization.serialize_value(
       x, serialize_type_spec)
   type_test_utils.assert_type_assignable_from(value_type, serialize_type_spec)
   y, type_spec = value_serialization.deserialize_value(value_proto)
   type_test_utils.assert_type_assignable_from(serialize_type_spec, type_spec)
   self.assertEqual(y.dtype, serialize_type_spec.dtype.as_numpy_dtype)
   self.assertAllEqual(x, y)
 def test_serialize_deserialize_sequence_of_scalars_graph_mode(self):
   with tf.Graph().as_default():
     ds = tf.data.Dataset.range(5).map(lambda x: x * 2)
     value_proto, value_type = value_serialization.serialize_value(
         ds, computation_types.SequenceType(tf.int64))
   type_test_utils.assert_types_identical(
       value_type, computation_types.SequenceType(tf.int64))
   y, type_spec = value_serialization.deserialize_value(value_proto)
   type_test_utils.assert_types_identical(
       type_spec, computation_types.SequenceType(tf.int64))
   self.assertAllEqual(list(y), [x * 2 for x in range(5)])
 def test_serialize_deserialize_sequence_of_scalars(self, dataset_fn):
   ds = tf.data.Dataset.range(5).map(lambda x: x * 2)
   ds_repr = dataset_fn(ds)
   value_proto, value_type = value_serialization.serialize_value(
       ds_repr, computation_types.SequenceType(tf.int64))
   type_test_utils.assert_types_identical(
       value_type, computation_types.SequenceType(tf.int64))
   y, type_spec = value_serialization.deserialize_value(value_proto)
   type_test_utils.assert_types_identical(
       type_spec, computation_types.SequenceType(tf.int64))
   self.assertAllEqual(list(y), [x * 2 for x in range(5)])
 def test_serialize_deserialize_tensor_value_with_nontrivial_shape(self):
   x = tf.constant([10, 20, 30])
   value_proto, value_type = value_serialization.serialize_value(
       x, computation_types.TensorType(tf.int32, [3]))
   self.assertIsInstance(value_proto, serialization_bindings.Value)
   self.assert_types_identical(value_type,
                               computation_types.TensorType(tf.int32, [3]))
   y, type_spec = value_serialization.deserialize_value(value_proto)
   self.assert_types_identical(type_spec,
                               computation_types.TensorType(tf.int32, [3]))
   self.assertAllEqual(x, y)
 def test_serialize_deserialize_tensor_value(self):
   x = tf.constant(10.0)
   type_spec = computation_types.TensorType(tf.as_dtype(x.dtype), x.shape)
   value_proto, value_type = value_serialization.serialize_value(x, type_spec)
   self.assertIsInstance(value_proto, serialization_bindings.Value)
   self.assert_types_identical(value_type,
                               computation_types.TensorType(tf.float32))
   y, type_spec = value_serialization.deserialize_value(value_proto)
   self.assert_types_identical(type_spec,
                               computation_types.TensorType(tf.float32))
   self.assertAllEqual(x, y)
 def test_serialize_deserialize_tensor_value_with_different_dtype(self):
   x = tf.constant(10.0)
   value_proto, value_type = value_serialization.serialize_value(
       x, computation_types.TensorType(tf.int32))
   self.assertIsInstance(value_proto, serialization_bindings.Value)
   self.assert_types_identical(value_type,
                               computation_types.TensorType(tf.int32))
   y, type_spec = value_serialization.deserialize_value(value_proto)
   self.assert_types_identical(type_spec,
                               computation_types.TensorType(tf.int32))
   self.assertEqual(y, 10)
 def test_serialize_deserialize_sequence_of_scalars(self):
   ds = tf.data.Dataset.range(5).map(lambda x: x * 2)
   value_proto, value_type = value_serialization.serialize_value(
       ds, computation_types.SequenceType(tf.int64))
   self.assertIsInstance(value_proto, serialization_bindings.Value)
   self.assert_types_identical(value_type,
                               computation_types.SequenceType(tf.int64))
   y, type_spec = value_serialization.deserialize_value(value_proto)
   self.assert_types_identical(type_spec,
                               computation_types.SequenceType(tf.int64))
   self.assertAllEqual(list(y), [x * 2 for x in range(5)])
 def test_serialize_deserialize_federated_at_clients(self):
   x = [10, 20]
   x_type = computation_types.at_clients(tf.int32)
   value_proto, value_type = value_serialization.serialize_value(x, x_type)
   self.assertIsInstance(value_proto, serialization_bindings.Value)
   self.assert_types_identical(value_type,
                               computation_types.at_clients(tf.int32))
   y, type_spec = value_serialization.deserialize_value(value_proto)
   self.assert_types_identical(type_spec,
                               computation_types.at_clients(tf.int32))
   self.assertEqual(y, [10, 20])
 def test_serialize_deserialize_string_value(self, x):
   tf_type = tf.as_dtype(x.dtype)
   type_spec = TensorType(tf_type, x.shape)
   value_proto, value_type = value_serialization.serialize_value(x, type_spec)
   type_test_utils.assert_types_identical(value_type,
                                          TensorType(tf_type, x.shape))
   y, type_spec = value_serialization.deserialize_value(
       value_proto, type_hint=type_spec)
   type_test_utils.assert_types_identical(type_spec,
                                          TensorType(tf_type, x.shape))
   self.assertIsInstance(y, bytes)
   self.assertAllEqual(x, y)
 def test_serialize_deserialize_sequence_of_ragged_tensors(self, dataset_fn):
   self.skipTest('b/235492749')
   ds = tf.data.Dataset.from_tensor_slices(tf.strings.split(['a b c', 'd e']))
   ds_repr = dataset_fn(ds)
   value_proto, value_type = value_serialization.serialize_value(
       ds_repr, computation_types.SequenceType(element=ds.element_spec))
   expected_type = computation_types.SequenceType(ds.element_spec)
   type_test_utils.assert_types_identical(value_type, expected_type)
   _, type_spec = value_serialization.deserialize_value(value_proto)
   # Only checking for equivalence, we don't have the Python container
   # after deserialization.
   type_test_utils.assert_types_equivalent(type_spec, expected_type)
Exemplo n.º 29
0
    def test_executor_service_slowly_create_tensor_value(self):
        class SlowExecutorValue(executor_value_base.ExecutorValue):
            def __init__(self, v, t):
                self._v = v
                self._t = t

            @property
            def type_signature(self):
                return self._t

            async def compute(self):
                return self._v

        class SlowExecutor(executor_base.Executor):
            def __init__(self):
                self.status = 'idle'
                self.busy = threading.Event()
                self.done = threading.Event()

            async def create_value(self, value, type_spec=None):
                self.status = 'busy'
                self.busy.set()
                self.done.wait()
                self.status = 'done'
                return SlowExecutorValue(value, type_spec)

            async def create_call(self, comp, arg=None):
                raise NotImplementedError

            async def create_struct(self, elements):
                raise NotImplementedError

            async def create_selection(self, source, index):
                raise NotImplementedError

            def close(self):
                pass

        ex = SlowExecutor()
        ex_factory = executor_test_utils.BasicTestExFactory(ex)
        env = TestEnv(ex_factory)
        self.assertEqual(ex.status, 'idle')
        value_proto, _ = value_serialization.serialize_value(10, tf.int32)
        response = env.stub.CreateValue(
            executor_pb2.CreateValueRequest(executor=env.executor_pb,
                                            value=value_proto))
        ex.busy.wait()
        self.assertEqual(ex.status, 'busy')
        ex.done.set()
        value = env.get_value(response.value_ref.id)
        self.assertEqual(ex.status, 'done')
        self.assertEqual(value, 10)
 def test_serialize_deserialize_tensor_value_without_hint_graph_mode(
     self, x, serialize_type_spec):
   # Test serializing non-eager tensors.
   with tf.Graph().as_default():
     tensor_x = tf.convert_to_tensor(x)
     value_proto, value_type = value_serialization.serialize_value(
         tensor_x, serialize_type_spec)
   type_test_utils.assert_types_identical(value_type, serialize_type_spec)
   y, deserialize_type_spec = value_serialization.deserialize_value(
       value_proto)
   type_test_utils.assert_types_identical(deserialize_type_spec,
                                          serialize_type_spec)
   self.assertEqual(y.dtype, serialize_type_spec.dtype.as_numpy_dtype)
   self.assertAllEqual(x, y)