Exemple #1
0
class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
    TestDistributionStrategyDnnCorrectness):

  def get_model(self,
                cloning,
                initial_weights=None,
                distribution=None,
                input_shapes=None):
    with keras_correctness_test_base.MaybeDistributionScope(distribution):
      model = SubclassedModel(initial_weights, input_shapes)

      model.compile(
          loss=keras.losses.mean_squared_error,
          optimizer=gradient_descent_keras.SGD(0.05),
          metrics=['mse'],
          cloning=cloning)
      return model

  @combinations.generate(
      keras_correctness_test_base.all_strategy_and_input_config_combinations())
  def test_dnn_correctness(self, distribution, use_numpy, use_validation_data,
                           cloning):
    if (context.executing_eagerly() or is_default_strategy(distribution)):
      self.run_correctness_test(distribution, use_numpy, use_validation_data,
                                cloning)
    elif K.is_tpu_strategy(distribution) and not context.executing_eagerly():
      with self.assertRaisesRegexp(
          ValueError,
          'Expected `model` argument to be a functional `Model` instance, '
          'but got a subclass model instead.'):
        self.run_correctness_test(distribution, use_numpy, use_validation_data,
                                  cloning)
    else:
      with self.assertRaisesRegexp(
          ValueError,
          'We currently do not support distribution strategy with a '
          '`Sequential` model that is created without `input_shape`/'
          '`input_dim` set in its first layer or a subclassed model.'):
        self.run_correctness_test(distribution, use_numpy, use_validation_data,
                                  cloning)

  @combinations.generate(all_strategy_combinations_with_graph_mode())
  def test_dnn_with_dynamic_learning_rate(self, distribution, cloning):
    if ((not cloning and context.executing_eagerly() and
         not K.is_tpu_strategy(distribution)) or
        is_default_strategy(distribution)):
      self.run_dynamic_lr_test(distribution, cloning)
    elif K.is_tpu_strategy(distribution):
      with self.assertRaisesRegexp(
          ValueError,
          'Expected `model` argument to be a functional `Model` instance, '
          'but got a subclass model instead.'):
        self.run_dynamic_lr_test(distribution, cloning)
    else:
      with self.assertRaisesRegexp(
          ValueError,
          'We currently do not support distribution strategy with a '
          '`Sequential` model that is created without `input_shape`/'
          '`input_dim` set in its first layer or a subclassed model.'):
        self.run_dynamic_lr_test(distribution, cloning)

  @combinations.generate(
      keras_correctness_test_base.test_combinations_with_tpu_strategies())
  def test_dnn_correctness_with_partial_last_batch_eval(self, distribution,
                                                        use_numpy,
                                                        use_validation_data):
    with self.assertRaisesRegexp(
        ValueError,
        'Expected `model` argument to be a functional `Model` instance, '
        'but got a subclass model instead.'):
      self.run_correctness_test(
          distribution,
          use_numpy,
          use_validation_data,
          partial_last_batch='eval')
