コード例 #1
0
  def test_checkpoint_restore_before_variable_creation(self):

    class TestModule(module.Module):

      def __init__(self, initializer, rows):
        self._initializer = initializer
        self._rows = rows

        table = tpu_embedding_v2_utils.TableConfig(
            vocabulary_size=self._rows,
            dim=4,
            initializer=self._initializer,
            combiner='sum',
            name='table')
        feature_config = (tpu_embedding_v2_utils.FeatureConfig(
            table=table, name='feature'),)
        optimizer = tpu_embedding_v2_utils.SGD()

        self.tpu_embedding = tpu_embedding_v2.TPUEmbedding(
            feature_config, optimizer)

      def create_embedding(self):
        # We aren't training so batch_size here doesn't matter.
        self.tpu_embedding.build(64)

    strategy = self._get_strategy()
    with strategy.scope():
      module1 = TestModule(init_ops_v2.Ones(),
                           strategy.num_replicas_in_sync * 2)
      module1.create_embedding()

    checkpoint = util.Checkpoint(test_module=module1)
    checkpoint.save(self._get_tmpdir('restore_before_create', 'save'))

    # Reinitialize the tpu
    strategy = self._get_strategy()

    with strategy.scope():
      module2 = TestModule(init_ops_v2.Zeros(),
                           strategy.num_replicas_in_sync * 2)

    checkpoint = util.Checkpoint(test_module=module2)
    checkpoint.restore(self._get_tmpdir('restore_before_create', 'save-1'))

    with strategy.scope():
      module2.create_embedding()

    def get_values(mid):
      return mid._variables['table']['parameters'].variables[0].numpy()

    self.assertAllClose(
        np.ones((strategy.num_replicas_in_sync * 2, 4)),
        get_values(module2.tpu_embedding))

    # Fetch the values from the TPU to check that they are the same.
    module2.tpu_embedding._retrieve_variables()

    self.assertAllClose(
        np.ones((strategy.num_replicas_in_sync * 2, 4)),
        get_values(module2.tpu_embedding))
コード例 #2
0
  def test_checkpoint_restore_before_variable_creation(self):
    # This test works right now because we only have one TPU host in the unit
    # environment. Initializing from checkpoint does not understand how to
    # pass the sharding info to the restore op right now.

    class TestModule(module.Module):

      def __init__(self, initializer, rows):
        self._initializer = initializer
        self._rows = rows

        table = tpu_embedding_v2_utils.TableConfig(
            vocabulary_size=self._rows, dim=4, initializer=self._initializer,
            combiner='sum', name='table')
        feature_config = (tpu_embedding_v2_utils.FeatureConfig(
            table=table, name='feature'),)
        optimizer = tpu_embedding_v2_utils.SGD()

        self.tpu_embedding = tpu_embedding_v2.TPUEmbedding(
            feature_config, optimizer)

      def create_embedding(self):
        # We aren't training so batch_size here doesn't matter.
        self.tpu_embedding.build(64)

    # We need to clear the any already loaded config provided by setUp method.
    tpu_strategy_util.initialize_tpu_system(self.resolver)

    with self.strategy.scope():
      module1 = TestModule(init_ops_v2.Ones(),
                           self.strategy.num_replicas_in_sync * 2)
      module1.create_embedding()

    checkpoint = util.Checkpoint(test_module=module1)
    checkpoint.save(_get_tmpdir('restore_before_create', 'save'))

    tpu_strategy_util.initialize_tpu_system(self.resolver)

    with self.strategy.scope():
      module2 = TestModule(init_ops_v2.Zeros(),
                           self.strategy.num_replicas_in_sync * 2)

    checkpoint = util.Checkpoint(test_module=module2)
    checkpoint.restore(_get_tmpdir('restore_before_create', 'save-1'))

    with self.strategy.scope():
      module2.create_embedding()

    def get_values(mid):
      return mid._variables['table']['parameters'].variables[0].numpy()

    self.assertAllClose(np.ones((self.strategy.num_replicas_in_sync * 2, 4)),
                        get_values(module2.tpu_embedding))

    # Fetch the values from the TPU to check that they are the same.
    module2.tpu_embedding._retrieve_variables()

    self.assertAllClose(np.ones((self.strategy.num_replicas_in_sync * 2, 4)),
                        get_values(module2.tpu_embedding))
コード例 #3
0
 def slot_creation_fn(table, slot_names, _):
   slots = {}
   for slot in slot_names:
     slots[slot] = tf_variables.Variable(
         name='{}_{}'.format(table.name, slot),
         initial_value=functools.partial(
             init_ops_v2.Zeros(), shape=table.shape, dtype=dtypes.float32),
         trainable=False)
   return slots
