Ejemplo n.º 1
0
  def test_emnist_score(self):
    score = eeu.emnist_score(self.fake_images,
                             ecm.get_trained_emnist_classifier_model())
    self.assertAllClose(score, 1.1598, rtol=0.0001, atol=0.0001)

    score = eeu.emnist_score(self.real_images,
                             ecm.get_trained_emnist_classifier_model())
    self.assertAllClose(score, 3.9547, rtol=0.0001, atol=0.0001)
Ejemplo n.º 2
0
  def test_emnist_frechet_distance(self):
    distance = eeu.emnist_frechet_distance(
        self.real_images, self.fake_images,
        ecm.get_trained_emnist_classifier_model())
    self.assertAllClose(distance, 568.6883, rtol=0.0001, atol=0.0001)

    distance = eeu.emnist_frechet_distance(
        self.real_images, self.real_images,
        ecm.get_trained_emnist_classifier_model())
    self.assertAllClose(distance, 0.0)
Ejemplo n.º 3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Flags.
    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    for k in hparam_dict.keys():
        if hparam_dict[k] is None:
            hparam_dict[k] = 'None'
    for k, v in hparam_dict.items():
        print('{} : {} '.format(k, v))

    if FLAGS.invert_imagery_likelihood > 1.0:
        raise ValueError(
            'invert_imagery_likelihood cannot be greater than 1.0')
    if FLAGS.bad_accuracy_cutoff > 1.0:
        raise ValueError('bad_accuracy_cutoff cannot be greater than 1.0')
    if FLAGS.good_accuracy_cutoff > 1.0:
        raise ValueError('good_accuracy_cutoff cannot be greater than 1.0')

    # Training datasets.
    client_real_images_train_tff_data = (
        emnist_data_utils.create_real_images_tff_client_data('train'))

    print('There are %d unique clients.' %
          len(client_real_images_train_tff_data.client_ids))

    # Trained classifier model.
    classifier_model = ecm.get_trained_emnist_classifier_model()

    # Filter down to those client IDs that fall within some accuracy cutoff.
    bad_client_ids_inversion_map, good_client_ids_inversion_map = (
        _get_client_ids_meeting_condition(client_real_images_train_tff_data,
                                          FLAGS.bad_accuracy_cutoff,
                                          FLAGS.good_accuracy_cutoff,
                                          FLAGS.invert_imagery_likelihood,
                                          classifier_model))

    print(
        'There are %d unique clients meeting bad accuracy cutoff condition.' %
        len(bad_client_ids_inversion_map))
    print(
        'There are %d unique clients meeting good accuracy cutoff condition.' %
        len(good_client_ids_inversion_map))

    # Save selected client id dictionary to csv.
    with tf.io.gfile.GFile(FLAGS.path_to_save_bad_clients_csv, 'w') as csvfile:
        w = csv.writer(csvfile)
        for key, val in bad_client_ids_inversion_map.items():
            w.writerow([key, val])

    with tf.io.gfile.GFile(FLAGS.path_to_save_good_clients_csv,
                           'w') as csvfile:
        w = csv.writer(csvfile)
        for key, val in good_client_ids_inversion_map.items():
            w.writerow([key, val])

    print('CSV files with selected Federated EMNIST clients have been saved.')
