예제 #1
0
파일: module_v2.py 프로젝트: syash5/hub
def resolve(handle):
    """Resolves a module handle into a path.

  This function works both for plain TF2 SavedModels and the legacy TF1 Hub
  format.

  Resolves a module handle into a path by downloading and caching in
  location specified by TF_HUB_CACHE_DIR if needed.

  Currently, three types of module handles are supported:
    1) Smart URL resolvers such as tfhub.dev, e.g.:
       https://tfhub.dev/google/nnlm-en-dim128/1.
    2) A directory on a file system supported by Tensorflow containing module
       files. This may include a local directory (e.g. /usr/local/mymodule) or a
       Google Cloud Storage bucket (gs://mymodule).
    3) A URL pointing to a TGZ archive of a module, e.g.
       https://example.com/mymodule.tar.gz.

  Args:
    handle: (string) the Module handle to resolve.

  Returns:
    A string representing the Module path.
  """
    return registry.resolver(handle)
예제 #2
0
def load_module_spec(path):
    """Loads a ModuleSpec from a TF Hub service or the filesystem.

  Warning: Deprecated. This belongs to the hub.Module API and TF1 Hub format.
  For TF2, switch to plain SavedModels and hub.load(); see also hub.resolve().

  THIS FUNCTION IS DEPRECATED.

  Args:
    path: string describing the location of a module. There are several
      supported path encoding schemes:
        a) A URL like "https://tfhub.dev/the/module/1" referring to tfhub.dev or
           another service implementing https://www.tensorflow.org/hub/hosting.
        b) A URL like "https://example.com/module.tar.gz" that points to a
           compressed tarball directly, as long as that web server ignores
           the query parameters added by https://www.tensorflow.org/hub/hosting.
        c) Any filesystem location of a module directory (e.g. /module_dir
           for a local filesystem). All filesystems implementations provided
           by Tensorflow are supported.
        d) Private name resolution schemes added by the maintainer of your
           local installation of the tensorflow_hub library (usually none).

  Returns:
    A ModuleSpec.

  Raises:
    ValueError: on unexpected values in the module spec.
    tf.errors.OpError: on file handling exceptions.
  """
    path = registry.resolver(path)
    return registry.loader(path)
예제 #3
0
 def build(self):
   """Builds the class. Used for lazy initialization."""
   if self.is_built:
     return
   self.vocab_file = os.path.join(
       registry.resolver(self.uri), 'assets', 'vocab.txt')
   self.tokenizer = tokenization.FullTokenizer(self.vocab_file,
                                               self.do_lower_case)
예제 #4
0
def resolve(handle):
    """Resolves a module handle into a path.

   Resolves a module handle into a path by downloading and caching in
   location specified by TF_HUB_CACHE_DIR if needed.

  Args:
    handle: (string) the Module handle to resolve.

  Returns:
    A string representing the Module path.
  """
    return registry.resolver(handle)
예제 #5
0
def load_module_spec(path):
    """Loads a ModuleSpec from the filesystem.

  Args:
    path: string describing the location of a module. There are several
          supported path encoding schemes:
          a) URL location specifying an archived module
            (e.g. http://domain/module.tgz)
          b) Any filesystem location of a module directory (e.g. /module_dir
             for a local filesystem). All filesystems implementations provided
             by Tensorflow are supported.

  Returns:
    A ModuleSpec.

  Raises:
    ValueError: on unexpected values in the module spec.
    tf.OpError: on file handling exceptions.
  """
    path = registry.resolver(path)
    return registry.loader(path)
예제 #6
0
파일: module.py 프로젝트: jankim/hub
def load_module_spec(path):
  """Loads a ModuleSpec from the filesystem.

  Args:
    path: string describing the location of a module. There are several
          supported path encoding schemes:
          a) URL location specifying an archived module
            (e.g. http://domain/module.tgz)
          b) Any filesystem location of a module directory (e.g. /module_dir
             for a local filesystem). All filesystems implementations provided
             by Tensorflow are supported.

  Returns:
    A ModuleSpec.

  Raises:
    ValueError: on unexpected values in the module spec.
    tf.OpError: on file handling exceptions.
  """
  path = registry.resolver(path)
  return registry.loader(path)
