def test_raises_type_error_with_none_accumulate(self): value_type = computation_types.FederatedType(tf.int32, placements.CLIENTS, False) value = computation_building_blocks.Data('v', value_type) zero = computation_building_blocks.Data('z', tf.int32) merge_type = computation_types.NamedTupleType((tf.int32, tf.int32)) merge_result = computation_building_blocks.Data('m', tf.int32) merge = computation_building_blocks.Lambda('x', merge_type, merge_result) report_ref = computation_building_blocks.Reference('r', tf.int32) report = computation_building_blocks.Lambda(report_ref.name, report_ref.type_signature, report_ref) with self.assertRaises(TypeError): computation_constructing_utils.create_federated_aggregate( value, zero, None, merge, report)
def test_two_tuple_zip_with_named_client_all_equal_int_and_bool(self): test_ref = computation_building_blocks.Reference( 'test', computation_types.NamedTupleType([ ('a', computation_types.FederatedType(tf.int32, placements.CLIENTS, True)), ('b', computation_types.FederatedType(tf.bool, placements.CLIENTS, True)) ])) zipped = value_utils.zip_two_tuple( value_impl.to_value(test_ref, None, _context_stack), _context_stack) self.assertEqual(str(zipped.type_signature), '{<a=int32,b=bool>}@CLIENTS')
def _extract_from_selection(comp): """Returns a new computation with all intrinsics extracted.""" if _is_called_intrinsic(comp.source): called_intrinsic = comp.source name = six.next(name_generator) variables = ((name, called_intrinsic),) result = computation_building_blocks.Reference( name, called_intrinsic.type_signature) else: block = comp.source variables = block.locals result = block.result selection = computation_building_blocks.Selection( result, name=comp.name, index=comp.index) block = computation_building_blocks.Block(variables, selection) return _extract_from_block(block)
def _create_lambda_to_identity(dtype): r"""Creates a lambda to return the argument. Lambda(x) \ Reference(x) Args: dtype: The type of the argument. Returns: An instance of `computation_building_blocks.Lambda`. """ arg = computation_building_blocks.Reference('arg', dtype) return computation_building_blocks.Lambda(arg.name, arg.type_signature, arg)
def test_value_impl_with_lambda(self): arg_name = 'arg' arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)), ('x', tf.int32)] result_value = (lambda arg: arg.f(arg.f(arg.x)))(value_impl.ValueImpl( computation_building_blocks.Reference(arg_name, arg_type), context_stack_impl.context_stack)) x = value_impl.ValueImpl( computation_building_blocks.Lambda( arg_name, arg_type, value_impl.ValueImpl.get_comp(result_value)), context_stack_impl.context_stack) self.assertIsInstance(x, value_base.Value) self.assertEqual(str(x.type_signature), '(<f=(int32 -> int32),x=int32> -> int32)') self.assertEqual(str(x), '(arg -> arg.f(arg.f(arg.x)))')
def _extract_from_tuple(comp): """Returns a new computation with all intrinsics extracted.""" variables = [] elements = [] for name, element in anonymous_tuple.to_elements(comp): if _is_called_intrinsic_or_block(element): variable_name = six.next(name_generator) variables.append((variable_name, element)) ref = computation_building_blocks.Reference(variable_name, element.type_signature) elements.append((name, ref)) else: elements.append((name, element)) tup = computation_building_blocks.Tuple(elements) block = computation_building_blocks.Block(variables, tup) return _extract_from_block(block)
def create_dummy_called_intrinsic(parameter_name, parameter_type=tf.int32): r"""Returns a dummy called intrinsic. Call / \ intrinsic Ref(x) Args: parameter_name: The name of the parameter. parameter_type: The type of the parameter. """ intrinsic_type = computation_types.FunctionType(parameter_type, parameter_type) intrinsic = computation_building_blocks.Intrinsic('intrinsic', intrinsic_type) ref = computation_building_blocks.Reference(parameter_name, parameter_type) return computation_building_blocks.Call(intrinsic, ref)
def test_flatten_function(self, n): input_reference = computation_building_blocks.Reference( 'test', [tf.int32] * n) input_function = computation_building_blocks.Lambda( 'test', input_reference.type_signature, input_reference) type_to_add = (None, computation_types.to_type(tf.int32)) input_type = computation_types.NamedTupleType( [input_reference.type_signature, type_to_add]) desired_output_type = computation_types.to_type([tf.int32] * (n + 1)) desired_function_type = computation_types.FunctionType( input_type, desired_output_type) new_func = value_utils.flatten_first_index( value_impl.to_value(input_function, None, _context_stack), type_to_add, _context_stack) self.assertEqual(str(new_func.type_signature), str(desired_function_type))
def test_propogates_dependence_into_binding_to_reference(self): fed_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) ref_to_x = computation_building_blocks.Reference('x', fed_type) federated_zero = computation_building_blocks.Intrinsic( intrinsic_defs.GENERIC_ZERO.uri, fed_type) def federated_zero_predicate(x): return isinstance(x, computation_building_blocks.Intrinsic ) and x.uri == intrinsic_defs.GENERIC_ZERO.uri block = computation_building_blocks.Block([('x', federated_zero)], ref_to_x) dependent_nodes = tree_analysis.extract_nodes_consuming( block, federated_zero_predicate) self.assertIn(ref_to_x, dependent_nodes)
def test_basic_functionality_of_selection_class(self): x = computation_building_blocks.Reference('foo', [('bar', tf.int32), ('baz', tf.bool)]) y = computation_building_blocks.Selection(x, name='bar') self.assertEqual(y.name, 'bar') self.assertEqual(y.index, None) self.assertEqual(str(y.type_signature), 'int32') self.assertEqual( repr(y), 'Selection(Reference(\'foo\', NamedTupleType([' '(\'bar\', TensorType(tf.int32)), (\'baz\', TensorType(tf.bool))]))' ', name=\'bar\')') self.assertEqual(computation_building_blocks.compact_representation(y), 'foo.bar') z = computation_building_blocks.Selection(x, name='baz') self.assertEqual(str(z.type_signature), 'bool') self.assertEqual(computation_building_blocks.compact_representation(z), 'foo.baz') with self.assertRaises(ValueError): _ = computation_building_blocks.Selection(x, name='bak') x0 = computation_building_blocks.Selection(x, index=0) self.assertEqual(x0.name, None) self.assertEqual(x0.index, 0) self.assertEqual(str(x0.type_signature), 'int32') self.assertEqual( repr(x0), 'Selection(Reference(\'foo\', NamedTupleType([' '(\'bar\', TensorType(tf.int32)), (\'baz\', TensorType(tf.bool))]))' ', index=0)') self.assertEqual( computation_building_blocks.compact_representation(x0), 'foo[0]') x1 = computation_building_blocks.Selection(x, index=1) self.assertEqual(str(x1.type_signature), 'bool') self.assertEqual( computation_building_blocks.compact_representation(x1), 'foo[1]') with self.assertRaises(ValueError): _ = computation_building_blocks.Selection(x, index=2) with self.assertRaises(ValueError): _ = computation_building_blocks.Selection(x, index=-1) y_proto = y.proto self.assertEqual(type_serialization.deserialize_type(y_proto.type), y.type_signature) self.assertEqual(y_proto.WhichOneof('computation'), 'selection') self.assertEqual(str(y_proto.selection.source), str(x.proto)) self.assertEqual(y_proto.selection.name, 'bar') self._serialize_deserialize_roundtrip_test(y) self._serialize_deserialize_roundtrip_test(z) self._serialize_deserialize_roundtrip_test(x0) self._serialize_deserialize_roundtrip_test(x1)
def test_returns_string_for_call_with_arg(self): ref = computation_building_blocks.Reference('a', tf.int32) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, ref) arg = computation_building_blocks.Data('data', tf.int32) comp = computation_building_blocks.Call(fn, arg) compact_string = comp.compact_representation() self.assertEqual(compact_string, '(a -> a)(data)') formatted_string = comp.formatted_representation() self.assertEqual(formatted_string, '(a -> a)(data)') structural_string = comp.structural_representation() # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' 'Lambda(a) data\n' '|\n' 'Ref(a)')
def test_raises_type_error_with_none_value(self): zero = computation_building_blocks.Data('z', tf.int32) accumulate_type = computation_types.NamedTupleType( (tf.int32, tf.int32)) accumulate_result = computation_building_blocks.Data('a', tf.int32) accumulate = computation_building_blocks.Lambda( 'x', accumulate_type, accumulate_result) merge_type = computation_types.NamedTupleType((tf.int32, tf.int32)) merge_result = computation_building_blocks.Data('m', tf.int32) merge = computation_building_blocks.Lambda('x', merge_type, merge_result) report_ref = computation_building_blocks.Reference('r', tf.int32) report = computation_building_blocks.Lambda(report_ref.name, report_ref.type_signature, report_ref) with self.assertRaises(TypeError): computation_constructing_utils.create_federated_aggregate( None, zero, accumulate, merge, report)
def construct_federated_getitem_comp(comp, key): """Function to construct computation for `federated_apply` of `__getitem__`. Constructs a `computation_building_blocks.ComputationBuildingBlock` which selects `key` from its argument, of type `comp.type_signature.member`, of type `computation_types.NamedTupleType`. Args: comp: Instance of `computation_building_blocks.ComputationBuildingBlock` with type signature `computation_types.FederatedType` whose `member` attribute is of type `computation_types.NamedTupleType`. key: Instance of `int` or `slice`, key used to grab elements from the member of `comp`. implementation of slicing for `ValueImpl` objects with `type_signature` `computation_types.NamedTupleType`. Returns: Instance of `computation_building_blocks.Lambda` which grabs slice according to `key` of its argument. """ py_typecheck.check_type( comp, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type(comp.type_signature, computation_types.FederatedType) py_typecheck.check_type(comp.type_signature.member, computation_types.NamedTupleType) py_typecheck.check_type(key, (int, slice)) apply_input = computation_building_blocks.Reference( 'x', comp.type_signature.member) if isinstance(key, int): selected = computation_building_blocks.Selection(apply_input, index=key) else: elems = anonymous_tuple.to_elements(comp.type_signature.member) index_range = range(*key.indices(len(elems))) elem_list = [] for k in index_range: elem_list.append( (elems[k][0], computation_building_blocks.Selection(apply_input, index=k))) selected = computation_building_blocks.Tuple(elem_list) apply_lambda = computation_building_blocks.Lambda( 'x', apply_input.type_signature, selected) return apply_lambda
def _transform(comp): """Internal transform function.""" if not _should_transform(comp): return comp map_arg = comp.argument[1].argument[1] inner_arg = computation_building_blocks.Reference( 'inner_arg', map_arg.type_signature.member) inner_fn = comp.argument[1].argument[0] inner_call = computation_building_blocks.Call(inner_fn, inner_arg) outer_fn = comp.argument[0] outer_call = computation_building_blocks.Call(outer_fn, inner_call) map_lambda = computation_building_blocks.Lambda( inner_arg.name, inner_arg.type_signature, outer_call) map_tuple = computation_building_blocks.Tuple([map_lambda, map_arg]) map_intrinsic_type = computation_types.FunctionType( map_tuple.type_signature, comp.function.type_signature.result) map_intrinsic = computation_building_blocks.Intrinsic( comp.function.uri, map_intrinsic_type) return computation_building_blocks.Call(map_intrinsic, map_tuple)
def test_returns_string_for_block(self): data = computation_building_blocks.Data('data', tf.int32) ref = computation_building_blocks.Reference('c', tf.int32) comp = computation_building_blocks.Block((('a', data), ('b', data)), ref) compact_string = comp.compact_representation() self.assertEqual(compact_string, '(let a=data,b=data in c)') formatted_string = comp.formatted_representation() # pyformat: disable self.assertEqual(formatted_string, '(let\n' ' a=data,\n' ' b=data\n' ' in c)') # pyformat: enable structural_string = comp.structural_representation() # pyformat: disable self.assertEqual( structural_string, ' Block\n' ' / \\\n' '[a=data, b=data] Ref(c)')
def _create_lambda_to_identity(type_spec): r"""Creates a lambda to return the argument. Lambda \ Ref(arg) Args: type_spec: The type of the argument. Returns: A `computation_building_blocks.Lambda`. Raises: TypeError: If `type_spec` is not a `tf.dtypes.DType`. """ py_typecheck.check_type(type_spec, tf.dtypes.DType) arg = computation_building_blocks.Reference('arg', type_spec) return computation_building_blocks.Lambda(arg.name, arg.type_signature, arg)
def test_replace_chained_federated_maps_replaces_federated_maps_with_different_types( self): fn_1 = _create_lambda_to_dummy_cast(tf.int32, tf.float32) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) arg = computation_building_blocks.Reference('x', arg_type) call_1 = _create_called_federated_map(fn_1, arg) fn_2 = _create_lambda_to_identity(tf.float32) call_2 = _create_called_federated_map(fn_2, call_1) comp = call_2 transformed_comp = transformations.replace_chained_federated_maps_with_federated_map( comp) self.assertEqual( comp.tff_repr, 'federated_map(<(arg -> arg),federated_map(<(arg -> data),x>)>)') self.assertEqual( transformed_comp.tff_repr, 'federated_map(<(arg -> (arg -> arg)((arg -> data)(arg))),x>)')
def test_remove_mapped_or_applied_identity_removes_nested_identity( self, uri, data_type): data = computation_building_blocks.Data('x', data_type) identity_arg = computation_building_blocks.Reference('arg', tf.float32) identity_lam = computation_building_blocks.Lambda( 'arg', tf.float32, identity_arg) arg_tuple = computation_building_blocks.Tuple([identity_lam, data]) function_type = computation_types.FunctionType( [arg_tuple.type_signature[0], arg_tuple.type_signature[1]], arg_tuple.type_signature[1]) intrinsic = computation_building_blocks.Intrinsic(uri, function_type) call = computation_building_blocks.Call(intrinsic, arg_tuple) tuple_wrapped_call = computation_building_blocks.Tuple([call]) lambda_wrapped_tuple = computation_building_blocks.Lambda( 'y', tf.int32, tuple_wrapped_call) self.assertEqual(str(lambda_wrapped_tuple), '(y -> <{}(<(arg -> arg),x>)>)'.format(uri)) reduced = transformations.remove_mapped_or_applied_identity( lambda_wrapped_tuple) self.assertEqual(str(reduced), '(y -> <x>)')
def _transform(comp, context_tree): """Renames References in `comp` to unique names.""" if isinstance(comp, computation_building_blocks.Reference): new_name = context_tree.get_payload_with_name(comp.name).new_name return computation_building_blocks.Reference(new_name, comp.type_signature, comp.context), True elif isinstance(comp, computation_building_blocks.Block): new_locals = [] for name, val in comp.locals: context_tree.walk_down_one_variable_binding() new_name = context_tree.get_payload_with_name(name).new_name new_locals.append((new_name, val)) return computation_building_blocks.Block(new_locals, comp.result), True elif isinstance(comp, computation_building_blocks.Lambda): context_tree.walk_down_one_variable_binding() new_name = context_tree.get_payload_with_name( comp.parameter_name).new_name return computation_building_blocks.Lambda(new_name, comp.parameter_type, comp.result), True return comp, False
def test_returns_string_for_comp_with_left_overhang(self): fn_type = computation_types.FunctionType(tf.int32, tf.int32) fn = computation_building_blocks.Reference('a', fn_type) proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(1), None, context_stack_impl.context_stack) compiled = computation_building_blocks.CompiledComputation( proto, 'bbbbb') arg = computation_building_blocks.Call(compiled) comp = computation_building_blocks.Call(fn, arg) compact_string = comp.compact_representation() self.assertEqual(compact_string, 'a(comp#bbbbb())') formatted_string = comp.formatted_representation() self.assertEqual(formatted_string, 'a(comp#bbbbb())') structural_string = comp.structural_representation() # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' ' Ref(a) Call\n' ' /\n' 'Compiled(bbbbb)')
def _extract_from_block(comp): """Returns a new computation with all intrinsics extracted.""" if _is_called_intrinsic(comp.result): called_intrinsic = comp.result name = six.next(name_generator) variables = comp.locals variables.append((name, called_intrinsic)) result = computation_building_blocks.Reference( name, called_intrinsic.type_signature) return computation_building_blocks.Block(variables, result) elif isinstance(comp.result, computation_building_blocks.Block): return computation_building_blocks.Block(comp.locals + comp.result.locals, comp.result.result) else: variables = [] for name, variable in comp.locals: if isinstance(variable, computation_building_blocks.Block): variables.extend(variable.locals) variables.append((name, variable.result)) else: variables.append((name, variable)) return computation_building_blocks.Block(variables, comp.result)
def _extract_from_lambda(comp): """Returns a new computation with all intrinsics extracted.""" if _is_called_intrinsic(comp.result): called_intrinsic = comp.result name = six.next(name_generator) variables = ((name, called_intrinsic), ) ref = computation_building_blocks.Reference( name, called_intrinsic.type_signature) if not _contains_unbound_reference(comp.result, comp.parameter_name): fn = computation_building_blocks.Lambda( comp.parameter_name, comp.parameter_type, ref) return computation_building_blocks.Block(variables, fn) else: block = computation_building_blocks.Block(variables, ref) return computation_building_blocks.Lambda( comp.parameter_name, comp.parameter_type, block) else: block = comp.result extracted_variables = [] retained_variables = [] for name, variable in block.locals: names = [n for n, _ in retained_variables] if (not _contains_unbound_reference(variable, comp.parameter_name) and not _contains_unbound_reference(variable, names)): extracted_variables.append((name, variable)) else: retained_variables.append((name, variable)) if retained_variables: result = computation_building_blocks.Block( retained_variables, block.result) else: result = block.result fn = computation_building_blocks.Lambda(comp.parameter_name, comp.parameter_type, result) block = computation_building_blocks.Block(extracted_variables, fn) return _extract_from_block(block)
def test_remove_mapped_or_applied_identity_does_not_remove_other_intrinsic( self): data_type = tf.int32 uri = 'dummy' data = computation_building_blocks.Data('x', data_type) identity_arg = computation_building_blocks.Reference('arg', tf.float32) identity_lam = computation_building_blocks.Lambda( 'arg', tf.float32, identity_arg) arg_tuple = computation_building_blocks.Tuple([identity_lam, data]) function_type = computation_types.FunctionType( [arg_tuple.type_signature[0], arg_tuple.type_signature[1]], arg_tuple.type_signature[1]) intrinsic = computation_building_blocks.Intrinsic(uri, function_type) call = computation_building_blocks.Call(intrinsic, arg_tuple) comp = call transformed_comp = transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(str(comp), '{}(<(arg -> arg),x>)'.format(uri)) self.assertEqual(str(transformed_comp), '{}(<(arg -> arg),x>)'.format(uri))
def _construct_naming_function(tuple_type_to_name, names_to_add): """Private function to construct lambda naming a given tuple type. Args: tuple_type_to_name: Instance of `computation_types.NamedTupleType`, the type of the argument which we wish to name. names_to_add: Python `list` or `tuple`, the names we wish to give to `tuple_type_to_name`. Returns: An instance of `computation_building_blocks.Lambda` representing a function which will take an argument of type `tuple_type_to_name` and return a tuple with the same elements, but with names in `names_to_add` attached. Raises: ValueError: If `tuple_type_to_name` and `names_to_add` have different lengths. """ py_typecheck.check_type(tuple_type_to_name, computation_types.NamedTupleType) if len(names_to_add) != len(tuple_type_to_name): raise ValueError( 'Number of elements in `names_to_add` must match number of element in ' 'the named tuple type `tuple_type_to_name`; here, `names_to_add` has ' '{} elements and `tuple_type_to_name` has {}.'.format( len(names_to_add), len(tuple_type_to_name))) naming_lambda_arg = computation_building_blocks.Reference( 'x', tuple_type_to_name) def _create_tuple_element(i): return (names_to_add[i], computation_building_blocks.Selection(naming_lambda_arg, index=i)) named_result = computation_building_blocks.Tuple( [_create_tuple_element(k) for k in range(len(names_to_add))]) return computation_building_blocks.Lambda('x', naming_lambda_arg.type_signature, named_result)
def test_replace_chained_federated_maps_does_not_replace_one_federated_maps( self): map_arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) map_arg = computation_building_blocks.Reference('arg', map_arg_type) inner_lambda = _create_lambda_to_add_one(map_arg.type_signature.member) inner_call = _create_call_to_federated_map(inner_lambda, map_arg) map_lambda = computation_building_blocks.Lambda( map_arg.name, map_arg.type_signature, inner_call) comp = map_lambda uri = intrinsic_defs.FEDERATED_MAP.uri self.assertEqual(_get_number_of_intrinsics(comp, uri), 1) comp_impl = _to_comp(comp) self.assertEqual(comp_impl([(1)]), [2]) transformed_comp = transformations.replace_chained_federated_maps_with_federated_map( comp) self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 1) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl([(1)]), [2])
def test_basic_functionality_of_lambda_class(self): arg_name = 'arg' arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)), ('x', tf.int32)] arg = computation_building_blocks.Reference(arg_name, arg_type) arg_f = computation_building_blocks.Selection(arg, name='f') arg_x = computation_building_blocks.Selection(arg, name='x') x = computation_building_blocks.Lambda( arg_name, arg_type, computation_building_blocks.Call( arg_f, computation_building_blocks.Call(arg_f, arg_x))) self.assertEqual(str(x.type_signature), '(<f=(int32 -> int32),x=int32> -> int32)') self.assertEqual(x.parameter_name, arg_name) self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>') self.assertEqual( computation_building_blocks.compact_representation(x.result), 'arg.f(arg.f(arg.x))') arg_type_repr = ( 'NamedTupleType([' '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), ' '(\'x\', TensorType(tf.int32))])') self.assertEqual( repr(x), 'Lambda(\'arg\', {0}, ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format( arg_type_repr)) self.assertEqual(computation_building_blocks.compact_representation(x), '(arg -> arg.f(arg.f(arg.x)))') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'lambda') self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name) self.assertEqual(str(getattr(x_proto, 'lambda').result), str(x.result.proto)) self._serialize_deserialize_roundtrip_test(x)
def construct_federated_getattr_comp(comp, name): """Function to construct computation for `federated_apply` of `__getattr__`. Constructs a `computation_building_blocks.ComputationBuildingBlock` which selects `name` from its argument, of type `comp.type_signature.member`, an instance of `computation_types.NamedTupleType`. Args: comp: Instance of `computation_building_blocks.ComputationBuildingBlock` with type signature `computation_types.FederatedType` whose `member` attribute is of type `computation_types.NamedTupleType`. name: String name of attribute to grab. Returns: Instance of `computation_building_blocks.Lambda` which grabs attribute according to `name` of its argument. """ py_typecheck.check_type( comp, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type(comp.type_signature, computation_types.FederatedType) py_typecheck.check_type(comp.type_signature.member, computation_types.NamedTupleType) py_typecheck.check_type(name, six.string_types) element_names = [ x for x, _ in anonymous_tuple.to_elements(comp.type_signature.member) ] if name not in element_names: raise ValueError( 'The federated value {} has no element of name {}'.format( comp, name)) apply_input = computation_building_blocks.Reference( 'x', comp.type_signature.member) selected = computation_building_blocks.Selection(apply_input, name=name) apply_lambda = computation_building_blocks.Lambda( 'x', apply_input.type_signature, selected) return apply_lambda
def test_returns_string_for_comp_with_right_overhang(self): ref = computation_building_blocks.Reference('a', tf.int32) data = computation_building_blocks.Data('data', tf.int32) tup = computation_building_blocks.Tuple([ref, data, data, data, data]) sel = computation_building_blocks.Selection(tup, index=0) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, sel) comp = computation_building_blocks.Call(fn, data) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, '(a -> <a,data,data,data,data>[0])(data)') formatted_string = computation_building_blocks.formatted_representation( comp) # pyformat: disable self.assertEqual( formatted_string, '(a -> <\n' ' a,\n' ' data,\n' ' data,\n' ' data,\n' ' data\n' '>[0])(data)') # pyformat: enable structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' 'Lambda(a) data\n' '|\n' 'Sel(0)\n' '|\n' 'Tuple\n' '|\n' '[Ref(a), data, data, data, data]')
def _create_lambda_to_cast(dtype1, dtype2): r"""Creates a computation to TensorFlow cast from dtype1 to dtype2. Lambda \ Call / \ Compiled Reference Computation Where `CompiledComputation` is a TensorFlow computation casting from `dtype1` to `dtype2`. The `dtype` arguments can be either instances of `tf.dtypes.DType` or `computation_types.TensorType`, but in the latter case the `tf.dtypes.DType` of these tensors will be extracted. Args: dtype1: The type of the argument. dtype2: The type to cast the argument to. Returns: An instance of `computation_building_blocks.Lambda` wrapping a function that casts TensorFlow dtype1 to dtype2. """ if isinstance(dtype1, computation_types.TensorType): dtype1 = dtype1.dtype if isinstance(dtype2, computation_types.TensorType): dtype2 = dtype2.dtype py_typecheck.check_type(dtype1, tf.dtypes.DType) py_typecheck.check_type(dtype2, tf.dtypes.DType) arg = computation_building_blocks.Reference('arg', dtype1) tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda x: tf.cast(x, dtype2), dtype1, context_stack_impl.context_stack) compiled_comp = computation_building_blocks.CompiledComputation(tf_comp) call = computation_building_blocks.Call(compiled_comp, arg) return computation_building_blocks.Lambda(arg.name, dtype1, call)
def test_returns_federated_aggregate(self): value_type = computation_types.FederatedType(tf.int32, placements.CLIENTS, False) value = computation_building_blocks.Data('v', value_type) zero = computation_building_blocks.Data('z', tf.int32) accumulate_type = computation_types.NamedTupleType( (tf.int32, tf.int32)) accumulate_result = computation_building_blocks.Data('a', tf.int32) accumulate = computation_building_blocks.Lambda( 'x', accumulate_type, accumulate_result) merge_type = computation_types.NamedTupleType((tf.int32, tf.int32)) merge_result = computation_building_blocks.Data('m', tf.int32) merge = computation_building_blocks.Lambda('x', merge_type, merge_result) report_ref = computation_building_blocks.Reference('r', tf.int32) report = computation_building_blocks.Lambda(report_ref.name, report_ref.type_signature, report_ref) comp = computation_constructing_utils.create_federated_aggregate( value, zero, accumulate, merge, report) self.assertEqual( comp.tff_repr, 'federated_aggregate(<v,z,(x -> a),(x -> m),(r -> r)>)') self.assertEqual(str(comp.type_signature), 'int32@SERVER')