Exemple #1
0
    def restore(self):
        """Restores the model parameters from the latest available data."""

        logging.info('Trying to restore saved model from %s',
                     self._saved_model_path)
        # Get the expected assets filename.
        t2r_assets_dir = os.path.join(self._saved_model_path,
                                      tensorspec_utils.EXTRA_ASSETS_DIRECTORY)
        t2r_assets_filename = os.path.join(
            t2r_assets_dir, tensorspec_utils.T2R_ASSETS_FILENAME)

        start_time = time.time()
        while time.time() - start_time < self._timeout:
            # Check for the assets.extra/t2r_assets.pbtxt file which is materialized
            # last. Otherwise we should check for saved_model.pb
            if tf.io.gfile.exists(t2r_assets_filename):
                break

            logging.info(
                'Waiting for a saved model to become available at %s.',
                self._saved_model_path)
            time.sleep(_BUSY_WAITING_SLEEP_TIME_IN_SECS)
        else:
            logging.warning('No saved_model found after %s seconds.',
                            str(self._timeout))
            return False

        # Loading assets for features and labels.
        t2r_assets_file_path = os.path.join(
            self._saved_model_path, tensorspec_utils.EXTRA_ASSETS_DIRECTORY,
            tensorspec_utils.T2R_ASSETS_FILENAME)
        t2r_assets = tensorspec_utils.load_t2r_assets_to_file(
            t2r_assets_file_path)

        self._feature_spec = tensorspec_utils.TensorSpecStruct.from_proto(
            t2r_assets.feature_spec)  # pytype: disable=wrong-arg-types
        self._label_spec = tensorspec_utils.TensorSpecStruct.from_proto(
            t2r_assets.label_spec)  # pytype: disable=wrong-arg-types

        self._model = tf.saved_model.load(self._saved_model_path)
        return True
  def restore(self):
    """Restores the model parameters from the latest available data.

    Raises:
      ValueError: If no checkpoint can be found or loaded within the user
        defined timeout.

    Returns:
      True if a exported saved model has been loaded and False otherwise.
    """
    start_time = time.time()

    while time.time() - start_time < self._timeout:
      # The exported saved models directory names are numbers (timestamp) which
      # monotonically increase, meaning the largest directory name will contain
      # the latest exported model. Lexicographical sorting will maintain this
      # order.
      model_dirs = sorted(tf.io.gfile.glob(os.path.join(self._export_dir, '*')))
      model_dirs = self._remove_invalid_model_dirnames(model_dirs)

      if len(model_dirs) >= 1:
        logging.info('Found latest model at %s. ', model_dirs[-1])
        break

      logging.info('Waiting for an exported model to become available at %s.',
                   self._export_dir)
      # Since a checkpoint might not be available and this is a busy waiting
      # loop, we throttle checking for checkpoints.
      time.sleep(_BUSY_WAITING_SLEEP_TIME_IN_SECS)

    if model_dirs is None or not model_dirs:
      logging.warning('No checkpoint available after %s seconds.',
                      str(self._timeout))
      return False

    if self._latest_export_dir == model_dirs[-1]:
      # The latest model has already been loaded.
      return True

    logging.info('Loading the latest model at %s. ', model_dirs[-1])
    self._latest_export_dir = model_dirs[-1]
    start_time_loading = time.time()

    # Note, loading from a saved model might require several attempts if
    # the checkpoint gets written asynchronously.
    while time.time() - start_time_loading < self._timeout:
      try:
        self._predict_fn = tf.contrib.predictor.from_saved_model(
            model_dirs[-1], config=self._tf_config)
        t2r_assets_file_path = os.path.join(
            model_dirs[-1], 'assets.extra',
            tensorspec_utils.T2R_ASSETS_FILENAME)
        if tf.io.gfile.exists(t2r_assets_file_path):
          t2r_assets = tensorspec_utils.load_t2r_assets_to_file(
              t2r_assets_file_path)
          self._feature_spec = tensorspec_utils.TensorSpecStruct.from_proto(
              t2r_assets.feature_spec)  # pytype: disable=wrong-arg-types
          self._label_spec = tensorspec_utils.TensorSpecStruct.from_proto(
              t2r_assets.label_spec)  # pytype: disable=wrong-arg-types
          if t2r_assets.HasField('global_step'):
            self._global_step = t2r_assets.global_step
          else:
            logging.warning(
                'Error loading the global step, therefore using the previously'
                'set global step %s.', str(self.global_step))
        else:
          input_spec_filename = os.path.join(model_dirs[-1], 'assets.extra',
                                             'input_specs.pkl')
          logging.warning(
              'Using the legacy loading, please convert the assets '
              'using convert_pkl_assets_to_proto_assets binary for '
              'file path %s.', input_spec_filename)
          # Load input specs from file.
          self._feature_spec, self._label_spec = (
              tensorspec_utils.load_input_spec_from_file(input_spec_filename))

          # Load input specs from file.
          global_step_filename = os.path.join(model_dirs[-1], 'assets.extra',
                                              'global_step.pkl')
          try:
            global_step = tensorspec_utils.load_global_step_from_file(
                global_step_filename)
            self._global_step = global_step
          except ValueError:
            logging.warning(
                'Error loading the global step, therefore using the previously'
                'set global step %s.', str(self.global_step))
        return True
      except ValueError as err:
        logging.warning(
            'Error loading model as %s:\n%s\nThe next attempt at loading the '
            'latest model will be in %d seconds', model_dirs[-1], err,
            _BUSY_WAITING_SLEEP_TIME_IN_SECS)
      # Since a checkpoint might be written by the tf model concurrently
      # this is a busy waiting loop.
      time.sleep(_BUSY_WAITING_SLEEP_TIME_IN_SECS)
    logging.warning(
        'The checkpoint at %s could not be loaded after '
        '%s seconds.', str(self._latest_export_dir), str(self._timeout))
    return False
