def test_replace_intrinsic_raises_type_error_none_uri(self): comp = _create_lambda_to_add_one(tf.int32) body = lambda x: 100 with self.assertRaises(TypeError): transformations.replace_intrinsic_with_callable( comp, None, body, context_stack_impl.context_stack)
def test_replace_intrinsic_raises_type_error_none_body(self): comp = _create_lambda_to_add_one(tf.int32) uri = intrinsic_defs.GENERIC_PLUS.uri with self.assertRaises(TypeError): transformations.replace_intrinsic_with_callable( comp, uri, None, context_stack_impl.context_stack)
def test_replace_intrinsic_raises_type_error_none_body(self): comp = _create_lambda_to_dummy_intrinsic(tf.int32) uri = 'dummy' with self.assertRaises(TypeError): transformations.replace_intrinsic_with_callable( comp, uri, None, context_stack_impl.context_stack)
def test_replace_intrinsic_raises_type_error_none_comp(self): uri = intrinsic_defs.GENERIC_PLUS.uri body = lambda x: 100 with self.assertRaises(TypeError): transformations.replace_intrinsic_with_callable( None, uri, body, context_stack_impl.context_stack)
def test_replace_intrinsic_raises_type_error_none_comp(self): uri = 'dummy' body = lambda x: x with self.assertRaises(TypeError): transformations.replace_intrinsic_with_callable( None, uri, body, context_stack_impl.context_stack)
def test_replace_intrinsic_replaces_multiple_intrinsics(self): calling_arg = computation_building_blocks.Reference('arg', tf.int32) arg_type = calling_arg.type_signature arg = calling_arg for _ in range(10): lam = _create_lambda_to_add_one(arg_type) call = computation_building_blocks.Call(lam, arg) arg_type = call.function.type_signature.result arg = call calling_lambda = computation_building_blocks.Lambda( calling_arg.name, calling_arg.type_signature, call) comp = calling_lambda uri = intrinsic_defs.GENERIC_PLUS.uri body = lambda x: 100 self.assertEqual(_get_number_of_intrinsics(comp, uri), 10) comp_impl = _to_comp(comp) self.assertEqual(comp_impl(1), 11) transformed_comp = transformations.replace_intrinsic_with_callable( comp, uri, body, context_stack_impl.context_stack) self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 0) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl(1), 100)
def test_replace_intrinsic_does_not_replace_other_intrinsic(self): comp = _create_lambda_to_dummy_intrinsic(tf.int32) uri = 'other' body = lambda x: x transformed_comp = transformations.replace_intrinsic_with_callable( comp, uri, body, context_stack_impl.context_stack) self.assertEqual(comp.tff_repr, '(arg -> dummy(arg))') self.assertEqual(transformed_comp.tff_repr, '(arg -> dummy(arg))')
def test_replace_intrinsic_replaces_intrinsic(self): comp = _create_lambda_to_add_one(tf.int32) uri = intrinsic_defs.GENERIC_PLUS.uri body = lambda x: 100 self.assertEqual(_get_number_of_intrinsics(comp, uri), 1) comp_impl = _to_comp(comp) self.assertEqual(comp_impl(1), 2) transformed_comp = transformations.replace_intrinsic_with_callable( comp, uri, body, context_stack_impl.context_stack) self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 0) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl(1), 100)
def test_replace_intrinsic_replaces_nested_intrinsic(self): fn = _create_lambda_to_dummy_intrinsic(tf.int32) block = _create_dummy_block(fn) comp = block uri = 'dummy' body = lambda x: x transformed_comp = transformations.replace_intrinsic_with_callable( comp, uri, body, context_stack_impl.context_stack) self.assertEqual(comp.tff_repr, '(let local=data in (arg -> dummy(arg)))') self.assertEqual( transformed_comp.tff_repr, '(let local=data in (arg -> (dummy_arg -> dummy_arg)(arg)))')
def test_replace_intrinsic_replaces_multiple_intrinsics(self): fn = _create_lambda_to_dummy_intrinsic(tf.int32) arg = computation_building_blocks.Data('x', tf.int32) call = _create_chained_call(fn, arg, 2) comp = call uri = 'dummy' body = lambda x: x transformed_comp = transformations.replace_intrinsic_with_callable( comp, uri, body, context_stack_impl.context_stack) self.assertEqual(comp.tff_repr, '(arg -> dummy(arg))((arg -> dummy(arg))(x))') self.assertEqual( transformed_comp.tff_repr, '(arg -> (dummy_arg -> dummy_arg)(arg))((arg -> (dummy_arg -> dummy_arg)(arg))(x))' )
def compile(self, computation_to_compile): """Compiles `computation_to_compile`. Args: computation_to_compile: An instance of `computation_base.Computation` to compile. Returns: An instance of `computation_base.Computation` that repeesents the result. """ py_typecheck.check_type(computation_to_compile, computation_base.Computation) computation_proto = computation_impl.ComputationImpl.get_proto( computation_to_compile) # TODO(b/113123410): Add a compiler options argument that characterizes the # desired form of the output. To be driven by what the specific backend the # pipeline is targeting is able to understand. Pending a more fleshed out # design of the backend API. py_typecheck.check_type(computation_proto, pb.Computation) comp = computation_building_blocks.ComputationBuildingBlock.from_proto( computation_proto) # Replace intrinsics with their bodies, for now manually in a fixed order. # TODO(b/113123410): Replace this with a more automated implementation that # does not rely on manual maintenance. for uri, body in six.iteritems(self._intrinsic_bodies): comp, _ = transformations.replace_intrinsic_with_callable( comp, uri, body, self._context_stack) # Replaces called lambdas with LET constructs with a single local symbol. comp, _ = transformations.replace_called_lambda_with_block(comp) # TODO(b/113123410): Add more transformations to simplify and optimize the # structure, e.g., such as: # * removing unnecessary lambdas, # * flatteting the structure, # * merging TensorFlow blocks where appropriate, # * ...and so on. return computation_impl.ComputationImpl(comp.proto, self._context_stack)