def testClassifierConstructor(self): batch_size = 16 num_classes = 2 num_unroll = 32 sequence_length = 32 num_units = 4 learning_rate = 0.5 steps = 100 input_fn = self._get_input_fn(sequence_length, seed=1234) model_dir = tempfile.mkdtemp() seq_columns = [ feature_column.real_valued_column( 'inputs', dimension=num_units) ] estimator = ssre.multi_value_rnn_classifier(num_classes, num_units, num_unroll, batch_size, seq_columns, learning_rate=learning_rate, model_dir=model_dir, queue_capacity=batch_size+2, seed=1234) estimator.fit(input_fn=input_fn, steps=steps)
def estimator_fn(): return ssre.multi_value_rnn_classifier( num_classes=num_classes, num_units=cell_size, num_unroll=num_unroll, batch_size=batch_size, input_key_column_name=input_key_column_name, sequence_feature_columns=seq_columns, predict_probabilities=True, model_dir=model_dir, queue_capacity=2 + batch_size)
def testLearnLyrics(self): lyrics = 'if I go there will be trouble and if I stay it will be double' lyrics_list = lyrics.split() sequence_length = len(lyrics_list) vocab = set(lyrics_list) batch_size = 16 num_classes = len(vocab) num_unroll = 5 # not a divisor of sequence_length train_steps = 300 eval_steps = 30 num_units = 4 learning_rate = 0.4 accuracy_threshold = 0.70 input_key_column_name = 'input_key_column' def get_lyrics_input_fn(seed): def input_fn(): start = random_ops.random_uniform( (), minval=0, maxval=sequence_length, dtype=dtypes.int32, seed=seed) # Concatenate lyrics_list so inputs and labels wrap when start > 0. lyrics_list_concat = lyrics_list + lyrics_list inputs_dense = array_ops.slice(lyrics_list_concat, [start], [sequence_length]) indices = array_ops.constant( [[i, 0] for i in range(sequence_length)], dtype=dtypes.int64) dense_shape = [sequence_length, 1] inputs = sparse_tensor.SparseTensor( indices=indices, values=inputs_dense, dense_shape=dense_shape) table = lookup.string_to_index_table_from_tensor( mapping=list(vocab), default_value=-1, name='lookup') labels = table.lookup( array_ops.slice(lyrics_list_concat, [start + 1], [sequence_length])) input_key = string_ops.string_join([ 'key_', string_ops.as_string( random_ops.random_uniform( (), minval=0, maxval=10000000, dtype=dtypes.int32, seed=seed)) ]) return {'lyrics': inputs, input_key_column_name: input_key}, labels return input_fn sequence_feature_columns = [ feature_column.embedding_column( feature_column.sparse_column_with_keys('lyrics', vocab), dimension=8) ] config = run_config.RunConfig(tf_random_seed=21212) sequence_estimator = ssre.multi_value_rnn_classifier( num_classes=num_classes, num_units=num_units, num_unroll=num_unroll, batch_size=batch_size, input_key_column_name=input_key_column_name, sequence_feature_columns=sequence_feature_columns, learning_rate=learning_rate, config=config, predict_probabilities=True, queue_capacity=2 + batch_size) train_input_fn = get_lyrics_input_fn(seed=12321) eval_input_fn = get_lyrics_input_fn(seed=32123) sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) evaluation = sequence_estimator.evaluate( input_fn=eval_input_fn, steps=eval_steps) accuracy = evaluation['accuracy'] self.assertGreater(accuracy, accuracy_threshold, 'Accuracy should be higher than {}; got {}'.format( accuracy_threshold, accuracy))
def testLearnShiftByOne(self): """Tests that learning a 'shift-by-one' example. Each label sequence consists of the input sequence 'shifted' by one place. The RNN must learn to 'remember' the previous input. """ batch_size = 16 num_classes = 2 num_unroll = 32 sequence_length = 32 train_steps = 200 eval_steps = 20 num_units = 4 learning_rate = 0.5 accuracy_threshold = 0.9 input_key_column_name = 'input_key_column' def get_shift_input_fn(sequence_length, seed=None): def input_fn(): random_sequence = random_ops.random_uniform( [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) labels = array_ops.slice(random_sequence, [0], [sequence_length]) inputs = math_ops.to_float( array_ops.slice(random_sequence, [1], [sequence_length])) input_key = string_ops.string_join([ 'key_', string_ops.as_string( random_ops.random_uniform( (), minval=0, maxval=10000000, dtype=dtypes.int32, seed=seed)) ]) return {'inputs': inputs, input_key_column_name: input_key}, labels return input_fn seq_columns = [ feature_column.real_valued_column( 'inputs', dimension=num_units) ] config = run_config.RunConfig(tf_random_seed=21212) sequence_estimator = ssre.multi_value_rnn_classifier( num_classes=num_classes, num_units=num_units, num_unroll=num_unroll, batch_size=batch_size, input_key_column_name=input_key_column_name, sequence_feature_columns=seq_columns, learning_rate=learning_rate, config=config, predict_probabilities=True, queue_capacity=2 + batch_size) train_input_fn = get_shift_input_fn(sequence_length, seed=12321) eval_input_fn = get_shift_input_fn(sequence_length, seed=32123) sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) evaluation = sequence_estimator.evaluate( input_fn=eval_input_fn, steps=eval_steps) accuracy = evaluation['accuracy'] self.assertGreater(accuracy, accuracy_threshold, 'Accuracy should be higher than {}; got {}'.format( accuracy_threshold, accuracy)) # Testing `predict` when `predict_probabilities=True`. prediction_dict = sequence_estimator.predict( input_fn=eval_input_fn, as_iterable=False) self.assertListEqual( sorted(list(prediction_dict.keys())), sorted([ ssre.RNNKeys.PREDICTIONS_KEY, ssre.RNNKeys.PROBABILITIES_KEY, ssre._get_state_name(0) ])) predictions = prediction_dict[ssre.RNNKeys.PREDICTIONS_KEY] probabilities = prediction_dict[ssre.RNNKeys.PROBABILITIES_KEY] self.assertListEqual(list(predictions.shape), [batch_size, sequence_length]) self.assertListEqual( list(probabilities.shape), [batch_size, sequence_length, 2])
def testLearnShiftByOne(self): """Tests that learning a 'shift-by-one' example. Each label sequence consists of the input sequence 'shifted' by one place. The RNN must learn to 'remember' the previous input. """ batch_size = 16 num_classes = 2 num_unroll = 32 sequence_length = 32 train_steps = 200 eval_steps = 20 cell_size = 4 learning_rate = 0.5 accuracy_threshold = 0.9 input_key_column_name = 'input_key_column' def get_shift_input_fn(sequence_length, seed=None): def input_fn(): random_sequence = random_ops.random_uniform( [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) labels = array_ops.slice(random_sequence, [0], [sequence_length]) inputs = math_ops.to_float( array_ops.slice(random_sequence, [1], [sequence_length])) input_key = string_ops.string_join([ 'key_', string_ops.as_string( random_ops.random_uniform((), minval=0, maxval=10000000, dtype=dtypes.int32, seed=seed)) ]) return { 'inputs': inputs, input_key_column_name: input_key }, labels return input_fn seq_columns = [ feature_column.real_valued_column('inputs', dimension=cell_size) ] config = run_config.RunConfig(tf_random_seed=21212) sequence_estimator = ssre.multi_value_rnn_classifier( num_classes=num_classes, num_units=cell_size, num_unroll=num_unroll, batch_size=batch_size, input_key_column_name=input_key_column_name, sequence_feature_columns=seq_columns, learning_rate=learning_rate, config=config, predict_probabilities=True, queue_capacity=2 + batch_size) train_input_fn = get_shift_input_fn(sequence_length, seed=12321) eval_input_fn = get_shift_input_fn(sequence_length, seed=32123) sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) evaluation = sequence_estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) accuracy = evaluation['accuracy'] self.assertGreater( accuracy, accuracy_threshold, 'Accuracy should be higher than {}; got {}'.format( accuracy_threshold, accuracy)) # Testing `predict` when `predict_probabilities=True`. prediction_dict = sequence_estimator.predict(input_fn=eval_input_fn, as_iterable=False) self.assertListEqual( sorted(list(prediction_dict.keys())), sorted([ ssre.RNNKeys.PREDICTIONS_KEY, ssre.RNNKeys.PROBABILITIES_KEY, ssre._get_state_name(0) ])) predictions = prediction_dict[ssre.RNNKeys.PREDICTIONS_KEY] probabilities = prediction_dict[ssre.RNNKeys.PROBABILITIES_KEY] self.assertListEqual(list(predictions.shape), [batch_size, sequence_length]) self.assertListEqual(list(probabilities.shape), [batch_size, sequence_length, 2])
def testLearnLyrics(self): lyrics = 'if I go there will be trouble and if I stay it will be double' lyrics_list = lyrics.split() sequence_length = len(lyrics_list) vocab = set(lyrics_list) batch_size = 16 num_classes = len(vocab) num_unroll = 5 # not a divisor of sequence_length train_steps = 300 eval_steps = 30 num_units = 4 learning_rate = 0.4 accuracy_threshold = 0.70 input_key_column_name = 'input_key_column' def get_lyrics_input_fn(seed): def input_fn(): start = random_ops.random_uniform((), minval=0, maxval=sequence_length, dtype=dtypes.int32, seed=seed) # Concatenate lyrics_list so inputs and labels wrap when start > 0. lyrics_list_concat = lyrics_list + lyrics_list inputs_dense = array_ops.slice(lyrics_list_concat, [start], [sequence_length]) indices = array_ops.constant([[i, 0] for i in range(sequence_length)], dtype=dtypes.int64) dense_shape = [sequence_length, 1] inputs = sparse_tensor.SparseTensor(indices=indices, values=inputs_dense, dense_shape=dense_shape) table = lookup.string_to_index_table_from_tensor( mapping=list(vocab), default_value=-1, name='lookup') labels = table.lookup( array_ops.slice(lyrics_list_concat, [start + 1], [sequence_length])) input_key = string_ops.string_join([ 'key_', string_ops.as_string( random_ops.random_uniform((), minval=0, maxval=10000000, dtype=dtypes.int32, seed=seed)) ]) return { 'lyrics': inputs, input_key_column_name: input_key }, labels return input_fn sequence_feature_columns = [ feature_column.embedding_column( feature_column.sparse_column_with_keys('lyrics', vocab), dimension=8) ] config = run_config.RunConfig(tf_random_seed=21212) sequence_estimator = ssre.multi_value_rnn_classifier( num_classes=num_classes, num_units=num_units, num_unroll=num_unroll, batch_size=batch_size, input_key_column_name=input_key_column_name, sequence_feature_columns=sequence_feature_columns, learning_rate=learning_rate, config=config, predict_probabilities=True, queue_capacity=2 + batch_size) train_input_fn = get_lyrics_input_fn(seed=12321) eval_input_fn = get_lyrics_input_fn(seed=32123) sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) evaluation = sequence_estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) accuracy = evaluation['accuracy'] self.assertGreater( accuracy, accuracy_threshold, 'Accuracy should be higher than {}; got {}'.format( accuracy_threshold, accuracy))