Ejemplo n.º 1
0
 def initialize_graph(self, model, input_statistics=None):
     """Adds required operations to the graph."""
     super(ChainingStateManager,
           self).initialize_graph(model=model,
                                  input_statistics=input_statistics)
     self._start_state = model.get_start_state()
     self._cached_states = math_utils.TupleOfTensorsLookup(
         key_dtype=dtypes.int64,
         default_values=self._start_state,
         empty_key=-1,
         name="cached_states",
         checkpoint=self._checkpoint_state)
Ejemplo n.º 2
0
  def test_tuple_of_tensors_lookup(self):
    hash_table = math_utils.TupleOfTensorsLookup(
        key_dtype=dtypes.int64,
        default_values=[[
            array_ops.ones([3, 2], dtype=dtypes.float32),
            array_ops.zeros([5], dtype=dtypes.float64)
        ],
                        array_ops.ones([7, 7], dtype=dtypes.int64)],
        empty_key=-1,
        deleted_key=-2,
        name="test_lookup")
    def stack_tensor(base_tensor):
      return array_ops.stack([base_tensor + 1, base_tensor + 2])

    with self.cached_session() as session:
      ((float_output, double_output), int_output) = session.run(
          hash_table.lookup([2, 1, 0]))
      def expected_output_before_insert(base_tensor):
        return [base_tensor,
                base_tensor,
                base_tensor]
      self.assertAllClose(
          expected_output_before_insert(numpy.ones([3, 2])),
          float_output)
      self.assertAllClose(
          expected_output_before_insert(numpy.zeros([5])),
          double_output)
      self.assertAllEqual(
          expected_output_before_insert(numpy.ones([7, 7], dtype=numpy.int64)),
          int_output)
      hash_table.insert(
          keys=[1, 2],
          values=[[
              stack_tensor(array_ops.ones([3, 2], dtype=dtypes.float32)),
              stack_tensor(array_ops.zeros([5], dtype=dtypes.float64))
          ], stack_tensor(array_ops.ones([7, 7], dtype=dtypes.int64))]).run()
      ((float_output, double_output), int_output) = session.run(
          hash_table.lookup([2, 1, 0]))
      def expected_output_after_insert(base_tensor):
        return [base_tensor + 2,
                base_tensor + 1,
                base_tensor]
      self.assertAllClose(
          expected_output_after_insert(numpy.ones([3, 2])),
          float_output)
      self.assertAllClose(
          expected_output_after_insert(numpy.zeros([5])),
          double_output)
      self.assertAllEqual(
          expected_output_after_insert(numpy.ones([7, 7], dtype=numpy.int64)),
          int_output)