class DistributionStrategyCnnCorrectnessTest(
    keras_correctness_test_base.TestDistributionStrategyCorrectnessBase):

  def get_model(self,
                initial_weights=None,
                distribution=None,
                input_shapes=None):
    del input_shapes
    with keras_correctness_test_base.MaybeDistributionScope(distribution):
      image = keras.layers.Input(shape=(28, 28, 3), name='image')
      c1 = keras.layers.Conv2D(
          name='conv1',
          filters=16,
          kernel_size=(3, 3),
          strides=(4, 4),
          kernel_regularizer=keras.regularizers.l2(1e-4))(
              image)
      if self.with_batch_norm == 'regular':
        c1 = keras.layers.BatchNormalization(name='bn1')(c1)
      elif self.with_batch_norm == 'sync':
        c1 = keras.layers.SyncBatchNormalization(name='bn1')(c1)
      c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1)
      logits = keras.layers.Dense(
          10, activation='softmax', name='pred')(
              keras.layers.Flatten()(c1))
      model = keras.Model(inputs=[image], outputs=[logits])

      if initial_weights:
        model.set_weights(initial_weights)

      model.compile(
          optimizer=gradient_descent.SGD(learning_rate=0.1),
          loss='sparse_categorical_crossentropy',
          metrics=['sparse_categorical_accuracy'])

    return model

  def _get_data(self, count, shape=(28, 28, 3), num_classes=10):
    centers = np.random.randn(num_classes, *shape)

    features = []
    labels = []
    for _ in range(count):
      label = np.random.randint(0, num_classes, size=1)[0]
      offset = np.random.normal(loc=0, scale=0.1, size=np.prod(shape))
      offset = offset.reshape(shape)
      labels.append(label)
      features.append(centers[label] + offset)

    x = np.asarray(features, dtype=np.float32)
    y = np.asarray(labels, dtype=np.float32).reshape((count, 1))
    return x, y

  def get_data(self):
    x_train, y_train = self._get_data(
        count=keras_correctness_test_base._GLOBAL_BATCH_SIZE *
        keras_correctness_test_base._EVAL_STEPS)
    x_predict = x_train
    return x_train, y_train, x_predict

  def get_data_with_partial_last_batch_eval(self):
    x_train, y_train = self._get_data(count=1280)
    x_eval, y_eval = self._get_data(count=1000)
    return x_train, y_train, x_eval, y_eval, x_eval

  @ds_combinations.generate(
      keras_correctness_test_base.all_strategy_and_input_config_combinations())
  def test_cnn_correctness(self, distribution, use_numpy, use_validation_data):
    self.run_correctness_test(distribution, use_numpy, use_validation_data)

  @ds_combinations.generate(
      keras_correctness_test_base.all_strategy_and_input_config_combinations())
  def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy,
                                           use_validation_data):
    self.run_correctness_test(
        distribution,
        use_numpy,
        use_validation_data,
        with_batch_norm='regular')

  @ds_combinations.generate(
      keras_correctness_test_base.all_strategy_and_input_config_combinations())
  def test_cnn_with_sync_batch_norm_correctness(self, distribution, use_numpy,
                                                use_validation_data):
    if not context.executing_eagerly():
      self.skipTest('SyncBatchNorm is not enabled in graph mode.')

    self.run_correctness_test(
        distribution,
        use_numpy,
        use_validation_data,
        with_batch_norm='sync')

  @ds_combinations.generate(
      keras_correctness_test_base.test_combinations_with_tpu_strategies() +
      keras_correctness_test_base
      .strategy_minus_tpu_and_input_config_combinations_eager())
  def test_cnn_correctness_with_partial_last_batch_eval(self, distribution,
                                                        use_numpy,
                                                        use_validation_data):
    self.run_correctness_test(
        distribution,
        use_numpy,
        use_validation_data,
        partial_last_batch=True,
        training_epochs=1)

  @ds_combinations.generate(
      keras_correctness_test_base.test_combinations_with_tpu_strategies() +
      keras_correctness_test_base
      .strategy_minus_tpu_and_input_config_combinations_eager())
  def test_cnn_with_batch_norm_correctness_and_partial_last_batch_eval(
      self, distribution, use_numpy, use_validation_data):
    self.run_correctness_test(
        distribution,
        use_numpy,
        use_validation_data,
        with_batch_norm='regular',
        partial_last_batch=True)
