def testBasicRNNCellNotTrainable(self):
        with self.test_session() as sess:

            def not_trainable_getter(getter, *args, **kwargs):
                kwargs["trainable"] = False
                return getter(*args, **kwargs)

            with variable_scope.variable_scope(
                    "root",
                    initializer=init_ops.constant_initializer(0.5),
                    custom_getter=not_trainable_getter):
                x = array_ops.zeros([1, 2])
                m = array_ops.zeros([1, 2])
                cell = core_rnn_cell_impl.BasicRNNCell(2)
                g, _ = cell(x, m)
                self.assertFalse(cell.trainable_variables)
                self.assertEqual([
                    "root/basic_rnn_cell/%s:0" %
                    core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
                    "root/basic_rnn_cell/%s:0" %
                    core_rnn_cell_impl._BIAS_VARIABLE_NAME
                ], [v.name for v in cell.non_trainable_variables])
                sess.run([variables_lib.global_variables_initializer()])
                res = sess.run([g], {
                    x.name: np.array([[1., 1.]]),
                    m.name: np.array([[0.1, 0.1]])
                })
                self.assertEqual(res[0].shape, (1, 2))
示例#2
0
    def testTimeReversedFusedRNN(self):
        with self.test_session() as sess:
            initializer = init_ops.random_uniform_initializer(-0.01,
                                                              0.01,
                                                              seed=19890213)
            cell = core_rnn_cell_impl.BasicRNNCell(10)
            batch_size = 5
            input_size = 20
            timelen = 15
            inputs = constant_op.constant(
                np.random.randn(timelen, batch_size, input_size))

            # test bi-directional rnn
            with variable_scope.variable_scope("basic",
                                               initializer=initializer):
                unpacked_inputs = array_ops.unstack(inputs)
                outputs, fw_state, bw_state = core_rnn.static_bidirectional_rnn(
                    cell, cell, unpacked_inputs, dtype=dtypes.float64)
                packed_outputs = array_ops.stack(outputs)
                basic_vars = [
                    v for v in variables.trainable_variables()
                    if v.name.startswith("basic/")
                ]
                sess.run([variables.global_variables_initializer()])
                basic_outputs, basic_fw_state, basic_bw_state = sess.run(
                    [packed_outputs, fw_state, bw_state])
                basic_grads = sess.run(
                    gradients_impl.gradients(packed_outputs, inputs))
                basic_wgrads = sess.run(
                    gradients_impl.gradients(packed_outputs, basic_vars))

            with variable_scope.variable_scope("fused",
                                               initializer=initializer):
                fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(cell)
                fused_bw_cell = fused_rnn_cell.TimeReversedFusedRNN(fused_cell)
                fw_outputs, fw_state = fused_cell(inputs,
                                                  dtype=dtypes.float64,
                                                  scope="fw")
                bw_outputs, bw_state = fused_bw_cell(inputs,
                                                     dtype=dtypes.float64,
                                                     scope="bw")
                outputs = array_ops.concat([fw_outputs, bw_outputs], 2)
                fused_vars = [
                    v for v in variables.trainable_variables()
                    if v.name.startswith("fused/")
                ]
                sess.run([variables.global_variables_initializer()])
                fused_outputs, fused_fw_state, fused_bw_state = sess.run(
                    [outputs, fw_state, bw_state])
                fused_grads = sess.run(
                    gradients_impl.gradients(outputs, inputs))
                fused_wgrads = sess.run(
                    gradients_impl.gradients(outputs, fused_vars))

            self.assertAllClose(basic_outputs, fused_outputs)
            self.assertAllClose(basic_fw_state, fused_fw_state)
            self.assertAllClose(basic_bw_state, fused_bw_state)
            self.assertAllClose(basic_grads, fused_grads)
            for basic, fused in zip(basic_wgrads, fused_wgrads):
                self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)
示例#3
0
 def testBasicRNNCell(self):
   with self.test_session() as sess:
     with variable_scope.variable_scope(
         "root", initializer=init_ops.constant_initializer(0.5)):
       x = array_ops.zeros([1, 2])
       m = array_ops.zeros([1, 2])
       g, _ = core_rnn_cell_impl.BasicRNNCell(2)(x, m)
       sess.run([variables_lib.global_variables_initializer()])
       res = sess.run(
           [g], {x.name: np.array([[1., 1.]]),
                 m.name: np.array([[0.1, 0.1]])})
       self.assertEqual(res[0].shape, (1, 2))
  def setUp(self):
    super(DynamicRnnEstimatorTest, self).setUp()
    self.rnn_cell = core_rnn_cell_impl.BasicRNNCell(self.NUM_RNN_CELL_UNITS)
    self.mock_target_column = MockTargetColumn(
        num_label_columns=self.NUM_LABEL_COLUMNS)

    location = feature_column.sparse_column_with_keys(
        'location', keys=['west_side', 'east_side', 'nyc'])
    location_onehot = feature_column.one_hot_column(location)
    self.context_feature_columns = [location_onehot]

    wire_cast = feature_column.sparse_column_with_keys(
        'wire_cast', ['marlo', 'omar', 'stringer'])
    wire_cast_embedded = feature_column.embedding_column(wire_cast, dimension=8)
    measurements = feature_column.real_valued_column(
        'measurements', dimension=2)
    self.sequence_feature_columns = [measurements, wire_cast_embedded]
