Beispiel #1
0
  def testTimeReversedFusedRNN(self):
    with self.cached_session() as sess:
      initializer = init_ops.random_uniform_initializer(
          -0.01, 0.01, seed=19890213)
      fw_cell = rnn_cell.BasicRNNCell(10)
      bw_cell = rnn_cell.BasicRNNCell(10)
      batch_size = 5
      input_size = 20
      timelen = 15
      inputs = constant_op.constant(
          np.random.randn(timelen, batch_size, input_size))

      # test bi-directional rnn
      with variable_scope.variable_scope("basic", initializer=initializer):
        unpacked_inputs = array_ops.unstack(inputs)
        outputs, fw_state, bw_state = rnn.static_bidirectional_rnn(
            fw_cell, bw_cell, unpacked_inputs, dtype=dtypes.float64)
        packed_outputs = array_ops.stack(outputs)
        basic_vars = [
            v for v in variables.trainable_variables()
            if v.name.startswith("basic/")
        ]
        sess.run([variables.global_variables_initializer()])
        basic_outputs, basic_fw_state, basic_bw_state = sess.run(
            [packed_outputs, fw_state, bw_state])
        basic_grads = sess.run(gradients_impl.gradients(packed_outputs, inputs))
        basic_wgrads = sess.run(
            gradients_impl.gradients(packed_outputs, basic_vars))

      with variable_scope.variable_scope("fused", initializer=initializer):
        fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
            rnn_cell.BasicRNNCell(10))
        fused_bw_cell = fused_rnn_cell.TimeReversedFusedRNN(
            fused_rnn_cell.FusedRNNCellAdaptor(rnn_cell.BasicRNNCell(10)))
        fw_outputs, fw_state = fused_cell(
            inputs, dtype=dtypes.float64, scope="fw")
        bw_outputs, bw_state = fused_bw_cell(
            inputs, dtype=dtypes.float64, scope="bw")
        outputs = array_ops.concat([fw_outputs, bw_outputs], 2)
        fused_vars = [
            v for v in variables.trainable_variables()
            if v.name.startswith("fused/")
        ]
        sess.run([variables.global_variables_initializer()])
        fused_outputs, fused_fw_state, fused_bw_state = sess.run(
            [outputs, fw_state, bw_state])
        fused_grads = sess.run(gradients_impl.gradients(outputs, inputs))
        fused_wgrads = sess.run(gradients_impl.gradients(outputs, fused_vars))

      self.assertAllClose(basic_outputs, fused_outputs)
      self.assertAllClose(basic_fw_state, fused_fw_state)
      self.assertAllClose(basic_bw_state, fused_bw_state)
      self.assertAllClose(basic_grads, fused_grads)
      for basic, fused in zip(basic_wgrads, fused_wgrads):
        self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)
Beispiel #2
0
    def testBasicRNNFusedWrapper(self):
        """This test checks that using a wrapper for BasicRNN works as expected."""

        with self.test_session() as sess:
            initializer = init_ops.random_uniform_initializer(-0.01,
                                                              0.01,
                                                              seed=19890212)
            cell = core_rnn_cell_impl.BasicRNNCell(10)
            batch_size = 5
            input_size = 20
            timelen = 15
            inputs = constant_op.constant(
                np.random.randn(timelen, batch_size, input_size))
            with variable_scope.variable_scope("basic",
                                               initializer=initializer):
                unpacked_inputs = array_ops.unstack(inputs)
                outputs, state = core_rnn.static_rnn(cell,
                                                     unpacked_inputs,
                                                     dtype=dtypes.float64)
                packed_outputs = array_ops.stack(outputs)
                basic_vars = [
                    v for v in variables.trainable_variables()
                    if v.name.startswith("basic/")
                ]
                sess.run([variables.global_variables_initializer()])
                basic_outputs, basic_state = sess.run([packed_outputs, state])
                basic_grads = sess.run(
                    gradients_impl.gradients(packed_outputs, inputs))
                basic_wgrads = sess.run(
                    gradients_impl.gradients(packed_outputs, basic_vars))

            with variable_scope.variable_scope("fused_static",
                                               initializer=initializer):
                fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(cell)
                outputs, state = fused_cell(inputs, dtype=dtypes.float64)
                fused_static_vars = [
                    v for v in variables.trainable_variables()
                    if v.name.startswith("fused_static/")
                ]
                sess.run([variables.global_variables_initializer()])
                fused_static_outputs, fused_static_state = sess.run(
                    [outputs, state])
                fused_static_grads = sess.run(
                    gradients_impl.gradients(outputs, inputs))
                fused_static_wgrads = sess.run(
                    gradients_impl.gradients(outputs, fused_static_vars))

            self.assertAllClose(basic_outputs, fused_static_outputs)
            self.assertAllClose(basic_state, fused_static_state)
            self.assertAllClose(basic_grads, fused_static_grads)
            for basic, fused in zip(basic_wgrads, fused_static_wgrads):
                self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)

            with variable_scope.variable_scope("fused_dynamic",
                                               initializer=initializer):
                fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
                    cell, use_dynamic_rnn=True)
                outputs, state = fused_cell(inputs, dtype=dtypes.float64)
                fused_dynamic_vars = [
                    v for v in variables.trainable_variables()
                    if v.name.startswith("fused_dynamic/")
                ]
                sess.run([variables.global_variables_initializer()])
                fused_dynamic_outputs, fused_dynamic_state = sess.run(
                    [outputs, state])
                fused_dynamic_grads = sess.run(
                    gradients_impl.gradients(outputs, inputs))
                fused_dynamic_wgrads = sess.run(
                    gradients_impl.gradients(outputs, fused_dynamic_vars))

            self.assertAllClose(basic_outputs, fused_dynamic_outputs)
            self.assertAllClose(basic_state, fused_dynamic_state)
            self.assertAllClose(basic_grads, fused_dynamic_grads)
            for basic, fused in zip(basic_wgrads, fused_dynamic_wgrads):
                self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)