Esempio n. 1
0
  def testMultiRNNState(self):
    """Test that state flattening/reconstruction works for `MultiRNNCell`."""
    batch_size = 11
    sequence_length = 16
    train_steps = 5
    cell_sizes = [4, 8, 7]
    learning_rate = 0.1

    def get_shift_input_fn(batch_size, sequence_length, seed=None):

      def input_fn():
        random_sequence = random_ops.random_uniform(
            [batch_size, sequence_length + 1],
            0,
            2,
            dtype=dtypes.int32,
            seed=seed)
        labels = array_ops.slice(random_sequence, [0, 0],
                                 [batch_size, sequence_length])
        inputs = array_ops.expand_dims(
            math_ops.cast(
                array_ops.slice(random_sequence, [0, 1],
                                [batch_size, sequence_length]),
                dtypes.float32), 2)
        input_dict = {
            dynamic_rnn_estimator._get_state_name(i): random_ops.random_uniform(
                [batch_size, cell_size], seed=((i + 1) * seed))
            for i, cell_size in enumerate([4, 4, 8, 8, 7, 7])
        }
        input_dict['inputs'] = inputs
        return input_dict, labels

      return input_fn

    seq_columns = [feature_column.real_valued_column('inputs', dimension=1)]
    config = run_config.RunConfig(tf_random_seed=21212)
    cell_type = 'lstm'
    sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
        problem_type=constants.ProblemType.CLASSIFICATION,
        prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE,
        num_classes=2,
        num_units=cell_sizes,
        sequence_feature_columns=seq_columns,
        cell_type=cell_type,
        learning_rate=learning_rate,
        config=config,
        predict_probabilities=True)

    train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321)
    eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

    prediction_dict = sequence_estimator.predict(
        input_fn=eval_input_fn, as_iterable=False)
    for i, state_size in enumerate([4, 4, 8, 8, 7, 7]):
      state_piece = prediction_dict[dynamic_rnn_estimator._get_state_name(i)]
      self.assertListEqual(list(state_piece.shape), [batch_size, state_size])
Esempio n. 2
0
  def testLearnSineFunction(self):
    """Tests learning a sine function."""
    batch_size = 8
    sequence_length = 64
    train_steps = 200
    eval_steps = 20
    cell_size = [4]
    learning_rate = 0.1
    loss_threshold = 0.02

    def get_sin_input_fn(batch_size, sequence_length, increment, seed=None):

      def _sin_fn(x):
        ranger = math_ops.linspace(
            array_ops.reshape(x[0], []), (sequence_length - 1) * increment,
            sequence_length + 1)
        return math_ops.sin(ranger)

      def input_fn():
        starts = random_ops.random_uniform(
            [batch_size], maxval=(2 * np.pi), seed=seed)
        sin_curves = map_fn.map_fn(
            _sin_fn, (starts,), dtype=dtypes.float32)
        inputs = array_ops.expand_dims(
            array_ops.slice(sin_curves, [0, 0], [batch_size, sequence_length]),
            2)
        labels = array_ops.slice(sin_curves, [0, 1],
                                 [batch_size, sequence_length])
        return {'inputs': inputs}, labels

      return input_fn

    seq_columns = [
        feature_column.real_valued_column(
            'inputs', dimension=cell_size[0])
    ]
    config = run_config.RunConfig(tf_random_seed=1234)
    sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
        problem_type=constants.ProblemType.LINEAR_REGRESSION,
        prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE,
        num_units=cell_size,
        sequence_feature_columns=seq_columns,
        learning_rate=learning_rate,
        dropout_keep_probabilities=[0.9, 0.9],
        config=config)

    train_input_fn = get_sin_input_fn(
        batch_size, sequence_length, np.pi / 32, seed=1234)
    eval_input_fn = get_sin_input_fn(
        batch_size, sequence_length, np.pi / 32, seed=4321)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
    loss = sequence_estimator.evaluate(
        input_fn=eval_input_fn, steps=eval_steps)['loss']
    self.assertLess(loss, loss_threshold,
                    'Loss should be less than {}; got {}'.format(loss_threshold,
                                                                 loss))
