def restore(self): """Restores the model parameters from the latest available data.""" logging.info('Trying to restore saved model from %s', self._saved_model_path) # Get the expected assets filename. t2r_assets_dir = os.path.join(self._saved_model_path, tensorspec_utils.EXTRA_ASSETS_DIRECTORY) t2r_assets_filename = os.path.join( t2r_assets_dir, tensorspec_utils.T2R_ASSETS_FILENAME) start_time = time.time() while time.time() - start_time < self._timeout: # Check for the assets.extra/t2r_assets.pbtxt file which is materialized # last. Otherwise we should check for saved_model.pb if tf.io.gfile.exists(t2r_assets_filename): break logging.info( 'Waiting for a saved model to become available at %s.', self._saved_model_path) time.sleep(_BUSY_WAITING_SLEEP_TIME_IN_SECS) else: logging.warning('No saved_model found after %s seconds.', str(self._timeout)) return False # Loading assets for features and labels. t2r_assets_file_path = os.path.join( self._saved_model_path, tensorspec_utils.EXTRA_ASSETS_DIRECTORY, tensorspec_utils.T2R_ASSETS_FILENAME) t2r_assets = tensorspec_utils.load_t2r_assets_to_file( t2r_assets_file_path) self._feature_spec = tensorspec_utils.TensorSpecStruct.from_proto( t2r_assets.feature_spec) # pytype: disable=wrong-arg-types self._label_spec = tensorspec_utils.TensorSpecStruct.from_proto( t2r_assets.label_spec) # pytype: disable=wrong-arg-types self._model = tf.saved_model.load(self._saved_model_path) return True
def restore(self): """Restores the model parameters from the latest available data. Raises: ValueError: If no checkpoint can be found or loaded within the user defined timeout. Returns: True if a exported saved model has been loaded and False otherwise. """ start_time = time.time() while time.time() - start_time < self._timeout: # The exported saved models directory names are numbers (timestamp) which # monotonically increase, meaning the largest directory name will contain # the latest exported model. Lexicographical sorting will maintain this # order. model_dirs = sorted(tf.io.gfile.glob(os.path.join(self._export_dir, '*'))) model_dirs = self._remove_invalid_model_dirnames(model_dirs) if len(model_dirs) >= 1: logging.info('Found latest model at %s. ', model_dirs[-1]) break logging.info('Waiting for an exported model to become available at %s.', self._export_dir) # Since a checkpoint might not be available and this is a busy waiting # loop, we throttle checking for checkpoints. time.sleep(_BUSY_WAITING_SLEEP_TIME_IN_SECS) if model_dirs is None or not model_dirs: logging.warning('No checkpoint available after %s seconds.', str(self._timeout)) return False if self._latest_export_dir == model_dirs[-1]: # The latest model has already been loaded. return True logging.info('Loading the latest model at %s. ', model_dirs[-1]) self._latest_export_dir = model_dirs[-1] start_time_loading = time.time() # Note, loading from a saved model might require several attempts if # the checkpoint gets written asynchronously. while time.time() - start_time_loading < self._timeout: try: self._predict_fn = tf.contrib.predictor.from_saved_model( model_dirs[-1], config=self._tf_config) t2r_assets_file_path = os.path.join( model_dirs[-1], 'assets.extra', tensorspec_utils.T2R_ASSETS_FILENAME) if tf.io.gfile.exists(t2r_assets_file_path): t2r_assets = tensorspec_utils.load_t2r_assets_to_file( t2r_assets_file_path) self._feature_spec = tensorspec_utils.TensorSpecStruct.from_proto( t2r_assets.feature_spec) # pytype: disable=wrong-arg-types self._label_spec = tensorspec_utils.TensorSpecStruct.from_proto( t2r_assets.label_spec) # pytype: disable=wrong-arg-types if t2r_assets.HasField('global_step'): self._global_step = t2r_assets.global_step else: logging.warning( 'Error loading the global step, therefore using the previously' 'set global step %s.', str(self.global_step)) else: input_spec_filename = os.path.join(model_dirs[-1], 'assets.extra', 'input_specs.pkl') logging.warning( 'Using the legacy loading, please convert the assets ' 'using convert_pkl_assets_to_proto_assets binary for ' 'file path %s.', input_spec_filename) # Load input specs from file. self._feature_spec, self._label_spec = ( tensorspec_utils.load_input_spec_from_file(input_spec_filename)) # Load input specs from file. global_step_filename = os.path.join(model_dirs[-1], 'assets.extra', 'global_step.pkl') try: global_step = tensorspec_utils.load_global_step_from_file( global_step_filename) self._global_step = global_step except ValueError: logging.warning( 'Error loading the global step, therefore using the previously' 'set global step %s.', str(self.global_step)) return True except ValueError as err: logging.warning( 'Error loading model as %s:\n%s\nThe next attempt at loading the ' 'latest model will be in %d seconds', model_dirs[-1], err, _BUSY_WAITING_SLEEP_TIME_IN_SECS) # Since a checkpoint might be written by the tf model concurrently # this is a busy waiting loop. time.sleep(_BUSY_WAITING_SLEEP_TIME_IN_SECS) logging.warning( 'The checkpoint at %s could not be loaded after ' '%s seconds.', str(self._latest_export_dir), str(self._timeout)) return False
def restore(self): """Restores the model parameters from the latest available data. Raises: ValueError: If no checkpoint can be found or loaded within the user defined timeout. Returns: True if a exported saved model has been loaded and False otherwise. """ start_time = time.time() while time.time() - start_time < self._timeout: model_dir = self._latest_valid_model_dirs( tf.io.gfile.glob(os.path.join(self._export_dir, '*'))) if model_dir is not None: logging.info('Found latest model at %s. ', model_dir) break logging.info( 'Waiting for an exported model to become available at %s.', self._export_dir) # Since a checkpoint might not be available and this is a busy waiting # loop, we throttle checking for checkpoints. time.sleep(_BUSY_WAITING_SLEEP_TIME_IN_SECS) if model_dir is None: logging.warning('No checkpoint available after %s seconds.', str(self._timeout)) return False if self._latest_export_dir == model_dir: # The latest model has already been loaded. return True logging.info('Loading the latest model at %s. ', model_dir) self._latest_export_dir = model_dir start_time_loading = time.time() # Note, loading from a saved model might require several attempts if # the checkpoint gets written asynchronously. while time.time() - start_time_loading < self._timeout: try: t2r_assets_file_path = os.path.join( model_dir, tensorspec_utils.EXTRA_ASSETS_DIRECTORY, tensorspec_utils.T2R_ASSETS_FILENAME) t2r_assets = tensorspec_utils.load_t2r_assets_to_file( t2r_assets_file_path) self._feature_spec = tensorspec_utils.TensorSpecStruct.from_proto( t2r_assets.feature_spec) # pytype: disable=wrong-arg-types self._label_spec = tensorspec_utils.TensorSpecStruct.from_proto( t2r_assets.label_spec) # pytype: disable=wrong-arg-types if t2r_assets.HasField('global_step'): self._global_step = t2r_assets.global_step else: logging.warning( 'Error loading the global step, therefore using the previously' 'set global step %s.', str(self.global_step)) self._predict_fn = contrib_predictor.from_saved_model( model_dir, config=self._tf_config) model_global_step = self._predict_fn.session.run( self._predict_fn.graph.get_collection( tf.GraphKeys.GLOBAL_STEP))[0] if (model_global_step is not None and model_global_step != self._global_step): logging.warning( 'Using the global step loaded from the model %s and not the ' 'one from the assets file %s.', str(model_global_step), str(self._global_step)) self._global_step = model_global_step return True except ValueError as err: logging.warning( 'Error loading model as %s:\n%s\nThe next attempt at loading the ' 'latest model will be in %d seconds', model_dir, err, _BUSY_WAITING_SLEEP_TIME_IN_SECS) # Since a checkpoint might be written by the tf model concurrently # this is a busy waiting loop. time.sleep(_BUSY_WAITING_SLEEP_TIME_IN_SECS) logging.warning( 'The checkpoint at %s could not be loaded after ' '%s seconds.', str(self._latest_export_dir), str(self._timeout)) return False