def test_cpu_multiple_creation(self):
        feature_config = (tpu_embedding_v2_utils.FeatureConfig(
            table=self.table_user, name='friends', max_sequence_length=2), )
        optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
        embedding_one = tpu_embedding_v2.TPUEmbedding(
            feature_config=feature_config, optimizer=optimizer)
        embedding_two = tpu_embedding_v2.TPUEmbedding(
            feature_config=feature_config, optimizer=optimizer)

        # Both of the tpu embedding tables should be able to build on cpu.
        embedding_one.build()
        embedding_two.build()
 def test_optimizer_with_slot_creation_fn_non_partial(self):
   def slot_creation_fn(table, slot_names, _):
     slots = {}
     for slot in slot_names:
       # Note that we don't pass functools.partial here, so on TPU we can't
       # extract the shape. We expect the error below.
       slots[slot] = tf_variables.Variable(
           name='{}_{}'.format(table.name, slot),
           initial_value=init_ops_v2.Zeros()(shape=table.shape,
                                             dtype=dtypes.float32),
           trainable=False)
     return slots
   optimizer = tpu_embedding_v2_utils.Adagrad(
       learning_rate=0.1,
       slot_variable_creation_fn=slot_creation_fn)
   strategy = self._get_strategy()
   with strategy.scope():
     mid_level_api = tpu_embedding_v2.TPUEmbedding(
         feature_config=self.feature_config,
         optimizer=optimizer)
     with self.assertRaisesRegex(ValueError,
                                 'Unable to extract initializer function'):
       # We aren't going to actually run anything, so the batch_size here does
       # not matter.
       mid_level_api.build(self.batch_size)
 def test_unsupported_optimizer(self):
   with self.assertRaisesRegex(
       ValueError, 'is an unsupported optimizer class.'):
     with self._get_strategy().scope():
       tpu_embedding_v2.TPUEmbedding(
           self.feature_config,
           tpu_embedding.AdagradParameters(learning_rate=0.1))
    def _create_mid_level(self, optimizer=None):
        # Create `TPUEmbedding` object.
        if optimizer is None:
            optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)

        return tpu_embedding_v2.TPUEmbedding(
            feature_config=self.feature_config, optimizer=optimizer)
Example #5
0
def create_mid_level(optimizer=None):
    # Create `TPUEmbedding` object.
    if optimizer is None:
        optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
    return tpu_embedding_v2.TPUEmbedding(feature_config=feature_config,
                                         batch_size=batch_size,
                                         optimizer=optimizer)
Example #6
0
  def build_mid_level(self, embedding_values, optimizer,
                      initialize_tpu_embedding=True):
    """Creates an embedding api object initialized to embedding_values."""
    initializer = init_ops_v2.Constant(embedding_values)

    table = tpu_embedding_v2_utils.TableConfig(
        vocabulary_size=self.num_rows, dim=4, initializer=initializer,
        combiner='sum', name='table')
    feature_config = (tpu_embedding_v2_utils.FeatureConfig(
        table=table, name='feature'),)

    mid_level = tpu_embedding_v2.TPUEmbedding(
        feature_config, optimizer)

    # We want to create a second object (with its own variables) but not
    # initialize the TPU.
    if not initialize_tpu_embedding:
      saved_fn = tpu.initialize_system_for_tpu_embedding
      tpu.initialize_system_for_tpu_embedding = lambda x: None

    # batch_size here does not matter as we aren't training in any of these
    # tests.
    mid_level.build(64)

    if not initialize_tpu_embedding:
      tpu.initialize_system_for_tpu_embedding = saved_fn

    return mid_level
