def testConstructWrapperWithExistingEmptyDumpRoot(self): os.mkdir(self._tmp_dir) self.assertTrue(os.path.isdir(self._tmp_dir)) local_cli_wrapper.LocalCLIDebugWrapperSession(session.Session(), dump_root=self._tmp_dir, log_usage=False)
def testConstructWrapperWithExistingFileDumpRoot(self): file_path = os.path.join(self._tmp_dir, "foo") open(file_path, "a").close() # Create the file self.assertTrue(os.path.isfile(file_path)) with self.assertRaisesRegex(ValueError, "dump_root path points to a file"): local_cli_wrapper.LocalCLIDebugWrapperSession( session.Session(), dump_root=file_path, log_usage=False)
def before_run(self, run_context): if not self._session_wrapper: self._session_wrapper = local_cli_wrapper.LocalCLIDebugWrapperSession( run_context.session, ui_type=self._ui_type, dump_root=self._dump_root, thread_name_filter=self._thread_name_filter) # Actually register tensor filters registered prior to the construction # of the underlying LocalCLIDebugWrapperSession object. for filter_name in self._pending_tensor_filters: self._session_wrapper.add_tensor_filter( filter_name, self._pending_tensor_filters[filter_name]) # Increment run call counter. self._session_wrapper.increment_run_call_count() # Adapt run_context to an instance of OnRunStartRequest for invoking # superclass on_run_start(). on_run_start_request = framework.OnRunStartRequest( run_context.original_args.fetches, run_context.original_args.feed_dict, None, None, self._session_wrapper.run_call_count) on_run_start_response = self._session_wrapper.on_run_start( on_run_start_request) self._performed_action = on_run_start_response.action run_args = session_run_hook.SessionRunArgs( None, feed_dict=None, options=config_pb2.RunOptions()) if self._performed_action == framework.OnRunStartAction.DEBUG_RUN: # pylint: disable=protected-access self._session_wrapper._decorate_run_options_for_debug( run_args.options, on_run_start_response.debug_urls, debug_ops=on_run_start_response.debug_ops, node_name_regex_whitelist=( on_run_start_response.node_name_regex_whitelist), op_type_regex_whitelist=( on_run_start_response.op_type_regex_whitelist), tensor_dtype_regex_whitelist=( on_run_start_response.tensor_dtype_regex_whitelist), tolerate_debug_op_creation_failures=( on_run_start_response.tolerate_debug_op_creation_failures)) # pylint: enable=protected-access elif self._performed_action == framework.OnRunStartAction.PROFILE_RUN: # pylint: disable=protected-access self._session_wrapper._decorate_run_options_for_profile( run_args.options) # pylint: enable=protected-access elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER: # The _finalized property must be set to False so that the NodeStepper # can insert ops for retrieving TensorHandles. # pylint: disable=protected-access run_context.session.graph._finalized = False # pylint: enable=protected-access return run_args
def testConstructWrapperWithExistingNonEmptyDumpRoot(self): dir_path = os.path.join(self._tmp_dir, "foo") os.mkdir(dir_path) self.assertTrue(os.path.isdir(dir_path)) with self.assertRaisesRegex( ValueError, "dump_root path points to a non-empty directory"): local_cli_wrapper.LocalCLIDebugWrapperSession( session.Session(), dump_root=self._tmp_dir, log_usage=False)
def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key, input_tensor_key_feed_dict, outdir, overwrite_flag, tf_debug=False): """Runs SavedModel and fetch all outputs. Runs the input dictionary through the MetaGraphDef within a SavedModel specified by the given tag_set and SignatureDef. Also save the outputs to file if outdir is not None. Args: saved_model_dir: Directory containing the SavedModel to execute. tag_set: Group of tag(s) of the MetaGraphDef with the SignatureDef map, in string format, separated by ','. For tag-set contains multiple tags, all tags must be passed in. signature_def_key: A SignatureDef key string. input_tensor_key_feed_dict: A dictionary maps input keys to numpy ndarrays. outdir: A directory to save the outputs to. If the directory doesn't exist, it will be created. overwrite_flag: A boolean flag to allow overwrite output file if file with the same name exists. tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the intermediate Tensor values and runtime GraphDefs while running the SavedModel. Raises: ValueError: When any of the input tensor keys is not valid. RuntimeError: An error when output file already exists and overwrite is not enabled. """ # Get a list of output tensor names. meta_graph_def = saved_model_utils.get_meta_graph_def( saved_model_dir, tag_set) # Re-create feed_dict based on input tensor name instead of key as session.run # uses tensor name. inputs_tensor_info = _get_inputs_tensor_info_from_meta_graph_def( meta_graph_def, signature_def_key) # Check if input tensor keys are valid. for input_key_name in input_tensor_key_feed_dict.keys(): if input_key_name not in inputs_tensor_info.keys(): raise ValueError( '"%s" is not a valid input key. Please choose from %s, or use ' '--show option.' % (input_key_name, '"' + '", "'.join(inputs_tensor_info.keys()) + '"')) inputs_feed_dict = { inputs_tensor_info[key].name: tensor for key, tensor in input_tensor_key_feed_dict.items() } # Get outputs outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def( meta_graph_def, signature_def_key) # Sort to preserve order because we need to go from value to key later. output_tensor_keys_sorted = sorted(outputs_tensor_info.keys()) output_tensor_names_sorted = [ outputs_tensor_info[tensor_key].name for tensor_key in output_tensor_keys_sorted ] with session.Session(graph=ops_lib.Graph()) as sess: loader.load(sess, tag_set.split(','), saved_model_dir) if tf_debug: sess = local_cli_wrapper.LocalCLIDebugWrapperSession(sess) outputs = sess.run(output_tensor_names_sorted, feed_dict=inputs_feed_dict) for i, output in enumerate(outputs): output_tensor_key = output_tensor_keys_sorted[i] print('Result for output key %s:\n%s' % (output_tensor_key, output)) # Only save if outdir is specified. if outdir: # Create directory if outdir does not exist if not os.path.isdir(outdir): os.makedirs(outdir) output_full_path = os.path.join(outdir, output_tensor_key + '.npy') # If overwrite not enabled and file already exist, error out if not overwrite_flag and os.path.exists(output_full_path): raise RuntimeError( 'Output file %s already exists. Add \"--overwrite\" to overwrite' ' the existing output files.' % output_full_path) np.save(output_full_path, output) print('Output %s is saved to %s' % (output_tensor_key, output_full_path))
def testConstructWrapper(self): local_cli_wrapper.LocalCLIDebugWrapperSession(session.Session(), log_usage=False)
def __init__(self): sess = session.Session(graph=ops_lib.Graph()) loader.load(sess, tag_set.split(','), saved_model_dir) if tf_debug: sess = local_cli_wrapper.LocalCLIDebugWrapperSession(sess) self._sess = sess