Ejemplo n.º 4
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    logging.set_verbosity(logging.INFO)

    # Flags.
    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    for k in hparam_dict.keys():
        if hparam_dict[k] is None:
            hparam_dict[k] = 'None'
    for k, v in hparam_dict.items():
        print('{} : {} '.format(k, v))

    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(
            num_clients=FLAGS.num_clients_per_round))

    # Trained classifier model.
    classifier_model = ecm.get_trained_emnist_classifier_model()

    # GAN Models.
    disc_model_fn, gen_model_fn = _get_gan_network_models(FLAGS.noise_dim)

    # Training datasets.
    server_gen_inputs_dataset = _create_gen_inputs_dataset(
        batch_size=FLAGS.server_train_batch_size, noise_dim=FLAGS.noise_dim)
    client_gen_inputs_dataset = _create_gen_inputs_dataset(
        batch_size=CLIENT_TRAIN_BATCH_SIZE, noise_dim=FLAGS.noise_dim)

    if FLAGS.filtering == 'by_user':
        client_real_images_train_tff_data = (
            fedu.get_filtered_by_user_client_data_for_training(
                invert_imagery_probability=FLAGS.invert_imagery_probability,
                accuracy_threshold=FLAGS.accuracy_threshold,
                batch_size=CLIENT_TRAIN_BATCH_SIZE))
    elif FLAGS.filtering == 'by_example':
        client_real_images_train_tff_data = (
            fedu.get_filtered_by_example_client_data_for_training(
                invert_imagery_probability=FLAGS.invert_imagery_probability,
                min_num_examples=FLAGS.min_num_examples,
                example_class_selection=FLAGS.example_class_selection,
                batch_size=CLIENT_TRAIN_BATCH_SIZE))
    else:
        client_real_images_train_tff_data = (
            fedu.get_unfiltered_client_data_for_training(
                batch_size=CLIENT_TRAIN_BATCH_SIZE))

    print('There are %d unique clients that will be used for GAN training.' %
          len(client_real_images_train_tff_data.client_ids))

    # Training: GAN Losses and Optimizers.
    gan_loss_fns = gan_losses.WassersteinGanLossFns(
        grad_penalty_lambda=FLAGS.wass_gp_lambda)
    disc_optimizer = tf.keras.optimizers.SGD(lr=0.0005)
    gen_optimizer = tf.keras.optimizers.SGD(lr=0.005)

    # Eval datasets.
    gen_inputs_eval_dataset = _create_gen_inputs_dataset(
        batch_size=EVAL_BATCH_SIZE, noise_dim=FLAGS.noise_dim)
    real_images_eval_dataset = _create_real_images_dataset_for_eval()

    # Eval hook.
    path_to_output_images = _get_path_to_output_image(FLAGS.root_output_dir,
                                                      FLAGS.exp_name)
    logging.info('path_to_output_images is %s', path_to_output_images)
    eval_hook_fn = _get_emnist_eval_hook_fn(
        FLAGS.exp_name, FLAGS.root_output_dir, hparam_dict, gan_loss_fns,
        gen_inputs_eval_dataset, real_images_eval_dataset,
        FLAGS.num_rounds_per_save_images, path_to_output_images,
        classifier_model)

    # Form the GAN.
    gan = _get_gan(gen_model_fn,
                   disc_model_fn,
                   gan_loss_fns,
                   gen_optimizer,
                   disc_optimizer,
                   server_gen_inputs_dataset,
                   client_real_images_train_tff_data,
                   use_dp=FLAGS.use_dp,
                   dp_l2_norm_clip=FLAGS.dp_l2_norm_clip,
                   dp_noise_multiplier=FLAGS.dp_noise_multiplier,
                   clients_per_round=FLAGS.num_clients_per_round)

    # Training.
    _, tff_time = _train(gan,
                         server_gen_inputs_dataset,
                         client_gen_inputs_dataset,
                         client_real_images_train_tff_data,
                         FLAGS.num_client_disc_train_steps,
                         FLAGS.num_server_gen_train_steps,
                         FLAGS.num_clients_per_round,
                         FLAGS.num_rounds,
                         FLAGS.num_rounds_per_eval,
                         eval_hook_fn,
                         FLAGS.num_rounds_per_checkpoint,
                         output_dir=FLAGS.root_output_dir,
                         exp_name=FLAGS.exp_name)
    logging.info('Total training time was %4.3f seconds.', tff_time)

    print('\nTRAINING COMPLETE.')
