class TestDistributionStrategyValidation(test.TestCase, parameterized.TestCase):

  @combinations.generate(
      combinations.times(
          keras_test_lib.all_strategy_combinations_minus_default(),
          combinations.combine(cloning=[True, False])))
  def test_layer_outside_scope(self, distribution, cloning):
    with self.cached_session():
      with self.assertRaisesRegexp(
          ValueError, 'was not created in the distribution strategy'):
        x = keras.layers.Input(shape=(3,), name='input')
        y = keras.layers.Dense(4, name='dense')(x)
        with distribution.scope():
          model = keras.Model(x, y)
          optimizer = gradient_descent.GradientDescentOptimizer(0.001)
          loss = 'mse'
          metrics = ['mae', keras.metrics.CategoricalAccuracy()]
          model.compile(optimizer, loss, metrics=metrics, cloning=cloning)

  @combinations.generate(
      combinations.times(
          keras_test_lib.all_strategy_combinations_minus_default(),
          combinations.combine(cloning=[True, False])))
  def test_model_outside_scope(self, distribution, cloning):
    with self.cached_session():
      with self.assertRaisesRegexp(
          ValueError, 'was not created in the distribution strategy'):
        x = keras.layers.Input(shape=(3,), name='input')
        y = keras.layers.Dense(4, name='dense')(x)
        model = keras.Model(x, y)
        with distribution.scope():
          optimizer = gradient_descent.GradientDescentOptimizer(0.001)
          loss = 'mse'
          metrics = ['mae', keras.metrics.CategoricalAccuracy()]
          model.compile(optimizer, loss, metrics=metrics, cloning=cloning)
class SavedModelTFModuleTest(test_base.TestSavedModelBase):
    def setUp(self):
        self._root_dir = 'saved_model_save_load'
        super(SavedModelTFModuleTest, self).setUp()

    def _train_model(self, model, x_train, y_train, batch_size):
        pass

    def _predict_with_model(self, distribution, model, predict_dataset):
        if distribution:
            dist_predict_dataset = distribution.experimental_distribute_dataset(
                predict_dataset)
            per_replica_predict_data = next(iter(dist_predict_dataset))
            result = distribution.experimental_run_v2(
                model, args=(per_replica_predict_data, ))
            # Convert the per_replica value to a list, then concatenate them
            reduced = distribution.experimental_local_results(result)
            concat = array_ops.concat(reduced, 0)
            return concat
        else:
            return model(next(iter(predict_dataset)))

    def _save_model(self, model, saved_dir):
        call = model.__call__.get_concrete_function(
            tensor_spec.TensorSpec(None))
        saved_model.save(model, saved_dir, signatures=call)

    def _load_and_run_model(self,
                            distribution,
                            saved_dir,
                            predict_dataset,
                            output_name='output_1'):
        del output_name
        model = saved_model.load(saved_dir)
        return self._predict_with_model(distribution, model, predict_dataset)

    @combinations.generate(test_base.tfmodule_models_with_strategies())
    def test_save_no_strategy_restore_strategy(self, model_and_input,
                                               distribution):
        self.run_test_save_no_strategy_restore_strategy(
            model_and_input, distribution)

    @combinations.generate(
        combinations.times(test_base.tfmodule_models_with_strategies(),
                           combinations.combine(save_in_scope=[True, False])))
    def test_save_strategy_restore_no_strategy(self, model_and_input,
                                               distribution, save_in_scope):
        self.run_test_save_strategy_restore_no_strategy(
            model_and_input, distribution, save_in_scope)

    @combinations.generate(
        combinations.times(test_base.tfmodule_models_with_strategy_pairs(),
                           combinations.combine(save_in_scope=[True, False])))
    def test_save_strategy_restore_strategy(self, model_and_input,
                                            distribution_for_saving,
                                            distribution_for_restoring,
                                            save_in_scope):
        self.run_test_save_strategy_restore_strategy(
            model_and_input, distribution_for_saving,
            distribution_for_restoring, save_in_scope)
class SavedModelKerasModelTest(test_base.TestSavedModelBase):

  def setUp(self):
    self._root_dir = 'saved_model_save_load'
    super(SavedModelKerasModelTest, self).setUp()

  def _save_model(self, model, saved_dir):
    saved_model.save(model, saved_dir)

  def _load_and_run_model(self,
                          distribution,
                          saved_dir,
                          predict_dataset,
                          output_name='output_1'):
    return test_base.load_and_run_with_saved_model_api(distribution, saved_dir,
                                                       predict_dataset,
                                                       output_name)

  @combinations.generate(test_base.simple_models_with_strategies())
  def test_save_no_strategy_restore_strategy(self, model_and_input,
                                             distribution):
    self.run_test_save_no_strategy_restore_strategy(
        model_and_input, distribution)

  @combinations.generate(
      combinations.times(test_base.simple_models_with_strategies(),
                         combinations.combine(save_in_scope=[True, False])))
  def test_save_strategy_restore_no_strategy(self, model_and_input,
                                             distribution, save_in_scope):
    self.run_test_save_strategy_restore_no_strategy(
        model_and_input, distribution, save_in_scope)

  @combinations.generate(
      combinations.times(test_base.simple_models_with_strategy_pairs(),
                         combinations.combine(save_in_scope=[True, False])))
  def test_save_strategy_restore_strategy(self, model_and_input,
                                          distribution_for_saving,
                                          distribution_for_restoring,
                                          save_in_scope):
    self.run_test_save_strategy_restore_strategy(model_and_input,
                                                 distribution_for_saving,
                                                 distribution_for_restoring,
                                                 save_in_scope)

  @combinations.generate(
      combinations.times(test_base.simple_models_with_strategies(),
                         combinations.combine(save_in_scope=[True, False])))
  def test_no_variable_device_placement(self, model_and_input, distribution,
                                        save_in_scope):
    saved_dir = self.run_test_save_strategy(model_and_input, distribution,
                                            save_in_scope)
    func = saved_model.load(saved_dir)
    concrete_function = func.signatures[test_base._DEFAULT_FUNCTION_KEY]
    for f in concrete_function.graph.as_graph_def().library.function:
      for n in f.node_def:
        if n.op == 'ReadVariableOp':
          self.assertEmpty(n.device)
Example #4
0
class TestDistributionStrategySaveLoadWeights(test.TestCase,
                                              parameterized.TestCase):
    @combinations.generate(
        combinations.times(
            keras_test_lib.all_strategy_combinations_minus_default(),
            combinations.combine(
                cloning=[True, False],
                optimizer=strategy_combinations.rmsprop_optimizer_keras_v2_fn))
    )
    def test_save_load_h5(self, distribution, optimizer, cloning):
        with self.cached_session():
            dataset = keras_test_lib.get_dataset(distribution)
            with distribution.scope():
                model = keras_test_lib.get_model()
                model.compile(optimizer(), 'mse', cloning=cloning)
                model.fit(dataset, epochs=1, steps_per_epoch=1)

                weights_file = tempfile.mktemp('.h5')
                model.save_weights(weights_file)

                model_2 = keras_test_lib.get_model()
                model_2.compile(optimizer(), 'mse', cloning=cloning)
                model_2.load_weights(weights_file)
                model_2.predict(
                    keras_test_lib.get_predict_dataset(distribution), steps=2)
                model_2.fit(dataset, epochs=1, steps_per_epoch=1)

    @combinations.generate(
        combinations.times(
            keras_test_lib.all_strategy_combinations_minus_default(),
            combinations.combine(
                cloning=[True, False],
                optimizer=strategy_combinations.rmsprop_optimizer_keras_v2_fn))
    )
    def test_save_load_trackable(self, distribution, optimizer, cloning):
        # TODO(b/123533246): Enable the test for TPU once bug is fixed
        if (isinstance(distribution,
                       (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1))
                and distribution.extended.steps_per_run > 1):
            self.skipTest(
                'MultiStep TPU Strategy deadlocks with optimizer restore.')
        with self.cached_session():
            dataset = keras_test_lib.get_dataset(distribution)
            with distribution.scope():
                model = keras_test_lib.get_model()
                model.compile(optimizer(), 'mse', cloning=cloning)
                model.fit(dataset, epochs=1, steps_per_epoch=1)

                weights_file = tempfile.mktemp()
                model.save_weights(weights_file)

                model_2 = keras_test_lib.get_model()
                model_2.compile(optimizer(), 'mse', cloning=cloning)
                model_2.load_weights(weights_file)
                model_2.predict(
                    keras_test_lib.get_predict_dataset(distribution), steps=2)
                model_2.fit(dataset, epochs=1, steps_per_epoch=1)
