def create_whimsy_block(comp, variable_name, variable_type=tf.int32): r"""Returns an identity block. Block / \ [x=data] Comp Args: comp: The computation to use as the result. variable_name: The name of the variable. variable_type: The type of the variable. """ data = building_blocks.Data('data', variable_type) return building_blocks.Block([(variable_name, data)], comp)
def create_identity_block_with_dummy_data(variable_name, variable_type=tf.int32): r"""Returns an identity block with a dummy `Data` computation. Block / \ [x=data] Ref(x) Args: variable_name: The name of the variable. variable_type: The type of the variable. """ data = building_blocks.Data('data', variable_type) return create_identity_block(variable_name, data)
def create_dummy_called_federated_broadcast(value_type=tf.int32): r"""Returns a dummy called federated broadcast. Call / \ federated_map data Args: value_type: The type of the value. """ federated_type = computation_types.FederatedType(value_type, placements.SERVER) value = building_blocks.Data('data', federated_type) return building_block_factory.create_federated_broadcast(value)
def create_whimsy_called_federated_sum(value_type=tf.int32): r"""Returns a whimsy called federated sum. Call / \ federated_sum data Args: value_type: The type of the value. """ federated_type = computation_types.FederatedType(value_type, placements.CLIENTS) value = building_blocks.Data('data', federated_type) return building_block_factory.create_federated_sum(value)
def create_dummy_called_federated_aggregate(accumulate_parameter_name, merge_parameter_name, report_parameter_name, value_type=tf.int32): r"""Returns a dummy called federated aggregate. Call / \ federated_aggregate Tuple | [data, data, Lambda(x), Lambda(x), Lambda(x)] | | | data data data Args: accumulate_parameter_name: The name of the accumulate parameter. merge_parameter_name: The name of the merge parameter. report_parameter_name: The name of the report parameter. value_type: The TFF type of the value to be aggregated, placed at CLIENTS. """ federated_value_type = computation_types.FederatedType( value_type, placements.CLIENTS) value = building_blocks.Data('data', federated_value_type) zero = building_blocks.Data('data', tf.float32) accumulate_type = computation_types.NamedTupleType((tf.float32, value_type)) accumulate_result = building_blocks.Data('data', tf.float32) accumulate = building_blocks.Lambda(accumulate_parameter_name, accumulate_type, accumulate_result) merge_type = computation_types.NamedTupleType((tf.float32, tf.float32)) merge_result = building_blocks.Data('data', tf.float32) merge = building_blocks.Lambda(merge_parameter_name, merge_type, merge_result) report_result = building_blocks.Data('data', tf.bool) report = building_blocks.Lambda(report_parameter_name, tf.float32, report_result) return building_block_factory.create_federated_aggregate( value, zero, accumulate, merge, report)
def test_does_not_remove_called_lambda(self): fn = building_block_test_utils.create_identity_function('a', tf.int32) arg = building_blocks.Data('data', tf.int32) call = building_blocks.Call(fn, arg) comp = call transformed_comp, modified = tree_transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(transformed_comp.compact_representation(), comp.compact_representation()) self.assertEqual(transformed_comp.compact_representation(), '(a -> a)(data)') self.assertEqual(transformed_comp.type_signature, comp.type_signature) self.assertFalse(modified)
def test_single_level_block(self): ref = building_blocks.Reference('a', tf.int32) data = building_blocks.Data('data', tf.int32) block = building_blocks.Block((('a', data), ('a', ref), ('a', ref)), ref) transformed_comp, modified = tree_transformations.uniquify_reference_names( block) self.assertEqual(block.compact_representation(), '(let a=data,a=a,a=a in a)') self.assertEqual(transformed_comp.compact_representation(), '(let a=data,_var1=a,_var2=_var1 in _var2)') tree_analysis.check_has_unique_names(transformed_comp) self.assertTrue(modified)
def test_returns_string_for_tuple_with_names(self): data = building_blocks.Data('data', tf.int32) comp = building_blocks.Tuple((('a', data), ('b', data))) compact_string = comp.compact_representation() self.assertEqual(compact_string, '<a=data,b=data>') formatted_string = comp.formatted_representation() # pyformat: disable self.assertEqual(formatted_string, '<\n' ' a=data,\n' ' b=data\n' '>') # pyformat: enable structural_string = comp.structural_representation() # pyformat: disable self.assertEqual(structural_string, 'Tuple\n' '|\n' '[a=data, b=data]')
def test_basic_functionality_of_data_class(self): x = building_blocks.Data('/tmp/mydata', computation_types.SequenceType(tf.int32)) self.assertEqual(str(x.type_signature), 'int32*') self.assertEqual(x.uri, '/tmp/mydata') self.assertEqual( repr(x), 'Data(\'/tmp/mydata\', SequenceType(TensorType(tf.int32)))') self.assertEqual(x.compact_representation(), '/tmp/mydata') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'data') self.assertEqual(x_proto.data.uri, x.uri) self._serialize_deserialize_roundtrip_test(x)
def create_whimsy_called_sequence_map(parameter_name, parameter_type=tf.int32): r"""Returns a whimsy called sequence map. Call / \ sequence_map data Args: parameter_name: The name of the parameter. parameter_type: The type of the parameter. """ fn = create_identity_function(parameter_name, parameter_type) arg_type = computation_types.SequenceType(parameter_type) arg = building_blocks.Data('data', arg_type) return building_block_factory.create_sequence_map(fn, arg)
def test_returns_string_for_call_with_arg(self): fn_type = computation_types.FunctionType(tf.int32, tf.int32) fn = building_blocks.Reference('a', fn_type) arg = building_blocks.Data('data', tf.int32) comp = building_blocks.Call(fn, arg) self.assertEqual(comp.compact_representation(), 'a(data)') self.assertEqual(comp.formatted_representation(), 'a(data)') # pyformat: disable self.assertEqual( comp.structural_representation(), ' Call\n' ' / \\\n' 'Ref(a) data' )
def test_returns_string_for_block(self): data = building_blocks.Data('data', tf.int32) ref = building_blocks.Reference('c', tf.int32) comp = building_blocks.Block((('a', data), ('b', data)), ref) self.assertEqual(comp.compact_representation(), '(let a=data,b=data in c)') # pyformat: disable self.assertEqual(comp.formatted_representation(), '(let\n' ' a=data,\n' ' b=data\n' ' in c)') self.assertEqual( comp.structural_representation(), ' Block\n' ' / \\\n' '[a=data, b=data] Ref(c)')
def test_nested_blocks(self): x_ref = building_blocks.Reference('a', tf.int32) data = building_blocks.Data('data', tf.int32) block1 = building_blocks.Block([('a', data), ('a', x_ref)], x_ref) block2 = building_blocks.Block([('a', data), ('a', x_ref)], block1) transformed_comp, modified = tree_transformations.uniquify_reference_names( block2) self.assertEqual(block2.compact_representation(), '(let a=data,a=a in (let a=data,a=a in a))') self.assertEqual( transformed_comp.compact_representation(), '(let a=data,_var1=a in (let _var2=data,_var3=_var2 in _var3))') tree_analysis.check_has_unique_names(transformed_comp) self.assertTrue(modified)
def data(uri: str, type_spec: computation_types.Type): """Constructs a TFF `data` computation with the given URI and TFF type. Args: uri: A string (`str`) URI of the data. type_spec: An instance of `tff.Type` that represents the type of this data. Returns: A representation of the data with the given URI and TFF type in the body of a federated computation. Raises: TypeError: If the arguments are not of the types specified above. """ py_typecheck.check_type(uri, str) type_spec = computation_types.to_type(type_spec) return value_impl.to_value(building_blocks.Data(uri, type_spec), type_spec)
def test_removes_chained_federated_maps(self): fn = building_block_test_utils.create_identity_function('a', tf.int32) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) arg = building_blocks.Data('data', arg_type) call = _create_chained_whimsy_federated_maps([fn, fn], arg) comp = call transformed_comp, modified = tree_transformations.remove_mapped_or_applied_identity( comp) self.assertEqual( comp.compact_representation(), 'federated_map(<(a -> a),federated_map(<(a -> a),data>)>)') self.assertEqual(transformed_comp.compact_representation(), 'data') self.assertEqual(transformed_comp.type_signature, comp.type_signature) self.assertTrue(modified)
def test_with_structure_replacing_federated_map(self): function_type = computation_types.FunctionType(tf.int32, tf.int32) tuple_ref = building_blocks.Reference('arg', [ function_type, tf.int32, ]) fn = building_blocks.Selection(tuple_ref, index=0) arg = building_blocks.Selection(tuple_ref, index=1) called_fn = building_blocks.Call(fn, arg) concrete_fn = building_blocks.Lambda( 'x', tf.int32, building_blocks.Reference('x', tf.int32)) concrete_arg = building_blocks.Data('a', tf.int32) arg_tuple = building_blocks.Tuple([concrete_fn, concrete_arg]) generated_structure = building_blocks.Block([('arg', arg_tuple)], called_fn) lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks( generated_structure) self.assertTrue(modified) self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
def test_removes_federated_map_with_named_result(self): parameter_type = [('a', tf.int32), ('b', tf.int32)] fn = building_block_test_utils.create_identity_function( 'c', parameter_type) arg_type = computation_types.FederatedType(parameter_type, placements.CLIENTS) arg = building_blocks.Data('data', arg_type) call = building_block_factory.create_federated_map(fn, arg) comp = call transformed_comp, modified = tree_transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(comp.compact_representation(), 'federated_map(<(c -> c),data>)') self.assertEqual(transformed_comp.compact_representation(), 'data') self.assertEqual(transformed_comp.type_signature, comp.type_signature) self.assertTrue(modified)
def test_nested_lambdas(self): data = building_blocks.Data('data', tf.int32) input1 = building_blocks.Reference('a', data.type_signature) first_level_call = building_blocks.Call( building_blocks.Lambda('a', input1.type_signature, input1), data) input2 = building_blocks.Reference('b', first_level_call.type_signature) second_level_call = building_blocks.Call( building_blocks.Lambda('b', input2.type_signature, input2), first_level_call) transformed_comp, modified = tree_transformations.uniquify_reference_names( second_level_call) self.assertEqual(transformed_comp.compact_representation(), '(b -> b)((a -> a)(data))') tree_analysis.check_has_unique_names(transformed_comp) self.assertFalse(modified)
def test_with_higher_level_lambdas(self): self.skipTest('b/146904968') data = building_blocks.Data('a', tf.int32) dummy = building_blocks.Reference('z', tf.int32) lowest_lambda = building_blocks.Lambda( 'z', tf.int32, building_blocks.Tuple([dummy, building_blocks.Reference('x', tf.int32)])) middle_lambda = building_blocks.Lambda('x', tf.int32, lowest_lambda) lam_arg = building_blocks.Reference('x', middle_lambda.type_signature) rez = building_blocks.Call(lam_arg, data) left_lambda = building_blocks.Lambda('x', middle_lambda.type_signature, rez) higher_call = building_blocks.Call(left_lambda, middle_lambda) high_call = building_blocks.Call(higher_call, data) lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks( high_call) self.assertTrue(modified) self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
def test_returns_string_for_struct_with_no_names(self): data = building_blocks.Data('data', tf.int32) comp = building_blocks.Struct([data, data]) self.assertEqual(comp.compact_representation(), '<data,data>') # pyformat: disable self.assertEqual( comp.formatted_representation(), '<\n' ' data,\n' ' data\n' '>' ) self.assertEqual( comp.structural_representation(), 'Struct\n' '|\n' '[data, data]' )
def test_blocks_nested_inside_of_locals(self): data = building_blocks.Data('data', tf.int32) lower_block = building_blocks.Block([('a', data)], data) middle_block = building_blocks.Block([('a', lower_block)], data) higher_block = building_blocks.Block([('a', middle_block)], data) y_ref = building_blocks.Reference('a', tf.int32) lower_block_with_y_ref = building_blocks.Block([('a', y_ref)], data) middle_block_with_y_ref = building_blocks.Block( [('a', lower_block_with_y_ref)], data) higher_block_with_y_ref = building_blocks.Block( [('a', middle_block_with_y_ref)], data) multiple_bindings_highest_block = building_blocks.Block( [('a', higher_block), ('a', higher_block_with_y_ref)], higher_block_with_y_ref) transformed_comp = self.assert_transforms( multiple_bindings_highest_block, 'uniquify_names_blocks_nested_inside_of_locals.expected') tree_analysis.check_has_unique_names(transformed_comp)
def create_dummy_called_federated_map(parameter_name, parameter_type=tf.int32): r"""Returns a dummy called federated map. Call / \ federated_map Tuple | [Lambda(x), data] | Ref(x) Args: parameter_name: The name of the parameter. parameter_type: The type of the parameter. """ fn = create_identity_function(parameter_name, parameter_type) arg_type = computation_types.FederatedType(parameter_type, placements.CLIENTS) arg = building_blocks.Data('data', arg_type) return building_block_factory.create_federated_map(fn, arg)
def test_replaces_chained_intrinsics(self): fn = test_utils.create_lambda_to_dummy_called_intrinsic(parameter_name='a') arg = building_blocks.Data('data', tf.int32) call = test_utils.create_chained_calls([fn, fn], arg) comp = call uri = 'intrinsic' body = lambda x: x transformed_comp, modified = value_transformations.replace_intrinsics_with_callable( comp, uri, body, context_stack_impl.context_stack) self.assertEqual(comp.compact_representation(), '(a -> intrinsic(a))((a -> intrinsic(a))(data))') self.assertEqual( transformed_comp.compact_representation(), '(a -> (intrinsic_arg -> intrinsic_arg)(a))((a -> (intrinsic_arg -> intrinsic_arg)(a))(data))' ) self.assertEqual(transformed_comp.type_signature, comp.type_signature) self.assertTrue(modified)
def create_whimsy_called_federated_apply(parameter_name, parameter_type=tf.int32): r"""Returns a whimsy called federated apply. Call / \ federated_apply Tuple | [Lambda(x), data] | Ref(x) Args: parameter_name: The name of the parameter. parameter_type: The type of the parameter. """ fn = create_identity_function(parameter_name, parameter_type) arg_type = computation_types.FederatedType(parameter_type, placements.SERVER) arg = building_blocks.Data('data', arg_type) return building_block_factory.create_federated_apply(fn, arg)
def test_with_multiple_reference_indirection(self): identity_lam = building_blocks.Lambda( 'x', tf.int32, building_blocks.Reference('x', tf.int32)) tuple_wrapping_ref = building_blocks.Tuple( [building_blocks.Reference('a', identity_lam.type_signature)]) selection_from_ref = building_blocks.Selection( building_blocks.Reference('b', tuple_wrapping_ref.type_signature), index=0) data = building_blocks.Data('a', tf.int32) called_lambda_with_indirection = building_blocks.Call( building_blocks.Reference('c', selection_from_ref.type_signature), data) blk = building_blocks.Block([ ('a', identity_lam), ('b', tuple_wrapping_ref), ('c', selection_from_ref), ], called_lambda_with_indirection) lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks( blk) self.assertTrue(modified) self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
def test_handles_federated_broadcasts_nested_in_tuple(self): first_broadcast = compiler_test_utils.create_dummy_called_federated_broadcast( ) packed_broadcast = building_blocks.Tuple([ building_blocks.Data( 'a', computation_types.FederatedType( computation_types.TensorType(tf.int32), placements.SERVER)), first_broadcast ]) sel = building_blocks.Selection(packed_broadcast, index=0) second_broadcast = building_block_factory.create_federated_broadcast(sel) comp = building_blocks.Lambda('a', tf.int32, second_broadcast) uri = [intrinsic_defs.FEDERATED_BROADCAST.uri] before, after = transformations.force_align_and_split_by_intrinsics( comp, uri) self.assertIsInstance(before, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri)) self.assertIsInstance(after, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
def test_block_lambda_block_lambda(self): x_ref = building_blocks.Reference('a', tf.int32) inner_lambda = building_blocks.Lambda('a', tf.int32, x_ref) called_lambda = building_blocks.Call(inner_lambda, x_ref) lower_block = building_blocks.Block([('a', x_ref), ('a', x_ref)], called_lambda) second_lambda = building_blocks.Lambda('a', tf.int32, lower_block) second_call = building_blocks.Call(second_lambda, x_ref) data = building_blocks.Data('data', tf.int32) last_block = building_blocks.Block([('a', data), ('a', x_ref)], second_call) transformed_comp, modified = tree_transformations.uniquify_reference_names( last_block) self.assertEqual( last_block.compact_representation(), '(let a=data,a=a in (a -> (let a=a,a=a in (a -> a)(a)))(a))') self.assertEqual( transformed_comp.compact_representation(), '(let a=data,_var1=a in (_var2 -> (let _var3=_var2,_var4=_var3 in (_var5 -> _var5)(_var4)))(_var1))' ) tree_analysis.check_has_unique_names(transformed_comp) self.assertTrue(modified)
def test_compiles_lambda_under_federated_comp_to_tf(self): ref_to_x = building_blocks.Reference( 'x', computation_types.StructType([tf.int32, tf.float32])) identity_lambda = building_blocks.Lambda(ref_to_x.name, ref_to_x.type_signature, ref_to_x) federated_data = building_blocks.Data( 'a', computation_types.FederatedType( computation_types.StructType([tf.int32, tf.float32]), placements.SERVER)) applied = building_block_factory.create_federated_apply( identity_lambda, federated_data) transformed = compiler.compile_local_subcomputations_to_tensorflow( applied) self.assertIsInstance(transformed, building_blocks.Call) self.assertIsInstance(transformed.function, building_blocks.Intrinsic) self.assertIsInstance(transformed.argument[0], building_blocks.CompiledComputation) self.assertEqual(transformed.argument[1], federated_data) self.assertEqual(transformed.argument[0].type_signature, identity_lambda.type_signature)
def test_returns_string_for_comp_with_right_overhang(self): ref = building_blocks.Reference('a', tf.int32) data = building_blocks.Data('data', tf.int32) tup = building_blocks.Tuple([ref, data, data, data, data]) sel = building_blocks.Selection(tup, index=0) fn = building_blocks.Lambda(ref.name, ref.type_signature, sel) comp = building_blocks.Call(fn, data) compact_string = comp.compact_representation() self.assertEqual(compact_string, '(a -> <a,data,data,data,data>[0])(data)') formatted_string = comp.formatted_representation() # pyformat: disable self.assertEqual( formatted_string, '(a -> <\n' ' a,\n' ' data,\n' ' data,\n' ' data,\n' ' data\n' '>[0])(data)' ) # pyformat: enable structural_string = comp.structural_representation() # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' 'Lambda(a) data\n' '|\n' 'Sel(0)\n' '|\n' 'Tuple\n' '|\n' '[Ref(a), data, data, data, data]' )
def test_ok_lambda_binding_of_new_variable(self): y_ref = building_blocks.Reference('y', tf.int32) lambda_1 = building_blocks.Lambda('y', tf.int32, y_ref) x_data = building_blocks.Data('x', tf.int32) single_block = building_blocks.Block([('x', x_data)], lambda_1) tree_analysis.check_has_unique_names(single_block)