Пример #1
0
 def _testDropoutWrapper(self, batch_size=None, time_steps=None,
                         parallel_iterations=None, **kwargs):
   with self.test_session() as sess:
     with variable_scope.variable_scope(
         "root", initializer=init_ops.constant_initializer(0.5)):
       if batch_size is None and time_steps is None:
         # 2 time steps, batch size 1, depth 3
         batch_size = 1
         time_steps = 2
         x = constant_op.constant(
             [[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32)
         m = core_rnn_cell_impl.LSTMStateTuple(
             *[constant_op.constant([[0.1, 0.1, 0.1]],
                                    dtype=dtypes.float32)] * 2)
       else:
         x = constant_op.constant(
             np.random.randn(time_steps, batch_size, 3).astype(np.float32))
         m = core_rnn_cell_impl.LSTMStateTuple(
             *[constant_op.constant([[0.1, 0.1, 0.1]] * batch_size,
                                    dtype=dtypes.float32)] * 2)
       outputs, final_state = rnn.dynamic_rnn(
           cell=core_rnn_cell_impl.DropoutWrapper(
               core_rnn_cell_impl.LSTMCell(3),
               dtype=x.dtype,
               **kwargs),
           time_major=True,
           parallel_iterations=parallel_iterations,
           inputs=x, initial_state=m)
       sess.run([variables_lib.global_variables_initializer()])
       res = sess.run([outputs, final_state])
       self.assertEqual(res[0].shape, (time_steps, batch_size, 3))
       self.assertEqual(res[1].c.shape, (batch_size, 3))
       self.assertEqual(res[1].h.shape, (batch_size, 3))
       return res
Пример #2
0
    def get_rnncell(cell_type, cell_size, keep_prob, num_layer):
        if cell_type == "gru":
            cell = rnn_cell.GRUCell(cell_size)
        else:
            cell = rnn_cell.LSTMCell(cell_size, use_peepholes=False, forget_bias=1.0)

        if keep_prob < 1.0:
            cell = rnn_cell.DropoutWrapper(cell, output_keep_prob=keep_prob)

        if num_layer > 1:
            cell = rnn_cell.MultiRNNCell([cell] * num_layer, state_is_tuple=True)

        return cell
 def testDropoutWrapper(self):
     with self.test_session() as sess:
         with variable_scope.variable_scope(
                 "root", initializer=init_ops.constant_initializer(0.5)):
             x = array_ops.zeros([1, 3])
             m = array_ops.zeros([1, 3])
             keep = array_ops.zeros([]) + 1
             g, new_m = core_rnn_cell_impl.DropoutWrapper(
                 core_rnn_cell_impl.GRUCell(3), keep, keep)(x, m)
             sess.run([variables_lib.global_variables_initializer()])
             res = sess.run(
                 [g, new_m], {
                     x.name: np.array([[1., 1., 1.]]),
                     m.name: np.array([[0.1, 0.1, 0.1]])
                 })
             self.assertEqual(res[1].shape, (1, 3))
             # The numbers in results were not calculated, this is just a smoke test.
             self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])