Ejemplo n.º 1
0
    def testCompatibleNames(self):
        with self.test_session(use_gpu=self._use_gpu, graph=ops.Graph()):
            cell = core_rnn_cell_impl.LSTMCell(10)
            pcell = core_rnn_cell_impl.LSTMCell(10, use_peepholes=True)
            inputs = [array_ops.zeros([4, 5])] * 6
            core_rnn.static_rnn(cell,
                                inputs,
                                dtype=dtypes.float32,
                                scope="basic")
            core_rnn.static_rnn(pcell,
                                inputs,
                                dtype=dtypes.float32,
                                scope="peephole")
            basic_names = {
                v.name: v.get_shape()
                for v in variables.trainable_variables()
            }

        with self.test_session(use_gpu=self._use_gpu, graph=ops.Graph()):
            cell = lstm_ops.LSTMBlockCell(10)
            pcell = lstm_ops.LSTMBlockCell(10, use_peephole=True)
            inputs = [array_ops.zeros([4, 5])] * 6
            core_rnn.static_rnn(cell,
                                inputs,
                                dtype=dtypes.float32,
                                scope="basic")
            core_rnn.static_rnn(pcell,
                                inputs,
                                dtype=dtypes.float32,
                                scope="peephole")
            block_names = {
                v.name: v.get_shape()
                for v in variables.trainable_variables()
            }

        with self.test_session(use_gpu=self._use_gpu, graph=ops.Graph()):
            cell = lstm_ops.LSTMBlockFusedCell(10)
            pcell = lstm_ops.LSTMBlockFusedCell(10, use_peephole=True)
            inputs = [array_ops.zeros([4, 5])] * 6
            cell(inputs, dtype=dtypes.float32, scope="basic/lstm_cell")
            pcell(inputs, dtype=dtypes.float32, scope="peephole/lstm_cell")
            fused_names = {
                v.name: v.get_shape()
                for v in variables.trainable_variables()
            }

        self.assertEqual(basic_names, block_names)
        self.assertEqual(basic_names, fused_names)
Ejemplo n.º 2
0
  def __init__(self,
               num_features,
               input_window_size,
               output_window_size,
               num_units=128):
    """Construct the LSTM prediction model.

    Args:
      num_features: number of input features per time step.
      input_window_size: Number of past time steps of data to look at when doing
        the regression.
      output_window_size: Number of future time steps to predict. Note that
        setting it to > 1 empirically seems to give a better fit.
      num_units: The number of units in the encoder and decoder LSTM cells.
    """
    super(LSTMPredictionModel, self).__init__()
    self._encoder = lstm_ops.LSTMBlockFusedCell(
        num_units=num_units, name="encoder")
    self._decoder = lstm_ops.LSTMBlockFusedCell(
        num_units=num_units, name="decoder")
    self._mean_transform = core.Dense(num_features,
                                      name="mean_transform")
    self._covariance_transform = core.Dense(num_features,
                                            name="covariance_transform")