Example #5
0
class SavedModelKerasModelTest(test_base.TestSavedModelBase):

  def setUp(self):
    self._root_dir = 'saved_model_save_load'
    super(SavedModelKerasModelTest, self).setUp()

  def _save_model(self, model, saved_dir):
    saved_model.save(model, saved_dir)

  def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
                          output_name, experimental_run_tf_function):
    return test_base.load_and_run_with_saved_model_api(distribution, saved_dir,
                                                       predict_dataset,
                                                       output_name)

  @combinations.generate(test_base.simple_models_with_strategies())
  def test_save_no_strategy_restore_strategy(self, model_and_input,
                                             distribution,
                                             experimental_run_tf_function):
    self.run_test_save_no_strategy_restore_strategy(
        model_and_input, distribution, experimental_run_tf_function)

  @combinations.generate(
      combinations.times(test_base.simple_models_with_strategies(),
                         combinations.combine(save_in_scope=[True, False])))
  def test_save_strategy_restore_no_strategy(self, model_and_input,
                                             distribution, save_in_scope,
                                             experimental_run_tf_function):
    if save_in_scope:
      # TODO(b/134703272): Unskip this test when fixed.
      self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
                     'supported.'))
    self.run_test_save_strategy_restore_no_strategy(
        model_and_input, distribution, save_in_scope,
        experimental_run_tf_function)

  @combinations.generate(
      combinations.times(test_base.simple_models_with_strategy_pairs(),
                         combinations.combine(save_in_scope=[True, False])))
  def test_save_strategy_restore_strategy(self, model_and_input,
                                          distribution_for_saving,
                                          distribution_for_restoring,
                                          save_in_scope,
                                          experimental_run_tf_function):
    if save_in_scope:
      # TODO(b/134703272): Unskip this test when fixed.
      self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
                     'supported.'))
    self.run_test_save_strategy_restore_strategy(model_and_input,
                                                 distribution_for_saving,
                                                 distribution_for_restoring,
                                                 save_in_scope,
                                                 experimental_run_tf_function)
Example #6
0
class KerasSaveLoadTest(test_base.TestSavedModelBase):

  def setUp(self):
    self._root_dir = 'keras_save_load'
    super(KerasSaveLoadTest, self).setUp()

  def _save_model(self, model, saved_dir):
    model.save(saved_dir, save_format='tf')

  def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
                          output_name, experimental_run_tf_function):
    restored_keras_model = save.load_model(saved_dir)
    restored_keras_model._experimental_run_tf_function = (
        experimental_run_tf_function)
    return restored_keras_model.predict(
        predict_dataset, steps=test_base.PREDICT_STEPS)

  @combinations.generate(test_base.simple_models_with_strategies())
  def test_save_no_strategy_restore_strategy(self, model_and_input,
                                             distribution,
                                             experimental_run_tf_function):
    self.run_test_save_no_strategy_restore_strategy(
        model_and_input, distribution, experimental_run_tf_function)

  @combinations.generate(
      combinations.times(test_base.simple_models_with_strategies(),
                         combinations.combine(save_in_scope=[True, False])))
  def test_save_strategy_restore_no_strategy(self, model_and_input,
                                             distribution, save_in_scope,
                                             experimental_run_tf_function):
    if save_in_scope:
      self.skipTest(('b/134703272 - Saving model in tf.distribute.Strategy ',
                     'scope is not supported.'))
    self.run_test_save_strategy_restore_no_strategy(
        model_and_input, distribution, save_in_scope,
        experimental_run_tf_function)

  @combinations.generate(
      combinations.times(test_base.simple_models_with_strategy_pairs(),
                         combinations.combine(save_in_scope=[True, False])))
  def test_save_strategy_restore_strategy(self, model_and_input,
                                          distribution_for_saving,
                                          distribution_for_restoring,
                                          save_in_scope,
                                          experimental_run_tf_function):
    if save_in_scope:
      self.skipTest(('b/134703272 - Saving model in tf.distribute.Strategy ',
                     'scope is not supported.'))
    self.run_test_save_strategy_restore_strategy(model_and_input,
                                                 distribution_for_saving,
                                                 distribution_for_restoring,
                                                 save_in_scope,
                                                 experimental_run_tf_function)
Example #7
0
def test_combinations_for_embedding_model():
    # TODO(sourabhbajaj): Enable tests for eager mode
    eager_mode_strategies = [
        s for s in strategies_for_embedding_models() if not s.required_tpu
    ]

    return (combinations.times(
        combinations.combine(distribution=strategies_for_embedding_models(),
                             experimental_run_tf_function=[True, False]),
        (graph_mode_test_configuration())) + combinations.times(
            combinations.combine(distribution=eager_mode_strategies,
                                 experimental_run_tf_function=[False]),
            (eager_mode_test_configuration())))
def strategy_and_input_combinations():
    return (combinations.times(
        combinations.combine(distribution=strategies_minus_tpu),
        combinations.combine(mode=['graph'],
                             use_numpy=[True, False],
                             use_validation_data=[True, False]) +
        combinations.combine(
            mode=['eager'], use_numpy=[False], use_validation_data=[False])) +
            combinations.times(
                combinations.combine(distribution=tpu_strategies),
                combinations.combine(mode=['graph'],
                                     use_numpy=[True, False],
                                     use_validation_data=[True, False])))
def test_combinations_for_embedding_model():
  # TODO(sourabhbajaj): Enable tests for eager mode
  eager_mode_strategies = [
      s for s in strategies_for_embedding_models() if not s.required_tpu
  ]

  return (combinations.times(
      combinations.combine(
          distribution=strategies_for_embedding_models()),
      (graph_mode_test_configuration())) + combinations.times(
          combinations.combine(
              distribution=eager_mode_strategies),
          (eager_mode_test_configuration())))
class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):

  def setUp(self):
    self._root_dir = 'saved_model_save_load'
    super(SavedModelSaveAndLoadTest, self).setUp()

  def _save_model(self, model, saved_dir):
    keras_saved_model.export_saved_model(model, saved_dir, serving_only=True)

  def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
                          output_name, run_distributed):
    return test_base.load_and_run_with_saved_model_api(distribution, saved_dir,
                                                       predict_dataset,
                                                       output_name)

  @combinations.generate(test_base.simple_models_with_strategies())
  def test_save_no_strategy_restore_strategy(self, model_and_input,
                                             distribution, run_distributed):
    self.run_test_save_no_strategy_restore_strategy(model_and_input,
                                                    distribution,
                                                    run_distributed)

  @combinations.generate(
      combinations.times(test_base.simple_models_with_strategies(),
                         combinations.combine(save_in_scope=[True, False])))
  def test_save_strategy_restore_no_strategy(self, model_and_input,
                                             distribution, save_in_scope,
                                             run_distributed):
    if save_in_scope:
      self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
                     'supported.'))
    self.run_test_save_strategy_restore_no_strategy(model_and_input,
                                                    distribution, save_in_scope,
                                                    run_distributed)

  @combinations.generate(
      combinations.times(test_base.simple_models_with_strategy_pairs(),
                         combinations.combine(save_in_scope=[True, False])))
  def test_save_strategy_restore_strategy(self, model_and_input,
                                          distribution_for_saving,
                                          distribution_for_restoring,
                                          save_in_scope, run_distributed):
    if save_in_scope:
      self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
                     'supported.'))
    self.run_test_save_strategy_restore_strategy(model_and_input,
                                                 distribution_for_saving,
                                                 distribution_for_restoring,
                                                 save_in_scope, run_distributed)
