def test_federated_min_on_nested_scalars(self):
    tuple_type = computation_types.StructType([
        ('x', tf.float32),
        ('y', tf.float32),
    ])

    @computations.federated_computation(
        computation_types.FederatedType(tuple_type, placements.CLIENTS))
    def call_federated_min(value):
      return federated_aggregations.federated_min(value)

    test_type = collections.namedtuple('NestedScalars', ['x', 'y'])
    value = call_federated_min(
        [test_type(0.0, 1.0),
         test_type(-1.0, 5.0),
         test_type(2.0, -10.0)])
    self.assertDictEqual(value._asdict(), {'x': -1.0, 'y': -10.0})
    def test_returns_value_with_source_and_index_structure(self):
        executor = create_test_executor()
        element, element_type = executor_test_utils.create_dummy_value_unplaced(
        )

        element = self.run_sync(executor.create_value(element, element_type))
        elements = [element] * 3
        type_signature = computation_types.StructType([element_type] * 3)
        source = self.run_sync(executor.create_struct(elements))
        result = self.run_sync(executor.create_selection(source, 0))

        self.assertIsInstance(result, executor_value_base.ExecutorValue)
        self.assertEqual(result.type_signature.compact_representation(),
                         type_signature[0].compact_representation())
        actual_result = self.run_sync(result.compute())
        expected_result = self.run_sync(source.compute())[0]
        self.assertEqual(actual_result, expected_result)
    def test_returns_value_with_elements_fn_and_arg(self, fn, fn_type, arg,
                                                    arg_type):
        executor = create_test_executor()

        fn = self.run_sync(executor.create_value(fn, fn_type))
        arg = self.run_sync(executor.create_value(arg, arg_type))
        element = self.run_sync(executor.create_call(fn, arg))
        elements = [element] * 3
        type_signature = computation_types.StructType([fn_type.result] * 3)
        result = self.run_sync(executor.create_struct(elements))

        self.assertIsInstance(result, executor_value_base.ExecutorValue)
        self.assertEqual(result.type_signature.compact_representation(),
                         type_signature.compact_representation())
        actual_result = self.run_sync(result.compute())
        expected_result = [self.run_sync(element.compute())] * 3
        self.assertCountEqual(actual_result, expected_result)
Beispiel #4
0
    def test_federated_max_on_nested_scalars(self):
        tuple_type = computation_types.StructType([
            ('a', tf.int32),
            ('b', tf.int32),
        ])

        @computations.federated_computation(
            computation_types.FederatedType(tuple_type, placements.CLIENTS))
        def call_federated_max(value):
            return federated_aggregations.federated_max(value)

        test_type = collections.namedtuple('NestedScalars', ['a', 'b'])
        value = call_federated_max(
            [test_type(1, 5),
             test_type(2, 3),
             test_type(1, 8)])
        self.assertDictEqual(value._asdict(), {'a': 2, 'b': 8})
    def test_generic_add_with_unplaced_named_tuple_and_tensor(self):
        bodies = intrinsic_bodies.get_intrinsic_bodies(
            context_stack_impl.context_stack)

        @computations.federated_computation(
            computation_types.StructType([[('a', tf.float32),
                                           ('b', tf.float32)], tf.float32]))
        def foo(x):
            return bodies[intrinsic_defs.GENERIC_PLUS.uri](x)

        self.assertEqual(
            str(foo.type_signature),
            '(<<a=float32,b=float32>,float32> -> <a=float32,b=float32>)')

        self.assertEqual(
            foo([[1., 1.], 1.]),
            anonymous_tuple.AnonymousTuple([('a', 2.), ('b', 2.)]))
