Ejemplo n.º 1
0
  def test_global_emnist_dataset_structure(self):
    global_train, global_test = dataset.get_centralized_emnist_datasets(
        batch_size=32, only_digits=False)

    train_batch = next(iter(global_train))
    train_batch_shape = train_batch[0].shape
    test_batch = next(iter(global_test))
    test_batch_shape = test_batch[0].shape
    self.assertEqual(train_batch_shape.as_list(), [32, 28, 28, 1])
    self.assertEqual(test_batch_shape.as_list(), [TEST_BATCH_SIZE, 28, 28, 1])
Ejemplo n.º 2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.compat.v1.enable_v2_behavior()

    experiment_output_dir = FLAGS.root_output_dir
    tensorboard_dir = os.path.join(experiment_output_dir, 'logdir',
                                   FLAGS.experiment_name)
    results_dir = os.path.join(experiment_output_dir, 'results',
                               FLAGS.experiment_name)

    for path in [experiment_output_dir, tensorboard_dir, results_dir]:
        try:
            tf.io.gfile.makedirs(path)
        except tf.errors.OpError:
            pass  # Directory already exists.

    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparam_dict['results_file'] = results_dir
    hparams_file = os.path.join(results_dir, 'hparams.csv')

    logging.info('Saving hyper parameters to: [%s]', hparams_file)
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    train_dataset, eval_dataset = dataset.get_centralized_emnist_datasets(
        batch_size=FLAGS.batch_size, only_digits=False)

    optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')()

    if FLAGS.model == 'cnn':
        model = models.create_conv_dropout_model(only_digits=False)
    elif FLAGS.model == '2nn':
        model = models.create_two_hidden_layer_model(only_digits=False)
    else:
        raise ValueError('Cannot handle model flag [{!s}].'.format(
            FLAGS.model))

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                  optimizer=optimizer,
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    logging.info('Training model:')
    logging.info(model.summary())

    csv_logger_callback = keras_callbacks.AtomicCSVLogger(results_dir)
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=tensorboard_dir)

    # Reduce the learning rate after a fixed number of epochs.
    def decay_lr(epoch, learning_rate):
        if (epoch + 1) % FLAGS.decay_epochs == 0:
            return learning_rate * FLAGS.lr_decay
        else:
            return learning_rate

    lr_callback = tf.keras.callbacks.LearningRateScheduler(decay_lr, verbose=1)

    history = model.fit(
        train_dataset,
        validation_data=eval_dataset,
        epochs=FLAGS.num_epochs,
        callbacks=[lr_callback, tensorboard_callback, csv_logger_callback])

    logging.info('Final metrics:')
    for name in ['loss', 'sparse_categorical_accuracy']:
        metric = history.history['val_{}'.format(name)][-1]
        logging.info('\t%s: %.4f', name, metric)