class SavedModelKerasModelTest(test_base.TestSavedModelBase):

  def setUp(self):
    self._root_dir = 'saved_model_save_load'
    super(SavedModelKerasModelTest, self).setUp()

  def _save_model(self, model, saved_dir):
    saved_model.save(model, saved_dir)

  def _load_and_run_model(self,
                          distribution,
                          saved_dir,
                          predict_dataset,
                          experimental_run_tf_function,
                          output_name='output_1'):
    return test_base.load_and_run_with_saved_model_api(distribution, saved_dir,
                                                       predict_dataset,
                                                       output_name)

  @combinations.generate(test_base.simple_models_with_strategies())
  def test_save_no_strategy_restore_strategy(self, model_and_input,
                                             distribution,
                                             experimental_run_tf_function):
    self.run_test_save_no_strategy_restore_strategy(
        model_and_input, distribution, experimental_run_tf_function)

  @combinations.generate(
      combinations.times(test_base.simple_models_with_strategies(),
                         combinations.combine(save_in_scope=[True, False])))
  def test_save_strategy_restore_no_strategy(self, model_and_input,
                                             distribution, save_in_scope,
                                             experimental_run_tf_function):
    self.run_test_save_strategy_restore_no_strategy(
        model_and_input, distribution, save_in_scope,
        experimental_run_tf_function)

  @combinations.generate(
      combinations.times(test_base.simple_models_with_strategy_pairs(),
                         combinations.combine(save_in_scope=[True, False])))
  def test_save_strategy_restore_strategy(self, model_and_input,
                                          distribution_for_saving,
                                          distribution_for_restoring,
                                          save_in_scope,
                                          experimental_run_tf_function):
    self.run_test_save_strategy_restore_strategy(model_and_input,
                                                 distribution_for_saving,
                                                 distribution_for_restoring,
                                                 save_in_scope,
                                                 experimental_run_tf_function)
def strategy_and_input_combinations():
  return (
      combinations.times(
          combinations.combine(distribution=strategies_minus_tpu),
          combinations.combine(mode=['graph'],
                               use_numpy=[True, False],
                               use_validation_data=[True, False])
          + combinations.combine(mode=['eager'],
                                 use_numpy=[False],
                                 use_validation_data=[False])) +
      combinations.times(
          combinations.combine(distribution=tpu_strategies),
          combinations.combine(mode=['graph'],
                               use_numpy=[True, False],
                               use_validation_data=[True, False])))
def test_combinations_for_embedding_model():
  return (
      combinations.times(
          combinations.combine(distribution=
                               strategies_for_embedding_models()),
          (graph_mode_test_configuration() +
           eager_mode_test_configuration())))
def test_combinations_for_embedding_model():
  return (
      combinations.times(
          combinations.combine(distribution=
                               strategies_for_embedding_models()),
          (graph_mode_test_configuration() +
           eager_mode_test_configuration())))
 def test_times_variable_arguments(self):
   c1 = combinations.combine(mode=["graph", "eager"])
   c2 = combinations.combine(optimizer=["adam", "gd"])
   c3 = combinations.combine(distribution=["d1", "d2"])
   c4 = combinations.times(c3, c1, c2)
   self.assertEqual([
       OrderedDict([("distribution", "d1"), ("mode", "graph"),
                    ("optimizer", "adam")]),
       OrderedDict([("distribution", "d1"), ("mode", "graph"),
                    ("optimizer", "gd")]),
       OrderedDict([("distribution", "d1"), ("mode", "eager"),
                    ("optimizer", "adam")]),
       OrderedDict([("distribution", "d1"), ("mode", "eager"),
                    ("optimizer", "gd")]),
       OrderedDict([("distribution", "d2"), ("mode", "graph"),
                    ("optimizer", "adam")]),
       OrderedDict([("distribution", "d2"), ("mode", "graph"),
                    ("optimizer", "gd")]),
       OrderedDict([("distribution", "d2"), ("mode", "eager"),
                    ("optimizer", "adam")]),
       OrderedDict([("distribution", "d2"), ("mode", "eager"),
                    ("optimizer", "gd")])
   ], c4)
   self.assertEqual(
       combinations.combine(
           mode=["graph", "eager"],
           optimizer=["adam", "gd"],
           distribution=["d1", "d2"]), c4)
Example #16
0
class SingleLossStepTest(test.TestCase, parameterized.TestCase):

  @combinations.generate(
      combinations.times(
          strategy_combinations.distributions_and_v1_optimizers(),
          combinations.combine(
              mode=strategy_combinations.graph_and_eager_modes),
          combinations.combine(is_tpu=[False])) + combinations.combine(
              distribution=[strategy_combinations.tpu_strategy],
              optimizer_fn=strategy_combinations.optimizers_v1,
              mode=["graph"],
              is_tpu=[True]))
  def testTrainNetwork(self, distribution, optimizer_fn, is_tpu):
    with distribution.scope():
      single_loss_step, layer = single_loss_example(
          optimizer_fn, distribution, use_bias=True, iterations_per_step=2)

      if context.executing_eagerly():
        single_loss_step.initialize()
        run_step = single_loss_step
      else:
        with self.cached_session() as sess:
          sess.run(single_loss_step.initialize())
          run_step = sess.make_callable(single_loss_step())
      self.evaluate(variables.global_variables_initializer())

      weights, biases = [], []
      for _ in range(5):
        run_step()
        weights.append(self.evaluate(layer.kernel))
        biases.append(self.evaluate(layer.bias))

      error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
      is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
      self.assertTrue(is_not_increasing)
class CommunicationHintTest(test.TestCase, parameterized.TestCase):
    def setUp(self):
        _setup_context()
        super().setUp()

    @combinations.generate(
        combinations.times(collective_op_combinations,
                           combinations.combine(required_gpus=[0, 1])))
    def testNCCLFallbackOnCPU(self, collective_op):
        # communication_hint=NCCL should work for CPU by falling back to RING. The
        # test doesn't actually require GPU, only GPU builds. We specify
        # required_gpus=1 so that it's tested with GPU builds.
        dev0 = '/device:CPU:0'
        dev1 = '/device:CPU:1'
        group_key = 20
        instance_key = 30
        input_data = constant_op.constant([1., 2., 3., 4.])

        @def_function.function
        def run():
            for device in [dev0, dev1]:
                with ops.device(device):
                    collective_op(input_data,
                                  group_size=2,
                                  group_key=group_key,
                                  instance_key=instance_key,
                                  communication_hint='NCCL')

        run()
Example #18
0
 def test_times_variable_arguments(self):
     c1 = combinations.combine(mode=["graph", "eager"])
     c2 = combinations.combine(optimizer=["adam", "gd"])
     c3 = combinations.combine(distribution=["d1", "d2"])
     c4 = combinations.times(c3, c1, c2)
     self.assertEqual([
         OrderedDict([("distribution", "d1"), ("mode", "graph"),
                      ("optimizer", "adam")]),
         OrderedDict([("distribution", "d1"), ("mode", "graph"),
                      ("optimizer", "gd")]),
         OrderedDict([("distribution", "d1"), ("mode", "eager"),
                      ("optimizer", "adam")]),
         OrderedDict([("distribution", "d1"), ("mode", "eager"),
                      ("optimizer", "gd")]),
         OrderedDict([("distribution", "d2"), ("mode", "graph"),
                      ("optimizer", "adam")]),
         OrderedDict([("distribution", "d2"), ("mode", "graph"),
                      ("optimizer", "gd")]),
         OrderedDict([("distribution", "d2"), ("mode", "eager"),
                      ("optimizer", "adam")]),
         OrderedDict([("distribution", "d2"), ("mode", "eager"),
                      ("optimizer", "gd")])
     ], c4)
     self.assertEqual(
         combinations.combine(mode=["graph", "eager"],
                              optimizer=["adam", "gd"],
                              distribution=["d1", "d2"]), c4)
