Esempio n. 1
0
def convert(assets_filepath):
    """Converts existing asset pickle based files to t2r proto based assets."""

    t2r_assets = t2r_pb2.T2RAssets()
    input_spec_filepath = os.path.join(assets_filepath, 'input_specs.pkl')
    if not tf.io.gfile.exists(input_spec_filepath):
        raise ValueError('No file exists for {}.'.format(input_spec_filepath))
    feature_spec, label_spec = tensorspec_utils.load_input_spec_from_file(
        input_spec_filepath)

    t2r_assets.feature_spec.CopyFrom(feature_spec.to_proto())
    t2r_assets.label_spec.CopyFrom(label_spec.to_proto())

    global_step_filepath = os.path.join(assets_filepath, 'global_step.pkl')
    if tf.io.gfile.exists(global_step_filepath):
        global_step = tensorspec_utils.load_input_spec_from_file(
            global_step_filepath)
        t2r_assets.global_step = global_step

    t2r_assets_filepath = os.path.join(assets_filepath,
                                       tensorspec_utils.T2R_ASSETS_FILENAME)
    tensorspec_utils.write_t2r_assets_to_file(t2r_assets, t2r_assets_filepath)
  def restore(self):
    """Restores the model parameters from the latest available data.

    Raises:
      ValueError: If no checkpoint can be found or loaded within the user
        defined timeout.

    Returns:
      True if a exported saved model has been loaded and False otherwise.
    """
    start_time = time.time()

    while time.time() - start_time < self._timeout:
      # The exported saved models directory names are numbers (timestamp) which
      # monotonically increase, meaning the largest directory name will contain
      # the latest exported model. Lexicographical sorting will maintain this
      # order.
      model_dirs = sorted(tf.io.gfile.glob(os.path.join(self._export_dir, '*')))
      model_dirs = self._remove_invalid_model_dirnames(model_dirs)

      if len(model_dirs) >= 1:
        logging.info('Found latest model at %s. ', model_dirs[-1])
        break

      logging.info('Waiting for an exported model to become available at %s.',
                   self._export_dir)
      # Since a checkpoint might not be available and this is a busy waiting
      # loop, we throttle checking for checkpoints.
      time.sleep(_BUSY_WAITING_SLEEP_TIME_IN_SECS)

    if model_dirs is None or not model_dirs:
      logging.warning('No checkpoint available after %s seconds.',
                      str(self._timeout))
      return False

    if self._latest_export_dir == model_dirs[-1]:
      # The latest model has already been loaded.
      return True

    logging.info('Loading the latest model at %s. ', model_dirs[-1])
    self._latest_export_dir = model_dirs[-1]
    start_time_loading = time.time()

    # Note, loading from a saved model might require several attempts if
    # the checkpoint gets written asynchronously.
    while time.time() - start_time_loading < self._timeout:
      try:
        self._predict_fn = tf.contrib.predictor.from_saved_model(
            model_dirs[-1], config=self._tf_config)
        t2r_assets_file_path = os.path.join(
            model_dirs[-1], 'assets.extra',
            tensorspec_utils.T2R_ASSETS_FILENAME)
        if tf.io.gfile.exists(t2r_assets_file_path):
          t2r_assets = tensorspec_utils.load_t2r_assets_to_file(
              t2r_assets_file_path)
          self._feature_spec = tensorspec_utils.TensorSpecStruct.from_proto(
              t2r_assets.feature_spec)  # pytype: disable=wrong-arg-types
          self._label_spec = tensorspec_utils.TensorSpecStruct.from_proto(
              t2r_assets.label_spec)  # pytype: disable=wrong-arg-types
          if t2r_assets.HasField('global_step'):
            self._global_step = t2r_assets.global_step
          else:
            logging.warning(
                'Error loading the global step, therefore using the previously'
                'set global step %s.', str(self.global_step))
        else:
          input_spec_filename = os.path.join(model_dirs[-1], 'assets.extra',
                                             'input_specs.pkl')
          logging.warning(
              'Using the legacy loading, please convert the assets '
              'using convert_pkl_assets_to_proto_assets binary for '
              'file path %s.', input_spec_filename)
          # Load input specs from file.
          self._feature_spec, self._label_spec = (
              tensorspec_utils.load_input_spec_from_file(input_spec_filename))

          # Load input specs from file.
          global_step_filename = os.path.join(model_dirs[-1], 'assets.extra',
                                              'global_step.pkl')
          try:
            global_step = tensorspec_utils.load_global_step_from_file(
                global_step_filename)
            self._global_step = global_step
          except ValueError:
            logging.warning(
                'Error loading the global step, therefore using the previously'
                'set global step %s.', str(self.global_step))
        return True
      except ValueError as err:
        logging.warning(
            'Error loading model as %s:\n%s\nThe next attempt at loading the '
            'latest model will be in %d seconds', model_dirs[-1], err,
            _BUSY_WAITING_SLEEP_TIME_IN_SECS)
      # Since a checkpoint might be written by the tf model concurrently
      # this is a busy waiting loop.
      time.sleep(_BUSY_WAITING_SLEEP_TIME_IN_SECS)
    logging.warning(
        'The checkpoint at %s could not be loaded after '
        '%s seconds.', str(self._latest_export_dir), str(self._timeout))
    return False