Beispiel #1
0
 def _testDropoutWrapper(self, batch_size=None, time_steps=None,
                         parallel_iterations=None, **kwargs):
   with self.test_session() as sess:
     with variable_scope.variable_scope(
         "root", initializer=init_ops.constant_initializer(0.5)):
       if batch_size is None and time_steps is None:
         # 2 time steps, batch size 1, depth 3
         batch_size = 1
         time_steps = 2
         x = constant_op.constant(
             [[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32)
         m = rnn_cell_impl.LSTMStateTuple(
             *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32)
              ] * 2)
       else:
         x = constant_op.constant(
             np.random.randn(time_steps, batch_size, 3).astype(np.float32))
         m = rnn_cell_impl.LSTMStateTuple(*[
             constant_op.constant(
                 [[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)
         ] * 2)
       outputs, final_state = rnn.dynamic_rnn(
           cell=rnn_cell_impl.DropoutWrapper(
               rnn_cell_impl.LSTMCell(3), dtype=x.dtype, **kwargs),
           time_major=True,
           parallel_iterations=parallel_iterations,
           inputs=x,
           initial_state=m)
       sess.run([variables_lib.global_variables_initializer()])
       res = sess.run([outputs, final_state])
       self.assertEqual(res[0].shape, (time_steps, batch_size, 3))
       self.assertEqual(res[1].c.shape, (batch_size, 3))
       self.assertEqual(res[1].h.shape, (batch_size, 3))
       return res
  def testDropoutWrapperKerasStyle(self):
    """Tests if DropoutWrapperV2 cell is instantiated in keras style scope."""
    wrapped_cell_v2 = rnn_cell_impl.DropoutWrapperV2(
        rnn_cell_impl.BasicRNNCell(1))
    self.assertTrue(wrapped_cell_v2._keras_style)

    wrapped_cell = rnn_cell_impl.DropoutWrapper(rnn_cell_impl.BasicRNNCell(1))
    self.assertFalse(wrapped_cell._keras_style)
Beispiel #3
0
def _build_multi_lstm_cell(num_units,
                           num_layers,
                           train_test_predict,
                           keep_prob=1.0):
    cell = rnn_cell_impl.BasicLSTMCell(
        num_units, reuse=not (train_test_predict == 'train'))
    if train_test_predict == 'train' and keep_prob < 1.0:
        cell = rnn_cell_impl.DropoutWrapper(cell, output_keep_prob=keep_prob)
    cells = [cell for _ in range(num_layers)]
    return rnn_cell_impl.MultiRNNCell(cells)
Beispiel #4
0
    def get_rnncell(cell_type, cell_size, keep_prob, num_layer):
        if cell_type == "gru":
            cell = rnn_cell.GRUCell(cell_size)
        else:
            cell = rnn_cell.LSTMCell(cell_size,
                                     use_peepholes=False,
                                     forget_bias=1.0)

        if keep_prob < 1.0:
            cell = rnn_cell.DropoutWrapper(cell, output_keep_prob=keep_prob)

        if num_layer > 1:
            cell = rnn_cell.MultiRNNCell([cell] * num_layer,
                                         state_is_tuple=True)

        return cell
Beispiel #5
0
 def _rnn_cell(cell_type,
               cell_size,
               is_drop_out=True,
               _dropout_in=1.0,
               _dropout_out=0.5,
               use_peepholes=True,
               activation=None):
     if isinstance(cell_type, rnn_cell_impl.LSTMCell):
         cell = cell_type(num_units=cell_size,
                          use_peepholes=use_peepholes,
                          activation=activation)
     else:
         cell = cell_type(num_units=cell_size, activation=activation)
     if is_drop_out:
         cell = rnn_cell_impl.DropoutWrapper(cell=cell,
                                             input_keep_prob=_dropout_in,
                                             output_keep_prob=_dropout_out)
     return cell
Beispiel #6
0
    def get_rnncell(cell_type, cell_size, keep_prob, num_layer):
        # thanks for this solution from @dimeldo
        cells = []
        for _ in range(num_layer):
            if cell_type == "gru":
                cell = rnn_cell.GRUCell(cell_size)
            else:
                cell = rnn_cell.LSTMCell(cell_size, use_peepholes=False, forget_bias=1.0)

            if keep_prob < 1.0:
                cell = rnn_cell.DropoutWrapper(cell, output_keep_prob=keep_prob)

            cells.append(cell)

        if num_layer > 1:
            cell = rnn_cell.MultiRNNCell(cells, state_is_tuple=True)
        else:
            cell = cells[0]

        return cell
Beispiel #7
0
 def testWrappedCellProperty(self):
     cell = rnn_cell_impl.BasicRNNCell(10)
     wrapper = rnn_cell_impl.DropoutWrapper(cell)
     # Github issue 15810
     self.assertEqual(wrapper.wrapped_cell, cell)