Ejemplo n.º 3
0
def blocks_match(sess, use_peephole):
    batch_size = 2
    input_size = 3
    cell_size = 4
    sequence_length = 4

    inputs = []
    for _ in range(sequence_length):
        inp = ops.convert_to_tensor(np.random.randn(batch_size, input_size),
                                    dtype=dtypes.float32)
        inputs.append(inp)
    stacked_inputs = array_ops.stack(inputs)

    initializer = init_ops.random_uniform_initializer(-0.01,
                                                      0.01,
                                                      seed=19890212)

    with variable_scope.variable_scope("test", initializer=initializer):
        # magic naming so that the cells pick up these variables and reuse them
        if use_peephole:
            wci = variable_scope.get_variable("rnn/lstm_cell/w_i_diag",
                                              shape=[cell_size],
                                              dtype=dtypes.float32)
            wcf = variable_scope.get_variable("rnn/lstm_cell/w_f_diag",
                                              shape=[cell_size],
                                              dtype=dtypes.float32)
            wco = variable_scope.get_variable("rnn/lstm_cell/w_o_diag",
                                              shape=[cell_size],
                                              dtype=dtypes.float32)

        w = variable_scope.get_variable(
            "rnn/lstm_cell/kernel",
            shape=[input_size + cell_size, cell_size * 4],
            dtype=dtypes.float32)
        b = variable_scope.get_variable(
            "rnn/lstm_cell/bias",
            shape=[cell_size * 4],
            dtype=dtypes.float32,
            initializer=init_ops.zeros_initializer())

        basic_cell = rnn_cell.LSTMCell(cell_size,
                                       use_peepholes=use_peephole,
                                       state_is_tuple=True,
                                       reuse=True)
        basic_outputs_op, basic_state_op = rnn.static_rnn(basic_cell,
                                                          inputs,
                                                          dtype=dtypes.float32)

        if use_peephole:
            _, _, _, _, _, _, block_outputs_op = block_lstm(
                ops.convert_to_tensor(sequence_length, dtype=dtypes.int64),
                inputs,
                w,
                b,
                wci=wci,
                wcf=wcf,
                wco=wco,
                cell_clip=0,
                use_peephole=True)
        else:
            _, _, _, _, _, _, block_outputs_op = block_lstm(
                ops.convert_to_tensor(sequence_length, dtype=dtypes.int64),
                inputs,
                w,
                b,
                cell_clip=0)

        fused_cell = lstm_ops.LSTMBlockFusedCell(cell_size,
                                                 cell_clip=0,
                                                 use_peephole=use_peephole,
                                                 reuse=True,
                                                 name="rnn/lstm_cell")
        fused_outputs_op, fused_state_op = fused_cell(stacked_inputs,
                                                      dtype=dtypes.float32)

        sess.run([variables.global_variables_initializer()])
        basic_outputs, basic_state = sess.run(
            [basic_outputs_op, basic_state_op[0]])
        basic_grads = sess.run(
            gradients_impl.gradients(basic_outputs_op, inputs))
        xs = [w, b]
        if use_peephole:
            xs += [wci, wcf, wco]
        basic_wgrads = sess.run(gradients_impl.gradients(basic_outputs_op, xs))

        block_outputs = sess.run(block_outputs_op)
        block_grads = sess.run(
            gradients_impl.gradients(block_outputs_op, inputs))
        block_wgrads = sess.run(gradients_impl.gradients(block_outputs_op, xs))

        xs = [w, b]
        if use_peephole:
            xs += [wci, wcf, wco]
        fused_outputs, fused_state = sess.run(
            [fused_outputs_op, fused_state_op[0]])
        fused_grads = sess.run(
            gradients_impl.gradients(fused_outputs_op, inputs))
        fused_wgrads = sess.run(gradients_impl.gradients(fused_outputs_op, xs))

        return (basic_state, fused_state, basic_outputs, block_outputs,
                fused_outputs, basic_grads, block_grads, fused_grads,
                basic_wgrads, block_wgrads, fused_wgrads)
Ejemplo n.º 4
0
    def testLSTMFusedSequenceLengths(self):
        """Verify proper support for sequence lengths in LSTMBlockFusedCell."""
        with self.session(use_gpu=True) as sess:
            batch_size = 3
            input_size = 4
            cell_size = 5
            max_sequence_length = 6

            inputs = []
            for _ in range(max_sequence_length):
                inp = ops.convert_to_tensor(np.random.randn(
                    batch_size, input_size),
                                            dtype=dtypes.float32)
                inputs.append(inp)
            seq_lengths = constant_op.constant([3, 4, 5])
            cell_inputs = array_ops.stack(inputs)

            initializer = init_ops.random_uniform_initializer(-0.01,
                                                              0.01,
                                                              seed=19890213)

            with variable_scope.variable_scope("lstm_cell",
                                               initializer=initializer):
                # magic naming so that the cells pick up these variables and reuse them
                variable_scope.get_variable(
                    "kernel",
                    shape=[input_size + cell_size, cell_size * 4],
                    dtype=dtypes.float32)

                variable_scope.get_variable(
                    "bias",
                    shape=[cell_size * 4],
                    dtype=dtypes.float32,
                    initializer=init_ops.zeros_initializer())

            cell = lstm_ops.LSTMBlockFusedCell(cell_size,
                                               cell_clip=0,
                                               use_peephole=False,
                                               reuse=True,
                                               name="lstm_cell")

            fused_outputs_op, fused_state_op = cell(
                cell_inputs, dtype=dtypes.float32, sequence_length=seq_lengths)

            cell_vars = [
                v for v in variables.trainable_variables()
                if v.name.endswith("kernel") or v.name.endswith("bias")
            ]

            # Verify that state propagation works if we turn our sequence into
            # tiny (single-time) subsequences, i.e. unfuse the cell
            unfused_outputs_op = []
            state = None
            with variable_scope.variable_scope(
                    variable_scope.get_variable_scope(), reuse=True):
                for i, inp in enumerate(inputs):
                    lengths = [int(i < l) for l in seq_lengths.eval()]
                    output, state = cell(array_ops.expand_dims(inp, 0),
                                         initial_state=state,
                                         dtype=dtypes.float32,
                                         sequence_length=lengths)
                    unfused_outputs_op.append(output[0])
            unfused_outputs_op = array_ops.stack(unfused_outputs_op)

            sess.run([variables.global_variables_initializer()])
            unfused_outputs, unfused_state = sess.run(
                [unfused_outputs_op, state[0]])
            unfused_grads = sess.run(
                gradients_impl.gradients(unfused_outputs_op, inputs))
            unfused_wgrads = sess.run(
                gradients_impl.gradients(unfused_outputs_op, cell_vars))

            fused_outputs, fused_state = sess.run(
                [fused_outputs_op, fused_state_op[0]])
            fused_grads = sess.run(
                gradients_impl.gradients(fused_outputs_op, inputs))
            fused_wgrads = sess.run(
                gradients_impl.gradients(fused_outputs_op, cell_vars))

            self.assertAllClose(fused_outputs, unfused_outputs)
            self.assertAllClose(fused_state, unfused_state)
            self.assertAllClose(fused_grads, unfused_grads)
            for fused, unfused in zip(fused_wgrads, unfused_wgrads):
                self.assertAllClose(fused, unfused, rtol=1e-6, atol=1e-6)