Esempio n. 3
0
 def estimator_fn():
   return dynamic_rnn_estimator.DynamicRnnEstimator(
       problem_type=constants.ProblemType.CLASSIFICATION,
       prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE,
       num_classes=2,
       num_units=self.NUM_RNN_CELL_UNITS,
       sequence_feature_columns=self.sequence_feature_columns,
       context_feature_columns=self.context_feature_columns,
       predict_probabilities=True,
       model_dir=model_dir)
Esempio n. 4
0
  def DISABLED_testLearnMajority(self):
    """Test learning the 'majority' function."""
    batch_size = 16
    sequence_length = 7
    train_steps = 500
    eval_steps = 20
    cell_type = 'lstm'
    cell_size = 4
    optimizer_type = 'Momentum'
    learning_rate = 2.0
    momentum = 0.9
    accuracy_threshold = 0.6

    def get_majority_input_fn(batch_size, sequence_length, seed=None):
      random_seed.set_random_seed(seed)

      def input_fn():
        random_sequence = random_ops.random_uniform(
            [batch_size, sequence_length], 0, 2, dtype=dtypes.int32, seed=seed)
        inputs = array_ops.expand_dims(
            math_ops.cast(random_sequence, dtypes.float32), 2)
        labels = math_ops.cast(
            array_ops.squeeze(
                math_ops.reduce_sum(inputs, axis=[1]) > (
                    sequence_length / 2.0)),
            dtypes.int32)
        return {'inputs': inputs}, labels

      return input_fn

    seq_columns = [
        feature_column.real_valued_column(
            'inputs', dimension=cell_size)
    ]
    config = run_config.RunConfig(tf_random_seed=77)
    sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
        problem_type=constants.ProblemType.CLASSIFICATION,
        prediction_type=rnn_common.PredictionType.SINGLE_VALUE,
        num_classes=2,
        num_units=cell_size,
        sequence_feature_columns=seq_columns,
        cell_type=cell_type,
        optimizer=optimizer_type,
        learning_rate=learning_rate,
        momentum=momentum,
        config=config,
        predict_probabilities=True)

    train_input_fn = get_majority_input_fn(batch_size, sequence_length, 1111)
    eval_input_fn = get_majority_input_fn(batch_size, sequence_length, 2222)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
    evaluation = sequence_estimator.evaluate(
        input_fn=eval_input_fn, steps=eval_steps)
    accuracy = evaluation['accuracy']
    self.assertGreater(accuracy, accuracy_threshold,
                       'Accuracy should be higher than {}; got {}'.format(
                           accuracy_threshold, accuracy))

    # Testing `predict` when `predict_probabilities=True`.
    prediction_dict = sequence_estimator.predict(
        input_fn=eval_input_fn, as_iterable=False)
    self.assertListEqual(
        sorted(list(prediction_dict.keys())),
        sorted([
            prediction_key.PredictionKey.CLASSES,
            prediction_key.PredictionKey.PROBABILITIES,
            dynamic_rnn_estimator._get_state_name(0),
            dynamic_rnn_estimator._get_state_name(1)
        ]))
    predictions = prediction_dict[prediction_key.PredictionKey.CLASSES]
    probabilities = prediction_dict[
        prediction_key.PredictionKey.PROBABILITIES]
    self.assertListEqual(list(predictions.shape), [batch_size])
    self.assertListEqual(list(probabilities.shape), [batch_size, 2])
