Ejemplo n.º 1
0
class DistributionStrategyLstmModelCorrectnessTest(
        _DistributionStrategyRnnModelCorrectnessTest):
    def _get_layer_class(self):
        if tf2.enabled():
            if not context.executing_eagerly():
                self.skipTest(
                    "LSTM v2 and legacy graph mode don't work together.")
            return rnn_v2.LSTM
        else:
            return rnn_v1.LSTM

    @ds_combinations.generate(
        keras_correctness_test_base.test_combinations_for_embedding_model() +
        keras_correctness_test_base.multi_worker_mirrored_eager())
    def test_lstm_model_correctness(self, distribution, use_numpy,
                                    use_validation_data):
        self.run_correctness_test(distribution, use_numpy, use_validation_data)

    @ds_combinations.generate(
        keras_correctness_test_base.test_combinations_for_embedding_model() +
        keras_correctness_test_base.multi_worker_mirrored_eager())
    @testing_utils.enable_v2_dtype_behavior
    def test_lstm_model_correctness_mixed_precision(self, distribution,
                                                    use_numpy,
                                                    use_validation_data):
        if isinstance(distribution,
                      (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
            policy_name = 'mixed_bfloat16'
        else:
            policy_name = 'mixed_float16'

        with policy.policy_scope(policy_name):
            self.run_correctness_test(distribution, use_numpy,
                                      use_validation_data)
Ejemplo n.º 2
0
class DistributionStrategyEmbeddingModelCorrectnessTest(
    keras_correctness_test_base
    .TestDistributionStrategyEmbeddingModelCorrectnessBase):

  def get_model(self,
                max_words=10,
                initial_weights=None,
                distribution=None,
                run_distributed=None,
                input_shapes=None):
    del input_shapes
    with keras_correctness_test_base.MaybeDistributionScope(distribution):
      word_ids = keras.layers.Input(
          shape=(max_words,), dtype=np.int32, name='words')
      word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids)
      if self.use_distributed_dense:
        word_embed = keras.layers.TimeDistributed(keras.layers.Dense(4))(
            word_embed)
      avg = keras.layers.GlobalAveragePooling1D()(word_embed)
      preds = keras.layers.Dense(2, activation='softmax')(avg)
      model = keras.Model(inputs=[word_ids], outputs=[preds])

      if initial_weights:
        model.set_weights(initial_weights)

      model.compile(
          # TODO(b/130808953): Switch back the V1 optimizer once global_step is
          # mirrored.
          optimizer=gradient_descent_keras.SGD(learning_rate=0.1),
          loss='sparse_categorical_crossentropy',
          metrics=['sparse_categorical_accuracy'],
          run_distributed=run_distributed)
    return model

  @combinations.generate(
      keras_correctness_test_base.test_combinations_for_embedding_model())
  def test_embedding_model_correctness(self, distribution, use_numpy,
                                       use_validation_data, run_distributed):

    self.use_distributed_dense = False
    self.run_correctness_test(distribution, use_numpy, use_validation_data,
                              run_distributed)

  @combinations.generate(
      keras_correctness_test_base.test_combinations_for_embedding_model())
  def test_embedding_time_distributed_model_correctness(self, distribution,
                                                        use_numpy,
                                                        use_validation_data,
                                                        run_distributed):
    self.use_distributed_dense = True
    self.run_correctness_test(distribution, use_numpy, use_validation_data,
                              run_distributed)
class DistributionStrategyLstmModelCorrectnessTest(
        keras_correctness_test_base.
        TestDistributionStrategyEmbeddingModelCorrectnessBase):
    def get_model(self, max_words=10, initial_weights=None, distribution=None):
        with keras_correctness_test_base.MaybeDistributionScope(distribution):
            word_ids = keras.layers.Input(shape=(max_words, ),
                                          dtype=np.int32,
                                          name='words')
            word_embed = keras.layers.Embedding(input_dim=20,
                                                output_dim=10)(word_ids)
            lstm_embed = keras.layers.LSTM(units=4,
                                           return_sequences=False)(word_embed)

            preds = keras.layers.Dense(2, activation='softmax')(lstm_embed)
            model = keras.Model(inputs=[word_ids], outputs=[preds])

            if initial_weights:
                model.set_weights(initial_weights)

            model.compile(optimizer=gradient_descent.GradientDescentOptimizer(
                learning_rate=0.1),
                          loss='sparse_categorical_crossentropy',
                          metrics=['sparse_categorical_accuracy'])
        return model

    @combinations.generate(
        keras_correctness_test_base.test_combinations_for_embedding_model())
    def test_lstm_model_correctness(self, distribution, use_numpy,
                                    use_validation_data):
        self.run_correctness_test(distribution, use_numpy, use_validation_data)
Ejemplo n.º 4
0
class DistributionStrategyLstmModelCorrectnessTest(
    keras_correctness_test_base
    .TestDistributionStrategyEmbeddingModelCorrectnessBase):

  def get_model(self,
                max_words=10,
                initial_weights=None,
                distribution=None,
                experimental_run_tf_function=None,
                input_shapes=None):
    del input_shapes

    if tf2.enabled():
      if not context.executing_eagerly():
        self.skipTest("LSTM v2 and legacy graph mode don't work together.")
      lstm = rnn_v2.LSTM
    else:
      lstm = rnn_v1.LSTM

    with keras_correctness_test_base.MaybeDistributionScope(distribution):
      word_ids = keras.layers.Input(
          shape=(max_words,), dtype=np.int32, name='words')
      word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids)
      lstm_embed = lstm(units=4, return_sequences=False)(
          word_embed)

      preds = keras.layers.Dense(2, activation='softmax')(lstm_embed)
      model = keras.Model(inputs=[word_ids], outputs=[preds])

      if initial_weights:
        model.set_weights(initial_weights)

      optimizer_fn = gradient_descent_keras.SGD

      model.compile(
          optimizer=optimizer_fn(learning_rate=0.1),
          loss='sparse_categorical_crossentropy',
          metrics=['sparse_categorical_accuracy'],
          experimental_run_tf_function=experimental_run_tf_function)
    return model

  @combinations.generate(
      keras_correctness_test_base.test_combinations_for_embedding_model())
  def test_lstm_model_correctness(self, distribution, use_numpy,
                                  use_validation_data,
                                  experimental_run_tf_function):
    self.run_correctness_test(distribution, use_numpy, use_validation_data,
                              experimental_run_tf_function)
