Esempio n. 1
0
    def test_single_dynamic_lstm_seq_length_is_not_const(self):
        units = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]], dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)
        state_is_tuple = True
        x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
        initializer = init_ops.constant_initializer(0.5)

        y_val = np.array([4, 3, 4, 5, 2, 1], dtype=np.int32)
        seq_length = tf.placeholder(tf.int32, y_val.shape, name="input_2")

        # no scope
        cell = rnn.LSTMCell(
            units,
            initializer=initializer,
            state_is_tuple=state_is_tuple)
        outputs, cell_state = tf.nn.dynamic_rnn(
            cell,
            x,
            dtype=tf.float32,
            sequence_length=tf.identity(seq_length))

        _ = tf.identity(outputs, name="output")
        _ = tf.identity(cell_state, name="cell_state")

        feed_dict = {"input_1:0": x_val, "input_2:0": y_val}
        input_names_with_port = ["input_1:0", "input_2:0"]
        output_names_with_port = ["output:0", "cell_state:0"]
        self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 2
0
    def test_single_dynamic_lstm_placeholder_input(self):
        units = 5
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]],
                         dtype=np.float32)
        x_val = np.stack([x_val] * 6)

        def func(x):
            # no scope
            cell = LSTMBlockCell(units)
            outputs, cell_state = dynamic_rnn(
                cell, x,
                dtype=tf.float32)  # by default zero initializer is used
            return tf.identity(outputs,
                               name="output"), tf.identity(cell_state,
                                                           name="cell_state")

        feed_dict = {"input_1:0": x_val}
        input_names_with_port = ["input_1:0"]
        output_names_with_port = ["output:0", "cell_state:0"]
        self.run_test_case(func,
                           feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-3,
                           atol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 3
0
    def test_single_dynamic_lstm_seq_length_is_const(self):
        units = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]],
                         dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)
        x = tf.placeholder(tf.float32, x_val.shape, name="input_1")

        # no scope
        cell = rnn.LSTMBlockCell(units)
        outputs, cell_state = tf.nn.dynamic_rnn(
            cell, x, dtype=tf.float32, sequence_length=[4, 3, 4, 5, 2, 1])

        _ = tf.identity(outputs, name="output")
        _ = tf.identity(cell_state, name="cell_state")

        feed_dict = {"input_1:0": x_val}
        input_names_with_port = ["input_1:0"]
        output_names_with_port = ["output:0", "cell_state:0"]
        self.run_test_case(feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-3,
                           atol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 4
0
    def test_single_dynamic_lstm_forget_bias(self):
        units = 5
        seq_len = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]],
                         dtype=np.float32)
        x_val = np.stack([x_val] * seq_len)

        x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
        initializer = init_ops.constant_initializer(0.5)

        # no scope
        cell = rnn.LSTMCell(units, initializer=initializer, forget_bias=0.5)
        outputs, cell_state = tf.nn.dynamic_rnn(cell,
                                                x,
                                                time_major=True,
                                                dtype=tf.float32)

        _ = tf.identity(outputs, name="output")
        _ = tf.identity(cell_state, name="cell_state")

        input_names_with_port = ["input_1:0"]
        feed_dict = {"input_1:0": x_val}

        output_names_with_port = ["output:0", "cell_state:0"]
        self.run_test_case(feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 5
0
    def test_dynamic_multi_bilstm_with_same_input_hidden_size(self):
        units = 5
        batch_size = 10
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)

        x = tf.placeholder(tf.float32, x_val.shape, name="input_1")

        cell1 = rnn.LSTMCell(units)
        cell2 = rnn.LSTMCell(units)
        outputs_1, cell_state_1 = tf.nn.bidirectional_dynamic_rnn(
            cell1, cell2, x, dtype=tf.float32, scope="bilstm_1")

        units = 10
        cell1 = rnn.LSTMCell(units)
        cell2 = rnn.LSTMCell(units)
        outputs_2, cell_state_2 = tf.nn.bidirectional_dynamic_rnn(
            cell1, cell2, x, dtype=tf.float32, scope="bilstm_2")

        _ = tf.identity(outputs_1, name="output_1")
        _ = tf.identity(cell_state_1, name="cell_state_1")
        _ = tf.identity(outputs_2, name="output_2")
        _ = tf.identity(cell_state_2, name="cell_state_2")

        feed_dict = {"input_1:0": x_val}
        input_names_with_port = ["input_1:0"]
        output_names_with_port = [
            "output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"
        ]
        self.run_test_case(feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-3,
                           atol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 2))