Esempio n. 5
0
  def testLearnMean(self):
    """Test learning to calculate a mean."""
    batch_size = 16
    sequence_length = 3
    train_steps = 200
    eval_steps = 20
    cell_type = 'basic_rnn'
    cell_size = 8
    optimizer_type = 'Momentum'
    learning_rate = 0.1
    momentum = 0.9
    loss_threshold = 0.1

    def get_mean_input_fn(batch_size, sequence_length, seed=None):

      def input_fn():
        # Create examples by choosing 'centers' and adding uniform noise.
        centers = math_ops.matmul(
            random_ops.random_uniform(
                [batch_size, 1], -0.75, 0.75, dtype=dtypes.float32, seed=seed),
            array_ops.ones([1, sequence_length]))
        noise = random_ops.random_uniform(
            [batch_size, sequence_length],
            -0.25,
            0.25,
            dtype=dtypes.float32,
            seed=seed)
        sequences = centers + noise

        inputs = array_ops.expand_dims(sequences, 2)
        labels = math_ops.reduce_mean(sequences, axis=[1])
        return {'inputs': inputs}, labels

      return input_fn

    seq_columns = [
        feature_column.real_valued_column(
            'inputs', dimension=cell_size)
    ]
    config = run_config.RunConfig(tf_random_seed=6)
    sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
        problem_type=constants.ProblemType.LINEAR_REGRESSION,
        prediction_type=rnn_common.PredictionType.SINGLE_VALUE,
        num_units=cell_size,
        sequence_feature_columns=seq_columns,
        cell_type=cell_type,
        optimizer=optimizer_type,
        learning_rate=learning_rate,
        momentum=momentum,
        config=config)

    train_input_fn = get_mean_input_fn(batch_size, sequence_length, 121)
    eval_input_fn = get_mean_input_fn(batch_size, sequence_length, 212)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
    evaluation = sequence_estimator.evaluate(
        input_fn=eval_input_fn, steps=eval_steps)
    loss = evaluation['loss']
    self.assertLess(loss, loss_threshold,
                    'Loss should be less than {}; got {}'.format(loss_threshold,
                                                                 loss))
Esempio n. 6
0
  def testLearnShiftByOne(self):
    """Tests that learning a 'shift-by-one' example.

    Each label sequence consists of the input sequence 'shifted' by one place.
    The RNN must learn to 'remember' the previous input.
    """
    batch_size = 16
    sequence_length = 32
    train_steps = 200
    eval_steps = 20
    cell_size = 4
    learning_rate = 0.3
    accuracy_threshold = 0.9

    def get_shift_input_fn(batch_size, sequence_length, seed=None):

      def input_fn():
        random_sequence = random_ops.random_uniform(
            [batch_size, sequence_length + 1],
            0,
            2,
            dtype=dtypes.int32,
            seed=seed)
        labels = array_ops.slice(random_sequence, [0, 0],
                                 [batch_size, sequence_length])
        inputs = array_ops.expand_dims(
            math_ops.cast(
                array_ops.slice(random_sequence, [0, 1],
                                [batch_size, sequence_length]),
                dtypes.float32),
            2)
        return {'inputs': inputs}, labels

      return input_fn

    seq_columns = [
        feature_column.real_valued_column(
            'inputs', dimension=cell_size)
    ]
    config = run_config.RunConfig(tf_random_seed=21212)
    sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
        problem_type=constants.ProblemType.CLASSIFICATION,
        prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE,
        num_classes=2,
        num_units=cell_size,
        sequence_feature_columns=seq_columns,
        learning_rate=learning_rate,
        config=config,
        predict_probabilities=True)

    train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321)
    eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

    evaluation = sequence_estimator.evaluate(
        input_fn=eval_input_fn, steps=eval_steps)
    accuracy = evaluation['accuracy']
    self.assertGreater(accuracy, accuracy_threshold,
                       'Accuracy should be higher than {}; got {}'.format(
                           accuracy_threshold, accuracy))

    # Testing `predict` when `predict_probabilities=True`.
    prediction_dict = sequence_estimator.predict(
        input_fn=eval_input_fn, as_iterable=False)
    self.assertListEqual(
        sorted(list(prediction_dict.keys())),
        sorted([
            prediction_key.PredictionKey.CLASSES,
            prediction_key.PredictionKey.PROBABILITIES,
            dynamic_rnn_estimator._get_state_name(0)
        ]))
    predictions = prediction_dict[prediction_key.PredictionKey.CLASSES]
    probabilities = prediction_dict[
        prediction_key.PredictionKey.PROBABILITIES]
    self.assertListEqual(list(predictions.shape), [batch_size, sequence_length])
    self.assertListEqual(
        list(probabilities.shape), [batch_size, sequence_length, 2])