Example #7
0
  def test_cpu_sequence_lookup_ragged(self):
    feature_config = (
        tpu_embedding_v2_utils.FeatureConfig(
            table=self.table_video, name='watched', max_sequence_length=2),)
    optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
    mid_level = tpu_embedding_v2.TPUEmbedding(
        feature_config=feature_config,
        optimizer=optimizer)
    features = self._get_ragged_tensors()[:1]
    result = tpu_embedding_v2.cpu_embedding_lookup(
        features,
        weights=None,
        tables=mid_level.embedding_tables,
        feature_config=feature_config)

    sparse_ver = features[0].to_sparse()
    golden = self._numpy_sequence_lookup(
        mid_level.embedding_tables[self.table_video].numpy(),
        sparse_ver.indices.numpy(),
        sparse_ver.values.numpy(),
        self.data_batch_size,
        feature_config[0].max_sequence_length,
        self.table_video.dim)

    self.assertAllClose(result[0], golden)
    def test_multiple_creation(self):
        feature_config = tpu_embedding_v2_utils.FeatureConfig(
            table=self.table_user, name='friends', max_sequence_length=2)
        optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
        strategy = self._get_strategy()
        with strategy.scope():
            embedding_one = tpu_embedding_v2.TPUEmbedding(
                feature_config=feature_config, optimizer=optimizer)
            embedding_two = tpu_embedding_v2.TPUEmbedding(
                feature_config=feature_config, optimizer=optimizer)

        # The first TPU embedding should be able to be built.
        # The second one should fail with a runtime error indicating another TPU
        # embedding has already been initialized on TPU.
        embedding_one.build(64)
        with self.assertRaisesRegex(
                RuntimeError, 'TPU is already initialized for embeddings.'):
            embedding_two.build(64)
  def test_check_checkpoint_variable_names_are_same_on_cpu_and_tpu(
      self, optimizer):
    # Reinitialize the TPU so that we can re-initialize the embeddings with the
    # given optimizer.
    if optimizer != tpu_embedding_v2_utils.SGD:
      self.skip_if_oss()
    strategy = self._get_strategy()
    num_rows = strategy.num_replicas_in_sync

    with strategy.scope():
      first_mid_level_contents = np.ones((num_rows, 4))
      first_mid_level_optimizer = optimizer(learning_rate=0.1)
      initializer = init_ops_v2.Constant(first_mid_level_contents)

      table = tpu_embedding_v2_utils.TableConfig(
          vocabulary_size=num_rows,
          dim=4,
          initializer=initializer,
          combiner='sum',
          name='table')
      feature_config = (tpu_embedding_v2_utils.FeatureConfig(
          table=table, name='feature'),)

      first_mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config, first_mid_level_optimizer)

      first_mid_level.build(64)

    cpu_mid_level_optimizer = optimizer(learning_rate=0.1)
    cpu_mid_level = tpu_embedding_v2.TPUEmbedding(feature_config,
                                                  cpu_mid_level_optimizer)
    cpu_mid_level.build(64)

    tpu_checkpoint = util.Checkpoint(model=first_mid_level)
    tpu_checkpoint.save(self._get_tmpdir('save-tpu', 'save'))
    tpu_variables = checkpoint_utils.list_variables(
        self._get_tmpdir('save-tpu'))

    cpu_checkpoint = util.Checkpoint(model=cpu_mid_level)
    cpu_checkpoint.save(self._get_tmpdir('save-cpu', 'save'))
    cpu_variables = checkpoint_utils.list_variables(
        self._get_tmpdir('save-cpu'))

    self.assertAllEqual(tpu_variables, cpu_variables)
 def test_cpu_no_optimizer(self):
     feature_config = (tpu_embedding_v2_utils.FeatureConfig(
         table=self.table_video, name='watched', max_sequence_length=2), )
     mid_level = tpu_embedding_v2.TPUEmbedding(
         feature_config=feature_config,
         batch_size=self.batch_size,
         optimizer=None)
     self.assertEqual(
         list(mid_level._variables[self.table_video.name].keys()),
         ['parameters'])
      def create_embedding(self):
        table = tpu_embedding_v2_utils.TableConfig(
            vocabulary_size=self._rows, dim=4, initializer=self._initializer,
            combiner='sum', name='table')
        feature_config = (tpu_embedding_v2_utils.FeatureConfig(
            table=table, name='feature'),)
        optimizer = tpu_embedding_v2_utils.SGD()

        self.tpu_embedding = tpu_embedding_v2.TPUEmbedding(
            feature_config, self._rows, optimizer)
