def testBasicRNNCellNotTrainable(self): with self.test_session() as sess: def not_trainable_getter(getter, *args, **kwargs): kwargs["trainable"] = False return getter(*args, **kwargs) with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5), custom_getter=not_trainable_getter): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 2]) cell = core_rnn_cell_impl.BasicRNNCell(2) g, _ = cell(x, m) self.assertFalse(cell.trainable_variables) self.assertEqual([ "root/basic_rnn_cell/%s:0" % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME, "root/basic_rnn_cell/%s:0" % core_rnn_cell_impl._BIAS_VARIABLE_NAME ], [v.name for v in cell.non_trainable_variables]) sess.run([variables_lib.global_variables_initializer()]) res = sess.run([g], { x.name: np.array([[1., 1.]]), m.name: np.array([[0.1, 0.1]]) }) self.assertEqual(res[0].shape, (1, 2))
def testTimeReversedFusedRNN(self): with self.test_session() as sess: initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=19890213) cell = core_rnn_cell_impl.BasicRNNCell(10) batch_size = 5 input_size = 20 timelen = 15 inputs = constant_op.constant( np.random.randn(timelen, batch_size, input_size)) # test bi-directional rnn with variable_scope.variable_scope("basic", initializer=initializer): unpacked_inputs = array_ops.unstack(inputs) outputs, fw_state, bw_state = core_rnn.static_bidirectional_rnn( cell, cell, unpacked_inputs, dtype=dtypes.float64) packed_outputs = array_ops.stack(outputs) basic_vars = [ v for v in variables.trainable_variables() if v.name.startswith("basic/") ] sess.run([variables.global_variables_initializer()]) basic_outputs, basic_fw_state, basic_bw_state = sess.run( [packed_outputs, fw_state, bw_state]) basic_grads = sess.run( gradients_impl.gradients(packed_outputs, inputs)) basic_wgrads = sess.run( gradients_impl.gradients(packed_outputs, basic_vars)) with variable_scope.variable_scope("fused", initializer=initializer): fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(cell) fused_bw_cell = fused_rnn_cell.TimeReversedFusedRNN(fused_cell) fw_outputs, fw_state = fused_cell(inputs, dtype=dtypes.float64, scope="fw") bw_outputs, bw_state = fused_bw_cell(inputs, dtype=dtypes.float64, scope="bw") outputs = array_ops.concat([fw_outputs, bw_outputs], 2) fused_vars = [ v for v in variables.trainable_variables() if v.name.startswith("fused/") ] sess.run([variables.global_variables_initializer()]) fused_outputs, fused_fw_state, fused_bw_state = sess.run( [outputs, fw_state, bw_state]) fused_grads = sess.run( gradients_impl.gradients(outputs, inputs)) fused_wgrads = sess.run( gradients_impl.gradients(outputs, fused_vars)) self.assertAllClose(basic_outputs, fused_outputs) self.assertAllClose(basic_fw_state, fused_fw_state) self.assertAllClose(basic_bw_state, fused_bw_state) self.assertAllClose(basic_grads, fused_grads) for basic, fused in zip(basic_wgrads, fused_wgrads): self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)
def testBasicRNNCell(self): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 2]) g, _ = core_rnn_cell_impl.BasicRNNCell(2)(x, m) sess.run([variables_lib.global_variables_initializer()]) res = sess.run( [g], {x.name: np.array([[1., 1.]]), m.name: np.array([[0.1, 0.1]])}) self.assertEqual(res[0].shape, (1, 2))
def setUp(self): super(DynamicRnnEstimatorTest, self).setUp() self.rnn_cell = core_rnn_cell_impl.BasicRNNCell(self.NUM_RNN_CELL_UNITS) self.mock_target_column = MockTargetColumn( num_label_columns=self.NUM_LABEL_COLUMNS) location = feature_column.sparse_column_with_keys( 'location', keys=['west_side', 'east_side', 'nyc']) location_onehot = feature_column.one_hot_column(location) self.context_feature_columns = [location_onehot] wire_cast = feature_column.sparse_column_with_keys( 'wire_cast', ['marlo', 'omar', 'stringer']) wire_cast_embedded = feature_column.embedding_column(wire_cast, dimension=8) measurements = feature_column.real_valued_column( 'measurements', dimension=2) self.sequence_feature_columns = [measurements, wire_cast_embedded]
def testBasicRNNCellMatch(self): batch_size = 32 input_size = 100 num_units = 10 with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): inputs = random_ops.random_uniform((batch_size, input_size)) _, initial_state = basic_rnn_cell(inputs, None, num_units) rnn_cell = core_rnn_cell_impl.BasicRNNCell(num_units) outputs, state = rnn_cell(inputs, initial_state) variable_scope.get_variable_scope().reuse_variables() my_cell = functools.partial(basic_rnn_cell, num_units=num_units) # pylint: disable=protected-access slim_cell = core_rnn_cell_impl._SlimRNNCell(my_cell) # pylint: enable=protected-access slim_outputs, slim_state = slim_cell(inputs, initial_state) self.assertEqual(slim_outputs.get_shape(), outputs.get_shape()) self.assertEqual(slim_state.get_shape(), state.get_shape()) sess.run([variables_lib.global_variables_initializer()]) res = sess.run([slim_outputs, slim_state, outputs, state]) self.assertAllClose(res[0], res[2]) self.assertAllClose(res[1], res[3])
def testBasicRNNFusedWrapper(self): """This test checks that using a wrapper for BasicRNN works as expected.""" with self.test_session() as sess: initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=19890212) cell = core_rnn_cell_impl.BasicRNNCell(10) batch_size = 5 input_size = 20 timelen = 15 inputs = constant_op.constant( np.random.randn(timelen, batch_size, input_size)) with variable_scope.variable_scope("basic", initializer=initializer): unpacked_inputs = array_ops.unstack(inputs) outputs, state = core_rnn.static_rnn(cell, unpacked_inputs, dtype=dtypes.float64) packed_outputs = array_ops.stack(outputs) basic_vars = [ v for v in variables.trainable_variables() if v.name.startswith("basic/") ] sess.run([variables.global_variables_initializer()]) basic_outputs, basic_state = sess.run([packed_outputs, state]) basic_grads = sess.run( gradients_impl.gradients(packed_outputs, inputs)) basic_wgrads = sess.run( gradients_impl.gradients(packed_outputs, basic_vars)) with variable_scope.variable_scope("fused_static", initializer=initializer): fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(cell) outputs, state = fused_cell(inputs, dtype=dtypes.float64) fused_static_vars = [ v for v in variables.trainable_variables() if v.name.startswith("fused_static/") ] sess.run([variables.global_variables_initializer()]) fused_static_outputs, fused_static_state = sess.run( [outputs, state]) fused_static_grads = sess.run( gradients_impl.gradients(outputs, inputs)) fused_static_wgrads = sess.run( gradients_impl.gradients(outputs, fused_static_vars)) self.assertAllClose(basic_outputs, fused_static_outputs) self.assertAllClose(basic_state, fused_static_state) self.assertAllClose(basic_grads, fused_static_grads) for basic, fused in zip(basic_wgrads, fused_static_wgrads): self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2) with variable_scope.variable_scope("fused_dynamic", initializer=initializer): fused_cell = fused_rnn_cell.FusedRNNCellAdaptor( cell, use_dynamic_rnn=True) outputs, state = fused_cell(inputs, dtype=dtypes.float64) fused_dynamic_vars = [ v for v in variables.trainable_variables() if v.name.startswith("fused_dynamic/") ] sess.run([variables.global_variables_initializer()]) fused_dynamic_outputs, fused_dynamic_state = sess.run( [outputs, state]) fused_dynamic_grads = sess.run( gradients_impl.gradients(outputs, inputs)) fused_dynamic_wgrads = sess.run( gradients_impl.gradients(outputs, fused_dynamic_vars)) self.assertAllClose(basic_outputs, fused_dynamic_outputs) self.assertAllClose(basic_state, fused_dynamic_state) self.assertAllClose(basic_grads, fused_dynamic_grads) for basic, fused in zip(basic_wgrads, fused_dynamic_wgrads): self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)