def estimator_fn():
   return ssre.StateSavingRnnEstimator(
       constants.ProblemType.CLASSIFICATION,
       num_units=num_units,
       num_unroll=num_unroll,
       batch_size=batch_size,
       sequence_feature_columns=seq_columns,
       num_classes=num_classes,
       predict_probabilities=True,
       model_dir=model_dir,
       queue_capacity=2 + batch_size,
       seed=1234)
Beispiel #2
0
  def testLearnLyrics(self):
    lyrics = 'if I go there will be trouble and if I stay it will be double'
    lyrics_list = lyrics.split()
    sequence_length = len(lyrics_list)
    vocab = set(lyrics_list)
    batch_size = 16
    num_classes = len(vocab)
    num_unroll = 7  # not a divisor of sequence_length
    train_steps = 350
    eval_steps = 30
    num_units = [4]
    learning_rate = 0.4
    accuracy_threshold = 0.65

    def get_lyrics_input_fn(seed):

      def input_fn():
        start = random_ops.random_uniform(
            (), minval=0, maxval=sequence_length, dtype=dtypes.int32, seed=seed)
        # Concatenate lyrics_list so inputs and labels wrap when start > 0.
        lyrics_list_concat = lyrics_list + lyrics_list
        inputs_dense = array_ops.slice(lyrics_list_concat, [start],
                                       [sequence_length])
        indices = array_ops.constant(
            [[i, 0] for i in range(sequence_length)], dtype=dtypes.int64)
        dense_shape = [sequence_length, 1]
        inputs = sparse_tensor.SparseTensor(
            indices=indices, values=inputs_dense, dense_shape=dense_shape)
        table = lookup.string_to_index_table_from_tensor(
            mapping=list(vocab), default_value=-1, name='lookup')
        labels = table.lookup(
            array_ops.slice(lyrics_list_concat, [start + 1], [sequence_length]))
        return {'lyrics': inputs}, labels

      return input_fn

    sequence_feature_columns = [
        feature_column.embedding_column(
            feature_column.sparse_column_with_keys('lyrics', vocab),
            dimension=8)
    ]
    config = run_config.RunConfig(tf_random_seed=21212)
    sequence_estimator = ssre.StateSavingRnnEstimator(
        constants.ProblemType.CLASSIFICATION,
        num_units=num_units,
        cell_type='basic_rnn',
        num_unroll=num_unroll,
        batch_size=batch_size,
        sequence_feature_columns=sequence_feature_columns,
        num_classes=num_classes,
        learning_rate=learning_rate,
        config=config,
        predict_probabilities=True,
        queue_capacity=2 + batch_size,
        seed=1234)

    train_input_fn = get_lyrics_input_fn(seed=12321)
    eval_input_fn = get_lyrics_input_fn(seed=32123)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

    evaluation = sequence_estimator.evaluate(
        input_fn=eval_input_fn, steps=eval_steps)
    accuracy = evaluation['accuracy']
    self.assertGreater(accuracy, accuracy_threshold,
                       'Accuracy should be higher than {}; got {}'.format(
                           accuracy_threshold, accuracy))
Beispiel #3
0
  def testLearnShiftByOne(self):
    """Tests that learning a 'shift-by-one' example.

    Each label sequence consists of the input sequence 'shifted' by one place.
    The RNN must learn to 'remember' the previous input.
    """
    batch_size = 16
    num_classes = 2
    num_unroll = 32
    sequence_length = 32
    train_steps = 200
    eval_steps = 20
    num_units = [4]
    learning_rate = 0.5
    accuracy_threshold = 0.9

    def get_shift_input_fn(sequence_length, seed=None):

      def input_fn():
        random_sequence = random_ops.random_uniform(
            [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed)
        labels = array_ops.slice(random_sequence, [0], [sequence_length])
        inputs = math_ops.to_float(
            array_ops.slice(random_sequence, [1], [sequence_length]))
        return {'inputs': inputs}, labels

      return input_fn

    seq_columns = [
        feature_column.real_valued_column(
            'inputs', dimension=1)
    ]
    config = run_config.RunConfig(tf_random_seed=21212)
    sequence_estimator = ssre.StateSavingRnnEstimator(
        constants.ProblemType.CLASSIFICATION,
        num_units=num_units,
        cell_type='lstm',
        num_unroll=num_unroll,
        batch_size=batch_size,
        sequence_feature_columns=seq_columns,
        num_classes=num_classes,
        learning_rate=learning_rate,
        config=config,
        predict_probabilities=True,
        queue_capacity=2 + batch_size,
        seed=1234)

    train_input_fn = get_shift_input_fn(sequence_length, seed=12321)
    eval_input_fn = get_shift_input_fn(sequence_length, seed=32123)

    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)

    evaluation = sequence_estimator.evaluate(
        input_fn=eval_input_fn, steps=eval_steps)
    accuracy = evaluation['accuracy']
    self.assertGreater(accuracy, accuracy_threshold,
                       'Accuracy should be higher than {}; got {}'.format(
                           accuracy_threshold, accuracy))

    # Testing `predict` when `predict_probabilities=True`.
    prediction_dict = sequence_estimator.predict(
        input_fn=eval_input_fn, as_iterable=False)
    self.assertListEqual(
        sorted(list(prediction_dict.keys())),
        sorted([
            prediction_key.PredictionKey.CLASSES,
            prediction_key.PredictionKey.PROBABILITIES, ssre._get_state_name(0)
        ]))
    predictions = prediction_dict[prediction_key.PredictionKey.CLASSES]
    probabilities = prediction_dict[prediction_key.PredictionKey.PROBABILITIES]
    self.assertListEqual(list(predictions.shape), [batch_size, sequence_length])
    self.assertListEqual(
        list(probabilities.shape), [batch_size, sequence_length, 2])