def test_record_new_flags(self):
        with utils_impl.record_new_flags() as hparam_flags:
            flags.DEFINE_string('exp_name', 'name',
                                'Unique name for the experiment.')
            flags.DEFINE_float('learning_rate', 0.1,
                               'Optimizer learning rate.')

        self.assertCountEqual(hparam_flags, ['exp_name', 'learning_rate'])
    def test_convert_flag_names_to_odict(self):
        with utils_impl.record_new_flags() as hparam_flags:
            flags.DEFINE_integer('flag1', 1, 'This is the first flag.')
            flags.DEFINE_float('flag2', 2.0, 'This is the second flag.')

        hparam_odict = utils_impl.lookup_flag_values(hparam_flags)
        expected_odict = collections.OrderedDict(flag1=1, flag2=2.0)

        self.assertEqual(hparam_odict, expected_odict)
Esempio n. 3
0
import os.path

from absl import app
from absl import flags
import tensorflow as tf
import tensorflow_federated as tff

from compression import sparsity
from utils import training_loop
from utils import training_utils
from utils import utils_impl
from utils.datasets import emnist_dataset
from utils.models import emnist_models
from tensorflow_model_optimization.python.core.internal import tensor_encoding as te

with utils_impl.record_new_flags():
    # Training hyperparameters
    flags.DEFINE_integer('clients_per_round', 2,
                         'How many clients to sample per round.')
    flags.DEFINE_integer('client_epochs_per_round', 1,
                         'Number of epochs in the client to take per round.')
    flags.DEFINE_integer('client_batch_size', 20,
                         'Batch size used on the client.')
    flags.DEFINE_boolean(
        'only_digits', True, 'Whether to use the digit-only '
        'EMNIST dataset (10 characters) or the extended EMNIST '
        'dataset (62 characters).')

    # Optimizer configuration (this defines one or more flags per optimizer).
    utils_impl.define_optimizer_flags('server')
    utils_impl.define_optimizer_flags('client')
Esempio n. 4
0
from absl import app
from absl import flags
from absl import logging
import pandas as pd
import tensorflow as tf
import tensorflow_federated as tff
import tree

from flars import flars_fedavg
from flars import flars_optimizer
from utils import checkpoint_manager
from utils import utils_impl
from utils.models import emnist_models
from tensorboard.plugins.hparams import api as hp

with utils_impl.record_new_flags() as hparam_flags:
  # Metadata
  flags.DEFINE_string(
      'exp_name', 'emnist', 'Unique name for the experiment, suitable for use '
      'in filenames.')

  # Training hyperparameters
  flags.DEFINE_boolean(
      'digit_only_emnist', True,
      'Whether to train on the digits only (10 classes) data '
      'or the full data (62 classes).')
  flags.DEFINE_integer('total_rounds', 500, 'Number of total training rounds.')
  flags.DEFINE_integer('rounds_per_eval', 1, 'How often to evaluate')
  flags.DEFINE_integer(
      'rounds_per_checkpoint', 25,
      'How often to emit a state checkpoint. Higher numbers '