def register(self, candidate, name=None): """Registers a Python object "candidate" for the given "name". Args: candidate: The candidate object to add to the registry. name: An optional string specifying the registry key for the candidate. If None, candidate.__name__ will be used. Raises: KeyError: If same name is used twice. """ if not name: name = candidate.__name__ if name in self._registry: (filename, line_number, function_name, _) = (self._registry[name][_LOCATION_TAG]) raise KeyError( "Registering two %s with name '%s'! " "(Previous registration was in %s %s:%d)" % (self._name, name, function_name, filename, line_number)) logging.vlog(1, "Registering %s (%s) in %s.", name, candidate, self._name) # stack trace is [this_function, Register(), user_function,...] # so the user function is #2. stack = tf_stack.extract_stack() user_function = stack[2] location_tag = tf_stack.convert_stack([user_function])[0] self._registry[name] = { _TYPE_TAG: candidate, _LOCATION_TAG: location_tag }
def _find_index_of_defining_frame_for_op(op): """Return index in op._traceback with first 'useful' frame. This method reads through the stack stored in op._traceback looking for the innermost frame which (hopefully) belongs to the caller. It accomplishes this by rejecting frames whose filename appears to come from TensorFlow (see error_interpolation._BAD_FILE_SUBSTRINGS for the list of rejected substrings). Args: op: the Operation object for which we would like to find the defining location. Returns: Integer index into op._traceback where the first non-TF file was found (innermost to outermost), or 0 (for the outermost stack frame) if all files came from TensorFlow. """ # pylint: disable=protected-access # Index 0 of tf_traceback is the outermost frame. tf_traceback = tf_stack.convert_stack(op._traceback) size = len(tf_traceback) # pylint: enable=protected-access filenames = [frame[tf_stack.TB_FILENAME] for frame in tf_traceback] # We process the filenames from the innermost frame to outermost. for idx, filename in enumerate(reversed(filenames)): contains_bad_substrings = [ss in filename for ss in _BAD_FILE_SUBSTRINGS] if not any(contains_bad_substrings): return size - idx - 1 return 0
def traceback_files_common_prefix(all_ops): """Determines the common prefix from the paths of the stacktrace of 'all_ops'. For example, if the paths are '/foo/bar/baz/' and '/foo/car', this would return '/foo'. Args: all_ops: All the input nodes in the form of a list of lists of ops. Returns: The common prefix. """ files = set() for ops in all_ops: if ops is None: continue for op in ops: # pylint: disable=protected-access tf_traceback = tf_stack.convert_stack(op._traceback) # pylint: enable=protected-access for frame in tf_traceback: filename = frame[tf_stack.TB_FILENAME] if "<embedded" not in filename: files.add(filename) return os.path.split(os.path.commonprefix(list(files)))[0]
def register(self, candidate, name=None): """Registers a Python object "candidate" for the given "name". Args: candidate: The candidate object to add to the registry. name: An optional string specifying the registry key for the candidate. If None, candidate.__name__ will be used. Raises: KeyError: If same name is used twice. """ if not name: name = candidate.__name__ if name in self._registry: (filename, line_number, function_name, _) = ( self._registry[name][_LOCATION_TAG]) raise KeyError("Registering two %s with name '%s'! " "(Previous registration was in %s %s:%d)" % (self._name, name, function_name, filename, line_number)) logging.vlog(1, "Registering %s (%s) in %s.", name, candidate, self._name) # stack trace is [this_function, Register(), user_function,...] # so the user function is #2. stack = tf_stack.extract_stack() stack_index = min(2, len(stack)-1) if stack_index >= 0: user_function = stack[stack_index] location_tag = tf_stack.convert_stack([user_function])[0] else: location_tag = "UNKNOWN" self._registry[name] = {_TYPE_TAG: candidate, _LOCATION_TAG: location_tag}
def _compact_stack_trace(op): """Returns a traceback for `op` with common file prefixes stripped.""" compact_traces = [] common_prefix = error_interpolation.traceback_files_common_prefix([[op]]) # pylint: disable=protected-access tf_traceback = tf_stack.convert_stack(op._traceback) # pylint: enable=protected-access for frame in tf_traceback: frame = list(frame) filename = frame[tf_stack.TB_FILENAME] if filename.startswith(common_prefix): filename = filename[len(common_prefix):] frame[tf_stack.TB_FILENAME] = filename compact_traces.append(tuple(frame)) return compact_traces