예제 #1
0
 def test_pooling_method(self):
   model1 = mobilenet_v2.create_mobilenet_v2(
       input_shape=(224, 224, 3), pooling='avg', num_classes=1000)
   model2 = mobilenet_v2.create_mobilenet_v2(
       input_shape=(224, 224, 3), pooling='max', num_classes=1000)
   self.assertIsInstance(model1, tf.keras.Model)
   self.assertIsInstance(model2, tf.keras.Model)
   self.assertEqual(model1.count_params(), model2.count_params())
예제 #2
0
 def test_num_groups(self):
   model1 = mobilenet_v2.create_mobilenet_v2(
       input_shape=(224, 224, 3), num_classes=1000)
   model2 = mobilenet_v2.create_mobilenet_v2(
       input_shape=(224, 224, 3), num_groups=4, num_classes=1000)
   self.assertIsInstance(model1, tf.keras.Model)
   self.assertIsInstance(model2, tf.keras.Model)
   self.assertEqual(model1.count_params(), model2.count_params())
예제 #3
0
 def test_alpha_changes_number_parameters(self):
   model1 = mobilenet_v2.create_mobilenet_v2(
       input_shape=(224, 224, 3), num_classes=1000)
   model2 = mobilenet_v2.create_mobilenet_v2(
       input_shape=(224, 224, 3), alpha=0.5, num_classes=1000)
   model3 = mobilenet_v2.create_mobilenet_v2(
       input_shape=(224, 224, 3), alpha=2.0, num_classes=1000)
   self.assertIsInstance(model1, tf.keras.Model)
   self.assertIsInstance(model2, tf.keras.Model)
   self.assertIsInstance(model3, tf.keras.Model)
   self.assertLess(model2.count_params(), model1.count_params())
   self.assertLess(model1.count_params(), model3.count_params())
예제 #4
0
  def test_dropout(self):
    model1 = mobilenet_v2.create_mobilenet_v2(
        input_shape=(224, 224, 3), dropout_prob=0.5)
    model2 = mobilenet_v2.create_mobilenet_v2(
        input_shape=(224, 224, 3), dropout_prob=0.2)
    model3 = mobilenet_v2.create_mobilenet_v2(
        input_shape=(224, 224, 3), dropout_prob=None)
    self.assertEqual(len(model1.layers), len(model2.layers))
    self.assertGreater(len(model1.layers), len(model3.layers))

    model1 = mobilenet_v2.create_small_mobilenet_v2(
        input_shape=(64, 64, 3), dropout_prob=0.9)
    model2 = mobilenet_v2.create_small_mobilenet_v2(
        input_shape=(64, 64, 3), dropout_prob=0.2)
    model3 = mobilenet_v2.create_small_mobilenet_v2(
        input_shape=(64, 64, 3), dropout_prob=None)
    self.assertEqual(len(model1.layers), len(model2.layers))
    self.assertGreater(len(model1.layers), len(model3.layers))
예제 #5
0
 def model_builder() -> tf.keras.Model:
     return mobilenet_v2.create_mobilenet_v2(input_shape=(image_size,
                                                          image_size, 3),
                                             num_groups=num_groups,
                                             num_classes=num_classes,
                                             dropout_prob=dropout_prob)
예제 #6
0
def run_centralized(
        optimizer: tf.keras.optimizers.Optimizer,
        image_size: int,
        num_epochs: int,
        batch_size: int,
        num_groups: int = 8,
        dataset_type: dataset.DatasetType = dataset.DatasetType.GLD23K,
        experiment_name: str = 'centralized_gld23k',
        root_output_dir: str = '/tmp/fedopt_guide',
        dropout_prob: Optional[float] = None,
        hparams_dict: Optional[Mapping[str, Any]] = None,
        max_batches: Optional[int] = None):
    """Trains a MobileNetV2 on the Google Landmark datasets.

  Args:
    optimizer: A `tf.keras.optimizers.Optimizer` used to perform training.
    image_size: The height and width of images after preprocessing.
    num_epochs: The number of training epochs.
    batch_size: The batch size, used for train and test.
    num_groups: The number of groups in the GroupNorm layers of MobilenetV2.
    dataset_type: A `dataset.DatasetType` specifying which dataset is used for
      experiments.
    experiment_name: The name of the experiment. Part of the output directory.
    root_output_dir: The top-level output directory for experiment runs. The
      `experiment_name` argument will be appended, and the directory will
      contain tensorboard logs, metrics written as CSVs, and a CSV of
      hyperparameter choices (if `hparams_dict` is used).
    dropout_prob: Probability of setting a weight to zero in the dropout layer
      of MobilenetV2. Must be in the range [0, 1). Setting it to None (default)
      or zero means no dropout.
    hparams_dict: A mapping with string keys representing the hyperparameters
      and their values. If not None, this is written to CSV.
    max_batches: If set to a positive integer, datasets are capped to at most
      that many batches. If set to None or a nonpositive integer, the full
      datasets are used.
  """

    train_data, test_data = dataset.get_centralized_datasets(
        image_size=image_size,
        batch_size=batch_size,
        dataset_type=dataset_type)
    num_classes, _ = dataset.get_dataset_stats(dataset_type)

    if max_batches and max_batches >= 1:
        train_data = train_data.take(max_batches)
        test_data = test_data.take(max_batches)

    if dropout_prob and (dropout_prob < 0 or dropout_prob >= 1):
        raise ValueError(
            f'Expected a value in [0, 1) for `dropout_prob`, found {dropout_prob}.'
        )

    model = mobilenet_v2.create_mobilenet_v2(input_shape=(image_size,
                                                          image_size, 3),
                                             num_groups=num_groups,
                                             num_classes=num_classes,
                                             dropout_prob=dropout_prob)

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                  optimizer=optimizer,
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    centralized_training_loop.run(keras_model=model,
                                  train_dataset=train_data,
                                  validation_dataset=test_data,
                                  experiment_name=experiment_name,
                                  root_output_dir=root_output_dir,
                                  num_epochs=num_epochs,
                                  hparams_dict=hparams_dict)