def _check_trt_version_compatibility(): """Check compatibility of TensorRT version. Raises: RuntimeError: if the TensorRT library version is incompatible. """ compiled_version = get_linked_tensorrt_version() loaded_version = get_loaded_tensorrt_version() tf_logging.info("Linked TensorRT version: %s" % str(compiled_version)) tf_logging.info("Loaded TensorRT version: %s" % str(loaded_version)) version_mismatch = False if loaded_version[0] < compiled_version[0]: tf_logging.error( "TensorRT version mismatch. Tensorflow was compiled against " + "TensorRT %s but library loaded from environment is TensorRT %s" % (".".join([str(x) for x in compiled_version]), ".".join([str(x) for x in loaded_version])) + ". Please make sure that correct version of TensorRT " + "is available in the system and added to ldconfig or LD_LIBRARY_PATH") raise RuntimeError("Incompatible TensorRT library version") for i in zip(loaded_version, compiled_version): if i[0] != i[1]: tf_logging.warn("TensorRT mismatch. Compiled against version " + "%s, but loaded %s. Things may not work" % (".".join([str(x) for x in compiled_version]), ".".join([str(x) for x in loaded_version]))) version_mismatch = True break if not version_mismatch: tf_logging.info("Running against TensorRT version %s" % ".".join([str(x) for x in loaded_version]))
def run_simple_server(tb_app): """Start serving TensorBoard, and print some messages to console.""" # Mute the werkzeug logging. base_logging.getLogger('werkzeug').setLevel(base_logging.WARNING) try: server = serving.make_server(FLAGS.host, FLAGS.port, tb_app, threaded=True) server.daemon_threads = True except socket.error: if FLAGS.port == 0: msg = 'TensorBoard unable to find any open port' else: msg = ( 'TensorBoard attempted to bind to port %d, but it was already in use' % FLAGS.port) logging.error(msg) print(msg) exit(-1) port = server.socket.getsockname()[1] msg = 'Starting TensorBoard %s at http://%s:%d' % (tb_app.tag, FLAGS.host, port) print(msg) logging.info(msg) print('(Press CTRL+C to quit)') sys.stdout.flush() server.serve_forever()
def _AddOpInternal(self, op): # pylint: disable=protected-access if op.type in _BLACKLISTED_OPS: logging.error("Operation of type %s (%s) is not supported on the TPU. " "Execution will fail if this op is used in the graph. " % (op.type, op.name)) if op.type in _NOT_IMPLEMENTED_OPS: self._unsupported_ops.append(op) if any(x.dtype._is_ref_dtype for x in op.inputs): raise NotImplementedError( "Non-resource Variables are not supported inside TPU computations " "(operator name: %s)" % op.name) if _TPU_REPLICATE_ATTR in op.node_def.attr: raise ValueError("TPU computations cannot be nested") op._set_attr(_TPU_REPLICATE_ATTR, attr_value_pb2.AttrValue(s=compat.as_bytes(self._name))) if self._outside_compilation_cluster: op._set_attr( _OUTSIDE_COMPILATION_ATTR, attr_value_pb2.AttrValue( s=compat.as_bytes(self._outside_compilation_cluster))) if self._num_replicas > 1 or not self._outside_compilation_cluster: # Prevent feeding or fetching anything that is being compiled, # and any replicated outside_compilation Op. op.graph.prevent_feeding(op) op.graph.prevent_fetching(op)
def latest_checkpoint(checkpoint_dir, latest_filename=None): """Finds the filename of latest saved checkpoint file. Args: checkpoint_dir: Directory where the variables were saved. latest_filename: Optional name for the protocol buffer file that contains the list of most recent checkpoint filenames. See the corresponding argument to `Saver.save()`. Returns: The full path to the latest checkpoint or `None` if no checkpoint was found. """ # Pick the latest checkpoint based on checkpoint state. ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) if ckpt and ckpt.model_checkpoint_path: # Look for either a V2 path or a V1 path, with priority for V2. v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, saver_pb2.SaverDef.V2) v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, saver_pb2.SaverDef.V1) if file_io.get_matching_files(v2_path) or file_io.get_matching_files( v1_path): return ckpt.model_checkpoint_path else: logging.error("Couldn't match files for checkpoint %s", ckpt.model_checkpoint_path) return None
def get_global_step(graph=None): """Get the global step tensor. The global step tensor must be an integer variable. We first try to find it in the collection `GLOBAL_STEP`, or by name `global_step:0`. Args: graph: The graph to find the global step in. If missing, use default graph. Returns: The global step variable, or `None` if none was found. Raises: TypeError: If the global step tensor has a non-integer type, or if it is not a `Variable`. """ graph = ops.get_default_graph() if graph is None else graph global_step_tensor = None global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP) if len(global_step_tensors) == 1: global_step_tensor = global_step_tensors[0] elif not global_step_tensors: try: global_step_tensor = graph.get_tensor_by_name('global_step:0') except KeyError: return None else: logging.error('Multiple tensors in global_step collection.') return None assert_global_step(global_step_tensor) return global_step_tensor
def get_global_counter(collection, name, graph=None): """Get the global counter tensor. The global counter tensor must be an integer variable. We first try to find it in the collection, or by name. Args: collection: the counter's collection. name: the counter's name. graph: The graph to find the global counter in. If missing, use default graph. Returns: The global counter variable, or `None` if none was found. Raises: TypeError: If the global counter tensor has a non-integer type, or if it is not a `Variable`. """ graph = graph or tf.get_default_graph() global_counter_tensors = graph.get_collection(collection) if len(global_counter_tensors) == 1: global_counter_tensor = global_counter_tensors[0] elif not global_counter_tensors: try: global_counter_tensor = graph.get_tensor_by_name(name) except KeyError: return None else: logging.error('Multiple tensors in `{}` collection.'.format(collection)) return None assert_global_counter(global_counter_tensor) return global_counter_tensor
def record_error(self, source, exc_info, session=None): """Report an exception from the given source. If a session is passed, a timer will be registered to close it after a few seconds. This is necessary to ensure the main training loop does not hang if an infeed/oufeed error occurs. We sleep a few seconds to allow a more interesting error from another thread to propagate. Args: source: string, source of the error exc_info: Output from `sys.exc_info` (type, value, traceback) session: Session to close after delay. """ _, value, _ = exc_info self._errors[source] = exc_info logging.error('Error recorded from %s: %s', source, value) if session is not None and self._session_cancel_timer is None: def _cancel_session(): time.sleep(5) logging.error('Closing session due to error %s' % value) try: session.close() except: # pylint: disable=bare-except logging.error( '\n\n\nFailed to close session after error.' 'Other threads may hang.\n\n\n') self._session_cancel_timer = threading.Thread(target=_cancel_session,) self._session_cancel_timer.daemon = True self._session_cancel_timer.start()
def _check_dtypes(value, dtype): if value.dtype != dtype: logging.error( "Error: Input value {} has dtype {}, but expected dtype {}. " "This leads to undefined behavior and will be an error " "in future versions of TensorFlow. Traceback:\n{}".format( value, str(value.dtype), str(dtype), "".join(traceback.format_stack())))
def __exit__(self, exec_type, exec_value, exec_tb): if exec_type is errors.OpError: logging.error('Session closing due to OpError: %s', (exec_value,)) for context_manager in reversed(self._context_managers): context_manager.__exit__(exec_type, exec_value, exec_tb) self.close()
def __call__(self, path, parent, children): # The path to the object. lib_path = 'tensorflow.%s' % path if path else 'tensorflow' # A small helper method to construct members(children) protos. def _AddMember(member_name, member_obj, proto): """Add the child object to the object being constructed.""" _, member_obj = tf_decorator.unwrap(member_obj) if member_name == '__init__' or not member_name.startswith('_'): if tf_inspect.isroutine(member_obj): new_method = proto.member_method.add() new_method.name = member_name # If member_obj is a python builtin, there is no way to get its # argspec, because it is implemented on the C side. It also has no # func_code. if getattr(member_obj, 'func_code', None): new_method.argspec = _SanitizedArgSpec(member_obj) else: new_member = proto.member.add() new_member.name = member_name new_member.mtype = str(type(member_obj)) parent_corner_cases = _CORNER_CASES.get(path, {}) if path not in _CORNER_CASES or parent_corner_cases: # Decide if we have a module or a class. if tf_inspect.ismodule(parent): # Create a module object. module_obj = api_objects_pb2.TFAPIModule() for name, child in children: if name in parent_corner_cases: # If we have an empty entry, skip this object. if parent_corner_cases[name]: module_obj.member.add(**(parent_corner_cases[name])) else: _AddMember(name, child, module_obj) # Store the constructed module object. self._protos[lib_path] = api_objects_pb2.TFAPIObject( path=lib_path, tf_module=module_obj) elif tf_inspect.isclass(parent): # Construct a class. class_obj = api_objects_pb2.TFAPIClass() class_obj.is_instance.extend(_SanitizedMRO(parent)) for name, child in children: if name in parent_corner_cases: # If we have an empty entry, skip this object. if parent_corner_cases[name]: module_obj.member.add(**(parent_corner_cases[name])) else: _AddMember(name, child, class_obj) # Store the constructed class object. self._protos[lib_path] = api_objects_pb2.TFAPIObject( path=lib_path, tf_class=class_obj) else: logging.error('Illegal call to ApiProtoDump::_py_obj_to_proto.' 'Object is neither a module nor a class: %s', path)
def make_simple_server(tb_app, host, port): """Create an HTTP server for TensorBoard. Args: tb_app: The TensorBoard WSGI application to create a server for. host: Indicates the interfaces to bind to ('::' or '0.0.0.0' for all interfaces, '::1' or '127.0.0.1' for localhost). A blank value ('') indicates protocol-agnostic all interfaces. port: The port to bind to (0 indicates an unused port selected by the operating system). Returns: A tuple of (server, url): server: An HTTP server object configured to host TensorBoard. url: A best guess at a URL where TensorBoard will be accessible once the server has been started. Raises: socket.error: If a server could not be constructed with the host and port specified. Also logs an error message. """ # Mute the werkzeug logging. base_logging.getLogger('werkzeug').setLevel(base_logging.WARNING) try: if host: # The user gave us an explicit host server = serving.make_server(host, port, tb_app, threaded=True) if ':' in host and not host.startswith('['): # Display IPv6 addresses as [::1]:80 rather than ::1:80 final_host = '[{}]'.format(host) else: final_host = host else: # We've promised to bind to all interfaces on this host. However, we're # not sure whether that means IPv4 or IPv6 interfaces. try: # First try passing in a blank host (meaning all interfaces). This, # unfortunately, defaults to IPv4 even if no IPv4 interface is available # (yielding a socket.error). server = serving.make_server(host, port, tb_app, threaded=True) except socket.error: # If a blank host didn't work, we explicitly request IPv6 interfaces. server = serving.make_server('::', port, tb_app, threaded=True) final_host = socket.gethostname() server.daemon_threads = True except socket.error as socket_error: if port == 0: msg = 'TensorBoard unable to find any open port' else: msg = ( 'TensorBoard attempted to bind to port %d, but it was already in use' % FLAGS.port) logging.error(msg) print(msg) raise socket_error final_port = server.socket.getsockname()[1] tensorboard_url = 'http://%s:%d' % (final_host, final_port) return server, tensorboard_url
def _cancel_session(): time.sleep(5) logging.error('Closing session due to error %s' % value) try: session.close() except: # pylint: disable=bare-except logging.error( '\n\n\nFailed to close session after error.' 'Other threads may hang.\n\n\n')
def after_run(self, run_context, run_values): if np.isnan(run_values.results): failure_message = "Model diverged with loss = NaN." if self._fail_on_nan_loss: logging.error(failure_message) raise NanLossDuringTrainingError else: logging.warning(failure_message) # We don't raise an error but we request stop without an exception. run_context.request_stop()
def __exit__(self, exec_type, exec_value, exec_tb): if exec_type is errors.OpError: logging.error('Session closing due to OpError: %s', (exec_value,)) self._default_session_context_manager.__exit__( exec_type, exec_value, exec_tb) self._default_graph_context_manager.__exit__(exec_type, exec_value, exec_tb) self._default_session_context_manager = None self._default_graph_context_manager = None self.close()
def every_n_step_end(self, step, outputs): super(NanLoss, self).every_n_step_end(step, outputs) if np.isnan(_extract_output(outputs, self._loss_tensor)): failure_message = "Model diverged with loss = NaN." if self._fail_on_nan_loss: logging.error(failure_message) raise NanLossDuringTrainingError else: logging.warning(failure_message) # We don't raise an error but we return "should stop" so we stop, but # without an exception. return True
def Reload(self): """Call `Reload` on every `EventAccumulator`.""" self._reload_called = True # Build a list so we're safe even if the list of accumulators is modified # even while we're reloading. with self._accumulators_mutex: items = list(self._accumulators.items()) for name, accumulator in items: try: accumulator.Reload() except (OSError, IOError) as e: logging.error("Unable to reload accumulator '%s': %s", name, e) return self
def test_function_alias(self, mock_warning): deprecated_func = deprecation.deprecated_alias("deprecated.func", "real.func", logging.error) logging.error("fake error logged") self.assertEqual(0, mock_warning.call_count) deprecated_func("FAKE ERROR!") self.assertEqual(1, mock_warning.call_count) # Make sure the error points to the right file. self.assertRegexpMatches(mock_warning.call_args[0][1], r"deprecation_test\.py:") deprecated_func("ANOTHER FAKE ERROR!") self.assertEqual(1, mock_warning.call_count)
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): """Convert an existing calibration graph to inference graph. Args: calibration_graph_def: the calibration GraphDef object with calibration data is_dynamic_op: whether to create dynamic static engines from calibration Returns: New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. Raises: RuntimeError: if the returned status message is malformed. """ def py2string(inp): return inp def py3string(inp): return inp.decode("utf-8") if _six.PY2: to_string = py2string else: to_string = py3string is_calib_graph = False for n in calibration_graph_def.node: if n.op == "TRTEngineOp": is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s if not is_calib_graph: tf_logging.error( "Not a calib graph. Doesn't seem to contain any calibration nodes.") return None graph_str = calibration_graph_def.SerializeToString() out = calib_convert(graph_str, is_dynamic_op) status = to_string(out[0]) output_graph_def_string = out[1] del graph_str # Save some memory if len(status) < 2: raise _impl.UnknownError(None, None, status) if status[:2] != "OK": msg = status.split(";") if len(msg) == 1: raise RuntimeError("Status message is malformed {}".format(status)) # pylint: disable=protected-access raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), int(msg[0])) # pylint: enable=protected-access output_graph_def = graph_pb2.GraphDef() output_graph_def.ParseFromString(output_graph_def_string) del output_graph_def_string # Save some memory return output_graph_def
def _HasOOOWrite(self, path): """Returns whether the path has had an out-of-order write.""" # Check the sizes of each path before the current one. size = io_wrapper.Size(path) old_size = self._finalized_sizes.get(path, None) if size != old_size: if old_size is None: logging.error('File %s created after file %s even though it\'s ' 'lexicographically earlier', path, self._path) else: logging.error('File %s updated even though the current file is %s', path, self._path) return True else: return False
def _run(self, sess, enqueue_op, feed_fn, coord=None): """Execute the enqueue op in a loop, close the queue in case of error. Args: sess: A `Session`. enqueue_op: The `Operation` to run. feed_fn: the feed function to pass to `sess.run`. coord: Optional `Coordinator` object for reporting errors and checking for stop conditions. """ # TODO(jamieas): Reduce code duplication with `QueueRunner`. if coord: coord.register_thread(threading.current_thread()) decremented = False try: while True: if coord and coord.should_stop(): break try: feed_dict = None if feed_fn is None else feed_fn() sess.run(enqueue_op, feed_dict=feed_dict) except (errors.OutOfRangeError, errors.CancelledError): # This exception indicates that a queue was closed. with self._lock: self._runs_per_session[sess] -= 1 decremented = True if self._runs_per_session[sess] == 0: try: sess.run(self._close_op) except Exception as e: # Intentionally ignore errors from close_op. logging.vlog(1, "Ignored exception: %s", str(e)) return except Exception as e: # This catches all other exceptions. if coord: coord.request_stop(e) else: logging.error("Exception in QueueRunner: %s", str(e)) with self._lock: self._exceptions_raised.append(e) raise finally: # Make sure we account for all terminations: normal or errors. if not decremented: with self._lock: self._runs_per_session[sess] -= 1
def from_devices(session, devices): """Construct a heartbeat manager for the given devices.""" if not devices: logging.error('Trying to create heartbeat manager with no devices?') logging.info('Creating heartbeat manager for %s', devices) request_placeholder = array_ops.placeholder( name='worker_heartbeat_request', dtype=dtypes.string) heartbeat_ops = [] for device in devices: with ops.device(device): heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder)) return WorkerHeartbeatManager(session, devices, heartbeat_ops, request_placeholder)
def CheckIsSupported(): """Raises an OSError if the system isn't set up for Google Cloud Storage. Raises: OSError: If the system hasn't been set up so that TensorBoard can access Google Cloud Storage. The error's message contains installation instructions. """ try: subprocess.check_output(['gsutil', 'version']) except OSError as e: logging.error('Error while checking for gsutil: %s', e) raise OSError( 'Unable to execute the gsutil binary, which is required for Google ' 'Cloud Storage support. You can find installation instructions at ' 'https://goo.gl/sST520')
def _convert_default_signature_to_signature_def(signatures): """Convert default signature to object of type SignatureDef. Args: signatures: object of type manifest_pb2.Signatures() Returns: object of type SignatureDef which contains a converted version of default signature from input signatures object Returns None if signature is of generic type because it cannot be converted to SignatureDef. """ default_signature = signatures.default_signature signature_def = meta_graph_pb2.SignatureDef() if (default_signature.WhichOneof("type") == legacy_constants.REGRESSION_SIGNATURE): regression_signature = default_signature.regression_signature signature_def.method_name = signature_constants.REGRESS_METHOD_NAME _add_input_to_signature_def(regression_signature.input.tensor_name, signature_constants.REGRESS_INPUTS, signature_def) _add_output_to_signature_def(regression_signature.output.tensor_name, signature_constants.REGRESS_OUTPUTS, signature_def) elif (default_signature.WhichOneof("type") == legacy_constants.CLASSIFICATION_SIGNATURE): classification_signature = default_signature.classification_signature signature_def.method_name = signature_constants.CLASSIFY_METHOD_NAME _add_input_to_signature_def(classification_signature.input.tensor_name, signature_constants.CLASSIFY_INPUTS, signature_def) _add_output_to_signature_def(classification_signature.classes.tensor_name, signature_constants.CLASSIFY_OUTPUT_CLASSES, signature_def) _add_output_to_signature_def(classification_signature.scores.tensor_name, signature_constants.CLASSIFY_OUTPUT_SCORES, signature_def) else: logging.error( "Only classification and regression default signatures " "are supported for up-conversion. %s is not " "supported", default_signature.WhichOneof("type")) return None return signature_def
def _run(self, sess, enqueue_op, coord=None): """Execute the enqueue op in a loop, close the queue in case of error. Args: sess: A Session. enqueue_op: The Operation to run. coord: Optional Coordinator object for reporting errors and checking for stop conditions. """ decremented = False try: # Make a cached callable from the `enqueue_op` to decrease the # Python overhead in the queue-runner loop. enqueue_callable = sess.make_callable(enqueue_op) while True: if coord and coord.should_stop(): break try: enqueue_callable() except self._queue_closed_exception_types: # pylint: disable=catching-non-exception # This exception indicates that a queue was closed. with self._lock: self._runs_per_session[sess] -= 1 decremented = True if self._runs_per_session[sess] == 0: try: sess.run(self._close_op) except Exception as e: # Intentionally ignore errors from close_op. logging.vlog(1, "Ignored exception: %s", str(e)) return except Exception as e: # This catches all other exceptions. if coord: coord.request_stop(e) else: logging.error("Exception in QueueRunner: %s", str(e)) with self._lock: self._exceptions_raised.append(e) raise finally: # Make sure we account for all terminations: normal or errors. if not decremented: with self._lock: self._runs_per_session[sess] -= 1
def _restore_collections(dest_graph, src_meta_graph_def, collection_keys): """Restores collections that we need to keep.""" scope = "" for key in collection_keys: collection_def = src_meta_graph_def.collection_def[key] kind = collection_def.WhichOneof("kind") if kind is None: tf_logging.error( "Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) # It is assumed that there are no Variables Keys in collections for value in collection_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) try: new_value = from_proto(proto, import_scope=scope) except: continue dest_graph.add_to_collection(key, new_value) else: field = getattr(collection_def, kind) if kind == "node_list": for value in field.value: name = ops.prepend_name_scope(value, scope) # Since the graph has been optimized, the node may no longer # exists try: col_op = dest_graph.as_graph_element(name) except (TypeError, ValueError, KeyError) as e: continue dest_graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the # fact that Python2 distinguishes between int and long, while # Python3 has only int. for value in field.value: dest_graph.add_to_collection(key, int(value)) else: for value in field.value: dest_graph.add_to_collection(key, ops.prepend_name_scope(value, scope))
def _run(self, sess, enqueue_op, coord=None): """Execute the enqueue op in a loop, close the queue in case of error. Args: sess: A Session. enqueue_op: The Operation to run. coord: Optional Coordinator object for reporting errors and checking for stop conditions. """ if coord: coord.register_thread(threading.current_thread()) decremented = False try: while True: if coord and coord.should_stop(): break try: sess.run(enqueue_op) except errors.OutOfRangeError: # This exception indicates that a queue was closed. with self._lock: self._runs -= 1 decremented = True if self._runs == 0: try: sess.run(self._close_op) except Exception as e: # Intentionally ignore errors from close_op. logging.vlog(1, "Ignored exception: %s", str(e)) return except Exception as e: # This catches all other exceptions. if coord: coord.request_stop(e) else: logging.error("Exception in QueueRunner: %s", str(e)) with self._lock: self._exceptions_raised.append(e) raise finally: # Make sure we account for all terminations: normal or errors. if not decremented: with self._lock: self._runs -= 1
def _run(self, sess, enqueue_op, coord=None): """Execute the enqueue op in a loop, close the queue in case of error. Args: sess: A Session. enqueue_op: The Operation to run. coord: Optional Coordinator object for reporting errors and checking for stop conditions. """ decremented = False try: while True: if coord and coord.should_stop(): break try: sess.run(enqueue_op) except self._queue_closed_exception_types: # pylint: disable=catching-non-exception # This exception indicates that a queue was closed. with self._lock: self._runs_per_session[sess] -= 1 decremented = True if self._runs_per_session[sess] == 0: try: sess.run(self._close_op) except Exception as e: # Intentionally ignore errors from close_op. logging.vlog(1, "Ignored exception: %s", str(e)) return except Exception as e: # This catches all other exceptions. if coord: coord.request_stop(e) else: logging.error("Exception in QueueRunner: %s", str(e)) with self._lock: self._exceptions_raised.append(e) raise finally: # Make sure we account for all terminations: normal or errors. if not decremented: with self._lock: self._runs_per_session[sess] -= 1
def serve(self): """Starts a WSGI server that serves the TensorBoard app.""" tb_app = self.create_app() logging.info('Starting TensorBoard in directory %s', os.getcwd()) debug = FLAGS.insecure_debug_mode if debug: logging.set_verbosity(logging.DEBUG) logging.warning( 'TensorBoard is in debug mode. This is NOT SECURE.') print('Starting TensorBoard %s on port %d' % (self.get_tag(), FLAGS.port)) if FLAGS.host == '0.0.0.0': try: host = socket.gethostbyname(socket.gethostname()) print('(You can navigate to http://%s:%d)' % (host, FLAGS.port)) except socket.gaierror: pass else: print('(You can navigate to http://%s:%d)' % (FLAGS.host, FLAGS.port)) try: serving.run_simple(FLAGS.host, FLAGS.port, tb_app, threaded=True, use_reloader=debug, use_evalex=debug, use_debugger=debug) except socket.error: if FLAGS.port == 0: msg = 'Unable to find any open ports.' logging.error(msg) print(msg) return -2 else: msg = 'Tried to connect to port %d, but address is in use.' % FLAGS.port logging.error(msg) print(msg) return -3
def saver(self): if self._saver: return self._saver savers = ops.get_collection(ops.GraphKeys.SAVERS) if not savers: return None if not isinstance(savers, list): return savers if len(savers) > 1: logging.error( 'Multiple savers in the SAVERS collection. On-demand checkpointing ' 'will be disabled. Pass an explicit `saver` to the constructor to ' 'override this behavior.') return None return savers[0]
def _check_health(self): while True: if self._check_health_thread_should_stop.is_set(): return for job in self._cluster_spec.jobs: for task_id in range(self._cluster_spec.num_tasks(job)): peer = "/job:{}/replica:0/task:{}".format(job, task_id) attempts = 0 while True: attempts += 1 try: context.context().check_collective_ops_peer_health( peer) # If check_collective_ops_peer_health doesn't raise an Exception, # the peer is healthy. break except (errors.UnavailableError, errors.FailedPreconditionError) as e: # TODO(b/151232436): Always raise UnavailableError when a peer # fails. Now there could be many kinds of errors: # - Unavailable: when the peer is not reachable, e.g. it's down. # - FailedPrecondition: when the peer has restarted. if attempts < self._check_health_retry_limit: logging.warning( "%s seems down, retrying %d/%d", peer, attempts, self._check_health_retry_limit) continue logging.error( "Cluster check alive failed, %s is down, " "aborting collectives: %s", peer, e) context.context().abort_collective_ops( errors.UNAVAILABLE, "cluster check alive failed, {} is down". format(peer)) return except Exception as e: # pylint: disable=broad-except logging.error( "Unexpected exception in check alive: %s", e) context.context().abort_collective_ops( errors.INTERNAL, "unexecpted exception in check alive: %s" % e) return time.sleep(self._check_health_interval)
def _convert_default_signature_to_signature_def(signatures): """Convert default signature to object of type SignatureDef. Args: signatures: object of type manifest_pb2.Signatures() Returns: object of type SignatureDef which contains a converted version of default signature from input signatures object Returns None if signature is of generic type because it cannot be converted to SignatureDef. """ default_signature = signatures.default_signature signature_def = meta_graph_pb2.SignatureDef() if default_signature.WhichOneof("type") == "regression_signature": regression_signature = default_signature.regression_signature signature_def.method_name = signature_constants.REGRESS_METHOD_NAME _add_input_to_signature_def(regression_signature.input.tensor_name, signature_constants.REGRESS_INPUTS, signature_def) _add_output_to_signature_def(regression_signature.output.tensor_name, signature_constants.REGRESS_OUTPUTS, signature_def) elif default_signature.WhichOneof("type") == "classification_signature": classification_signature = default_signature.classification_signature signature_def.method_name = signature_constants.CLASSIFY_METHOD_NAME _add_input_to_signature_def(classification_signature.input.tensor_name, signature_constants.CLASSIFY_INPUTS, signature_def) _add_output_to_signature_def( classification_signature.classes.tensor_name, signature_constants.CLASSIFY_OUTPUT_CLASSES, signature_def) _add_output_to_signature_def( classification_signature.scores.tensor_name, signature_constants.CLASSIFY_OUTPUT_SCORES, signature_def) else: logging.error("Only classification and regression default signatures " "are supported for up-conversion. %s is not " "supported" % default_signature.WhichOneof("type")) return None return signature_def
def _AddOpInternal(self, op): # pylint: disable=protected-access if op.type in _BLACKLISTED_OPS: logging.error("Operation of type %s (%s) is not supported on the TPU. " "Execution will fail if this op is used in the graph. " % (op.type, op.name)) if op.type in _NOT_IMPLEMENTED_OPS: self._unsupported_ops.append(op) if any(x.dtype._is_ref_dtype for x in op.inputs): raise NotImplementedError( "Non-resource Variables are not supported inside TPU computations " "(operator name: %s)" % op.name) # pylint: enable=protected-access if _TPU_REPLICATE_ATTR in op.node_def.attr: raise ValueError("TPU computations cannot be nested") op.node_def.attr[_TPU_REPLICATE_ATTR].s = compat.as_bytes(self._name) op.graph.prevent_feeding(op) op.graph.prevent_fetching(op)
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): """Convert an existing calibration graph to inference graph. Args: calibration_graph_def: the calibration GraphDef object with calibration data is_dynamic_op: whether to create dynamic static engines from calibration Returns: New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. Raises: RuntimeError: if the returned status message is malformed. """ is_calib_graph = False for n in calibration_graph_def.node: if n.op == "TRTEngineOp": is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s if not is_calib_graph: tf_logging.error( "Not a calib graph. Doesn't seem to contain any calibration nodes." ) return None graph_str = calibration_graph_def.SerializeToString() out = calib_convert(graph_str, is_dynamic_op) status = _to_string(out[0]) output_graph_def_string = out[1] del graph_str # Save some memory if len(status) < 2: raise _impl.UnknownError(None, None, status) if status[:2] != "OK": msg = status.split(";") if len(msg) == 1: raise RuntimeError("Status message is malformed {}".format(status)) # pylint: disable=protected-access raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), int(msg[0])) # pylint: enable=protected-access output_graph_def = graph_pb2.GraphDef() output_graph_def.ParseFromString(output_graph_def_string) del output_graph_def_string # Save some memory return output_graph_def
def wait_on_failure(self, on_failure_fn=None, on_recovery_fn=None, worker_device_name="(unknown)"): """Catches worker preemption error and wait until failed workers are back. Args: on_failure_fn: an optional function to run if preemption happens. on_recovery_fn: an optional function to run when a worker is recovered from preemption. worker_device_name: the device name of the worker instance that is passing through the failure. Yields: None. """ try: yield except errors.OpError as e: # If the error is due to temporary connectivity issues between worker and # ps, put back closure, ignore error and do not mark worker as failure. if self._cluster._record_and_ignore_transient_ps_failure(e): # pylint: disable=protected-access if on_failure_fn: on_failure_fn() return self._validate_preemption_failure(e) logging.error("Worker %s failed with error: %s", worker_device_name, e) if on_failure_fn: on_failure_fn() with self._cluster_update_lock: self._cluster_due_for_update.set() self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC) logging.info("Worker %s has been recovered.", worker_device_name) if on_recovery_fn: with self.wait_on_failure( on_recovery_fn=on_recovery_fn, worker_device_name=worker_device_name): on_recovery_fn()
def _run(self, sess, enqueue_op, coord=None): if coord: coord.register_thread(threading.current_thread()) decremented = False try: while True: if coord and coord.should_stop(): break try: #CUSTOM FUNCTION CALL self.func(sess, enqueue_op, self.data_producer) # call enqueue function #CUSTOM FUNCTION CALL except self._queue_closed_exception_types: # pylint: disable=catching-non-exception # This exception indicates that a queue was closed. with self._lock: self._runs_per_session[sess] -= 1 decremented = True if self._runs_per_session[sess] == 0: try: sess.run(self._close_op) except Exception as e: # Intentionally ignore errors from close_op. logging.vlog(1, "Ignored exception: %s", str(e)) return except Exception as e: # This catches all other exceptions. if coord: coord.request_stop(e) else: logging.error("Exception in QueueRunner: %s", str(e)) with self._lock: self._exceptions_raised.append(e) raise finally: # Make sure we account for all terminations: normal or errors. if not decremented: with self._lock: self._runs_per_session[sess] -= 1
def check_column(df: DataFrame, name: str, fn: Callable[[float], bool]) -> bool: """Checks the values of a column using a custom function and logs abnormals. The check is only performed on TensorRT models, not native CPU/GPU models. Args: df: The DataFrame to be checked. name: The name of the column to be checked. fn: The function that takes a value of at the specified column and returns if the value statisfies the check. Returns: Whether all the values of the specified column satisfies the provided check. """ is_ok = True for r in range(df.n_rows): if df(r, "trt_model"): if not fn(df(r, name)): logging.error("Unsatisfied %s found at: %s", name, df(r)) is_ok = False return is_ok
def _SetPath(self, path): """Sets the current path to watch for new events. This also records the size of the old path, if any. If the size can't be found, an error is logged. Args: path: The full path of the file to watch. """ old_path = self._path if old_path and not io_wrapper.IsGCSPath(old_path): try: # We're done with the path, so store its size. size = gfile.Stat(old_path).length logging.debug('Setting latest size of %s to %d', old_path, size) self._finalized_sizes[old_path] = size except errors.OpError as e: logging.error('Unable to get size of %s: %s', old_path, e) self._path = path self._loader = self._loader_factory(path)
def _AddOpInternal(self, op): # pylint: disable=protected-access if op.type in _BLACKLISTED_OPS: logging.error( "Operation of type %s (%s) is not supported on the TPU. " "Execution will fail if this op is used in the graph. " % (op.type, op.name)) if op.type in _NOT_IMPLEMENTED_OPS: self._unsupported_ops.append(op) if any(x.dtype._is_ref_dtype for x in op.inputs): raise NotImplementedError( "Non-resource Variables are not supported inside TPU computations " "(operator name: %s)" % op.name) # pylint: enable=protected-access if _TPU_REPLICATE_ATTR in op.node_def.attr: raise ValueError("TPU computations cannot be nested") op.node_def.attr[_TPU_REPLICATE_ATTR].s = compat.as_bytes(self._name) op.graph.prevent_feeding(op) op.graph.prevent_fetching(op)
def _SetPath(self, path): """Sets the current path to watch for new events. This also records the size of the old path, if any. If the size can't be found, an error is logged. Args: path: The full path of the file to watch. """ old_path = self._path if old_path and not gcs.IsGCSPath(old_path): try: # We're done with the path, so store its size. size = io_wrapper.Size(old_path) logging.debug('Setting latest size of %s to %d', old_path, size) self._finalized_sizes[old_path] = size except (IOError, OSError) as e: logging.error('Unable to get size of %s: %s', old_path, e) self._path = path self._loader = self._loader_factory(path)
def from_devices(session, devices): """Construct a heartbeat manager for the given devices.""" if not devices: logging.error('Trying to create heartbeat manager with no devices?') logging.info('Creating heartbeat manager for %s', devices) request_placeholder = array_ops.placeholder( name='worker_heartbeat_request', dtype=dtypes.string) heartbeat_ops = [] kept_devices = [] for device in devices: heartbeat_op = _make_heartbeat_op(session, device, request_placeholder) if heartbeat_op is not None: kept_devices.append(device) heartbeat_ops.append(heartbeat_op) else: logging.warning('Heartbeat support not available for %s', device) return WorkerHeartbeatManager(session, kept_devices, heartbeat_ops, request_placeholder)
def _serve_health_pills_helper(self, request): """Responds with health pills. Accepts POST requests and responds with health pills. Specifically, the handler expects a "node_names" POST data key. The value of that key should be a JSON-ified list of node names for which the client would like to request health pills. This data is sent via POST instead of GET because URL length is limited. This handler responds with a JSON-ified object mapping from node names to a list of HealthPillEvents. Node names for which there are no health pills to be found are excluded from the mapping. Args: request: The request issued by the client for health pills. Returns: A werkzeug BaseResponse object. """ if request.method != 'POST': logging.error('%s requests are forbidden by the debugger plugin.', request.method) return wrappers.Response(status=405) if _NODE_NAMES_POST_KEY not in request.form: logging.error( 'The %s POST key was not found in the request for health pills.', _NODE_NAMES_POST_KEY) return wrappers.Response(status=400) jsonified_node_names = request.form[_NODE_NAMES_POST_KEY] try: node_names = json.loads(jsonified_node_names) except Exception as e: # pylint: disable=broad-except # Different JSON libs raise different exceptions, so we just do a # catch-all here. This problem is complicated by how Tensorboard might be # run in many different environments, as it is open-source. logging.error('Could not decode node name JSON string %s: %s', jsonified_node_names, e) return wrappers.Response(status=400) if not isinstance(node_names, list): logging.error('%s is not a JSON list of node names:', jsonified_node_names) return wrappers.Response(status=400) # TODO(chizeng): Actually respond with the health pills per node name. return http_util.Respond(request, node_names, mimetype='application/json')
def TestRandomGraph(self, sess, op_placement=None, random_seed=None): debug_mode = False if op_placement is None: op_placement = self._GenerateOperationPlacement() else: debug_mode = True if random_seed is None: random_seed = random.randint(0, 1 << 31) else: debug_mode = True logging.info('Virtual gpu functional test for random graph...') logging.info('operation placement: %s', str(op_placement)) logging.info('random seed: %d', random_seed) # Run with multiple virtual gpus. result_vgd = self._TestRandomGraphWithDevices( sess, random_seed, op_placement, self.devices, debug_mode=debug_mode) # Run with single cpu. result_cpu = self._TestRandomGraphWithDevices( sess, random_seed, op_placement, ['/cpu:0'] * self._num_devices, debug_mode=debug_mode) # Test the result for i in range(self._dim): for j in range(self._dim): if result_vgd[i][j] != result_cpu[i][j]: logging.error( 'Result mismatch at row %d column %d: expected %f, actual %f', i, j, result_cpu[i][j], result_vgd[i][j]) logging.error('Devices: %s', self.devices) logging.error('Memory limits (in MB): %s', self._mem_limits_mb) return False return True
def _process_closure(self, closure): """Runs a closure with preemption handling.""" try: with self._cluster.failure_handler.wait_on_failure( on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure), # pylint: disable=protected-access on_recovery_fn=self._set_resources_aborted, worker_device_name=self.device_name): closure.execute_on(self) # TODO(yuefengz): we don't have to materialize results every step. with metric_utils.monitored_timer("remote_value_fetch"): closure._fetch_output_remote_values() # pylint: disable=protected-access self._cluster._closure_queue.mark_finished() # pylint: disable=protected-access except Exception as e: # pylint: disable=broad-except # Avoid logging the derived cancellation error if not isinstance(e, errors.CancelledError): logging.error( "/job:worker/task:%d encountered the following error when " "processing closure: %r:%s", self.worker_index, e, e) nest.map_structure( lambda x: x._set_error(e), # pylint: disable=protected-access closure._output_remote_values) # pylint: disable=protected-access self._cluster._closure_queue.mark_failed(e) # pylint: disable=protected-access
def Reload(self): """Call `Reload` on every `EventAccumulator`.""" self._reload_called = True # Build a list so we're safe even if the list of accumulators is modified # even while we're reloading. with self._accumulators_mutex: items = list(self._accumulators.items()) names_to_delete = set() for name, accumulator in items: try: accumulator.Reload() except (OSError, IOError) as e: logging.error("Unable to reload accumulator '%s': %s", name, e) except directory_watcher.DirectoryDeletedError: names_to_delete.add(name) with self._accumulators_mutex: for name in names_to_delete: logging.warning("Deleting accumulator '%s'", name) del self._accumulators[name] return self
def _run(self, sess, enqueue_op, coord=None): if coord: coord.register_thread(threading.current_thread()) decremented = False try: while True: if coord and coord.should_stop(): break try: # Call enqueue_op self.func(sess, enqueue_op) except self._queue_closed_exception_types: # This exception indicates that a queue was closed with self._lock: self._runs_per_session[sess] -= 1 decremented = True if self._runs_per_session[sess] == 0: try: sess.run(self._close_op) except Exception as e: logging.vlog(1, "Ignored exception: %s", str(e)) return except ValueError: pass except Exception as e: # This catches all other exceptions if coord: coord.request_stop(e) else: logging.error("Exception in QueueRunner: %s", str(e)) with self._lock: self._exceptions_raised.append(e) raise finally: # Make sure we acount for all terminations if not decremented: with self._lock: self._runs_per_session[sess] -= 1
def _check_trt_version_compatibility(): """Check compatibility of TensorRT version. Raises: RuntimeError: if the TensorRT library version is incompatible. """ linked_version = wrap_py_utils.get_linked_tensorrt_version() loaded_version = wrap_py_utils.get_loaded_tensorrt_version() assert isinstance(linked_version, tuple) assert isinstance(loaded_version, tuple) assert len(linked_version) == 3 assert len(loaded_version) == 3 tf_logging.info("Linked TensorRT version: %s" % str(linked_version)) tf_logging.info("Loaded TensorRT version: %s" % str(loaded_version)) if loaded_version < linked_version: tf_logging.error( "Loaded TensorRT %s but linked TensorFlow against TensorRT %s. " % (".".join([str(x) for x in loaded_version]), ".".join([str(x) for x in linked_version])) + "TensorRT does not support forward compatibility. " + "It is also required to use the same major version of TensorRT " + "during compilation and runtime.") raise RuntimeError("Incompatible TensorRT versions") if loaded_version[0] > linked_version[0]: tf_logging.error( "Loaded TensorRT %s but linked TensorFlow against TensorRT %s. " % (".".join([str(x) for x in loaded_version]), ".".join([str(x) for x in linked_version])) + "It is required to use the same major version " + "of TensorRT during compilation and runtime.") raise RuntimeError("Incompatible TensorRT major version") if loaded_version != linked_version: tf_logging.info( "Loaded TensorRT %s and linked TensorFlow against TensorRT %s. " % (".".join([str(x) for x in loaded_version]), ".".join([str(x) for x in linked_version])) + "This is supported because TensorRT " + " minor/patch upgrades are backward compatible")
def _check_health(self): while True: if self._check_health_thread_should_stop.is_set(): return try: for job in self._cluster_spec.jobs: for task_id in range(self._cluster_spec.num_tasks(job)): context.context().check_collective_ops_peer_health( "/job:{}/replica:0/task:{}".format(job, task_id)) except (errors.UnavailableError, errors.FailedPreconditionError) as e: # TODO(b/151232436): Always raise UnavailableError when a peer fails. # Now there could be many kinds of errors: # - Unavailable: when the peer is not reachable, e.g. it's down. # - FailedPrecondition: when the peer has restarted. logging.error("Cluster check alive failed, aborting collectives") context.context().abort_collective_ops( errors.UNAVAILABLE, "cluster check alive failed: %s" % e) except Exception as e: # pylint: disable=broad-except logging.exception("Unexpected exception in check alive.") context.context().abort_collective_ops( errors.INTERNAL, "unexecpted exception in check alive: %s" % e) return time.sleep(self._check_health_interval)
def _check_sated(self, raise_error): """Check if the object has been sated.""" if self._sated: return creation_stack = ''.join( [line.rstrip() for line in traceback.format_stack(self._stack_frame, limit=5)]) if raise_error: try: raise RuntimeError( 'Object was never used (type {}): {}. If you want to mark it as ' 'used call its "mark_used()" method. It was originally created ' 'here:\n{}'.format(self._type, self._repr, creation_stack)) finally: self.sate() else: tf_logging.error( '==================================\n' 'Object was never used (type {}):\n{}\nIf you want to mark it as ' 'used call its "mark_used()" method.\nIt was originally created ' 'here:\n{}\n' '==================================' .format(self._type, self._repr, creation_stack))
def latest_checkpoint(checkpoint_dir, latest_filename=None): """Finds the filename of latest saved checkpoint file. Gets the checkpoint state given the provided checkpoint_dir and looks for a corresponding TensorFlow 2 (preferred) or TensorFlow 1.x checkpoint path. The latest_filename argument is only applicable if you are saving checkpoint using `v1.train.Saver.save` See the [Training Checkpoints Guide](https://www.tensorflow.org/guide/checkpoint) for more details and examples.` Args: checkpoint_dir: Directory where the variables were saved. latest_filename: Optional name for the protocol buffer file that contains the list of most recent checkpoint filenames. See the corresponding argument to `v1.train.Saver.save`. Returns: The full path to the latest checkpoint or `None` if no checkpoint was found. """ # Pick the latest checkpoint based on checkpoint state. ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) if ckpt and ckpt.model_checkpoint_path: # Look for either a V2 path or a V1 path, with priority for V2. v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, saver_pb2.SaverDef.V2) v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, saver_pb2.SaverDef.V1) if file_io.get_matching_files(v2_path) or file_io.get_matching_files( v1_path): return ckpt.model_checkpoint_path else: logging.error("Couldn't match files for checkpoint %s", ckpt.model_checkpoint_path) return None
def record_error(self, source, exc_info, session=None): """Report an exception from the given source. If a session is passed, a timer will be registered to close it after a few seconds. This is necessary to ensure the main training loop does not hang if an infeed/oufeed error occurs. We sleep a few seconds to allow a more interesting error from another thread to propagate. Args: source: string, source of the error exc_info: Output from `sys.exc_info` (type, value, traceback) session: Session to close after delay. """ _, value, _ = exc_info # Ignore errors already handled by MonitoredSession if isinstance(value, _IGNORED_ERRORS): return self._errors[source] = exc_info logging.error('Error recorded from %s: %s', source, value) if session is not None and self._session_cancel_timer is None: def _cancel_session(): time.sleep(5) logging.error('Closing session due to error %s' % value) try: session.close() except: # pylint: disable=bare-except logging.error('\n\n\nFailed to close session after error.' 'Other threads may hang.\n\n\n') self._session_cancel_timer = threading.Thread( target=_cancel_session, ) self._session_cancel_timer.daemon = True self._session_cancel_timer.start()
def __init__(self, test_generator=None, out_file=None, precision=4, verbose=0): super(STAccuracyMonitor, self).__init__() self.out_file = out_file self.precision = precision self.verbose = verbose self.acc_df = None if test_generator is None: logging.error( "\nSTAccuracyMonitor: Test generator is empty. Skipping...", RuntimeError) return else: self.test_generator = test_generator if os.path.exists(self.out_file): logging.warning( "\nSTAccuracyMonitor: Output file already exists. File will be overwritten.", RuntimeWarning)
def main(unused_argv=None): if FLAGS.debug: logging.set_verbosity(logging.DEBUG) logging.info('TensorBoard is in debug mode.') if FLAGS.inspect: logging.info( 'Not bringing up TensorBoard, but inspecting event files.') efi.inspect(logdir=FLAGS.logdir, event_file=FLAGS.event_file, tag=FLAGS.tag) return 0 if not FLAGS.logdir: msg = ('A logdir must be specified. Run `tensorboard --help` for ' 'details and examples.') logging.error(msg) print(msg) return -1 logging.info('Starting TensorBoard in directory %s', os.getcwd()) path_to_run = server.ParseEventFilesSpec(FLAGS.logdir) logging.info('TensorBoard path_to_run is: %s', path_to_run) multiplexer = event_multiplexer.EventMultiplexer( size_guidance=server.TENSORBOARD_SIZE_GUIDANCE, purge_orphaned_data=FLAGS.purge_orphaned_data) server.StartMultiplexerReloadingThread(multiplexer, path_to_run, FLAGS.reload_interval) try: tb_server = server.BuildServer(multiplexer, FLAGS.host, FLAGS.port) except socket.error: if FLAGS.port == 0: msg = 'Unable to find any open ports.' logging.error(msg) print(msg) return -2 else: msg = 'Tried to connect to port %d, but address is in use.' % FLAGS.port logging.error(msg) print(msg) return -3 try: tag = resource_loader.load_resource('tensorboard/TAG').strip() logging.info('TensorBoard is tag: %s', tag) except IOError: logging.warning('Unable to read TensorBoard tag') tag = '' status_bar.SetupStatusBarInsideGoogle('TensorBoard %s' % tag, FLAGS.port) print('Starting TensorBoard %s on port %d' % (tag, FLAGS.port)) print('(You can navigate to http://%s:%d)' % (FLAGS.host, FLAGS.port)) tb_server.serve_forever()
def _execute_schedule(experiment, schedule): """Execute the method named `schedule` of `experiment`.""" if not hasattr(experiment, schedule): logging.error("Schedule references non-existent task {}".format(schedule)) valid_tasks = [x for x in dir(experiment) if not x.startswith('_') and callable(getattr(experiment, x))] logging.error("Allowed values for this experiment are: {}".format(valid_tasks)) raise ValueError("Schedule references non-existent task {}".format(schedule)) task = getattr(experiment, schedule) if not callable(task): logging.error("Schedule references non-callable member {}".format(schedule)) valid_tasks = [x for x in dir(experiment) if not x.startswith('_') and callable(getattr(experiment, x))] logging.error("Allowed values for this experiment are: {}".format(valid_tasks)) raise TypeError("Schedule references non-callable member {}".format(schedule)) return task()
def run(experiment_fn): """Make and run an experiment.""" if not FLAGS.output_dir: raise RuntimeError( 'Must specify an output directory (use --output_dir).') if not FLAGS.schedule: raise RuntimeError('Must specify a schedule (use --schedule).') if not callable(experiment_fn): raise TypeError('Experiment builder "%s" is not callable.' % experiment_fn) # Call the builder experiment = experiment_fn(output_dir=FLAGS.output_dir) if not isinstance(experiment, Experiment): raise TypeError('Experiment builder did not return an Experiment ' 'instance, got %s instead.' % type(experiment)) # Execute the schedule taskname = FLAGS.schedule if not hasattr(experiment, taskname): logging.error('Schedule references non-existent task %s', taskname) valid_tasks = [ x for x in experiment.__dict__ if callable(getattr(experiment, x)) ] logging.error('Allowed values for this experiment are: %s', valid_tasks) raise ValueError('Schedule references non-existent task %s', taskname) task = getattr(experiment, taskname) if not callable(task): logging.error('Schedule references non-callable member %s', taskname) valid_tasks = [ x for x in experiment.__dict__ if callable(getattr(experiment, x)) and not x.startswith('_') ] logging.error('Allowed values for this experiment are: %s', valid_tasks) raise TypeError('Schedule references non-callable member %s', taskname) return task()
def create_metric_ops(self, inputs, labels, predictions): """Connect our `metric_fn` to the specified members of the given dicts. This function will call the `metric_fn` given in our constructor as follows: ``` metric_fn(predictions[self.prediction_key], labels[self.label_key], weights=weights[self.weight_key]) ``` And returns the result. The `weights` argument is only passed if `self.weight_key` is not `None`. `predictions` and `labels` may be single tensors as well as dicts. If `predictions` is a single tensor, `self.prediction_key` must be `None`. If `predictions` is a single element dict, `self.prediction_key` is allowed to be `None`. Conversely, if `labels` is a single tensor, `self.label_key` must be `None`. If `labels` is a single element dict, `self.label_key` is allowed to be `None`. Args: inputs: A dict of inputs produced by the `input_fn` labels: A dict of labels or a single label tensor produced by the `input_fn`. predictions: A dict of predictions or a single tensor produced by the `model_fn`. Returns: The result of calling `metric_fn`. Raises: ValueError: If `predictions` or `labels` is a single `Tensor` and `self.prediction_key` or `self.label_key` is not `None`; or if `self.label_key` is `None` but `labels` is a dict with more than one element, or if `self.prediction_key` is `None` but `predictions` is a dict with more than one element. """ def _get_dict(name, dict_or_tensor, key): """Get a single tensor or an element of a dict or raise ValueError.""" if key: if not isinstance(dict_or_tensor, dict): raise ValueError( 'MetricSpec with ' + name + '_key specified' ' requires ' + name + 's dict, got %s.\n' % dict_or_tensor + 'You must not provide a %s_key if you ' % name + 'only have a single Tensor as %ss.' % name) if key not in dict_or_tensor: raise KeyError('Key \'%s\' missing from %s.' % (key, dict_or_tensor.keys())) return dict_or_tensor[key] else: if isinstance(dict_or_tensor, dict): if len(dict_or_tensor) != 1: raise ValueError('MetricSpec without specified ' + name + '_key' ' requires ' + name + 's tensor or single element' ' dict, got %s' % dict_or_tensor) return six.next(six.itervalues(dict_or_tensor)) return dict_or_tensor # Get the predictions. prediction = _get_dict('prediction', predictions, self.prediction_key) # Get the labels. label = _get_dict('label', labels, self.label_key) try: return self.metric_fn( labels=label, predictions=prediction, weights=inputs[self.weight_key] if self.weight_key else None) except Exception as ex: logging.error('Could not create metric ops for %s, %s.' % (self, ex)) raise
def _AssertProtoDictEquals(self, expected_dict, actual_dict, verbose=False, update_goldens=False, additional_missing_object_message='', api_version=2): """Diff given dicts of protobufs and report differences a readable way. Args: expected_dict: a dict of TFAPIObject protos constructed from golden files. actual_dict: a ict of TFAPIObject protos constructed by reading from the TF package linked to the test. verbose: Whether to log the full diffs, or simply report which files were different. update_goldens: Whether to update goldens when there are diffs found. additional_missing_object_message: Message to print when a symbol is missing. api_version: TensorFlow API version to test. """ diffs = [] verbose_diffs = [] expected_keys = set(expected_dict.keys()) actual_keys = set(actual_dict.keys()) only_in_expected = expected_keys - actual_keys only_in_actual = actual_keys - expected_keys all_keys = expected_keys | actual_keys # This will be populated below. updated_keys = [] for key in all_keys: diff_message = '' verbose_diff_message = '' # First check if the key is not found in one or the other. if key in only_in_expected: diff_message = 'Object %s expected but not found (removed). %s' % ( key, additional_missing_object_message) verbose_diff_message = diff_message elif key in only_in_actual: diff_message = 'New object %s found (added).' % key verbose_diff_message = diff_message else: # Do not truncate diff self.maxDiff = None # pylint: disable=invalid-name # Now we can run an actual proto diff. try: self.assertProtoEquals(expected_dict[key], actual_dict[key]) except AssertionError as e: updated_keys.append(key) diff_message = 'Change detected in python object: %s.' % key verbose_diff_message = str(e) # All difference cases covered above. If any difference found, add to the # list. if diff_message: diffs.append(diff_message) verbose_diffs.append(verbose_diff_message) # If diffs are found, handle them based on flags. if diffs: diff_count = len(diffs) logging.error(self._test_readme_message) logging.error('%d differences found between API and golden.', diff_count) if update_goldens: # Write files if requested. logging.warning(self._update_golden_warning) # If the keys are only in expected, some objects are deleted. # Remove files. for key in only_in_expected: filepath = _KeyToFilePath(key, api_version) tf.compat.v1.gfile.Remove(filepath) # If the files are only in actual (current library), these are new # modules. Write them to files. Also record all updates in files. for key in only_in_actual | set(updated_keys): filepath = _KeyToFilePath(key, api_version) file_io.write_string_to_file( filepath, text_format.MessageToString(actual_dict[key])) else: # Include the actual differences to help debugging. for d, verbose_d in zip(diffs, verbose_diffs): logging.error(' %s', d) logging.error(' %s', verbose_d) # Fail if we cannot fix the test by updating goldens. self.fail('%d differences found between API and golden.' % diff_count) else: logging.info('No differences found between API and golden.')
def main(unused_argv=None): debug = FLAGS.insecure_debug_mode logdir = os.path.expanduser(FLAGS.logdir) if debug: logging.set_verbosity(logging.DEBUG) logging.warning('TensorBoard is in debug mode. This is NOT SECURE.') if FLAGS.inspect: logging.info( 'Not bringing up TensorBoard, but inspecting event files.') event_file = os.path.expanduser(FLAGS.event_file) efi.inspect(logdir, event_file, FLAGS.tag) return 0 if not logdir: msg = ('A logdir must be specified. Run `tensorboard --help` for ' 'details and examples.') logging.error(msg) print(msg) return -1 logging.info('Starting TensorBoard in directory %s', os.getcwd()) plugins = {'projector': projector_plugin.ProjectorPlugin()} tb_app = application.TensorBoardWSGIApp( logdir, plugins, purge_orphaned_data=FLAGS.purge_orphaned_data, reload_interval=FLAGS.reload_interval) try: tag = resource_loader.load_resource('tensorboard/TAG').strip() logging.info('TensorBoard is tag: %s', tag) except IOError: logging.info('Unable to read TensorBoard tag') tag = '' status_bar.SetupStatusBarInsideGoogle('TensorBoard %s' % tag, FLAGS.port) print('Starting TensorBoard %s on port %d' % (tag, FLAGS.port)) if FLAGS.host == "0.0.0.0": try: host = socket.gethostbyname(socket.gethostname()) print('(You can navigate to http://%s:%d)' % (host, FLAGS.port)) except socket.gaierror: pass else: print('(You can navigate to http://%s:%d)' % (FLAGS.host, FLAGS.port)) try: serving.run_simple(FLAGS.host, FLAGS.port, tb_app, threaded=True, use_reloader=debug, use_evalex=debug, use_debugger=debug) except socket.error: if FLAGS.port == 0: msg = 'Unable to find any open ports.' logging.error(msg) print(msg) return -2 else: msg = 'Tried to connect to port %d, but address is in use.' % FLAGS.port logging.error(msg) print(msg) return -3
def fill(self): """Intelligently sets any non-specific parameters.""" # Fail fast if num_classes or num_features isn't set. _ = getattr(self, 'num_classes') _ = getattr(self, 'num_features') self.bagged_num_features = int(self.feature_bagging_fraction * self.num_features) self.bagged_features = None if self.feature_bagging_fraction < 1.0: self.bagged_features = [ random.sample(range(self.num_features), self.bagged_num_features) for _ in range(self.num_trees) ] self.regression = getattr(self, 'regression', False) # Num_outputs is the actual number of outputs (a single prediction for # classification, a N-dimensional point for regression). self.num_outputs = self.num_classes if self.regression else 1 # Add an extra column to classes for storing counts, which is needed for # regression and avoids having to recompute sums for classification. self.num_output_columns = self.num_classes + 1 # Our experiments have found that num_splits_to_consider = num_features # gives good accuracy. self.num_splits_to_consider = self.num_splits_to_consider or min( max(10, math.floor(math.sqrt(self.num_features))), 1000) # If base_random_seed is 0, the current time will be used to seed the # random number generators for each tree. If non-zero, the i-th tree # will be seeded with base_random_seed + i. self.base_random_seed = getattr(self, 'base_random_seed', 0) # How to store leaf models. self.leaf_model_type = ( REGRESSION_MODEL_TYPE[0] if self.regression else CLASSIFICATION_LEAF_MODEL_TYPES[self.model_name][0]) # How to store stats objects. self.stats_model_type = ( REGRESSION_MODEL_TYPE[1] if self.regression else CLASSIFICATION_LEAF_MODEL_TYPES[self.model_name][1]) self.finish_type = (_params_proto.SPLIT_FINISH_BASIC if self.regression else FINISH_TYPES[self.split_finish_name]) self.pruning_type = PRUNING_TYPES[self.split_pruning_name] if self.pruning_type == _params_proto.SPLIT_PRUNE_NONE: self.prune_every_samples = 0 else: if (not self.prune_every_samples and not (isinstance(numbers.Number) or self.split_after_samples.isdigit())): logging.error( 'Must specify prune_every_samples if using a depth-dependent ' 'split_after_samples') # Pruning half-way through split_after_samples seems like a decent # default, making it easy to select the number being pruned with # pruning_type while not paying the cost of pruning too often. Note that # this only holds if not using a depth-dependent split_after_samples. self.prune_every_samples = (self.prune_every_samples or int(self.split_after_samples) / 2) if self.finish_type == _params_proto.SPLIT_FINISH_BASIC: self.early_finish_check_every_samples = 0 else: if (not self.early_finish_check_every_samples and not (isinstance(numbers.Number) or self.split_after_samples.isdigit())): logging.error( 'Must specify prune_every_samples if using a depth-dependent ' 'split_after_samples') # Checking for early finish every quarter through split_after_samples # seems like a decent default. We don't want to incur the checking cost # too often, but (at least for hoeffding) it's lower than the cost of # pruning so we can do it a little more frequently. self.early_finish_check_every_samples = ( self.early_finish_check_every_samples or int(self.split_after_samples) / 4) self.split_type = SPLIT_TYPES[self.split_name] return self