class TestDistributionStrategyDnnCorrectness(
        keras_correctness_test_base.TestDistributionStrategyCorrectnessBase):
    def get_model(self,
                  initial_weights=None,
                  distribution=None,
                  input_shapes=None):
        with keras_correctness_test_base.MaybeDistributionScope(distribution):
            # We add few non-linear layers to make it non-trivial.
            model = keras.Sequential()
            model.add(
                keras.layers.Dense(10, activation='relu', input_shape=(1, )))
            model.add(
                keras.layers.Dense(
                    10,
                    activation='relu',
                    kernel_regularizer=keras.regularizers.l2(1e-4)))
            model.add(keras.layers.Dense(10, activation='relu'))
            model.add(keras.layers.Dense(1))

            if initial_weights:
                model.set_weights(initial_weights)

            model.compile(loss=keras.losses.mean_squared_error,
                          optimizer=gradient_descent_keras.SGD(0.05),
                          metrics=['mse'])
            return model

    def get_data(self):
        x_train = np.random.rand(9984, 1).astype('float32')
        y_train = 3 * x_train
        x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32)
        return x_train, y_train, x_predict

    def get_data_with_partial_last_batch(self):
        x_train = np.random.rand(10000, 1).astype('float32')
        y_train = 3 * x_train
        x_eval = np.random.rand(10000, 1).astype('float32')
        y_eval = 3 * x_eval
        x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32)
        return x_train, y_train, x_eval, y_eval, x_predict

    def get_data_with_partial_last_batch_eval(self):
        x_train = np.random.rand(9984, 1).astype('float32')
        y_train = 3 * x_train
        x_eval = np.random.rand(10000, 1).astype('float32')
        y_eval = 3 * x_eval
        x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32)
        return x_train, y_train, x_eval, y_eval, x_predict

    @ds_combinations.generate(
        keras_correctness_test_base.all_strategy_and_input_config_combinations(
        ) + keras_correctness_test_base.multi_worker_mirrored_eager())
    def test_dnn_correctness(self, distribution, use_numpy,
                             use_validation_data):
        self.run_correctness_test(distribution, use_numpy, use_validation_data)

    @ds_combinations.generate(
        keras_correctness_test_base.
        test_combinations_with_tpu_strategies_graph() +
        keras_correctness_test_base.multi_worker_mirrored_eager())
    def test_dnn_correctness_with_partial_last_batch_eval(
            self, distribution, use_numpy, use_validation_data):
        self.run_correctness_test(distribution,
                                  use_numpy,
                                  use_validation_data,
                                  partial_last_batch='eval')

    @ds_combinations.generate(
        keras_correctness_test_base.
        strategy_minus_tpu_and_input_config_combinations_eager() +
        keras_correctness_test_base.multi_worker_mirrored_eager())
    def test_dnn_correctness_with_partial_last_batch(self, distribution,
                                                     use_numpy,
                                                     use_validation_data):
        distribution.extended.experimental_enable_get_next_as_optional = True
        self.run_correctness_test(distribution,
                                  use_numpy,
                                  use_validation_data,
                                  partial_last_batch='train_and_eval',
                                  training_epochs=1)

    @ds_combinations.generate(all_strategy_combinations_with_graph_mode())
    def test_dnn_with_dynamic_learning_rate(self, distribution):
        self.run_dynamic_lr_test(distribution)
Exemple #4
0
class DistributionStrategyCnnCorrectnessTest(
    keras_correctness_test_base.TestDistributionStrategyCorrectnessBase):

  def get_model(self,
                initial_weights=None,
                distribution=None,
                cloning=None,
                input_shapes=None):
    del input_shapes
    with keras_correctness_test_base.MaybeDistributionScope(distribution):
      image = keras.layers.Input(shape=(28, 28, 3), name='image')
      c1 = keras.layers.Conv2D(
          name='conv1', filters=16, kernel_size=(3, 3), strides=(4, 4),
          kernel_regularizer=keras.regularizers.l2(1e-4))(
              image)
      if self.with_batch_norm:
        c1 = keras.layers.BatchNormalization(name='bn1')(c1)
      c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1)
      logits = keras.layers.Dense(
          10, activation='softmax', name='pred')(
              keras.layers.Flatten()(c1))
      model = keras.Model(inputs=[image], outputs=[logits])

      if initial_weights:
        model.set_weights(initial_weights)

      model.compile(
          optimizer=gradient_descent.SGD(
              learning_rate=0.1),
          loss='sparse_categorical_crossentropy',
          metrics=['sparse_categorical_accuracy'],
          cloning=cloning)

    return model

  def get_data(self,
               count=keras_correctness_test_base._GLOBAL_BATCH_SIZE
               * keras_correctness_test_base._EVAL_STEPS,
               shape=(28, 28, 3),
               num_classes=10):
    centers = np.random.randn(num_classes, *shape)

    features = []
    labels = []
    for _ in range(count):
      label = np.random.randint(0, num_classes, size=1)[0]
      offset = np.random.normal(loc=0, scale=0.1, size=np.prod(shape))
      offset = offset.reshape(shape)
      labels.append(label)
      features.append(centers[label] + offset)

    x_train = np.asarray(features, dtype=np.float32)
    y_train = np.asarray(labels, dtype=np.float32).reshape((count, 1))
    x_predict = x_train
    return x_train, y_train, x_predict

  @combinations.generate(
      keras_correctness_test_base.all_strategy_and_input_config_combinations())
  def test_cnn_correctness(self, distribution, use_numpy, use_validation_data,
                           cloning):
    self.run_correctness_test(distribution, use_numpy, use_validation_data,
                              cloning)

  @combinations.generate(
      keras_correctness_test_base.all_strategy_and_input_config_combinations())
  def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy,
                                           use_validation_data, cloning):
    self.run_correctness_test(distribution, use_numpy, use_validation_data,
                              with_batch_norm=True, cloning=cloning)