Example #1
0
 def test_build_to_ids(self):
     data = {'tokens': 'A B X'}
     oov, bos, eos, _ = dataset.get_special_tokens(len(VOCAB))
     expected = [bos, 0, 1, oov, eos]
     for max_seq_len in range(1, 8):
         to_ids_fn = dataset.build_to_ids_fn(VOCAB, max_seq_len)
         processed = to_ids_fn(data)
         seq_len = min(max_seq_len, len(expected)) + 1
         self.assertAllEqual(self.evaluate(processed), expected[:seq_len])
Example #2
0
 def test_batch_and_split(self):
     raw_data = {'tokens': 'A Z C'}
     oov, bos, eos, pad = dataset.get_special_tokens(len(VOCAB))
     expected = [bos, 0, oov, 2, eos, pad, pad, pad]
     for max_seq_len in range(1, 8):
         to_ids_fn = dataset.build_to_ids_fn(VOCAB, max_seq_len)
         data = tf.data.Dataset.from_tensor_slices([to_ids_fn(raw_data)])
         batched = dataset.batch_and_split(data, max_seq_len, pad)
         sample_elem = next(iter(batched))
         result = self.evaluate(sample_elem)
         correct = ([expected[:max_seq_len]], [expected[1:max_seq_len + 1]])
         self.assertAllEqual(result, correct)
Example #3
0
def run_experiment():
    """Runs the training experiment."""
    training_set, validation_set, test_set = (
        dataset.construct_word_level_datasets(
            vocab_size=FLAGS.vocab_size,
            batch_size=FLAGS.batch_size,
            client_epochs_per_round=1,
            max_seq_len=FLAGS.sequence_length,
            max_training_elements_per_user=-1,
            num_validation_examples=FLAGS.num_validation_examples,
            num_test_examples=FLAGS.num_test_examples))
    centralized_train = training_set.create_tf_dataset_from_all_clients()

    def _lstm_fn():
        return tf.keras.layers.LSTM(FLAGS.latent_size, return_sequences=True)

    model = models.create_recurrent_model(
        FLAGS.vocab_size,
        FLAGS.embedding_size,
        FLAGS.num_layers,
        _lstm_fn,
        'stackoverflow-lstm',
        shared_embedding=FLAGS.shared_embedding)
    logging.info('Training model: %s', model.summary())
    optimizer = utils_impl.create_optimizer_from_flags('centralized')
    model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy,
                  optimizer=optimizer,
                  weighted_metrics=['acc'])

    train_results_path = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                      'train_results')
    test_results_path = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                     'test_results')

    train_csv_logger = AtomicCSVLogger(train_results_path)
    test_csv_logger = AtomicCSVLogger(test_results_path)

    log_dir = os.path.join(FLAGS.root_output_dir, 'logdir', FLAGS.exp_name)
    try:
        tf.io.gfile.makedirs(log_dir)
        tf.io.gfile.makedirs(train_results_path)
        tf.io.gfile.makedirs(test_results_path)
    except tf.errors.OpError:
        pass  # log_dir already exists.

    train_tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir,
        write_graph=True,
        update_freq=FLAGS.tensorboard_update_frequency)

    test_tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

    results_file = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                'results.csv.bz2')

    # Write the hyperparameters to a CSV:
    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparam_dict['results_file'] = results_file
    hparams_file = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                'hparams.csv')
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    oov, bos, eos, pad = dataset.get_special_tokens(FLAGS.vocab_size)
    class_weight = {x: 1.0 for x in range(FLAGS.vocab_size)}
    class_weight[oov] = 0.0  # No credit for predicting OOV.
    class_weight[bos] = 0.0  # Shouldn't matter since this is never a target.
    class_weight[eos] = 1.0  # Model should learn to predict end of sentence.
    class_weight[pad] = 0.0  # No credit for predicting pad.

    model.fit(centralized_train,
              epochs=FLAGS.epochs,
              verbose=1,
              class_weight=class_weight,
              validation_data=validation_set,
              callbacks=[train_csv_logger, train_tensorboard_callback])
    score = model.evaluate(
        test_set,
        verbose=1,
        callbacks=[test_csv_logger, test_tensorboard_callback])
    logging.info('Final test loss: %.4f', score[0])
    logging.info('Final test accuracy: %.4f', score[1])
