示例#1
0
def main(_) -> None:
    params = exp_factory.get_exp_config(FLAGS.experiment)
    if FLAGS.config_file is not None:
        for config_file in FLAGS.config_file:
            params = hyperparams.override_params_dict(params,
                                                      config_file,
                                                      is_strict=True)
    if FLAGS.params_override:
        params = hyperparams.override_params_dict(params,
                                                  FLAGS.params_override,
                                                  is_strict=True)

    params.validate()
    params.lock()

    logging.info('Converting SavedModel from %s to TFLite model...',
                 FLAGS.saved_model_dir)
    tflite_model = export_tflite_lib.convert_tflite_model(
        saved_model_dir=FLAGS.saved_model_dir,
        quant_type=FLAGS.quant_type,
        params=params,
        calibration_steps=FLAGS.calibration_steps)

    with tf.io.gfile.GFile(FLAGS.tflite_path, 'wb') as fw:
        fw.write(tflite_model)

    logging.info('TFLite model converted and saved to %s.', FLAGS.tflite_path)
示例#2
0
def main(_):

    params = exp_factory.get_exp_config(FLAGS.experiment)
    for config_file in FLAGS.config_file or []:
        params = hyperparams.override_params_dict(params,
                                                  config_file,
                                                  is_strict=True)
    if FLAGS.params_override:
        params = hyperparams.override_params_dict(params,
                                                  FLAGS.params_override,
                                                  is_strict=True)

    params.validate()
    params.lock()

    export_saved_model_lib.export_inference_graph(
        input_type=FLAGS.input_type,
        batch_size=FLAGS.batch_size,
        input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
        params=params,
        checkpoint_path=FLAGS.checkpoint_path,
        export_dir=FLAGS.export_dir,
        export_module=basnet.BASNetModule(
            params=params,
            batch_size=FLAGS.batch_size,
            input_image_size=[
                int(x) for x in FLAGS.input_image_size.split(',')
            ]),
        export_checkpoint_subdir='checkpoint',
        export_saved_model_subdir='saved_model')
