Пример #1
0
 def test_remove_mapped_or_applied_identity_does_not_remove_called_lambda(
         self):
     x = computation_building_blocks.Reference('x', tf.int32)
     dummy_lambda = computation_building_blocks.Lambda('x', tf.int32, x)
     test_arg = computation_building_blocks.Data('test', tf.int32)
     called = computation_building_blocks.Call(dummy_lambda, test_arg)
     self.assertEqual(str(called), '(x -> x)(test)')
     self.assertEqual(
         str(transformations.remove_mapped_or_applied_identity(called)),
         '(x -> x)(test)')
Пример #2
0
    def test_remove_mapped_or_applied_identity_does_not_remove_called_lambda(
            self):
        fn = _create_lambda_to_identity(tf.int32)
        arg = computation_building_blocks.Data('x', tf.int32)
        call = computation_building_blocks.Call(fn, arg)
        comp = call

        transformed_comp = transformations.remove_mapped_or_applied_identity(
            comp)

        self.assertEqual(comp.tff_repr, '(arg -> arg)(x)')
        self.assertEqual(transformed_comp.tff_repr, '(arg -> arg)(x)')
Пример #3
0
    def test_remove_mapped_or_applied_identity_removes_identity(
            self, uri, type_spec, comp_factory):
        fn = _create_lambda_to_identity(tf.int32)
        arg = computation_building_blocks.Data('x', type_spec)
        call = comp_factory(fn, arg)
        comp = call

        transformed_comp = transformations.remove_mapped_or_applied_identity(
            comp)

        self.assertEqual(comp.tff_repr, '{}(<(arg -> arg),x>)'.format(uri))
        self.assertEqual(transformed_comp.tff_repr, 'x')
Пример #4
0
 def test_remove_mapped_or_applied_identity_removes_identity(
         self, uri, data_type):
     data = computation_building_blocks.Data('x', data_type)
     identity_arg = computation_building_blocks.Reference('arg', tf.float32)
     identity_lam = computation_building_blocks.Lambda(
         'arg', tf.float32, identity_arg)
     arg_tuple = computation_building_blocks.Tuple([identity_lam, data])
     function_type = computation_types.FunctionType(
         [arg_tuple.type_signature[0], arg_tuple.type_signature[1]],
         arg_tuple.type_signature[1])
     intrinsic = computation_building_blocks.Intrinsic(uri, function_type)
     call = computation_building_blocks.Call(intrinsic, arg_tuple)
     self.assertEqual(str(call), '{}(<(arg -> arg),x>)'.format(uri))
     reduced = transformations.remove_mapped_or_applied_identity(call)
     self.assertEqual(str(reduced), 'x')
Пример #5
0
    def test_remove_mapped_or_applied_identity_removes_multiple_identities(
            self):
        fn = _create_lambda_to_identity(tf.int32)
        arg_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)
        arg = computation_building_blocks.Data('x', arg_type)
        call = _create_chained_called_federated_map(fn, arg, 2)
        comp = call

        transformed_comp = transformations.remove_mapped_or_applied_identity(
            comp)

        self.assertEqual(
            comp.tff_repr,
            'federated_map(<(arg -> arg),federated_map(<(arg -> arg),x>)>)')
        self.assertEqual(transformed_comp.tff_repr, 'x')
Пример #6
0
    def test_remove_mapped_or_applied_identity_does_not_remove_other_intrinsic(
            self):
        fn = _create_lambda_to_identity(tf.int32)
        arg = computation_building_blocks.Data('x', tf.int32)
        intrinsic_type = computation_types.FunctionType(
            [fn.type_signature, arg.type_signature], arg.type_signature)
        intrinsic = computation_building_blocks.Intrinsic(
            'dummy', intrinsic_type)
        tup = computation_building_blocks.Tuple((fn, arg))
        call = computation_building_blocks.Call(intrinsic, tup)
        comp = call

        transformed_comp = transformations.remove_mapped_or_applied_identity(
            comp)

        self.assertEqual(comp.tff_repr, 'dummy(<(arg -> arg),x>)')
        self.assertEqual(transformed_comp.tff_repr, 'dummy(<(arg -> arg),x>)')