Beispiel #6
0
 def test_create_selection_does_not_cache_error(self):
   loop = asyncio.get_event_loop()
   mock_executor = mock.create_autospec(executor_base.Executor)
   mock_executor.create_value.side_effect = create_test_value
   mock_executor.create_selection.side_effect = raise_error
   cached_executor = caching_executor.CachingExecutor(mock_executor)
   value = loop.run_until_complete(
       cached_executor.create_value((1, 2),
                                    computation_types.StructType(
                                        (tf.int32, tf.int32))))
   with self.assertRaises(TestError):
     _ = loop.run_until_complete(cached_executor.create_selection(value, 1))
   with self.assertRaises(TestError):
     _ = loop.run_until_complete(cached_executor.create_selection(value, 1))
   # Ensure create_struct was called twice on the mock (not cached and only
   # called once).
   mock_executor.create_selection.assert_has_calls([])
Beispiel #7
0
 def test_call_tf_comp_with_int_tuple(self):
     comp = computations.tf_computation(lambda x, y: x + y, tf.int32,
                                        tf.int32)
     comp_pb = computation_impl.ComputationImpl.get_proto(comp)
     comp_type = comp.type_signature
     comp_val = _run_sync(
         self._sequence_executor.create_value(comp_pb, comp_type))
     arg = collections.OrderedDict([('a', 10), ('b', 20)])
     arg_type = computation_types.StructType(
         collections.OrderedDict([('a', tf.int32), ('b', tf.int32)]))
     arg_val = _run_sync(self._sequence_executor.create_value(
         arg, arg_type))
     result_val = _run_sync(
         self._sequence_executor.create_call(comp_val, arg_val))
     self.assertIsInstance(result_val,
                           sequence_executor.SequenceExecutorValue)
     self.assertEqual(str(result_val.type_signature), 'int32')
     self.assertEqual(_run_sync(result_val.compute()), 30)
 async def compute_federated_zip_at_server(
     self,
     arg: FederatedComposingStrategyValue) -> FederatedComposingStrategyValue:
   py_typecheck.check_type(arg.type_signature, computation_types.StructType)
   py_typecheck.check_len(arg.type_signature, 2)
   py_typecheck.check_type(arg.internal_representation, structure.Struct)
   py_typecheck.check_len(arg.internal_representation, 2)
   for n in [0, 1]:
     type_analysis.check_federated_type(
         arg.type_signature[n],
         placement=placement_literals.SERVER,
         all_equal=True)
   return FederatedComposingStrategyValue(
       await self._server_executor.create_struct(
           [arg.internal_representation[n] for n in [0, 1]]),
       computation_types.at_server(
           computation_types.StructType(
               [arg.type_signature[0].member, arg.type_signature[1].member])))
