class CreateConstantTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters( ('int', 10, computation_types.TensorType(tf.int32, [3]), [10] * 3), ('float', 10.0, computation_types.TensorType(tf.float32, [3]), [10.0] * 3), ('unnamed_tuple', 10, computation_types.StructType( [tf.int32] * 3), structure.Struct([(None, 10)] * 3)), ('named_tuple', 10, computation_types.StructType([ ('a', tf.int32), ('b', tf.int32), ('c', tf.int32) ]), structure.Struct([('a', 10), ('b', 10), ('c', 10)])), ('nested_tuple', 10, computation_types.StructType( [[tf.int32] * 3] * 3), structure.Struct([(None, structure.Struct([(None, 10)] * 3))] * 3)), ) # pyformat: enable def test_returns_computation(self, value, type_signature, expected_result): proto, _ = tensorflow_computation_factory.create_constant( value, type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, type_signature) expected_type.check_assignable_from(actual_type) actual_result = test_utils.run_tensorflow(proto) if isinstance(expected_result, list): self.assertCountEqual(actual_result, expected_result) else: self.assertEqual(actual_result, expected_result) @parameterized.named_parameters( ('non_scalar_value', np.zeros( [1]), computation_types.TensorType(tf.int32)), ('none_type', 10, None), ('federated_type', 10, computation_types.at_server(tf.int32)), ('bad_type', 10.0, computation_types.TensorType(tf.int32)), ) def test_raises_type_error(self, value, type_signature): with self.assertRaises(TypeError): tensorflow_computation_factory.create_constant( value, type_signature)
def test_identity_with_structure(self): with tf.Graph().as_default() as graph: c1 = structure.Struct([('foo', tf.constant(10, dtype=tf.int32, shape=[]))]) c2 = tf_computation_utils.identity(c1) self.assertIsNot(c2, c1) with tf.compat.v1.Session(graph=graph) as sess: result = sess.run(c2.foo) self.assertEqual(result, 10)
def test_federated_generic_divide_with_unnamed_tuples(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType([tf.int32, tf.float32], placement_literals.CLIENTS)) def foo(x): return bodies[intrinsic_defs.GENERIC_DIVIDE.uri]([x, x]) self.assertEqual( str(foo.type_signature), '({<int32,float32>}@CLIENTS -> {<float64,float32>}@CLIENTS)') self.assertEqual( foo([[1, 1.]]), [structure.Struct([(None, 1.), (None, 1.)])]) self.assertEqual( foo([[1, 1.], [1, 2.], [3, 3.]]), [structure.Struct([(None, 1.), (None, 1.)])] * 3)
def test_capture_result_with_struct_of_constants(self): t = self._checked_capture_result( structure.Struct([ ('x', tf.constant(10)), (None, tf.constant(True)), ('y', tf.constant(0.66)), ])) self.assertEqual(str(t), '<x=int32,bool,y=float32>') self.assertIsInstance(t, computation_types.StructType) self.assertNotIsInstance(t, computation_types.StructWithPythonType)
def _unwrap(value): if isinstance(value, tf.Tensor): return value.numpy() elif isinstance(value, structure.Struct): return structure.Struct( (k, _unwrap(v)) for k, v in structure.iter_elements(value)) elif isinstance(value, list): return [_unwrap(v) for v in value] else: return value
def test_to_canonical_value_with_ordered_dict(self): self.assertEqual( type_utils.to_canonical_value( collections.OrderedDict([ ('a', 1), ('b', 0.1), ])), structure.Struct([ ('a', 1), ('b', 0.1), ])) self.assertEqual( type_utils.to_canonical_value( collections.OrderedDict([ ('b', 0.1), ('a', 1), ])), structure.Struct([ ('b', 0.1), ('a', 1), ]))
def test_setattr_federated_named_tuple_type_bool(self): self.skipTest( 'TODO(b/148685415): Calling setatter on a FederatedType constructs a ' 'lambda passed to intrinsic containing a reference to a captured ' 'variable; a native execution stack can not handle this structure.' ) @computations.federated_computation( computation_types.FederatedType([('a', tf.int32), ('b', tf.bool)], placement_literals.CLIENTS)) def foo(x): x.b = False return x self.assertEqual(foo([[5, True], [0, False], [-5, True]]), [ structure.Struct([('a', 5), ('b', False)]), structure.Struct([('a', 0), ('b', False)]), structure.Struct([('a', -5), ('b', False)]) ])
def test_create_scalar_multiply_operator_float32(self): operand_type = computation_types.to_type(np.float32) scalar_type = computation_types.to_type(np.float32) comp, comp_type = self._factory.create_scalar_multiply_operator( operand_type, scalar_type) operand = np.float32(10.0) scalar = np.float32(5.0) arg = structure.Struct([(None, operand), (None, scalar)]) result = self._run_comp(comp, comp_type, arg=arg) self.assertEqual(result, 50.0)
def _unwrap_execution_context_value(val): """Recursively removes wrapping from `val` under anonymous tuples.""" if isinstance(val, structure.Struct): value_elements_iter = structure.iter_elements(val) return structure.Struct((name, _unwrap_execution_context_value(elem)) for name, elem in value_elements_iter) elif isinstance(val, ExecutionContextValue): return _unwrap_execution_context_value(val.value) else: return val
def test_returns_computation(self): proto, _ = tensorflow_computation_factory.create_empty_tuple() self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, []) expected_type.check_assignable_from(actual_type) actual_result = test_utils.run_tensorflow(proto) expected_result = structure.Struct([]) self.assertEqual(actual_result, expected_result)
async def _make(): v1 = await ex.create_value(add) v2 = await ex.create_value([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], float_type) v3 = await ex.create_value([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], int_type) v4 = await ex.create_struct( structure.Struct([(None, v2), (None, v3)])) v5 = await ex.create_call(v1, v4) return await v5.compute()
def test_anon_tuple_without_names_to_container_without_names(self): anon_tuple = structure.Struct([(None, 1), (None, 2.0)]) types = [tf.int32, tf.float32] self.assertSequenceEqual( type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType(types, list)), [1, 2.0]) self.assertSequenceEqual( type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType(types, tuple)), (1, 2.0))
def test_federated_weighted_mean_named_tuple_with_tensor(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType([[('a', tf.float32), ('b', tf.float32)], tf.float32], placement_literals.CLIENTS)) def foo(x): return bodies[intrinsic_defs.FEDERATED_WEIGHTED_MEAN.uri](x) self.assertEqual( str(foo.type_signature), '({<<a=float32,b=float32>,float32>}@CLIENTS -> <a=float32,b=float32>@SERVER)' ) self.assertEqual(foo([[[1., 1.], 1.]]), structure.Struct([('a', 1.), ('b', 1.)])) self.assertEqual(foo([[[1., 1.], 1.], [[1., 2.], 2.], [[1., 4.], 4.]]), structure.Struct([('a', 1.), ('b', 3.)]))
async def _evaluate_struct( self, comp: pb.Computation, scope: ReferenceResolvingExecutorScope, ) -> ReferenceResolvingExecutorValue: names = [str(e.name) if e.name else None for e in comp.struct.element] values = [ self._evaluate(e.value, scope=scope) for e in comp.struct.element ] values = await asyncio.gather(*values) return await self.create_struct(structure.Struct(zip(names, values)))
def test_pack_sequence_as_fails_non_struct(self): x = structure.Struct([ ('a', 10), ('b', { 'd': 20 }), ('c', 30), ]) y = [10, 20, 30] with self.assertRaisesRegex(TypeError, 'Cannot pack sequence'): _ = structure.pack_sequence_as(x, y)
def _deserialize_struct_value( value_proto: executor_pb2.Value) -> _DeserializeReturnType: """Deserializes a value of struct type.""" val_elems = [] type_elems = [] for e in value_proto.struct.element: name = e.name if e.name else None e_val, e_type = deserialize_value(e.value) val_elems.append((name, e_val)) type_elems.append((name, e_type) if name else e_type) return (structure.Struct(val_elems), computation_types.StructType(type_elems))
async def create_struct(self, elements): elements = structure.iter_elements(structure.from_container(elements)) val_elements = [] type_elements = [] for k, v in elements: py_typecheck.check_type(v, SequenceExecutorValue) val_elements.append((k, v.internal_representation)) type_elements.append((k, v.type_signature)) return SequenceExecutorValue( structure.Struct(val_elements), computation_types.StructType(type_elements))
def test_infer_cardinalities_structure_failure(self): with self.assertRaisesRegex(ValueError, 'Conflicting cardinalities'): cardinalities_utils.infer_cardinalities( structure.Struct([('A', [1, 2, 3]), ('B', [1, 2])]), computation_types.StructType([ ('A', computation_types.FederatedType(tf.int32, placements.CLIENTS)), ('B', computation_types.FederatedType(tf.int32, placements.CLIENTS)) ]))
async def _embed_value_in_target_exec( self, value: ReferenceResolvingExecutorValue ) -> executor_value_base.ExecutorValue: """Inserts a value into the target executor. This function is called in order to prepare the argument being passed to a `self._target_executor.create_call`, which happens when a non-`Lambda` function is passed as the `comp` argument to `create_value` above. Args: value: An instance of `ReferenceResolvingExecutorValue`. Returns: An instance of `executor_value_base.ExecutorValue` that represents a value embedded in the target executor. Raises: RuntimeError: Upon encountering a request to delegate a computation that is in a form that cannot be delegated. """ py_typecheck.check_type(value, ReferenceResolvingExecutorValue) value_repr = value.internal_representation if isinstance(value_repr, executor_value_base.ExecutorValue): return value_repr elif isinstance(value_repr, structure.Struct): vals = await asyncio.gather( *[self._embed_value_in_target_exec(v) for v in value_repr]) return await self._target_executor.create_struct( structure.Struct( zip((k for k, _ in structure.iter_elements(value_repr)), vals))) else: py_typecheck.check_type(value_repr, ScopedLambda) # Pull `comp` out of the `ScopedLambda`, asserting that it doesn't # reference any scope variables. We don't have a way to replace the # references inside the lambda pb.Computation with actual computed values, # so we must throw an error in this case. unbound_refs = _unbound_refs(value_repr.comp) if len(unbound_refs) != 0: # pylint: disable=g-explicit-length-test # Note: "passed to intrinsic" here is an assumption of what the user is # doing. Typechecking should reject a lambda passed to Tensorflow code, # and intrinsics are the only other functional construct in TFF. tree = building_blocks.ComputationBuildingBlock.from_proto( value_repr.comp) raise RuntimeError( 'lambda passed to intrinsic contains references to captured ' 'variables. This is not currently supported. For more information, ' 'see b/148685415. ' 'Found references {} in computation {} with type {}'. format(unbound_refs, tree, tree.type_signature)) return await self._target_executor.create_value( value_repr.comp, value.type_signature)
def test_returns_computation(self, type_signature, count, value): proto, _ = tensorflow_computation_factory.create_replicate_input( type_signature, count) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType( type_signature, [type_signature] * count) expected_type.check_assignable_from(actual_type) actual_result = test_utils.run_tensorflow(proto, value) expected_result = structure.Struct([(None, value)] * count) self.assertEqual(actual_result, expected_result)
def test_tuple_argument_can_accept_unnamed_elements(self): @tensorflow_computation.tf_computation(tf.int32, tf.int32) def foo(x, y): return x + y executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): # pylint:disable=no-value-for-parameter result = foo(structure.Struct([(None, 2), (None, 3)])) # pylint:enable=no-value-for-parameter self.assertEqual(result, 5)
def test_federated_weighted_mean(self): @computations.federated_computation(type_factory.at_clients( tf.float32), type_factory.at_clients(tf.float32)) def comp(x, y): return intrinsics.federated_mean(x, y) executor, num_clients = _create_test_executor() arg = structure.Struct([('x', [float(x + 1) for x in range(num_clients)]), ('y', [1.0, 2.0, 3.0] * 4)]) result = _invoke(executor, comp, arg) self.assertAlmostEqual(result, 6.83333333333, places=3)
async def create_struct(self, elements): val_elements = [] type_elements = [] for k, v in structure.iter_elements(structure.from_container(elements)): py_typecheck.check_type(v, XlaValue) val_elements.append((k, v.internal_representation)) type_elements.append((k, v.type_signature)) struct_val = structure.Struct(val_elements) struct_type = computation_types.StructType([ (k, v) if k is not None else v for k, v in type_elements ]) return XlaValue(struct_val, struct_type, self._backend)
def test_create_scalar_multiply_operator_2xfloat32(self): operand_type = computation_types.to_type( collections.OrderedDict([('a', np.float32), ('b', np.float32)])) scalar_type = computation_types.to_type(np.float32) comp, comp_type = self._factory.create_scalar_multiply_operator( operand_type, scalar_type) operand = collections.OrderedDict([('a', np.float32(10.0)), ('b', np.float32(11.0))]) scalar = np.float32(5.0) arg = structure.Struct([(None, operand), (None, scalar)]) result = self._run_comp(comp, comp_type, arg=arg) self.assertEqual(str(result), '<a=50.0,b=55.0>')
def test_federated_generic_add_with_named_tuples(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType([('a', tf.int32), ('b', tf.float32)], placement_literals.CLIENTS)) def foo(x): return bodies[intrinsic_defs.GENERIC_PLUS.uri]([x, x]) self.assertEqual( str(foo.type_signature), '({<a=int32,b=float32>}@CLIENTS -> {<a=int32,b=float32>}@CLIENTS)') self.assertEqual(foo([[1, 1.]]), [structure.Struct([('a', 2), ('b', 2.)])]) self.assertEqual( foo([[1, 1.], [1, 2.], [1, 3.]]), [ structure.Struct([('a', 2), ('b', 2.)]), structure.Struct([('a', 2), ('b', 4.)]), structure.Struct([('a', 2), ('b', 6.)]) ])
def assert_is_add_two_implied_name_args_fn(self, fn): expected = Result( arg=structure.Struct([('x', 10), ('y', 20)]), arg_type=computation_types.to_type( collections.OrderedDict(x=tffint32, y=tffint32)), zero_result=0, ) self.assertEqual(fn(10, 20), expected, 'without names') self.assertEqual(fn(x=10, y=20), expected, 'with names') self.assertEqual(fn(y=20, x=10), expected, 'with names reversed') self.assertEqual(fn(10, y=20), expected, 'with only one name')
def test_tuple_argument_can_accept_unnamed_elements(self): @computations.tf_computation(tf.int32, tf.int32) def foo(x, y): return x + y executor = executor_stacks.local_executor_factory() with executor_test_utils.install_executor(executor): # pylint:disable=no-value-for-parameter result = foo(structure.Struct([(None, 2), (None, 3)])) # pylint:enable=no-value-for-parameter self.assertEqual(result, 5)
def test_unnamed_tuple(self): ex = sizing_executor.SizingExecutor(eager_tf_executor.EagerTFExecutor()) type_spec = computation_types.StructType([tf.int32, tf.int32]) value = structure.Struct([(None, 0), (None, 1)]) async def _make(): v1 = await ex.create_value(value, type_spec) return await v1.compute() asyncio.get_event_loop().run_until_complete(_make()) self.assertCountEqual(ex.broadcast_history, [[1, tf.int32], [1, tf.int32]]) self.assertCountEqual(ex.aggregate_history, [[1, tf.int32], [1, tf.int32]])
def test_struct(self): t = computation_types.to_type( structure.Struct(( (None, tf.int32), ('b', tf.int64), ))) self.assertEqual( t, computation_types.StructType([ (None, computation_types.TensorType(tf.int32)), ('b', computation_types.TensorType(tf.int64)) ]))
def test_to_value_for_structure(self): x = value_impl.ValueImpl( building_blocks.Reference('foo', tf.int32), context_stack_impl.context_stack) y = value_impl.ValueImpl( building_blocks.Reference('bar', tf.bool), context_stack_impl.context_stack) v = value_impl.to_value( structure.Struct([('a', x), ('b', y)]), None, context_stack_impl.context_stack) self.assertIsInstance(v, value_base.Value) self.assertEqual(str(v), '<a=foo,b=bar>')