Ejemplo n.º 5
0
    def testLSTMFusedSequenceLengths(self):
        """Verify proper support for sequence lengths in LSTMBlockFusedCell."""
        with self.test_session(use_gpu=True) as sess:
            batch_size = 3
            input_size = 4
            cell_size = 5
            max_sequence_length = 6

            inputs = []
            for _ in range(max_sequence_length):
                inp = ops.convert_to_tensor(np.random.randn(
                    batch_size, input_size),
                                            dtype=dtypes.float32)
                inputs.append(inp)
            seq_lengths = constant_op.constant([3, 4, 5])

            initializer = init_ops.random_uniform_initializer(-0.01,
                                                              0.01,
                                                              seed=19890213)
            with variable_scope.variable_scope("basic",
                                               initializer=initializer):
                cell = rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
                outputs, state = rnn.static_rnn(cell,
                                                inputs,
                                                dtype=dtypes.float32,
                                                sequence_length=seq_lengths)
                sess.run([variables.global_variables_initializer()])
                basic_outputs, basic_state = sess.run([outputs, state[0]])
                basic_grads = sess.run(
                    gradients_impl.gradients(outputs, inputs))
                basic_wgrads = sess.run(
                    gradients_impl.gradients(outputs,
                                             variables.trainable_variables()))

            with variable_scope.variable_scope("fused",
                                               initializer=initializer):
                cell = lstm_ops.LSTMBlockFusedCell(cell_size,
                                                   cell_clip=0,
                                                   use_peephole=False)
                outputs, state = cell(inputs,
                                      dtype=dtypes.float32,
                                      sequence_length=seq_lengths)

                sess.run([variables.global_variables_initializer()])
                fused_outputs, fused_state = sess.run([outputs, state[0]])
                fused_grads = sess.run(
                    gradients_impl.gradients(outputs, inputs))
                fused_vars = [
                    v for v in variables.trainable_variables()
                    if v.name.startswith("fused/")
                ]
                fused_wgrads = sess.run(
                    gradients_impl.gradients(outputs, fused_vars))

            self.assertAllClose(basic_outputs, fused_outputs)
            self.assertAllClose(basic_state, fused_state)
            self.assertAllClose(basic_grads, fused_grads)
            for basic, fused in zip(basic_wgrads, fused_wgrads):
                self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)

            # Verify that state propagation works if we turn our sequence into
            # tiny (single-time) subsequences, i.e. unfuse the cell
            with variable_scope.variable_scope("unfused",
                                               initializer=initializer) as vs:
                cell = lstm_ops.LSTMBlockFusedCell(cell_size,
                                                   cell_clip=0,
                                                   use_peephole=False)
                outputs = []
                state = None
                for i, inp in enumerate(inputs):
                    lengths = [int(i < l) for l in seq_lengths.eval()]
                    output, state = cell([inp],
                                         initial_state=state,
                                         dtype=dtypes.float32,
                                         sequence_length=lengths)
                    vs.reuse_variables()
                    outputs.append(output[0])
                outputs = array_ops.stack(outputs)

                sess.run([variables.global_variables_initializer()])
                unfused_outputs, unfused_state = sess.run([outputs, state[0]])
                unfused_grads = sess.run(
                    gradients_impl.gradients(outputs, inputs))
                unfused_vars = [
                    v for v in variables.trainable_variables()
                    if v.name.startswith("unfused/")
                ]
                unfused_wgrads = sess.run(
                    gradients_impl.gradients(outputs, unfused_vars))

            self.assertAllClose(basic_outputs, unfused_outputs)
            self.assertAllClose(basic_state, unfused_state)
            self.assertAllClose(basic_grads, unfused_grads)
            for basic, unfused in zip(basic_wgrads, unfused_wgrads):
                self.assertAllClose(basic, unfused, rtol=1e-2, atol=1e-2)
