Example #1
0
class CreateConstantTest(parameterized.TestCase):

    # pyformat: disable
    @parameterized.named_parameters(
        ('int', 10, computation_types.TensorType(tf.int32, [3]), [10] * 3),
        ('float', 10.0, computation_types.TensorType(tf.float32,
                                                     [3]), [10.0] * 3),
        ('unnamed_tuple', 10, computation_types.StructType(
            [tf.int32] * 3), structure.Struct([(None, 10)] * 3)),
        ('named_tuple', 10,
         computation_types.StructType([
             ('a', tf.int32), ('b', tf.int32), ('c', tf.int32)
         ]), structure.Struct([('a', 10), ('b', 10), ('c', 10)])),
        ('nested_tuple', 10, computation_types.StructType(
            [[tf.int32] * 3] * 3),
         structure.Struct([(None, structure.Struct([(None, 10)] * 3))] * 3)),
    )
    # pyformat: enable
    def test_returns_computation(self, value, type_signature, expected_result):
        proto, _ = tensorflow_computation_factory.create_constant(
            value, type_signature)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = computation_types.FunctionType(None, type_signature)
        expected_type.check_assignable_from(actual_type)
        actual_result = test_utils.run_tensorflow(proto)
        if isinstance(expected_result, list):
            self.assertCountEqual(actual_result, expected_result)
        else:
            self.assertEqual(actual_result, expected_result)

    @parameterized.named_parameters(
        ('non_scalar_value', np.zeros(
            [1]), computation_types.TensorType(tf.int32)),
        ('none_type', 10, None),
        ('federated_type', 10, computation_types.at_server(tf.int32)),
        ('bad_type', 10.0, computation_types.TensorType(tf.int32)),
    )
    def test_raises_type_error(self, value, type_signature):
        with self.assertRaises(TypeError):
            tensorflow_computation_factory.create_constant(
                value, type_signature)
 def test_identity_with_structure(self):
     with tf.Graph().as_default() as graph:
         c1 = structure.Struct([('foo',
                                 tf.constant(10, dtype=tf.int32,
                                             shape=[]))])
         c2 = tf_computation_utils.identity(c1)
     self.assertIsNot(c2, c1)
     with tf.compat.v1.Session(graph=graph) as sess:
         result = sess.run(c2.foo)
     self.assertEqual(result, 10)
Example #3
0
  def test_federated_generic_divide_with_unnamed_tuples(self):
    bodies = intrinsic_bodies.get_intrinsic_bodies(
        context_stack_impl.context_stack)

    @computations.federated_computation(
        computation_types.FederatedType([tf.int32, tf.float32],
                                        placement_literals.CLIENTS))
    def foo(x):
      return bodies[intrinsic_defs.GENERIC_DIVIDE.uri]([x, x])

    self.assertEqual(
        str(foo.type_signature),
        '({<int32,float32>}@CLIENTS -> {<float64,float32>}@CLIENTS)')

    self.assertEqual(
        foo([[1, 1.]]), [structure.Struct([(None, 1.), (None, 1.)])])
    self.assertEqual(
        foo([[1, 1.], [1, 2.], [3, 3.]]),
        [structure.Struct([(None, 1.), (None, 1.)])] * 3)
Example #4
0
 def test_capture_result_with_struct_of_constants(self):
     t = self._checked_capture_result(
         structure.Struct([
             ('x', tf.constant(10)),
             (None, tf.constant(True)),
             ('y', tf.constant(0.66)),
         ]))
     self.assertEqual(str(t), '<x=int32,bool,y=float32>')
     self.assertIsInstance(t, computation_types.StructType)
     self.assertNotIsInstance(t, computation_types.StructWithPythonType)
Example #5
0
def _unwrap(value):
    if isinstance(value, tf.Tensor):
        return value.numpy()
    elif isinstance(value, structure.Struct):
        return structure.Struct(
            (k, _unwrap(v)) for k, v in structure.iter_elements(value))
    elif isinstance(value, list):
        return [_unwrap(v) for v in value]
    else:
        return value