class DistributionStrategyStatefulLstmModelCorrectnessTest(
        keras_correctness_test_base.
        TestDistributionStrategyEmbeddingModelCorrectnessBase):
    def get_model(self,
                  max_words=10,
                  initial_weights=None,
                  distribution=None,
                  input_shapes=None):
        del input_shapes
        batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE

        with keras_correctness_test_base.MaybeDistributionScope(distribution):
            word_ids = keras.layers.Input(shape=(max_words, ),
                                          batch_size=batch_size,
                                          dtype=np.int32,
                                          name='words')
            word_embed = keras.layers.Embedding(input_dim=20,
                                                output_dim=10)(word_ids)
            lstm_embed = keras.layers.LSTM(units=4,
                                           return_sequences=False,
                                           stateful=True)(word_embed)

            preds = keras.layers.Dense(2, activation='softmax')(lstm_embed)
            model = keras.Model(inputs=[word_ids], outputs=[preds])

            if initial_weights:
                model.set_weights(initial_weights)

            optimizer_fn = gradient_descent_keras.SGD

            model.compile(optimizer=optimizer_fn(learning_rate=0.1),
                          loss='sparse_categorical_crossentropy',
                          metrics=['sparse_categorical_accuracy'])
        return model

    # TODO(jhseu): Disabled to fix b/130808953. Need to investigate why it
    # doesn't work and enable for DistributionStrategy more generally.
    @combinations.generate(test_combinations_for_stateful_embedding_model())
    def disabled_test_stateful_lstm_model_correctness(self, distribution,
                                                      use_numpy,
                                                      use_validation_data):
        self.run_correctness_test(distribution,
                                  use_numpy,
                                  use_validation_data,
                                  is_stateful_model=True)

    @combinations.generate(
        combinations.times(keras_correctness_test_base.
                           test_combinations_with_tpu_strategies()))
    def test_incorrectly_use_multiple_cores_for_stateful_lstm_model(
            self, distribution, use_numpy, use_validation_data):
        with self.assertRaisesRegexp(
                ValueError, 'RNNs with stateful=True not yet supported with '
                'tf.distribute.Strategy.'):
            self.run_correctness_test(distribution,
                                      use_numpy,
                                      use_validation_data,
                                      is_stateful_model=True)
class KerasExperimentalSaveLoadTest(test_base.TestSavedModelBase):
    def setUp(self):
        self._root_dir = 'keras_experimental_save_load'
        super(KerasExperimentalSaveLoadTest, self).setUp()

    def _save_model(self, model, saved_dir):
        saved_model.export_saved_model(model, saved_dir)

    def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
                            output_name, experimental_run_tf_function):
        restored_keras_model = saved_model.load_from_saved_model(saved_dir)
        restored_keras_model._experimental_run_tf_function = (
            experimental_run_tf_function)
        return restored_keras_model.predict(predict_dataset,
                                            steps=test_base.PREDICT_STEPS)

    @combinations.generate(test_base.simple_models_with_strategies())
    def test_save_no_strategy_restore_strategy(self, model_and_input,
                                               distribution,
                                               experimental_run_tf_function):
        self.run_test_save_no_strategy_restore_strategy(
            model_and_input, distribution, experimental_run_tf_function)

    @combinations.generate(
        combinations.times(test_base.simple_models_with_strategies(),
                           combinations.combine(save_in_scope=[True, False])))
    def test_save_strategy_restore_no_strategy(self, model_and_input,
                                               distribution, save_in_scope,
                                               experimental_run_tf_function):
        self.run_test_save_strategy_restore_no_strategy(
            model_and_input, distribution, save_in_scope,
            experimental_run_tf_function)

    @combinations.generate(
        combinations.times(test_base.simple_models_with_strategy_pairs(),
                           combinations.combine(save_in_scope=[True, False])))
    def test_save_strategy_restore_strategy(self, model_and_input,
                                            distribution_for_saving,
                                            distribution_for_restoring,
                                            save_in_scope,
                                            experimental_run_tf_function):
        self.run_test_save_strategy_restore_strategy(
            model_and_input, distribution_for_saving,
            distribution_for_restoring, save_in_scope,
            experimental_run_tf_function)
def test_combinations_with_tpu_strategies():
  tpu_strategies = [
      strategy_combinations.tpu_strategy,
      strategy_combinations.tpu_strategy_one_step
  ]

  return (combinations.times(
      combinations.combine(distribution=tpu_strategies),
      graph_mode_test_configuration()))
def test_combinations_with_tpu_strategies():
  tpu_strategies = [
      strategy_combinations.tpu_strategy,
      strategy_combinations.tpu_strategy_one_step
  ]

  return (
      combinations.times(
          combinations.combine(distribution=tpu_strategies),
          graph_mode_test_configuration()))
def strategy_and_optimizer_combinations():
  return combinations.times(
      all_strategy_combinations(),
      combinations.combine(optimizer=[
          strategy_combinations.adagrad_optimizer_v1_fn,
          strategy_combinations.adagrad_optimizer_keras_v2_fn,
          strategy_combinations.adam_optimizer_v1_fn,
          strategy_combinations.adam_optimizer_keras_v2_fn,
          strategy_combinations.gradient_descent_optimizer_v1_fn,
          strategy_combinations.gradient_descent_optimizer_keras_v2_fn,
          strategy_combinations.rmsprop_optimizer_v1_fn,
          strategy_combinations.rmsprop_optimizer_keras_v2_fn
      ]))
def strategy_and_optimizer_combinations():
    return combinations.times(
        all_strategy_combinations(),
        combinations.combine(optimizer=[
            strategy_combinations.adagrad_optimizer_v1_fn,
            strategy_combinations.adagrad_optimizer_keras_v2_fn,
            strategy_combinations.adam_optimizer_v1_fn,
            strategy_combinations.adam_optimizer_keras_v2_fn,
            strategy_combinations.gradient_descent_optimizer_v1_fn,
            strategy_combinations.gradient_descent_optimizer_keras_v2_fn,
            strategy_combinations.rmsprop_optimizer_v1_fn,
            strategy_combinations.rmsprop_optimizer_keras_v2_fn
        ]))
class MonitorTest(test.TestCase, parameterized.TestCase):

  @combinations.generate(
      combinations.times(
          strategy_combinations.distributions_and_v1_optimizers(),
          combinations.combine(
              mode=strategy_combinations.graph_and_eager_modes)))
  def testTrainNetwork(self, distribution, optimizer_fn):
    with distribution.scope():
      single_loss_step, layer = single_loss_example(optimizer_fn, distribution)

      if context.executing_eagerly():
        monitor = monitor_lib.Monitor(single_loss_step, None)
      else:
        with self.cached_session() as sess:
          monitor = monitor_lib.Monitor(single_loss_step, sess)

      monitor.run_steps(1)

      self.assertEqual(1, len(layer.trainable_variables))
      mirrored_weight_variable = layer.trainable_variables[0]
      start_error = self.evaluate(mirrored_weight_variable)
      start_error = abs(numpy.array(start_error) - 1)

      monitor.run_steps(9)
      end_error = self.evaluate(mirrored_weight_variable)
      end_error = abs(numpy.array(end_error) - 1)
      self.assertGreaterEqual(start_error, end_error)

  def testPassingASessionInEager(self):
    distribution = one_device_strategy.OneDeviceStrategy(
        "/device:CPU:0")
    step_function, _ = single_loss_example(
        lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution)

    with session.Session() as sess, context.eager_mode():
      with self.assertRaisesRegexp(ValueError, "Should not provide"):
        _ = monitor_lib.Monitor(step_function, sess)

  def testNotPassingASessionInGraph(self):
    distribution = one_device_strategy.OneDeviceStrategy(
        "/device:CPU:0")
    step_function, _ = single_loss_example(
        lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution)

    with context.graph_mode(), ops.Graph().as_default():
      with self.assertRaisesRegexp(ValueError, "Should provide"):
        _ = monitor_lib.Monitor(step_function, session=None)