Ejemplo n.º 6
0
    def testLSTMBasicToBlockPeeping(self):
        with self.test_session(use_gpu=True) as sess:
            batch_size = 2
            input_size = 3
            cell_size = 4
            sequence_length = 5

            inputs = []
            for _ in range(sequence_length):
                inp = ops.convert_to_tensor(np.random.randn(
                    batch_size, input_size),
                                            dtype=dtypes.float32)
                inputs.append(inp)

            initializer = init_ops.random_uniform_initializer(-0.01,
                                                              0.01,
                                                              seed=19890212)
            with variable_scope.variable_scope("basic",
                                               initializer=initializer):
                cell = rnn_cell.LSTMCell(cell_size,
                                         use_peepholes=True,
                                         state_is_tuple=True)
                outputs, state = rnn.static_rnn(cell,
                                                inputs,
                                                dtype=dtypes.float32)

                sess.run([variables.global_variables_initializer()])
                basic_outputs, basic_state = sess.run([outputs, state[0]])
                basic_grads = sess.run(
                    gradients_impl.gradients(outputs, inputs))
                basic_wgrads = sess.run(
                    gradients_impl.gradients(outputs,
                                             variables.trainable_variables()))

            with variable_scope.variable_scope("block",
                                               initializer=initializer):
                w = variable_scope.get_variable(
                    "w",
                    shape=[input_size + cell_size, cell_size * 4],
                    dtype=dtypes.float32)
                b = variable_scope.get_variable(
                    "b",
                    shape=[cell_size * 4],
                    dtype=dtypes.float32,
                    initializer=init_ops.zeros_initializer())

                wci = variable_scope.get_variable("wci",
                                                  shape=[cell_size],
                                                  dtype=dtypes.float32)
                wcf = variable_scope.get_variable("wcf",
                                                  shape=[cell_size],
                                                  dtype=dtypes.float32)
                wco = variable_scope.get_variable("wco",
                                                  shape=[cell_size],
                                                  dtype=dtypes.float32)

                _, _, _, _, _, _, outputs = block_lstm(ops.convert_to_tensor(
                    sequence_length, dtype=dtypes.int64),
                                                       inputs,
                                                       w,
                                                       b,
                                                       wci=wci,
                                                       wcf=wcf,
                                                       wco=wco,
                                                       cell_clip=0,
                                                       use_peephole=True)

                sess.run([variables.global_variables_initializer()])
                block_outputs = sess.run(outputs)
                block_grads = sess.run(
                    gradients_impl.gradients(outputs, inputs))
                block_wgrads = sess.run(
                    gradients_impl.gradients(outputs, [w, b, wci, wcf, wco]))

            self.assertAllClose(basic_outputs, block_outputs)
            self.assertAllClose(basic_grads, block_grads)
            for basic, block in zip(basic_wgrads, block_wgrads):
                self.assertAllClose(basic, block, rtol=1e-2, atol=1e-2)

            with variable_scope.variable_scope("fused",
                                               initializer=initializer):
                cell = lstm_ops.LSTMBlockFusedCell(cell_size,
                                                   cell_clip=0,
                                                   use_peephole=True)
                outputs, state = cell(inputs, dtype=dtypes.float32)

                sess.run([variables.global_variables_initializer()])
                fused_outputs, fused_state = sess.run([outputs, state[0]])
                fused_grads = sess.run(
                    gradients_impl.gradients(outputs, inputs))
                fused_vars = [
                    v for v in variables.trainable_variables()
                    if v.name.startswith("fused/")
                ]
                fused_wgrads = sess.run(
                    gradients_impl.gradients(outputs, fused_vars))

            self.assertAllClose(basic_outputs, fused_outputs)
            self.assertAllClose(basic_state, fused_state)
            self.assertAllClose(basic_grads, fused_grads)
            for basic, fused in zip(basic_wgrads, fused_wgrads):
                self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)
