Пример #1
0
def test_static_dropout_rnn_cell():
    sess = tf.Session()
    x = np.random.randn(1, 10, 50).astype(np.float32)
    with sess.graph.as_default():
        with tf.variable_scope("DropoutIsOn"):
            rnn_drop_cell = rnn_cell_w_dropout(100,
                                               0.9999999999,
                                               'gru',
                                               training=True)
            rnn_drop, _ = tf.nn.dynamic_rnn(rnn_drop_cell,
                                            x,
                                            sequence_length=np.array(
                                                [10], dtype=np.int),
                                            dtype=tf.float32)
        with tf.variable_scope("DropoutIsOff"):
            rnn_no_drop_cell = rnn_cell_w_dropout(100,
                                                  0.9999999999,
                                                  'gru',
                                                  training=False)
            rnn_no_drop, _ = tf.nn.dynamic_rnn(rnn_no_drop_cell,
                                               x,
                                               sequence_length=np.array(
                                                   [10], dtype=np.int),
                                               dtype=tf.float32)
    sess.run(tf.global_variables_initializer())
    out_ten = sess.run(rnn_drop)
    assert len(out_ten[np.nonzero(out_ten)].squeeze()) < 20
    out_ten = sess.run(rnn_no_drop)
    assert len(out_ten[np.nonzero(out_ten)].squeeze()) > 20
Пример #2
0
def test_static_dropout_rnn_cell():
    with tf.device('/cpu:0'):
        sess = tf.Session()
        x = np.random.randn(1, 10, 50).astype(np.float32)
        with sess.graph.as_default():
            with tf.variable_scope("DropoutIsOn"):
                rnn_drop_cell = rnn_cell_w_dropout(100, 0.9999999999, 'gru', training=True)
                rnn_drop, _ = tf.nn.dynamic_rnn(rnn_drop_cell, x, sequence_length=np.array([10], dtype=np.int), dtype=tf.float32)
            with tf.variable_scope("DropoutIsOff"):
                rnn_no_drop_cell = rnn_cell_w_dropout(100, 0.9999999999, 'gru', training=False)
                rnn_no_drop, _ = tf.nn.dynamic_rnn(rnn_no_drop_cell, x, sequence_length=np.array([10], dtype=np.int), dtype=tf.float32)
        sess.run(tf.global_variables_initializer())
        out_ten = sess.run(rnn_drop)
        assert len(out_ten[np.nonzero(out_ten)].squeeze()) < 20
        out_ten = sess.run(rnn_no_drop)
        assert len(out_ten[np.nonzero(out_ten)].squeeze()) > 20
Пример #3
0
def test_placeholder_dropout_rnn_cell():
    sess = tf.Session()
    x = np.random.randn(1, 10, 50).astype(np.float32)
    with sess.graph.as_default():
        train_flag = tf.placeholder_with_default(False, shape=(), name='TEST_TRAIN_FLAG')
        with tf.variable_scope("DropoutMightBeOn"):
            rnn_cell = rnn_cell_w_dropout(100, 0.9999999999, 'gru', training=train_flag)
            rnn, _ = tf.nn.dynamic_rnn(rnn_cell, x, sequence_length=np.array([10], dtype=np.int), dtype=tf.float32)

    sess.run(tf.global_variables_initializer())
    out_ten = sess.run(rnn, {train_flag: True})
    assert len(out_ten[np.nonzero(out_ten)].squeeze()) < 20
    out_ten = sess.run(rnn)
    assert len(out_ten[np.nonzero(out_ten)].squeeze()) > 20
Пример #4
0
def test_placeholder_dropout_rnn_cell():
    with tf.device('/cpu:0'):
        sess = tf.Session()
        x = np.random.randn(1, 10, 50).astype(np.float32)
        with sess.graph.as_default():
            train_flag = tf.placeholder_with_default(False, shape=(), name='TEST_TRAIN_FLAG')
            with tf.variable_scope("DropoutMightBeOn"):
                rnn_cell = rnn_cell_w_dropout(100, 0.9999999999, 'gru', training=train_flag)
                rnn, _ = tf.nn.dynamic_rnn(rnn_cell, x, sequence_length=np.array([10], dtype=np.int), dtype=tf.float32)

        sess.run(tf.global_variables_initializer())
        out_ten = sess.run(rnn, {train_flag: True})
        assert len(out_ten[np.nonzero(out_ten)].squeeze()) < 20
        out_ten = sess.run(rnn)
        assert len(out_ten[np.nonzero(out_ten)].squeeze()) > 20