Exemple #1
0
    def testGetOutputAlternatives(self):
        test_cases = ((dynamic_rnn_estimator.PredictionType.SINGLE_VALUE,
                       constants.ProblemType.CLASSIFICATION, {
                           prediction_key.PredictionKey.CLASSES: True,
                           prediction_key.PredictionKey.PROBABILITIES: True,
                           dynamic_rnn_estimator._get_state_name(0): True
                       }, {
                           'dynamic_rnn_output':
                           (constants.ProblemType.CLASSIFICATION, {
                               prediction_key.PredictionKey.CLASSES: True,
                               prediction_key.PredictionKey.PROBABILITIES: True
                           })
                       }), (dynamic_rnn_estimator.PredictionType.SINGLE_VALUE,
                            constants.ProblemType.LINEAR_REGRESSION, {
                                prediction_key.PredictionKey.SCORES: True,
                                dynamic_rnn_estimator._get_state_name(0): True,
                                dynamic_rnn_estimator._get_state_name(1): True
                            }, {
                                'dynamic_rnn_output':
                                (constants.ProblemType.LINEAR_REGRESSION, {
                                    prediction_key.PredictionKey.SCORES: True
                                })
                            }),
                      (dynamic_rnn_estimator.PredictionType.MULTIPLE_VALUE,
                       constants.ProblemType.CLASSIFICATION, {
                           prediction_key.PredictionKey.CLASSES: True,
                           prediction_key.PredictionKey.PROBABILITIES: True,
                           dynamic_rnn_estimator._get_state_name(0): True
                       }, None))

        for pred_type, prob_type, pred_dict, expected_alternatives in test_cases:
            actual_alternatives = dynamic_rnn_estimator._get_output_alternatives(
                pred_type, prob_type, pred_dict)
            self.assertEqual(expected_alternatives, actual_alternatives)
  def testGetOutputAlternatives(self):
    test_cases = (
        (rnn_common.PredictionType.SINGLE_VALUE,
         constants.ProblemType.CLASSIFICATION,
         {prediction_key.PredictionKey.CLASSES: True,
          prediction_key.PredictionKey.PROBABILITIES: True,
          dynamic_rnn_estimator._get_state_name(0): True},
         {'dynamic_rnn_output':
          (constants.ProblemType.CLASSIFICATION,
           {prediction_key.PredictionKey.CLASSES: True,
            prediction_key.PredictionKey.PROBABILITIES: True})}),

        (rnn_common.PredictionType.SINGLE_VALUE,
         constants.ProblemType.LINEAR_REGRESSION,
         {prediction_key.PredictionKey.SCORES: True,
          dynamic_rnn_estimator._get_state_name(0): True,
          dynamic_rnn_estimator._get_state_name(1): True},
         {'dynamic_rnn_output':
          (constants.ProblemType.LINEAR_REGRESSION,
           {prediction_key.PredictionKey.SCORES: True})}),

        (rnn_common.PredictionType.MULTIPLE_VALUE,
         constants.ProblemType.CLASSIFICATION,
         {prediction_key.PredictionKey.CLASSES: True,
          prediction_key.PredictionKey.PROBABILITIES: True,
          dynamic_rnn_estimator._get_state_name(0): True},
         None))

    for pred_type, prob_type, pred_dict, expected_alternatives in test_cases:
      actual_alternatives = dynamic_rnn_estimator._get_output_alternatives(
          pred_type, prob_type, pred_dict)
      self.assertEqual(expected_alternatives, actual_alternatives)
  def testMultiRNNState(self):
    """Test that state flattening/reconstruction works for `MultiRNNCell`."""
    batch_size = 11
    sequence_length = 16
    train_steps = 5
    cell_sizes = [4, 8, 7]
    learning_rate = 0.1

    def get_shift_input_fn(batch_size, sequence_length, seed=None):

      def input_fn():
        random_sequence = random_ops.random_uniform(
            [batch_size, sequence_length + 1],
            0,
            2,
            dtype=dtypes.int32,
            seed=seed)
        labels = array_ops.slice(random_sequence, [0, 0],
                                 [batch_size, sequence_length])
        inputs = array_ops.expand_dims(
            math_ops.cast(
                array_ops.slice(random_sequence, [0, 1],
                                [batch_size, sequence_length]),
                dtypes.float32), 2)
        input_dict = {
            dynamic_rnn_estimator._get_state_name(i): random_ops.random_uniform(
                [batch_size, cell_size], seed=((i + 1) * seed))
            for i, cell_size in enumerate([4, 4, 8, 8, 7, 7])
        }
        input_dict['inputs'] = inputs
        return input_dict, labels

      return input_fn

    seq_columns = [feature_column.real_valued_column('inputs', dimension=1)]
    config = run_config.RunConfig(tf_random_seed=21212)
    cell_type = 'lstm'
    sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
        problem_type=constants.ProblemType.CLASSIFICATION,
        prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE,
        num_classes=2,
        num_units=cell_sizes,
        sequence_feature_columns=seq_columns,
        cell_type=cell_type,
        learning_rate=learning_rate,
        config=config,
        predict_probabilities=True)

    train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321)
    eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

    prediction_dict = sequence_estimator.predict(
        input_fn=eval_input_fn, as_iterable=False)
    for i, state_size in enumerate([4, 4, 8, 8, 7, 7]):
      state_piece = prediction_dict[dynamic_rnn_estimator._get_state_name(i)]
      self.assertListEqual(list(state_piece.shape), [batch_size, state_size])
  def testMultiRNNState(self):
    """Test that state flattening/reconstruction works for `MultiRNNCell`."""
    batch_size = 11
    sequence_length = 16
    train_steps = 5
    cell_sizes = [4, 8, 7]
    learning_rate = 0.1

    def get_shift_input_fn(batch_size, sequence_length, seed=None):

      def input_fn():
        random_sequence = random_ops.random_uniform(
            [batch_size, sequence_length + 1],
            0,
            2,
            dtype=dtypes.int32,
            seed=seed)
        labels = array_ops.slice(random_sequence, [0, 0],
                                 [batch_size, sequence_length])
        inputs = array_ops.expand_dims(
            math_ops.to_float(
                array_ops.slice(random_sequence, [0, 1],
                                [batch_size, sequence_length])), 2)
        input_dict = {
            dynamic_rnn_estimator._get_state_name(i): random_ops.random_uniform(
                [batch_size, cell_size], seed=((i + 1) * seed))
            for i, cell_size in enumerate([4, 4, 8, 8, 7, 7])
        }
        input_dict['inputs'] = inputs
        return input_dict, labels

      return input_fn

    seq_columns = [feature_column.real_valued_column('inputs', dimension=1)]
    config = run_config.RunConfig(tf_random_seed=21212)
    cell_type = 'lstm'
    sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
        problem_type=constants.ProblemType.CLASSIFICATION,
        prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE,
        num_classes=2,
        num_units=cell_sizes,
        sequence_feature_columns=seq_columns,
        cell_type=cell_type,
        learning_rate=learning_rate,
        config=config,
        predict_probabilities=True)

    train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321)
    eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

    prediction_dict = sequence_estimator.predict(
        input_fn=eval_input_fn, as_iterable=False)
    for i, state_size in enumerate([4, 4, 8, 8, 7, 7]):
      state_piece = prediction_dict[dynamic_rnn_estimator._get_state_name(i)]
      self.assertListEqual(list(state_piece.shape), [batch_size, state_size])
 def input_fn():
   random_sequence = tf.random_uniform(
       [batch_size, sequence_length + 1], 0, 2, dtype=tf.int32, seed=seed)
   labels = tf.slice(
       random_sequence, [0, 0], [batch_size, sequence_length])
   inputs = tf.expand_dims(
       tf.to_float(tf.slice(
           random_sequence, [0, 1], [batch_size, sequence_length])), 2)
   input_dict = {
       dynamic_rnn_estimator._get_state_name(i): tf.random_uniform(
           [batch_size, cell_size], seed=((i + 1) * seed))
       for i, cell_size in enumerate([4, 4, 8, 8, 7, 7])}
   input_dict['inputs'] = inputs
   return input_dict, labels
  def testMultiRNNState(self):
    """Test that state flattening/reconstruction works for `MultiRNNCell`."""
    batch_size = 11
    sequence_length = 16
    train_steps = 5
    cell_sizes = [4, 8, 7]
    learning_rate = 0.1

    def get_shift_input_fn(batch_size, sequence_length, seed=None):
      def input_fn():
        random_sequence = tf.random_uniform(
            [batch_size, sequence_length + 1], 0, 2, dtype=tf.int32, seed=seed)
        labels = tf.slice(
            random_sequence, [0, 0], [batch_size, sequence_length])
        inputs = tf.expand_dims(
            tf.to_float(tf.slice(
                random_sequence, [0, 1], [batch_size, sequence_length])), 2)
        input_dict = {
            dynamic_rnn_estimator._get_state_name(i): tf.random_uniform(
                [batch_size, cell_size], seed=((i + 1) * seed))
            for i, cell_size in enumerate([4, 4, 8, 8, 7, 7])}
        input_dict['inputs'] = inputs
        return input_dict, labels
      return input_fn

    seq_columns = [tf.contrib.layers.real_valued_column(
        'inputs', dimension=1)]
    config = tf.contrib.learn.RunConfig(tf_random_seed=21212)
    cell = tf.contrib.rnn.MultiRNNCell(
        [tf.contrib.rnn.BasicLSTMCell(size) for size in cell_sizes])
    sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_classifier(
        num_classes=2,
        num_units=None,
        sequence_feature_columns=seq_columns,
        cell_type=cell,
        learning_rate=learning_rate,
        config=config,
        predict_probabilities=True)

    train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321)
    eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

    prediction_dict = sequence_estimator.predict(
        input_fn=eval_input_fn, as_iterable=False)
    for i, state_size in enumerate([4, 4, 8, 8, 7, 7]):
      state_piece = prediction_dict[dynamic_rnn_estimator._get_state_name(i)]
      self.assertListEqual(list(state_piece.shape), [batch_size, state_size])