class DistributionStrategyGruModelCorrectnessTest(
        _DistributionStrategyRnnModelCorrectnessTest):
    def _get_layer_class(self):
        if tf2.enabled():
            if not context.executing_eagerly():
                self.skipTest(
                    "GRU v2 and legacy graph mode don't work together.")
            return rnn_v2.GRU
        else:
            return rnn_v1.GRU

    @ds_combinations.generate(
        keras_correctness_test_base.test_combinations_for_embedding_model())
    def test_gru_model_correctness(self, distribution, use_numpy,
                                   use_validation_data):
        self.run_correctness_test(distribution, use_numpy, use_validation_data)
class DistributionStrategyLstmModelCorrectnessTest(
    _DistributionStrategyRnnModelCorrectnessTest):

  def _get_layer_class(self):
    if tf2.enabled():
      if not context.executing_eagerly():
        self.skipTest("LSTM v2 and legacy graph mode don't work together.")
      return rnn_v2.LSTM
    else:
      return rnn_v1.LSTM

  @combinations.generate(
      keras_correctness_test_base.test_combinations_for_embedding_model())
  def test_lstm_model_correctness(self, distribution, use_numpy,
                                  use_validation_data,
                                  experimental_run_tf_function):
    self.run_correctness_test(distribution, use_numpy, use_validation_data,
                              experimental_run_tf_function)
class DistributionStrategySiameseEmbeddingModelCorrectnessTest(
        keras_correctness_test_base.
        TestDistributionStrategyEmbeddingModelCorrectnessBase):
    def get_model(self,
                  max_words=10,
                  initial_weights=None,
                  distribution=None,
                  input_shapes=None):
        del input_shapes
        with keras_correctness_test_base.MaybeDistributionScope(distribution):
            word_ids_a = keras.layers.Input(shape=(max_words, ),
                                            dtype=np.int32,
                                            name='words_a')
            word_ids_b = keras.layers.Input(shape=(max_words, ),
                                            dtype=np.int32,
                                            name='words_b')

            def submodel(embedding, word_ids):
                word_embed = embedding(word_ids)
                rep = keras.layers.GlobalAveragePooling1D()(word_embed)
                return keras.Model(inputs=[word_ids], outputs=[rep])

            word_embed = keras.layers.Embedding(
                input_dim=20,
                output_dim=10,
                input_length=max_words,
                embeddings_initializer=keras.initializers.RandomUniform(0, 1))

            a_rep = submodel(word_embed, word_ids_a).outputs[0]
            b_rep = submodel(word_embed, word_ids_b).outputs[0]
            sim = keras.layers.Dot(axes=1, normalize=True)([a_rep, b_rep])

            model = keras.Model(inputs=[word_ids_a, word_ids_b], outputs=[sim])

            if initial_weights:
                model.set_weights(initial_weights)

            # TODO(b/130808953): Switch back to the V1 optimizer after global_step
            # is made mirrored.
            model.compile(
                optimizer=gradient_descent_keras.SGD(learning_rate=0.1),
                loss='mse',
                metrics=['mse'])
        return model

    def get_data(self,
                 count=(keras_correctness_test_base._GLOBAL_BATCH_SIZE *
                        keras_correctness_test_base._EVAL_STEPS),
                 min_words=5,
                 max_words=10,
                 max_word_id=19,
                 num_classes=2):
        features_a, labels_a, _ = (super(
            DistributionStrategySiameseEmbeddingModelCorrectnessTest,
            self).get_data(count, min_words, max_words, max_word_id,
                           num_classes))

        features_b, labels_b, _ = (super(
            DistributionStrategySiameseEmbeddingModelCorrectnessTest,
            self).get_data(count, min_words, max_words, max_word_id,
                           num_classes))

        y_train = np.zeros((count, 1), dtype=np.float32)
        y_train[labels_a == labels_b] = 1.0
        y_train[labels_a != labels_b] = -1.0
        # TODO(b/123360757): Add tests for using list as inputs for multi-input
        # models.
        x_train = {
            'words_a': features_a,
            'words_b': features_b,
        }
        x_predict = x_train

        return x_train, y_train, x_predict

    @ds_combinations.generate(
        keras_correctness_test_base.test_combinations_for_embedding_model())
    def test_siamese_embedding_model_correctness(self, distribution, use_numpy,
                                                 use_validation_data):
        self.run_correctness_test(distribution, use_numpy, use_validation_data)