Example #12
0
  def _create_mid_level(self, optimizer=None):
    # Create `TPUEmbedding` object.
    if optimizer is None:
      optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)

    num_replicas = (
        distribution_strategy_context.get_strategy().num_replicas_in_sync)
    return tpu_embedding_v2.TPUEmbedding(
        feature_config=self.feature_config,
        batch_size=self.batch_size * num_replicas,
        optimizer=optimizer)
Example #13
0
 def test_cpu_no_optimizer(self):
     feature_config = (tpu_embedding_v2_utils.FeatureConfig(
         table=self.table_video, name='watched', max_sequence_length=2), )
     mid_level = tpu_embedding_v2.TPUEmbedding(
         feature_config=feature_config, optimizer=None)
     # Build the layer manually to create the variables. Normally calling enqueue
     # would do this.
     mid_level.build()
     self.assertEqual(
         list(mid_level._variables[self.table_video.name].keys()),
         ['parameters'])
    def test_cpu_high_dimensional_lookup_ragged(self):
        feature_config = (tpu_embedding_v2_utils.FeatureConfig(
            table=self.table_user, name='friends', output_shape=[2, 2]), )
        optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
        mid_level = tpu_embedding_v2.TPUEmbedding(
            feature_config=feature_config, optimizer=optimizer)
        features = self._get_ragged_tensors()[2:3]
        result = tpu_embedding_v2.cpu_embedding_lookup(
            features,
            weights=None,
            tables=mid_level.embedding_tables,
            feature_config=feature_config)

        self.assertAllClose(result[0].shape, (2, 2, 2))
Example #15
0
 def test_cpu_sequence_lookup(self):
     feature_config = (tpu_embedding_v2_utils.FeatureConfig(
         table=self.table_video, name='watched', max_sequence_length=2), )
     optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
     mid_level = tpu_embedding_v2.TPUEmbedding(
         feature_config=feature_config, optimizer=optimizer)
     features = tuple(self._get_sparse_tensors()[:1])
     with self.assertRaisesRegex(
             ValueError, 'Sequence features unsupported at this time.'):
         tpu_embedding_v2.cpu_embedding_lookup(
             features,
             weights=None,
             tables=mid_level.embedding_tables,
             feature_config=feature_config)
 def test_cpu_high_dimensional_sequence_lookup_ragged(self):
     # Prod of output shape is a factor of the data batch size.
     # The divide result will be the sequence length.
     feature_config = (tpu_embedding_v2_utils.FeatureConfig(
         table=self.table_user, name='friends', output_shape=[2, 4]), )
     optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
     mid_level = tpu_embedding_v2.TPUEmbedding(
         feature_config=feature_config, optimizer=optimizer)
     features = self._get_ragged_tensors()[2:3]
     result = tpu_embedding_v2.cpu_embedding_lookup(
         features,
         weights=None,
         tables=mid_level.embedding_tables,
         feature_config=feature_config)
     self.assertAllClose(result[0].shape, (2, 4, 2))
 def tpu_embedding_config():
   feature_configs = []
   for dim, vocab, name in table_data:
     feature_configs.append(tpu_embedding_v2_utils.FeatureConfig(
         table=tpu_embedding_v2_utils.TableConfig(
             vocabulary_size=int(vocab), dim=int(dim),
             initializer=init_ops_v2.Zeros(), name=name)))
   optimizer = tpu_embedding_v2_utils.Adagrad(
       learning_rate=0.1)
   with strategy.scope():
     mid_level_api = tpu_embedding_v2.TPUEmbedding(
         feature_config=feature_configs,
         optimizer=optimizer)
   mid_level_api._output_shapes = [TensorShape(128)] * len(feature_configs)
   return mid_level_api._create_config_proto()
  def build_mid_level(self, embedding_values, optimizer,
                      initialize_tpu_embedding=True):
    """Creates an embedding api object initialized to embedding_values."""
    initializer = init_ops_v2.Constant(embedding_values)

    table = tpu_embedding_v2_utils.TableConfig(
        vocabulary_size=self.num_rows, dim=4, initializer=initializer,
        combiner='sum', name='table')
    feature_config = (tpu_embedding_v2_utils.FeatureConfig(
        table=table, name='feature'),)

    # batch_size here does not matter as we aren't training in any of these
    # tests.
    return tpu_embedding_v2.TPUEmbedding(
        feature_config, 64, optimizer,
        initialize_tpu_embedding=initialize_tpu_embedding)
 def test_cpu_high_dimensional_invalid_lookup_ragged(self):
     # Prod of output shape is not a factor of the data batch size.
     # An error will be raised in this case.
     feature_config = (tpu_embedding_v2_utils.FeatureConfig(
         table=self.table_user, name='friends', output_shape=[3]), )
     optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
     mid_level = tpu_embedding_v2.TPUEmbedding(
         feature_config=feature_config, optimizer=optimizer)
     features = self._get_ragged_tensors()[2:3]
     with self.assertRaisesRegex(
             ValueError,
             'Output shape set in the FeatureConfig should be the factor'):
         tpu_embedding_v2.cpu_embedding_lookup(
             features,
             weights=None,
             tables=mid_level.embedding_tables,
             feature_config=feature_config)