예제 #7
0
def load_module_spec(path):
  """Loads a ModuleSpec from the filesystem.

  Args:
    path: string describing the location of a module. There are several
          supported path encoding schemes:
          a) URL location specifying an archived module
            (e.g. http://domain/module.tgz)
          b) Any filesystem location of a module directory (e.g. /module_dir
             for a local filesystem). All filesystems implementations provided
             by Tensorflow are supported.

  Returns:
    A ModuleSpec.

  Raises:
    ValueError: on unexpected values in the module spec.
    tf.OpError: on file handling exceptions.
  """
  path = registry.resolver(path)
  module_def_path = _get_module_proto_path(path)
  module_def_proto = module_def_pb2.ModuleDef()
  with tf.gfile.Open(module_def_path, "rb") as f:
    module_def_proto.ParseFromString(f.read())

  if module_def_proto.format != module_def_pb2.ModuleDef.FORMAT_V3:
    raise ValueError("Unsupported module def format: %r" %
                     module_def_proto.format)

  required_features = set(module_def_proto.required_features)
  unsupported_features = (required_features - _MODULE_V3_SUPPORTED_FEATURES)

  if unsupported_features:
    raise ValueError("Unsupported features: %r" % list(unsupported_features))

  saved_model_handler = saved_model_lib.load(path)
  checkpoint_filename = saved_model_lib.get_variables_path(path)
  return _ModuleSpec(saved_model_handler, checkpoint_filename)
예제 #8
0
def load_module_spec(path):
    """Loads a ModuleSpec from the filesystem.

  DEPRECATION NOTE: This belongs to the hub.Module API and file format for TF1.
  For TF2, switch to plain SavedModels and hub.load().

  Args:
    path: string describing the location of a module. There are several
          supported path encoding schemes:
          a) URL location specifying an archived module
            (e.g. http://domain/module.tgz)
          b) Any filesystem location of a module directory (e.g. /module_dir
             for a local filesystem). All filesystems implementations provided
             by Tensorflow are supported.

  Returns:
    A ModuleSpec.

  Raises:
    ValueError: on unexpected values in the module spec.
    tf.errors.OpError: on file handling exceptions.
  """
    path = registry.resolver(path)
    return registry.loader(path)
예제 #9
0
    def __init__(
            self,
            uri='https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1',
            model_dir=None,
            seq_len=128,
            dropout_rate=0.1,
            initializer_range=0.02,
            learning_rate=3e-5,
            distribution_strategy='mirrored',
            num_gpus=-1,
            tpu='',
            trainable=True,
            do_lower_case=True,
            is_tf2=True):
        """Initialze an instance with model paramaters.

    Args:
      uri: TF-Hub path/url to Bert module.
      model_dir: The location of the model checkpoint files.
      seq_len: Length of the sequence to feed into the model.
      dropout_rate: The rate for dropout.
      initializer_range: The stdev of the truncated_normal_initializer for
        initializing all weight matrices.
      learning_rate: The initial learning rate for Adam.
      distribution_strategy:  A string specifying which distribution strategy to
        use. Accepted values are 'off', 'one_device', 'mirrored',
        'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case
        insensitive. 'off' means not to use Distribution Strategy; 'tpu' means
        to use TPUStrategy using `tpu_address`.
      num_gpus: How many GPUs to use at each worker with the
        DistributionStrategies API. The default is -1, which means utilize all
        available GPUs.
      tpu: TPU address to connect to.
      trainable: boolean, whether pretrain layer is trainable.
      do_lower_case: boolean, whether to lower case the input text. Should be
        True for uncased models and False for cased models.
      is_tf2: boolean, whether the hub module is in TensorFlow 2.x format.
    """
        if compat.get_tf_behavior() not in self.compat_tf_versions:
            raise ValueError(
                'Incompatible versions. Expect {}, but got {}.'.format(
                    self.compat_tf_versions, compat.get_tf_behavior()))
        self.seq_len = seq_len
        self.dropout_rate = dropout_rate
        self.initializer_range = initializer_range
        self.learning_rate = learning_rate
        self.trainable = trainable

        self.model_dir = model_dir
        if self.model_dir is None:
            self.model_dir = tempfile.mkdtemp()

        num_gpus = get_num_gpus(num_gpus)
        self.strategy = distribution_utils.get_distribution_strategy(
            distribution_strategy=distribution_strategy,
            num_gpus=num_gpus,
            tpu_address=tpu)
        self.tpu = tpu

        self.uri = uri

        self.is_tf2 = is_tf2
        self.vocab_file = os.path.join(registry.resolver(uri), 'assets',
                                       'vocab.txt')
        self.do_lower_case = do_lower_case

        self.tokenizer = tokenization.FullTokenizer(self.vocab_file,
                                                    self.do_lower_case)

        self.bert_config = bert_configs.BertConfig(
            0,
            initializer_range=self.initializer_range,
            hidden_dropout_prob=self.dropout_rate)