示例#1
0
class DetrTask(cfg.TaskConfig):
    model: Detr = Detr()
    train_data: cfg.DataConfig = cfg.DataConfig()
    validation_data: cfg.DataConfig = cfg.DataConfig()
    losses: Losses = Losses()
    init_checkpoint: Optional[str] = None
    init_checkpoint_modules: Union[str, List[str]] = 'all'  # all, backbone
    annotation_file: Optional[str] = None
    per_category_metrics: bool = False
示例#2
0
class TranslationConfig(cfg.TaskConfig):
  """The translation task config."""
  model: ModelConfig = ModelConfig()
  train_data: cfg.DataConfig = cfg.DataConfig()
  validation_data: cfg.DataConfig = cfg.DataConfig()
  # Tokenization
  sentencepiece_model_path: str = ""
  # Evaluation.
  print_translations: Optional[bool] = None
示例#3
0
class DualEncoderConfig(cfg.TaskConfig):
    """The model config."""
    # At most one of `init_checkpoint` and `hub_module_url` can
    # be specified.
    init_checkpoint: str = ''
    hub_module_url: str = ''
    # Defines the concrete model config at instantiation time.
    model: ModelConfig = ModelConfig()
    train_data: cfg.DataConfig = cfg.DataConfig()
    validation_data: cfg.DataConfig = cfg.DataConfig()
示例#4
0
class MaskedLMConfig(cfg.TaskConfig):
    """The model config."""
    model: bert.PretrainerConfig = bert.PretrainerConfig(cls_heads=[
        bert.ClsHeadConfig(inner_dim=768,
                           num_classes=2,
                           dropout_rate=0.1,
                           name='next_sentence')
    ])
    train_data: cfg.DataConfig = cfg.DataConfig()
    validation_data: cfg.DataConfig = cfg.DataConfig()
示例#5
0
class BertDistillationTaskConfig(cfg.TaskConfig):
  """Defines the teacher/student model architecture and training data."""
  teacher_model: bert.PretrainerConfig = bert.PretrainerConfig(
      encoder=encoders.EncoderConfig(type='mobilebert'))

  student_model: bert.PretrainerConfig = bert.PretrainerConfig(
      encoder=encoders.EncoderConfig(type='mobilebert'))
  # The path to the teacher model checkpoint or its directory.
  teacher_model_init_checkpoint: str = ''
  train_data: cfg.DataConfig = cfg.DataConfig()
  validation_data: cfg.DataConfig = cfg.DataConfig()
class QuestionAnsweringConfig(cfg.TaskConfig):
  """The model config."""
  # At most one of `init_checkpoint` and `hub_module_url` can be specified.
  init_checkpoint: str = ''
  hub_module_url: str = ''
  n_best_size: int = 20
  max_answer_length: int = 30
  null_score_diff_threshold: float = 0.0
  model: ModelConfig = ModelConfig()
  train_data: cfg.DataConfig = cfg.DataConfig()
  validation_data: cfg.DataConfig = cfg.DataConfig()
