def _CreateLayerVariables(self): super()._CreateLayerVariables() p = self.params load_op_list = [] retrieve_op_list = [] # At the feature level, track which are associated # with "sequence embeddings". self._sequence_features = {} if py_utils.use_tpu(): num_cores = self.cluster.params.worker.tpus_per_replica global_batch_size = (self.params.batch_size * self.cluster.num_splits_per_client) table_to_config_dict = {} feature_to_config_dict = {} for table in self.tables: table_to_config_dict[table.table_name] = table.table_config load_op_list += table.load_op_list retrieve_op_list += table.retrieve_op_list for feature in table.input_keys: if table.max_sequence_length > 0: self._sequence_features[feature] = True feature_to_config_dict[ feature] = tpu_embedding_lib.FeatureConfig( table.table_name, max_sequence_length=table.max_sequence_length) tpu_embedding = self._tpu_embedding_collection.tpu_embedding if tpu_embedding: self._CheckTPUEmbeddingConfig(tpu_embedding, table_to_config_dict, feature_to_config_dict, global_batch_size) tf.logging.info( 'TPUEmbedding API singleton already exists, reusing') self._tpu_embedding = tpu_embedding else: tf.logging.info('adding load and retrieve ops to collection.') self._tpu_embedding_collection.AddLoadOps(load_op_list) self._tpu_embedding_collection.AddRetrieveOps(retrieve_op_list) mode = tpu_embedding_lib.TRAINING device_config = tpu_embedding_lib.DeviceConfig( num_cores=num_cores, num_hosts=self.params.tables[0].num_tpu_hosts, job_name=self.cluster.params.worker.name) self._tpu_embedding = tpu_embedding_lib.TPUEmbedding( table_to_config_dict, feature_to_config_dict, global_batch_size, mode, master=None, pipeline_execution_with_tensor_core=( self.params.pipeline_execution_with_tensor_core), partition_strategy=p.partition_strategy, device_config=device_config) self._tpu_embedding_collection.tpu_embedding = self._tpu_embedding
def _CreateLayerVariables(self): super()._CreateLayerVariables() load_op_list = [] retrieve_op_list = [] # At the feature level, track which are associated # with "sequence embeddings". self._sequence_features = {} if py_utils.use_tpu(): num_cores = self.cluster.params.worker.tpus_per_replica global_batch_size = (self.params.batch_size * self.cluster.num_splits_per_client) table_to_config_dict = {} feature_to_config_dict = {} for table in self.tables: table_to_config_dict[table.table_name] = table.table_config load_op_list += table.load_op_list retrieve_op_list += table.retrieve_op_list for feature in table.input_keys: if table.max_sequence_length > 0: self._sequence_features[feature] = True feature_to_config_dict[ feature] = tpu_embedding_lib.FeatureConfig( table.table_name, max_sequence_length=table.max_sequence_length) tf.logging.info('adding load and retrieve ops to collection.') tf.add_to_collection(py_utils.TPU_EMBEDDING_LOAD_OPS, load_op_list) tf.add_to_collection(py_utils.TPU_EMBEDDING_RETRIEVE_OPS, retrieve_op_list) tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) assert len(tpu_embedding_collection) <= 1 if len(tpu_embedding_collection) == 1: tf.logging.info( 'TPUEmbedding API singleton already exists, reusing') self._tpu_embedding = tpu_embedding_collection[0] else: mode = tpu_embedding_lib.TRAINING device_config = tpu_embedding_lib.DeviceConfig( num_cores=num_cores, num_hosts=self.params.tables[0].num_tpu_hosts, job_name=self.cluster.params.worker.name) self._tpu_embedding = tpu_embedding_lib.TPUEmbedding( table_to_config_dict, feature_to_config_dict, global_batch_size, mode, master=None, pipeline_execution_with_tensor_core=( self.params.pipeline_execution_with_tensor_core), device_config=device_config) tf.add_to_collection(py_utils.TPU_EMBEDDING, self._tpu_embedding)
def _BuildTpuEmbeddingApi(): load_op_list = [] retrieve_op_list = [] num_cores = self.cluster.params.worker.tpus_per_replica global_batch_size = (self.params.batch_size * self.cluster.num_splits_per_client) table_to_config_dict = {} feature_to_config_dict = {} for table in self.tables: table_to_config_dict[table.table_name] = table.table_config load_op_list += table.load_op_list retrieve_op_list += table.retrieve_op_list for feature in table.input_keys: feature_to_config_dict[ feature] = tpu_embedding_lib.FeatureConfig( table.table_name, max_sequence_length=table.max_sequence_length) mode = tpu_embedding_lib.TRAINING device_config = tpu_embedding_lib.DeviceConfig( num_cores=num_cores, num_hosts=self.params.tables[0].num_tpu_hosts, job_name=self.cluster.params.worker.name) tpu_embedding = tpu_embedding_lib.TPUEmbedding( table_to_config_dict, feature_to_config_dict, global_batch_size, mode, master=None, pipeline_execution_with_tensor_core=( self.params.pipeline_execution_with_tensor_core), partition_strategy=p.partition_strategy, device_config=device_config) with tf.init_scope(): dummy_variables, dummy_variables_init = ( tpu_embedding_gradient.create_dummy_table_variables( tpu_embedding)) load_op_list += [dummy_variables_init] tf.add_to_collection(py_utils.TPU_EMBEDDING, tpu_embedding) tf.add_to_collection(py_utils.TPU_EMBEDDING_DUMMY_VARS, dummy_variables) tf.add_to_collection(py_utils.TPU_EMBEDDING_LOAD_OPS, load_op_list) tf.add_to_collection(py_utils.TPU_EMBEDDING_RETRIEVE_OPS, retrieve_op_list)
def _CreateLayerVariables(self): super()._CreateLayerVariables() p = self.params # At the feature level, track which are associated # with "sequence embeddings". self._sequence_features = {} if _ShouldUseTpu(p): num_cores = self.cluster.params.worker.tpus_per_replica global_batch_size = ( self.params.batch_size * self.cluster.num_splits_per_client) table_to_config_dict = {} feature_to_config_dict = {} for table in self.tables: table_to_config_dict[table.table_name] = table.table_config for feature in table.input_keys: if table.max_sequence_length > 0: self._sequence_features[feature] = True feature_to_config_dict[feature] = tpu_embedding_lib.FeatureConfig( table.table_name, max_sequence_length=table.max_sequence_length) tpu_embedding = self._tpu_embedding_collection.tpu_embedding if tpu_embedding: self._CheckTPUEmbeddingConfig(tpu_embedding, table_to_config_dict, feature_to_config_dict, global_batch_size) tf.logging.info('TPUEmbedding API singleton already exists, reusing') self._tpu_embedding = tpu_embedding else: mode = tpu_embedding_lib.TRAINING device_config = tpu_embedding_lib.DeviceConfig( num_cores=num_cores, num_hosts=self.params.tables[0].num_tpu_hosts, job_name=self.cluster.params.worker.name) self._tpu_embedding = tpu_embedding_lib.TPUEmbedding( table_to_config_dict, feature_to_config_dict, global_batch_size, mode, master=None, pipeline_execution_with_tensor_core=( self.params.pipeline_execution_with_tensor_core), partition_strategy=p.partition_strategy, device_config=device_config) self._tpu_embedding_collection.tpu_embedding = self._tpu_embedding self._tpu_embedding_collection.SetGradientMultiplierSchedule( self.gradient_multiplier_schedule)