Пример #7
0
    def compile(self, computation_to_compile):
        """Compiles `computation_to_compile`.

    Args:
      computation_to_compile: An instance of `computation_base.Computation` to
        compile.

    Returns:
      An instance of `computation_base.Computation` that repeesents the result.
    """
        py_typecheck.check_type(computation_to_compile,
                                computation_base.Computation)
        computation_proto = computation_impl.ComputationImpl.get_proto(
            computation_to_compile)
        py_typecheck.check_type(computation_proto, pb.Computation)
        comp = building_blocks.ComputationBuildingBlock.from_proto(
            computation_proto)

        # TODO(b/113123410): Add a compiler options argument that characterizes the
        # desired form of the output. To be driven by what the specific backend the
        # pipeline is targeting is able to understand. Pending a more fleshed out
        # design of the backend API.

        # Replace intrinsics with their bodies, for now manually in a fixed order.
        # TODO(b/113123410): Replace this with a more automated implementation that
        # does not rely on manual maintenance.
        comp, _ = value_transformations.replace_all_intrinsics_with_bodies(
            comp, self._context_stack)

        # Replaces called lambdas with LET constructs with a single local symbol.
        comp, _ = transformations.replace_called_lambda_with_block(comp)

        # Removes maped or applied identities.
        comp, _ = transformations.remove_mapped_or_applied_identity(comp)

        # Remove duplicate computations. This is important! otherwise the semantics
        # non-deterministic computations (e.g. a `tff.tf_computation` depending on
        # `tf.random`) will give unexpected behavior. Additionally, this may reduce
        # the amount of calls into TF for some ASTs.
        comp, _ = transformations.uniquify_reference_names(comp)
        comp, _ = transformations.extract_computations(comp)
        comp, _ = transformations.remove_duplicate_computations(comp)

        return computation_impl.ComputationImpl(comp.proto,
                                                self._context_stack)
Пример #8
0
    def test_remove_mapped_or_applied_identity_does_not_remove_other_intrinsic(
            self):
        data_type = tf.int32
        uri = 'dummy'
        data = computation_building_blocks.Data('x', data_type)
        identity_arg = computation_building_blocks.Reference('arg', tf.float32)
        identity_lam = computation_building_blocks.Lambda(
            'arg', tf.float32, identity_arg)
        arg_tuple = computation_building_blocks.Tuple([identity_lam, data])
        function_type = computation_types.FunctionType(
            [arg_tuple.type_signature[0], arg_tuple.type_signature[1]],
            arg_tuple.type_signature[1])
        intrinsic = computation_building_blocks.Intrinsic(uri, function_type)
        call = computation_building_blocks.Call(intrinsic, arg_tuple)
        comp = call

        transformed_comp = transformations.remove_mapped_or_applied_identity(
            comp)

        self.assertEqual(str(comp), '{}(<(arg -> arg),x>)'.format(uri))
        self.assertEqual(str(transformed_comp),
                         '{}(<(arg -> arg),x>)'.format(uri))
Пример #9
0
    def test_remove_mapped_or_applied_identity_removes_multiple_identities(
            self):
        calling_arg_type = computation_types.FederatedType(
            tf.int32, placements.CLIENTS)
        calling_arg = computation_building_blocks.Data('x', calling_arg_type)
        arg_type = calling_arg.type_signature.member
        arg = calling_arg
        for _ in range(2):
            lam = _create_lambda_to_identity(arg_type)
            call = _create_call_to_federated_map(lam, arg)
            arg_type = call.function.type_signature.result.member
            arg = call
        comp = call
        uri = intrinsic_defs.FEDERATED_MAP.uri

        transformed_comp = transformations.remove_mapped_or_applied_identity(
            comp)

        self.assertEqual(
            str(comp),
            '{uri}(<(arg -> arg),{uri}(<(arg -> arg),x>)>)'.format(uri=uri))
        self.assertEqual(str(transformed_comp), 'x')
Пример #10
0
 def transformation_fn(x):
   x, _ = transformations.remove_mapped_or_applied_identity(x)
   x, _ = transformations.inline_block_locals(x)
   x, _ = transformations.replace_selection_from_tuple_with_element(x)
   return x
Пример #11
0
 def transformation_fn(x):
   x, _ = transformations.remove_mapped_or_applied_identity(x)
   return x
Пример #12
0
 def test_remove_mapped_or_applied_identity_raises_type_error(self):
     with self.assertRaises(TypeError):
         transformations.remove_mapped_or_applied_identity(None)