Beispiel #9
0
    def test_n_tuple_federated_zip_tensor_args(self, n):
        fed_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)
        initial_tuple_type = computation_types.StructType([fed_type] * n)
        final_fed_type = computation_types.FederatedType([tf.int32] * n,
                                                         placements.CLIENTS)
        function_type = computation_types.FunctionType(initial_tuple_type,
                                                       final_fed_type)

        @computations.federated_computation(
            [computation_types.FederatedType(tf.int32, placements.CLIENTS)] * n
        )
        def foo(x):
            val = intrinsics.federated_zip(x)
            self.assertIsInstance(val, value_base.Value)
            return val

        self.assert_type(foo, function_type.compact_representation())
 async def compute_federated_sum(
     self,
     arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
   py_typecheck.check_type(arg.type_signature, computation_types.FederatedType)
   zero, plus = await asyncio.gather(
       executor_utils.embed_tf_constant(self._executor,
                                        arg.type_signature.member, 0),
       executor_utils.embed_tf_binary_operator(self._executor,
                                               arg.type_signature.member,
                                               tf.add))
   return await self.compute_federated_reduce(
       FederatedResolvingStrategyValue(
           structure.Struct([(None, arg.internal_representation),
                             (None, zero.internal_representation),
                             (None, plus.internal_representation)]),
           computation_types.StructType(
               (arg.type_signature, zero.type_signature, plus.type_signature)))
   )
Beispiel #11
0
  def test_with_selection_by_index(self):
    ex, _ = _make_executor_and_tracer_for_test()
    loop = asyncio.get_event_loop()

    v1 = loop.run_until_complete(
        ex.create_value([10, 20],
                        computation_types.StructType([tf.int32, tf.int32])))
    self.assertEqual(str(v1.identifier), '1')
    v2 = loop.run_until_complete(ex.create_selection(v1, index=0))
    self.assertEqual(str(v2.identifier), '1[0]')
    v3 = loop.run_until_complete(ex.create_selection(v1, index=1))
    self.assertEqual(str(v3.identifier), '1[1]')
    v4 = loop.run_until_complete(ex.create_selection(v1, index=0))
    self.assertIs(v4, v2)
    v5 = loop.run_until_complete(ex.create_selection(v1, index=1))
    self.assertIs(v5, v3)
    c5 = loop.run_until_complete(v5.compute())
    self.assertEqual(c5.numpy(), 20)
Beispiel #12
0
  def test_returns_value_with_source_and_name_structure(self):
    executor = create_test_executor()
    element, element_type = executor_test_utils.create_dummy_value_unplaced()

    names = ['a', 'b', 'c']
    element = self.run_sync(executor.create_value(element, element_type))
    elements = structure.Struct((n, element) for n in names)
    type_signature = computation_types.StructType(
        (n, element_type) for n in names)
    source = self.run_sync(executor.create_struct(elements))
    result = self.run_sync(executor.create_selection(source, name='a'))

    self.assertIsInstance(result, executor_value_base.ExecutorValue)
    self.assertEqual(result.type_signature.compact_representation(),
                     type_signature['a'].compact_representation())
    actual_result = self.run_sync(result.compute())
    expected_result = self.run_sync(source.compute())['a']
    self.assertEqual(actual_result, expected_result)
 def test_is_assignable_from(self):
     t1 = computation_types.StructType([tf.int32, ('a', tf.bool)])
     t2 = computation_types.StructType([tf.int32, ('a', tf.bool)])
     t3 = computation_types.StructType([tf.int32, ('b', tf.bool)])
     t4 = computation_types.StructType([tf.int32, ('a', tf.string)])
     t5 = computation_types.StructType([tf.int32])
     t6 = computation_types.StructType([tf.int32, tf.bool])
     self.assertTrue(t1.is_assignable_from(t2))
     self.assertFalse(t1.is_assignable_from(t3))
     self.assertFalse(t1.is_assignable_from(t4))
     self.assertFalse(t1.is_assignable_from(t5))
     self.assertTrue(t1.is_assignable_from(t6))
     self.assertFalse(t6.is_assignable_from(t1))
Beispiel #14
0
 def test_to_representation_for_type_with_noarg_to_2xint32_comp(self):
     builder = xla_client.XlaBuilder('comp')
     xla_client.ops.Parameter(builder, 0,
                              xla_client.shape_from_pyval(tuple()))
     xla_client.ops.Tuple(builder, [
         xla_client.ops.Constant(builder, np.int32(10)),
         xla_client.ops.Constant(builder, np.int32(20))
     ])
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType(
         None,
         computation_types.StructType([('a', np.int32), ('b', np.int32)]))
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [0, 1], comp_type)
     rep = executor.to_representation_for_type(comp_pb, comp_type,
                                               self._backend)
     self.assertTrue(callable(rep))
     result = rep()
     self.assertEqual(str(result), '<a=10,b=20>')
 async def create_struct(self, elements):
   constructed_anon_tuple = structure.from_container(elements)
   proto_elem = []
   type_elem = []
   for k, v in structure.iter_elements(constructed_anon_tuple):
     py_typecheck.check_type(v, RemoteValue)
     proto_elem.append(
         executor_pb2.CreateStructRequest.Element(
             name=(k if k else None), value_ref=v.value_ref))
     type_elem.append((k, v.type_signature) if k else v.type_signature)
   result_type = computation_types.StructType(type_elem)
   request = executor_pb2.CreateStructRequest(element=proto_elem)
   if self._bidi_stream is None:
     response = _request(self._stub.CreateStruct, request)
   else:
     response = (await self._bidi_stream.send_request(
         executor_pb2.ExecuteRequest(create_struct=request))).create_struct
   py_typecheck.check_type(response, executor_pb2.CreateStructResponse)
   return RemoteValue(response.value_ref, result_type, self)
  def test_federated_max_nested_tensor_value(self):
    tuple_type = computation_types.StructType([
        ('a', (tf.int32, [2])),
        ('b', (tf.int32, [3])),
    ])

    @computations.federated_computation(
        computation_types.FederatedType(tuple_type, placements.CLIENTS))
    def call_federated_max(value):
      return federated_aggregations.federated_max(value)

    test_type = collections.namedtuple('NestedScalars', ['a', 'b'])
    client1 = test_type(
        np.array([4, 5], dtype=np.int32), np.array([1, -2, 3], dtype=np.int32))
    client2 = test_type(
        np.array([9, 0], dtype=np.int32), np.array([5, 1, -2], dtype=np.int32))
    value = call_federated_max([client1, client2])
    self.assertCountEqual(value[0], [9, 5])
    self.assertCountEqual(value[1], [5, 1, 3])