Ejemplo n.º 7
0
def tmp():
    initializer = init_ops.random_uniform_initializer(-0.01, 0.01)

    def lstm_cell():
        hidden_size = RNN_UNIT_SIZE
        input_size = CONTENT_DIM
        cell = tf.contrib.rnn.LSTMCell(hidden_size,
                                       input_size,
                                       initializer=initializer,
                                       state_is_tuple=True)
        return cell

    if True:
        attn_length = 16
        cells = [lstm_cell() for _ in range(2)]
        cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)
        cell = tf.contrib.rnn.AttentionCellWrapper(cell,
                                                   attn_length,
                                                   state_is_tuple=True)
        outputs, states = tf.nn.dynamic_rnn(cell,
                                            content_embeddings,
                                            sequence_length=content_lengths,
                                            dtype=tf.float32)
        #last_outputs = states[0][-1].h
        last_outputs = tf.concat([states[0][-1].h, states[-1]], 1)
    elif True:
        content_embeddings = tf.unstack(content_embeddings, 200, 1)
        cell = lstm_ops.LSTMBlockFusedCell(RNN_UNIT_SIZE)
        content_lengths = tf.cast(content_lengths, tf.int32)
        outputs, state = cell(content_embeddings,
                              sequence_length=content_lengths,
                              dtype=tf.float32)
        last_outputs = state.h
    elif True:
        layer_sizes = [RNN_UNIT_SIZE, RNN_UNIT_SIZE]
        cell = make_rnn_cell(layer_sizes,
                             dropout_keep_prob=dropout_keep_prob,
                             base_cell=lstm_ops.LSTMBlockCell,
                             attn_length=16)
        outputs, final_state = tf.nn.dynamic_rnn(
            cell,
            content_embeddings,
            sequence_length=content_lengths,
            swap_memory=True,
            dtype=tf.float32)
        last_outputs = final_state[-1].h
        #last_outputs = tf.concat([final_state[-1].h, final_state[0][1]], 1)
    elif True:
        cell = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(2)],
                                           state_is_tuple=True)
        outputs, states = tf.nn.dynamic_rnn(cell,
                                            content_embeddings,
                                            sequence_length=content_lengths,
                                            dtype=tf.float32)
        last_outputs = states[-1].h
    elif True:
        num_hidden = RNN_UNIT_SIZE
        cell_fw = tf.nn.rnn_cell.LSTMCell(num_units=num_hidden,
                                          state_is_tuple=True)
        cell_bw = tf.nn.rnn_cell.LSTMCell(num_units=num_hidden,
                                          state_is_tuple=True)
        outputs, states = tf.nn.bidirectional_dynamic_rnn(
            cell_fw,
            cell_bw,
            content_embeddings,
            sequence_length=content_lengths,
            dtype=tf.float32)
        output_fw, output_bw = outputs
        output_state_fw, output_state_bw = states
        #last_outputs = tf.concat([output_fw[:, 0], output_state_bw.h], 1)
        last_outputs = tf.concat([output_state_fw.h, output_state_bw.h], 1)
    elif True:
        num_hidden = RNN_UNIT_SIZE
        lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(num_hidden,
                                                    forget_bias=1.0)
        lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(num_hidden,
                                                    forget_bias=1.0)
        content_embeddings = tf.unstack(content_embeddings, 200, 1)
        outputs, _, _ = tf.contrib.rnn.static_bidirectional_rnn(
            lstm_fw_cell,
            lstm_bw_cell,
            content_embeddings,
            sequence_length=content_lengths,
            dtype=tf.float32)
        last_outputs = outputs[-1]