Example #4
0
def run_experiment():
    """Runs the training experiment."""
    try:
        tf.io.gfile.makedirs(
            os.path.join(FLAGS.root_output_dir, FLAGS.exp_name))
    except tf.errors.OpError:
        pass

    train_set, validation_set, test_set = (
        dataset.construct_word_level_datasets(
            vocab_size=FLAGS.vocab_size,
            client_batch_size=FLAGS.batch_size,
            client_epochs_per_round=1,
            max_seq_len=FLAGS.sequence_length,
            max_elements_per_user=FLAGS.max_elements_per_user,
            centralized_train=True,
            shuffle_buffer_size=None,
            num_validation_examples=FLAGS.num_validation_examples,
            num_test_examples=FLAGS.num_test_examples))

    recurrent_model = tf.keras.layers.LSTM if FLAGS.lstm else tf.keras.layers.GRU

    def _layer_fn():
        return recurrent_model(FLAGS.latent_size, return_sequences=True)

    pad, oov, _, eos = dataset.get_special_tokens(FLAGS.vocab_size)

    model = models.create_recurrent_model(
        FLAGS.vocab_size,
        FLAGS.embedding_size,
        FLAGS.num_layers,
        _layer_fn,
        'stackoverflow-recurrent',
        shared_embedding=FLAGS.shared_embedding)
    logging.info('Training model: %s', model.summary())
    optimizer = utils_impl.create_optimizer_from_flags('centralized')
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=optimizer,
        metrics=[
            metrics.MaskedCategoricalAccuracy([pad], 'accuracy_with_oov'),
            metrics.MaskedCategoricalAccuracy([pad, oov], 'accuracy_no_oov'),
            metrics.MaskedCategoricalAccuracy([pad, oov, eos],
                                              'accuracy_no_oov_no_eos')
        ])

    train_results_path = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                      'train_results')
    test_results_path = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                     'test_results')

    train_csv_logger = AtomicCSVLogger(train_results_path)
    test_csv_logger = AtomicCSVLogger(test_results_path)

    log_dir = os.path.join(FLAGS.root_output_dir, 'logdir', FLAGS.exp_name)
    try:
        tf.io.gfile.makedirs(log_dir)
        tf.io.gfile.makedirs(train_results_path)
        tf.io.gfile.makedirs(test_results_path)
    except tf.errors.OpError:
        pass  # log_dir already exists.

    train_tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir,
        write_graph=True,
        update_freq=FLAGS.tensorboard_update_frequency)

    test_tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

    results_file = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                'results.csv.bz2')

    # Write the hyperparameters to a CSV:
    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparam_dict['results_file'] = results_file
    hparams_file = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name,
                                'hparams.csv')
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)

    model.fit(train_set,
              epochs=FLAGS.epochs,
              verbose=1,
              steps_per_epoch=FLAGS.steps_per_epoch,
              validation_data=validation_set,
              callbacks=[train_csv_logger, train_tensorboard_callback])
    score = model.evaluate(
        test_set,
        verbose=1,
        callbacks=[test_csv_logger, test_tensorboard_callback])
    logging.info('Final test loss: %.4f', score[0])
    logging.info('Final test accuracy: %.4f', score[1])
Example #5
0
def run_experiment():
    """Runs the training experiment."""
    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(max_fanout=10))

    def _layer_fn():
        layer_type = tf.keras.layers.LSTM if FLAGS.lstm else tf.keras.layers.GRU
        return layer_type(FLAGS.latent_size, return_sequences=True)

    model_builder = functools.partial(models.create_recurrent_model,
                                      vocab_size=FLAGS.vocab_size,
                                      embedding_size=FLAGS.embedding_size,
                                      num_layers=FLAGS.num_layers,
                                      recurrent_layer_fn=_layer_fn,
                                      name='stackoverflow-recurrent',
                                      shared_embedding=FLAGS.shared_embedding)

    pad, oov, _, eos = dataset.get_special_tokens(FLAGS.vocab_size)

    train_set, validation_set, _ = (dataset.construct_word_level_datasets(
        FLAGS.vocab_size, FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round, FLAGS.sequence_length,
        FLAGS.max_elements_per_user, False, FLAGS.shuffle_buffer_size,
        FLAGS.num_validation_examples, FLAGS.num_test_examples))

    sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                         next(iter(validation_set)))

    def model_fn():
        """Defines the model."""
        keras_model = model_builder()
        train_metrics = [
            metrics.NumTokensCounter(name='num_tokens', masked_tokens=[pad]),
            metrics.NumTokensCounter(name='num_tokens_no_oov',
                                     masked_tokens=[pad, oov]),
            metrics.NumBatchesCounter(),
            metrics.NumExamplesCounter(),
            metrics.MaskedCategoricalAccuracy(name='accuracy',
                                              masked_tokens=[pad]),
            metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                              masked_tokens=[pad, oov]),
            metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov_no_eos',
                                              masked_tokens=[pad, oov, eos]),
        ]
        keras_model.compile(
            loss=tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True),
            optimizer=utils_impl.create_optimizer_from_flags('client'),
            metrics=train_metrics)
        return tff.learning.from_compiled_keras_model(keras_model,
                                                      sample_batch)

    def server_optimizer_fn():
        return utils_impl.create_optimizer_from_flags('server')

    def client_weight_fn(local_outputs):
        num_tokens = tf.cast(tf.squeeze(local_outputs['num_tokens']),
                             tf.float32)
        return 1.0 if FLAGS.uniform_weighting else num_tokens

    iterative_process = (
        tff.learning.federated_averaging.build_federated_averaging_process(
            model_fn=model_fn,
            server_optimizer_fn=server_optimizer_fn,
            client_weight_fn=client_weight_fn))

    server_state = iterative_process.initialize()
    for round_num in range(1, FLAGS.total_rounds + 1):
        sampled_clients = np.random.choice(train_set.client_ids,
                                           size=FLAGS.clients_per_round,
                                           replace=False)
        client_data = [
            train_set.create_tf_dataset_for_client(client)
            for client in sampled_clients
        ]
        server_state, server_metrics = iterative_process.next(
            server_state, client_data)
        print('Round: {}'.format(round_num))
        print('   Loss: {:.8f}'.format(server_metrics.loss))
        print('   num_batches: {}'.format(server_metrics.num_batches))
        print('   num_examples: {}'.format(server_metrics.num_examples))
        print('   num_tokens: {}'.format(server_metrics.num_tokens))
        print('   num_tokens_no_oov: {}'.format(
            server_metrics.num_tokens_no_oov))
        print('   accuracy: {:.5f}'.format(server_metrics.accuracy))
        print('   accuracy_no_oov: {:.5f}'.format(
            server_metrics.accuracy_no_oov))