def test_raises_type_error_with_none_body(self): comp = test_utils.create_lambda_to_dummy_called_intrinsic( parameter_name='a') uri = 'intrinsic' with self.assertRaises(TypeError): value_transformations.replace_intrinsics_with_callable( comp, uri, None, context_stack_impl.context_stack)
def test_replaces_intrinsic(self): comp = test_utils.create_lambda_to_dummy_called_intrinsic( parameter_name='a') uri = 'intrinsic' body = lambda x: x transformed_comp, modified = value_transformations.replace_intrinsics_with_callable( comp, uri, body, context_stack_impl.context_stack) self.assertEqual(comp.compact_representation(), '(a -> intrinsic(a))') self.assertEqual(transformed_comp.compact_representation(), '(a -> (intrinsic_arg -> intrinsic_arg)(a))') self.assertEqual(transformed_comp.type_signature, comp.type_signature) self.assertTrue(modified)
def test_replaces_chained_intrinsics(self): fn = test_utils.create_lambda_to_dummy_called_intrinsic(parameter_name='a') arg = building_blocks.Data('data', tf.int32) call = test_utils.create_chained_calls([fn, fn], arg) comp = call uri = 'intrinsic' body = lambda x: x transformed_comp, modified = value_transformations.replace_intrinsics_with_callable( comp, uri, body, context_stack_impl.context_stack) self.assertEqual(comp.compact_representation(), '(a -> intrinsic(a))((a -> intrinsic(a))(data))') self.assertEqual( transformed_comp.compact_representation(), '(a -> (intrinsic_arg -> intrinsic_arg)(a))((a -> (intrinsic_arg -> intrinsic_arg)(a))(data))' ) self.assertEqual(transformed_comp.type_signature, comp.type_signature) self.assertTrue(modified)