Ejemplo n.º 5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    invert_imagery_likelihood = FLAGS.invert_imagery_likelihood
    print('invert_imagery_likelihood is %s' % invert_imagery_likelihood)
    if invert_imagery_likelihood > 1.0:
        raise ValueError(
            'invert_imagery_likelihood cannot be greater than 1.0')

    # TFF Dataset.
    client_real_images_tff_data = (
        emnist_data_utils.create_real_images_tff_client_data(split='train'))
    print('There are %d unique clients.' %
          len(client_real_images_tff_data.client_ids))

    # EMNIST Classifier.
    classifier_model = ecm.get_trained_emnist_classifier_model()

    accuracy_list = []
    overall_total_count = 0
    overall_correct_count = 0
    for client_id in client_real_images_tff_data.client_ids:
        invert_imagery = (1 == np.random.binomial(n=1,
                                                  p=invert_imagery_likelihood))

        # TF Dataset for particular client.
        raw_images_ds = client_real_images_tff_data.create_tf_dataset_for_client(
            client_id)
        # Preprocess into format expected by classifier.
        images_ds = emnist_data_utils.preprocess_img_dataset(
            raw_images_ds,
            invert_imagery=invert_imagery,
            include_label=True,
            batch_size=None,
            shuffle=False,
            repeat=False)
        # Run classifier on all data on client, compute % classified correctly.
        total_count, correct_count = _analyze_classifier(
            images_ds, classifier_model)
        accuracy = float(correct_count) / float(total_count)
        accuracy_list.append(accuracy)

        overall_total_count += total_count
        overall_correct_count += correct_count

    # Calculate histogram.
    bin_width = 1
    histogram = _compute_histogram(accuracy_list, bin_width)
    print('\nHistogram:')
    print(histogram.numpy())
    # Sanity check (should be 3400)
    print('(Histogram sum):')
    print(sum(histogram.numpy()))

    # Calculate percentile values.
    percentile_25, percentile_75 = np.percentile(accuracy_list, q=(25, 75))
    print('\nPercentiles...')
    print('25th Percentile : %f' % percentile_25)
    print('75th Percentile : %f' % percentile_75)

    overall_accuracy = (float(overall_correct_count) /
                        float(overall_total_count))
    print('\nOverall classification success percentage: %d / %d (%f)' %
          (overall_correct_count, overall_total_count, overall_accuracy))
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Flags.
  hparam_dict = collections.OrderedDict([
      (name, FLAGS[name].value) for name in hparam_flags
  ])
  for k in hparam_dict.keys():
    if hparam_dict[k] is None:
      hparam_dict[k] = 'None'
  for k, v in hparam_dict.items():
    print('{} : {} '.format(k, v))

  if FLAGS.invert_imagery_likelihood > 1.0:
    raise ValueError('invert_imagery_likelihood cannot be greater than 1.0')

  # Training datasets.
  client_real_images_train_tff_data = (
      emnist_data_utils.create_real_images_tff_client_data('train'))

  print('There are %d unique clients.' %
        len(client_real_images_train_tff_data.client_ids))

  # Trained classifier model.
  classifier_model = ecm.get_trained_emnist_classifier_model()

  # Filter down to those client IDs that fall within some accuracy cutoff.
  (client_ids_with_correct_examples_map, client_ids_with_incorrect_examples_map,
   client_ids_correct_example_indices_map,
   client_ids_incorrect_example_indices_map) = (
       _get_client_ids_and_examples_based_on_classification(
           client_real_images_train_tff_data, FLAGS.min_num_examples,
           FLAGS.invert_imagery_likelihood, classifier_model))

  print('There are %d unique clients with at least %d correct examples.' %
        (len(client_ids_with_correct_examples_map), FLAGS.min_num_examples))
  print('There are %d unique clients with at least %d incorrect examples.' %
        (len(client_ids_with_incorrect_examples_map), FLAGS.min_num_examples))

  # Save client id dictionarys to csv.
  with tf.io.gfile.GFile(FLAGS.path_to_save_clients_with_correct_examples_csv,
                         'w') as csvfile:
    w = csv.writer(csvfile)
    for key, val in client_ids_with_correct_examples_map.items():
      w.writerow([key, val])

  with tf.io.gfile.GFile(FLAGS.path_to_save_clients_with_incorrect_examples_csv,
                         'w') as csvfile:
    w = csv.writer(csvfile)
    for key, val in client_ids_with_incorrect_examples_map.items():
      w.writerow([key, val])

  with tf.io.gfile.GFile(FLAGS.path_to_save_correct_example_indices_csv,
                         'w') as csvfile:
    w = csv.writer(csvfile)
    for key, val in client_ids_correct_example_indices_map.items():
      w.writerow([key, val])

  with tf.io.gfile.GFile(FLAGS.path_to_save_incorrect_example_indices_csv,
                         'w') as csvfile:
    w = csv.writer(csvfile)
    for key, val in client_ids_incorrect_example_indices_map.items():
      w.writerow([key, val])

  print('CSV files with selected Federated EMNIST clients and lists of '
        'classified/misclassified examples have been saved.')