示例#3
0
def main(_):

    params = exp_factory.get_exp_config(_EXPERIMENT.value)
    for config_file in _CONFIG_FILE.value or []:
        params = hyperparams.override_params_dict(params,
                                                  config_file,
                                                  is_strict=True)
    if _PARAMS_OVERRIDE.value:
        params = hyperparams.override_params_dict(params,
                                                  _PARAMS_OVERRIDE.value,
                                                  is_strict=True)

    params.validate()
    params.lock()

    export_saved_model_lib.export_inference_graph(
        input_type=_IMAGE_TYPE.value,
        batch_size=_BATCH_SIZSE.value,
        input_image_size=[int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')],
        params=params,
        checkpoint_path=_CHECKPOINT_PATH.value,
        export_dir=_EXPORT_DIR.value,
        export_checkpoint_subdir=_EXPORT_CHECKPOINT_SUBDIR.value,
        export_saved_model_subdir=_EXPORT_SAVED_MODEL_SUBDIR.value,
        log_model_flops_and_params=_LOG_MODEL_FLOPS_AND_PARAMS.value,
        input_name=_INPUT_NAME.value)
示例#4
0
def parse_configuration(flags_obj):
    """Parses ExperimentConfig from flags."""

    # 1. Get the default config from the registered experiment.
    params = exp_factory.get_exp_config(flags_obj.experiment)
    params.override({'runtime': {
        'tpu': flags_obj.tpu,
    }})

    # 2. Get the first level of override from `--config_file`.
    #    `--config_file` is typically used as a template that specifies the common
    #    override for a particular experiment.
    for config_file in flags_obj.config_file or []:
        params = hyperparams.override_params_dict(params,
                                                  config_file,
                                                  is_strict=True)

    # 3. Get the second level of override from `--params_override`.
    #    `--params_override` is typically used as a further override over the
    #    template. For example, one may define a particular template for training
    #    ResNet50 on ImageNet in a config file and pass it via `--config_file`,
    #    then define different learning rates and pass it via `--params_override`.
    if flags_obj.params_override:
        params = hyperparams.override_params_dict(params,
                                                  flags_obj.params_override,
                                                  is_strict=True)

    params.validate()
    params.lock()

    pp = pprint.PrettyPrinter()
    logging.info('Final experiment parameters: %s',
                 pp.pformat(params.as_dict()))

    return params
示例#5
0
def parse_configuration(flags_obj, lock_return=True, print_return=True):
    """Parses ExperimentConfig from flags."""

    if flags_obj.experiment is None:
        raise ValueError('The flag --experiment must be specified.')

    # 1. Get the default config from the registered experiment.
    params = exp_factory.get_exp_config(flags_obj.experiment)

    # 2. Get the first level of override from `--config_file`.
    #    `--config_file` is typically used as a template that specifies the common
    #    override for a particular experiment.
    for config_file in flags_obj.config_file or []:
        params = hyperparams.override_params_dict(params,
                                                  config_file,
                                                  is_strict=True)

    # 3. Override the TPU address and tf.data service address.
    params.override({
        'runtime': {
            'tpu': flags_obj.tpu,
        },
    })
    if ('tf_data_service' in flags_obj and flags_obj.tf_data_service
            and isinstance(params.task, config_definitions.TaskConfig)):
        params.override({
            'task': {
                'train_data': {
                    'tf_data_service_address': flags_obj.tf_data_service,
                },
                'validation_data': {
                    'tf_data_service_address': flags_obj.tf_data_service,
                }
            }
        })

    # 4. Get the second level of override from `--params_override`.
    #    `--params_override` is typically used as a further override over the
    #    template. For example, one may define a particular template for training
    #    ResNet50 on ImageNet in a config file and pass it via `--config_file`,
    #    then define different learning rates and pass it via `--params_override`.
    if flags_obj.params_override:
        params = hyperparams.override_params_dict(params,
                                                  flags_obj.params_override,
                                                  is_strict=True)

    params.validate()
    if lock_return:
        params.lock()

    if print_return:
        pp = pprint.PrettyPrinter()
        logging.info('Final experiment parameters:\n%s',
                     pp.pformat(params.as_dict()))

    return params
def load_model_config_file(model_config_file: str) -> Dict[str, Any]:
  """Loads bert config json file or `encoders.EncoderConfig` in yaml file."""
  if not model_config_file:
    # model_config_file may be empty when using tf.hub.
    return {}

  try:
    encoder_config = encoders.EncoderConfig()
    encoder_config = hyperparams.override_params_dict(
        encoder_config, model_config_file, is_strict=True)
    logging.info('Load encoder_config yaml file from %s.', model_config_file)
    return encoder_config.as_dict()
  except KeyError:
    pass

  logging.info('Load bert config json file from %s', model_config_file)
  with tf.io.gfile.GFile(model_config_file, 'r') as reader:
    text = reader.read()
    config = json.loads(text)

  def get_value(key1, key2):
    if key1 in config and key2 in config:
      raise ValueError('Unexpected that both %s and %s are in config.' %
                       (key1, key2))

    return config[key1] if key1 in config else config[key2]

  def get_value_or_none(key):
    return config[key] if key in config else None

  # Support both legacy bert_config attributes and the new config attributes.
  return {
      'bert': {
          'attention_dropout_rate':
              get_value('attention_dropout_rate',
                        'attention_probs_dropout_prob'),
          'dropout_rate':
              get_value('dropout_rate', 'hidden_dropout_prob'),
          'hidden_activation':
              get_value('hidden_activation', 'hidden_act'),
          'hidden_size':
              config['hidden_size'],
          'embedding_size':
              get_value_or_none('embedding_size'),
          'initializer_range':
              config['initializer_range'],
          'intermediate_size':
              config['intermediate_size'],
          'max_position_embeddings':
              config['max_position_embeddings'],
          'num_attention_heads':
              config['num_attention_heads'],
          'num_layers':
              get_value('num_layers', 'num_hidden_layers'),
          'type_vocab_size':
              config['type_vocab_size'],
          'vocab_size':
              config['vocab_size'],
      }
  }
示例#7
0
 def parse_config_file(self, params):
     """Override the configs of params from the config_file."""
     for config_file in self._flags_obj.config_file or []:
         params = hyperparams.override_params_dict(params,
                                                   config_file,
                                                   is_strict=True)
     return params
示例#8
0
def create_export_module(*, task_name: Text, config_file: Text,
                         serving_params: Dict[Text, Any]):
    """Creates a ExportModule."""
    task_config_cls = None
    task_cls = None
    # pylint: disable=protected-access
    for key, value in task_factory._REGISTERED_TASK_CLS.items():
        print(key.__name__)
        if task_name in key.__name__:
            task_config_cls, task_cls = key, value
            break
    if task_cls is None:
        raise ValueError(
            "Failed to identify the task class. The provided task "
            f"name is {task_name}")
    # pylint: enable=protected-access
    # TODO(hongkuny): Figure out how to separate the task config from experiments.

    @dataclasses.dataclass
    class Dummy(base_config.Config):
        task: task_config_cls = task_config_cls()

    dummy_exp = Dummy()
    dummy_exp = hyperparams.override_params_dict(dummy_exp,
                                                 config_file,
                                                 is_strict=False)
    dummy_exp.task.validation_data = None
    task = task_cls(dummy_exp.task)
    model = task.build_model()
    export_module_cls = lookup_export_module(task)
    params = export_module_cls.Params(**serving_params)
    return export_module_cls(params=params, model=model)
示例#9
0
def main(_):

    params = exp_factory.get_exp_config(_EXPERIMENT.value)
    for config_file in _CONFIG_FILE.value or []:
        params = hyperparams.override_params_dict(params,
                                                  config_file,
                                                  is_strict=True)
    if _PARAMS_OVERRIDE.value:
        params = hyperparams.override_params_dict(params,
                                                  _PARAMS_OVERRIDE.value,
                                                  is_strict=True)

    params.validate()
    params.lock()

    input_image_size = [int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')]

    if isinstance(params.task,
                  configs.image_classification.ImageClassificationTask):
        export_module_cls = export_module.ClassificationModule
    elif isinstance(params.task,
                    configs.semantic_segmentation.SemanticSegmentationTask):
        export_module_cls = export_module.SegmentationModule
    else:
        raise TypeError(
            f'Export module for {type(params.task)} is not supported.')

    module = export_module_cls(params=params,
                               batch_size=_BATCH_SIZSE.value,
                               input_image_size=input_image_size,
                               input_type=_IMAGE_TYPE.value,
                               num_channels=3)

    export_saved_model_lib.export_inference_graph(
        input_type=_IMAGE_TYPE.value,
        batch_size=_BATCH_SIZSE.value,
        input_image_size=input_image_size,
        params=params,
        checkpoint_path=_CHECKPOINT_PATH.value,
        export_dir=_EXPORT_DIR.value,
        export_checkpoint_subdir=_EXPORT_CHECKPOINT_SUBDIR.value,
        export_saved_model_subdir=_EXPORT_SAVED_MODEL_SUBDIR.value,
        export_module=module,
        log_model_flops_and_params=_LOG_MODEL_FLOPS_AND_PARAMS.value,
        input_name=_INPUT_NAME.value)
示例#10
0
def main(_):
  params = exp_factory.get_exp_config(FLAGS.experiment)
  for config_file in FLAGS.config_file or []:
    params = hyperparams.override_params_dict(
        params, config_file, is_strict=True)
  if FLAGS.params_override:
    params = hyperparams.override_params_dict(
        params, FLAGS.params_override, is_strict=True)
  params.validate()
  params.lock()

  export_model_to_tfhub(
      params=params,
      batch_size=FLAGS.batch_size,
      input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
      skip_logits_layer=FLAGS.skip_logits_layer,
      checkpoint_path=FLAGS.checkpoint_path,
      export_path=FLAGS.export_path)
示例#11
0
def main(_):

  params = exp_factory.get_exp_config(FLAGS.experiment)
  for config_file in FLAGS.config_file or []:
    params = hyperparams.override_params_dict(
        params, config_file, is_strict=True)
  if FLAGS.params_override:
    params = hyperparams.override_params_dict(
        params, FLAGS.params_override, is_strict=True)

  params.validate()
  params.lock()

  input_image_size = [int(x) for x in FLAGS.input_image_size.split(',')]
  input_specs = tf.keras.layers.InputSpec(
      shape=[FLAGS.batch_size, *input_image_size, 3])

  if FLAGS.model == 'panoptic_deeplab':
    build_model = factory.build_panoptic_deeplab
    panoptic_module = panoptic_deeplab.PanopticSegmentationModule
  elif FLAGS.model == 'panoptic_maskrcnn':
    build_model = factory.build_panoptic_maskrcnn
    panoptic_module = panoptic_maskrcnn.PanopticSegmentationModule
  else:
    raise ValueError('Unsupported model type: %s' % FLAGS.model)

  model = build_model(input_specs=input_specs, model_config=params.task.model)
  export_module = panoptic_module(
      params=params,
      model=model,
      batch_size=FLAGS.batch_size,
      input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
      num_channels=3)
  export_saved_model_lib.export_inference_graph(
      input_type=FLAGS.input_type,
      batch_size=FLAGS.batch_size,
      input_image_size=input_image_size,
      params=params,
      checkpoint_path=FLAGS.checkpoint_path,
      export_dir=FLAGS.export_dir,
      export_module=export_module,
      export_checkpoint_subdir='checkpoint',
      export_saved_model_subdir='saved_model')
示例#12
0
 def parse_params_override(self, params):
   # Get the second level of override from `--params_override`.
   # `--params_override` is typically used as a further override over the
   # template. For example, one may define a particular template for training
   # ResNet50 on ImageNet in a config file and pass it via `--config_file`,
   # then define different learning rates and pass it via `--params_override`.
   if self._flags_obj.params_override:
     params = hyperparams.override_params_dict(
         params, self._flags_obj.params_override, is_strict=True)
   return params
示例#13
0
文件: utils.py 项目: xiangww00/models
def config_override(experiment_params, flags_obj):
    """Overrides ExperimentConfig according to flags."""
    if not hasattr(flags_obj, 'tpu'):
        raise ModuleNotFoundError(
            '`tpu` is not found in FLAGS. Need to load flags.py first.')
    # Change runtime.tpu to the real tpu.
    experiment_params.override({'runtime': {
        'tpu_address': flags_obj.tpu,
    }})

    # Get the first level of override from `--config_file`.
    #   `--config_file` is typically used as a template that specifies the common
    #   override for a particular experiment.
    for config_file in flags_obj.config_file or []:
        experiment_params = hyperparams.override_params_dict(experiment_params,
                                                             config_file,
                                                             is_strict=True)

    # Get the second level of override from `--params_override`.
    #   `--params_override` is typically used as a further override over the
    #   template. For example, one may define a particular template for training
    #   ResNet50 on ImageNet in a config file and pass it via `--config_file`,
    #   then define different learning rates and pass it via `--params_override`.
    if flags_obj.params_override:
        experiment_params = hyperparams.override_params_dict(
            experiment_params, flags_obj.params_override, is_strict=True)

    experiment_params.validate()
    experiment_params.lock()

    pp = pprint.PrettyPrinter()
    logging.info('Final experiment parameters: %s',
                 pp.pformat(experiment_params.as_dict()))

    model_dir = get_model_dir(experiment_params, flags_obj)
    if flags_obj.mode is not None:
        if 'train' in flags_obj.mode:
            # Pure eval modes do not output yaml files. Otherwise continuous eval job
            # may race against the train job for writing the same file.
            serialize_config(experiment_params, model_dir)

    return experiment_params
示例#14
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)

    if bool(FLAGS.vocab_file) == bool(FLAGS.sp_model_file):
        raise ValueError("Exactly one of `vocab_file` and `sp_model_file` "
                         "can be specified, but got %s and %s." %
                         (FLAGS.vocab_file, FLAGS.sp_model_file))
    do_lower_case = export_tfhub_lib.get_do_lower_case(FLAGS.do_lower_case,
                                                       FLAGS.vocab_file,
                                                       FLAGS.sp_model_file)

    if FLAGS.export_type in ("model", "model_with_mlm"):
        if bool(FLAGS.bert_config_file) == bool(FLAGS.encoder_config_file):
            raise ValueError(
                "Exactly one of `bert_config_file` and "
                "`encoder_config_file` can be specified, but got "
                "%s and %s." %
                (FLAGS.bert_config_file, FLAGS.encoder_config_file))
        if FLAGS.bert_config_file:
            bert_config = configs.BertConfig.from_json_file(
                FLAGS.bert_config_file)
            encoder_config = None
        else:
            bert_config = None
            encoder_config = encoders.EncoderConfig()
            encoder_config = hyperparams.override_params_dict(
                encoder_config, FLAGS.encoder_config_file, is_strict=True)
        export_tfhub_lib.export_model(
            FLAGS.export_path,
            bert_config=bert_config,
            encoder_config=encoder_config,
            model_checkpoint_path=FLAGS.model_checkpoint_path,
            vocab_file=FLAGS.vocab_file,
            sp_model_file=FLAGS.sp_model_file,
            do_lower_case=do_lower_case,
            with_mlm=FLAGS.export_type == "model_with_mlm",
            copy_pooler_dense_to_encoder=FLAGS.copy_pooler_dense_to_encoder)

    elif FLAGS.export_type == "preprocessing":
        export_tfhub_lib.export_preprocessing(
            FLAGS.export_path,
            vocab_file=FLAGS.vocab_file,
            sp_model_file=FLAGS.sp_model_file,
            do_lower_case=do_lower_case,
            default_seq_length=FLAGS.default_seq_length,
            tokenize_with_offsets=FLAGS.tokenize_with_offsets,
            experimental_disable_assert=FLAGS.
            experimental_disable_assert_in_preprocessing)

    else:
        raise app.UsageError("Unknown value '%s' for flag --export_type" %
                             FLAGS.export_type)
示例#15
0
def config_override(params, flags_obj):
    """Override ExperimentConfig according to flags."""
    # Change runtime.tpu to the real tpu.
    params.override({'runtime': {
        'tpu': flags_obj.tpu,
    }})

    # Get the first level of override from `--config_file`.
    #   `--config_file` is typically used as a template that specifies the common
    #   override for a particular experiment.
    for config_file in flags_obj.config_file or []:
        params = hyperparams.override_params_dict(params,
                                                  config_file,
                                                  is_strict=True)

    # Get the second level of override from `--params_override`.
    #   `--params_override` is typically used as a further override over the
    #   template. For example, one may define a particular template for training
    #   ResNet50 on ImageNet in a config file and pass it via `--config_file`,
    #   then define different learning rates and pass it via `--params_override`.
    if flags_obj.params_override:
        params = hyperparams.override_params_dict(params,
                                                  flags_obj.params_override,
                                                  is_strict=True)

    params.validate()
    params.lock()

    pp = pprint.PrettyPrinter()
    logging.info('Final experiment parameters: %s',
                 pp.pformat(params.as_dict()))

    model_dir = flags_obj.model_dir
    if 'train' in flags_obj.mode:
        # Pure eval modes do not output yaml files. Otherwise continuous eval job
        # may race against the train job for writing the same file.
        train_utils.serialize_config(params, model_dir)

    return params
示例#16
0
def main(_):
    flags.mark_flag_as_required('export_dir')
    flags.mark_flag_as_required('checkpoint_path')

    params = exp_factory.get_exp_config(FLAGS.experiment)
    for config_file in FLAGS.config_file or []:
        params = hyperparams.override_params_dict(params,
                                                  config_file,
                                                  is_strict=True)
    if FLAGS.params_override:
        params = hyperparams.override_params_dict(params,
                                                  FLAGS.params_override,
                                                  is_strict=True)

    params.validate()
    params.lock()

    input_image_size = FLAGS.input_image_size

    export_module = semantic_segmentation_3d.SegmentationModule(
        params=params,
        batch_size=1,
        input_image_size=input_image_size,
        num_channels=FLAGS.num_channels)

    export_saved_model_lib.export_inference_graph(
        input_type=FLAGS.input_type,
        batch_size=FLAGS.batch_size,
        input_image_size=input_image_size,
        params=params,
        checkpoint_path=FLAGS.checkpoint_path,
        export_dir=FLAGS.export_dir,
        num_channels=FLAGS.num_channels,
        export_module=export_module,
        export_checkpoint_subdir='checkpoint',
        export_saved_model_subdir='saved_model')
def _get_params_from_flags(flags_obj: flags.FlagValues):
    """Get ParamsDict from flags."""
    model = flags_obj.model_type.lower()
    dataset = flags_obj.dataset.lower()
    params = configs.get_config(model=model, dataset=dataset)

    flags_overrides = {
        'model_dir': flags_obj.model_dir,
        'mode': flags_obj.mode,
        'model': {
            'name': model,
        },
        'runtime': {
            'run_eagerly': flags_obj.run_eagerly,
            'tpu': flags_obj.tpu,
        },
        'train_dataset': {
            'data_dir': flags_obj.data_dir,
        },
        'validation_dataset': {
            'data_dir': flags_obj.data_dir,
        },
        'train': {
            'time_history': {
                'log_steps': flags_obj.log_steps,
            },
        },
    }

    overriding_configs = (flags_obj.config_file, flags_obj.params_override,
                          flags_overrides)

    pp = pprint.PrettyPrinter()

    logging.info('Base params: %s', pp.pformat(params.as_dict()))

    for param in overriding_configs:
        logging.info('Overriding params: %s', param)
        params = hyperparams.override_params_dict(params,
                                                  param,
                                                  is_strict=True)

    params.validate()
    params.lock()

    logging.info('Final model parameters: %s', pp.pformat(params.as_dict()))
    return params
示例#18
0
def get_export_module(experiment: str,
                      batch_size: int,
                      config_files: Optional[str] = None):
    """Get export module according to experiment config.
  
  Args:
    experiment: `str`, look up for ExperimentConfig factory methods
    batch_size: `int`, batch size of inference
    config_file: `str`, path to yaml file that overrides experiment config.
  """
    params = exp_factory.get_exp_config(experiment)
    for config_file in config_files or []:
        params = hyperparams.override_params_dict(params,
                                                  config_file,
                                                  is_strict=True)
    params.validate()
    params.lock()

    # Obtain relevant serving object
    kwargs = dict(params=params,
                  batch_size=batch_size,
                  input_image_size=params.task.model.input_size[:2],
                  num_channels=3)

    if isinstance(params.task,
                  configs.image_classification.ImageClassificationTask):
        export_module = image_classification.ClassificationModule(**kwargs)
    elif isinstance(params.task, configs.retinanet.RetinaNetTask) or \
      isinstance(params.task, configs.maskrcnn.MaskRCNNTask):
        export_module = detection.DetectionModule(**kwargs)
    elif isinstance(params.task,
                    configs.semantic_segmentation.SemanticSegmentationTask):
        export_module = semantic_segmentation.SegmentationModule(**kwargs)
    elif isinstance(params.task, configs.yolo.YoloTask):
        export_module = yolo.YoloModule(**kwargs)
    elif isinstance(params.task, multi_cfg.MultiTaskConfig):
        export_module = multitask.MultitaskModule(**kwargs)
    else:
        raise ValueError('Export module not implemented for {} task.'.format(
            type(params.task)))

    return export_module