예제 #1
0
    def test_run_lstm_cell(self):
        filter_size = [3, 3]
        output_size = [10, 10]
        num_units = 15
        state_name = 'lstm_state'
        batch_size = 4
        dtype = tf.float32
        learned_state = False

        inputs = tf.zeros([4, 10, 10, 3], dtype=tf.float32)
        cell = lstm_cells.BottleneckConvLSTMCell(filter_size=filter_size,
                                                 output_size=output_size,
                                                 num_units=num_units)
        init_state = cell.init_state(state_name, batch_size, dtype,
                                     learned_state)
        output, state_tuple = cell(inputs, init_state)
        self.assertAllEqual([4, 10, 10, 15], output.shape.as_list())
        self.assertAllEqual([4, 10, 10, 15], state_tuple[0].shape.as_list())
        self.assertAllEqual([4, 10, 10, 15], state_tuple[1].shape.as_list())
예제 #2
0
    def test_get_init_learned_state(self):
        filter_size = [3, 3]
        output_size = [10, 10]
        num_units = 15
        state_name = 'lstm_state'
        batch_size = 4
        dtype = tf.float32
        learned_state = True

        cell = lstm_cells.BottleneckConvLSTMCell(filter_size=filter_size,
                                                 output_size=output_size,
                                                 num_units=num_units)
        init_c, init_h = cell.init_state(state_name, batch_size, dtype,
                                         learned_state)

        self.assertEqual(tf.float32, init_c.dtype)
        self.assertEqual(tf.float32, init_h.dtype)
        self.assertAllEqual([4, 10, 10, 15], init_c.shape.as_list())
        self.assertAllEqual([4, 10, 10, 15], init_h.shape.as_list())
예제 #3
0
    def test_get_init_state(self):
        filter_size = [3, 3]
        output_dim = 10
        output_size = [output_dim] * 2
        num_units = 15
        state_name = 'lstm_state'
        batch_size = 4
        dtype = tf.float32
        learned_state = False

        cell = lstm_cells.BottleneckConvLSTMCell(filter_size=filter_size,
                                                 output_size=output_size,
                                                 num_units=num_units)
        init_c, init_h = cell.init_state(state_name, batch_size, dtype,
                                         learned_state)

        self.assertEqual(tf.float32, init_c.dtype)
        self.assertEqual(tf.float32, init_h.dtype)
        with self.test_session() as sess:
            init_c_res, init_h_res = sess.run([init_c, init_h])
            self.assertAllClose(np.zeros((4, 10, 10, 15)), init_c_res)
            self.assertAllClose(np.zeros((4, 10, 10, 15)), init_h_res)