class MaskedLMConfig(cfg.TaskConfig):
  """The model config."""
  model: bert.PretrainerConfig = bert.PretrainerConfig(cls_heads=[
      bert.ClsHeadConfig(
          inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
  ])
  # TODO(b/154564893): Mathematically, scale_loss should be True.
  # However, it works better with scale_loss being False.
  scale_loss: bool = False
  train_data: cfg.DataConfig = cfg.DataConfig()
  validation_data: cfg.DataConfig = cfg.DataConfig()
示例#8
0
class SentencePredictionConfig(cfg.TaskConfig):
  """The model config."""
  # At most one of `init_checkpoint` and `hub_module_url` can
  # be specified.
  init_checkpoint: str = ''
  init_cls_pooler: bool = False
  hub_module_url: str = ''
  metric_type: str = 'accuracy'
  # Defines the concrete model config at instantiation time.
  model: ModelConfig = ModelConfig()
  train_data: cfg.DataConfig = cfg.DataConfig()
  validation_data: cfg.DataConfig = cfg.DataConfig()
class TaggingConfig(cfg.TaskConfig):
    """The model config."""
    # At most one of `init_checkpoint` and `hub_module_url` can be specified.
    init_checkpoint: str = ''
    hub_module_url: str = ''
    model: ModelConfig = ModelConfig()

    # The real class names, the order of which should match real label id.
    # Note that a word may be tokenized into multiple word_pieces tokens, and
    # we asssume the real label id (non-negative) is assigned to the first token
    # of the word, and a negative label id is assigned to the remaining tokens.
    # The negative label id will not contribute to loss and metrics.
    class_names: Optional[List[str]] = None
    train_data: cfg.DataConfig = cfg.DataConfig()
    validation_data: cfg.DataConfig = cfg.DataConfig()
示例#10
0
    def testSegmentationInputReader(self, input_size, num_classes,
                                    num_channels):
        params = cfg.DataConfig(input_path=self._data_path,
                                global_batch_size=2,
                                is_training=False)

        decoder = segmentation_input_3d.Decoder()
        parser = segmentation_input_3d.Parser(input_size=input_size,
                                              num_classes=num_classes,
                                              num_channels=num_channels)

        reader = input_reader.InputReader(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn('tfrecord'),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))

        dataset = reader.read()
        iterator = iter(dataset)
        image, labels = next(iterator)

        # Checks image shape.
        self.assertEqual(
            list(image.numpy().shape),
            [2, input_size[0], input_size[1], input_size[2], num_channels])
        self.assertEqual(
            list(labels.numpy().shape),
            [2, input_size[0], input_size[1], input_size[2], num_classes])
示例#11
0
class DetectionConfig(cfg.TaskConfig):
    """The translation task config."""
    train_data: cfg.DataConfig = cfg.DataConfig()
    validation_data: cfg.DataConfig = cfg.DataConfig()
    lambda_cls: float = 1.0
    lambda_box: float = 5.0
    lambda_giou: float = 2.0

    init_ckpt: str = ''
    num_classes: int = 81  # 0: background
    background_cls_weight: float = 0.1
    num_encoder_layers: int = 6
    num_decoder_layers: int = 6

    # Make DETRConfig.
    num_queries: int = 100
    num_hidden: int = 256
    per_category_metrics: bool = False
示例#12
0
def unified_detector() -> cfg.ExperimentConfig:
    """Configurations for trainer of unified detector."""
    total_train_steps = 100000
    summary_interval = steps_per_loop = 200
    checkpoint_interval = 2000
    warmup_steps = 1000
    config = cfg.ExperimentConfig(
        # Input pipeline and model are configured through Gin.
        task=OcrTaskConfig(train_data=cfg.DataConfig(is_training=True)),
        trainer=cfg.TrainerConfig(
            train_steps=total_train_steps,
            steps_per_loop=steps_per_loop,
            summary_interval=summary_interval,
            checkpoint_interval=checkpoint_interval,
            max_to_keep=1,
            optimizer_config=optimization.OptimizationConfig({
                'optimizer': {
                    'type': 'adamw',
                    'adamw': {
                        'weight_decay_rate':
                        0.05,
                        'include_in_weight_decay': [
                            '^((?!depthwise).)*(kernel|weights):0$',
                        ],
                        'exclude_from_weight_decay': [
                            '(^((?!kernel).)*:0)|(depthwise_kernel)',
                        ],
                        'gradient_clip_norm':
                        10.,
                    },
                },
                'learning_rate': {
                    'type': 'cosine',
                    'cosine': {
                        'initial_learning_rate': 1e-3,
                        'decay_steps': total_train_steps - warmup_steps,
                        'alpha': 1e-2,
                        'offset': warmup_steps,
                    },
                },
                'warmup': {
                    'type': 'linear',
                    'linear': {
                        'warmup_learning_rate': 1e-5,
                        'warmup_steps': warmup_steps,
                    }
                },
            }),
        ),
    )
    return config
示例#13
0
class TeamsPretrainTaskConfig(cfg.TaskConfig):
    """The model config."""
    model: teams.TeamsPretrainerConfig = teams.TeamsPretrainerConfig()
    train_data: cfg.DataConfig = cfg.DataConfig()
    validation_data: cfg.DataConfig = cfg.DataConfig()
示例#14
0
class OcrTaskConfig(cfg.TaskConfig):
    train_data: cfg.DataConfig = cfg.DataConfig()
    model_call_needs_labels: bool = False