コード例 #1
0
  def test_launch_experiment(self, mock_multiprocessing):
    pool = mock_multiprocessing.Pool(processes=10)

    grid_dict = [
        collections.OrderedDict([('a_long', 1), ('b', 4.0)]),
        collections.OrderedDict([('a_long', 1), ('b', 5.0)])
    ]

    utils.launch_experiment(
        'run_exp.py', grid_dict, '/tmp_dir', short_names={'a_long': 'a'})
    expected = [
        'python run_exp.py --a_long=1 --b=4.0 --root_output_dir=/tmp_dir '
        '--exp_name=0-a=1,b=4.0',
        'python run_exp.py --a_long=1 --b=5.0 --root_output_dir=/tmp_dir '
        '--exp_name=1-a=1,b=5.0'
    ]
    result = pool.apply_async.call_args_list
    result = [args[0][1][0] for args in result]
    self.assertCountEqual(result, expected)
コード例 #2
0
ファイル: run_experiments.py プロジェクト: FreJoe/tff-0.4.0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    target = '//tensorflow_federated/python/research/gans/experiments/emnist:train'
    executable = 'bazel run {} --'.format(target)

    grid_iter = utils_impl.iter_grid({
        'filtering': ['by_user'],
        'invert_imagery_probability': ['0p0', '0p5'],
        'accuracy_threshold': ['lt0p882', 'gt0p939'],
        'num_client_disc_train_steps': [6],
        'num_server_gen_train_steps': [6],
        'num_clients_per_round': [10],
        'num_rounds': [1000],
        'use_dp': [True],
        'dp_l2_norm_clip': [0.1],
        'dp_noise_multiplier': [0.01],
        'num_rounds_per_eval': [10],
        'num_rounds_per_save_images': [10]
    })

    utils_impl.launch_experiment(executable,
                                 grid_iter,
                                 root_output_dir='/tmp/exp',
                                 short_names={
                                     'filtering': 'filt',
                                     'invert_imagery_probability': 'inv_lik',
                                     'accuracy_threshold': 'acc',
                                     'num_client_disc_train_steps': 'n_disc',
                                     'num_server_gen_train_steps': 'n_gen',
                                     'dp_l2_norm_clip': 'dp_clip',
                                     'dp_noise_multiplier': 'dp_noise',
                                     'num_rounds_per_eval': 'n_rds_eval',
                                     'num_rounds_per_save_images':
                                     'n_rds_images'
                                 },
                                 max_workers=1)
    print('Experiments launched.')