Ejemplo n.º 1
0
  def _create_config_proto(self):
    """Create `TPUEmbeddingConfiguration`."""
    config_proto = elc.TPUEmbeddingConfiguration()
    for table in self._table_to_config_dict:
      table_descriptor = config_proto.table_descriptor.add()
      table_descriptor.name = table

      table_config = self._table_to_config_dict[table]
      table_descriptor.vocabulary_size = table_config.vocabulary_size
      table_descriptor.dimension = table_config.dimension

      features_for_table = self._table_to_features_dict[table]
      table_descriptor.num_features = len(features_for_table)

      table_descriptor.optimization_parameters.learning_rate.constant = (
          self._optimization_parameters.learning_rate)
      table_descriptor.optimization_parameters.gradient_accumulation_status = (
          optimization_parameters_pb2.GradientAccumulationStatus.ENABLED
          if self._optimization_parameters.use_gradient_accumulation else
          optimization_parameters_pb2.GradientAccumulationStatus.DISABLED)
      self._optimizer_handler.set_optimization_parameters(table_descriptor)

    config_proto.mode = self._mode
    config_proto.batch_size_per_tensor_core = self._batch_size_per_core
    config_proto.num_hosts = self._num_hosts
    config_proto.num_tensor_cores = self._num_cores
    config_proto.sharding_strategy = elc.TPUEmbeddingConfiguration.DIV_DEFAULT
    config_proto.pipeline_execution_with_tensor_core = (
        self._pipeline_execution_with_tensor_core)

    return config_proto
    def test_no_truncate(self):
        truncate_length = 14937  # Experimentally maximum string length loggable.

        config = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration()
        for i in range(500):
            td = config.table_descriptor.add()
            td.name = 'table_{}'.format(i)
            td.vocabulary_size = i
        config.num_hosts = 2
        config.num_tensor_cores = 4
        config.batch_size_per_tensor_core = 128

        self.assertGreater(
            len(str(config)), truncate_length,
            'Test sanity check: generated config should be of truncating length.'
        )

        with self.assertLogs() as logs:
            tpu_embedding_v2_utils.log_tpu_embedding_configuration(config)

        self.assertIn('table_499', ''.join(logs.output))
        for line in logs.output:
            self.assertLess(
                len(line), truncate_length,
                'Logging function lines should not be of truncating length.')