示例#1
0
def _make_config_asset_file(config, asset_path):
    asset_config = copy.deepcopy(config)
    for key, value in six.iteritems(asset_config):
        # Only keep the basename for files (that should also be registered as assets).
        if isinstance(value, six.string_types) and compat.gfile_exists(value):
            asset_config[key] = os.path.basename(value)
    with open(asset_path, "w") as asset_file:
        yaml.dump(asset_config, stream=asset_file, default_flow_style=False)
示例#2
0
 def initialize(self, metadata, asset_dir=None, asset_prefix=""):
   self.vocabulary_file = metadata[self.vocabulary_file_key]
   self.vocabulary_size = count_lines(self.vocabulary_file) + self.num_oov_buckets
   if self.tokenizer is None:
     tokenizer_config = _get_field(metadata, "tokenization", prefix=asset_prefix)
     if tokenizer_config:
       if isinstance(tokenizer_config, six.string_types) and compat.gfile_exists(tokenizer_config):
         with compat.gfile_open(tokenizer_config, mode="rb") as config_file:
           tokenizer_config = yaml.load(config_file)
       self.tokenizer = tokenizers.OpenNMTTokenizer(params=tokenizer_config)
     else:
       self.tokenizer = tokenizers.SpaceTokenizer()
   self.tokenizer.initialize(metadata)
   return super(TextInputter, self).initialize(
       metadata, asset_dir=asset_dir, asset_prefix=asset_prefix)
示例#3
0
    def __init__(self, configuration_file_or_key=None, params=None):
        """Initializes the tokenizer.

    Args:
      configuration_file_or_key: The YAML configuration file or a the key to
        the YAML configuration file.
    """
        self._configuration_key = None
        if params is not None:
            self._config = params
        else:
            self._config = {}
            if configuration_file_or_key is not None and compat.gfile_exists(
                    configuration_file_or_key):
                configuration_file = configuration_file_or_key
                with compat.gfile_open(configuration_file,
                                       mode="rb") as conf_file:
                    self._config = yaml.load(conf_file)
            else:
                self._configuration_key = configuration_file_or_key
示例#4
0
def load_model(model_dir,
               model_file=None,
               model_name=None,
               serialize_model=True):
    """Loads the model from the catalog or a file.

  The model object is pickled in :obj:`model_dir` to make the model
  configuration optional for future runs.

  Args:
    model_dir: The model directory.
    model_file: An optional model configuration.
      Mutually exclusive with :obj:`model_name`.
    model_name: An optional model name from the catalog.
      Mutually exclusive with :obj:`model_file`.
    serialize_model: Serialize the model definition in the model directory.

  Returns:
    A :class:`opennmt.models.model.Model` instance.

  Raises:
    ValueError: if both :obj:`model_file` and :obj:`model_name` are set.
  """
    if model_file and model_name:
        raise ValueError("only one of model_file and model_name should be set")
    model_name_or_path = model_file or model_name
    model_description_path = os.path.join(model_dir, "model_description.py")

    # Also try to load the pickled model for backward compatibility.
    serial_model_file = os.path.join(model_dir, "model_description.pkl")

    if model_name_or_path:
        if tf.train.latest_checkpoint(model_dir) is not None:
            compat.logging.warn(
                "You provided a model configuration but a checkpoint already exists. "
                "The model configuration must define the same model as the one used for "
                "the initial training. However, you can change non structural values like "
                "dropout.")

        if model_file:
            model = load_model_from_file(model_file)
            if serialize_model:
                compat.gfile_copy(model_file,
                                  model_description_path,
                                  overwrite=True)
        elif model_name:
            model = load_model_from_catalog(model_name)
            if serialize_model:
                with compat.gfile_open(model_description_path,
                                       mode="w") as model_description_file:
                    model_description_file.write(
                        "from opennmt.models import catalog\n")
                    model_description_file.write("model = catalog.%s\n" %
                                                 model_name)
    elif compat.gfile_exists(model_description_path):
        compat.logging.info("Loading model description from %s",
                            model_description_path)
        model = load_model_from_file(model_description_path)
    elif compat.gfile_exists(serial_model_file):
        compat.logging.info("Loading serialized model description from %s",
                            serial_model_file)
        with compat.gfile_open(serial_model_file, mode="rb") as serial_model:
            model = pickle.load(serial_model)
    else:
        raise RuntimeError(
            "A model configuration is required: you probably need to "
            "set --model or --model_type on the command line.")

    return model