def initialize_graph(self, model, input_statistics=None): """Adds required operations to the graph.""" super(ChainingStateManager, self).initialize_graph(model=model, input_statistics=input_statistics) self._start_state = model.get_start_state() self._cached_states = math_utils.TupleOfTensorsLookup( key_dtype=dtypes.int64, default_values=self._start_state, empty_key=-1, name="cached_states", checkpoint=self._checkpoint_state)
def test_tuple_of_tensors_lookup(self): hash_table = math_utils.TupleOfTensorsLookup( key_dtype=dtypes.int64, default_values=[[ array_ops.ones([3, 2], dtype=dtypes.float32), array_ops.zeros([5], dtype=dtypes.float64) ], array_ops.ones([7, 7], dtype=dtypes.int64)], empty_key=-1, deleted_key=-2, name="test_lookup") def stack_tensor(base_tensor): return array_ops.stack([base_tensor + 1, base_tensor + 2]) with self.cached_session() as session: ((float_output, double_output), int_output) = session.run( hash_table.lookup([2, 1, 0])) def expected_output_before_insert(base_tensor): return [base_tensor, base_tensor, base_tensor] self.assertAllClose( expected_output_before_insert(numpy.ones([3, 2])), float_output) self.assertAllClose( expected_output_before_insert(numpy.zeros([5])), double_output) self.assertAllEqual( expected_output_before_insert(numpy.ones([7, 7], dtype=numpy.int64)), int_output) hash_table.insert( keys=[1, 2], values=[[ stack_tensor(array_ops.ones([3, 2], dtype=dtypes.float32)), stack_tensor(array_ops.zeros([5], dtype=dtypes.float64)) ], stack_tensor(array_ops.ones([7, 7], dtype=dtypes.int64))]).run() ((float_output, double_output), int_output) = session.run( hash_table.lookup([2, 1, 0])) def expected_output_after_insert(base_tensor): return [base_tensor + 2, base_tensor + 1, base_tensor] self.assertAllClose( expected_output_after_insert(numpy.ones([3, 2])), float_output) self.assertAllClose( expected_output_after_insert(numpy.zeros([5])), double_output) self.assertAllEqual( expected_output_after_insert(numpy.ones([7, 7], dtype=numpy.int64)), int_output)