def testSparseInput(self): with module.eval_function_for_module(_ModuleSpec(), tags={"special"}) as f: self.assertAllEqual( f( tf.SparseTensorValue([[0]], [1], [2]), # Value is [1, 0]. signature="sparse"), [2, 0])
def testRaggedInput(self): with module.eval_function_for_module(_ModuleSpec(), tags={"special"}) as f: rt = tf.compat.v1.ragged.constant_value( [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], [[[20, 21, 22], [30, 35, 38], [0, 2, 0]]]], ragged_rank=2) self.assertAllEqual(f(rt, signature="ragged"), [[[[2, 4, 6], [8, 10, 12]], [[14, 16, 18]]], [[[40, 42, 44], [60, 70, 76], [0, 4, 0]]]])
def testExplicitSignatureAndTags(self): with module.eval_function_for_module(_ModuleSpec(), tags={"special"}) as f: result = f(dict(x=[1], y=[2]), signature="extra", as_dict=True) self.assertAllEqual(result["default"], [2]) self.assertAllEqual(result["z"], [8])
def testSignature(self): with module.eval_function_for_module(_ModuleSpec()) as f: self.assertAllEqual(f([1, 2]), [2, 4])
def testDictOutput(self): with module.eval_function_for_module(_ModuleSpec()) as f: result = f({"x": [1, 2]}, as_dict=True) self.assertTrue(isinstance(result, dict)) self.assertAllEqual(list(result.keys()), ["default"])
def testDictInput(self): with module.eval_function_for_module(_ModuleSpec()) as f: self.assertAllEqual(f({"x": [1, 2]}), [2, 4])