Ejemplo n.º 1
0
    def test_raises_type_error_with_none_comp(self):
        uri = 'intrinsic'
        body = lambda x: x

        with self.assertRaises(TypeError):
            value_transformations.replace_intrinsics_with_callable(
                None, uri, body, context_stack_impl.context_stack)
Ejemplo n.º 2
0
    def test_raises_type_error_with_none_body(self):
        comp = test_utils.create_lambda_to_dummy_called_intrinsic(
            parameter_name='a')
        uri = 'intrinsic'

        with self.assertRaises(TypeError):
            value_transformations.replace_intrinsics_with_callable(
                comp, uri, None, context_stack_impl.context_stack)
Ejemplo n.º 3
0
    def test_replaces_intrinsic(self):
        comp = test_utils.create_lambda_to_dummy_called_intrinsic(
            parameter_name='a')
        uri = 'intrinsic'
        body = lambda x: x

        transformed_comp, modified = value_transformations.replace_intrinsics_with_callable(
            comp, uri, body, context_stack_impl.context_stack)

        self.assertEqual(comp.compact_representation(), '(a -> intrinsic(a))')
        self.assertEqual(transformed_comp.compact_representation(),
                         '(a -> (intrinsic_arg -> intrinsic_arg)(a))')
        self.assertEqual(transformed_comp.type_signature, comp.type_signature)
        self.assertTrue(modified)
    def test_replaces_nested_intrinsic(self):
        fn = computation_test_utils.create_lambda_to_dummy_called_intrinsic(
            parameter_name='a')
        block = computation_test_utils.create_dummy_block(fn,
                                                          variable_name='b')
        comp = block
        uri = 'intrinsic'
        body = lambda x: x

        transformed_comp, modified = value_transformations.replace_intrinsics_with_callable(
            comp, uri, body, context_stack_impl.context_stack)

        self.assertEqual(comp.compact_representation(),
                         '(let b=data in (a -> intrinsic(a)))')
        self.assertEqual(
            transformed_comp.compact_representation(),
            '(let b=data in (a -> (intrinsic_arg -> intrinsic_arg)(a)))')
        self.assertEqual(transformed_comp.type_signature, comp.type_signature)
        self.assertTrue(modified)
Ejemplo n.º 5
0
  def test_replaces_chained_intrinsics(self):
    fn = test_utils.create_lambda_to_dummy_called_intrinsic(parameter_name='a')
    arg = building_blocks.Data('data', tf.int32)
    call = test_utils.create_chained_calls([fn, fn], arg)
    comp = call
    uri = 'intrinsic'
    body = lambda x: x

    transformed_comp, modified = value_transformations.replace_intrinsics_with_callable(
        comp, uri, body, context_stack_impl.context_stack)

    self.assertEqual(comp.compact_representation(),
                     '(a -> intrinsic(a))((a -> intrinsic(a))(data))')
    self.assertEqual(
        transformed_comp.compact_representation(),
        '(a -> (intrinsic_arg -> intrinsic_arg)(a))((a -> (intrinsic_arg -> intrinsic_arg)(a))(data))'
    )
    self.assertEqual(transformed_comp.type_signature, comp.type_signature)
    self.assertTrue(modified)