Esempio n. 7
0
  def testMultipleRuns(self):
    """Tests resuming training by feeding state."""
    cell_sizes = [4, 7]
    batch_size = 11
    learning_rate = 0.1
    train_sequence_length = 21
    train_steps = 121
    dropout_keep_probabilities = [0.5, 0.5, 0.5]
    prediction_steps = [3, 2, 5, 11, 6]

    def get_input_fn(batch_size, sequence_length, state_dict, starting_step=0):

      def input_fn():
        sequence = constant_op.constant(
            [[(starting_step + i + j) % 2 for j in range(sequence_length + 1)]
             for i in range(batch_size)],
            dtype=dtypes.int32)
        labels = array_ops.slice(sequence, [0, 0],
                                 [batch_size, sequence_length])
        inputs = array_ops.expand_dims(
            math_ops.cast(
                array_ops.slice(sequence, [0, 1], [batch_size, sequence_length
                                                  ]),
                dtypes.float32), 2)
        input_dict = state_dict
        input_dict['inputs'] = inputs
        return input_dict, labels

      return input_fn

    seq_columns = [feature_column.real_valued_column('inputs', dimension=1)]
    config = run_config.RunConfig(tf_random_seed=21212)

    model_dir = tempfile.mkdtemp()
    sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
        problem_type=constants.ProblemType.CLASSIFICATION,
        prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE,
        num_classes=2,
        sequence_feature_columns=seq_columns,
        num_units=cell_sizes,
        cell_type='lstm',
        dropout_keep_probabilities=dropout_keep_probabilities,
        learning_rate=learning_rate,
        config=config,
        model_dir=model_dir)

    train_input_fn = get_input_fn(
        batch_size, train_sequence_length, state_dict={})

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

    def incremental_predict(estimator, increments):
      """Run `estimator.predict` for `i` steps for `i` in `increments`."""
      step = 0
      incremental_state_dict = {}
      for increment in increments:
        input_fn = get_input_fn(
            batch_size,
            increment,
            state_dict=incremental_state_dict,
            starting_step=step)
        prediction_dict = estimator.predict(
            input_fn=input_fn, as_iterable=False)
        step += increment
        incremental_state_dict = {
            k: v
            for (k, v) in prediction_dict.items()
            if k.startswith(rnn_common.RNNKeys.STATE_PREFIX)
        }
      return prediction_dict

    pred_all_at_once = incremental_predict(sequence_estimator,
                                           [sum(prediction_steps)])
    pred_step_by_step = incremental_predict(sequence_estimator,
                                            prediction_steps)

    # Check that the last `prediction_steps[-1]` steps give the same
    # predictions.
    np.testing.assert_array_equal(
        pred_all_at_once[prediction_key.PredictionKey.CLASSES]
        [:, -1 * prediction_steps[-1]:],
        pred_step_by_step[prediction_key.PredictionKey.CLASSES],
        err_msg='Mismatch on last {} predictions.'.format(prediction_steps[-1]))
    # Check that final states are identical.
    for k, v in pred_all_at_once.items():
      if k.startswith(rnn_common.RNNKeys.STATE_PREFIX):
        np.testing.assert_array_equal(
            v, pred_step_by_step[k], err_msg='Mismatch on state {}.'.format(k))