Example #26
0
class TestDistributionStrategyWithNormalizationLayer(test.TestCase,
                                                     parameterized.TestCase):
    @combinations.generate(
        combinations.times(
            keras_test_lib.all_strategy_combinations(),
            combinations.combine(fused=[True, False]),
            combinations.combine(cloning=True,
                                 optimizer=strategy_combinations.
                                 gradient_descent_optimizer_v1_fn) +
            combinations.combine(cloning=False,
                                 optimizer=strategy_combinations.
                                 gradient_descent_optimizer_keras_v2_fn)))
    def test_batchnorm_correctness(self, distribution, fused, optimizer,
                                   cloning):
        with self.cached_session():
            with distribution.scope():
                model = keras.models.Sequential()
                norm = keras.layers.BatchNormalization(input_shape=(
                    10,
                    20,
                    30,
                ),
                                                       momentum=0.8,
                                                       fused=fused)
                model.add(norm)
                model.compile(loss='mse',
                              optimizer=optimizer(),
                              cloning=cloning)

            # centered on 5.0, variance 10.0
            x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10, 20, 30))
            x = x.astype('float32')
            dataset = dataset_ops.Dataset.from_tensor_slices((x, x))
            dataset = dataset.repeat(100)
            dataset = keras_test_lib.batch_wrapper(dataset, 32, distribution)

            predict_dataset = dataset_ops.Dataset.from_tensor_slices(x)
            predict_dataset = predict_dataset.repeat(100)
            predict_dataset = keras_test_lib.batch_wrapper(
                predict_dataset, 32, distribution)

            model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10)
            out = model.predict(predict_dataset, steps=2)
            out -= keras.backend.eval(norm.beta)
            out /= keras.backend.eval(norm.gamma)
            np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
            np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
Example #27
0
 def test_times(self):
     c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"])
     c2 = combinations.combine(mode=["eager"], loss=["callable"])
     c3 = combinations.combine(distribution=["d1", "d2"])
     c4 = combinations.times(c3, c1 + c2)
     self.assertEqual([
         OrderedDict([("distribution", "d1"), ("loss", "callable"),
                      ("mode", "graph")]),
         OrderedDict([("distribution", "d1"), ("loss", "tensor"),
                      ("mode", "graph")]),
         OrderedDict([("distribution", "d1"), ("loss", "callable"),
                      ("mode", "eager")]),
         OrderedDict([("distribution", "d2"), ("loss", "callable"),
                      ("mode", "graph")]),
         OrderedDict([("distribution", "d2"), ("loss", "tensor"),
                      ("mode", "graph")]),
         OrderedDict([("distribution", "d2"), ("loss", "callable"),
                      ("mode", "eager")])
     ], c4)
 def test_times(self):
   c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"])
   c2 = combinations.combine(mode=["eager"], loss=["callable"])
   c3 = combinations.combine(distribution=["d1", "d2"])
   c4 = combinations.times(c3, c1 + c2)
   self.assertEqual([
       OrderedDict([("distribution", "d1"), ("loss", "callable"),
                    ("mode", "graph")]),
       OrderedDict([("distribution", "d1"), ("loss", "tensor"),
                    ("mode", "graph")]),
       OrderedDict([("distribution", "d1"), ("loss", "callable"),
                    ("mode", "eager")]),
       OrderedDict([("distribution", "d2"), ("loss", "callable"),
                    ("mode", "graph")]),
       OrderedDict([("distribution", "d2"), ("loss", "tensor"),
                    ("mode", "graph")]),
       OrderedDict([("distribution", "d2"), ("loss", "callable"),
                    ("mode", "eager")])
   ], c4)
