def test_gru_fewer_parameters_than_lstm(self):
        def _gru_fn(x):
            return tf.keras.layers.GRU(x, return_sequences=True)

        def _lstm_fn(x):
            return tf.keras.layers.LSTM(x, return_sequences=True)

        gru_model = models.create_recurrent_model(10, _gru_fn, 'gru')
        lstm_model = models.create_recurrent_model(10, _lstm_fn, 'lstm')
        self.assertLess(gru_model.count_params(), lstm_model.count_params())
    def test_gru_constructs(self):
        def _recurrent_layer_fn(x):
            return tf.keras.layers.GRU(x, return_sequences=True)

        model = models.create_recurrent_model(10, _recurrent_layer_fn,
                                              'rnn-gru')
        self.assertIsInstance(model, tf.keras.Model)
        self.assertEqual('rnn-gru', model.name)
Exemple #3
0
 def test_constructs(self):
     model = models.create_recurrent_model(10, name='rnn-lstm')
     self.assertIsInstance(model, tf.keras.Model)
     self.assertEqual('rnn-lstm', model.name)
Exemple #4
0
def run_experiment():
    """Runs the training experiment."""
    (_, stackoverflow_validation,
     stackoverflow_test) = dataset.construct_word_level_datasets(
         FLAGS.vocab_size, FLAGS.batch_size, 1, FLAGS.sequence_length, -1,
         FLAGS.num_validation_examples)
    centralized_train = dataset.get_centralized_train_dataset(
        FLAGS.vocab_size, FLAGS.batch_size, FLAGS.sequence_length,
        FLAGS.shuffle_buffer_size)

    def _lstm_fn(latent_size):
        return tf.keras.layers.LSTM(latent_size, return_sequences=True)

    model = models.create_recurrent_model(
        FLAGS.vocab_size,
        _lstm_fn,
        'stackoverflow-lstm',
        shared_embedding=FLAGS.shared_embedding)
    logging.info('Training model: %s', model.summary())
    optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')()
    pad_token, oov_token, _, eos_token = dataset.get_special_tokens(
        FLAGS.vocab_size)
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=optimizer,
        metrics=[
            # Plus 4 for pad, oov, bos, eos
            keras_metrics.FlattenedCategoricalAccuracy(
                vocab_size=FLAGS.vocab_size + 4,
                name='accuracy_with_oov',
                masked_tokens=pad_token),
            keras_metrics.FlattenedCategoricalAccuracy(
                vocab_size=FLAGS.vocab_size + 4,
                name='accuracy_no_oov',
                masked_tokens=[pad_token, oov_token]),
            keras_metrics.FlattenedCategoricalAccuracy(
                vocab_size=FLAGS.vocab_size + 4,
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, oov_token, eos_token]),
        ])

    train_results_path = os.path.join(FLAGS.root_output_dir, 'train_results',
                                      FLAGS.experiment_name)
    test_results_path = os.path.join(FLAGS.root_output_dir, 'test_results',
                                     FLAGS.experiment_name)

    train_csv_logger = keras_callbacks.AtomicCSVLogger(train_results_path)
    test_csv_logger = keras_callbacks.AtomicCSVLogger(test_results_path)

    log_dir = os.path.join(FLAGS.root_output_dir, 'logdir',
                           FLAGS.experiment_name)
    try:
        tf.io.gfile.makedirs(log_dir)
        tf.io.gfile.makedirs(train_results_path)
        tf.io.gfile.makedirs(test_results_path)
    except tf.errors.OpError:
        pass  # log_dir already exists.

    train_tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir,
        write_graph=True,
        update_freq=FLAGS.tensorboard_update_frequency)

    test_tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

    # Write the hyperparameters to a CSV:
    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparams_file = os.path.join(FLAGS.root_output_dir, FLAGS.experiment_name,
                                'hparams.csv')
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    model.fit(centralized_train,
              epochs=FLAGS.epochs,
              verbose=0,
              validation_data=stackoverflow_validation,
              callbacks=[train_csv_logger, train_tensorboard_callback])
    score = model.evaluate(
        stackoverflow_test,
        verbose=0,
        callbacks=[test_csv_logger, test_tensorboard_callback])
    logging.info('Final test loss: %.4f', score[0])
    logging.info('Final test accuracy: %.4f', score[1])
    def test_dense_fn_raises(self):
        def _dense_layer_fn(x):
            return tf.keras.layers.Dense(x)

        with self.assertRaisesRegex(ValueError, 'tf.keras.layers.RNN'):
            models.create_recurrent_model(10, _dense_layer_fn, 'dense')