Esempio n. 6
0
    def test_single_dynamic_lstm_random_weights2(self, state_is_tuple=True):
        hidden_size = 128
        batch_size = 1
        x_val = np.random.randn(1, 133).astype('f')
        x_val = np.stack([x_val] * batch_size)

        x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
        initializer = tf.random_uniform_initializer(0.0, 1.0)
        # no scope
        cell = rnn.LSTMCell(hidden_size,
                            initializer=initializer,
                            state_is_tuple=state_is_tuple)

        outputs, cell_state = tf.nn.dynamic_rnn(cell, x, dtype=tf.float32)

        _ = tf.identity(outputs, name="output")
        _ = tf.identity(cell_state, name="cell_state")

        feed_dict = {"input_1:0": x_val}
        input_names_with_port = ["input_1:0"]
        output_names_with_port = ["output:0", "cell_state:0"]
        self.run_test_case(feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=0.01,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 7
0
    def test_single_dynamic_lstm_placeholder_input(self):
        units = 5
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]],
                         dtype=np.float32)
        x_val = np.stack([x_val] * 6)
        state_is_tuple = True
        x = tf.placeholder(tf.float32, shape=(None, 4, 2), name="input_1")
        initializer = init_ops.constant_initializer(0.5)

        # no scope
        cell = rnn.LSTMCell(units,
                            initializer=initializer,
                            state_is_tuple=state_is_tuple)
        outputs, cell_state = tf.nn.dynamic_rnn(
            cell, x, dtype=tf.float32)  # by default zero initializer is used

        _ = tf.identity(outputs, name="output")
        _ = tf.identity(cell_state, name="cell_state")

        feed_dict = {"input_1:0": x_val}
        input_names_with_port = ["input_1:0"]
        output_names_with_port = ["output:0", "cell_state:0"]
        self.run_test_case(feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 8
0
    def internal_test_single_dynamic_lstm(self, state_is_tuple):
        units = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]],
                         dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)

        x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
        initializer = init_ops.constant_initializer(0.5)

        # no scope
        cell = rnn.LSTMCell(units,
                            initializer=initializer,
                            state_is_tuple=state_is_tuple)
        outputs, cell_state = tf.nn.dynamic_rnn(cell, x, dtype=tf.float32)

        _ = tf.identity(outputs, name="output")
        _ = tf.identity(cell_state, name="cell_state")

        input_names_with_port = ["input_1:0"]
        feed_dict = {"input_1:0": x_val}

        output_names_with_port = ["output:0", "cell_state:0"]
        self.run_test_case(feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 9
0
    def test_single_dynamic_lstm(self):
        units = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]],
                         dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)

        def func(x):
            # no scope
            cell = LSTMBlockCell(units, use_peephole=False)
            outputs, cell_state = dynamic_rnn(cell, x, dtype=tf.float32)
            return tf.identity(outputs,
                               name="output"), tf.identity(cell_state,
                                                           name="cell_state")

        input_names_with_port = ["input_1:0"]
        feed_dict = {"input_1:0": x_val}
        output_names_with_port = ["output:0", "cell_state:0"]
        self.run_test_case(func,
                           feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-3,
                           atol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 10
0
    def test_dynamic_bilstm_state_consumed_only(self, state_is_tuple=True):
        units = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)

        x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
        initializer = init_ops.constant_initializer(0.5)

        # bilstm, no scope
        cell1 = rnn.LSTMCell(
            units, initializer=initializer, state_is_tuple=state_is_tuple
        )  # state_is_tuple will impact Pack node (for cell_state)'s usage pattern
        cell2 = rnn.LSTMCell(units,
                             initializer=initializer,
                             state_is_tuple=state_is_tuple)
        _, cell_state = tf.nn.bidirectional_dynamic_rnn(cell1,
                                                        cell2,
                                                        x,
                                                        dtype=tf.float32)

        _ = tf.identity(cell_state, name="cell_state")

        feed_dict = {"input_1:0": x_val}
        input_names_with_port = ["input_1:0"]
        output_names_with_port = ["cell_state:0"]
        self.run_test_case(feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 11
0
    def test_dynamic_bilstm_output_consumed_only(self, state_is_tuple=True):
        units = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)

        def func(x):
            initializer = init_ops.constant_initializer(0.5)

            # bilstm, no scope
            cell1 = LSTMCell(
                units, initializer=initializer, state_is_tuple=state_is_tuple
            )  # state_is_tuple will impact Pack node (for cell_state)'s usage pattern
            cell2 = LSTMCell(units,
                             initializer=initializer,
                             state_is_tuple=state_is_tuple)
            outputs, _ = bidirectional_dynamic_rnn(cell1,
                                                   cell2,
                                                   x,
                                                   dtype=tf.float32)

            return tf.identity(outputs, name="output")

        feed_dict = {"input_1:0": x_val}
        input_names_with_port = ["input_1:0"]
        output_names_with_port = ["output:0"]
        self.run_test_case(func,
                           feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 12
0
    def test_dynamic_bilstm_unknown_batch_size(self, state_is_tuple=True):
        units = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)

        def func(x):
            initializer = init_ops.constant_initializer(0.5)

            cell1 = LSTMCell(units,
                             initializer=initializer,
                             state_is_tuple=state_is_tuple)
            cell2 = LSTMCell(units,
                             initializer=initializer,
                             state_is_tuple=state_is_tuple)
            _, cell_state = bidirectional_dynamic_rnn(
                cell1,
                cell2,
                x,
                dtype=tf.float32,
            )

            return tf.identity(cell_state, name="cell_state")

        feed_dict = {"input_1:0": x_val}
        input_names_with_port = ["input_1:0"]
        output_names_with_port = ["cell_state:0"]
        self.run_test_case(func,
                           feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 13
0
    def test_single_dynamic_lstm_random_weights(self, state_is_tuple=True):
        hidden_size = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]],
                         dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)

        def func(x):
            initializer = tf.random_uniform_initializer(-1.0, 1.0)

            # no scope
            cell = LSTMCell(hidden_size,
                            initializer=initializer,
                            state_is_tuple=state_is_tuple)

            outputs, cell_state = dynamic_rnn(cell, x, dtype=tf.float32)

            return tf.identity(outputs,
                               name="output"), tf.identity(cell_state,
                                                           name="cell_state")

        feed_dict = {"input_1:0": x_val}
        input_names_with_port = ["input_1:0"]
        output_names_with_port = ["output:0", "cell_state:0"]
        self.run_test_case(func,
                           feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=0.0001,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 14
0
    def test_dynamic_bilstm_output_consumed_only(self):
        units = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)

        x = tf.placeholder(tf.float32, x_val.shape, name="input_1")

        # bilstm, no scope
        cell1 = rnn.LSTMBlockCell(units)
        cell2 = rnn.LSTMBlockCell(units)
        outputs, _ = tf.nn.bidirectional_dynamic_rnn(cell1,
                                                     cell2,
                                                     x,
                                                     dtype=tf.float32)

        _ = tf.identity(outputs, name="output")

        feed_dict = {"input_1:0": x_val}
        input_names_with_port = ["input_1:0"]
        output_names_with_port = ["output:0"]
        self.run_test_case(feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-3,
                           atol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 15
0
    def test_single_dynamic_lstm_consume_one_of_ch_tuple(self):
        units = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]],
                         dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)

        def func(x):
            initializer = init_ops.constant_initializer(0.5)
            state_is_tuple = True
            # no scope
            cell = LSTMCell(units,
                            initializer=initializer,
                            state_is_tuple=state_is_tuple)
            outputs, cell_state = dynamic_rnn(cell, x, dtype=tf.float32)

            return tf.identity(outputs, name="output"), \
                   tf.identity(cell_state.c, name="cell_state_c")

        feed_dict = {"input_1:0": x_val}
        input_names_with_port = ["input_1:0"]
        output_names_with_port = ["output:0", "cell_state_c:0"]
        self.run_test_case(func,
                           feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 16
0
    def test_single_dynamic_lstm_seq_length_is_not_const(self):
        units = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]],
                         dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)
        y_val = np.array([4, 3, 4, 5, 2, 1], dtype=np.int32)

        def func(x, seq_length):
            # no scope
            cell = LSTMBlockCell(units)
            outputs, cell_state = dynamic_rnn(
                cell,
                x,
                dtype=tf.float32,
                sequence_length=tf.identity(seq_length))
            return tf.identity(outputs,
                               name="output"), tf.identity(cell_state,
                                                           name="cell_state")

        feed_dict = {"input_1:0": x_val, "input_2:0": y_val}
        input_names_with_port = ["input_1:0", "input_2:0"]
        output_names_with_port = ["output:0", "cell_state:0"]
        self.run_test_case(func,
                           feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-3,
                           atol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 17
0
    def test_dynamic_bilstm_state_consumed_only(self):
        units = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)

        def func(x):
            # bilstm, no scope
            cell1 = LSTMBlockCell(units)
            cell2 = LSTMBlockCell(units)
            _, cell_state = bidirectional_dynamic_rnn(cell1,
                                                      cell2,
                                                      x,
                                                      dtype=tf.float32)

            return tf.identity(cell_state, name="cell_state")

        feed_dict = {"input_1:0": x_val}
        input_names_with_port = ["input_1:0"]
        output_names_with_port = ["cell_state:0"]
        self.run_test_case(func,
                           feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-3,
                           atol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 18
0
    def test_single_dynamic_lstm_time_major(self):
        units = 5
        seq_len = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]],
                         dtype=np.float32)
        x_val = np.stack([x_val] * seq_len)

        def func(x):
            initializer = init_ops.constant_initializer(0.5)

            # no scope
            cell = LSTMCell(units, initializer=initializer)
            outputs, cell_state = dynamic_rnn(cell,
                                              x,
                                              time_major=True,
                                              dtype=tf.float32)

            return tf.identity(outputs,
                               name="output"), tf.identity(cell_state,
                                                           name="cell_state")

        input_names_with_port = ["input_1:0"]
        feed_dict = {"input_1:0": x_val}

        output_names_with_port = ["output:0", "cell_state:0"]
        self.run_test_case(func,
                           feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 19
0
    def internal_test_multiple_dynamic_lstm_with_parameters(
            self, state_is_tuple):
        units = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]],
                         dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)

        x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
        _ = tf.placeholder(tf.float32, x_val.shape, name="input_2")
        initializer = init_ops.constant_initializer(0.5)

        lstm_output_list = []
        lstm_cell_state_list = []
        # no scope
        cell = rnn.LSTMCell(units,
                            initializer=initializer,
                            state_is_tuple=state_is_tuple)
        outputs, cell_state = tf.nn.dynamic_rnn(cell, x, dtype=tf.float32)
        lstm_output_list.append(outputs)
        lstm_cell_state_list.append(cell_state)

        # given scope
        cell = rnn.LSTMCell(units,
                            initializer=initializer,
                            state_is_tuple=state_is_tuple)
        with variable_scope.variable_scope("root1") as scope:
            outputs, cell_state = tf.nn.dynamic_rnn(
                cell,
                x,
                dtype=tf.float32,
                sequence_length=[4, 4, 4, 4, 4, 4],
                scope=scope)
        lstm_output_list.append(outputs)
        lstm_cell_state_list.append(cell_state)

        _ = tf.identity(lstm_output_list, name="output")
        _ = tf.identity(lstm_cell_state_list, name="cell_state")

        feed_dict = {"input_1:0": x_val}
        input_names_with_port = ["input_1:0"]
        output_names_with_port = ["output:0", "cell_state:0"]
        self.run_test_case(feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 2))