Example #29
0
class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
    @combinations.generate(
        combinations.times(
            distributions_and_v2_optimizers(),
            combinations.combine(mode=["graph"],
                                 use_callable_loss=[True, False]) +
            combinations.combine(mode=["eager"], use_callable_loss=[True])))
    def testTrainNetwork(self,
                         distribution,
                         optimizer_fn,
                         use_callable_loss=True):
        with distribution.scope():
            model_fn, dataset_fn, layer = minimize_loss_example(
                optimizer_fn,
                use_bias=True,
                use_callable_loss=use_callable_loss)
            iterator = distribution.make_input_fn_iterator(
                lambda _: dataset_fn())

            def run_step():
                return control_flow_ops.group(
                    distribution.experimental_local_results(
                        distribution.extended.call_for_each_replica(
                            model_fn, args=(iterator.get_next(), ))))

            if not context.executing_eagerly():
                with self.cached_session() as sess:
                    sess.run(iterator.initialize())
                    run_step = sess.make_callable(run_step())
                self.evaluate(variables.global_variables_initializer())

            weights, biases = [], []
            for _ in range(10):
                run_step()

                weights.append(self.evaluate(layer.kernel))
                biases.append(self.evaluate(layer.bias))

            error = abs(
                numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
            is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
            self.assertTrue(is_not_increasing)
Example #30
0
class TestDistributionStrategyWithLossMasking(test.TestCase,
                                              parameterized.TestCase):

    # TODO(priyag): Enable all strategies for this test. Currently it does not
    # work for TPU due to some invalid datatype.
    @combinations.generate(
        combinations.times(
            combinations.combine(distribution=[
                strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
            ],
                                 mode=['graph', 'eager']),
            combinations.combine(cloning=True,
                                 optimizer=strategy_combinations.
                                 gradient_descent_optimizer_v1_fn) +
            combinations.combine(cloning=False,
                                 optimizer=strategy_combinations.
                                 gradient_descent_optimizer_keras_v2_fn)))
    def test_masking(self, distribution, cloning, optimizer):
        with self.cached_session():
            np.random.seed(1337)
            x = np.array([[[1], [1]], [[0], [0]]])
            with distribution.scope():
                model = keras.models.Sequential()
                model.add(
                    keras.layers.Masking(mask_value=0, input_shape=(2, 1)))
                model.add(
                    keras.layers.TimeDistributed(
                        keras.layers.Dense(1, kernel_initializer='one')))
                model.compile(loss='mse',
                              optimizer=optimizer(),
                              cloning=cloning)
            y = np.array([[[1], [1]], [[1], [1]]])
            dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
            dataset = dataset.repeat(100)
            dataset = dataset.batch(10)
            hist = model.fit(x=dataset, epochs=1, steps_per_epoch=2)
            self.assertEqual(hist.history['loss'][0], 0)
def all_strategy_and_eager_plus_graph():
  return combinations.times(
      combinations.combine(distribution=contrib_mirrored_strategies),
      combinations.combine(mode=["eager", "graph"]))
Example #32
0
def strategy_minus_tpu_and_input_config_combinations_eager():
  return (combinations.times(
      combinations.combine(
          distribution=strategy_combinations.strategies_minus_tpu),
      eager_mode_test_configuration()))
Example #33
0
def all_strategy_and_input_config_combinations():
  return (combinations.times(
      combinations.combine(
          distribution=all_strategies, cloning=[True, False]),
      eager_mode_test_configuration() + graph_mode_test_configuration()))
class TestDistributionStrategyWithCallbacks(test.TestCase,
                                            parameterized.TestCase):

  @combinations.generate(
      combinations.times(keras_test_lib.all_strategy_combinations(),
                         combinations.combine(cloning=[True, False])))
  def test_callbacks_in_fit(self, distribution, cloning):
    with distribution.scope():
      model = keras_test_lib.get_model()
      model.compile(
          optimizer='sgd', loss='mse', metrics=['mae'], cloning=cloning)

    dataset = keras_test_lib.get_dataset(distribution)
    counter = Counter()

    epochs = 2
    steps_per_epoch = 5
    validation_steps = 3

    model.fit(
        dataset,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        verbose=0,
        validation_data=dataset,
        validation_steps=validation_steps,
        callbacks=[counter])

    if (isinstance(distribution, tpu_strategy.TPUStrategyV1) and
        not context.executing_eagerly()):
      # TPU Strategy can have multi step training, from extended.steps_per_run
      # if steps_per_run = 1, then num_batch_call_per_epoch = steps_per_epoch
      steps_per_run = distribution.extended.steps_per_run
      num_batch_call_per_epoch = steps_per_epoch // steps_per_run
      if steps_per_epoch % steps_per_run:
        num_batch_call_per_epoch += 1
    else:
      num_batch_call_per_epoch = steps_per_epoch

    self.assertDictEqual(
        counter.method_counts, {
            'on_batch_begin': epochs * num_batch_call_per_epoch,
            'on_batch_end': epochs * num_batch_call_per_epoch,
            'on_epoch_begin': epochs,
            'on_epoch_end': epochs,
            'on_test_batch_begin': epochs * validation_steps,
            'on_test_batch_end': epochs * validation_steps,
            'on_test_begin': epochs,
            'on_test_end': epochs,
            'on_train_batch_begin': epochs * num_batch_call_per_epoch,
            'on_train_batch_end': epochs * num_batch_call_per_epoch,
            'on_train_begin': 1,
            'on_train_end': 1
        })

  @combinations.generate(
      combinations.times(keras_test_lib.all_strategy_combinations(),
                         combinations.combine(cloning=[True, False])))
  def test_callbacks_in_eval(self, distribution, cloning):
    with distribution.scope():
      model = keras_test_lib.get_model()
      model.compile(
          optimizer='sgd', loss='mse', metrics=['mae'], cloning=cloning)

    dataset = keras_test_lib.get_dataset(distribution)
    counter = Counter()

    model.evaluate(dataset, steps=5, callbacks=[counter])

    self.assertDictEqual(
        counter.method_counts, {
            'on_test_batch_begin': 5,
            'on_test_batch_end': 5,
            'on_test_begin': 1,
            'on_test_end': 1
        })

  @combinations.generate(
      combinations.times(keras_test_lib.all_strategy_combinations(),
                         combinations.combine(cloning=[True, False])))
  def test_callbacks_in_predict(self, distribution, cloning):
    with distribution.scope():
      model = keras_test_lib.get_model()
      model.compile(
          optimizer='sgd', loss='mse', metrics=['mae'], cloning=cloning)

    dataset = keras_test_lib.get_dataset(distribution)
    counter = Counter()

    model.predict(
        keras_test_lib.get_predict_dataset(dataset),
        steps=5,
        callbacks=[counter])

    self.assertDictEqual(
        counter.method_counts, {
            'on_predict_batch_begin': 5,
            'on_predict_batch_end': 5,
            'on_predict_begin': 1,
            'on_predict_end': 1
        })
def all_strategies_excluding_tpu_and_input_config_combinations():
  return (combinations.times(
      combinations.combine(
          distribution=strategy_combinations.strategies_minus_tpu),
      eager_mode_test_configuration() + graph_mode_test_configuration()))
class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
    def _get_iterator(self, strategy, input_fn):
        iterator = strategy.make_input_fn_iterator(lambda _: input_fn())
        self.evaluate(iterator.initializer)
        return iterator

    @combinations.generate(
        combinations.times(
            strategy_combinations.distributions_and_v1_optimizers(),
            combinations.combine(mode=["graph"],
                                 use_callable_loss=[True, False]) +
            combinations.combine(mode=["eager"], use_callable_loss=[True])) +
        combinations.times(
            strategy_combinations.distributions_and_v2_optimizers(),
            combinations.combine(mode=["graph", "eager"],
                                 use_callable_loss=[True])) +
        combinations.combine(distribution=[strategy_combinations.tpu_strategy],
                             optimizer_fn=strategy_combinations.optimizers_v2,
                             mode=["graph"],
                             use_callable_loss=[True]) +
        combinations.combine(distribution=[strategy_combinations.tpu_strategy],
                             optimizer_fn=strategy_combinations.optimizers_v1,
                             mode=["graph"],
                             use_callable_loss=[True, False]))
    def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss):
        with distribution.scope():
            optimizer = optimizer_fn()
            model_fn, dataset_fn, layer = minimize_loss_example(
                optimizer, use_bias=True, use_callable_loss=use_callable_loss)

            def step_fn(ctx, inputs):
                del ctx  # Unused
                return distribution.group(
                    distribution.extended.call_for_each_replica(
                        model_fn, args=(inputs, )))

            iterator = self._get_iterator(distribution, dataset_fn)

            def run_step():
                return distribution.extended.experimental_run_steps_on_iterator(
                    step_fn, iterator, iterations=2).run_op

            if not context.executing_eagerly():
                with self.cached_session() as sess:
                    run_step = sess.make_callable(run_step())
            self.evaluate(variables_lib.global_variables_initializer())

            weights, biases = [], []
            for _ in range(5):
                run_step()
                weights.append(self.evaluate(layer.kernel))
                biases.append(self.evaluate(layer.bias))

            error = abs(
                numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
            is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
            self.assertTrue(is_not_increasing)

    @combinations.generate(
        combinations.times(
            strategy_combinations.distributions_and_v1_optimizers(),
            combinations.combine(mode=["graph"],
                                 use_callable_loss=[True, False]) +
            combinations.combine(mode=["eager"], use_callable_loss=[True])) +
        combinations.times(
            strategy_combinations.distributions_and_v2_optimizers(),
            combinations.combine(mode=["graph", "eager"],
                                 use_callable_loss=[True])))
    def testTrainNetworkByCallForEachReplica(self, distribution, optimizer_fn,
                                             use_callable_loss):
        with distribution.scope():
            optimizer = optimizer_fn()
            model_fn, dataset_fn, layer = minimize_loss_example(
                optimizer, use_bias=True, use_callable_loss=use_callable_loss)

            iterator = self._get_iterator(distribution, dataset_fn)

            def run_step():
                return distribution.group(
                    distribution.extended.call_for_each_replica(
                        model_fn, args=(iterator.get_next(), )))

            if not context.executing_eagerly():
                with self.cached_session() as sess:
                    run_step = sess.make_callable(run_step())
                self.evaluate(variables_lib.global_variables_initializer())

            weights, biases = [], []
            for _ in range(10):
                run_step()

                weights.append(self.evaluate(layer.kernel))
                biases.append(self.evaluate(layer.bias))

            error = abs(
                numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
            is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
            self.assertTrue(is_not_increasing)

    @combinations.generate(
        combinations.times(
            strategy_combinations.distributions_and_v1_and_v2_optimizers(),
            combinations.combine(mode=["graph", "eager"])) +
        combinations.combine(
            distribution=[strategy_combinations.tpu_strategy],
            optimizer_fn=strategy_combinations.optimizers_v1_and_v2,
            mode=["graph"]))
    def testOptimizerInsideModelFn(self, distribution, optimizer_fn):
        if (not context.executing_eagerly()
                and control_flow_v2_toggles.control_flow_v2_enabled()):
            self.skipTest("b/138751864")
        created_variables = []
        trainable_variables = []

        def appending_creator(next_creator, **kwargs):
            v = next_creator(**kwargs)
            created_variables.append(v.name)
            if "trainable" in kwargs and kwargs["trainable"]:
                trainable_variables.append(v.name)
            return v

        # Creator scope needs to be set before it's used inside
        # `distribution.scope`.
        with variable_scope.variable_creator_scope(
                appending_creator), distribution.scope():
            optimizer = optimizer_fn()
            model_fn, dataset_fn, _ = minimize_loss_example(
                optimizer, use_bias=True, use_callable_loss=True)

            def step_fn(ctx, inputs):
                del ctx  # Unused
                return distribution.group(
                    distribution.extended.call_for_each_replica(
                        model_fn, args=(inputs, )))

            iterator = self._get_iterator(distribution, dataset_fn)

            def run_step():
                return distribution.extended.experimental_run_steps_on_iterator(
                    step_fn, iterator, iterations=1).run_op

            if not context.executing_eagerly():
                with self.cached_session() as sess:
                    run_step = sess.make_callable(run_step())
            self.evaluate(variables_lib.global_variables_initializer())
            run_step()

            def get_expected_variables(num_parameter_devices):
                name = optimizer._name

                if isinstance(optimizer, optimizer_v2.OptimizerV2):
                    variables = VAR_MAP_V2[name]
                else:
                    variables = VAR_MAP_V1[name]

                extended_variables = [
                    v + "/replica_{}".format(replica) for v in variables
                    for replica in range(1, num_parameter_devices)
                ]
                variables = list(variables) + extended_variables
                return set(v + ":0" for v in variables)

            self.assertEqual(
                get_expected_variables(
                    len(distribution.extended.parameter_devices)),
                set(created_variables))

    @combinations.generate(
        combinations.times(
            combinations.combine(momentum=[0.8, 0.9, 0.99],
                                 renorm=[False, True]),
            combinations.times(
                strategy_combinations.distributions_and_v1_and_v2_optimizers(),
                combinations.combine(
                    mode=["graph", "eager"],
                    # TODO(isaprykin):  Allow False here.  Currently subsequent
                    # replicas will re-execute UPDATE_OPS of previous replicas.
                    update_ops_in_cross_replica_mode=[True])) +
            combinations.combine(
                distribution=[strategy_combinations.tpu_strategy],
                optimizer_fn=strategy_combinations.optimizers_v1_and_v2,
                mode=["graph"],
                update_ops_in_cross_replica_mode=[False])))
    def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn,
                                      momentum, renorm,
                                      update_ops_in_cross_replica_mode):
        """Verifies that moving mean updates are reduced across replicas."""
        with distribution.scope():
            num_replicas = distribution.num_replicas_in_sync
            model_fn, dataset_fn, batchnorm = batchnorm_example(
                optimizer_fn,
                batch_per_epoch=num_replicas,
                momentum=momentum,
                renorm=renorm,
                update_ops_in_replica_mode=not update_ops_in_cross_replica_mode
            )

            def step_fn(ctx, inputs):
                del ctx  # Unused
                fetches = distribution.experimental_local_results(
                    distribution.extended.call_for_each_replica(
                        model_fn, args=(inputs, )))
                if update_ops_in_cross_replica_mode:
                    fetches += tuple(
                        ops.get_collection(ops.GraphKeys.UPDATE_OPS))
                return control_flow_ops.group(fetches)

            iterator = self._get_iterator(distribution, dataset_fn)

            def run_step():
                return distribution.extended.experimental_run_steps_on_iterator(
                    step_fn, iterator, iterations=1).run_op

            if not context.executing_eagerly():
                with self.cached_session() as sess:
                    run_step = sess.make_callable(run_step())
            self.evaluate(variables_lib.global_variables_initializer())

            expected_moving_means = [0.] * 8

            def averaged_batch_mean(i):
                # Each batch has shape [16, 8] where the ith element in jth list is
                # (8 * j + i + replica_id * 100). So the batch mean in each replica is
                # (60 + i + replica_id * 100). So here comes its batch mean over all
                # replicas:
                return 60. + i + (num_replicas - 1.) / 2. * 100.

            for _ in range(10):
                run_step()
                moving_means = self.evaluate(batchnorm.moving_mean)

                # We make sure that the moving_mean is updated as if the sample mean is
                # calculated over all replicas.
                for i, expected_moving_mean in enumerate(
                        expected_moving_means):
                    expected_moving_means[i] -= (
                        (expected_moving_mean - averaged_batch_mean(i)) *
                        (1.0 - momentum))
                    self.assertNear(expected_moving_means[i], moving_means[i],
                                    0.0001)

    @combinations.generate(
        combinations.times(
            combinations.combine(loss_reduction=[
                losses_impl.Reduction.SUM, losses_impl.Reduction.MEAN,
                losses_impl.Reduction.SUM_OVER_BATCH_SIZE,
                losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS
            ]),
            combinations.times(
                combinations.combine(distribution=[
                    strategy_combinations.one_device_strategy,
                    strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
                    strategy_combinations.mirrored_strategy_with_two_gpus
                ]),
                combinations.times(
                    combinations.combine(optimizer_fn=strategy_combinations.
                                         gradient_descent_optimizer_v1_fn),
                    combinations.combine(mode=["graph"],
                                         use_callable_loss=[True, False]) +
                    combinations.combine(mode=["eager"],
                                         use_callable_loss=[True])) +
                combinations.times(
                    combinations.combine(
                        optimizer_fn=strategy_combinations.
                        gradient_descent_optimizer_keras_v2_fn),
                    combinations.combine(mode=["graph", "eager"],
                                         use_callable_loss=[True]))) +
            combinations.combine(
                distribution=[strategy_combinations.tpu_strategy],
                optimizer_fn=strategy_combinations.
                gradient_descent_optimizer_v1_fn,
                mode=["graph"],
                use_callable_loss=[True, False]) + combinations.combine(
                    distribution=[strategy_combinations.tpu_strategy],
                    optimizer_fn=strategy_combinations.
                    gradient_descent_optimizer_keras_v2_fn,
                    mode=["graph"],
                    use_callable_loss=[True])))
    def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction,
                      use_callable_loss):
        with distribution.scope():
            all_vars = []

            def model_fn(inputs):
                x, y = inputs
                w = variable_scope.get_variable("w", initializer=[[2.]])
                all_vars.append(w)

                def loss_fn():
                    # Use fixed initialization to make the steps deterministic.
                    predict = math_ops.matmul(x, w)
                    loss = losses_impl.mean_squared_error(
                        y, predict, reduction=loss_reduction)
                    if loss_reduction == losses_impl.Reduction.SUM:
                        return loss
                    return loss / distribution.num_replicas_in_sync

                optimizer = optimizer_fn(
                )  # GradientDescent with 0.2 learning rate

                if isinstance(optimizer, optimizer_v2.OptimizerV2):
                    return optimizer.minimize(loss_fn, [w])
                else:
                    if use_callable_loss:
                        return optimizer.minimize(loss_fn)
                    else:
                        return optimizer.minimize(loss_fn())

            def dataset_fn():
                features = dataset_ops.Dataset.from_tensors([[2.], [7.]])
                labels = dataset_ops.Dataset.from_tensors([[6.], [21.]])
                return dataset_ops.Dataset.zip((features, labels)).repeat()

            def step_fn(ctx, inputs):
                del ctx  # Unused
                return distribution.group(
                    distribution.extended.call_for_each_replica(
                        model_fn, args=(inputs, )))

            iterator = self._get_iterator(distribution, dataset_fn)

            def run_step():
                return distribution.extended.experimental_run_steps_on_iterator(
                    step_fn, iterator, iterations=1).run_op

            if not context.executing_eagerly():
                with self.cached_session() as sess:
                    run_step = sess.make_callable(run_step())
            self.evaluate(variables_lib.global_variables_initializer())

            run_step()

            v = all_vars[0]
            self.assertTrue(all(v is vi for vi in all_vars[1:]))
            weight = numpy.squeeze(self.evaluate(v))
            # Our model is:
            #   predict = x * w
            #   loss = (predict - y)^2
            #   dloss/dpredict = 2*(predict - y)
            #   dloss/dw = 2 * x^T @ (predict - y)
            # For our batch size of 2, assuming sum loss reduction:
            #   x = [2, 7]
            #   y = [6, 21]
            #   w_initial = 2
            #   predict = [4, 14]
            #   predict - y = [-2, -7]
            #   dloss/dw = 2 <[2, 7], [-2, -7]> = - 2(4 + 49) = -106
            # So unreplicated the update to w with lr=0.001 is -0.2 * -106 = 0.106
            # with sum loss reduction, or 0.053 with mean.
            if loss_reduction == losses_impl.Reduction.SUM:
                # Note that the "distribution.num_replicas_in_sync" factor will go away
                # once we split the input across replicas, instead of pulling a complete
                # batch of input per replica.
                self.assertNear(weight,
                                2 + 0.106 * distribution.num_replicas_in_sync,
                                0.0001)
            else:
                # One of the mean loss reductions.
                self.assertNear(weight, 2 + 0.053, 0.0001)

    @combinations.generate(
        combinations.times(
            strategy_combinations.distributions_and_v1_and_v2_optimizers(),
            combinations.combine(mode=["graph", "eager"]),
            combinations.combine(is_tpu=[False])) + combinations.combine(
                distribution=[strategy_combinations.tpu_strategy],
                optimizer_fn=strategy_combinations.optimizers_v1_and_v2,
                mode=["graph"],
                is_tpu=[True]))
    def testRunStepsWithOutputContext(self, distribution, optimizer_fn,
                                      is_tpu):
        with distribution.scope():

            def dataset_fn():
                dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
                # TODO(priyag): batch with drop_remainder=True causes shapes to be
                # fully defined for TPU. Remove this when XLA supports dynamic shapes.
                return dataset.batch(batch_size=1, drop_remainder=True)

            optimizer = optimizer_fn()
            layer = core.Dense(1, use_bias=True)

            key1 = "foo"
            value1 = "bar"

            def model_fn(output_context, x):
                """A very simple model written by the user."""
                def loss_fn():
                    y = array_ops.reshape(layer(x),
                                          []) - constant_op.constant(1.)
                    return y * y

                if isinstance(optimizer, optimizer_v2.OptimizerV2):
                    train_op = optimizer.minimize(
                        loss_fn, lambda: layer.trainable_variables)
                else:
                    train_op = optimizer.minimize(loss_fn)
                loss = loss_fn()
                output_context.set_last_step_output(
                    name="replica_loss_reduced",
                    output=loss,
                    reduce_op=reduce_util.ReduceOp.MEAN)
                output_context.set_non_tensor_output(key1, value1)
                return (train_op, loss)

            def step_fn(output_context, inputs):
                (train_op, loss) = distribution.extended.call_for_each_replica(
                    model_fn, args=(output_context, inputs))
                output_context.set_last_step_output(
                    name="cross_replica_loss_reduced",
                    output=loss,
                    reduce_op=reduce_util.ReduceOp.MEAN)
                output_context.set_last_step_output(
                    name="cross_replica_loss_not_reduced", output=loss)
                return distribution.group(train_op)

            iterator = self._get_iterator(distribution, dataset_fn)

            def run_step():
                initial_loss = lambda: constant_op.constant(1e7)
                # Initial values corresponding to reduced losses are just single
                # tensors. But for non reduced losses, we need to have initial
                # values that are of the same structure as non reduced losses. In
                # MirroredStrategy, this will be a list of losses, in TPUStrategy
                # it will be single tensor. Using `call_for_each_replica` followed
                # by `experimental_local_results` gives us the desired initial
                # value structure.
                not_reduced = distribution.experimental_local_results(
                    distribution.extended.call_for_each_replica(initial_loss))
                initial_loop_values = {
                    "replica_loss_reduced": initial_loss(),
                    "cross_replica_loss_reduced": initial_loss(),
                    "cross_replica_loss_not_reduced": not_reduced,
                }
                ctx = distribution.extended.experimental_run_steps_on_iterator(
                    step_fn,
                    iterator,
                    iterations=2,
                    initial_loop_values=initial_loop_values)

                self.assertEqual({key1: (value1, )}, ctx.non_tensor_outputs)
                self._verify_loss_output(
                    initial_loss(),
                    loss_output=ctx.last_step_outputs["replica_loss_reduced"],
                    reduced=True,
                    distribution=distribution)
                self._verify_loss_output(
                    initial_loss(),
                    loss_output=ctx.
                    last_step_outputs["cross_replica_loss_reduced"],
                    reduced=True,
                    distribution=distribution)
                self._verify_loss_output(
                    initial_loss(),
                    loss_output=ctx.
                    last_step_outputs["cross_replica_loss_not_reduced"],
                    reduced=False,
                    distribution=distribution)
                return (ctx.run_op,
                        ctx.last_step_outputs["replica_loss_reduced"])

            if not context.executing_eagerly():
                with self.cached_session() as sess:
                    run_step = sess.make_callable(run_step())
            self.evaluate(variables_lib.global_variables_initializer())

            weights, biases, losses = [], [], []
            for _ in range(5):
                _, loss = run_step()
                losses.append(loss)
                weights.append(self.evaluate(layer.kernel))
                biases.append(self.evaluate(layer.bias))

            loss_is_not_increasing = all(y <= x
                                         for x, y in zip(losses, losses[1:]))
            self.assertTrue(loss_is_not_increasing)

            error = abs(
                numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
            error_is_not_increasing = all(y <= x
                                          for x, y in zip(error, error[1:]))
            self.assertTrue(error_is_not_increasing)

    def _verify_loss_output(self, initial_loss, loss_output, reduced,
                            distribution):
        if not reduced:
            self.assertLen(
                distribution.experimental_local_results(loss_output),
                distribution.num_replicas_in_sync)
            loss_tensor = distribution.reduce(reduce_util.ReduceOp.MEAN,
                                              loss_output,
                                              axis=None)
        else:
            unwrapped_output = distribution.experimental_local_results(
                loss_output)
            self.assertLen(unwrapped_output, 1)
            loss_tensor = unwrapped_output[0]
        self.assertEqual(initial_loss.dtype, loss_tensor.dtype)
        self.assertEqual(initial_loss.shape, loss_tensor.shape)

    @combinations.generate(
        strategy_combinations.distributions_and_v2_optimizers())
    def test_empty_var_list(self, distribution, optimizer_fn):
        opt = optimizer_fn()
        with distribution.scope():

            def run_fn():
                opt.minimize(lambda: constant_op.constant(1.), [])
                opt.apply_gradients([])

            distribution.run(run_fn)
def all_strategies_excluding_tpu_and_input_config_combinations():
    return (combinations.times(
        combinations.combine(
            distribution=strategy_combinations.strategies_minus_tpu),
        eager_mode_test_configuration() + graph_mode_test_configuration()))
def all_strategy_and_input_config_combinations():
  return (combinations.times(
      combinations.combine(
          distribution=all_strategies, cloning=[True, False]),
      eager_mode_test_configuration() + graph_mode_test_configuration()))
class DistributionStrategyStatefulLstmModelCorrectnessTest(
    keras_correctness_test_base.
    TestDistributionStrategyEmbeddingModelCorrectnessBase):

  def get_model(self,
                max_words=10,
                initial_weights=None,
                distribution=None,
                cloning=None,
                input_shapes=None):
    del input_shapes
    batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE

    with keras_correctness_test_base.MaybeDistributionScope(distribution):
      word_ids = keras.layers.Input(
          shape=(max_words,),
          batch_size=batch_size,
          dtype=np.int32, name='words')
      word_embed = keras.layers.Embedding(input_dim=20,
                                          output_dim=10)(word_ids)
      lstm_embed = keras.layers.LSTM(units=4,
                                     return_sequences=False,
                                     stateful=True)(word_embed)

      preds = keras.layers.Dense(2, activation='softmax')(lstm_embed)
      model = keras.Model(inputs=[word_ids], outputs=[preds])

      if initial_weights:
        model.set_weights(initial_weights)

      # TODO(b/130808953): Re-enable the V1 optimizer after iterations
      # is mirrored.
      optimizer_fn = (
          gradient_descent.GradientDescentOptimizer
          if cloning else gradient_descent_keras.SGD)

      model.compile(
          optimizer=optimizer_fn(learning_rate=0.1),
          loss='sparse_categorical_crossentropy',
          metrics=['sparse_categorical_accuracy'])
    return model

  @combinations.generate(test_combinations_for_stateful_embedding_model())
  def test_stateful_lstm_model_correctness(self,
                                           distribution,
                                           use_numpy,
                                           use_validation_data,
                                           cloning):
    self.run_correctness_test(distribution, use_numpy, use_validation_data,
                              is_stateful_model=True, cloning=cloning)

  @combinations.generate(
      combinations.times(
          keras_correctness_test_base.test_combinations_with_tpu_strategies(),
          combinations.combine(cloning=[True, False])))
  def test_incorrectly_use_multiple_cores_for_stateful_lstm_model(
      self, distribution, use_numpy, use_validation_data, cloning):
    with self.assertRaisesRegexp(
        ValueError,
        'Single core must be used for computation on stateful models. Consider '
        'adding `device_assignment` parameter to TPUStrategy'):
      self.run_correctness_test(
          distribution,
          use_numpy,
          use_validation_data,
          is_stateful_model=True,
          cloning=cloning)
 def test_overlapping_keys(self):
   c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"])
   c2 = combinations.combine(mode=["eager"], loss=["callable"])
   with self.assertRaisesRegexp(ValueError, ".*Keys.+overlap.+"):
     _ = combinations.times(c1, c2)