Beispiel #17
0
    def test_ordered_dict(self):
        a = computation_types.TensorType(tf.string, [4])
        b = computation_types.TensorType(tf.int64, [2, 3])
        tup = computation_types.StructType([('a', a), ('b', b)])
        ex = sizing_executor.SizingExecutor(
            eager_tf_executor.EagerTFExecutor())
        od = collections.OrderedDict()
        od['a'] = ['some', 'arbitrary', 'string', 'here']
        od['b'] = [[3, 4, 1], [6, 8, -5]]
        total_string_length = sum([len(s) for s in od['a']])

        async def _make():
            v1 = await ex.create_value(od, tup)
            return await v1.compute()

        asyncio.get_event_loop().run_until_complete(_make())
        self.assertCountEqual(
            ex.broadcast_history,
            [[total_string_length, tf.string], [6, tf.int64]])
Beispiel #18
0
class IsAverageCompatibleTest(parameterized.TestCase):
    @parameterized.named_parameters([
        ('tensor_type_float32', computation_types.TensorType(tf.float32)),
        ('tensor_type_float64', computation_types.TensorType(tf.float64)),
        ('tuple_type',
         computation_types.StructType([('x', tf.float32), ('y', tf.float64)])),
        ('federated_type',
         computation_types.FederatedType(tf.float32, placements.CLIENTS)),
    ])
    def test_returns_true(self, type_spec):
        self.assertTrue(type_analysis.is_average_compatible(type_spec))

    @parameterized.named_parameters([
        ('tensor_type_int32', computation_types.TensorType(tf.int32)),
        ('tensor_type_int64', computation_types.TensorType(tf.int64)),
        ('sequence_type', computation_types.SequenceType(tf.float32)),
    ])
    def test_returns_false(self, type_spec):
        self.assertFalse(type_analysis.is_average_compatible(type_spec))
  def test_serialize_jax_with_two_args(self):

    ctx_stack = context_stack_impl.context_stack
    param_type = computation_types.StructType([('a', np.int32),
                                               ('b', np.int32)])
    arg_func = lambda arg: ([], {'x': arg[0], 'y': arg[1]})

    def traced_func(x, y):
      return x + y

    comp_pb = jax_serialization.serialize_jax_computation(
        traced_func, arg_func, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(str(type_spec), '(<a=int32,b=int32> -> int32)')
    xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module)
    self.assertEqual(
        xla_comp.as_hlo_text(),
        # pylint: disable=line-too-long
        'HloModule xla_computation_traced_func__3.7\n\n'
        'ENTRY xla_computation_traced_func__3.7 {\n'
        '  constant.4 = pred[] constant(false)\n'
        '  parameter.1 = (s32[], s32[]) parameter(0)\n'
        '  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n'
        '  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n'
        '  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n'
        '  ROOT tuple.6 = (s32[]) tuple(add.5)\n'
        '}\n\n')
    self.assertEqual(
        str(comp_pb.xla.parameter), 'struct {\n'
        '  element {\n'
        '    tensor {\n'
        '      index: 0\n'
        '    }\n'
        '  }\n'
        '  element {\n'
        '    tensor {\n'
        '      index: 1\n'
        '    }\n'
        '  }\n'
        '}\n')
    self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' '  index: 0\n' '}\n')