Example #6
0
 def test_to_canonical_value_with_ordered_dict(self):
     self.assertEqual(
         type_utils.to_canonical_value(
             collections.OrderedDict([
                 ('a', 1),
                 ('b', 0.1),
             ])), structure.Struct([
                 ('a', 1),
                 ('b', 0.1),
             ]))
     self.assertEqual(
         type_utils.to_canonical_value(
             collections.OrderedDict([
                 ('b', 0.1),
                 ('a', 1),
             ])), structure.Struct([
                 ('b', 0.1),
                 ('a', 1),
             ]))
Example #7
0
    def test_setattr_federated_named_tuple_type_bool(self):
        self.skipTest(
            'TODO(b/148685415): Calling setatter on a FederatedType constructs a '
            'lambda passed to intrinsic containing a reference to a captured '
            'variable; a native execution stack can not handle this structure.'
        )

        @computations.federated_computation(
            computation_types.FederatedType([('a', tf.int32), ('b', tf.bool)],
                                            placement_literals.CLIENTS))
        def foo(x):
            x.b = False
            return x

        self.assertEqual(foo([[5, True], [0, False], [-5, True]]), [
            structure.Struct([('a', 5), ('b', False)]),
            structure.Struct([('a', 0), ('b', False)]),
            structure.Struct([('a', -5), ('b', False)])
        ])
Example #8
0
 def test_create_scalar_multiply_operator_float32(self):
     operand_type = computation_types.to_type(np.float32)
     scalar_type = computation_types.to_type(np.float32)
     comp, comp_type = self._factory.create_scalar_multiply_operator(
         operand_type, scalar_type)
     operand = np.float32(10.0)
     scalar = np.float32(5.0)
     arg = structure.Struct([(None, operand), (None, scalar)])
     result = self._run_comp(comp, comp_type, arg=arg)
     self.assertEqual(result, 50.0)
Example #9
0
def _unwrap_execution_context_value(val):
    """Recursively removes wrapping from `val` under anonymous tuples."""
    if isinstance(val, structure.Struct):
        value_elements_iter = structure.iter_elements(val)
        return structure.Struct((name, _unwrap_execution_context_value(elem))
                                for name, elem in value_elements_iter)
    elif isinstance(val, ExecutionContextValue):
        return _unwrap_execution_context_value(val.value)
    else:
        return val
Example #10
0
    def test_returns_computation(self):
        proto, _ = tensorflow_computation_factory.create_empty_tuple()

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = computation_types.FunctionType(None, [])
        expected_type.check_assignable_from(actual_type)
        actual_result = test_utils.run_tensorflow(proto)
        expected_result = structure.Struct([])
        self.assertEqual(actual_result, expected_result)