Exemple #3
0
    def restore(self):
        """Restores the model parameters from the latest available data.

    Raises:
      ValueError: If no checkpoint can be found or loaded within the user
        defined timeout.

    Returns:
      True if a exported saved model has been loaded and False otherwise.
    """
        start_time = time.time()
        while time.time() - start_time < self._timeout:
            model_dir = self._latest_valid_model_dirs(
                tf.io.gfile.glob(os.path.join(self._export_dir, '*')))

            if model_dir is not None:
                logging.info('Found latest model at %s. ', model_dir)
                break

            logging.info(
                'Waiting for an exported model to become available at %s.',
                self._export_dir)
            # Since a checkpoint might not be available and this is a busy waiting
            # loop, we throttle checking for checkpoints.
            time.sleep(_BUSY_WAITING_SLEEP_TIME_IN_SECS)

        if model_dir is None:
            logging.warning('No checkpoint available after %s seconds.',
                            str(self._timeout))
            return False

        if self._latest_export_dir == model_dir:
            # The latest model has already been loaded.
            return True

        logging.info('Loading the latest model at %s. ', model_dir)
        self._latest_export_dir = model_dir
        start_time_loading = time.time()

        # Note, loading from a saved model might require several attempts if
        # the checkpoint gets written asynchronously.
        while time.time() - start_time_loading < self._timeout:
            try:
                t2r_assets_file_path = os.path.join(
                    model_dir, tensorspec_utils.EXTRA_ASSETS_DIRECTORY,
                    tensorspec_utils.T2R_ASSETS_FILENAME)
                t2r_assets = tensorspec_utils.load_t2r_assets_to_file(
                    t2r_assets_file_path)
                self._feature_spec = tensorspec_utils.TensorSpecStruct.from_proto(
                    t2r_assets.feature_spec)  # pytype: disable=wrong-arg-types
                self._label_spec = tensorspec_utils.TensorSpecStruct.from_proto(
                    t2r_assets.label_spec)  # pytype: disable=wrong-arg-types

                if t2r_assets.HasField('global_step'):
                    self._global_step = t2r_assets.global_step
                else:
                    logging.warning(
                        'Error loading the global step, therefore using the previously'
                        'set global step %s.', str(self.global_step))

                self._predict_fn = contrib_predictor.from_saved_model(
                    model_dir, config=self._tf_config)
                model_global_step = self._predict_fn.session.run(
                    self._predict_fn.graph.get_collection(
                        tf.GraphKeys.GLOBAL_STEP))[0]
                if (model_global_step is not None
                        and model_global_step != self._global_step):
                    logging.warning(
                        'Using the global step loaded from the model %s and not the '
                        'one from the assets file %s.', str(model_global_step),
                        str(self._global_step))
                    self._global_step = model_global_step
                return True
            except ValueError as err:
                logging.warning(
                    'Error loading model as %s:\n%s\nThe next attempt at loading the '
                    'latest model will be in %d seconds', model_dir, err,
                    _BUSY_WAITING_SLEEP_TIME_IN_SECS)
            # Since a checkpoint might be written by the tf model concurrently
            # this is a busy waiting loop.
            time.sleep(_BUSY_WAITING_SLEEP_TIME_IN_SECS)
        logging.warning(
            'The checkpoint at %s could not be loaded after '
            '%s seconds.', str(self._latest_export_dir), str(self._timeout))
        return False