Exemplo n.º 1
0
 def test_create_selection(self):
     executor = executor_bindings.create_reference_resolving_executor(
         executor_bindings.create_tensorflow_executor())
     expected_type_spec = TensorType(shape=[3], dtype=tf.int64)
     value_pb, _ = value_serialization.serialize_value(
         tf.constant([1, 2, 3]), expected_type_spec)
     value = executor.create_value(value_pb)
     self.assertEqual(value.ref, 0)
     # 1. Create a struct from duplicated values.
     struct_value = executor.create_struct([value.ref, value.ref])
     self.assertEqual(struct_value.ref, 1)
     materialized_value = executor.materialize(struct_value.ref)
     deserialized_value, type_spec = value_serialization.deserialize_value(
         materialized_value)
     struct_type_spec = computation_types.to_type(
         [expected_type_spec, expected_type_spec])
     type_test_utils.assert_types_equivalent(type_spec, struct_type_spec)
     deserialized_value = type_conversions.type_to_py_container(
         deserialized_value, struct_type_spec)
     self.assertAllClose([(1, 2, 3), (1, 2, 3)], deserialized_value)
     # 2. Select the first value out of the struct.
     new_value = executor.create_selection(struct_value.ref, 0)
     materialized_value = executor.materialize(new_value.ref)
     deserialized_value, type_spec = value_serialization.deserialize_value(
         materialized_value)
     type_test_utils.assert_types_equivalent(type_spec, expected_type_spec)
     deserialized_value = type_conversions.type_to_py_container(
         deserialized_value, struct_type_spec)
     self.assertAllClose((1, 2, 3), deserialized_value)
Exemplo n.º 2
0
 def _executor_fn(
     cardinalities: executor_factory.CardinalitiesType
 ) -> executor_bindings.Executor:
     if cardinalities.get(placements.CLIENTS) is None:
         cardinalities[placements.CLIENTS] = default_num_clients
     num_clients = cardinalities[placements.CLIENTS]
     if max_concurrent_computation_calls > 0 and num_clients > max_concurrent_computation_calls:
         expected_concurrency_factor = math.ceil(
             num_clients / max_concurrent_computation_calls)
         _log_and_warn_on_sequential_execution(
             max_concurrent_computation_calls, num_clients,
             expected_concurrency_factor)
     tf_executor = executor_bindings.create_tensorflow_executor(
         max_concurrent_computation_calls)
     sub_federating_reference_resolving_executor = executor_bindings.create_reference_resolving_executor(
         tf_executor)
     federating_ex = executor_bindings.create_federating_executor(
         sub_federating_reference_resolving_executor, cardinalities)
     top_level_reference_resolving_ex = executor_bindings.create_reference_resolving_executor(
         federating_ex)
     return top_level_reference_resolving_ex
Exemplo n.º 3
0
    def test_create_value(self):
        executor = executor_bindings.create_reference_resolving_executor(
            executor_bindings.create_tensorflow_executor())
        # 1. Test a simple tensor.
        expected_type_spec = TensorType(shape=[3], dtype=tf.int64)
        value_pb, _ = value_serialization.serialize_value([1, 2, 3],
                                                          expected_type_spec)
        value = executor.create_value(value_pb)
        self.assertIsInstance(value, executor_bindings.OwnedValueId)
        self.assertEqual(value.ref, 0)
        self.assertEqual(str(value), '0')
        self.assertEqual(repr(value), r'<OwnedValueId: 0>')
        materialized_value = executor.materialize(value.ref)
        deserialized_value, type_spec = value_serialization.deserialize_value(
            materialized_value)
        type_test_utils.assert_types_identical(type_spec, expected_type_spec)
        self.assertAllEqual(deserialized_value, [1, 2, 3])
        # 2. Test a struct of tensors, ensure that we get a different ID.
        expected_type_spec = StructType([
            ('a', TensorType(shape=[3], dtype=tf.int64)),
            ('b', TensorType(shape=[], dtype=tf.float32))
        ])
        value_pb, _ = value_serialization.serialize_value(
            collections.OrderedDict(a=tf.constant([1, 2, 3]),
                                    b=tf.constant(42.0)), expected_type_spec)
        value = executor.create_value(value_pb)
        self.assertIsInstance(value, executor_bindings.OwnedValueId)
        # Assert the value ID was incremented.
        self.assertEqual(value.ref, 1)
        self.assertEqual(str(value), '1')
        self.assertEqual(repr(value), r'<OwnedValueId: 1>')
        materialized_value = executor.materialize(value.ref)
        deserialized_value, type_spec = value_serialization.deserialize_value(
            materialized_value)
        # Note: here we've lost the names `a` and `b` in the output. The output
        # is a more _strict_ type.
        self.assertTrue(expected_type_spec.is_assignable_from(type_spec))
        deserialized_value = type_conversions.type_to_py_container(
            deserialized_value, expected_type_spec)
        self.assertAllClose(deserialized_value,
                            collections.OrderedDict(a=(1, 2, 3), b=42.0))
        # 3. Test creating a value from a computation.
        @tensorflow_computation.tf_computation(tf.int32, tf.int32)
        def foo(a, b):
            return tf.add(a, b)

        value_pb, _ = value_serialization.serialize_value(foo)
        value = executor.create_value(value_pb)
        self.assertIsInstance(value, executor_bindings.OwnedValueId)
        # Assert the value ID was incremented again.
        self.assertEqual(value.ref, 2)
        self.assertEqual(str(value), '2')
        self.assertEqual(repr(value), '<OwnedValueId: 2>')
  def test_call_no_arg(self):
    executor = executor_bindings.create_reference_resolving_executor(
        executor_bindings.create_tensorflow_executor())

    @computations.tf_computation
    def foo():
      return tf.constant(123.0)

    comp_pb = serialization_bindings.Value(computation=foo.get_proto(foo))
    comp = executor.create_value(comp_pb)
    result = executor.create_call(comp.ref, None)
    result_value_pb = executor.materialize(result.ref)
    result_tensor, _ = value_serialization.deserialize_value(result_value_pb)
    self.assertEqual(result_tensor, 123.0)
  def test_call_with_arg(self):
    executor = executor_bindings.create_reference_resolving_executor(
        executor_bindings.create_tensorflow_executor())
    value_pb, _ = value_serialization.serialize_value(
        tf.constant([1, 2, 3]), TensorType(shape=[3], dtype=tf.int64))
    value_ref = executor.create_value(value_pb)
    arg = executor.create_struct((value_ref.ref, value_ref.ref))

    @computations.tf_computation(tf.int64, tf.int64)
    def foo(a, b):
      return tf.add(a, b)

    comp_pb = serialization_bindings.Value(computation=foo.get_proto(foo))
    comp = executor.create_value(comp_pb)
    result = executor.create_call(comp.ref, arg.ref)
    result_value_pb = executor.materialize(result.ref)
    result_tensor, _ = value_serialization.deserialize_value(result_value_pb)
    self.assertAllEqual(result_tensor, [2, 4, 6])
Exemplo n.º 6
0
 def test_create(self):
     try:
         executor_bindings.create_reference_resolving_executor(
             executor_bindings.create_tensorflow_executor())
     except Exception as e:  # pylint: disable=broad-except
         self.fail(f'Exception: {e}')