Ejemplo n.º 7
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  logging.set_verbosity(logging.INFO)

  # Flags.
  hparam_dict = collections.OrderedDict([
      (name, FLAGS[name].value) for name in hparam_flags
  ])
  for k in hparam_dict.keys():
    if hparam_dict[k] is None:
      hparam_dict[k] = 'None'
  for k, v in hparam_dict.items():
    print('{} : {} '.format(k, v))

  
  tff.backends.native.set_local_execution_context(
      num_clients=FLAGS.num_clients_per_round)
  
  exp_name = 'client_steps={0},clients_per_round={1},status={2},model_name={3},client_batch={4},optimizer={5},lr={6},lr_factor={7}'.format(
    FLAGS.num_client_disc_train_steps, FLAGS.num_clients_per_round, FLAGS.status, FLAGS.model,
    FLAGS.client_batch_size, FLAGS.optimizer, FLAGS.lr,FLAGS.lr_factor
  )
  cache_dir = None
  
  cache_dir = os.path.join(os.path.join(FLAGS.root_output_dir, exp_name))
  cache_subdir = os.path.join(cache_dir, 'datasets')
  if not os.path.exists(cache_dir):
    os.makedirs(cache_dir)
  if not os.path.exists(cache_subdir):
    os.makedirs(cache_subdir)
  
  if not os.path.exists(os.path.join(cache_subdir, 'fed_emnist_digitsonly_train.h5')):
    copyfile('/home/houc/.keras/datasets/fed_emnist.tar.bz2', os.path.join(cache_subdir, 'fed_emnist.tar.bz2'))
    copyfile('/home/houc/.keras/datasets/fed_emnist_test.h5', os.path.join(cache_subdir, 'fed_emnist_test.h5'))
    copyfile('/home/houc/.keras/datasets/fed_emnist_train.h5', os.path.join(cache_subdir, 'fed_emnist_train.h5'))
    copyfile('/home/houc/.keras/datasets/fed_emnist_digitsonly.tar.bz2', os.path.join(cache_subdir, 'fed_emnist_digitsonly.tar.bz2'))
    copyfile('/home/houc/.keras/datasets/fed_emnist_digitsonly_test.h5', os.path.join(cache_subdir, 'fed_emnist_digitsonly_test.h5'))
    copyfile('/home/houc/.keras/datasets/fed_emnist_digitsonly_train.h5', os.path.join(cache_subdir, 'fed_emnist_digitsonly_train.h5'))
  
  client_batch_size = FLAGS.client_batch_size
  # Trained classifier model.
  classifier_model = ecm.get_trained_emnist_classifier_model()

  # GAN Models.
  if FLAGS.model == 'spectral':
    disc_model_fn, gen_model_fn = _get_gan_network_models(FLAGS.noise_dim, True)
  else:
    disc_model_fn, gen_model_fn = _get_gan_network_models(FLAGS.noise_dim)

  # Training datasets.
  server_gen_inputs_dataset = _create_gen_inputs_dataset(
      batch_size=FLAGS.server_train_batch_size, noise_dim=FLAGS.noise_dim)
  client_gen_inputs_dataset = _create_gen_inputs_dataset(
      batch_size=client_batch_size, noise_dim=FLAGS.noise_dim)
  '''
  if FLAGS.filtering == 'by_user':
    client_real_images_train_tff_data = (
        fedu.get_filtered_by_user_client_data_for_training(
            invert_imagery_probability=FLAGS.invert_imagery_probability,
            accuracy_threshold=FLAGS.accuracy_threshold,
            batch_size=CLIENT_TRAIN_BATCH_SIZE))
  elif FLAGS.filtering == 'by_example':
    client_real_images_train_tff_data = (
        fedu.get_filtered_by_example_client_data_for_training(
            invert_imagery_probability=FLAGS.invert_imagery_probability,
            min_num_examples=FLAGS.min_num_examples,
            example_class_selection=FLAGS.example_class_selection,
            batch_size=CLIENT_TRAIN_BATCH_SIZE))
  else:
    client_real_images_train_tff_data = (
        fedu.get_unfiltered_client_data_for_training(
            batch_size=CLIENT_TRAIN_BATCH_SIZE))
  '''
  if FLAGS.status == 'CEN_CEN':
    #cache_dir = None
    
    if FLAGS.model.split('_')[0].lower() == 'mnist':
      (train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()
      def preprocess_images(images):
        images = np.float32(images.reshape((images.shape[0], 28, 28, 1))/255)
        return images
      train_images = preprocess_images(train_images)
      test_images = preprocess_images(test_images)
      train_size = 60000
      test_size = 10000
      central_dataset = (tf.data.Dataset.from_tensor_slices(train_images)
                  .shuffle(train_size, reshuffle_each_iteration=True).batch(client_batch_size))
    else:
      #cache_dir = None
      central_dataset = _create_real_images_dataset_for_central(client_batch_size, cache_dir)
      
      
      
    print('Dataset done', flush = True)
    def create_tf_dataset_for_client(client_id):
      return central_dataset
    client_real_images_train_tff_data = tff.simulation.ClientData.from_clients_and_fn(["1"], create_tf_dataset_for_client)
    
    
  else:
    client_real_images_train_tff_data = (
        fedu.get_unfiltered_client_data_for_training(
            batch_size=client_batch_size, cache_dir=cache_dir))
  print('There are %d unique clients that will be used for GAN training.' %
        len(client_real_images_train_tff_data.client_ids))

  # Training: GAN Losses and Optimizers.
  gan_loss_fns = gan_losses.WassersteinGanLossFns(
      grad_penalty_lambda=FLAGS.wass_gp_lambda)
  '''
  if FLAGS.status == 'CEN_CEN' or FLAGS.status == 'LOC_LOC':
    #disc_optimizer = tf.keras.optimizers.SGD(lr=0.0004)
    #gen_optimizer = tf.keras.optimizers.SGD(lr=0.0001)
    gen_sched = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
      [1000], [0.001, 0.0005]
    )
    disc_optimizer = tf.keras.optimizers.SGD(learning_rate=0.0004)
    gen_optimizer = tf.keras.optimizers.SGD(learning_rate=gen_sched)
  else:
    disc_optimizer = tf.keras.optimizers.SGD(learning_rate=0.0002)
    gen_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
  '''
  disc_optimizer = tf.keras.optimizers.SGD(learning_rate=1)
  gen_optimizer = tf.keras.optimizers.SGD(learning_rate=1)
  # Eval datasets.
  gen_inputs_eval_dataset = _create_gen_inputs_dataset(
      batch_size=EVAL_BATCH_SIZE, noise_dim=FLAGS.noise_dim)
  if FLAGS.model.split('_')[0].lower() == 'mnist':
    real_images_eval_dataset = (tf.data.Dataset.from_tensor_slices(test_images)
                      .shuffle(test_size, reshuffle_each_iteration=True).batch(EVAL_BATCH_SIZE))
  else:
    real_images_eval_dataset = _create_real_images_dataset_for_eval(cache_dir)

  # Eval hook.
  path_to_output_images = _get_path_to_output_image(FLAGS.root_output_dir,
                                                    exp_name)
  logging.info('path_to_output_images is %s', path_to_output_images)
  #num_rounds_per_save_images = max(int(FLAGS.num_rounds/100),1)
  eval_hook_fn = _get_emnist_eval_hook_fn(
      exp_name, FLAGS.root_output_dir, hparam_dict, gan_loss_fns,
      gen_inputs_eval_dataset, real_images_eval_dataset,
      FLAGS.num_rounds_per_save_images, path_to_output_images, classifier_model)

  # Form the GAN
  '''
  gan = _get_gan(
      gen_model_fn,
      disc_model_fn,
      gan_loss_fns,
      gen_optimizer,
      disc_optimizer,
      server_gen_inputs_dataset,
      client_real_images_train_tff_data,
      use_dp=FLAGS.use_dp,
      dp_l2_norm_clip=FLAGS.dp_l2_norm_clip,
      dp_noise_multiplier=FLAGS.dp_noise_multiplier,
      clients_per_round=FLAGS.num_clients_per_round)
  '''
  statussplit = FLAGS.status.split('_')
  gan = _get_gan(
      gen_model_fn,
      disc_model_fn,
      gan_loss_fns,
      gen_optimizer,
      disc_optimizer,
      server_gen_inputs_dataset,
      client_real_images_train_tff_data,
      use_dp=FLAGS.use_dp,
      dp_l2_norm_clip=FLAGS.dp_l2_norm_clip,
      dp_noise_multiplier=FLAGS.dp_noise_multiplier,
      clients_per_round=FLAGS.num_clients_per_round,
      gen_status=statussplit[0],
      disc_status=statussplit[1],
      learning_rate=FLAGS.lr,
      optimizer=FLAGS.optimizer,
      client_disc_train_steps=FLAGS.num_client_disc_train_steps,
      lr_factor=FLAGS.lr_factor)  
  
  # Training.
  '''
  _, tff_time = _train(
      gan,
      server_gen_inputs_dataset,
      client_gen_inputs_dataset,
      client_real_images_train_tff_data,
      FLAGS.num_client_disc_train_steps,
      FLAGS.num_server_gen_train_steps,
      FLAGS.num_clients_per_round,
      FLAGS.num_rounds,
      FLAGS.num_rounds_per_eval,
      eval_hook_fn,
      FLAGS.num_rounds_per_checkpoint,
      output_dir=FLAGS.root_output_dir,
      exp_name=FLAGS.exp_name)
  '''
  #num_rounds_per_eval = max(int(FLAGS.num_rounds/100),1)
  _, tff_time = _train(
      gan,
      server_gen_inputs_dataset,
      client_gen_inputs_dataset,
      client_real_images_train_tff_data,
      FLAGS.num_client_disc_train_steps,
      FLAGS.num_server_gen_train_steps,
      FLAGS.num_clients_per_round,
      FLAGS.num_rounds,
      FLAGS.num_rounds_per_eval,
      eval_hook_fn,
      FLAGS.num_rounds_per_checkpoint,
      output_dir=FLAGS.root_output_dir,
      exp_name=exp_name) 
  logging.info('Total training time was %4.3f seconds.', tff_time)

  print('\nTRAINING COMPLETE.')