Example #20
0
 def test_tables_with_same_name(self):
   with self.assertRaisesRegex(
       ValueError, 'Multiple tables with name table found.'):
     with self._get_strategy().scope():
       tpu_embedding_v2.TPUEmbedding(
           (tpu_embedding_v2_utils.FeatureConfig(
               table=tpu_embedding_v2_utils.TableConfig(
                   name='table',
                   vocabulary_size=4,
                   dim=2,
                   initializer=self.initializer,),
               name='watched'),
            tpu_embedding_v2_utils.FeatureConfig(
                table=tpu_embedding_v2_utils.TableConfig(
                    name='table',
                    vocabulary_size=4,
                    dim=2,
                    initializer=self.initializer),
                name='favorited')),
           tpu_embedding_v2_utils.SGD(learning_rate=0.1))
  def test_optimizer_with_slot_creation_fn(self, use_tpu):
    def slot_creation_fn(table, slot_names, _):
      slots = {}
      for slot in slot_names:
        slots[slot] = tf_variables.Variable(
            name='{}_{}'.format(table.name, slot),
            initial_value=functools.partial(
                init_ops_v2.Zeros(), shape=table.shape, dtype=dtypes.float32),
            trainable=False)
      return slots
    optimizer = tpu_embedding_v2_utils.Adagrad(
        learning_rate=0.1,
        slot_variable_creation_fn=slot_creation_fn)
    if use_tpu:
      strategy = self._get_strategy()
    else:
      strategy = distribution_strategy_context.get_strategy()
    with strategy.scope():
      mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config=self.feature_config,
          optimizer=optimizer)
      # We aren't going to actually run anything, so the batch_size here does
      # not matter.
      mid_level.build(self.batch_size)
    video_accumulator = mid_level._variables['video']['accumulators']
    user_accumulator = mid_level._variables['user']['accumulators']
    if use_tpu:
      # To check the table contents (ensure that it is zero rather than the
      # normal initial accumulator value specified to in the optimizer config),
      # we need to select the underlying table variable on TPU.
      # We only have one shard on Forge.
      video_accumulator = video_accumulator.variables[0]
      user_accumulator = user_accumulator.variables[0]

    self.assertAllClose(video_accumulator.numpy(),
                        np.zeros((self.table_video.vocabulary_size,
                                  self.table_video.dim)))
    self.assertAllClose(user_accumulator.numpy(),
                        np.zeros((self.table_user.vocabulary_size,
                                  self.table_user.dim)))