Exemple #7
0
    def testStateTupleDictConversion(self):
        """Test `state_tuple_to_dict` and `dict_to_state_tuple`."""
        cell_sizes = [5, 3, 7]
        # A MultiRNNCell of LSTMCells is both a common choice and an interesting
        # test case, because it has two levels of nesting, with an inner class that
        # is not a plain tuple.
        cell = core_rnn_cell_impl.MultiRNNCell(
            [core_rnn_cell_impl.LSTMCell(i) for i in cell_sizes])
        state_dict = {
            dynamic_rnn_estimator._get_state_name(i):
            array_ops.expand_dims(math_ops.range(cell_size), 0)
            for i, cell_size in enumerate([5, 5, 3, 3, 7, 7])
        }
        expected_state = (core_rnn_cell_impl.LSTMStateTuple(
            np.reshape(np.arange(5), [1, -1]),
            np.reshape(np.arange(5), [1, -1])),
                          core_rnn_cell_impl.LSTMStateTuple(
                              np.reshape(np.arange(3), [1, -1]),
                              np.reshape(np.arange(3), [1, -1])),
                          core_rnn_cell_impl.LSTMStateTuple(
                              np.reshape(np.arange(7), [1, -1]),
                              np.reshape(np.arange(7), [1, -1])))
        actual_state = dynamic_rnn_estimator.dict_to_state_tuple(
            state_dict, cell)
        flattened_state = dynamic_rnn_estimator.state_tuple_to_dict(
            actual_state)

        with self.test_session() as sess:
            (state_dict_val, actual_state_val, flattened_state_val) = sess.run(
                [state_dict, actual_state, flattened_state])

        def _recursive_assert_equal(x, y):
            self.assertEqual(type(x), type(y))
            if isinstance(x, (list, tuple)):
                self.assertEqual(len(x), len(y))
                for i, _ in enumerate(x):
                    _recursive_assert_equal(x[i], y[i])
            elif isinstance(x, np.ndarray):
                np.testing.assert_array_equal(x, y)
            else:
                self.fail('Unexpected type: {}'.format(type(x)))

        for k in state_dict_val.keys():
            np.testing.assert_array_almost_equal(
                state_dict_val[k],
                flattened_state_val[k],
                err_msg='Wrong value for state component {}.'.format(k))
        _recursive_assert_equal(expected_state, actual_state_val)
  def testStateTupleDictConversion(self):
    """Test `state_tuple_to_dict` and `dict_to_state_tuple`."""
    cell_sizes = [5, 3, 7]
    # A MultiRNNCell of LSTMCells is both a common choice and an interesting
    # test case, because it has two levels of nesting, with an inner class that
    # is not a plain tuple.
    cell = core_rnn_cell_impl.MultiRNNCell(
        [core_rnn_cell_impl.LSTMCell(i) for i in cell_sizes])
    state_dict = {
        dynamic_rnn_estimator._get_state_name(i):
        array_ops.expand_dims(math_ops.range(cell_size), 0)
        for i, cell_size in enumerate([5, 5, 3, 3, 7, 7])
    }
    expected_state = (core_rnn_cell_impl.LSTMStateTuple(
        np.reshape(np.arange(5), [1, -1]), np.reshape(np.arange(5), [1, -1])),
                      core_rnn_cell_impl.LSTMStateTuple(
                          np.reshape(np.arange(3), [1, -1]),
                          np.reshape(np.arange(3), [1, -1])),
                      core_rnn_cell_impl.LSTMStateTuple(
                          np.reshape(np.arange(7), [1, -1]),
                          np.reshape(np.arange(7), [1, -1])))
    actual_state = dynamic_rnn_estimator.dict_to_state_tuple(state_dict, cell)
    flattened_state = dynamic_rnn_estimator.state_tuple_to_dict(actual_state)

    with self.test_session() as sess:
      (state_dict_val, actual_state_val, flattened_state_val) = sess.run(
          [state_dict, actual_state, flattened_state])

    def _recursive_assert_equal(x, y):
      self.assertEqual(type(x), type(y))
      if isinstance(x, (list, tuple)):
        self.assertEqual(len(x), len(y))
        for i, _ in enumerate(x):
          _recursive_assert_equal(x[i], y[i])
      elif isinstance(x, np.ndarray):
        np.testing.assert_array_equal(x, y)
      else:
        self.fail('Unexpected type: {}'.format(type(x)))

    for k in state_dict_val.keys():
      np.testing.assert_array_almost_equal(
          state_dict_val[k],
          flattened_state_val[k],
          err_msg='Wrong value for state component {}.'.format(k))
    _recursive_assert_equal(expected_state, actual_state_val)
  def testLearnMajority(self):
    """Test learning the 'majority' function."""
    batch_size = 16
    sequence_length = 7
    train_steps = 200
    eval_steps = 20
    cell_type = 'lstm'
    cell_size = 4
    optimizer_type = 'Momentum'
    learning_rate = 2.0
    momentum = 0.9
    accuracy_threshold = 0.9

    def get_majority_input_fn(batch_size, sequence_length, seed=None):
      random_seed.set_random_seed(seed)

      def input_fn():
        random_sequence = random_ops.random_uniform(
            [batch_size, sequence_length], 0, 2, dtype=dtypes.int32, seed=seed)
        inputs = array_ops.expand_dims(math_ops.to_float(random_sequence), 2)
        labels = math_ops.to_int32(
            array_ops.squeeze(
                math_ops.reduce_sum(
                    inputs, reduction_indices=[1]) > (sequence_length / 2.0)))
        return {'inputs': inputs}, labels

      return input_fn

    seq_columns = [
        feature_column.real_valued_column(
            'inputs', dimension=cell_size)
    ]
    config = run_config.RunConfig(tf_random_seed=77)
    sequence_classifier = dynamic_rnn_estimator.single_value_rnn_classifier(
        num_classes=2,
        num_units=cell_size,
        sequence_feature_columns=seq_columns,
        cell_type=cell_type,
        optimizer_type=optimizer_type,
        learning_rate=learning_rate,
        momentum=momentum,
        config=config,
        predict_probabilities=True)

    train_input_fn = get_majority_input_fn(batch_size, sequence_length, 1111)
    eval_input_fn = get_majority_input_fn(batch_size, sequence_length, 2222)

    sequence_classifier.fit(input_fn=train_input_fn, steps=train_steps)
    evaluation = sequence_classifier.evaluate(
        input_fn=eval_input_fn, steps=eval_steps)
    accuracy = evaluation['accuracy']
    self.assertGreater(accuracy, accuracy_threshold,
                       'Accuracy should be higher than {}; got {}'.format(
                           accuracy_threshold, accuracy))

    # Testing `predict` when `predict_probabilities=True`.
    prediction_dict = sequence_classifier.predict(
        input_fn=eval_input_fn, as_iterable=False)
    self.assertListEqual(
        sorted(list(prediction_dict.keys())),
        sorted([
            dynamic_rnn_estimator.RNNKeys.PREDICTIONS_KEY,
            dynamic_rnn_estimator.RNNKeys.PROBABILITIES_KEY,
            dynamic_rnn_estimator._get_state_name(0),
            dynamic_rnn_estimator._get_state_name(1)
        ]))
    predictions = prediction_dict[dynamic_rnn_estimator.RNNKeys.PREDICTIONS_KEY]
    probabilities = prediction_dict[
        dynamic_rnn_estimator.RNNKeys.PROBABILITIES_KEY]
    self.assertListEqual(list(predictions.shape), [batch_size])
    self.assertListEqual(list(probabilities.shape), [batch_size, 2])
  def testLearnShiftByOne(self):
    """Tests that learning a 'shift-by-one' example.

    Each label sequence consists of the input sequence 'shifted' by one place.
    The RNN must learn to 'remember' the previous input.
    """
    batch_size = 16
    sequence_length = 32
    train_steps = 200
    eval_steps = 20
    cell_size = 4
    learning_rate = 0.3
    accuracy_threshold = 0.9

    def get_shift_input_fn(batch_size, sequence_length, seed=None):

      def input_fn():
        random_sequence = random_ops.random_uniform(
            [batch_size, sequence_length + 1],
            0,
            2,
            dtype=dtypes.int32,
            seed=seed)
        labels = array_ops.slice(random_sequence, [0, 0],
                                 [batch_size, sequence_length])
        inputs = array_ops.expand_dims(
            math_ops.to_float(
                array_ops.slice(random_sequence, [0, 1],
                                [batch_size, sequence_length])), 2)
        return {'inputs': inputs}, labels

      return input_fn

    seq_columns = [
        feature_column.real_valued_column(
            'inputs', dimension=cell_size)
    ]
    config = run_config.RunConfig(tf_random_seed=21212)
    sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_classifier(
        num_classes=2,
        num_units=cell_size,
        sequence_feature_columns=seq_columns,
        learning_rate=learning_rate,
        config=config,
        predict_probabilities=True)

    train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321)
    eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

    evaluation = sequence_estimator.evaluate(
        input_fn=eval_input_fn, steps=eval_steps)
    accuracy = evaluation['accuracy']
    self.assertGreater(accuracy, accuracy_threshold,
                       'Accuracy should be higher than {}; got {}'.format(
                           accuracy_threshold, accuracy))

    # Testing `predict` when `predict_probabilities=True`.
    prediction_dict = sequence_estimator.predict(
        input_fn=eval_input_fn, as_iterable=False)
    self.assertListEqual(
        sorted(list(prediction_dict.keys())),
        sorted([
            dynamic_rnn_estimator.RNNKeys.PREDICTIONS_KEY,
            dynamic_rnn_estimator.RNNKeys.PROBABILITIES_KEY,
            dynamic_rnn_estimator._get_state_name(0)
        ]))
    predictions = prediction_dict[dynamic_rnn_estimator.RNNKeys.PREDICTIONS_KEY]
    probabilities = prediction_dict[
        dynamic_rnn_estimator.RNNKeys.PROBABILITIES_KEY]
    self.assertListEqual(list(predictions.shape), [batch_size, sequence_length])
    self.assertListEqual(
        list(probabilities.shape), [batch_size, sequence_length, 2])
  def testLearnMajority(self):
    """Test learning the 'majority' function."""
    batch_size = 16
    sequence_length = 7
    train_steps = 200
    eval_steps = 20
    cell_type = 'lstm'
    cell_size = 4
    optimizer_type = 'Momentum'
    learning_rate = 2.0
    momentum = 0.9
    accuracy_threshold = 0.9

    def get_majority_input_fn(batch_size, sequence_length, seed=None):
      tf.set_random_seed(seed)
      def input_fn():
        random_sequence = tf.random_uniform(
            [batch_size, sequence_length], 0, 2, dtype=tf.int32, seed=seed)
        inputs = tf.expand_dims(tf.to_float(random_sequence), 2)
        labels = tf.to_int32(
            tf.squeeze(
                tf.reduce_sum(
                    inputs, reduction_indices=[1]) > (sequence_length / 2.0)))
        return {'inputs': inputs}, labels
      return input_fn

    seq_columns = [tf.contrib.layers.real_valued_column(
        'inputs', dimension=cell_size)]
    config = tf.contrib.learn.RunConfig(tf_random_seed=77)
    sequence_classifier = dynamic_rnn_estimator.single_value_rnn_classifier(
        num_classes=2,
        num_units=cell_size,
        sequence_feature_columns=seq_columns,
        cell_type=cell_type,
        optimizer_type=optimizer_type,
        learning_rate=learning_rate,
        momentum=momentum,
        config=config,
        predict_probabilities=True)

    train_input_fn = get_majority_input_fn(batch_size, sequence_length, 1111)
    eval_input_fn = get_majority_input_fn(batch_size, sequence_length, 2222)

    sequence_classifier.fit(input_fn=train_input_fn, steps=train_steps)
    evaluation = sequence_classifier.evaluate(
        input_fn=eval_input_fn, steps=eval_steps)
    accuracy = evaluation['accuracy']
    self.assertGreater(accuracy, accuracy_threshold,
                       'Accuracy should be higher than {}; got {}'.format(
                           accuracy_threshold, accuracy))

    # Testing `predict` when `predict_probabilities=True`.
    prediction_dict = sequence_classifier.predict(
        input_fn=eval_input_fn, as_iterable=False)
    self.assertListEqual(
        sorted(list(prediction_dict.keys())),
        sorted([dynamic_rnn_estimator.RNNKeys.PREDICTIONS_KEY,
                dynamic_rnn_estimator.RNNKeys.PROBABILITIES_KEY,
                dynamic_rnn_estimator._get_state_name(0),
                dynamic_rnn_estimator._get_state_name(1)]))
    predictions = prediction_dict[dynamic_rnn_estimator.RNNKeys.PREDICTIONS_KEY]
    probabilities = prediction_dict[
        dynamic_rnn_estimator.RNNKeys.PROBABILITIES_KEY]
    self.assertListEqual(list(predictions.shape), [batch_size])
    self.assertListEqual(list(probabilities.shape), [batch_size, 2])
  def testLearnShiftByOne(self):
    """Tests that learning a 'shift-by-one' example.

    Each label sequence consists of the input sequence 'shifted' by one place.
    The RNN must learn to 'remember' the previous input.
    """
    batch_size = 16
    sequence_length = 32
    train_steps = 200
    eval_steps = 20
    cell_size = 4
    learning_rate = 0.3
    accuracy_threshold = 0.9

    def get_shift_input_fn(batch_size, sequence_length, seed=None):
      def input_fn():
        random_sequence = tf.random_uniform(
            [batch_size, sequence_length + 1], 0, 2, dtype=tf.int32, seed=seed)
        labels = tf.slice(
            random_sequence, [0, 0], [batch_size, sequence_length])
        inputs = tf.expand_dims(
            tf.to_float(tf.slice(
                random_sequence, [0, 1], [batch_size, sequence_length])), 2)
        return {'inputs': inputs}, labels
      return input_fn

    seq_columns = [tf.contrib.layers.real_valued_column(
        'inputs', dimension=cell_size)]
    config = tf.contrib.learn.RunConfig(tf_random_seed=21212)
    sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_classifier(
        num_classes=2,
        num_units=cell_size,
        sequence_feature_columns=seq_columns,
        learning_rate=learning_rate,
        config=config,
        predict_probabilities=True)

    train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321)
    eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

    evaluation = sequence_estimator.evaluate(
        input_fn=eval_input_fn, steps=eval_steps)
    accuracy = evaluation['accuracy']
    self.assertGreater(accuracy, accuracy_threshold,
                       'Accuracy should be higher than {}; got {}'.format(
                           accuracy_threshold, accuracy))

    # Testing `predict` when `predict_probabilities=True`.
    prediction_dict = sequence_estimator.predict(
        input_fn=eval_input_fn, as_iterable=False)
    self.assertListEqual(
        sorted(list(prediction_dict.keys())),
        sorted([dynamic_rnn_estimator.RNNKeys.PREDICTIONS_KEY,
                dynamic_rnn_estimator.RNNKeys.PROBABILITIES_KEY,
                dynamic_rnn_estimator._get_state_name(0)]))
    predictions = prediction_dict[dynamic_rnn_estimator.RNNKeys.PREDICTIONS_KEY]
    probabilities = prediction_dict[
        dynamic_rnn_estimator.RNNKeys.PROBABILITIES_KEY]
    self.assertListEqual(list(predictions.shape), [batch_size, sequence_length])
    self.assertListEqual(
        list(probabilities.shape), [batch_size, sequence_length, 2])