Example #11
0
 async def _make():
     v1 = await ex.create_value(add)
     v2 = await ex.create_value([0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                                float_type)
     v3 = await ex.create_value([0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                                int_type)
     v4 = await ex.create_struct(
         structure.Struct([(None, v2), (None, v3)]))
     v5 = await ex.create_call(v1, v4)
     return await v5.compute()
Example #12
0
 def test_anon_tuple_without_names_to_container_without_names(self):
   anon_tuple = structure.Struct([(None, 1), (None, 2.0)])
   types = [tf.int32, tf.float32]
   self.assertSequenceEqual(
       type_conversions.type_to_py_container(
           anon_tuple, computation_types.StructWithPythonType(types, list)),
       [1, 2.0])
   self.assertSequenceEqual(
       type_conversions.type_to_py_container(
           anon_tuple, computation_types.StructWithPythonType(types, tuple)),
       (1, 2.0))
Example #13
0
    def test_federated_weighted_mean_named_tuple_with_tensor(self):
        bodies = intrinsic_bodies.get_intrinsic_bodies(
            context_stack_impl.context_stack)

        @computations.federated_computation(
            computation_types.FederatedType([[('a', tf.float32),
                                              ('b', tf.float32)], tf.float32],
                                            placement_literals.CLIENTS))
        def foo(x):
            return bodies[intrinsic_defs.FEDERATED_WEIGHTED_MEAN.uri](x)

        self.assertEqual(
            str(foo.type_signature),
            '({<<a=float32,b=float32>,float32>}@CLIENTS -> <a=float32,b=float32>@SERVER)'
        )

        self.assertEqual(foo([[[1., 1.], 1.]]),
                         structure.Struct([('a', 1.), ('b', 1.)]))
        self.assertEqual(foo([[[1., 1.], 1.], [[1., 2.], 2.], [[1., 4.], 4.]]),
                         structure.Struct([('a', 1.), ('b', 3.)]))
Example #14
0
 async def _evaluate_struct(
     self,
     comp: pb.Computation,
     scope: ReferenceResolvingExecutorScope,
 ) -> ReferenceResolvingExecutorValue:
     names = [str(e.name) if e.name else None for e in comp.struct.element]
     values = [
         self._evaluate(e.value, scope=scope) for e in comp.struct.element
     ]
     values = await asyncio.gather(*values)
     return await self.create_struct(structure.Struct(zip(names, values)))
Example #15
0
 def test_pack_sequence_as_fails_non_struct(self):
   x = structure.Struct([
       ('a', 10),
       ('b', {
           'd': 20
       }),
       ('c', 30),
   ])
   y = [10, 20, 30]
   with self.assertRaisesRegex(TypeError, 'Cannot pack sequence'):
     _ = structure.pack_sequence_as(x, y)
Example #16
0
def _deserialize_struct_value(
    value_proto: executor_pb2.Value) -> _DeserializeReturnType:
  """Deserializes a value of struct type."""
  val_elems = []
  type_elems = []
  for e in value_proto.struct.element:
    name = e.name if e.name else None
    e_val, e_type = deserialize_value(e.value)
    val_elems.append((name, e_val))
    type_elems.append((name, e_type) if name else e_type)
  return (structure.Struct(val_elems), computation_types.StructType(type_elems))
Example #17
0
 async def create_struct(self, elements):
     elements = structure.iter_elements(structure.from_container(elements))
     val_elements = []
     type_elements = []
     for k, v in elements:
         py_typecheck.check_type(v, SequenceExecutorValue)
         val_elements.append((k, v.internal_representation))
         type_elements.append((k, v.type_signature))
     return SequenceExecutorValue(
         structure.Struct(val_elements),
         computation_types.StructType(type_elements))
 def test_infer_cardinalities_structure_failure(self):
     with self.assertRaisesRegex(ValueError, 'Conflicting cardinalities'):
         cardinalities_utils.infer_cardinalities(
             structure.Struct([('A', [1, 2, 3]), ('B', [1, 2])]),
             computation_types.StructType([
                 ('A',
                  computation_types.FederatedType(tf.int32,
                                                  placements.CLIENTS)),
                 ('B',
                  computation_types.FederatedType(tf.int32,
                                                  placements.CLIENTS))
             ]))
Example #19
0
    async def _embed_value_in_target_exec(
        self, value: ReferenceResolvingExecutorValue
    ) -> executor_value_base.ExecutorValue:
        """Inserts a value into the target executor.

    This function is called in order to prepare the argument being passed to a
    `self._target_executor.create_call`, which happens when a non-`Lambda`
    function is passed as the `comp` argument to `create_value` above.

    Args:
      value: An instance of `ReferenceResolvingExecutorValue`.

    Returns:
      An instance of `executor_value_base.ExecutorValue` that represents a value
      embedded in
      the target executor.

    Raises:
      RuntimeError: Upon encountering a request to delegate a computation that
        is in a form that cannot be delegated.
    """
        py_typecheck.check_type(value, ReferenceResolvingExecutorValue)
        value_repr = value.internal_representation
        if isinstance(value_repr, executor_value_base.ExecutorValue):
            return value_repr
        elif isinstance(value_repr, structure.Struct):
            vals = await asyncio.gather(
                *[self._embed_value_in_target_exec(v) for v in value_repr])
            return await self._target_executor.create_struct(
                structure.Struct(
                    zip((k for k, _ in structure.iter_elements(value_repr)),
                        vals)))
        else:
            py_typecheck.check_type(value_repr, ScopedLambda)
            # Pull `comp` out of the `ScopedLambda`, asserting that it doesn't
            # reference any scope variables. We don't have a way to replace the
            # references inside the lambda pb.Computation with actual computed values,
            # so we must throw an error in this case.
            unbound_refs = _unbound_refs(value_repr.comp)
            if len(unbound_refs) != 0:  # pylint: disable=g-explicit-length-test
                # Note: "passed to intrinsic" here is an assumption of what the user is
                # doing. Typechecking should reject a lambda passed to Tensorflow code,
                # and intrinsics are the only other functional construct in TFF.
                tree = building_blocks.ComputationBuildingBlock.from_proto(
                    value_repr.comp)
                raise RuntimeError(
                    'lambda passed to intrinsic contains references to captured '
                    'variables. This is not currently supported. For more information, '
                    'see b/148685415. '
                    'Found references {} in computation {} with type {}'.
                    format(unbound_refs, tree, tree.type_signature))
            return await self._target_executor.create_value(
                value_repr.comp, value.type_signature)
Example #20
0
    def test_returns_computation(self, type_signature, count, value):
        proto, _ = tensorflow_computation_factory.create_replicate_input(
            type_signature, count)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = computation_types.FunctionType(
            type_signature, [type_signature] * count)
        expected_type.check_assignable_from(actual_type)
        actual_result = test_utils.run_tensorflow(proto, value)
        expected_result = structure.Struct([(None, value)] * count)
        self.assertEqual(actual_result, expected_result)
    def test_tuple_argument_can_accept_unnamed_elements(self):
        @tensorflow_computation.tf_computation(tf.int32, tf.int32)
        def foo(x, y):
            return x + y

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            # pylint:disable=no-value-for-parameter
            result = foo(structure.Struct([(None, 2), (None, 3)]))
            # pylint:enable=no-value-for-parameter

        self.assertEqual(result, 5)
    def test_federated_weighted_mean(self):
        @computations.federated_computation(type_factory.at_clients(
            tf.float32), type_factory.at_clients(tf.float32))
        def comp(x, y):
            return intrinsics.federated_mean(x, y)

        executor, num_clients = _create_test_executor()
        arg = structure.Struct([('x',
                                 [float(x + 1) for x in range(num_clients)]),
                                ('y', [1.0, 2.0, 3.0] * 4)])
        result = _invoke(executor, comp, arg)
        self.assertAlmostEqual(result, 6.83333333333, places=3)
Example #23
0
 async def create_struct(self, elements):
   val_elements = []
   type_elements = []
   for k, v in structure.iter_elements(structure.from_container(elements)):
     py_typecheck.check_type(v, XlaValue)
     val_elements.append((k, v.internal_representation))
     type_elements.append((k, v.type_signature))
   struct_val = structure.Struct(val_elements)
   struct_type = computation_types.StructType([
       (k, v) if k is not None else v for k, v in type_elements
   ])
   return XlaValue(struct_val, struct_type, self._backend)
Example #24
0
 def test_create_scalar_multiply_operator_2xfloat32(self):
     operand_type = computation_types.to_type(
         collections.OrderedDict([('a', np.float32), ('b', np.float32)]))
     scalar_type = computation_types.to_type(np.float32)
     comp, comp_type = self._factory.create_scalar_multiply_operator(
         operand_type, scalar_type)
     operand = collections.OrderedDict([('a', np.float32(10.0)),
                                        ('b', np.float32(11.0))])
     scalar = np.float32(5.0)
     arg = structure.Struct([(None, operand), (None, scalar)])
     result = self._run_comp(comp, comp_type, arg=arg)
     self.assertEqual(str(result), '<a=50.0,b=55.0>')
Example #25
0
  def test_federated_generic_add_with_named_tuples(self):
    bodies = intrinsic_bodies.get_intrinsic_bodies(
        context_stack_impl.context_stack)

    @computations.federated_computation(
        computation_types.FederatedType([('a', tf.int32), ('b', tf.float32)],
                                        placement_literals.CLIENTS))
    def foo(x):
      return bodies[intrinsic_defs.GENERIC_PLUS.uri]([x, x])

    self.assertEqual(
        str(foo.type_signature),
        '({<a=int32,b=float32>}@CLIENTS -> {<a=int32,b=float32>}@CLIENTS)')

    self.assertEqual(foo([[1, 1.]]), [structure.Struct([('a', 2), ('b', 2.)])])
    self.assertEqual(
        foo([[1, 1.], [1, 2.], [1, 3.]]), [
            structure.Struct([('a', 2), ('b', 2.)]),
            structure.Struct([('a', 2), ('b', 4.)]),
            structure.Struct([('a', 2), ('b', 6.)])
        ])
Example #26
0
    def assert_is_add_two_implied_name_args_fn(self, fn):
        expected = Result(
            arg=structure.Struct([('x', 10), ('y', 20)]),
            arg_type=computation_types.to_type(
                collections.OrderedDict(x=tffint32, y=tffint32)),
            zero_result=0,
        )

        self.assertEqual(fn(10, 20), expected, 'without names')
        self.assertEqual(fn(x=10, y=20), expected, 'with names')
        self.assertEqual(fn(y=20, x=10), expected, 'with names reversed')
        self.assertEqual(fn(10, y=20), expected, 'with only one name')
Example #27
0
    def test_tuple_argument_can_accept_unnamed_elements(self):
        @computations.tf_computation(tf.int32, tf.int32)
        def foo(x, y):
            return x + y

        executor = executor_stacks.local_executor_factory()
        with executor_test_utils.install_executor(executor):
            # pylint:disable=no-value-for-parameter
            result = foo(structure.Struct([(None, 2), (None, 3)]))
            # pylint:enable=no-value-for-parameter

        self.assertEqual(result, 5)
  def test_unnamed_tuple(self):
    ex = sizing_executor.SizingExecutor(eager_tf_executor.EagerTFExecutor())
    type_spec = computation_types.StructType([tf.int32, tf.int32])
    value = structure.Struct([(None, 0), (None, 1)])

    async def _make():
      v1 = await ex.create_value(value, type_spec)
      return await v1.compute()

    asyncio.get_event_loop().run_until_complete(_make())
    self.assertCountEqual(ex.broadcast_history, [[1, tf.int32], [1, tf.int32]])
    self.assertCountEqual(ex.aggregate_history, [[1, tf.int32], [1, tf.int32]])
Example #29
0
 def test_struct(self):
     t = computation_types.to_type(
         structure.Struct((
             (None, tf.int32),
             ('b', tf.int64),
         )))
     self.assertEqual(
         t,
         computation_types.StructType([
             (None, computation_types.TensorType(tf.int32)),
             ('b', computation_types.TensorType(tf.int64))
         ]))
Example #30
0
 def test_to_value_for_structure(self):
   x = value_impl.ValueImpl(
       building_blocks.Reference('foo', tf.int32),
       context_stack_impl.context_stack)
   y = value_impl.ValueImpl(
       building_blocks.Reference('bar', tf.bool),
       context_stack_impl.context_stack)
   v = value_impl.to_value(
       structure.Struct([('a', x), ('b', y)]), None,
       context_stack_impl.context_stack)
   self.assertIsInstance(v, value_base.Value)
   self.assertEqual(str(v), '<a=foo,b=bar>')