示例#1
0
 def testLoadModelModule(self):
     model_path = os.path.join(self.get_temp_dir(), "model.py")
     with open(model_path, mode="wb") as model_file:
         model_file.write(b"model = lambda: 42")
     model_module = config.load_model_module(model_path)
     model = model_module.model()
     self.assertEqual(42, model)
示例#2
0
def load_model(model_dir, model_file=None):
    """Loads the model.

  The model object is pickled in `model_dir` to make the model configuration
  optional for future runs.

  Args:
    model_dir: The model directory.
    model_file: An optional model configuration.

  Returns:
    A `opennmt.models.Model` object.
  """
    serial_model_file = os.path.join(model_dir, "model_description.pkl")

    if model_file:
        if tf.train.latest_checkpoint(model_dir) is not None:
            tf.logging.warn(
                "You provided a model configuration but a checkpoint already exists. "
                "The model configuration must define the same model as the one used for "
                "the initial training. However, you can change non structural values like "
                "dropout.")

        model_config = load_model_module(model_file)
        model = model_config.model()

        with open(serial_model_file, "wb") as serial_model:
            pickle.dump(model, serial_model)
    elif not os.path.isfile(serial_model_file):
        raise RuntimeError("A model configuration is required.")
    else:
        tf.logging.info("Loading serialized model description from %s",
                        serial_model_file)
        with open(serial_model_file, "rb") as serial_model:
            model = pickle.load(serial_model)

    return model
示例#3
0
文件: main.py 项目: yhgon/OpenNMT-tf
def load_model(model_dir, model_file=None):
  """Loads the model.

  The model object is pickled in `model_dir` to make the model configuration
  optional for future runs.

  Args:
    model_dir: The model directory.
    model_file: An optional model configuration.

  Returns:
    A `opennmt.models.Model` object.
  """
  serial_model_file = os.path.join(model_dir, "model_description.pkl")

  if model_file:
    if tf.train.latest_checkpoint(model_dir) is not None:
      tf.logging.warn(
          "You provided a model configuration but a checkpoint already exists. "
          "The model configuration must define the same model as the one used for "
          "the initial training. However, you can change non structural values like "
          "dropout.")

    model_config = load_model_module(model_file)
    model = model_config.model()

    with open(serial_model_file, "wb") as serial_model:
      pickle.dump(model, serial_model)
  elif not os.path.isfile(serial_model_file):
    raise RuntimeError("A model configuration is required.")
  else:
    tf.logging.info("Loading serialized model description from %s", serial_model_file)
    with open(serial_model_file, "rb") as serial_model:
      model = pickle.load(serial_model)

  return model
示例#4
0
 def testLoadModelModule(self):
     model_path = self._writeCustomModel()
     model_module = config.load_model_module(model_path)
     model = model_module.model()
     self.assertEqual(42, model)