def test_checkpoint_restore_before_variable_creation(self): class TestModule(module.Module): def __init__(self, initializer, rows): self._initializer = initializer self._rows = rows table = tpu_embedding_v2_utils.TableConfig( vocabulary_size=self._rows, dim=4, initializer=self._initializer, combiner='sum', name='table') feature_config = (tpu_embedding_v2_utils.FeatureConfig( table=table, name='feature'),) optimizer = tpu_embedding_v2_utils.SGD() self.tpu_embedding = tpu_embedding_v2.TPUEmbedding( feature_config, optimizer) def create_embedding(self): # We aren't training so batch_size here doesn't matter. self.tpu_embedding.build(64) strategy = self._get_strategy() with strategy.scope(): module1 = TestModule(init_ops_v2.Ones(), strategy.num_replicas_in_sync * 2) module1.create_embedding() checkpoint = util.Checkpoint(test_module=module1) checkpoint.save(self._get_tmpdir('restore_before_create', 'save')) # Reinitialize the tpu strategy = self._get_strategy() with strategy.scope(): module2 = TestModule(init_ops_v2.Zeros(), strategy.num_replicas_in_sync * 2) checkpoint = util.Checkpoint(test_module=module2) checkpoint.restore(self._get_tmpdir('restore_before_create', 'save-1')) with strategy.scope(): module2.create_embedding() def get_values(mid): return mid._variables['table']['parameters'].variables[0].numpy() self.assertAllClose( np.ones((strategy.num_replicas_in_sync * 2, 4)), get_values(module2.tpu_embedding)) # Fetch the values from the TPU to check that they are the same. module2.tpu_embedding._retrieve_variables() self.assertAllClose( np.ones((strategy.num_replicas_in_sync * 2, 4)), get_values(module2.tpu_embedding))
def test_checkpoint_restore_before_variable_creation(self): # This test works right now because we only have one TPU host in the unit # environment. Initializing from checkpoint does not understand how to # pass the sharding info to the restore op right now. class TestModule(module.Module): def __init__(self, initializer, rows): self._initializer = initializer self._rows = rows table = tpu_embedding_v2_utils.TableConfig( vocabulary_size=self._rows, dim=4, initializer=self._initializer, combiner='sum', name='table') feature_config = (tpu_embedding_v2_utils.FeatureConfig( table=table, name='feature'),) optimizer = tpu_embedding_v2_utils.SGD() self.tpu_embedding = tpu_embedding_v2.TPUEmbedding( feature_config, optimizer) def create_embedding(self): # We aren't training so batch_size here doesn't matter. self.tpu_embedding.build(64) # We need to clear the any already loaded config provided by setUp method. tpu_strategy_util.initialize_tpu_system(self.resolver) with self.strategy.scope(): module1 = TestModule(init_ops_v2.Ones(), self.strategy.num_replicas_in_sync * 2) module1.create_embedding() checkpoint = util.Checkpoint(test_module=module1) checkpoint.save(_get_tmpdir('restore_before_create', 'save')) tpu_strategy_util.initialize_tpu_system(self.resolver) with self.strategy.scope(): module2 = TestModule(init_ops_v2.Zeros(), self.strategy.num_replicas_in_sync * 2) checkpoint = util.Checkpoint(test_module=module2) checkpoint.restore(_get_tmpdir('restore_before_create', 'save-1')) with self.strategy.scope(): module2.create_embedding() def get_values(mid): return mid._variables['table']['parameters'].variables[0].numpy() self.assertAllClose(np.ones((self.strategy.num_replicas_in_sync * 2, 4)), get_values(module2.tpu_embedding)) # Fetch the values from the TPU to check that they are the same. module2.tpu_embedding._retrieve_variables() self.assertAllClose(np.ones((self.strategy.num_replicas_in_sync * 2, 4)), get_values(module2.tpu_embedding))
def slot_creation_fn(table, slot_names, _): slots = {} for slot in slot_names: slots[slot] = tf_variables.Variable( name='{}_{}'.format(table.name, slot), initial_value=functools.partial( init_ops_v2.Zeros(), shape=table.shape, dtype=dtypes.float32), trainable=False) return slots
def slot_creation_fn(table, slot_names, _): slots = {} for slot in slot_names: # Note that we don't pass functools.partial here, so on TPU we can't # extract the shape. We expect the error below. slots[slot] = tf_variables.Variable( name='{}_{}'.format(table.name, slot), initial_value=init_ops_v2.Zeros()(shape=table.shape, dtype=dtypes.float32), trainable=False) return slots
def test_checkpoint_restore_before_variable_creation(self): class TestModule(module.Module): def __init__(self, initializer, rows): self._initializer = initializer self._rows = rows def create_embedding(self): table = tpu_embedding_v2_utils.TableConfig( vocabulary_size=self._rows, dim=4, initializer=self._initializer, combiner='sum', name='table') feature_config = (tpu_embedding_v2_utils.FeatureConfig( table=table, name='feature'),) optimizer = tpu_embedding_v2_utils.SGD() self.tpu_embedding = tpu_embedding_v2.TPUEmbedding( feature_config, self._rows, optimizer) # We need to clear the already loaded config provided by setUp method. tpu_strategy_util.initialize_tpu_system(self.resolver) with self.strategy.scope(): module1 = TestModule(init_ops_v2.Ones(), self.strategy.num_replicas_in_sync * 2) module1.create_embedding() checkpoint = util.Checkpoint(test_module=module1) checkpoint.save(_get_tmpdir('restore_before_create', 'save')) tpu_strategy_util.initialize_tpu_system(self.resolver) with self.strategy.scope(): module2 = TestModule(init_ops_v2.Zeros(), self.strategy.num_replicas_in_sync * 2) checkpoint = util.Checkpoint(test_module=module2) checkpoint.restore(_get_tmpdir('restore_before_create', 'save-1')) with self.strategy.scope(): module2.create_embedding() def get_values(mid): return mid._variables['table']['parameters'].variables[0].numpy() self.assertAllClose(np.ones((self.strategy.num_replicas_in_sync * 2, 4)), get_values(module2.tpu_embedding)) # Fetch the values from the TPU to check that they are the same. module2.tpu_embedding._retrieve_variables() self.assertAllClose(np.ones((self.strategy.num_replicas_in_sync * 2, 4)), get_values(module2.tpu_embedding))
def tpu_embedding_config(): feature_configs = [] for dim, vocab, name in table_data: feature_configs.append(tpu_embedding_v2_utils.FeatureConfig( table=tpu_embedding_v2_utils.TableConfig( vocabulary_size=int(vocab), dim=int(dim), initializer=init_ops_v2.Zeros(), name=name))) optimizer = tpu_embedding_v2_utils.Adagrad( learning_rate=0.1) with strategy.scope(): mid_level_api = tpu_embedding_v2.TPUEmbedding( feature_config=feature_configs, optimizer=optimizer) mid_level_api._output_shapes = [TensorShape(128)] * len(feature_configs) return mid_level_api._create_config_proto()
def testZeros(self): self._range_test(init_ops_v2.Zeros(), shape=(4, 5), target_mean=0., target_max=0.)
def testZerosInvalidKwargs(self): init = init_ops_v2.Zeros() with self.assertRaisesWithLiteralMatch(TypeError, r"Unknown keyword arguments: dtpye"): init((2, 2), dtpye=dtypes.float32)
def testZerosPartition(self): init = init_ops_v2.Zeros() self._partition_test(init)
def testZerosInvalidKwargs(self): init = init_ops_v2.Zeros() with self.assertRaisesRegex( TypeError, r"Keyword argument should be one of .* Received: dtpye"): init((2, 2), dtpye=dtypes.float32)