def _make_config_asset_file(config, asset_path): asset_config = copy.deepcopy(config) for key, value in six.iteritems(asset_config): # Only keep the basename for files (that should also be registered as assets). if isinstance(value, six.string_types) and compat.gfile_exists(value): asset_config[key] = os.path.basename(value) with open(asset_path, "w") as asset_file: yaml.dump(asset_config, stream=asset_file, default_flow_style=False)
def initialize(self, metadata, asset_dir=None, asset_prefix=""): self.vocabulary_file = metadata[self.vocabulary_file_key] self.vocabulary_size = count_lines(self.vocabulary_file) + self.num_oov_buckets if self.tokenizer is None: tokenizer_config = _get_field(metadata, "tokenization", prefix=asset_prefix) if tokenizer_config: if isinstance(tokenizer_config, six.string_types) and compat.gfile_exists(tokenizer_config): with compat.gfile_open(tokenizer_config, mode="rb") as config_file: tokenizer_config = yaml.load(config_file) self.tokenizer = tokenizers.OpenNMTTokenizer(params=tokenizer_config) else: self.tokenizer = tokenizers.SpaceTokenizer() self.tokenizer.initialize(metadata) return super(TextInputter, self).initialize( metadata, asset_dir=asset_dir, asset_prefix=asset_prefix)
def __init__(self, configuration_file_or_key=None, params=None): """Initializes the tokenizer. Args: configuration_file_or_key: The YAML configuration file or a the key to the YAML configuration file. """ self._configuration_key = None if params is not None: self._config = params else: self._config = {} if configuration_file_or_key is not None and compat.gfile_exists( configuration_file_or_key): configuration_file = configuration_file_or_key with compat.gfile_open(configuration_file, mode="rb") as conf_file: self._config = yaml.load(conf_file) else: self._configuration_key = configuration_file_or_key
def load_model(model_dir, model_file=None, model_name=None, serialize_model=True): """Loads the model from the catalog or a file. The model object is pickled in :obj:`model_dir` to make the model configuration optional for future runs. Args: model_dir: The model directory. model_file: An optional model configuration. Mutually exclusive with :obj:`model_name`. model_name: An optional model name from the catalog. Mutually exclusive with :obj:`model_file`. serialize_model: Serialize the model definition in the model directory. Returns: A :class:`opennmt.models.model.Model` instance. Raises: ValueError: if both :obj:`model_file` and :obj:`model_name` are set. """ if model_file and model_name: raise ValueError("only one of model_file and model_name should be set") model_name_or_path = model_file or model_name model_description_path = os.path.join(model_dir, "model_description.py") # Also try to load the pickled model for backward compatibility. serial_model_file = os.path.join(model_dir, "model_description.pkl") if model_name_or_path: if tf.train.latest_checkpoint(model_dir) is not None: compat.logging.warn( "You provided a model configuration but a checkpoint already exists. " "The model configuration must define the same model as the one used for " "the initial training. However, you can change non structural values like " "dropout.") if model_file: model = load_model_from_file(model_file) if serialize_model: compat.gfile_copy(model_file, model_description_path, overwrite=True) elif model_name: model = load_model_from_catalog(model_name) if serialize_model: with compat.gfile_open(model_description_path, mode="w") as model_description_file: model_description_file.write( "from opennmt.models import catalog\n") model_description_file.write("model = catalog.%s\n" % model_name) elif compat.gfile_exists(model_description_path): compat.logging.info("Loading model description from %s", model_description_path) model = load_model_from_file(model_description_path) elif compat.gfile_exists(serial_model_file): compat.logging.info("Loading serialized model description from %s", serial_model_file) with compat.gfile_open(serial_model_file, mode="rb") as serial_model: model = pickle.load(serial_model) else: raise RuntimeError( "A model configuration is required: you probably need to " "set --model or --model_type on the command line.") return model