Example #22
0
    def test_missing_feature(self, is_sparse):
        strategy = self._get_strategy()
        with strategy.scope():
            optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
            mid_level_api = tpu_embedding_v2.TPUEmbedding(
                feature_config=tpu_embedding_v2_utils.FeatureConfig(
                    table=self.table_video, name='watched'),
                optimizer=optimizer)
        # Create sparse or ragged feature with last sample missing.
        if is_sparse:
            features = sparse_tensor.SparseTensor(
                indices=self.feature_watched_indices[:-1],
                values=self.feature_watched_values[:-1],
                dense_shape=[self.data_batch_size, 2])
        else:
            features = ragged_tensor.RaggedTensor.from_row_lengths(
                row_lengths=[1, 2, 2, 0],
                values=self.feature_watched_values[:-1])

        dataset = dataset_ops.DatasetV2.from_tensors(features)

        dataset = dataset.unbatch().repeat().batch(
            self.batch_size * strategy.num_replicas_in_sync,
            drop_remainder=True)
        dataset_iter = iter(
            strategy.experimental_distribute_dataset(
                dataset,
                options=distribute_lib.InputOptions(
                    experimental_fetch_to_device=False)))

        @def_function.function
        def test_fn():
            def get_activations():
                return mid_level_api.dequeue()

            mid_level_api.enqueue(next(dataset_iter), training=False)
            return strategy.run(get_activations)

        test_fn()
 def test_optimizer_with_slot_creation_fn_non_partial(self):
   def slot_creation_fn(table, slot_names):
     slots = {}
     for slot in slot_names:
       # Note that we don't pass functools.partial here, so on TPU we can't
       # extract the shape. We expect the error below.
       slots[slot] = tf_variables.Variable(
           name='{}_{}'.format(table.name, slot),
           initial_value=init_ops_v2.Zeros()(shape=table.shape,
                                             dtype=dtypes.float32),
           trainable=False)
     return slots
   optimizer = tpu_embedding_v2_utils.Adagrad(
       learning_rate=0.1,
       slot_variable_creation_fn=slot_creation_fn)
   strategy = self._get_strategy()
   num_replicas = strategy.num_replicas_in_sync
   with strategy.scope():
     with self.assertRaisesRegex(ValueError,
                                 'Unable to extract initializer function'):
       tpu_embedding_v2.TPUEmbedding(
           feature_config=self.feature_config,
           batch_size=self.batch_size*num_replicas,
           optimizer=optimizer)
