def test_federated_select_autozips_server_val(self, federated_select): client_keys, max_key, server_val, select_fn = ( self.basic_federated_select_args()) del server_val, select_fn values = ['first', 'second', 'third'] server_val_element = intrinsics.federated_value( values, placements.SERVER) server_val_dict = collections.OrderedDict(e1=server_val_element, e2=server_val_element) state_type = computation_types.StructType([ ('e1', server_val_element.type_signature.member), ('e2', server_val_element.type_signature.member), ]) def _select_fn_dict(arg): state = type_conversions.type_to_py_container(arg[0], state_type) key = arg[1] return (tf.gather(state['e1'], key), tf.gather(state['e2'], key)) select_fn_dict_type = computation_types.StructType( [state_type, tf.int32]) select_fn_dict_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( _select_fn_dict, select_fn_dict_type) select_fn_dict = computation_impl.ConcreteComputation( select_fn_dict_proto, context_stack_impl.context_stack) result = federated_select(client_keys, max_key, server_val_dict, select_fn_dict) self.assert_value(result, '{<string,string>*}@CLIENTS')
def test_type_properties(self, value_type): factory = stochastic_discretization.StochasticDiscretizationFactory( step_size=0.1, inner_agg_factory=_measurement_aggregator, distortion_aggregation_factory=mean.UnweightedMeanFactory()) value_type = computation_types.to_type(value_type) quantize_type = type_conversions.structure_from_tensor_type_tree( lambda x: (tf.int32, x.shape), value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.StructType([('step_size', tf.float32), ('inner_agg_process', ()) ]) server_state_type = computation_types.at_server(server_state_type) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) type_test_utils.assert_types_equivalent(process.initialize.type_signature, expected_initialize_type) expected_measurements_type = computation_types.StructType([ ('stochastic_discretization', quantize_type), ('distortion', tf.float32) ]) expected_measurements_type = computation_types.at_server( expected_measurements_type) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) type_test_utils.assert_types_equivalent(process.next.type_signature, expected_next_type)
def test_deserialize_federated_value_promotes_types(self): x = [10] smaller_type = computation_types.StructType([ (None, computation_types.to_type(tf.int32)) ]) smaller_type_member_proto, _ = value_serialization.serialize_value( x, smaller_type) larger_type = computation_types.StructType([ ('a', computation_types.to_type(tf.int32)) ]) larger_type_member_proto, _ = value_serialization.serialize_value( x, larger_type) type_at_clients = type_serialization.serialize_type( computation_types.at_clients(tf.int32)) unspecified_member_federated_type = computation_pb2.FederatedType( placement=type_at_clients.federated.placement, all_equal=False) federated_proto = executor_pb2.Value.Federated( type=unspecified_member_federated_type, value=[larger_type_member_proto, smaller_type_member_proto]) federated_value_proto = executor_pb2.Value(federated=federated_proto) _, deserialized_type_spec = value_serialization.deserialize_value( federated_value_proto) type_test_utils.assert_types_identical( deserialized_type_spec, computation_types.at_clients(larger_type))
def test_nested_tuple(self): ex = sizing_executor.SizingExecutor( eager_tf_executor.EagerTFExecutor()) a = computation_types.TensorType(tf.int32, [4]) b = computation_types.TensorType(tf.bool, [2]) c = computation_types.TensorType(tf.int64, [2, 3]) inner_type = computation_types.StructType([('a', a), ('b', b), ('c', c)]) outer_type = computation_types.StructType([('a', inner_type), ('b', inner_type)]) inner_type_val = collections.OrderedDict() inner_type_val['a'] = [0, 1, 2, 3] inner_type_val['b'] = [True, False] inner_type_val['c'] = [[1, 2, 3], [4, 5, 6]] outer_type_val = collections.OrderedDict() outer_type_val['a'] = inner_type_val outer_type_val['b'] = inner_type_val async def _make(): v1 = await ex.create_value(outer_type_val, outer_type) return await v1.compute() asyncio.run(_make()) self.assertCountEqual(ex.broadcast_history, [[4, tf.int32], [2, tf.bool], [6, tf.int64], [4, tf.int32], [2, tf.bool], [6, tf.int64]])
class IsSumCompatibleTest(parameterized.TestCase): @parameterized.named_parameters([ ('tensor_type', computation_types.TensorType(tf.int32)), ('tuple_type_int', computation_types.StructType([tf.int32, tf.int32], )), ('tuple_type_float', computation_types.StructType([tf.complex128, tf.float32, tf.float64])), ('federated_type', computation_types.FederatedType(tf.int32, placements.CLIENTS)), ]) def test_positive_examples(self, type_spec): type_analysis.check_is_sum_compatible(type_spec) @parameterized.named_parameters([ ('tensor_type_bool', computation_types.TensorType(tf.bool)), ('tensor_type_string', computation_types.TensorType(tf.string)), ('partially_defined_shape', computation_types.TensorType(tf.int32, shape=[None])), ('tuple_type', computation_types.StructType([tf.int32, tf.bool])), ('sequence_type', computation_types.SequenceType(tf.int32)), ('placement_type', computation_types.PlacementType()), ('function_type', computation_types.FunctionType(tf.int32, tf.int32)), ('abstract_type', computation_types.AbstractType('T')), ('ragged_tensor', computation_types.StructWithPythonType([], tf.RaggedTensor)), ('sparse_tensor', computation_types.StructWithPythonType([], tf.SparseTensor)), ]) def test_negative_examples(self, type_spec): with self.assertRaises(type_analysis.SumIncompatibleError): type_analysis.check_is_sum_compatible(type_spec)
def test_executor_stacks(self, executor_stack): assert executor_stack ex = eager_tf_executor.EagerTFExecutor() for wrapping_executor in reversed(executor_stack): ex = wrapping_executor(ex) a = computation_types.TensorType(tf.int32, [4]) b = computation_types.TensorType(tf.bool, [2]) c = computation_types.TensorType(tf.int64, [2, 3]) inner_type = computation_types.StructType([('a', a), ('b', b), ('c', c)]) outer_type = computation_types.StructType([('a', inner_type), ('b', inner_type)]) inner_type_val = collections.OrderedDict() inner_type_val['a'] = [0, 1, 2, 3] inner_type_val['b'] = [True, False] inner_type_val['c'] = [[1, 2, 3], [4, 5, 6]] outer_type_val = collections.OrderedDict() outer_type_val['a'] = inner_type_val outer_type_val['b'] = inner_type_val async def _make(): v1 = await ex.create_value(outer_type_val, outer_type) return await v1.compute() asyncio.get_event_loop().run_until_complete(_make()) self.assertCountEqual(ex.broadcast_history, [[4, tf.int32], [2, tf.bool], [6, tf.int64], [4, tf.int32], [2, tf.bool], [6, tf.int64]])
def test_long_formatted_with_diff(self): int32 = computation_types.TensorType(tf.int32) first = computation_types.StructType([(None, int32)] * 20) second = computation_types.StructType([(None, int32)] * 21) actual = computation_types.type_mismatch_error_message( first, second, computation_types.TypeRelation.EQUIVALENT) golden.check_string('long_formatted_with_diff.expected', actual)
def _get_accumulator_type(member_type): """Constructs a `tff.Type` for the accumulator in sample aggregation. Args: member_type: A `tff.Type` representing the member components of the federated type. Returns: The `tff.StructType` associated with the accumulator. The tuple contains two parts, `accumulators` and `rands`, that are parallel lists (e.g. the i-th index in one corresponds to the i-th index in the other). These two lists are used to sample from the accumulators with equal probability. """ # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed. if member_type.is_struct(): a = structure.map_structure( lambda v: computation_types.TensorType(v.dtype, [None] + v.shape. dims), member_type) return computation_types.StructType( collections.OrderedDict({ 'accumulators': computation_types.StructType(structure.to_odict(a, True)), 'rands': computation_types.TensorType(tf.float32, shape=[None]), })) return computation_types.StructType( collections.OrderedDict({ 'accumulators': computation_types.TensorType(member_type.dtype, shape=[None] + member_type.shape.dims), 'rands': computation_types.TensorType(tf.float32, shape=[None]), }))
def test_transform_same_compiled_computation_different_type_signature( self): # First create no-arg computation that returns an empty tuple. This will # be compared against a omputation that returns a nested empty tuple, which # should produce a different ID. empty_tuple_computation = building_block_factory.create_tensorflow_unary_operator( lambda x: (), operand_type=computation_types.StructType([])) add_ids = compiled_computation_transformations.AddUniqueIDs() first_transformed_comp, mutated = add_ids.transform( empty_tuple_computation) self.assertTrue(mutated) self.assertIsInstance(first_transformed_comp, building_blocks.CompiledComputation) self.assertTrue( first_transformed_comp.proto.tensorflow.HasField('cache_key')) self.assertNotEqual( first_transformed_comp.proto.tensorflow.cache_key.id, 0) # Now create the same NoOp tf.Graph, but with a different binding and # type_signature. nested_empty_tuple_computation = building_block_factory.create_tensorflow_unary_operator( lambda x: ((), ), operand_type=computation_types.StructType([])) second_transformed_comp, mutated = add_ids.transform( nested_empty_tuple_computation) self.assertTrue(mutated) self.assertIsInstance(second_transformed_comp, building_blocks.CompiledComputation) self.assertTrue( second_transformed_comp.proto.tensorflow.HasField('cache_key')) self.assertNotEqual( second_transformed_comp.proto.tensorflow.cache_key.id, 0) # Assert the IDs are different based on the type signture. self.assertNotEqual( first_transformed_comp.proto.tensorflow.cache_key.id, second_transformed_comp.proto.tensorflow.cache_key.id)
class CreateReplicateInputTest(parameterized.TestCase): @parameterized.named_parameters( ('int', computation_types.TensorType(tf.int32), 3, 10), ('float', computation_types.TensorType(tf.float32), 3, 10.0), ('unnamed_tuple', computation_types.StructType([tf.int32, tf.float32]), 3, structure.Struct([(None, 10), (None, 10.0)])), ('named_tuple', computation_types.StructType([ ('a', tf.int32), ('b', tf.float32) ]), 3, structure.Struct([('a', 10), ('b', 10.0)])), ('sequence', computation_types.SequenceType(tf.int32), 3, [10] * 3), ) def test_returns_computation(self, type_signature, count, value): proto, _ = tensorflow_computation_factory.create_replicate_input( type_signature, count) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType( type_signature, [type_signature] * count) expected_type.check_assignable_from(actual_type) actual_result = test_utils.run_tensorflow(proto, value) expected_result = structure.Struct([(None, value)] * count) self.assertEqual(actual_result, expected_result) @parameterized.named_parameters( ('none_type', None, 3), ('none_count', computation_types.TensorType(tf.int32), None), ('federated_type', computation_types.at_server(tf.int32), 3), ) def test_raises_type_error(self, type_signature, count): with self.assertRaises(TypeError): tensorflow_computation_factory.create_replicate_input( type_signature, count)
class ContainsOnlyTypesTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters([ ('one_type', computation_types.TensorType(tf.int32), computation_types.TensorType), ('two_types', computation_types.StructType([tf.int32]), (computation_types.StructType, computation_types.TensorType)), ('less_types', computation_types.TensorType(tf.int32), (computation_types.StructType, computation_types.TensorType)), ]) # pyformat: enable def test_returns_true(self, type_signature, types): result = type_analysis.contains_only(type_signature, lambda x: isinstance(x, types)) self.assertTrue(result) # pyformat: disable @parameterized.named_parameters([ ('one_type', computation_types.TensorType(tf.int32), computation_types.StructType), ('more_types', computation_types.StructType([tf.int32]), computation_types.TensorType), ]) # pyformat: enable def test_returns_false(self, type_signature, types): result = type_analysis.contains_only(type_signature, lambda x: isinstance(x, types)) self.assertFalse(result)
def test_raises_type_error(self): type_analysis.check_valid_federated_weighted_mean_argument_tuple_type( computation_types.StructType( [computation_types.at_clients(tf.float32)] * 2)) with self.assertRaises(TypeError): type_analysis.check_valid_federated_weighted_mean_argument_tuple_type( computation_types.StructType( [computation_types.at_clients(tf.int32)] * 2))
def test_serialize_deserialize_named_tuple_types(self): self._serialize_deserialize_roundtrip_test([ computation_types.StructType([tf.int32, tf.bool]), computation_types.StructType([ tf.int32, computation_types.StructType([('x', tf.bool)]), ]), computation_types.StructType([('x', tf.int32)]), ])
def test_succeeds_function_different_parameter_and_return_types(self): t1 = computation_types.FunctionType( computation_types.StructType([ computation_types.AbstractType('U'), computation_types.AbstractType('T') ]), computation_types.AbstractType('T')) t2 = computation_types.FunctionType( computation_types.StructType([tf.int32, tf.float32]), tf.float32) type_analysis.check_concrete_instance_of(t2, t1)
class CreateBinaryOperatorTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters( ('add_int', tf.math.add, computation_types.TensorType( tf.int32), [1, 2], 3), ('add_float', tf.math.add, computation_types.TensorType( tf.float32), [1.0, 2.25], 3.25), ('add_unnamed_tuple', tf.math.add, computation_types.StructType([tf.int32, tf.float32]), [[1, 1.0], [2, 2.25]], structure.Struct([(None, 3), (None, 3.25)])), ('add_named_tuple', tf.math.add, computation_types.StructType([ ('a', tf.int32), ('b', tf.float32) ]), [[1, 1.0], [2, 2.25]], structure.Struct([('a', 3), ('b', 3.25)])), ('multiply_int', tf.math.multiply, computation_types.TensorType(tf.int32), [2, 2], 4), ('multiply_float', tf.math.multiply, computation_types.TensorType(tf.float32), [2.0, 2.25], 4.5), ('divide_int', tf.math.divide, computation_types.TensorType( tf.int32), [4, 2], 2.0), ('divide_float', tf.math.divide, computation_types.TensorType(tf.float32), [4.0, 2.0], 2.0), ('divide_inf', tf.math.divide, computation_types.TensorType( tf.int32), [1, 0], np.inf), ) # pyformat: enable def test_returns_computation(self, operator, type_signature, operands, expected_result): proto, _ = tensorflow_computation_factory.create_binary_operator( operator, type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) self.assertIsInstance(actual_type, computation_types.FunctionType) # Note: It is only useful to test the parameter type; the result type # depends on the `operator` used, not the implemenation # `create_binary_operator`. expected_parameter_type = computation_types.StructType( [type_signature, type_signature]) self.assertEqual(actual_type.parameter, expected_parameter_type) actual_result = test_utils.run_tensorflow(proto, operands) self.assertEqual(actual_result, expected_result) @parameterized.named_parameters( ('non_callable_operator', 1, computation_types.TensorType(tf.int32)), ('none_type', tf.math.add, None), ('federated_type', tf.math.add, computation_types.at_server(tf.int32)), ('sequence_type', tf.math.add, computation_types.SequenceType( tf.int32)), ) def test_raises_type_error(self, operator, type_signature): with self.assertRaises(TypeError): tensorflow_computation_factory.create_binary_operator( operator, type_signature)
def test_succeeds_under_tuple(self): t1 = self.func_with_param( computation_types.StructType( [computation_types.AbstractType('T1')] * 2)) t2 = self.func_with_param( computation_types.StructType([ computation_types.TensorType(tf.int32), computation_types.TensorType(tf.int32) ])) type_analysis.check_concrete_instance_of(t2, t1)
def test_fails_under_tuple_conflicting_concrete_types(self): t1 = self.func_with_param( computation_types.StructType( [computation_types.AbstractType('T1')] * 2)) t2 = self.func_with_param( computation_types.StructType([ computation_types.TensorType(tf.int32), computation_types.TensorType(tf.float32) ])) with self.assertRaises(type_analysis.MismatchedConcreteTypesError): type_analysis.check_concrete_instance_of(t2, t1)
def test_transforms_unnamed_tuple_type(self): orig_type = computation_types.StructType([tf.int32, tf.float64]) expected_type = computation_types.StructType([tf.float32, tf.float32]) result_type, mutated = type_transformations.transform_type_postorder( orig_type, _convert_tensor_to_float) noop_type, not_mutated = type_transformations.transform_type_postorder( orig_type, _convert_abstract_type_to_tensor) self.assertEqual(result_type, expected_type) self.assertEqual(noop_type, orig_type) self.assertTrue(mutated) self.assertFalse(not_mutated)
def test_generates_tf_with_block(self): ref_to_x = building_blocks.Reference( 'x', computation_types.StructType([tf.int32, tf.float32])) identity_lambda = building_blocks.Lambda(ref_to_x.name, ref_to_x.type_signature, ref_to_x) tf_zero = building_block_factory.create_tensorflow_constant( computation_types.StructType([tf.int32, tf.float32]), 0) ref_to_z = building_blocks.Reference('z', [tf.int32, tf.float32]) called_lambda_on_z = building_blocks.Call(identity_lambda, ref_to_z) blk = building_blocks.Block([('z', tf_zero)], called_lambda_on_z) self.assert_compiles_to_tensorflow(blk)
def test_federated_aggregate_with_unknown_dimension(self): Accumulator = collections.namedtuple('Accumulator', ['samples']) # pylint: disable=invalid-name accumulator_type = computation_types.to_type( Accumulator(samples=computation_types.TensorType(dtype=tf.int32, shape=[None]))) x = _mock_data_of_type(computation_types.at_clients(tf.int32)) def initialize_fn(): return Accumulator(samples=tf.zeros(shape=[0], dtype=tf.int32)) initialize_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( initialize_fn, None) initialize = computation_impl.ConcreteComputation( initialize_proto, context_stack_impl.context_stack) zero = initialize() # The operator to use during the first stage simply adds an element to the # tensor, increasing its size. def _accumulate(arg): return Accumulator(samples=tf.concat( [arg[0].samples, tf.expand_dims(arg[1], axis=0)], axis=0)) accumulate_type = computation_types.StructType( [accumulator_type, tf.int32]) accumulate_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( _accumulate, accumulate_type) accumulate = computation_impl.ConcreteComputation( accumulate_proto, context_stack_impl.context_stack) # The operator to use during the second stage simply adds total and count. def _merge(arg): return Accumulator( samples=tf.concat([arg[0].samples, arg[1].samples], axis=0)) merge_type = computation_types.StructType( [accumulator_type, accumulator_type]) merge_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( _merge, merge_type) merge = computation_impl.ConcreteComputation( merge_proto, context_stack_impl.context_stack) # The operator to use during the final stage simply computes the ratio. report_proto, _ = tensorflow_computation_factory.create_identity( accumulator_type) report = computation_impl.ConcreteComputation( report_proto, context_stack_impl.context_stack) value = intrinsics.federated_aggregate(x, zero, accumulate, merge, report) self.assert_value(value, '<samples=int32[?]>@SERVER')
def test_str(self): self.assertEqual(str(computation_types.StructType([tf.int32])), '<int32>') self.assertEqual(str(computation_types.StructType([('a', tf.int32)])), '<a=int32>') self.assertEqual(str(computation_types.StructType(('a', tf.int32))), '<a=int32>') self.assertEqual( str(computation_types.StructType([tf.int32, tf.bool])), '<int32,bool>') self.assertEqual( str(computation_types.StructType([('a', tf.int32), tf.float32])), '<a=int32,float32>') self.assertEqual( str( computation_types.StructType([('a', tf.int32), ('b', tf.float32)])), '<a=int32,b=float32>') self.assertEqual( str( computation_types.StructType([ ('a', tf.int32), ('b', computation_types.StructType([('x', tf.string), ('y', tf.bool)])) ])), '<a=int32,b=<x=string,y=bool>>')
def test_federated_aggregate_with_client_int(self): # The representation used during the aggregation process will be a named # tuple with 2 elements - the integer 'total' that represents the sum of # elements encountered, and the integer element 'count'. Accumulator = collections.namedtuple('Accumulator', 'total count') # pylint: disable=invalid-name accumulator_type = computation_types.to_type( Accumulator(total=computation_types.TensorType(dtype=tf.int32), count=computation_types.TensorType(dtype=tf.int32))) x = _mock_data_of_type(computation_types.at_clients(tf.int32)) zero = Accumulator(0, 0) # The operator to use during the first stage simply adds an element to the # total and updates the count. def _accumulate(arg): return Accumulator(arg[0].total + arg[1], arg[0].count + 1) accumulate_type = computation_types.StructType( [accumulator_type, tf.int32]) accumulate_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( _accumulate, accumulate_type) accumulate = computation_impl.ConcreteComputation( accumulate_proto, context_stack_impl.context_stack) # The operator to use during the second stage simply adds total and count. def _merge(arg): return Accumulator(arg[0].total + arg[1].total, arg[0].count + arg[1].count) merge_type = computation_types.StructType( [accumulator_type, accumulator_type]) merge_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( _merge, merge_type) merge = computation_impl.ConcreteComputation( merge_proto, context_stack_impl.context_stack) # The operator to use during the final stage simply computes the ratio. def _report(arg): return tf.cast(arg.total, tf.float32) / tf.cast( arg.count, tf.float32) report_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( _report, accumulator_type) report = computation_impl.ConcreteComputation( report_proto, context_stack_impl.context_stack) value = intrinsics.federated_aggregate(x, zero, accumulate, merge, report) self.assert_value(value, 'float32@SERVER')
def test_abstract_parameters_contravariant(self): struct = lambda name: computation_types.StructType([(name, tf.int32)]) unnamed = struct(None) concrete = computation_types.FunctionType( computation_types.StructType([ unnamed, computation_types.FunctionType(struct('bar'), unnamed) ]), struct('foo')) abstract = computation_types.AbstractType('A') generic = computation_types.FunctionType( computation_types.StructType( [abstract, computation_types.FunctionType(abstract, abstract)]), abstract) type_analysis.check_concrete_instance_of(concrete, generic)
class CountTypesTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters([ ('one', computation_types.TensorType(tf.int32), lambda t: t.is_tensor(), 1), ('three', computation_types.StructType([tf.int32] * 3), lambda t: t.is_tensor(), 3), ('nested', computation_types.StructType([[tf.int32] * 3] * 3), lambda t: t.is_tensor(), 9), ]) # pyformat: enable def test_returns_result(self, type_signature, predicate, expected_result): result = type_analysis.count(type_signature, predicate) self.assertEqual(result, expected_result)
class VisitPreorderTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters([ ('abstract_type', computation_types.AbstractType('T'), 1), ('nested_function_type', computation_types.FunctionType( computation_types.FunctionType( computation_types.FunctionType(tf.int32, tf.int32), tf.int32), tf.int32), 7), ('named_tuple_type', computation_types.StructType( [tf.int32, tf.bool, computation_types.SequenceType(tf.int32)]), 5), ('placement_type', computation_types.PlacementType(), 1), ]) # pyformat: enable def test_preorder_call_count(self, type_signature, expected_count): class Counter(object): k = 0 def _count_hits(given_type, arg): del given_type # Unused. Counter.k += 1 return arg type_transformations.visit_preorder(type_signature, _count_hits, None) actual_count = Counter.k self.assertEqual(actual_count, expected_count)
def test_inlines_structs(self): int_type = computation_types.to_type(tf.int32) structed = computation_types.StructType([int_type]) double = computation_types.StructType([structed]) bb = building_blocks before = bb.Lambda( 'x', int_type, bb.Block([ ('y', bb.Struct([building_blocks.Reference('x', int_type)])), ('z', bb.Struct([building_blocks.Reference('y', structed)])), ], bb.Reference('z', double))) after = transformations.to_call_dominant(before) expected = bb.Lambda( 'x', int_type, bb.Struct([bb.Struct([bb.Reference('x', int_type)])])) self.assert_compact_representations_equal(after, expected)
def test_roundtrip_with_named_struct(self): value = collections.OrderedDict(a=0) type_signature = computation_types.StructType([ ('a', computation_types.at_server(tf.int32)) ]) self.assertRoundTripEqual(value, type_signature, structure.Struct([('a', 0)]))
def test_with_block(self): ex = reference_resolving_executor.ReferenceResolvingExecutor( eager_tf_executor.EagerTFExecutor()) f_type = computation_types.FunctionType(tf.int32, tf.int32) a = building_blocks.Reference( 'a', computation_types.StructType([('f', f_type), ('x', tf.int32)])) ret = building_blocks.Block( [('f', building_blocks.Selection(a, name='f')), ('x', building_blocks.Selection(a, name='x'))], building_blocks.Call( building_blocks.Reference('f', f_type), building_blocks.Call(building_blocks.Reference('f', f_type), building_blocks.Reference('x', tf.int32)))) comp = building_blocks.Lambda(a.name, a.type_signature, ret) @tensorflow_computation.tf_computation(tf.int32) def add_one(x): return x + 1 v1 = asyncio.run(ex.create_value(comp.proto, comp.type_signature)) v2 = asyncio.run(ex.create_value(add_one)) v3 = asyncio.run(ex.create_value(10, tf.int32)) v4 = asyncio.run( ex.create_struct(structure.Struct([('f', v2), ('x', v3)]))) v5 = asyncio.run(ex.create_call(v1, v4)) result = asyncio.run(v5.compute()) self.assertEqual(result.numpy(), 12)
def _remove_struct_element_names_from_tff_type(type_spec): """Removes names of struct elements from `type_spec`. Args: type_spec: An instance of `computation_types.Type` that must be a tensor, a (possibly) nested structure of tensors, or a function. Returns: A modified version of `type_spec` with element names in stuctures removed. Raises: TypeError: if arg is of the wrong type. """ if type_spec is None: return None if isinstance(type_spec, computation_types.FunctionType): return computation_types.FunctionType( _remove_struct_element_names_from_tff_type(type_spec.parameter), _remove_struct_element_names_from_tff_type(type_spec.result)) if isinstance(type_spec, computation_types.TensorType): return type_spec py_typecheck.check_type(type_spec, computation_types.StructType) return computation_types.StructType([ (None, _remove_struct_element_names_from_tff_type(v)) for _, v in structure.iter_elements(type_spec) ])
def create_whimsy_computation_tensorflow_add(): """Returns a tensorflow computation and type. `(<float32,float32> -> float32)` """ type_spec = tf.float32 with tf.Graph().as_default() as graph: parameter_1_value, parameter_1_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', type_spec, graph) parameter_2_value, parameter_2_binding = tensorflow_utils.stamp_parameter_in_graph( 'y', type_spec, graph) result_value = tf.add(parameter_1_value, parameter_2_value) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) parameter_type = computation_types.StructType([type_spec, type_spec]) type_signature = computation_types.FunctionType(parameter_type, result_type) struct_binding = pb.TensorFlow.StructBinding( element=[parameter_1_binding, parameter_2_binding]) parameter_binding = pb.TensorFlow.Binding(struct=struct_binding) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) value = pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow) return value, type_signature