def test_callable_object(self): class Identity(object): def __call__(self, other): return other map_lt = ops.map_fn(Identity(), self.original_lt) self.assertLabeledTensorsEqual(map_lt, self.original_lt)
def test_string(self): def fn(entry_lt): op = string_ops.string_join([entry_lt, 'world']) return core.LabeledTensor(op, []) tensor_lt = ops.constant(['hi', 'bye'], axes=['batch']) map_lt = ops.map_fn(fn, tensor_lt) golden_lt = ops.constant(['hiworld', 'byeworld'], axes=['batch']) self.assertLabeledTensorsEqual(map_lt, golden_lt)
def test_slice(self): map_lt = ops.map_fn(lambda t: core.slice_function(t, {'channel': 1}), self.original_lt) slice_lt = core.slice_function(self.original_lt, {'channel': 1}) self.assertLabeledTensorsEqual(map_lt, slice_lt)
def test_identity(self): map_lt = ops.map_fn(core.identity, self.original_lt) self.assertLabeledTensorsEqual(map_lt, self.original_lt)
def test_name(self): map_lt = ops.map_fn(core.identity, self.original_lt) self.assertIn('lt_map_fn', map_lt.name)