Exemple #1
0
  def testLearnShiftByOne(self):
    """Tests that learning a 'shift-by-one' example.

    Each label sequence consists of the input sequence 'shifted' by one place.
    The RNN must learn to 'remember' the previous input.
    """
    batch_size = 16
    num_classes = 2
    num_unroll = 32
    sequence_length = 32
    train_steps = 200
    eval_steps = 20
    num_units = [4]
    learning_rate = 0.5
    accuracy_threshold = 0.9

    def get_shift_input_fn(sequence_length, seed=None):

      def input_fn():
        random_sequence = random_ops.random_uniform(
            [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed)
        labels = array_ops.slice(random_sequence, [0], [sequence_length])
        inputs = math_ops.to_float(
            array_ops.slice(random_sequence, [1], [sequence_length]))
        return {'inputs': inputs}, labels

      return input_fn

    seq_columns = [
        feature_column.real_valued_column(
            'inputs', dimension=1)
    ]
    config = run_config.RunConfig(tf_random_seed=21212)
    sequence_estimator = ssre.StateSavingRnnEstimator(
        constants.ProblemType.CLASSIFICATION,
        num_units=num_units,
        cell_type='lstm',
        num_unroll=num_unroll,
        batch_size=batch_size,
        sequence_feature_columns=seq_columns,
        num_classes=num_classes,
        learning_rate=learning_rate,
        config=config,
        predict_probabilities=True,
        queue_capacity=2 + batch_size,
        seed=1234)

    train_input_fn = get_shift_input_fn(sequence_length, seed=12321)
    eval_input_fn = get_shift_input_fn(sequence_length, seed=32123)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

    evaluation = sequence_estimator.evaluate(
        input_fn=eval_input_fn, steps=eval_steps)
    accuracy = evaluation['accuracy']
    self.assertGreater(accuracy, accuracy_threshold,
                       'Accuracy should be higher than {}; got {}'.format(
                           accuracy_threshold, accuracy))

    # Testing `predict` when `predict_probabilities=True`.
    prediction_dict = sequence_estimator.predict(
        input_fn=eval_input_fn, as_iterable=False)
    self.assertListEqual(
        sorted(list(prediction_dict.keys())),
        sorted([
            prediction_key.PredictionKey.CLASSES,
            prediction_key.PredictionKey.PROBABILITIES, ssre._get_state_name(0)
        ]))
    predictions = prediction_dict[prediction_key.PredictionKey.CLASSES]
    probabilities = prediction_dict[prediction_key.PredictionKey.PROBABILITIES]
    self.assertListEqual(list(predictions.shape), [batch_size, sequence_length])
    self.assertListEqual(
        list(probabilities.shape), [batch_size, sequence_length, 2])
  def testLearnShiftByOne(self):
    """Tests that learning a 'shift-by-one' example.

    Each label sequence consists of the input sequence 'shifted' by one place.
    The RNN must learn to 'remember' the previous input.
    """
    batch_size = 16
    num_classes = 2
    num_unroll = 32
    sequence_length = 32
    train_steps = 200
    eval_steps = 20
    num_units = 4
    learning_rate = 0.5
    accuracy_threshold = 0.9
    input_key_column_name = 'input_key_column'

    def get_shift_input_fn(sequence_length, seed=None):

      def input_fn():
        random_sequence = random_ops.random_uniform(
            [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed)
        labels = array_ops.slice(random_sequence, [0], [sequence_length])
        inputs = math_ops.to_float(
            array_ops.slice(random_sequence, [1], [sequence_length]))
        input_key = string_ops.string_join([
            'key_', string_ops.as_string(
                random_ops.random_uniform(
                    (),
                    minval=0,
                    maxval=10000000,
                    dtype=dtypes.int32,
                    seed=seed))
        ])
        return {'inputs': inputs, input_key_column_name: input_key}, labels

      return input_fn

    seq_columns = [
        feature_column.real_valued_column(
            'inputs', dimension=num_units)
    ]
    config = run_config.RunConfig(tf_random_seed=21212)
    sequence_estimator = ssre.multi_value_rnn_classifier(
        num_classes=num_classes,
        num_units=num_units,
        num_unroll=num_unroll,
        batch_size=batch_size,
        input_key_column_name=input_key_column_name,
        sequence_feature_columns=seq_columns,
        learning_rate=learning_rate,
        config=config,
        predict_probabilities=True,
        queue_capacity=2 + batch_size)

    train_input_fn = get_shift_input_fn(sequence_length, seed=12321)
    eval_input_fn = get_shift_input_fn(sequence_length, seed=32123)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

    evaluation = sequence_estimator.evaluate(
        input_fn=eval_input_fn, steps=eval_steps)
    accuracy = evaluation['accuracy']
    self.assertGreater(accuracy, accuracy_threshold,
                       'Accuracy should be higher than {}; got {}'.format(
                           accuracy_threshold, accuracy))

    # Testing `predict` when `predict_probabilities=True`.
    prediction_dict = sequence_estimator.predict(
        input_fn=eval_input_fn, as_iterable=False)
    self.assertListEqual(
        sorted(list(prediction_dict.keys())),
        sorted([
            ssre.RNNKeys.PREDICTIONS_KEY, ssre.RNNKeys.PROBABILITIES_KEY,
            ssre._get_state_name(0)
        ]))
    predictions = prediction_dict[ssre.RNNKeys.PREDICTIONS_KEY]
    probabilities = prediction_dict[ssre.RNNKeys.PROBABILITIES_KEY]
    self.assertListEqual(list(predictions.shape), [batch_size, sequence_length])
    self.assertListEqual(
        list(probabilities.shape), [batch_size, sequence_length, 2])
  def testLearnShiftByOne(self):
    """Tests that learning a 'shift-by-one' example.

    Each label sequence consists of the input sequence 'shifted' by one place.
    The RNN must learn to 'remember' the previous input.
    """
    batch_size = 16
    num_classes = 2
    num_unroll = 32
    sequence_length = 32
    train_steps = 200
    eval_steps = 20
    num_units = [4]
    learning_rate = 0.5
    accuracy_threshold = 0.9

    def get_shift_input_fn(sequence_length, seed=None):

      def input_fn():
        random_sequence = random_ops.random_uniform(
            [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed)
        labels = array_ops.slice(random_sequence, [0], [sequence_length])
        inputs = math_ops.to_float(
            array_ops.slice(random_sequence, [1], [sequence_length]))
        return {'inputs': inputs}, labels

      return input_fn

    seq_columns = [
        feature_column.real_valued_column(
            'inputs', dimension=1)
    ]
    config = run_config.RunConfig(tf_random_seed=21212)
    sequence_estimator = ssre.StateSavingRnnEstimator(
        constants.ProblemType.CLASSIFICATION,
        num_units=num_units,
        cell_type='lstm',
        num_unroll=num_unroll,
        batch_size=batch_size,
        sequence_feature_columns=seq_columns,
        num_classes=num_classes,
        learning_rate=learning_rate,
        config=config,
        predict_probabilities=True,
        queue_capacity=2 + batch_size,
        seed=1234)

    train_input_fn = get_shift_input_fn(sequence_length, seed=12321)
    eval_input_fn = get_shift_input_fn(sequence_length, seed=32123)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

    evaluation = sequence_estimator.evaluate(
        input_fn=eval_input_fn, steps=eval_steps)
    accuracy = evaluation['accuracy']
    self.assertGreater(accuracy, accuracy_threshold,
                       'Accuracy should be higher than {}; got {}'.format(
                           accuracy_threshold, accuracy))

    # Testing `predict` when `predict_probabilities=True`.
    prediction_dict = sequence_estimator.predict(
        input_fn=eval_input_fn, as_iterable=False)
    self.assertListEqual(
        sorted(list(prediction_dict.keys())),
        sorted([
            prediction_key.PredictionKey.CLASSES,
            prediction_key.PredictionKey.PROBABILITIES, ssre._get_state_name(0)
        ]))
    predictions = prediction_dict[prediction_key.PredictionKey.CLASSES]
    probabilities = prediction_dict[prediction_key.PredictionKey.PROBABILITIES]
    self.assertListEqual(list(predictions.shape), [batch_size, sequence_length])
    self.assertListEqual(
        list(probabilities.shape), [batch_size, sequence_length, 2])
Exemple #4
0
    def testLearnShiftByOne(self):
        """Tests that learning a 'shift-by-one' example.

    Each label sequence consists of the input sequence 'shifted' by one place.
    The RNN must learn to 'remember' the previous input.
    """
        batch_size = 16
        num_classes = 2
        num_unroll = 32
        sequence_length = 32
        train_steps = 200
        eval_steps = 20
        cell_size = 4
        learning_rate = 0.5
        accuracy_threshold = 0.9
        input_key_column_name = 'input_key_column'

        def get_shift_input_fn(sequence_length, seed=None):
            def input_fn():
                random_sequence = random_ops.random_uniform(
                    [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed)
                labels = array_ops.slice(random_sequence, [0],
                                         [sequence_length])
                inputs = math_ops.to_float(
                    array_ops.slice(random_sequence, [1], [sequence_length]))
                input_key = string_ops.string_join([
                    'key_',
                    string_ops.as_string(
                        random_ops.random_uniform((),
                                                  minval=0,
                                                  maxval=10000000,
                                                  dtype=dtypes.int32,
                                                  seed=seed))
                ])
                return {
                    'inputs': inputs,
                    input_key_column_name: input_key
                }, labels

            return input_fn

        seq_columns = [
            feature_column.real_valued_column('inputs', dimension=cell_size)
        ]
        config = run_config.RunConfig(tf_random_seed=21212)
        sequence_estimator = ssre.multi_value_rnn_classifier(
            num_classes=num_classes,
            num_units=cell_size,
            num_unroll=num_unroll,
            batch_size=batch_size,
            input_key_column_name=input_key_column_name,
            sequence_feature_columns=seq_columns,
            learning_rate=learning_rate,
            config=config,
            predict_probabilities=True,
            queue_capacity=2 + batch_size)

        train_input_fn = get_shift_input_fn(sequence_length, seed=12321)
        eval_input_fn = get_shift_input_fn(sequence_length, seed=32123)

        sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

        evaluation = sequence_estimator.evaluate(input_fn=eval_input_fn,
                                                 steps=eval_steps)
        accuracy = evaluation['accuracy']
        self.assertGreater(
            accuracy, accuracy_threshold,
            'Accuracy should be higher than {}; got {}'.format(
                accuracy_threshold, accuracy))

        # Testing `predict` when `predict_probabilities=True`.
        prediction_dict = sequence_estimator.predict(input_fn=eval_input_fn,
                                                     as_iterable=False)
        self.assertListEqual(
            sorted(list(prediction_dict.keys())),
            sorted([
                ssre.RNNKeys.PREDICTIONS_KEY, ssre.RNNKeys.PROBABILITIES_KEY,
                ssre._get_state_name(0)
            ]))
        predictions = prediction_dict[ssre.RNNKeys.PREDICTIONS_KEY]
        probabilities = prediction_dict[ssre.RNNKeys.PROBABILITIES_KEY]
        self.assertListEqual(list(predictions.shape),
                             [batch_size, sequence_length])
        self.assertListEqual(list(probabilities.shape),
                             [batch_size, sequence_length, 2])