Example #24
0
 def _create_mid_level(self):
     optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
     return tpu_embedding_v2.TPUEmbedding(
         feature_config=self.feature_config, optimizer=optimizer)
  def test_model_export_cpu(self):
    strategy = self._get_strategy()
    num_rows = strategy.num_replicas_in_sync

    with strategy.scope():
      first_mid_level_contents = np.ones((num_rows, 4))
      first_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
      initializer = init_ops_v2.Constant(first_mid_level_contents)

      table = tpu_embedding_v2_utils.TableConfig(
          vocabulary_size=num_rows,
          dim=4,
          initializer=initializer,
          combiner='sum',
          name='table')
      feature_config = (tpu_embedding_v2_utils.FeatureConfig(
          table=table, name='feature'),)

      first_mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config, first_mid_level_optimizer)

      first_mid_level.build(64)

    cpu_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
    cpu_mid_level = tpu_embedding_v2.TPUEmbedding(feature_config,
                                                  cpu_mid_level_optimizer)

    cpu_mid_level.build(64)

    first_mid_level._load_variables()

    tpu_checkpoint = util.Checkpoint(model=first_mid_level)
    tpu_checkpoint.save(self._get_tmpdir('export_cpu', 'save'))

    # We restore the checkpoint of our tpu mid level onto our cpu mid level.
    cpu_checkpoint = util.Checkpoint(model=cpu_mid_level)
    cpu_checkpoint.restore(self._get_tmpdir('export_cpu', 'save-1'))

    @def_function.function
    def serve_tensors(features):
      features = tpu_embedding_v2.cpu_embedding_lookup(
          features, None, cpu_mid_level.embedding_tables,
          cpu_mid_level._feature_config)
      return features[0]

    signatures = {
        'serving_default':
            serve_tensors.get_concrete_function((tensor_spec.TensorSpec(
                shape=(2,), dtype=dtypes.int32, name='feature'),))
    }
    save.save(
        cpu_mid_level,
        export_dir=self._get_tmpdir('export_cpu', 'exported_model'),
        signatures=signatures)

    imported = load.load(self._get_tmpdir('export_cpu', 'exported_model'))
    predict_fn = imported.signatures['serving_default']

    input_feature_value = np.array([1, 0])
    input_batch = (constant_op.constant(
        input_feature_value, dtype=dtypes.int32),)
    prediction = predict_fn(*input_batch)['output_0']
    self.assertAllClose(prediction.numpy(),
                        first_mid_level_contents[input_feature_value])
  def test_sequence_embeddings(self, sparse):
    feature_config = (
        tpu_embedding_v2_utils.FeatureConfig(
            table=self.table_video, name='watched',
            max_sequence_length=2),
        tpu_embedding_v2_utils.FeatureConfig(
            table=self.table_video, name='favorited',
            max_sequence_length=2),
        tpu_embedding_v2_utils.FeatureConfig(
            table=self.table_user, name='friends',
            max_sequence_length=3))
    optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
    strategy = self._get_strategy()
    num_replicas = strategy.num_replicas_in_sync
    with strategy.scope():
      mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config=feature_config,
          optimizer=optimizer)
    # Call build here. We call 'next' outside of the tf.function and this
    # results in data where the shape of the sparse tensor is a tensor which we
    # can't tell the shape of at tracing time.
    mid_level.build(self.batch_size)
    if sparse:
      dataset = self._create_sparse_dataset(strategy)
    else:
      dataset = self._create_ragged_dataset(strategy)
    data = next(
        iter(
            strategy.experimental_distribute_dataset(
                dataset,
                options=distribute_lib.InputOptions(
                    experimental_fetch_to_device=False))))

    @def_function.function
    def embedding_and_set_gradients(data):
      def tpu_fn():
        activations = mid_level.dequeue()
        mid_level.apply_gradients(nest.map_structure(array_ops.ones_like,
                                                     activations))
        return activations
      mid_level.enqueue(data)
      return strategy.run(tpu_fn)

    @def_function.function
    def embedding_only(data):
      def tpu_fn():
        return mid_level.dequeue()
      mid_level.enqueue(data)
      return strategy.run(tpu_fn)

    # Only check core 0.
    before_update = self._get_replica_numpy(
        embedding_and_set_gradients(data), strategy, 0)
    after_update = self._get_replica_numpy(embedding_only(data), strategy, 0)

    # For videos table, row 0 and row 1 are looked up 3*num_replicas times as
    # they occur 3 times per replica (considering the features 0 and 1 which are
    # both looked up in the videos table).
    # Feature 0 has ids [0, 0, 1], [0, 1, 1], ... repeated over num_replicas
    # Feature 1 has ids [0, 1, 1], [0, 0, 1], ... repeated over num_replicas
    # This means that both rows 0 and 1 get a -0.1*3*num_replicas update
    # For users table, each row is looked up twice:
    # Feature 2 has ids [3, 0, 1, 2], .. repeated over num_replicas
    # This means that we get a -0.1*num_replicas update to the third feature.

    # In general this means that after the update, if we lookup feature 0 and 1
    # the values will be 0.3*num_replicas lower per entry and for feature 2 they
    # will be 0.1*num_replicas lower.
    # The one issue is that these lookups contain padding values.
    # For core 0, we get the first 2 elements of the 4 element batch.
    # For feature 0, the indices are [[0, 0], [1, 0], [1, 1]] with max sequence
    # length of 2, which means that [0, 1] will be 0s.
    # For feature 1, the indices are [[0, 0], [0, 1], [1, 0]] with max sequence
    # length of 2, which means that [1, 1] will be 0s.
    # For feature 2, the indices are [[0, 0], [1, 0], [1, 1], [1, 2]] with max
    # sequence length of 3, which means that [0, 1], [0, 2] will be 0s.
    # The following masks represent that so that we only apply the above updates
    # to the non-padding rows:
    masks = (
        np.array([[[1], [0]], [[1], [1]]]),
        np.array([[[1], [1]], [[1], [0]]]),
        np.array([[[1], [0], [0]], [[1], [1], [1]]]))

    per_row_update = (0.3 * num_replicas,
                      0.3 * num_replicas,
                      0.1 * num_replicas)
    golden = tuple([before - update * mask for before, update, mask in
                    zip(before_update, per_row_update, masks)])
    self.assertAllClose(golden, after_update)
  def test_variable_learning_rate(self):
    num_steps = 10
    num_steps_float = float(num_steps)
    starting_lr = 1.0
    ending_lr = 0.5

    strategy = self._get_strategy()
    num_replicas = strategy.num_replicas_in_sync

    # Create model with Keras.
    with strategy.scope():
      step_counter = tf_variables.Variable(0.0, dtypes.float32)

      def lr_function():
        return gen_math_ops.maximum(
            ending_lr,
            starting_lr + ((ending_lr - starting_lr) * step_counter) /
            num_steps_float)

      optimizer = tpu_embedding_v2_utils.SGD(learning_rate=lr_function)
      table_config = tpu_embedding_v2_utils.TableConfig(
          vocabulary_size=num_replicas,
          dim=4,
          initializer=init_ops_v2.Constant(np.zeros((num_replicas, 4))),
          combiner='sum', name='table')
      mid_level_api = tpu_embedding_v2.TPUEmbedding(
          feature_config={
              'feature': tpu_embedding_v2_utils.FeatureConfig(
                  table=table_config, name='feature')},
          optimizer=optimizer)

    feature = {
        'feature': constant_op.constant([0], shape=(1, 1), dtype=dtypes.int32)
    }

    def input_fn(ctx):
      del ctx
      return dataset_ops.DatasetV2.from_tensors(feature).repeat()

    dist = strategy.distribute_datasets_from_function(
        input_fn,
        options=distribute_lib.InputOptions(experimental_fetch_to_device=False))
    dist_iter = iter(dist)

    @def_function.function
    def test_fn():
      def step():
        with backprop.GradientTape() as tape:
          activations = mid_level_api.dequeue()
          tape.watch(activations)
          result = math_ops.reduce_sum(activations['feature'])
          loss = result / num_replicas
        grads = tape.gradient(loss, activations)
        mid_level_api.apply_gradients(grads)
        return activations['feature']

      mid_level_api.enqueue(next(dist_iter), training=True)
      return strategy.run(step)

    # Run model.
    results = []
    for _ in range(num_steps):
      result = test_fn()
      results.append(self._unpack(strategy, result))
      step_counter.assign_add(1.0)

    # Table is 2 elements wide, per-replica batch size of 1, with id 0.
    # Loss for the gradient is the sum of the entries divided by the number of
    # replicas. Thus the per replica gradient is 1/#of replicas for row 0 and no
    # other updates. The reduced gradient is therefore 1.
    # Learning rate schedule over num_steps steps:
    # 1.0 0.95 0.9 0.85 0.8 ...
    # Since use SGD and the gradient is one, the first row of the table is
    # [0, 0] [-1.0, -1.0] [-1.95, -1.95] [-2.85, -2.85] ... (the negative
    # partial sums of the above).

    learning_rates = [starting_lr - (starting_lr - ending_lr) / num_steps * j
                      for j in range(num_steps)]
    cumsum = [sum(learning_rates[0:j]) for j in range(num_steps)]
    goldens = [[[-cumsum[i]] * table_config.dim] * num_replicas
               for i in range(10)]
    self.assertAllClose(results, goldens)
  def test_checkpoint_save_retrieves(self):
    strategy = self._get_strategy()
    num_rows = strategy.num_replicas_in_sync

    with strategy.scope():
      first_mid_level_contents = np.ones((num_rows, 4))
      first_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
      initializer = init_ops_v2.Constant(first_mid_level_contents)

      table = tpu_embedding_v2_utils.TableConfig(
          vocabulary_size=num_rows,
          dim=4,
          initializer=initializer,
          combiner='sum',
          name='table')
      feature_config = (tpu_embedding_v2_utils.FeatureConfig(
          table=table, name='feature'),)

      first_mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config, first_mid_level_optimizer)
      first_mid_level.build(64)

    # Ensure that the variables from the first model are loaded.
    first_mid_level._load_variables()

    self.assertAllClose(
        first_mid_level_contents,
        self.make_checkpoint_and_get_embedding('before_load', first_mid_level,
                                               num_rows),
        msg='Checkpoint should contain values from the first api object.')

    # Reinitialize the tpu.
    tpu_strategy_util.initialize_tpu_system(self.resolver)

    with strategy.scope():
      second_mid_level_contents = np.ones((num_rows, 4)) * 2
      second_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
      initializer = init_ops_v2.Constant(second_mid_level_contents)

      table = tpu_embedding_v2_utils.TableConfig(
          vocabulary_size=num_rows,
          dim=4,
          initializer=initializer,
          combiner='sum',
          name='table')
      feature_config = (tpu_embedding_v2_utils.FeatureConfig(
          table=table, name='feature'),)
      second_mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config, second_mid_level_optimizer)
      second_mid_level.build(64)

    second_mid_level._load_variables()

    # When we load the variables from the second mid level API object to the TPU
    # we expect that checkpointing the first mid level API object will now
    # retrieve the values from the TPU which are now different from the current
    # variables in the first mid level.
    self.assertAllClose(
        second_mid_level_contents,
        self.make_checkpoint_and_get_embedding('after_load', first_mid_level,
                                               num_rows),
        msg='Checkpoint should contain values from the second api object.')
  def test_checkpoint_restore_loads(self):
    strategy = self._get_strategy()
    num_rows = strategy.num_replicas_in_sync

    def get_values(mid):
      return ops.convert_to_tensor(
          mid._variables['table']['parameters'].variables[0])

    with strategy.scope():
      first_mid_level_contents = np.ones((num_rows, 4))
      first_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
      initializer = init_ops_v2.Constant(first_mid_level_contents)

      table = tpu_embedding_v2_utils.TableConfig(
          vocabulary_size=num_rows,
          dim=4,
          initializer=initializer,
          combiner='sum',
          name='table')
      feature_config = (tpu_embedding_v2_utils.FeatureConfig(
          table=table, name='feature'),)

      first_mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config, first_mid_level_optimizer)
      first_mid_level.build(64)

    first_mid_level._load_variables()

    first_checkpoint = util.Checkpoint(model=first_mid_level)
    first_checkpoint.save(self._get_tmpdir('restore', 'save'))

    tpu_strategy_util.initialize_tpu_system(self.resolver)

    with strategy.scope():
      second_mid_level_contents = np.ones((num_rows, 4)) * 2
      second_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
      initializer = init_ops_v2.Constant(second_mid_level_contents)

      table = tpu_embedding_v2_utils.TableConfig(
          vocabulary_size=num_rows,
          dim=4,
          initializer=initializer,
          combiner='sum',
          name='table')
      feature_config = (tpu_embedding_v2_utils.FeatureConfig(
          table=table, name='feature'),)
      second_mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config, second_mid_level_optimizer)
      second_mid_level.build(64)

    second_mid_level._load_variables()

    self.assertAllClose(
        second_mid_level_contents,
        get_values(second_mid_level),
        msg='Second mid level api should contain its initial values.',
    )
    # We restore the checkpoint of our first model into our second model.
    # This should load the first mid level API object onto the TPU.
    second_checkpoint = util.Checkpoint(model=second_mid_level)
    second_checkpoint.restore(self._get_tmpdir('restore', 'save-1'))

    # Call retrieve here as a way to check what the TPU contains.
    # Calling the retrieve ops directly might make for a cleaner separation of
    # test and module, though.
    second_mid_level._retrieve_variables()

    self.assertAllClose(
        first_mid_level_contents,
        get_values(second_mid_level),
        msg='Second mid level api should have retrieved the first model values.'
    )