def _create_config_proto(self): """Create `TPUEmbeddingConfiguration`.""" config_proto = elc.TPUEmbeddingConfiguration() for table in self._table_to_config_dict: table_descriptor = config_proto.table_descriptor.add() table_descriptor.name = table table_config = self._table_to_config_dict[table] table_descriptor.vocabulary_size = table_config.vocabulary_size table_descriptor.dimension = table_config.dimension features_for_table = self._table_to_features_dict[table] table_descriptor.num_features = len(features_for_table) table_descriptor.optimization_parameters.learning_rate.constant = ( self._optimization_parameters.learning_rate) table_descriptor.optimization_parameters.gradient_accumulation_status = ( optimization_parameters_pb2.GradientAccumulationStatus.ENABLED if self._optimization_parameters.use_gradient_accumulation else optimization_parameters_pb2.GradientAccumulationStatus.DISABLED) self._optimizer_handler.set_optimization_parameters(table_descriptor) config_proto.mode = self._mode config_proto.batch_size_per_tensor_core = self._batch_size_per_core config_proto.num_hosts = self._num_hosts config_proto.num_tensor_cores = self._num_cores config_proto.sharding_strategy = elc.TPUEmbeddingConfiguration.DIV_DEFAULT config_proto.pipeline_execution_with_tensor_core = ( self._pipeline_execution_with_tensor_core) return config_proto
def test_no_truncate(self): truncate_length = 14937 # Experimentally maximum string length loggable. config = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration() for i in range(500): td = config.table_descriptor.add() td.name = 'table_{}'.format(i) td.vocabulary_size = i config.num_hosts = 2 config.num_tensor_cores = 4 config.batch_size_per_tensor_core = 128 self.assertGreater( len(str(config)), truncate_length, 'Test sanity check: generated config should be of truncating length.' ) with self.assertLogs() as logs: tpu_embedding_v2_utils.log_tpu_embedding_configuration(config) self.assertIn('table_499', ''.join(logs.output)) for line in logs.output: self.assertLess( len(line), truncate_length, 'Logging function lines should not be of truncating length.')