def get_checkpoint_state(checkpoint_dir, latest_filename=None): """Returns CheckpointState proto from the "checkpoint" file. If the "checkpoint" file contains a valid CheckpointState proto, returns it. Args: checkpoint_dir: The directory of checkpoints. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Returns: A CheckpointState if the state was available, None otherwise. Raises: ValueError: if the checkpoint read doesn't have model_checkpoint_path set. """ ckpt = None coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename) f = None try: # Check that the file exists before opening it to avoid # many lines of errors from colossus in the logs. if file_io.file_exists(coord_checkpoint_filename): file_content = file_io.read_file_to_string( coord_checkpoint_filename) ckpt = CheckpointState() text_format.Merge(file_content, ckpt) if not ckpt.model_checkpoint_path: raise ValueError("Invalid checkpoint state loaded from " + checkpoint_dir) # For relative model_checkpoint_path and all_model_checkpoint_paths, # prepend checkpoint_dir. if not os.path.isabs(ckpt.model_checkpoint_path): ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, ckpt.model_checkpoint_path) for i in range(len(ckpt.all_model_checkpoint_paths)): p = ckpt.all_model_checkpoint_paths[i] if not os.path.isabs(p): ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) except errors.OpError as e: # It's ok if the file cannot be read logging.warning("%s: %s", type(e).__name__, e) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None except text_format.ParseError as e: logging.warning("%s: %s", type(e).__name__, e) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None finally: if f: f.close() return ckpt
def get_checkpoint_state(checkpoint_dir, latest_filename=None): """Returns CheckpointState proto from the "checkpoint" file. If the "checkpoint" file contains a valid CheckpointState proto, returns it. Args: checkpoint_dir: The directory of checkpoints. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Returns: A CheckpointState if the state was available, None otherwise. """ ckpt = None coord_checkpoint_filename = _GetCheckpointFilename( checkpoint_dir, latest_filename) f = None try: # Check that the file exists before opening it to avoid # many lines of errors from colossus in the logs. if pywrap_tensorflow.file_exists(coord_checkpoint_filename): file_content = pywrap_tensorflow.read_file_to_string( coord_checkpoint_filename).decode("utf-8") ckpt = CheckpointState() text_format.Merge(file_content, ckpt) # For relative model_checkpoint_path and all_model_checkpoint_paths, # prepend checkpoint_dir. if not os.path.isabs(checkpoint_dir): if not os.path.isabs(ckpt.model_checkpoint_path): ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, ckpt.model_checkpoint_path) for i in range(len(ckpt.all_model_checkpoint_paths)): p = ckpt.all_model_checkpoint_paths[i] if not os.path.isabs(p): ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) except IOError: # It's ok if the file cannot be read return None except text_format.ParseError as e: logging.warning(str(e)) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None finally: if f: f.close() return ckpt
def generate_checkpoint_state_proto(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, all_model_checkpoint_timestamps=None, last_preserved_timestamp=None): """Generates a checkpoint state proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. all_model_checkpoint_timestamps: A list of floats, indicating the number of seconds since the Epoch when each checkpoint was generated. last_preserved_timestamp: A float, indicating the number of seconds since the Epoch when the last preserved checkpoint was written, e.g. due to a `keep_checkpoint_every_n_hours` parameter (see `tf.train.CheckpointManager` for an implementation). Returns: CheckpointState proto with model_checkpoint_path and all_model_checkpoint_paths updated to either absolute paths or relative paths to the current save_dir. Raises: ValueError: If `all_model_checkpoint_timestamps` was provided but its length does not match `all_model_checkpoint_paths`. """ if all_model_checkpoint_paths is None: all_model_checkpoint_paths = [] if (not all_model_checkpoint_paths or all_model_checkpoint_paths[-1] != model_checkpoint_path): logging.info( "%s is not in all_model_checkpoint_paths. Manually adding it.", model_checkpoint_path) all_model_checkpoint_paths.append(model_checkpoint_path) if (all_model_checkpoint_timestamps and (len(all_model_checkpoint_timestamps) != len(all_model_checkpoint_paths))): raise ValueError(( "Checkpoint timestamps, if provided, must match checkpoint paths (got " "paths %s and timestamps %s)") % (all_model_checkpoint_paths, all_model_checkpoint_timestamps)) # Relative paths need to be rewritten to be relative to the "save_dir" # if model_checkpoint_path already contains "save_dir". if not os.path.isabs(save_dir): if not os.path.isabs(model_checkpoint_path): model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) for i, p in enumerate(all_model_checkpoint_paths): if not os.path.isabs(p): all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) coord_checkpoint_proto = CheckpointState( model_checkpoint_path=model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths, all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, last_preserved_timestamp=last_preserved_timestamp) return coord_checkpoint_proto
evaluator_task.evaluate(mock_experiment) assert len(mock_exporter.export.call_args_list) == len(ckpt_to_export) assert len(mock_experiment.estimator.evaluate.call_args_list) == len( ckpt_to_export) export_path = os.path.join(mock_experiment.estimator.model_dir, mock_exporter.name) if len(ckpt_to_export) > 0: for ckpt in ckpt_to_export: mock_exporter.export.assert_any_call(ANY, export_path, ckpt, ANY, ANY) mock_experiment.estimator.evaluate(ANY, steps=ANY, hooks=ANY, name=ANY, checkpoint_path=ckpt) @pytest.mark.parametrize("checkpoint_state,checkpoints", [ (CheckpointState( all_model_checkpoint_paths=["/path/to/model/dir/model.ckpt-300"]), ["/path/to/model/dir/model.ckpt-300"]), (None, []), ]) def test__get_all_checkpoints(checkpoint_state, checkpoints): with mock.patch( "tf_yarn.tensorflow.tasks.evaluator_task.tf.train.get_checkpoint_state" ) as get_checkpoint_state_mock: get_checkpoint_state_mock.side_effect = lambda *args, **kwargs: checkpoint_state assert evaluator_task._get_all_checkpoints("dir") == checkpoints
def download_model(self, checkpoint_dir): s3_client = self.get_client() logger.info( "Downloading pretrained model from %s/%s %s" % (self.bucket, self.model_checkpoints_prefix, checkpoint_dir)) filename = "None" try: filename = os.path.abspath( os.path.join(checkpoint_dir, "checkpoint")) if not os.path.exists(checkpoint_dir): logger.info("Model folder %s does not exist, creating" % filename) os.makedirs(checkpoint_dir) while True: response = s3_client.list_objects_v2(Bucket=self.bucket, Prefix=self._get_s3_key( self.lock_file)) if "Contents" not in response: # If no lock is found, try getting the checkpoint try: key = self._get_s3_key("checkpoint") logger.info("Downloading %s" % key) s3_client.download_file(Bucket=self.bucket, Key=key, Filename=filename) except Exception as e: logger.info( "Something went wrong, will retry in 2 seconds %s" % e) time.sleep(2) continue else: logger.info("Found a lock file %s , waiting" % self._get_s3_key(self.lock_file)) time.sleep(2) continue ckpt = CheckpointState() if os.path.exists(filename): contents = open(filename, 'r').read() text_format.Merge(contents, ckpt) rel_path = ckpt.model_checkpoint_path checkpoint = int(rel_path.split('_Step')[0]) response = s3_client.list_objects_v2( Bucket=self.bucket, Prefix=self._get_s3_key(rel_path)) if "Contents" in response: num_files = 0 for obj in response["Contents"]: filename = os.path.abspath( os.path.join( checkpoint_dir, obj["Key"].replace( self.model_checkpoints_prefix, ""))) logger.info("Downloading model file %s" % filename) s3_client.download_file(Bucket=self.bucket, Key=obj["Key"], Filename=filename) num_files += 1 return True except Exception as e: util.json_format_logger( "{} while downloading the model {} from S3".format( e, filename), **util.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) return False
def download_model(self, checkpoint_dir): s3_client = self.get_client() filename = "None" try: filename = os.path.abspath( os.path.join(checkpoint_dir, "checkpoint")) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) while True: response = s3_client.list_objects_v2(Bucket=self.bucket, Prefix=self._get_s3_key( self.lock_file)) if "Contents" not in response: # If no lock is found, try getting the checkpoint try: s3_client.download_file( Bucket=self.bucket, Key=self._get_s3_key("checkpoint"), Filename=filename) except Exception as e: time.sleep(2) continue else: time.sleep(2) continue ckpt = CheckpointState() if os.path.exists(filename): contents = open(filename, 'r').read() text_format.Merge(contents, ckpt) rel_path = ckpt.model_checkpoint_path checkpoint = int(rel_path.split('_Step')[0]) response = s3_client.list_objects_v2( Bucket=self.bucket, Prefix=self._get_s3_key(rel_path)) if "Contents" in response: num_files = 0 for obj in response["Contents"]: filename = os.path.abspath( os.path.join( checkpoint_dir, obj["Key"].replace( self.model_checkpoints_prefix, ""))) s3_client.download_file(Bucket=self.bucket, Key=obj["Key"], Filename=filename) num_files += 1 return except botocore.exceptions.ClientError as e: utils.json_format_logger( "Unable to download model {} from {}: {}".format( filename, self.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Unable to download model {} from {}: {}".format( filename, self.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()