Beispiel #20
0
 async def create_struct(self, elements):
     if not isinstance(elements, structure.Struct):
         elements = structure.from_container(elements)
     element_strings = []
     element_kv_pairs = structure.to_elements(elements)
     to_gather = []
     type_elements = []
     for k, v in element_kv_pairs:
         py_typecheck.check_type(v, CachedValue)
         to_gather.append(v.target_future)
         if k is not None:
             py_typecheck.check_type(k, str)
             element_strings.append('{}={}'.format(k, v.identifier))
             type_elements.append((k, v.type_signature))
         else:
             element_strings.append(str(v.identifier))
             type_elements.append(v.type_signature)
     type_spec = computation_types.StructType(type_elements)
     gathered = await asyncio.gather(*to_gather)
     identifier = CachedValueIdentifier('<{}>'.format(
         ','.join(element_strings)))
     try:
         cached_value = self._cache[identifier]
     except KeyError:
         target_future = asyncio.ensure_future(
             self._target_executor.create_struct(
                 structure.Struct(
                     (k, v)
                     for (k, _), v in zip(element_kv_pairs, gathered))))
         cached_value = CachedValue(identifier, None, type_spec,
                                    target_future)
         self._cache[identifier] = cached_value
     try:
         target_value = await cached_value.target_future
     except Exception:
         # TODO(b/145514490): This is a bit heavy handed, there maybe caches where
         # only the current cache item needs to be invalidated; however this
         # currently only occurs when an inner RemoteExecutor has the backend go
         # down.
         self._cache = {}
         raise
     type_spec.check_assignable_from(target_value.type_signature)
     return cached_value
 def test_serialize_type_with_tensor_tuple(self):
     type_signature = computation_types.StructType([
         ('x', tf.int32),
         ('y', tf.string),
         tf.float32,
         ('z', tf.bool),
     ])
     actual_proto = type_serialization.serialize_type(type_signature)
     expected_proto = pb.Type(struct=pb.StructType(element=[
         pb.StructType.Element(name='x',
                               value=_create_scalar_tensor_type(tf.int32)),
         pb.StructType.Element(name='y',
                               value=_create_scalar_tensor_type(tf.string)),
         pb.StructType.Element(
             value=_create_scalar_tensor_type(tf.float32)),
         pb.StructType.Element(name='z',
                               value=_create_scalar_tensor_type(tf.bool)),
     ]))
     self.assertEqual(actual_proto, expected_proto)