示例#5
0
 def testBasicRNNCellMatch(self):
   batch_size = 32
   input_size = 100
   num_units = 10
   with self.test_session() as sess:
     with variable_scope.variable_scope(
         "root", initializer=init_ops.constant_initializer(0.5)):
       inputs = random_ops.random_uniform((batch_size, input_size))
       _, initial_state = basic_rnn_cell(inputs, None, num_units)
       rnn_cell = core_rnn_cell_impl.BasicRNNCell(num_units)
       outputs, state = rnn_cell(inputs, initial_state)
       variable_scope.get_variable_scope().reuse_variables()
       my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
       # pylint: disable=protected-access
       slim_cell = core_rnn_cell_impl._SlimRNNCell(my_cell)
       # pylint: enable=protected-access
       slim_outputs, slim_state = slim_cell(inputs, initial_state)
       self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
       self.assertEqual(slim_state.get_shape(), state.get_shape())
       sess.run([variables_lib.global_variables_initializer()])
       res = sess.run([slim_outputs, slim_state, outputs, state])
       self.assertAllClose(res[0], res[2])
       self.assertAllClose(res[1], res[3])
示例#6
0
    def testBasicRNNFusedWrapper(self):
        """This test checks that using a wrapper for BasicRNN works as expected."""

        with self.test_session() as sess:
            initializer = init_ops.random_uniform_initializer(-0.01,
                                                              0.01,
                                                              seed=19890212)
            cell = core_rnn_cell_impl.BasicRNNCell(10)
            batch_size = 5
            input_size = 20
            timelen = 15
            inputs = constant_op.constant(
                np.random.randn(timelen, batch_size, input_size))
            with variable_scope.variable_scope("basic",
                                               initializer=initializer):
                unpacked_inputs = array_ops.unstack(inputs)
                outputs, state = core_rnn.static_rnn(cell,
                                                     unpacked_inputs,
                                                     dtype=dtypes.float64)
                packed_outputs = array_ops.stack(outputs)
                basic_vars = [
                    v for v in variables.trainable_variables()
                    if v.name.startswith("basic/")
                ]
                sess.run([variables.global_variables_initializer()])
                basic_outputs, basic_state = sess.run([packed_outputs, state])
                basic_grads = sess.run(
                    gradients_impl.gradients(packed_outputs, inputs))
                basic_wgrads = sess.run(
                    gradients_impl.gradients(packed_outputs, basic_vars))

            with variable_scope.variable_scope("fused_static",
                                               initializer=initializer):
                fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(cell)
                outputs, state = fused_cell(inputs, dtype=dtypes.float64)
                fused_static_vars = [
                    v for v in variables.trainable_variables()
                    if v.name.startswith("fused_static/")
                ]
                sess.run([variables.global_variables_initializer()])
                fused_static_outputs, fused_static_state = sess.run(
                    [outputs, state])
                fused_static_grads = sess.run(
                    gradients_impl.gradients(outputs, inputs))
                fused_static_wgrads = sess.run(
                    gradients_impl.gradients(outputs, fused_static_vars))

            self.assertAllClose(basic_outputs, fused_static_outputs)
            self.assertAllClose(basic_state, fused_static_state)
            self.assertAllClose(basic_grads, fused_static_grads)
            for basic, fused in zip(basic_wgrads, fused_static_wgrads):
                self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)

            with variable_scope.variable_scope("fused_dynamic",
                                               initializer=initializer):
                fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
                    cell, use_dynamic_rnn=True)
                outputs, state = fused_cell(inputs, dtype=dtypes.float64)
                fused_dynamic_vars = [
                    v for v in variables.trainable_variables()
                    if v.name.startswith("fused_dynamic/")
                ]
                sess.run([variables.global_variables_initializer()])
                fused_dynamic_outputs, fused_dynamic_state = sess.run(
                    [outputs, state])
                fused_dynamic_grads = sess.run(
                    gradients_impl.gradients(outputs, inputs))
                fused_dynamic_wgrads = sess.run(
                    gradients_impl.gradients(outputs, fused_dynamic_vars))

            self.assertAllClose(basic_outputs, fused_dynamic_outputs)
            self.assertAllClose(basic_state, fused_dynamic_state)
            self.assertAllClose(basic_grads, fused_dynamic_grads)
            for basic, fused in zip(basic_wgrads, fused_dynamic_wgrads):
                self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)