Example #1
0
    def test_replace_intrinsic_raises_type_error_none_uri(self):
        comp = _create_lambda_to_add_one(tf.int32)
        body = lambda x: 100

        with self.assertRaises(TypeError):
            transformations.replace_intrinsic_with_callable(
                comp, None, body, context_stack_impl.context_stack)
Example #2
0
    def test_replace_intrinsic_raises_type_error_none_body(self):
        comp = _create_lambda_to_add_one(tf.int32)
        uri = intrinsic_defs.GENERIC_PLUS.uri

        with self.assertRaises(TypeError):
            transformations.replace_intrinsic_with_callable(
                comp, uri, None, context_stack_impl.context_stack)
    def test_replace_intrinsic_raises_type_error_none_body(self):
        comp = _create_lambda_to_dummy_intrinsic(tf.int32)
        uri = 'dummy'

        with self.assertRaises(TypeError):
            transformations.replace_intrinsic_with_callable(
                comp, uri, None, context_stack_impl.context_stack)
Example #4
0
    def test_replace_intrinsic_raises_type_error_none_comp(self):
        uri = intrinsic_defs.GENERIC_PLUS.uri
        body = lambda x: 100

        with self.assertRaises(TypeError):
            transformations.replace_intrinsic_with_callable(
                None, uri, body, context_stack_impl.context_stack)
    def test_replace_intrinsic_raises_type_error_none_comp(self):
        uri = 'dummy'
        body = lambda x: x

        with self.assertRaises(TypeError):
            transformations.replace_intrinsic_with_callable(
                None, uri, body, context_stack_impl.context_stack)
Example #6
0
    def test_replace_intrinsic_replaces_multiple_intrinsics(self):
        calling_arg = computation_building_blocks.Reference('arg', tf.int32)
        arg_type = calling_arg.type_signature
        arg = calling_arg
        for _ in range(10):
            lam = _create_lambda_to_add_one(arg_type)
            call = computation_building_blocks.Call(lam, arg)
            arg_type = call.function.type_signature.result
            arg = call
        calling_lambda = computation_building_blocks.Lambda(
            calling_arg.name, calling_arg.type_signature, call)
        comp = calling_lambda
        uri = intrinsic_defs.GENERIC_PLUS.uri
        body = lambda x: 100

        self.assertEqual(_get_number_of_intrinsics(comp, uri), 10)
        comp_impl = _to_comp(comp)
        self.assertEqual(comp_impl(1), 11)

        transformed_comp = transformations.replace_intrinsic_with_callable(
            comp, uri, body, context_stack_impl.context_stack)

        self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 0)
        transformed_comp_impl = _to_comp(transformed_comp)
        self.assertEqual(transformed_comp_impl(1), 100)
    def test_replace_intrinsic_does_not_replace_other_intrinsic(self):
        comp = _create_lambda_to_dummy_intrinsic(tf.int32)
        uri = 'other'
        body = lambda x: x

        transformed_comp = transformations.replace_intrinsic_with_callable(
            comp, uri, body, context_stack_impl.context_stack)

        self.assertEqual(comp.tff_repr, '(arg -> dummy(arg))')
        self.assertEqual(transformed_comp.tff_repr, '(arg -> dummy(arg))')
Example #8
0
    def test_replace_intrinsic_replaces_intrinsic(self):
        comp = _create_lambda_to_add_one(tf.int32)
        uri = intrinsic_defs.GENERIC_PLUS.uri
        body = lambda x: 100

        self.assertEqual(_get_number_of_intrinsics(comp, uri), 1)
        comp_impl = _to_comp(comp)
        self.assertEqual(comp_impl(1), 2)

        transformed_comp = transformations.replace_intrinsic_with_callable(
            comp, uri, body, context_stack_impl.context_stack)

        self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 0)
        transformed_comp_impl = _to_comp(transformed_comp)
        self.assertEqual(transformed_comp_impl(1), 100)
    def test_replace_intrinsic_replaces_nested_intrinsic(self):
        fn = _create_lambda_to_dummy_intrinsic(tf.int32)
        block = _create_dummy_block(fn)
        comp = block
        uri = 'dummy'
        body = lambda x: x

        transformed_comp = transformations.replace_intrinsic_with_callable(
            comp, uri, body, context_stack_impl.context_stack)

        self.assertEqual(comp.tff_repr,
                         '(let local=data in (arg -> dummy(arg)))')
        self.assertEqual(
            transformed_comp.tff_repr,
            '(let local=data in (arg -> (dummy_arg -> dummy_arg)(arg)))')
    def test_replace_intrinsic_replaces_multiple_intrinsics(self):
        fn = _create_lambda_to_dummy_intrinsic(tf.int32)
        arg = computation_building_blocks.Data('x', tf.int32)
        call = _create_chained_call(fn, arg, 2)
        comp = call
        uri = 'dummy'
        body = lambda x: x

        transformed_comp = transformations.replace_intrinsic_with_callable(
            comp, uri, body, context_stack_impl.context_stack)

        self.assertEqual(comp.tff_repr,
                         '(arg -> dummy(arg))((arg -> dummy(arg))(x))')
        self.assertEqual(
            transformed_comp.tff_repr,
            '(arg -> (dummy_arg -> dummy_arg)(arg))((arg -> (dummy_arg -> dummy_arg)(arg))(x))'
        )
    def compile(self, computation_to_compile):
        """Compiles `computation_to_compile`.

    Args:
      computation_to_compile: An instance of `computation_base.Computation` to
        compile.

    Returns:
      An instance of `computation_base.Computation` that repeesents the result.
    """
        py_typecheck.check_type(computation_to_compile,
                                computation_base.Computation)
        computation_proto = computation_impl.ComputationImpl.get_proto(
            computation_to_compile)

        # TODO(b/113123410): Add a compiler options argument that characterizes the
        # desired form of the output. To be driven by what the specific backend the
        # pipeline is targeting is able to understand. Pending a more fleshed out
        # design of the backend API.

        py_typecheck.check_type(computation_proto, pb.Computation)
        comp = computation_building_blocks.ComputationBuildingBlock.from_proto(
            computation_proto)

        # Replace intrinsics with their bodies, for now manually in a fixed order.
        # TODO(b/113123410): Replace this with a more automated implementation that
        # does not rely on manual maintenance.
        for uri, body in six.iteritems(self._intrinsic_bodies):
            comp, _ = transformations.replace_intrinsic_with_callable(
                comp, uri, body, self._context_stack)

        # Replaces called lambdas with LET constructs with a single local symbol.
        comp, _ = transformations.replace_called_lambda_with_block(comp)
        # TODO(b/113123410): Add more transformations to simplify and optimize the
        # structure, e.g., such as:
        # * removing unnecessary lambdas,
        # * flatteting the structure,
        # * merging TensorFlow blocks where appropriate,
        # * ...and so on.

        return computation_impl.ComputationImpl(comp.proto,
                                                self._context_stack)