def testGetOutputAlternatives(self): test_cases = ((dynamic_rnn_estimator.PredictionType.SINGLE_VALUE, constants.ProblemType.CLASSIFICATION, { prediction_key.PredictionKey.CLASSES: True, prediction_key.PredictionKey.PROBABILITIES: True, dynamic_rnn_estimator._get_state_name(0): True }, { 'dynamic_rnn_output': (constants.ProblemType.CLASSIFICATION, { prediction_key.PredictionKey.CLASSES: True, prediction_key.PredictionKey.PROBABILITIES: True }) }), (dynamic_rnn_estimator.PredictionType.SINGLE_VALUE, constants.ProblemType.LINEAR_REGRESSION, { prediction_key.PredictionKey.SCORES: True, dynamic_rnn_estimator._get_state_name(0): True, dynamic_rnn_estimator._get_state_name(1): True }, { 'dynamic_rnn_output': (constants.ProblemType.LINEAR_REGRESSION, { prediction_key.PredictionKey.SCORES: True }) }), (dynamic_rnn_estimator.PredictionType.MULTIPLE_VALUE, constants.ProblemType.CLASSIFICATION, { prediction_key.PredictionKey.CLASSES: True, prediction_key.PredictionKey.PROBABILITIES: True, dynamic_rnn_estimator._get_state_name(0): True }, None)) for pred_type, prob_type, pred_dict, expected_alternatives in test_cases: actual_alternatives = dynamic_rnn_estimator._get_output_alternatives( pred_type, prob_type, pred_dict) self.assertEqual(expected_alternatives, actual_alternatives)
def testGetOutputAlternatives(self): test_cases = ( (rnn_common.PredictionType.SINGLE_VALUE, constants.ProblemType.CLASSIFICATION, {prediction_key.PredictionKey.CLASSES: True, prediction_key.PredictionKey.PROBABILITIES: True, dynamic_rnn_estimator._get_state_name(0): True}, {'dynamic_rnn_output': (constants.ProblemType.CLASSIFICATION, {prediction_key.PredictionKey.CLASSES: True, prediction_key.PredictionKey.PROBABILITIES: True})}), (rnn_common.PredictionType.SINGLE_VALUE, constants.ProblemType.LINEAR_REGRESSION, {prediction_key.PredictionKey.SCORES: True, dynamic_rnn_estimator._get_state_name(0): True, dynamic_rnn_estimator._get_state_name(1): True}, {'dynamic_rnn_output': (constants.ProblemType.LINEAR_REGRESSION, {prediction_key.PredictionKey.SCORES: True})}), (rnn_common.PredictionType.MULTIPLE_VALUE, constants.ProblemType.CLASSIFICATION, {prediction_key.PredictionKey.CLASSES: True, prediction_key.PredictionKey.PROBABILITIES: True, dynamic_rnn_estimator._get_state_name(0): True}, None)) for pred_type, prob_type, pred_dict, expected_alternatives in test_cases: actual_alternatives = dynamic_rnn_estimator._get_output_alternatives( pred_type, prob_type, pred_dict) self.assertEqual(expected_alternatives, actual_alternatives)
def testMultiRNNState(self): """Test that state flattening/reconstruction works for `MultiRNNCell`.""" batch_size = 11 sequence_length = 16 train_steps = 5 cell_sizes = [4, 8, 7] learning_rate = 0.1 def get_shift_input_fn(batch_size, sequence_length, seed=None): def input_fn(): random_sequence = random_ops.random_uniform( [batch_size, sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) labels = array_ops.slice(random_sequence, [0, 0], [batch_size, sequence_length]) inputs = array_ops.expand_dims( math_ops.cast( array_ops.slice(random_sequence, [0, 1], [batch_size, sequence_length]), dtypes.float32), 2) input_dict = { dynamic_rnn_estimator._get_state_name(i): random_ops.random_uniform( [batch_size, cell_size], seed=((i + 1) * seed)) for i, cell_size in enumerate([4, 4, 8, 8, 7, 7]) } input_dict['inputs'] = inputs return input_dict, labels return input_fn seq_columns = [feature_column.real_valued_column('inputs', dimension=1)] config = run_config.RunConfig(tf_random_seed=21212) cell_type = 'lstm' sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator( problem_type=constants.ProblemType.CLASSIFICATION, prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE, num_classes=2, num_units=cell_sizes, sequence_feature_columns=seq_columns, cell_type=cell_type, learning_rate=learning_rate, config=config, predict_probabilities=True) train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321) eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123) sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) prediction_dict = sequence_estimator.predict( input_fn=eval_input_fn, as_iterable=False) for i, state_size in enumerate([4, 4, 8, 8, 7, 7]): state_piece = prediction_dict[dynamic_rnn_estimator._get_state_name(i)] self.assertListEqual(list(state_piece.shape), [batch_size, state_size])
def testMultiRNNState(self): """Test that state flattening/reconstruction works for `MultiRNNCell`.""" batch_size = 11 sequence_length = 16 train_steps = 5 cell_sizes = [4, 8, 7] learning_rate = 0.1 def get_shift_input_fn(batch_size, sequence_length, seed=None): def input_fn(): random_sequence = random_ops.random_uniform( [batch_size, sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) labels = array_ops.slice(random_sequence, [0, 0], [batch_size, sequence_length]) inputs = array_ops.expand_dims( math_ops.to_float( array_ops.slice(random_sequence, [0, 1], [batch_size, sequence_length])), 2) input_dict = { dynamic_rnn_estimator._get_state_name(i): random_ops.random_uniform( [batch_size, cell_size], seed=((i + 1) * seed)) for i, cell_size in enumerate([4, 4, 8, 8, 7, 7]) } input_dict['inputs'] = inputs return input_dict, labels return input_fn seq_columns = [feature_column.real_valued_column('inputs', dimension=1)] config = run_config.RunConfig(tf_random_seed=21212) cell_type = 'lstm' sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator( problem_type=constants.ProblemType.CLASSIFICATION, prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE, num_classes=2, num_units=cell_sizes, sequence_feature_columns=seq_columns, cell_type=cell_type, learning_rate=learning_rate, config=config, predict_probabilities=True) train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321) eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123) sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) prediction_dict = sequence_estimator.predict( input_fn=eval_input_fn, as_iterable=False) for i, state_size in enumerate([4, 4, 8, 8, 7, 7]): state_piece = prediction_dict[dynamic_rnn_estimator._get_state_name(i)] self.assertListEqual(list(state_piece.shape), [batch_size, state_size])
def input_fn(): random_sequence = tf.random_uniform( [batch_size, sequence_length + 1], 0, 2, dtype=tf.int32, seed=seed) labels = tf.slice( random_sequence, [0, 0], [batch_size, sequence_length]) inputs = tf.expand_dims( tf.to_float(tf.slice( random_sequence, [0, 1], [batch_size, sequence_length])), 2) input_dict = { dynamic_rnn_estimator._get_state_name(i): tf.random_uniform( [batch_size, cell_size], seed=((i + 1) * seed)) for i, cell_size in enumerate([4, 4, 8, 8, 7, 7])} input_dict['inputs'] = inputs return input_dict, labels
def testMultiRNNState(self): """Test that state flattening/reconstruction works for `MultiRNNCell`.""" batch_size = 11 sequence_length = 16 train_steps = 5 cell_sizes = [4, 8, 7] learning_rate = 0.1 def get_shift_input_fn(batch_size, sequence_length, seed=None): def input_fn(): random_sequence = tf.random_uniform( [batch_size, sequence_length + 1], 0, 2, dtype=tf.int32, seed=seed) labels = tf.slice( random_sequence, [0, 0], [batch_size, sequence_length]) inputs = tf.expand_dims( tf.to_float(tf.slice( random_sequence, [0, 1], [batch_size, sequence_length])), 2) input_dict = { dynamic_rnn_estimator._get_state_name(i): tf.random_uniform( [batch_size, cell_size], seed=((i + 1) * seed)) for i, cell_size in enumerate([4, 4, 8, 8, 7, 7])} input_dict['inputs'] = inputs return input_dict, labels return input_fn seq_columns = [tf.contrib.layers.real_valued_column( 'inputs', dimension=1)] config = tf.contrib.learn.RunConfig(tf_random_seed=21212) cell = tf.contrib.rnn.MultiRNNCell( [tf.contrib.rnn.BasicLSTMCell(size) for size in cell_sizes]) sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_classifier( num_classes=2, num_units=None, sequence_feature_columns=seq_columns, cell_type=cell, learning_rate=learning_rate, config=config, predict_probabilities=True) train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321) eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123) sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) prediction_dict = sequence_estimator.predict( input_fn=eval_input_fn, as_iterable=False) for i, state_size in enumerate([4, 4, 8, 8, 7, 7]): state_piece = prediction_dict[dynamic_rnn_estimator._get_state_name(i)] self.assertListEqual(list(state_piece.shape), [batch_size, state_size])
def testStateTupleDictConversion(self): """Test `state_tuple_to_dict` and `dict_to_state_tuple`.""" cell_sizes = [5, 3, 7] # A MultiRNNCell of LSTMCells is both a common choice and an interesting # test case, because it has two levels of nesting, with an inner class that # is not a plain tuple. cell = core_rnn_cell_impl.MultiRNNCell( [core_rnn_cell_impl.LSTMCell(i) for i in cell_sizes]) state_dict = { dynamic_rnn_estimator._get_state_name(i): array_ops.expand_dims(math_ops.range(cell_size), 0) for i, cell_size in enumerate([5, 5, 3, 3, 7, 7]) } expected_state = (core_rnn_cell_impl.LSTMStateTuple( np.reshape(np.arange(5), [1, -1]), np.reshape(np.arange(5), [1, -1])), core_rnn_cell_impl.LSTMStateTuple( np.reshape(np.arange(3), [1, -1]), np.reshape(np.arange(3), [1, -1])), core_rnn_cell_impl.LSTMStateTuple( np.reshape(np.arange(7), [1, -1]), np.reshape(np.arange(7), [1, -1]))) actual_state = dynamic_rnn_estimator.dict_to_state_tuple( state_dict, cell) flattened_state = dynamic_rnn_estimator.state_tuple_to_dict( actual_state) with self.test_session() as sess: (state_dict_val, actual_state_val, flattened_state_val) = sess.run( [state_dict, actual_state, flattened_state]) def _recursive_assert_equal(x, y): self.assertEqual(type(x), type(y)) if isinstance(x, (list, tuple)): self.assertEqual(len(x), len(y)) for i, _ in enumerate(x): _recursive_assert_equal(x[i], y[i]) elif isinstance(x, np.ndarray): np.testing.assert_array_equal(x, y) else: self.fail('Unexpected type: {}'.format(type(x))) for k in state_dict_val.keys(): np.testing.assert_array_almost_equal( state_dict_val[k], flattened_state_val[k], err_msg='Wrong value for state component {}.'.format(k)) _recursive_assert_equal(expected_state, actual_state_val)
def testStateTupleDictConversion(self): """Test `state_tuple_to_dict` and `dict_to_state_tuple`.""" cell_sizes = [5, 3, 7] # A MultiRNNCell of LSTMCells is both a common choice and an interesting # test case, because it has two levels of nesting, with an inner class that # is not a plain tuple. cell = core_rnn_cell_impl.MultiRNNCell( [core_rnn_cell_impl.LSTMCell(i) for i in cell_sizes]) state_dict = { dynamic_rnn_estimator._get_state_name(i): array_ops.expand_dims(math_ops.range(cell_size), 0) for i, cell_size in enumerate([5, 5, 3, 3, 7, 7]) } expected_state = (core_rnn_cell_impl.LSTMStateTuple( np.reshape(np.arange(5), [1, -1]), np.reshape(np.arange(5), [1, -1])), core_rnn_cell_impl.LSTMStateTuple( np.reshape(np.arange(3), [1, -1]), np.reshape(np.arange(3), [1, -1])), core_rnn_cell_impl.LSTMStateTuple( np.reshape(np.arange(7), [1, -1]), np.reshape(np.arange(7), [1, -1]))) actual_state = dynamic_rnn_estimator.dict_to_state_tuple(state_dict, cell) flattened_state = dynamic_rnn_estimator.state_tuple_to_dict(actual_state) with self.test_session() as sess: (state_dict_val, actual_state_val, flattened_state_val) = sess.run( [state_dict, actual_state, flattened_state]) def _recursive_assert_equal(x, y): self.assertEqual(type(x), type(y)) if isinstance(x, (list, tuple)): self.assertEqual(len(x), len(y)) for i, _ in enumerate(x): _recursive_assert_equal(x[i], y[i]) elif isinstance(x, np.ndarray): np.testing.assert_array_equal(x, y) else: self.fail('Unexpected type: {}'.format(type(x))) for k in state_dict_val.keys(): np.testing.assert_array_almost_equal( state_dict_val[k], flattened_state_val[k], err_msg='Wrong value for state component {}.'.format(k)) _recursive_assert_equal(expected_state, actual_state_val)
def testLearnMajority(self): """Test learning the 'majority' function.""" batch_size = 16 sequence_length = 7 train_steps = 200 eval_steps = 20 cell_type = 'lstm' cell_size = 4 optimizer_type = 'Momentum' learning_rate = 2.0 momentum = 0.9 accuracy_threshold = 0.9 def get_majority_input_fn(batch_size, sequence_length, seed=None): random_seed.set_random_seed(seed) def input_fn(): random_sequence = random_ops.random_uniform( [batch_size, sequence_length], 0, 2, dtype=dtypes.int32, seed=seed) inputs = array_ops.expand_dims(math_ops.to_float(random_sequence), 2) labels = math_ops.to_int32( array_ops.squeeze( math_ops.reduce_sum( inputs, reduction_indices=[1]) > (sequence_length / 2.0))) return {'inputs': inputs}, labels return input_fn seq_columns = [ feature_column.real_valued_column( 'inputs', dimension=cell_size) ] config = run_config.RunConfig(tf_random_seed=77) sequence_classifier = dynamic_rnn_estimator.single_value_rnn_classifier( num_classes=2, num_units=cell_size, sequence_feature_columns=seq_columns, cell_type=cell_type, optimizer_type=optimizer_type, learning_rate=learning_rate, momentum=momentum, config=config, predict_probabilities=True) train_input_fn = get_majority_input_fn(batch_size, sequence_length, 1111) eval_input_fn = get_majority_input_fn(batch_size, sequence_length, 2222) sequence_classifier.fit(input_fn=train_input_fn, steps=train_steps) evaluation = sequence_classifier.evaluate( input_fn=eval_input_fn, steps=eval_steps) accuracy = evaluation['accuracy'] self.assertGreater(accuracy, accuracy_threshold, 'Accuracy should be higher than {}; got {}'.format( accuracy_threshold, accuracy)) # Testing `predict` when `predict_probabilities=True`. prediction_dict = sequence_classifier.predict( input_fn=eval_input_fn, as_iterable=False) self.assertListEqual( sorted(list(prediction_dict.keys())), sorted([ dynamic_rnn_estimator.RNNKeys.PREDICTIONS_KEY, dynamic_rnn_estimator.RNNKeys.PROBABILITIES_KEY, dynamic_rnn_estimator._get_state_name(0), dynamic_rnn_estimator._get_state_name(1) ])) predictions = prediction_dict[dynamic_rnn_estimator.RNNKeys.PREDICTIONS_KEY] probabilities = prediction_dict[ dynamic_rnn_estimator.RNNKeys.PROBABILITIES_KEY] self.assertListEqual(list(predictions.shape), [batch_size]) self.assertListEqual(list(probabilities.shape), [batch_size, 2])
def testLearnShiftByOne(self): """Tests that learning a 'shift-by-one' example. Each label sequence consists of the input sequence 'shifted' by one place. The RNN must learn to 'remember' the previous input. """ batch_size = 16 sequence_length = 32 train_steps = 200 eval_steps = 20 cell_size = 4 learning_rate = 0.3 accuracy_threshold = 0.9 def get_shift_input_fn(batch_size, sequence_length, seed=None): def input_fn(): random_sequence = random_ops.random_uniform( [batch_size, sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) labels = array_ops.slice(random_sequence, [0, 0], [batch_size, sequence_length]) inputs = array_ops.expand_dims( math_ops.to_float( array_ops.slice(random_sequence, [0, 1], [batch_size, sequence_length])), 2) return {'inputs': inputs}, labels return input_fn seq_columns = [ feature_column.real_valued_column( 'inputs', dimension=cell_size) ] config = run_config.RunConfig(tf_random_seed=21212) sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_classifier( num_classes=2, num_units=cell_size, sequence_feature_columns=seq_columns, learning_rate=learning_rate, config=config, predict_probabilities=True) train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321) eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123) sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) evaluation = sequence_estimator.evaluate( input_fn=eval_input_fn, steps=eval_steps) accuracy = evaluation['accuracy'] self.assertGreater(accuracy, accuracy_threshold, 'Accuracy should be higher than {}; got {}'.format( accuracy_threshold, accuracy)) # Testing `predict` when `predict_probabilities=True`. prediction_dict = sequence_estimator.predict( input_fn=eval_input_fn, as_iterable=False) self.assertListEqual( sorted(list(prediction_dict.keys())), sorted([ dynamic_rnn_estimator.RNNKeys.PREDICTIONS_KEY, dynamic_rnn_estimator.RNNKeys.PROBABILITIES_KEY, dynamic_rnn_estimator._get_state_name(0) ])) predictions = prediction_dict[dynamic_rnn_estimator.RNNKeys.PREDICTIONS_KEY] probabilities = prediction_dict[ dynamic_rnn_estimator.RNNKeys.PROBABILITIES_KEY] self.assertListEqual(list(predictions.shape), [batch_size, sequence_length]) self.assertListEqual( list(probabilities.shape), [batch_size, sequence_length, 2])
def testLearnMajority(self): """Test learning the 'majority' function.""" batch_size = 16 sequence_length = 7 train_steps = 200 eval_steps = 20 cell_type = 'lstm' cell_size = 4 optimizer_type = 'Momentum' learning_rate = 2.0 momentum = 0.9 accuracy_threshold = 0.9 def get_majority_input_fn(batch_size, sequence_length, seed=None): tf.set_random_seed(seed) def input_fn(): random_sequence = tf.random_uniform( [batch_size, sequence_length], 0, 2, dtype=tf.int32, seed=seed) inputs = tf.expand_dims(tf.to_float(random_sequence), 2) labels = tf.to_int32( tf.squeeze( tf.reduce_sum( inputs, reduction_indices=[1]) > (sequence_length / 2.0))) return {'inputs': inputs}, labels return input_fn seq_columns = [tf.contrib.layers.real_valued_column( 'inputs', dimension=cell_size)] config = tf.contrib.learn.RunConfig(tf_random_seed=77) sequence_classifier = dynamic_rnn_estimator.single_value_rnn_classifier( num_classes=2, num_units=cell_size, sequence_feature_columns=seq_columns, cell_type=cell_type, optimizer_type=optimizer_type, learning_rate=learning_rate, momentum=momentum, config=config, predict_probabilities=True) train_input_fn = get_majority_input_fn(batch_size, sequence_length, 1111) eval_input_fn = get_majority_input_fn(batch_size, sequence_length, 2222) sequence_classifier.fit(input_fn=train_input_fn, steps=train_steps) evaluation = sequence_classifier.evaluate( input_fn=eval_input_fn, steps=eval_steps) accuracy = evaluation['accuracy'] self.assertGreater(accuracy, accuracy_threshold, 'Accuracy should be higher than {}; got {}'.format( accuracy_threshold, accuracy)) # Testing `predict` when `predict_probabilities=True`. prediction_dict = sequence_classifier.predict( input_fn=eval_input_fn, as_iterable=False) self.assertListEqual( sorted(list(prediction_dict.keys())), sorted([dynamic_rnn_estimator.RNNKeys.PREDICTIONS_KEY, dynamic_rnn_estimator.RNNKeys.PROBABILITIES_KEY, dynamic_rnn_estimator._get_state_name(0), dynamic_rnn_estimator._get_state_name(1)])) predictions = prediction_dict[dynamic_rnn_estimator.RNNKeys.PREDICTIONS_KEY] probabilities = prediction_dict[ dynamic_rnn_estimator.RNNKeys.PROBABILITIES_KEY] self.assertListEqual(list(predictions.shape), [batch_size]) self.assertListEqual(list(probabilities.shape), [batch_size, 2])
def testLearnShiftByOne(self): """Tests that learning a 'shift-by-one' example. Each label sequence consists of the input sequence 'shifted' by one place. The RNN must learn to 'remember' the previous input. """ batch_size = 16 sequence_length = 32 train_steps = 200 eval_steps = 20 cell_size = 4 learning_rate = 0.3 accuracy_threshold = 0.9 def get_shift_input_fn(batch_size, sequence_length, seed=None): def input_fn(): random_sequence = tf.random_uniform( [batch_size, sequence_length + 1], 0, 2, dtype=tf.int32, seed=seed) labels = tf.slice( random_sequence, [0, 0], [batch_size, sequence_length]) inputs = tf.expand_dims( tf.to_float(tf.slice( random_sequence, [0, 1], [batch_size, sequence_length])), 2) return {'inputs': inputs}, labels return input_fn seq_columns = [tf.contrib.layers.real_valued_column( 'inputs', dimension=cell_size)] config = tf.contrib.learn.RunConfig(tf_random_seed=21212) sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_classifier( num_classes=2, num_units=cell_size, sequence_feature_columns=seq_columns, learning_rate=learning_rate, config=config, predict_probabilities=True) train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321) eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123) sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) evaluation = sequence_estimator.evaluate( input_fn=eval_input_fn, steps=eval_steps) accuracy = evaluation['accuracy'] self.assertGreater(accuracy, accuracy_threshold, 'Accuracy should be higher than {}; got {}'.format( accuracy_threshold, accuracy)) # Testing `predict` when `predict_probabilities=True`. prediction_dict = sequence_estimator.predict( input_fn=eval_input_fn, as_iterable=False) self.assertListEqual( sorted(list(prediction_dict.keys())), sorted([dynamic_rnn_estimator.RNNKeys.PREDICTIONS_KEY, dynamic_rnn_estimator.RNNKeys.PROBABILITIES_KEY, dynamic_rnn_estimator._get_state_name(0)])) predictions = prediction_dict[dynamic_rnn_estimator.RNNKeys.PREDICTIONS_KEY] probabilities = prediction_dict[ dynamic_rnn_estimator.RNNKeys.PROBABILITIES_KEY] self.assertListEqual(list(predictions.shape), [batch_size, sequence_length]) self.assertListEqual( list(probabilities.shape), [batch_size, sequence_length, 2])
def testLearnMajority(self): """Test learning the 'majority' function.""" batch_size = 16 sequence_length = 7 train_steps = 200 eval_steps = 20 cell_type = 'lstm' cell_size = 4 optimizer_type = 'Momentum' learning_rate = 2.0 momentum = 0.9 accuracy_threshold = 0.9 def get_majority_input_fn(batch_size, sequence_length, seed=None): random_seed.set_random_seed(seed) def input_fn(): random_sequence = random_ops.random_uniform( [batch_size, sequence_length], 0, 2, dtype=dtypes.int32, seed=seed) inputs = array_ops.expand_dims( math_ops.to_float(random_sequence), 2) labels = math_ops.to_int32( array_ops.squeeze( math_ops.reduce_sum(inputs, reduction_indices=[1]) > ( sequence_length / 2.0))) return {'inputs': inputs}, labels return input_fn seq_columns = [ feature_column.real_valued_column('inputs', dimension=cell_size) ] config = run_config.RunConfig(tf_random_seed=77) sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator( problem_type=constants.ProblemType.CLASSIFICATION, prediction_type=dynamic_rnn_estimator.PredictionType.SINGLE_VALUE, num_classes=2, num_units=cell_size, sequence_feature_columns=seq_columns, cell_type=cell_type, optimizer=optimizer_type, learning_rate=learning_rate, momentum=momentum, config=config, predict_probabilities=True) train_input_fn = get_majority_input_fn(batch_size, sequence_length, 1111) eval_input_fn = get_majority_input_fn(batch_size, sequence_length, 2222) sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) evaluation = sequence_estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) accuracy = evaluation['accuracy'] self.assertGreater( accuracy, accuracy_threshold, 'Accuracy should be higher than {}; got {}'.format( accuracy_threshold, accuracy)) # Testing `predict` when `predict_probabilities=True`. prediction_dict = sequence_estimator.predict(input_fn=eval_input_fn, as_iterable=False) self.assertListEqual( sorted(list(prediction_dict.keys())), sorted([ prediction_key.PredictionKey.CLASSES, prediction_key.PredictionKey.PROBABILITIES, dynamic_rnn_estimator._get_state_name(0), dynamic_rnn_estimator._get_state_name(1) ])) predictions = prediction_dict[prediction_key.PredictionKey.CLASSES] probabilities = prediction_dict[ prediction_key.PredictionKey.PROBABILITIES] self.assertListEqual(list(predictions.shape), [batch_size]) self.assertListEqual(list(probabilities.shape), [batch_size, 2])
def testLearnShiftByOne(self): """Tests that learning a 'shift-by-one' example. Each label sequence consists of the input sequence 'shifted' by one place. The RNN must learn to 'remember' the previous input. """ batch_size = 16 sequence_length = 32 train_steps = 200 eval_steps = 20 cell_size = 4 learning_rate = 0.3 accuracy_threshold = 0.9 def get_shift_input_fn(batch_size, sequence_length, seed=None): def input_fn(): random_sequence = random_ops.random_uniform( [batch_size, sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) labels = array_ops.slice(random_sequence, [0, 0], [batch_size, sequence_length]) inputs = array_ops.expand_dims( math_ops.to_float( array_ops.slice(random_sequence, [0, 1], [batch_size, sequence_length])), 2) return {'inputs': inputs}, labels return input_fn seq_columns = [ feature_column.real_valued_column( 'inputs', dimension=cell_size) ] config = run_config.RunConfig(tf_random_seed=21212) sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator( problem_type=constants.ProblemType.CLASSIFICATION, prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE, num_classes=2, num_units=cell_size, sequence_feature_columns=seq_columns, learning_rate=learning_rate, config=config, predict_probabilities=True) train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321) eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123) sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) evaluation = sequence_estimator.evaluate( input_fn=eval_input_fn, steps=eval_steps) accuracy = evaluation['accuracy'] self.assertGreater(accuracy, accuracy_threshold, 'Accuracy should be higher than {}; got {}'.format( accuracy_threshold, accuracy)) # Testing `predict` when `predict_probabilities=True`. prediction_dict = sequence_estimator.predict( input_fn=eval_input_fn, as_iterable=False) self.assertListEqual( sorted(list(prediction_dict.keys())), sorted([ prediction_key.PredictionKey.CLASSES, prediction_key.PredictionKey.PROBABILITIES, dynamic_rnn_estimator._get_state_name(0) ])) predictions = prediction_dict[prediction_key.PredictionKey.CLASSES] probabilities = prediction_dict[ prediction_key.PredictionKey.PROBABILITIES] self.assertListEqual(list(predictions.shape), [batch_size, sequence_length]) self.assertListEqual( list(probabilities.shape), [batch_size, sequence_length, 2])