コード例 #1
0
    def test_flatten_state(self):
        filter_size = [3, 3]
        output_size = [10, 10]
        num_units = 16
        state_name = 'lstm_state'
        batch_size = 4
        dtype = tf.float32
        unroll = 10
        learned_state = False

        inputs_large = tf.zeros([4, 10, 10, 5], dtype=tf.float32)
        inputs_small = tf.zeros([4, 10, 10, 3], dtype=tf.float32)
        cell = lstm_cells.GroupedConvLSTMCell(filter_size=filter_size,
                                              output_size=output_size,
                                              num_units=num_units,
                                              is_training=True,
                                              pre_bottleneck=True,
                                              flatten_state=True)
        state = cell.init_state(state_name, batch_size, dtype, learned_state)
        for step in range(unroll):
            if step % 2 == 0:
                inputs = cell.pre_bottleneck(inputs_large, state[1], 0)
            else:
                inputs = cell.pre_bottleneck(inputs_small, state[1], 1)
            output, state = cell(inputs, state)
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            output_result, state_result = sess.run([output, state])
            self.assertAllEqual((4, 10, 10, 16), output_result.shape)
            self.assertAllEqual((4, 10 * 10 * 16), state_result[0].shape)
            self.assertAllEqual((4, 10 * 10 * 16), state_result[1].shape)
コード例 #2
0
    def test_prebottleneck(self):
        filter_size = [3, 3]
        output_size = [10, 10]
        num_units = 16
        state_name = 'lstm_state'
        batch_size = 4
        dtype = tf.float32
        unroll = 10
        learned_state = False

        inputs_large = tf.zeros([4, 10, 10, 5], dtype=tf.float32)
        inputs_small = tf.zeros([4, 10, 10, 3], dtype=tf.float32)
        cell = lstm_cells.GroupedConvLSTMCell(filter_size=filter_size,
                                              output_size=output_size,
                                              num_units=num_units,
                                              is_training=True,
                                              pre_bottleneck=True)
        state = cell.init_state(state_name, batch_size, dtype, learned_state)
        for step in range(unroll):
            if step % 2 == 0:
                inputs = cell.pre_bottleneck(inputs_large, state[1], 0)
            else:
                inputs = cell.pre_bottleneck(inputs_small, state[1], 1)
            output, state = cell(inputs, state)
        self.assertAllEqual([4, 10, 10, 16], output.shape.as_list())
        self.assertAllEqual([4, 10, 10, 16], state[0].shape.as_list())
        self.assertAllEqual([4, 10, 10, 16], state[1].shape.as_list())
    def create_lstm_cell(self,
                         batch_size,
                         output_size,
                         state_saver,
                         state_name,
                         dtype=tf.float32):
        """Create the LSTM cell, and initialize state if necessary.

    Args:
      batch_size: input batch size.
      output_size: output size of the lstm cell, [width, height].
      state_saver: a state saver object with methods `state` and `save_state`.
      state_name: string, the name to use with the state_saver.
      dtype: dtype to initialize lstm state.

    Returns:
      lstm_cell: the lstm cell unit.
      init_state: initial state representations.
      step: the step
    """
        lstm_cell = lstm_cells.GroupedConvLSTMCell(
            filter_size=(3, 3),
            output_size=output_size,
            num_units=max(self._min_depth, self._lstm_state_depth),
            is_training=self._is_training,
            activation=tf.nn.relu6,
            flatten_state=self._flatten_state,
            scale_state=self._scale_state,
            clip_state=self._clip_state,
            output_bottleneck=True,
            pre_bottleneck=self._pre_bottleneck,
            is_quantized=self._is_quantized,
            visualize_gates=False)

        if state_saver is None:
            init_state = lstm_cell.init_state('lstm_state', batch_size, dtype)
            step = None
        else:
            step = state_saver.state(state_name + '_step')
            c = state_saver.state(state_name + '_c')
            h = state_saver.state(state_name + '_h')
            c.set_shape([batch_size] + c.get_shape().as_list()[1:])
            h.set_shape([batch_size] + h.get_shape().as_list()[1:])
            init_state = (c, h)
        return lstm_cell, init_state, step
コード例 #4
0
    def test_get_init_learned_state(self):
        filter_size = [3, 3]
        output_size = [10, 10]
        num_units = 16
        state_name = 'lstm_state'
        batch_size = 4
        dtype = tf.float32
        learned_state = True

        cell = lstm_cells.GroupedConvLSTMCell(filter_size=filter_size,
                                              output_size=output_size,
                                              num_units=num_units,
                                              is_training=True)
        init_c, init_h = cell.init_state(state_name, batch_size, dtype,
                                         learned_state)

        self.assertEqual(tf.float32, init_c.dtype)
        self.assertEqual(tf.float32, init_h.dtype)
        self.assertAllEqual([4, 10, 10, 16], init_c.shape.as_list())
        self.assertAllEqual([4, 10, 10, 16], init_h.shape.as_list())
コード例 #5
0
    def test_run_lstm_cell(self):
        filter_size = [3, 3]
        output_size = [10, 10]
        num_units = 16
        state_name = 'lstm_state'
        batch_size = 4
        dtype = tf.float32
        learned_state = False

        inputs = tf.zeros([4, 10, 10, 3], dtype=tf.float32)
        cell = lstm_cells.GroupedConvLSTMCell(filter_size=filter_size,
                                              output_size=output_size,
                                              num_units=num_units,
                                              is_training=True)
        init_state = cell.init_state(state_name, batch_size, dtype,
                                     learned_state)
        output, state_tuple = cell(inputs, init_state)
        self.assertAllEqual([4, 10, 10, 16], output.shape.as_list())
        self.assertAllEqual([4, 10, 10, 16], state_tuple[0].shape.as_list())
        self.assertAllEqual([4, 10, 10, 16], state_tuple[1].shape.as_list())
コード例 #6
0
    def test_get_init_state(self):
        filter_size = [3, 3]
        output_dim = 10
        output_size = [output_dim] * 2
        num_units = 16
        state_name = 'lstm_state'
        batch_size = 4
        dtype = tf.float32
        learned_state = False

        cell = lstm_cells.GroupedConvLSTMCell(filter_size=filter_size,
                                              output_size=output_size,
                                              num_units=num_units,
                                              is_training=True)
        init_c, init_h = cell.init_state(state_name, batch_size, dtype,
                                         learned_state)

        self.assertEqual(tf.float32, init_c.dtype)
        self.assertEqual(tf.float32, init_h.dtype)
        with self.test_session() as sess:
            init_c_res, init_h_res = sess.run([init_c, init_h])
            self.assertAllClose(np.zeros((4, 10, 10, 16)), init_c_res)
            self.assertAllClose(np.zeros((4, 10, 10, 16)), init_h_res)