Exemple #13
0
    def testLearnMajority(self):
        """Test learning the 'majority' function."""
        batch_size = 16
        sequence_length = 7
        train_steps = 200
        eval_steps = 20
        cell_type = 'lstm'
        cell_size = 4
        optimizer_type = 'Momentum'
        learning_rate = 2.0
        momentum = 0.9
        accuracy_threshold = 0.9

        def get_majority_input_fn(batch_size, sequence_length, seed=None):
            random_seed.set_random_seed(seed)

            def input_fn():
                random_sequence = random_ops.random_uniform(
                    [batch_size, sequence_length],
                    0,
                    2,
                    dtype=dtypes.int32,
                    seed=seed)
                inputs = array_ops.expand_dims(
                    math_ops.to_float(random_sequence), 2)
                labels = math_ops.to_int32(
                    array_ops.squeeze(
                        math_ops.reduce_sum(inputs, reduction_indices=[1]) > (
                            sequence_length / 2.0)))
                return {'inputs': inputs}, labels

            return input_fn

        seq_columns = [
            feature_column.real_valued_column('inputs', dimension=cell_size)
        ]
        config = run_config.RunConfig(tf_random_seed=77)
        sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
            problem_type=constants.ProblemType.CLASSIFICATION,
            prediction_type=dynamic_rnn_estimator.PredictionType.SINGLE_VALUE,
            num_classes=2,
            num_units=cell_size,
            sequence_feature_columns=seq_columns,
            cell_type=cell_type,
            optimizer=optimizer_type,
            learning_rate=learning_rate,
            momentum=momentum,
            config=config,
            predict_probabilities=True)

        train_input_fn = get_majority_input_fn(batch_size, sequence_length,
                                               1111)
        eval_input_fn = get_majority_input_fn(batch_size, sequence_length,
                                              2222)

        sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
        evaluation = sequence_estimator.evaluate(input_fn=eval_input_fn,
                                                 steps=eval_steps)
        accuracy = evaluation['accuracy']
        self.assertGreater(
            accuracy, accuracy_threshold,
            'Accuracy should be higher than {}; got {}'.format(
                accuracy_threshold, accuracy))

        # Testing `predict` when `predict_probabilities=True`.
        prediction_dict = sequence_estimator.predict(input_fn=eval_input_fn,
                                                     as_iterable=False)
        self.assertListEqual(
            sorted(list(prediction_dict.keys())),
            sorted([
                prediction_key.PredictionKey.CLASSES,
                prediction_key.PredictionKey.PROBABILITIES,
                dynamic_rnn_estimator._get_state_name(0),
                dynamic_rnn_estimator._get_state_name(1)
            ]))
        predictions = prediction_dict[prediction_key.PredictionKey.CLASSES]
        probabilities = prediction_dict[
            prediction_key.PredictionKey.PROBABILITIES]
        self.assertListEqual(list(predictions.shape), [batch_size])
        self.assertListEqual(list(probabilities.shape), [batch_size, 2])
  def testLearnShiftByOne(self):
    """Tests that learning a 'shift-by-one' example.

    Each label sequence consists of the input sequence 'shifted' by one place.
    The RNN must learn to 'remember' the previous input.
    """
    batch_size = 16
    sequence_length = 32
    train_steps = 200
    eval_steps = 20
    cell_size = 4
    learning_rate = 0.3
    accuracy_threshold = 0.9

    def get_shift_input_fn(batch_size, sequence_length, seed=None):

      def input_fn():
        random_sequence = random_ops.random_uniform(
            [batch_size, sequence_length + 1],
            0,
            2,
            dtype=dtypes.int32,
            seed=seed)
        labels = array_ops.slice(random_sequence, [0, 0],
                                 [batch_size, sequence_length])
        inputs = array_ops.expand_dims(
            math_ops.to_float(
                array_ops.slice(random_sequence, [0, 1],
                                [batch_size, sequence_length])), 2)
        return {'inputs': inputs}, labels

      return input_fn

    seq_columns = [
        feature_column.real_valued_column(
            'inputs', dimension=cell_size)
    ]
    config = run_config.RunConfig(tf_random_seed=21212)
    sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
        problem_type=constants.ProblemType.CLASSIFICATION,
        prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE,
        num_classes=2,
        num_units=cell_size,
        sequence_feature_columns=seq_columns,
        learning_rate=learning_rate,
        config=config,
        predict_probabilities=True)

    train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321)
    eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

    evaluation = sequence_estimator.evaluate(
        input_fn=eval_input_fn, steps=eval_steps)
    accuracy = evaluation['accuracy']
    self.assertGreater(accuracy, accuracy_threshold,
                       'Accuracy should be higher than {}; got {}'.format(
                           accuracy_threshold, accuracy))

    # Testing `predict` when `predict_probabilities=True`.
    prediction_dict = sequence_estimator.predict(
        input_fn=eval_input_fn, as_iterable=False)
    self.assertListEqual(
        sorted(list(prediction_dict.keys())),
        sorted([
            prediction_key.PredictionKey.CLASSES,
            prediction_key.PredictionKey.PROBABILITIES,
            dynamic_rnn_estimator._get_state_name(0)
        ]))
    predictions = prediction_dict[prediction_key.PredictionKey.CLASSES]
    probabilities = prediction_dict[
        prediction_key.PredictionKey.PROBABILITIES]
    self.assertListEqual(list(predictions.shape), [batch_size, sequence_length])
    self.assertListEqual(
        list(probabilities.shape), [batch_size, sequence_length, 2])