def test_remove_mapped_or_applied_identity_does_not_remove_called_lambda( self): x = computation_building_blocks.Reference('x', tf.int32) dummy_lambda = computation_building_blocks.Lambda('x', tf.int32, x) test_arg = computation_building_blocks.Data('test', tf.int32) called = computation_building_blocks.Call(dummy_lambda, test_arg) self.assertEqual(str(called), '(x -> x)(test)') self.assertEqual( str(transformations.remove_mapped_or_applied_identity(called)), '(x -> x)(test)')
def test_remove_mapped_or_applied_identity_does_not_remove_called_lambda( self): fn = _create_lambda_to_identity(tf.int32) arg = computation_building_blocks.Data('x', tf.int32) call = computation_building_blocks.Call(fn, arg) comp = call transformed_comp = transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(comp.tff_repr, '(arg -> arg)(x)') self.assertEqual(transformed_comp.tff_repr, '(arg -> arg)(x)')
def test_remove_mapped_or_applied_identity_removes_identity( self, uri, type_spec, comp_factory): fn = _create_lambda_to_identity(tf.int32) arg = computation_building_blocks.Data('x', type_spec) call = comp_factory(fn, arg) comp = call transformed_comp = transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(comp.tff_repr, '{}(<(arg -> arg),x>)'.format(uri)) self.assertEqual(transformed_comp.tff_repr, 'x')
def test_remove_mapped_or_applied_identity_removes_identity( self, uri, data_type): data = computation_building_blocks.Data('x', data_type) identity_arg = computation_building_blocks.Reference('arg', tf.float32) identity_lam = computation_building_blocks.Lambda( 'arg', tf.float32, identity_arg) arg_tuple = computation_building_blocks.Tuple([identity_lam, data]) function_type = computation_types.FunctionType( [arg_tuple.type_signature[0], arg_tuple.type_signature[1]], arg_tuple.type_signature[1]) intrinsic = computation_building_blocks.Intrinsic(uri, function_type) call = computation_building_blocks.Call(intrinsic, arg_tuple) self.assertEqual(str(call), '{}(<(arg -> arg),x>)'.format(uri)) reduced = transformations.remove_mapped_or_applied_identity(call) self.assertEqual(str(reduced), 'x')
def test_remove_mapped_or_applied_identity_removes_multiple_identities( self): fn = _create_lambda_to_identity(tf.int32) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) arg = computation_building_blocks.Data('x', arg_type) call = _create_chained_called_federated_map(fn, arg, 2) comp = call transformed_comp = transformations.remove_mapped_or_applied_identity( comp) self.assertEqual( comp.tff_repr, 'federated_map(<(arg -> arg),federated_map(<(arg -> arg),x>)>)') self.assertEqual(transformed_comp.tff_repr, 'x')
def test_remove_mapped_or_applied_identity_does_not_remove_other_intrinsic( self): fn = _create_lambda_to_identity(tf.int32) arg = computation_building_blocks.Data('x', tf.int32) intrinsic_type = computation_types.FunctionType( [fn.type_signature, arg.type_signature], arg.type_signature) intrinsic = computation_building_blocks.Intrinsic( 'dummy', intrinsic_type) tup = computation_building_blocks.Tuple((fn, arg)) call = computation_building_blocks.Call(intrinsic, tup) comp = call transformed_comp = transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(comp.tff_repr, 'dummy(<(arg -> arg),x>)') self.assertEqual(transformed_comp.tff_repr, 'dummy(<(arg -> arg),x>)')
def compile(self, computation_to_compile): """Compiles `computation_to_compile`. Args: computation_to_compile: An instance of `computation_base.Computation` to compile. Returns: An instance of `computation_base.Computation` that repeesents the result. """ py_typecheck.check_type(computation_to_compile, computation_base.Computation) computation_proto = computation_impl.ComputationImpl.get_proto( computation_to_compile) py_typecheck.check_type(computation_proto, pb.Computation) comp = building_blocks.ComputationBuildingBlock.from_proto( computation_proto) # TODO(b/113123410): Add a compiler options argument that characterizes the # desired form of the output. To be driven by what the specific backend the # pipeline is targeting is able to understand. Pending a more fleshed out # design of the backend API. # Replace intrinsics with their bodies, for now manually in a fixed order. # TODO(b/113123410): Replace this with a more automated implementation that # does not rely on manual maintenance. comp, _ = value_transformations.replace_all_intrinsics_with_bodies( comp, self._context_stack) # Replaces called lambdas with LET constructs with a single local symbol. comp, _ = transformations.replace_called_lambda_with_block(comp) # Removes maped or applied identities. comp, _ = transformations.remove_mapped_or_applied_identity(comp) # Remove duplicate computations. This is important! otherwise the semantics # non-deterministic computations (e.g. a `tff.tf_computation` depending on # `tf.random`) will give unexpected behavior. Additionally, this may reduce # the amount of calls into TF for some ASTs. comp, _ = transformations.uniquify_reference_names(comp) comp, _ = transformations.extract_computations(comp) comp, _ = transformations.remove_duplicate_computations(comp) return computation_impl.ComputationImpl(comp.proto, self._context_stack)
def test_remove_mapped_or_applied_identity_does_not_remove_other_intrinsic( self): data_type = tf.int32 uri = 'dummy' data = computation_building_blocks.Data('x', data_type) identity_arg = computation_building_blocks.Reference('arg', tf.float32) identity_lam = computation_building_blocks.Lambda( 'arg', tf.float32, identity_arg) arg_tuple = computation_building_blocks.Tuple([identity_lam, data]) function_type = computation_types.FunctionType( [arg_tuple.type_signature[0], arg_tuple.type_signature[1]], arg_tuple.type_signature[1]) intrinsic = computation_building_blocks.Intrinsic(uri, function_type) call = computation_building_blocks.Call(intrinsic, arg_tuple) comp = call transformed_comp = transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(str(comp), '{}(<(arg -> arg),x>)'.format(uri)) self.assertEqual(str(transformed_comp), '{}(<(arg -> arg),x>)'.format(uri))
def test_remove_mapped_or_applied_identity_removes_multiple_identities( self): calling_arg_type = computation_types.FederatedType( tf.int32, placements.CLIENTS) calling_arg = computation_building_blocks.Data('x', calling_arg_type) arg_type = calling_arg.type_signature.member arg = calling_arg for _ in range(2): lam = _create_lambda_to_identity(arg_type) call = _create_call_to_federated_map(lam, arg) arg_type = call.function.type_signature.result.member arg = call comp = call uri = intrinsic_defs.FEDERATED_MAP.uri transformed_comp = transformations.remove_mapped_or_applied_identity( comp) self.assertEqual( str(comp), '{uri}(<(arg -> arg),{uri}(<(arg -> arg),x>)>)'.format(uri=uri)) self.assertEqual(str(transformed_comp), 'x')
def transformation_fn(x): x, _ = transformations.remove_mapped_or_applied_identity(x) x, _ = transformations.inline_block_locals(x) x, _ = transformations.replace_selection_from_tuple_with_element(x) return x
def transformation_fn(x): x, _ = transformations.remove_mapped_or_applied_identity(x) return x
def test_remove_mapped_or_applied_identity_raises_type_error(self): with self.assertRaises(TypeError): transformations.remove_mapped_or_applied_identity(None)