Ejemplo n.º 8
0
  def testLSTMBasicToBlockPeeping(self):
    with self.test_session(use_gpu=True) as sess:
      batch_size = 2
      input_size = 3
      cell_size = 4
      sequence_length = 4

      inputs = []
      for _ in range(sequence_length):
        inp = ops.convert_to_tensor(
            np.random.randn(batch_size, input_size), dtype=dtypes.float32)
        inputs.append(inp)

      initializer = init_ops.random_uniform_initializer(
          -0.01, 0.01, seed=19890212)

      with variable_scope.variable_scope("test", initializer=initializer):
        # magic naming so that the cells pick up these variables and resuse them
        wci = variable_scope.get_variable(
            "rnn/lstm_cell/w_i_diag", shape=[cell_size], dtype=dtypes.float32)
        wcf = variable_scope.get_variable(
            "rnn/lstm_cell/w_f_diag", shape=[cell_size], dtype=dtypes.float32)
        wco = variable_scope.get_variable(
            "rnn/lstm_cell/w_o_diag", shape=[cell_size], dtype=dtypes.float32)

        w = variable_scope.get_variable(
            "rnn/lstm_cell/kernel",
            shape=[input_size + cell_size, cell_size * 4],
            dtype=dtypes.float32)
        b = variable_scope.get_variable(
            "rnn/lstm_cell/bias",
            shape=[cell_size * 4],
            dtype=dtypes.float32,
            initializer=init_ops.zeros_initializer())

        wci_block = variable_scope.get_variable(
            "rnn/lstm_cell/lstm_block_wrapper/w_i_diag",
            initializer=wci.initialized_value())
        wcf_block = variable_scope.get_variable(
            "rnn/lstm_cell/lstm_block_wrapper/w_f_diag",
            initializer=wcf.initialized_value())
        wco_block = variable_scope.get_variable(
            "rnn/lstm_cell/lstm_block_wrapper/w_o_diag",
            initializer=wco.initialized_value())
        w_block = variable_scope.get_variable(
            "rnn/lstm_cell/lstm_block_wrapper/kernel",
            initializer=w.initialized_value())
        b_block = variable_scope.get_variable(
            "rnn/lstm_cell/lstm_block_wrapper/bias",
            initializer=b.initialized_value())

        basic_cell = rnn_cell.LSTMCell(
            cell_size, use_peepholes=True, state_is_tuple=True, reuse=True)
        basic_outputs_op, basic_state_op = rnn.static_rnn(
            basic_cell, inputs, dtype=dtypes.float32)

        _, _, _, _, _, _, block_outputs_op = block_lstm(
            ops.convert_to_tensor(sequence_length, dtype=dtypes.int64),
            inputs,
            w,
            b,
            wci=wci,
            wcf=wcf,
            wco=wco,
            cell_clip=0,
            use_peephole=True)

        with variable_scope.variable_scope("rnn/lstm_cell", reuse=True):
          fused_cell = lstm_ops.LSTMBlockFusedCell(
              cell_size, cell_clip=0, use_peephole=True)
          fused_outputs_op, fused_state_op = fused_cell(
              inputs, dtype=dtypes.float32)

        sess.run([variables.global_variables_initializer()])
        basic_outputs, basic_state = sess.run(
            [basic_outputs_op, basic_state_op[0]])
        basic_grads = sess.run(
            gradients_impl.gradients(basic_outputs_op, inputs))
        basic_wgrads = sess.run(
            gradients_impl.gradients(basic_outputs_op, [w, b, wci, wcf, wco]))

        block_outputs = sess.run(block_outputs_op)
        block_grads = sess.run(
            gradients_impl.gradients(block_outputs_op, inputs))
        block_wgrads = sess.run(
            gradients_impl.gradients(block_outputs_op, [w, b, wci, wcf, wco]))

        fused_outputs, fused_state = sess.run(
            [fused_outputs_op, fused_state_op[0]])
        fused_grads = sess.run(
            gradients_impl.gradients(fused_outputs_op, inputs))
        fused_wgrads = sess.run(
            gradients_impl.gradients(
                fused_outputs_op,
                [w_block, b_block, wci_block, wcf_block, wco_block]))

      self.assertAllClose(basic_outputs, block_outputs)
      self.assertAllClose(basic_grads, block_grads)
      for basic, block in zip(basic_wgrads, block_wgrads):
        self.assertAllClose(basic, block, rtol=1e-6, atol=1e-6)

      self.assertAllClose(basic_outputs, fused_outputs)
      self.assertAllClose(basic_state, fused_state)
      self.assertAllClose(basic_grads, fused_grads)
      for basic, fused in zip(block_wgrads, fused_wgrads):
        self.assertAllClose(basic, fused, rtol=1e-6, atol=1e-6)