def test_flatten_state(self): filter_size = [3, 3] output_size = [10, 10] num_units = 16 state_name = 'lstm_state' batch_size = 4 dtype = tf.float32 unroll = 10 learned_state = False inputs_large = tf.zeros([4, 10, 10, 5], dtype=tf.float32) inputs_small = tf.zeros([4, 10, 10, 3], dtype=tf.float32) cell = lstm_cells.GroupedConvLSTMCell(filter_size=filter_size, output_size=output_size, num_units=num_units, is_training=True, pre_bottleneck=True, flatten_state=True) state = cell.init_state(state_name, batch_size, dtype, learned_state) for step in range(unroll): if step % 2 == 0: inputs = cell.pre_bottleneck(inputs_large, state[1], 0) else: inputs = cell.pre_bottleneck(inputs_small, state[1], 1) output, state = cell(inputs, state) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) output_result, state_result = sess.run([output, state]) self.assertAllEqual((4, 10, 10, 16), output_result.shape) self.assertAllEqual((4, 10 * 10 * 16), state_result[0].shape) self.assertAllEqual((4, 10 * 10 * 16), state_result[1].shape)
def test_prebottleneck(self): filter_size = [3, 3] output_size = [10, 10] num_units = 16 state_name = 'lstm_state' batch_size = 4 dtype = tf.float32 unroll = 10 learned_state = False inputs_large = tf.zeros([4, 10, 10, 5], dtype=tf.float32) inputs_small = tf.zeros([4, 10, 10, 3], dtype=tf.float32) cell = lstm_cells.GroupedConvLSTMCell(filter_size=filter_size, output_size=output_size, num_units=num_units, is_training=True, pre_bottleneck=True) state = cell.init_state(state_name, batch_size, dtype, learned_state) for step in range(unroll): if step % 2 == 0: inputs = cell.pre_bottleneck(inputs_large, state[1], 0) else: inputs = cell.pre_bottleneck(inputs_small, state[1], 1) output, state = cell(inputs, state) self.assertAllEqual([4, 10, 10, 16], output.shape.as_list()) self.assertAllEqual([4, 10, 10, 16], state[0].shape.as_list()) self.assertAllEqual([4, 10, 10, 16], state[1].shape.as_list())
def create_lstm_cell(self, batch_size, output_size, state_saver, state_name, dtype=tf.float32): """Create the LSTM cell, and initialize state if necessary. Args: batch_size: input batch size. output_size: output size of the lstm cell, [width, height]. state_saver: a state saver object with methods `state` and `save_state`. state_name: string, the name to use with the state_saver. dtype: dtype to initialize lstm state. Returns: lstm_cell: the lstm cell unit. init_state: initial state representations. step: the step """ lstm_cell = lstm_cells.GroupedConvLSTMCell( filter_size=(3, 3), output_size=output_size, num_units=max(self._min_depth, self._lstm_state_depth), is_training=self._is_training, activation=tf.nn.relu6, flatten_state=self._flatten_state, scale_state=self._scale_state, clip_state=self._clip_state, output_bottleneck=True, pre_bottleneck=self._pre_bottleneck, is_quantized=self._is_quantized, visualize_gates=False) if state_saver is None: init_state = lstm_cell.init_state('lstm_state', batch_size, dtype) step = None else: step = state_saver.state(state_name + '_step') c = state_saver.state(state_name + '_c') h = state_saver.state(state_name + '_h') c.set_shape([batch_size] + c.get_shape().as_list()[1:]) h.set_shape([batch_size] + h.get_shape().as_list()[1:]) init_state = (c, h) return lstm_cell, init_state, step
def test_get_init_learned_state(self): filter_size = [3, 3] output_size = [10, 10] num_units = 16 state_name = 'lstm_state' batch_size = 4 dtype = tf.float32 learned_state = True cell = lstm_cells.GroupedConvLSTMCell(filter_size=filter_size, output_size=output_size, num_units=num_units, is_training=True) init_c, init_h = cell.init_state(state_name, batch_size, dtype, learned_state) self.assertEqual(tf.float32, init_c.dtype) self.assertEqual(tf.float32, init_h.dtype) self.assertAllEqual([4, 10, 10, 16], init_c.shape.as_list()) self.assertAllEqual([4, 10, 10, 16], init_h.shape.as_list())
def test_run_lstm_cell(self): filter_size = [3, 3] output_size = [10, 10] num_units = 16 state_name = 'lstm_state' batch_size = 4 dtype = tf.float32 learned_state = False inputs = tf.zeros([4, 10, 10, 3], dtype=tf.float32) cell = lstm_cells.GroupedConvLSTMCell(filter_size=filter_size, output_size=output_size, num_units=num_units, is_training=True) init_state = cell.init_state(state_name, batch_size, dtype, learned_state) output, state_tuple = cell(inputs, init_state) self.assertAllEqual([4, 10, 10, 16], output.shape.as_list()) self.assertAllEqual([4, 10, 10, 16], state_tuple[0].shape.as_list()) self.assertAllEqual([4, 10, 10, 16], state_tuple[1].shape.as_list())
def test_get_init_state(self): filter_size = [3, 3] output_dim = 10 output_size = [output_dim] * 2 num_units = 16 state_name = 'lstm_state' batch_size = 4 dtype = tf.float32 learned_state = False cell = lstm_cells.GroupedConvLSTMCell(filter_size=filter_size, output_size=output_size, num_units=num_units, is_training=True) init_c, init_h = cell.init_state(state_name, batch_size, dtype, learned_state) self.assertEqual(tf.float32, init_c.dtype) self.assertEqual(tf.float32, init_h.dtype) with self.test_session() as sess: init_c_res, init_h_res = sess.run([init_c, init_h]) self.assertAllClose(np.zeros((4, 10, 10, 16)), init_c_res) self.assertAllClose(np.zeros((4, 10, 10, 16)), init_h_res)