コード例 #4
0
 def slot_creation_fn(table, slot_names, _):
   slots = {}
   for slot in slot_names:
     # Note that we don't pass functools.partial here, so on TPU we can't
     # extract the shape. We expect the error below.
     slots[slot] = tf_variables.Variable(
         name='{}_{}'.format(table.name, slot),
         initial_value=init_ops_v2.Zeros()(shape=table.shape,
                                           dtype=dtypes.float32),
         trainable=False)
   return slots
コード例 #5
0
  def test_checkpoint_restore_before_variable_creation(self):

    class TestModule(module.Module):

      def __init__(self, initializer, rows):
        self._initializer = initializer
        self._rows = rows

      def create_embedding(self):
        table = tpu_embedding_v2_utils.TableConfig(
            vocabulary_size=self._rows, dim=4, initializer=self._initializer,
            combiner='sum', name='table')
        feature_config = (tpu_embedding_v2_utils.FeatureConfig(
            table=table, name='feature'),)
        optimizer = tpu_embedding_v2_utils.SGD()

        self.tpu_embedding = tpu_embedding_v2.TPUEmbedding(
            feature_config, self._rows, optimizer)

    # We need to clear the already loaded config provided by setUp method.
    tpu_strategy_util.initialize_tpu_system(self.resolver)

    with self.strategy.scope():
      module1 = TestModule(init_ops_v2.Ones(),
                           self.strategy.num_replicas_in_sync * 2)
      module1.create_embedding()

    checkpoint = util.Checkpoint(test_module=module1)
    checkpoint.save(_get_tmpdir('restore_before_create', 'save'))

    tpu_strategy_util.initialize_tpu_system(self.resolver)

    with self.strategy.scope():
      module2 = TestModule(init_ops_v2.Zeros(),
                           self.strategy.num_replicas_in_sync * 2)

    checkpoint = util.Checkpoint(test_module=module2)
    checkpoint.restore(_get_tmpdir('restore_before_create', 'save-1'))

    with self.strategy.scope():
      module2.create_embedding()

    def get_values(mid):
      return mid._variables['table']['parameters'].variables[0].numpy()

    self.assertAllClose(np.ones((self.strategy.num_replicas_in_sync * 2, 4)),
                        get_values(module2.tpu_embedding))

    # Fetch the values from the TPU to check that they are the same.
    module2.tpu_embedding._retrieve_variables()

    self.assertAllClose(np.ones((self.strategy.num_replicas_in_sync * 2, 4)),
                        get_values(module2.tpu_embedding))
コード例 #6
0
 def tpu_embedding_config():
   feature_configs = []
   for dim, vocab, name in table_data:
     feature_configs.append(tpu_embedding_v2_utils.FeatureConfig(
         table=tpu_embedding_v2_utils.TableConfig(
             vocabulary_size=int(vocab), dim=int(dim),
             initializer=init_ops_v2.Zeros(), name=name)))
   optimizer = tpu_embedding_v2_utils.Adagrad(
       learning_rate=0.1)
   with strategy.scope():
     mid_level_api = tpu_embedding_v2.TPUEmbedding(
         feature_config=feature_configs,
         optimizer=optimizer)
   mid_level_api._output_shapes = [TensorShape(128)] * len(feature_configs)
   return mid_level_api._create_config_proto()
コード例 #7
0
ファイル: init_ops_v2_test.py プロジェクト: chrisvon62/AiBot
 def testZeros(self):
   self._range_test(init_ops_v2.Zeros(), shape=(4, 5),
                    target_mean=0., target_max=0.)
コード例 #8
0
ファイル: init_ops_v2_test.py プロジェクト: chrisvon62/AiBot
 def testZerosInvalidKwargs(self):
   init = init_ops_v2.Zeros()
   with self.assertRaisesWithLiteralMatch(TypeError,
                                          r"Unknown keyword arguments: dtpye"):
     init((2, 2), dtpye=dtypes.float32)
コード例 #9
0
ファイル: init_ops_v2_test.py プロジェクト: chrisvon62/AiBot
 def testZerosPartition(self):
   init = init_ops_v2.Zeros()
   self._partition_test(init)
コード例 #10
0
 def testZerosInvalidKwargs(self):
     init = init_ops_v2.Zeros()
     with self.assertRaisesRegex(
             TypeError,
             r"Keyword argument should be one of .* Received: dtpye"):
         init((2, 2), dtpye=dtypes.float32)