def test_nested_lambda_block_overwrite_scope_snapshot(self): innermost_x = computation_building_blocks.Reference('x', tf.int32) inner_lambda = computation_building_blocks.Lambda( 'x', tf.int32, innermost_x) second_x = computation_building_blocks.Reference('x', tf.int32) called_lambda = computation_building_blocks.Call( inner_lambda, second_x) block_input = computation_building_blocks.Reference( 'block_in', tf.int32) lower_block = computation_building_blocks.Block([('x', block_input)], called_lambda) second_lambda = computation_building_blocks.Lambda( 'block_in', tf.int32, lower_block) third_x = computation_building_blocks.Reference('x', tf.int32) second_call = computation_building_blocks.Call(second_lambda, third_x) final_input = computation_building_blocks.Data('test_data', tf.int32) last_block = computation_building_blocks.Block([('x', final_input)], second_call) global_snapshot = transformations.scope_count_snapshot(last_block) self.assertEqual( str(last_block), '(let x=test_data in (block_in -> (let x=block_in in (x -> x)(x)))(x))' ) self.assertLen(global_snapshot, 4) self.assertEqual(global_snapshot[str(inner_lambda)], {'x': 1}) self.assertEqual(global_snapshot[str(lower_block)], {'x': 1}) self.assertEqual(global_snapshot[str(second_lambda)], {'block_in': 1}) self.assertEqual(global_snapshot[str(last_block)], {'x': 1})
def test_simple_block_snapshot(self): used1 = computation_building_blocks.Reference('used1', tf.int32) used2 = computation_building_blocks.Reference('used2', tf.int32) ref = computation_building_blocks.Reference('x', used1.type_signature) lower_block = computation_building_blocks.Block([('x', used1)], ref) higher_block = computation_building_blocks.Block([('used1', used2)], lower_block) self.assertEqual(str(higher_block), '(let used1=used2 in (let x=used1 in x))') snapshot = transformations.scope_count_snapshot(higher_block) self.assertEqual(snapshot[str(lower_block)]['x'], 1) self.assertEqual(snapshot[str(higher_block)]['used1'], 1) self.assertIsNone(snapshot[str(higher_block)].get('x'))
def test_conflicting_name_resolved_inlining(self): red_herring_arg = computation_building_blocks.Reference( 'redherring', tf.int32) used_arg = computation_building_blocks.Reference('used', tf.int32) ref = computation_building_blocks.Reference('x', used_arg.type_signature) lower_block = computation_building_blocks.Block([('x', used_arg)], ref) higher_block = computation_building_blocks.Block( [('x', red_herring_arg)], lower_block) self.assertEqual(str(higher_block), '(let x=redherring in (let x=used in x))') inlined = transformations.inline_blocks_with_n_referenced_locals( higher_block) self.assertEqual(str(inlined), '(let in (let in used))')
def test_with_block(self): ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor()) loop = asyncio.get_event_loop() f_type = computation_types.FunctionType(tf.int32, tf.int32) a = computation_building_blocks.Reference( 'a', computation_types.NamedTupleType([('f', f_type), ('x', tf.int32)])) ret = computation_building_blocks.Block( [('f', computation_building_blocks.Selection(a, name='f')), ('x', computation_building_blocks.Selection(a, name='x'))], computation_building_blocks.Call( computation_building_blocks.Reference('f', f_type), computation_building_blocks.Call( computation_building_blocks.Reference('f', f_type), computation_building_blocks.Reference('x', tf.int32)))) comp = computation_building_blocks.Lambda(a.name, a.type_signature, ret) @computations.tf_computation(tf.int32) def add_one(x): return x + 1 v1 = loop.run_until_complete( ex.create_value(comp.proto, comp.type_signature)) v2 = loop.run_until_complete(ex.create_value(add_one)) v3 = loop.run_until_complete(ex.create_value(10, tf.int32)) v4 = loop.run_until_complete( ex.create_tuple( anonymous_tuple.AnonymousTuple([('f', v2), ('x', v3)]))) v5 = loop.run_until_complete(ex.create_call(v1, v4)) result = loop.run_until_complete(v5.compute()) self.assertEqual(result.numpy(), 12)
def test_execute_with_block(self): add_one = computation_building_blocks.ComputationBuildingBlock.from_proto( computation_impl.ComputationImpl.get_proto( computations.tf_computation(lambda x: x + 1, tf.int32))) make_10 = computation_building_blocks.ComputationBuildingBlock.from_proto( computation_impl.ComputationImpl.get_proto( computations.tf_computation(lambda: tf.constant(10)))) make_13 = computation_building_blocks.Block( [('x', computation_building_blocks.Call(make_10)), ('x', computation_building_blocks.Call( add_one, computation_building_blocks.Reference('x', tf.int32))), ('x', computation_building_blocks.Call( add_one, computation_building_blocks.Reference('x', tf.int32))), ('x', computation_building_blocks.Call( add_one, computation_building_blocks.Reference('x', tf.int32)))], computation_building_blocks.Reference('x', tf.int32)) make_13_computation = computation_impl.ComputationImpl( make_13.proto, context_stack_impl.context_stack) self.assertEqual(make_13_computation(), 13)
def test_basic_functionality_of_block_class(self): x = computation_building_blocks.Block([ ('x', computation_building_blocks.Reference('arg', (tf.int32, tf.int32))), ('y', computation_building_blocks.Selection( computation_building_blocks.Reference('x', (tf.int32, tf.int32)), index=0)) ], computation_building_blocks.Reference('y', tf.int32)) self.assertEqual(str(x.type_signature), 'int32') self.assertEqual([(k, v.tff_repr) for k, v in x.locals], [('x', 'arg'), ('y', 'x[0]')]) self.assertEqual(x.result.tff_repr, 'y') self.assertEqual( repr(x), 'Block([(\'x\', Reference(\'arg\', ' 'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)]))), ' '(\'y\', Selection(Reference(\'x\', ' 'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)])), ' 'index=0))], ' 'Reference(\'y\', TensorType(tf.int32)))') self.assertEqual(x.tff_repr, '(let x=arg,y=x[0] in y)') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'block') self.assertEqual(str(x_proto.block.result), str(x.result.proto)) for idx, loc_proto in enumerate(x_proto.block.local): loc_name, loc_value = x.locals[idx] self.assertEqual(loc_proto.name, loc_name) self.assertEqual(str(loc_proto.value), str(loc_value.proto)) self._serialize_deserialize_roundtrip_test(x)
def _traverse_block(comp, transform, context_tree, identifier_seq): """Helper function holding traversal logic for block nodes.""" comp_id = six.next(identifier_seq) transformed_locals = [] if comp.locals: first_local_name = comp.locals[0][0] first_local_comp = comp.locals[0][1] new_value = _transform_postorder_with_symbol_bindings_switch( first_local_comp, transform, context_tree, identifier_seq) transformed_locals.append((first_local_name, new_value)) context_tree.ingest_variable_binding( name=transformed_locals[0][0], value=transformed_locals[0][1], mode=MutationMode.CHILD, comp_id=comp_id) for k in range(1, len(comp.locals)): new_value = _transform_postorder_with_symbol_bindings_switch( comp.locals[k][1], transform, context_tree, identifier_seq) transformed_locals.append((comp.locals[k][0], new_value)) context_tree.ingest_variable_binding( name=transformed_locals[k][0], value=transformed_locals[k][1], mode=MutationMode.SIBLING) transformed_result = _transform_postorder_with_symbol_bindings_switch( comp.result, transform, context_tree, identifier_seq) transformed_comp = transform( computation_building_blocks.Block(transformed_locals, transformed_result), context_tree) if comp.locals: context_tree.move_to_parent_context() return transformed_comp
def _transform(comp): if not _should_transform(comp): return comp, False transformed_comp = computation_building_blocks.Block( [(comp.function.parameter_name, comp.argument)], comp.function.result) return transformed_comp, True
def __call__(self, comp): """Counts references to locals under Block and performs inlining. If the `comp` argument is a `computation_building_blocks.Block`, `__call__` selects the locals on which to perform inlining based on the threshold defined in `inlining_threshold` and the snapshot of the calling AST before any transformations are executed, stored as `counts`, before executing the inlining itself. Args: comp: The `computation_building_blocks.ComputationBuildingBlock` to be checked for the possibility of inlining. Returns: comp: A transformed version of `comp`, with locals of any of its `computation_building_blocks.Block`s which are referenced `inlining_threshold` or fewer times replaced with their associated values. All local declarations no longer referenced in the body are removed. """ self.idx += 1 if isinstance(comp, (computation_building_blocks.Block)): bound_dict = self.counts[self.initial_comp_names[self.idx]] values_to_replace = [ k for k, v in bound_dict.items() if v <= self.inlining_threshold ] names_and_values = { x[0]: x[1] for x in comp.locals if x[0] in values_to_replace } def _execute_inlining_from_bound_dict(inner_comp): """Uses `dict` bound to calling comp to inline as appropriate. Args: inner_comp: The `computation_building_blocks.ComputationBuildingBlock` to potentially inline. Returns: `computation_building_blocks.ComputationBuildingBlock`, `inner_comp` unchanged if `inner_comp` is not a `computation_building_blocks.Reference` whose name appears in `bound_dict`; otherwise the appropriate local definition. """ if (isinstance(inner_comp, computation_building_blocks.Reference) and names_and_values.get(inner_comp.name)): py_typecheck.check_type( names_and_values[inner_comp.name], computation_building_blocks.ComputationBuildingBlock) return names_and_values[inner_comp.name] return inner_comp remaining_locals = [(name, val) for name, val in comp.locals if name not in values_to_replace] return computation_building_blocks.Block( remaining_locals, transform_postorder(comp.result, _execute_inlining_from_bound_dict)) else: return comp
def test_conflicting_nested_name_inlining(self): innermost = computation_building_blocks.Reference('x', tf.int32) intermediate_arg = computation_building_blocks.Reference('y', tf.int32) item2 = computation_building_blocks.Block([('x', intermediate_arg)], innermost) item1 = computation_building_blocks.Reference('x', tf.int32) mediate_tuple = computation_building_blocks.Tuple([item1, item2]) used = computation_building_blocks.Reference('used', tf.int32) used1 = computation_building_blocks.Reference('used1', tf.int32) outer_block = computation_building_blocks.Block([('x', used), ('y', used1)], mediate_tuple) self.assertEqual(str(outer_block), '(let x=used,y=used1 in <x,(let x=y in x)>)') inlined = transformations.inline_blocks_with_n_referenced_locals( outer_block) self.assertEqual(str(inlined), '(let in <used,(let in used1)>)')
def test_propogates_dependence_up_through_block_locals(self): dummy_intrinsic = computation_building_blocks.Intrinsic( 'dummy_intrinsic', tf.int32) integer_reference = computation_building_blocks.Reference('int', tf.int32) block = computation_building_blocks.Block([('x', dummy_intrinsic)], integer_reference) dependent_nodes = tree_analysis.extract_nodes_consuming( block, dummy_intrinsic_predicate) self.assertIn(block, dependent_nodes)
def test_scope_snapshot_block_overwrite(self): innermost = computation_building_blocks.Reference('x', tf.int32) intermediate_arg = computation_building_blocks.Reference('y', tf.int32) item2 = computation_building_blocks.Block([('x', intermediate_arg)], innermost) item1 = computation_building_blocks.Reference('x', tf.int32) mediate_tuple = computation_building_blocks.Tuple([item1, item2]) used = computation_building_blocks.Reference('used', tf.int32) used1 = computation_building_blocks.Reference('used1', tf.int32) outer_block = computation_building_blocks.Block([('x', used), ('y', used1)], mediate_tuple) self.assertEqual(str(outer_block), '(let x=used,y=used1 in <x,(let x=y in x)>)') snapshot = transformations.scope_count_snapshot(outer_block) self.assertEqual(snapshot[str(item2)], {'x': 1}) self.assertEqual(snapshot[str(outer_block)], {'x': 1, 'y': 1}) self.assertIsNone(snapshot.get(str(mediate_tuple)))
def _create_chain_zipped_values(value): r"""Creates a chain of called federated zip with two values. Block-------- / \ [value=Tuple] Call | / \ [Comp1, Intrinsic Tuple Comp2, | ...] [Call, Sel(n)] / \ \ Intrinsic Tuple Ref(value) | [Sel(0), Sel(1)] \ \ Ref(value) Ref(value) NOTE: This function is intended to be used in conjunction with `_create_fn_to_append_chain_zipped_values` and will drop the tuple names. The names will be added back to the resulting computation when the zipped values are mapped to a function that flattens the chain. This nested zip -> flatten structure must be used since length of a named tuple type in the TFF type system is an element of the type proper. That is, a named tuple type of length 2 is a different type than a named tuple type of length 3, they are not simply items with the same type and different values, as would be the case if you were thinking of these as Python `list`s. It may be better to think of named tuple types in TFF as more like `struct`s. Args: value: A `computation_building_blocks.ComputationBuildingBlock` with a `type_signature` of type `computation_types.NamedTupleType` containing at least two elements. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. ValueError: If `value` does not contain at least two elements. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) named_type_signatures = anonymous_tuple.to_elements(value.type_signature) length = len(named_type_signatures) if length < 2: raise ValueError( 'Expected a value with at least two elements, received {} elements.' .format(named_type_signatures)) ref = computation_building_blocks.Reference('value', value.type_signature) symbols = ((ref.name, value), ) sel_0 = computation_building_blocks.Selection(ref, index=0) result = sel_0 for i in range(1, length): sel = computation_building_blocks.Selection(ref, index=i) values = computation_building_blocks.Tuple((result, sel)) result = _create_zip_two_values(values) return computation_building_blocks.Block(symbols, result)
def test_simple_block_inlining(self): test_arg = computation_building_blocks.Data('test_data', tf.int32) result = computation_building_blocks.Reference('test_x', test_arg.type_signature) simple_block = computation_building_blocks.Block( [('test_x', test_arg)], result) self.assertEqual(str(simple_block), '(let test_x=test_data in test_x)') inlined = transformations.inline_blocks_with_n_referenced_locals( simple_block) self.assertEqual(str(inlined), '(let in test_data)')
def create_federated_unzip(value): r"""Creates a tuple of called federated maps or applies. Block / \ [value=Comp] Tuple | [Call, Call, ...] / \ / \ Intrinsic Tuple Intrinsic Tuple | | [Lambda(arg), Ref(value)] [Lambda(arg), Ref(value)] \ \ Sel(0) Sel(1) \ \ Ref(arg) Ref(arg) This function returns a tuple of federated values given a `value` with a federated tuple type signature. Args: value: A `computation_building_blocks.ComputationBuildingBlock` with a `type_signature` of type `computation_types.NamedTupleType` containing at least one element. Returns: A `computation_building_blocks.Block`. Raises: TypeError: If any of the types do not match. ValueError: If `value` does not contain any elements. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) named_type_signatures = anonymous_tuple.to_elements( value.type_signature.member) length = len(named_type_signatures) if length == 0: raise ValueError( 'federated_zip is only supported on non-empty tuples.') value_ref = computation_building_blocks.Reference('value', value.type_signature) elements = [] fn_ref = computation_building_blocks.Reference('arg', named_type_signatures) for index, (name, _) in enumerate(named_type_signatures): sel = computation_building_blocks.Selection(fn_ref, index=index) fn = computation_building_blocks.Lambda(fn_ref.name, fn_ref.type_signature, sel) intrinsic = create_federated_map_or_apply(fn, value_ref) elements.append((name, intrinsic)) result = computation_building_blocks.Tuple(elements) symbols = ((value_ref.name, value), ) return computation_building_blocks.Block(symbols, result)
def _create_chain_zipped_values(value): r"""Creates a chain of called federated zip with two values. Block-------- / \ [value=Tuple] Call | / \ [Comp1, Intrinsic Tuple Comp2, | ...] [Call, Sel(n)] / \ \ Intrinsic Tuple Ref(value) | [Sel(0), Sel(1)] \ \ Ref(value) Ref(value) NOTE: This function is intended to be used in conjunction with `_create_fn_to_append_chain_zipped_values` and will drop the tuple names. The names will be added back to the resulting computation when the zipped values are mapped to a function that flattens the chain. Args: value: A `computation_building_blocks.ComputationBuildingBlock` with a `type_signature` of type `computation_types.NamedTupleType` containing at least two elements. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. ValueError: If `value` does not contain at least two elements. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) named_type_signatures = anonymous_tuple.to_elements(value.type_signature) length = len(named_type_signatures) if length < 2: raise ValueError( 'Expected a value with at least two elements, received {} elements.' .format(named_type_signatures)) first_name, _ = named_type_signatures[0] ref = computation_building_blocks.Reference('value', value.type_signature) symbols = ((ref.name, value), ) sel_0 = computation_building_blocks.Selection(ref, index=0) result = (first_name, sel_0) for i in range(1, length): name, _ = named_type_signatures[i] sel = computation_building_blocks.Selection(ref, index=i) values = computation_building_blocks.Tuple((result, (name, sel))) result = _create_zip_two_values(values) return computation_building_blocks.Block(symbols, result)
def test_inline_conflicting_locals(self): arg_comp = computation_building_blocks.Reference( 'arg', [tf.int32, tf.int32]) selected = computation_building_blocks.Selection(arg_comp, index=0) internal_arg = computation_building_blocks.Reference('arg', tf.int32) block = computation_building_blocks.Block([('arg', selected)], internal_arg) lam = computation_building_blocks.Lambda('arg', arg_comp.type_signature, block) self.assertEqual(str(lam), '(arg -> (let arg=arg[0] in arg))') inlined = transformations.inline_blocks_with_n_referenced_locals(lam) self.assertEqual(str(inlined), '(arg -> (let in arg[0]))')
def _extract_from_lambda(comp): """Returns a new computation with all intrinsics extracted.""" if _is_called_intrinsic(comp.result): called_intrinsic = comp.result name = six.next(name_generator) variables = ((name, called_intrinsic), ) ref = computation_building_blocks.Reference( name, called_intrinsic.type_signature) if not _contains_unbound_reference(comp.result, comp.parameter_name): fn = computation_building_blocks.Lambda( comp.parameter_name, comp.parameter_type, ref) return computation_building_blocks.Block(variables, fn) else: block = computation_building_blocks.Block(variables, ref) return computation_building_blocks.Lambda( comp.parameter_name, comp.parameter_type, block) else: block = comp.result extracted_variables = [] retained_variables = [] for name, variable in block.locals: names = [n for n, _ in retained_variables] if (not _contains_unbound_reference(variable, comp.parameter_name) and not _contains_unbound_reference(variable, names)): extracted_variables.append((name, variable)) else: retained_variables.append((name, variable)) if retained_variables: result = computation_building_blocks.Block( retained_variables, block.result) else: result = block.result fn = computation_building_blocks.Lambda(comp.parameter_name, comp.parameter_type, result) block = computation_building_blocks.Block(extracted_variables, fn) return _extract_from_block(block)
def test_replace_called_lambda_does_not_replace_separated_called_lambda( self): arg = computation_building_blocks.Reference('arg', tf.int32) lam = _create_lambda_to_identity(arg.type_signature) block = computation_building_blocks.Block([], lam) call = computation_building_blocks.Call(block, arg) comp = call transformed_comp = transformations.replace_called_lambda_with_block( comp) self.assertEqual(str(transformed_comp), str(comp)) self.assertEqual(str(transformed_comp), '(let in (arg -> arg))(arg)')
def _extract_from_block(comp): """Returns a new computation with all intrinsics extracted.""" if _is_called_intrinsic(comp.result): called_intrinsic = comp.result name = six.next(name_generator) variables = comp.locals variables.append((name, called_intrinsic)) result = computation_building_blocks.Reference( name, called_intrinsic.type_signature) return computation_building_blocks.Block(variables, result) elif isinstance(comp.result, computation_building_blocks.Block): return computation_building_blocks.Block(comp.locals + comp.result.locals, comp.result.result) else: variables = [] for name, variable in comp.locals: if isinstance(variable, computation_building_blocks.Block): variables.extend(variable.locals) variables.append((name, variable.result)) else: variables.append((name, variable)) return computation_building_blocks.Block(variables, comp.result)
def test_multiple_inline_for_nested_block(self): used1 = computation_building_blocks.Reference('used1', tf.int32) used2 = computation_building_blocks.Reference('used2', tf.int32) ref = computation_building_blocks.Reference('x', used1.type_signature) lower_block = computation_building_blocks.Block([('x', used1)], ref) higher_block = computation_building_blocks.Block([('used1', used2)], lower_block) inlined = transformations.inline_blocks_with_n_referenced_locals( higher_block) self.assertEqual(str(higher_block), '(let used1=used2 in (let x=used1 in x))') self.assertEqual(str(inlined), '(let in (let in used2))') user_inlined_lower_block = computation_building_blocks.Block( [('x', used1)], used1) user_inlined_higher_block = computation_building_blocks.Block( [('used1', used2)], user_inlined_lower_block) self.assertEqual(str(user_inlined_higher_block), '(let used1=used2 in (let x=used1 in used1))') inlined_noop = transformations.inline_blocks_with_n_referenced_locals( user_inlined_higher_block) self.assertEqual(str(inlined_noop), '(let used1=used2 in (let in used1))')
def create_dummy_block(comp, variable_name, variable_type=tf.int32): r"""Returns an identity block. Block / \ [x=data] Comp Args: comp: The computation to use as the result. variable_name: The name of the variable. variable_type: The type of the variable. """ data = computation_building_blocks.Data('data', variable_type) return computation_building_blocks.Block([(variable_name, data)], comp)
def test_propogates_dependence_into_binding_to_reference(self): fed_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) ref_to_x = computation_building_blocks.Reference('x', fed_type) federated_zero = computation_building_blocks.Intrinsic( intrinsic_defs.GENERIC_ZERO.uri, fed_type) def federated_zero_predicate(x): return isinstance(x, computation_building_blocks.Intrinsic ) and x.uri == intrinsic_defs.GENERIC_ZERO.uri block = computation_building_blocks.Block([('x', federated_zero)], ref_to_x) dependent_nodes = tree_analysis.extract_nodes_consuming( block, federated_zero_predicate) self.assertIn(ref_to_x, dependent_nodes)
def create_computation_appending(comp1, comp2): r"""Returns a block appending `comp2` to `comp1`. Block / \ [comps=Tuple] Tuple | | [Comp, Comp] [Sel(0), ..., Sel(0), Sel(1)] \ \ \ Sel(0) Sel(n) Ref(comps) \ \ Ref(comps) Ref(comps) Args: comp1: A `computation_building_blocks.ComputationBuildingBlock` with a `type_signature` of type `computation_type.NamedTupleType`. comp2: A `computation_building_blocks.ComputationBuildingBlock` or a named computation (a tuple pair of name, computation) representing a single element of an `anonymous_tuple.AnonymousTuple`. Returns: A `computation_building_blocks.Block`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( comp1, computation_building_blocks.ComputationBuildingBlock) if isinstance(comp2, computation_building_blocks.ComputationBuildingBlock): name2 = None elif py_typecheck.is_name_value_pair( comp2, name_required=False, value_type=computation_building_blocks.ComputationBuildingBlock): name2, comp2 = comp2 else: raise TypeError('Unexpected tuple element: {}.'.format(comp2)) comps = computation_building_blocks.Tuple((comp1, comp2)) ref = computation_building_blocks.Reference('comps', comps.type_signature) sel_0 = computation_building_blocks.Selection(ref, index=0) elements = [] named_type_signatures = anonymous_tuple.to_elements(comp1.type_signature) for index, (name, _) in enumerate(named_type_signatures): sel = computation_building_blocks.Selection(sel_0, index=index) elements.append((name, sel)) sel_1 = computation_building_blocks.Selection(ref, index=1) elements.append((name2, sel_1)) result = computation_building_blocks.Tuple(elements) symbols = ((ref.name, comps), ) return computation_building_blocks.Block(symbols, result)
def create_identity_block(variable_name, comp): r"""Returns an identity block. Block / \ [x=comp] Ref(x) Args: variable_name: The name of the variable. comp: The computation to use as the variable. """ ref = computation_building_blocks.Reference(variable_name, comp.type_signature) return computation_building_blocks.Block([(variable_name, comp)], ref)
def _transform(comp): """Internal function to break down Call-Lambda and build Block.""" if not isinstance(comp, computation_building_blocks.Call): return comp elif not isinstance(comp.function, computation_building_blocks.Lambda): return comp py_typecheck.check_type( comp.argument, computation_building_blocks.ComputationBuildingBlock) arg = comp.argument lam = comp.function param_name = lam.parameter_name result = lam.result return computation_building_blocks.Block([(param_name, arg)], result)
def test_no_inlining_if_referenced_twice(self): test_arg = computation_building_blocks.Data('test_data', tf.int32) ref1 = computation_building_blocks.Reference('test_x', test_arg.type_signature) ref2 = computation_building_blocks.Reference('test_x', test_arg.type_signature) result = computation_building_blocks.Tuple([ref1, ref2]) simple_block = computation_building_blocks.Block( [('test_x', test_arg)], result) self.assertEqual(str(simple_block), '(let test_x=test_data in <test_x,test_x>)') inlined = transformations.inline_blocks_with_n_referenced_locals( simple_block) self.assertEqual(str(inlined), str(simple_block))
def _transform_functional_args(comps): r"""Transforms the functional computations `comps`. Given a computation containing `n` called intrinsics with `m` arguments, this function constructs the following computation from the functional arguments of the called intrinsic: Block / \ [fn=Tuple] Lambda(arg) | \ [Comp(f1), Comp(f2), ...] Tuple | [Call, Call, ...] / \ / \ Sel(0) Sel(0) Sel(1) Sel(1) / / / / Ref(fn) Ref(arg) Ref(fn) Ref(arg) with one `computation_building_blocks.Call` for each `n`. This computation represents one of `m` arguments that should be passed to the call of the transformed computation. Args: comps: a Python list of computations. Returns: A `computation_building_blocks.Block`. """ functions = computation_building_blocks.Tuple(comps) functions_name = six.next(name_generator) functions_ref = computation_building_blocks.Reference( functions_name, functions.type_signature) arg_name = six.next(name_generator) arg_type = [element.type_signature.parameter for element in comps] arg_ref = computation_building_blocks.Reference(arg_name, arg_type) elements = [] for index in range(len(comps)): sel_fn = computation_building_blocks.Selection(functions_ref, index=index) sel_arg = computation_building_blocks.Selection(arg_ref, index=index) call = computation_building_blocks.Call(sel_fn, sel_arg) elements.append(call) calls = computation_building_blocks.Tuple(elements) fn = computation_building_blocks.Lambda(arg_ref.name, arg_ref.type_signature, calls) return computation_building_blocks.Block( ((functions_ref.name, functions), ), fn)
def _extract_from_call(comp): """Returns a new computation with all intrinsics extracted.""" if _is_called_intrinsic(comp.argument): called_intrinsic = comp.argument name = six.next(name_generator) variables = ((name, called_intrinsic),) result = computation_building_blocks.Reference( name, called_intrinsic.type_signature) else: block = comp.argument variables = block.locals result = block.result call = computation_building_blocks.Call(comp.function, result) block = computation_building_blocks.Block(variables, call) return _extract_from_block(block)
def test_no_reduce_separated_lambda_and_call(self): @computations.federated_computation(tf.int32) def foo(x): return x comp = _to_building_block(foo) block_wrapped_comp = computation_building_blocks.Block([], comp) test_arg = computation_building_blocks.Data('test', tf.int32) called_block = computation_building_blocks.Call( block_wrapped_comp, test_arg) lambda_reduced_comp = transformations.replace_called_lambdas_with_block( called_block) self.assertEqual(str(called_block), '(let in (foo_arg -> foo_arg))(test)') self.assertEqual(str(called_block), str(lambda_reduced_comp))