Beispiel #22
0
def _create_xla_binary_op_computation(type_spec, xla_binary_op_constructor):
    """Helper for constructing computations that implement binary operators.

  The constructed computation is of type `(<T,T> -> T)`, where `T` is the type
  of the operand (`type_spec`).

  Args:
    type_spec: The type of a single operand.
    xla_binary_op_constructor: A two-argument callable that constructs a binary
      xla op from tensor parameters (such as `xla_client.ops.Add` or similar).

  Returns:
    An instance of `local_computation_factory_base.ComputationProtoAndType`.

  Raises:
    ValueError: if the arguments are invalid.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if not type_analysis.is_structure_of_tensors(type_spec):
        raise ValueError('Not a tensor or a structure of tensors: {}'.format(
            str(type_spec)))

    tensor_shapes = _xla_tensor_shape_list_from_from_tff_tensor_or_struct_type(
        type_spec)
    num_tensors = len(tensor_shapes)
    builder = xla_client.XlaBuilder('comp')
    param = xla_client.ops.Parameter(
        builder, 0, xla_client.Shape.tuple_shape(tensor_shapes * 2))
    result_tensors = []
    for idx in range(num_tensors):
        result_tensors.append(
            xla_binary_op_constructor(
                xla_client.ops.GetTupleElement(param, idx),
                xla_client.ops.GetTupleElement(param, idx + num_tensors)))
    xla_client.ops.Tuple(builder, result_tensors)
    xla_computation = builder.build()

    comp_type = computation_types.FunctionType(
        computation_types.StructType([(None, type_spec)] * 2), type_spec)
    comp_pb = xla_serialization.create_xla_tff_computation(
        xla_computation, list(range(2 * num_tensors)), comp_type)
    return (comp_pb, comp_type)
Beispiel #23
0
 def test_with_two_level_tuple(self):
     type_signature = computation_types.StructWithPythonType([
         ('a', tf.bool),
         ('b',
          computation_types.StructWithPythonType([
              ('c', computation_types.TensorType(tf.float32)),
              ('d', computation_types.TensorType(tf.int32, [20])),
          ], collections.OrderedDict)),
         ('e', computation_types.StructType([])),
     ], collections.OrderedDict)
     tensor_specs = type_conversions.type_to_tf_tensor_specs(type_signature)
     test.assert_nested_struct_eq(
         tensor_specs, {
             'a': tf.TensorSpec([], tf.bool),
             'b': {
                 'c': tf.TensorSpec([], tf.float32),
                 'd': tf.TensorSpec([20], tf.int32)
             },
             'e': (),
         })
Beispiel #24
0
    async def create_struct(self, elements):
        """Creates a tuple of `elements`.

    Args:
      elements: As documented in `executor_base.Executor`.

    Returns:
      An instance of `EagerValue` that represents the constructed tuple.
    """
        elements = structure.to_elements(structure.from_container(elements))
        val_elements = []
        type_elements = []
        for k, v in elements:
            py_typecheck.check_type(v, EagerValue)
            val_elements.append((k, v.internal_representation))
            type_elements.append((k, v.type_signature))
        return EagerValue(
            structure.Struct(val_elements), self._tf_function_cache,
            computation_types.StructType([(k, v) if k is not None else v
                                          for k, v in type_elements]))
Beispiel #25
0
  def test_serialize_type_with_nested_tuple(self):
    type_signature = computation_types.StructType([
        ('x', [('y', [('z', tf.bool)])]),
    ])
    actual_proto = type_serialization.serialize_type(type_signature)

    def _tuple_type_proto(elements):
      return pb.Type(struct=pb.StructType(element=elements))

    z_proto = pb.StructType.Element(
        name='z', value=_create_scalar_tensor_type(tf.bool))
    expected_proto = _tuple_type_proto([
        pb.StructType.Element(
            name='x',
            value=_tuple_type_proto([
                pb.StructType.Element(
                    name='y', value=_tuple_type_proto([z_proto]))
            ]))
    ])
    self.assertEqual(actual_proto, expected_proto)
 def test_str(self):
   self.assertEqual(str(computation_types.StructType([tf.int32])), '<int32>')
   self.assertEqual(
       str(computation_types.StructType([('a', tf.int32)])), '<a=int32>')
   self.assertEqual(
       str(computation_types.StructType(('a', tf.int32))), '<a=int32>')
   self.assertEqual(
       str(computation_types.StructType([tf.int32, tf.bool])), '<int32,bool>')
   self.assertEqual(
       str(computation_types.StructType([('a', tf.int32), tf.float32])),
       '<a=int32,float32>')
   self.assertEqual(
       str(computation_types.StructType([('a', tf.int32), ('b', tf.float32)])),
       '<a=int32,b=float32>')
   self.assertEqual(
       str(
           computation_types.StructType([('a', tf.int32),
                                         ('b',
                                          computation_types.StructType([
                                              ('x', tf.string), ('y', tf.bool)
                                          ]))])),
       '<a=int32,b=<x=string,y=bool>>')
    def test_returns_computation(self, operator, type_signature, operands,
                                 expected_result):
        # TODO(b/142795960): arguments in parameterized are called before test main.
        # `tf.constant` will error out on GPU and TPU without proper initialization.
        # A suggested workaround is to use numpy as argument and transform to TF
        # tensor inside the function.
        operands = tf.nest.map_structure(tf.constant, operands)
        proto, _ = tensorflow_computation_factory.create_binary_operator_with_upcast(
            type_signature, operator)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        self.assertIsInstance(actual_type, computation_types.FunctionType)
        # Note: It is only useful to test the parameter type; the result type
        # depends on the `operator` used, not the implemenation
        # `create_binary_operator_with_upcast`.
        expected_parameter_type = computation_types.StructType(type_signature)
        self.assertEqual(actual_type.parameter, expected_parameter_type)
        actual_result = test_utils.run_tensorflow(proto, operands)
        self.assertEqual(actual_result, expected_result)
    def test_replaces_lambda_to_called_graph_on_tuple_of_selections_from_arg_with_tf_of_same_type_with_names(
            self):
        identity_tf_block_type = computation_types.StructType(
            [tf.int32, tf.bool])
        identity_tf_block = building_block_factory.create_compiled_identity(
            identity_tf_block_type)
        tuple_ref = building_blocks.Reference('x', [('a', tf.int32),
                                                    ('b', tf.float32),
                                                    ('c', tf.bool)])
        selected_int = building_blocks.Selection(tuple_ref, index=0)
        selected_bool = building_blocks.Selection(tuple_ref, index=2)
        created_tuple = building_blocks.Struct([selected_int, selected_bool])
        called_tf_block = building_blocks.Call(identity_tf_block,
                                               created_tuple)
        lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32),
                                                      ('b', tf.float32),
                                                      ('c', tf.bool)],
                                                called_tf_block)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)
        self.assertEqual(exec_lambda({
            'a': 9,
            'b': 10.,
            'c': False
        }), exec_tf({
            'a': 9,
            'b': 10.,
            'c': False
        }))
Beispiel #29
0
    def test_fails_with_bad_types(self):
        function = computation_types.FunctionType(
            None, computation_types.TensorType(tf.int32))
        federated = computation_types.FederatedType(tf.int32,
                                                    placement_literals.CLIENTS)
        tuple_on_function = computation_types.StructType([federated, function])

        def foo(x):  # pylint: disable=unused-variable
            del x  # Unused.

        with self.assertRaisesRegex(
                TypeError,
                r'you have attempted to create one with the type {int32}@CLIENTS'
        ):
            computation_wrapper_instances.tensorflow_wrapper(foo, federated)

        # pylint: disable=anomalous-backslash-in-string
        with self.assertRaisesRegex(
                TypeError,
                r'you have attempted to create one with the type \( -> int32\)'
        ):
            computation_wrapper_instances.tensorflow_wrapper(foo, function)

        with self.assertRaisesRegex(
                TypeError,
                r'you have attempted to create one with the type placement'):
            computation_wrapper_instances.tensorflow_wrapper(
                foo, computation_types.PlacementType())

        with self.assertRaisesRegex(
                TypeError,
                r'you have attempted to create one with the type T'):
            computation_wrapper_instances.tensorflow_wrapper(
                foo, computation_types.AbstractType('T'))

        with self.assertRaisesRegex(
                TypeError,
                r'you have attempted to create one with the type <{int32}@CLIENTS,\( '
                '-> int32\)>'):
            computation_wrapper_instances.tensorflow_wrapper(
                foo, tuple_on_function)
Beispiel #30
0
 def test_create_selection_does_not_cache_error_avoids_double_cache_delete(
         self):
     loop = asyncio.get_event_loop()
     mock_executor = mock.create_autospec(executor_base.Executor)
     mock_executor.create_value.side_effect = create_test_value
     mock_executor.create_selection.side_effect = raise_error
     cached_executor = caching_executor.CachingExecutor(mock_executor)
     value = loop.run_until_complete(
         cached_executor.create_value((1, 2),
                                      computation_types.StructType(
                                          (tf.int32, tf.int32))))
     future1 = cached_executor.create_selection(value, 1)
     future2 = cached_executor.create_selection(value, 1)
     results = loop.run_until_complete(
         asyncio.gather(future1, future2, return_exceptions=True))
     # Ensure create_struct was called twice on the mock (not cached and only
     # called once).
     mock_executor.create_selection.assert_has_calls([])
     self.assertLen(results, 2)
     self.assertIsInstance(results[0], TestError)
     self.assertIsInstance(results[1], TestError)