Esempio n. 20
0
    def test_dynamic_lstm_state_consumed_only(self):
        units = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)

        def func(x):
            cell1 = LSTMCell(units, state_is_tuple=True)
            _, cell_state = dynamic_rnn(cell1, x, dtype=tf.float32)
            return tf.identity(cell_state, name="cell_state")

        feed_dict = {"input_1:0": x_val}
        input_names_with_port = ["input_1:0"]
        output_names_with_port = ["cell_state:0"]
        self.run_test_case(func,
                           feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=0.0001,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 21
0
    def test_dynamic_lstm_output_consumed_only(self):
        units = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)

        def func(x):
            cell1 = LSTMBlockCell(units)
            outputs, _ = dynamic_rnn(cell1, x, dtype=tf.float32)
            return tf.identity(outputs, name="output")

        feed_dict = {"input_1:0": x_val}
        input_names_with_port = ["input_1:0"]
        output_names_with_port = ["output:0"]
        self.run_test_case(func,
                           feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-3,
                           atol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 22
0
    def test_dynamic_lstm_output_consumed_only(self):
        units = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)

        x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
        cell1 = rnn.LSTMCell(units, state_is_tuple=True)

        outputs, _ = tf.nn.dynamic_rnn(cell1, x, dtype=tf.float32)

        _ = tf.identity(outputs, name="output")

        feed_dict = {"input_1:0": x_val}
        input_names_with_port = ["input_1:0"]
        output_names_with_port = ["output:0"]
        self.run_test_case(feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=0.0001,
                           atol=1e-07,
                           graph_validator=lambda g: check_lstm_count(g, 1))
Esempio n. 23
0
    def test_layered_lstm(self):
        units = 5
        batch_size = 6
        x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]],
                         dtype=np.float32)
        x_val = np.stack([x_val] * batch_size)

        def func(x):
            initializer = init_ops.constant_initializer(0.5)
            num_layers = 2

            # no scope
            def lstm_cell():
                return LSTMCell(units,
                                initializer=initializer,
                                state_is_tuple=True)

            stacked_lstm = MultiRNNCell(
                [lstm_cell() for _ in range(num_layers)])
            outputs, cell_state = dynamic_rnn(stacked_lstm,
                                              x,
                                              dtype=tf.float32)
            return tf.identity(outputs,
                               name="output"), tf.identity(cell_state,
                                                           name="cell_state")

        input_names_with_port = ["input_1:0"]
        feed_dict = {"input_1:0": x_val}

        output_names_with_port = ["output:0", "cell_state:0"]
        self.run_test_case(func,
                           feed_dict,
                           input_names_with_port,
                           output_names_with_port,
                           rtol=1e-06,
                           graph_validator=lambda g: check_lstm_count(g, 2))