def test_single_worker_sampling_mode(self): """Test single worker sampling mode.""" def on_worker(port): logging.info('worker starting server on {}'.format(port)) profiler.start_server(port) _, steps, train_ds, model = _model_setup() model.fit(x=train_ds, epochs=2, steps_per_epoch=steps) port = portpicker.pick_unused_port() thread = threading.Thread(target=on_worker, args=(port, )) thread.start() # Request for 3 seconds of profile. duration_ms = 3000 logdir = self.get_temp_dir() options = profiler.ProfilerOptions( host_tracer_level=2, python_tracer_level=0, device_tracer_level=1, ) profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms, '', 3, options) thread.join(30) self._check_tools_pb_exist(logdir)
def collect_profile(port: int, duration_in_ms: int, host: str, log_dir: Optional[str], host_tracer_level: int, device_tracer_level: int, python_tracer_level: int, no_perfetto_link: bool): options = profiler.ProfilerOptions( host_tracer_level=host_tracer_level, device_tracer_level=device_tracer_level, python_tracer_level=python_tracer_level, ) log_dir_ = pathlib.Path( log_dir if log_dir is not None else tempfile.mkdtemp()) profiler_client.trace(f"{host}:{port}", str(log_dir_), duration_in_ms, options=options) print(f"Dumped profiling information in: {log_dir_}") # The profiler dumps `xplane.pb` to the logging directory. To upload it to # the Perfetto trace viewer, we need to convert it to a `trace.json` file. # We do this by first finding the `xplane.pb` file, then passing it into # tensorflow_profile_plugin's `xplane` conversion function. curr_path = log_dir_.resolve() root_trace_folder = curr_path / "plugins" / "profile" trace_folders = [ root_trace_folder / trace_folder for trace_folder in root_trace_folder.iterdir() ] latest_folder = max(trace_folders, key=os.path.getmtime) xplane = next(latest_folder.glob("*.xplane.pb")) result = convert.xspace_to_tool_data([xplane], "trace_viewer^", None) with gzip.open(str(latest_folder / "remote.trace.json.gz"), "wb") as fp: fp.write(result.encode("utf-8")) if not no_perfetto_link: jax._src.profiler._host_perfetto_trace_file(str(log_dir_))
def testTrace_ProfileIdleServer(self): test_port = portpicker.pick_unused_port() profiler.start_server(test_port) # Test the profilers are successfully started and connected to profiler # service on the worker. Since there is no op running, it is expected to # return UnavailableError with no trace events collected string. with self.assertRaises(errors.UnavailableError) as error: profiler_client.trace( 'localhost:' + str(test_port), self.get_temp_dir(), duration_ms=10) self.assertStartsWith(str(error.exception), 'No trace event was collected')
def testProfileWorker(self): dispatcher = server_lib.DispatchServer() worker = server_lib.WorkerServer( server_lib.WorkerConfig(dispatcher._address)) # Test the profilers are successfully started and connected to profiler # service on the worker. Since there is no op running, it is expected to # return UnavailableError with no trace events collected string. with self.assertRaises(errors.UnavailableError) as error: profiler_client.trace(worker._address, tempfile.mkdtemp(), duration_ms=10) self.assertStartsWith(str(error.exception), "No trace event was collected")
def on_profile(port, logdir): # Request for 30 milliseconds of profile. duration_ms = 30 options = profiler.ProfilerOptions( host_tracer_level=2, python_tracer_level=0, device_tracer_level=1, ) profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms, '', 100, options)
def capture_route(self, request): service_addr = request.args.get('service_addr') duration = int(request.args.get('duration', '1000')) is_tpu_name = request.args.get('is_tpu_name') == 'true' worker_list = request.args.get('worker_list') num_tracing_attempts = int(request.args.get('num_retry', '0')) + 1 if is_tpu_name: try: tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( service_addr) master_grpc_addr = tpu_cluster_resolver.get_master() except (ImportError, RuntimeError) as err: return respond({'error': err.message}, 'application/json', code=200) except (ValueError, TypeError): return respond( {'error': 'no TPUs with the specified names exist.'}, 'application/json', code=200, ) if not worker_list: worker_list = get_worker_list(tpu_cluster_resolver) # TPU cluster resolver always returns port 8470. Replace it with 8466 # on which profiler service is running. master_ip = master_grpc_addr.replace('grpc://', '').replace(':8470', '') service_addr = master_ip + ':8466' # Set the master TPU for streaming trace viewer. self.master_tpu_unsecure_channel = master_ip try: profiler_client.trace( service_addr, self.logdir, duration, worker_list, num_tracing_attempts, ) return respond( {'result': 'Capture profile successfully. Please refresh.'}, 'application/json', ) except tf.errors.UnavailableError: return respond( {'error': 'empty trace result.'}, 'application/json', code=200, ) except Exception as e: # pylint: disable=broad-except return respond( {'error': str(e)}, 'application/json', code=200, )
def on_profile(port, logdir, worker_start): worker_start.wait() options = tf_profiler.ProfilerOptions( host_tracer_level=2, python_tracer_level=2, device_tracer_level=1, delay_ms=delay_ms, ) # Request for 1000 milliseconds of profile. duration_ms = 1000 profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms, '', 1000, options) self.profile_done = True
def trace(self): self.trace_dir = os.path.join(env['dir'], 'logs') os.makedirs(self.trace_dir, exist_ok=True) options = profiler.ProfilerOptions( host_tracer_level=self.monitoring_level) while self.alive: with self._lock: try: profiler_client.trace(self.service_addr, self.trace_dir, self.duration_ms, self.workers_list, 5, options) except KeyboardInterrupt: self.alive = False print('Closing Tracer') sys.exit()
def start_monitoring(self): success = False sleep_time = 2 # Sleep for wait_time seconds to avoid the training warmup time.sleep(self.wait_time) while not success: try: profiler_client.trace(**self.args) except UnavailableError as e: self.warning( "Failed to capture TPU profile, retry in {} seconds". format(sleep_time)) time.sleep(sleep_time) else: success = True self.info("Successfully captured TPU profile")
def test_profiler_service_with_valid_trace_request(self): """Test integration with profiler service by sending tracing requests.""" # Start model server model_path = self._GetSavedModelBundlePath() _, grpc_addr, rest_addr = TensorflowModelServerTest.RunServer( 'default', model_path) # Prepare predict request url = 'http://{}/v1/models/default:predict'.format(rest_addr) json_req = '{"instances": [2.0, 3.0, 4.0]}' # In a subprocess, send a REST predict request every second for 3 seconds exec_command = ("wget {} --content-on-error=on -O- --post-data '{}' " "--header='Content-Type:application/json'").format( url, json_req) repeat_command = 'for n in {{1..3}}; do {} & sleep 1; done;'.format( exec_command) proc = subprocess.Popen(repeat_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) # Prepare args to ProfilerClient logdir = os.path.join(self.temp_dir, 'logs') worker_list = '' duration_ms = 1000 num_tracing_attempts = 3 os.makedirs(logdir) # Send a tracing request profiler_client.trace(grpc_addr, logdir, duration_ms, worker_list, num_tracing_attempts) # Log stdout & stderr of subprocess issuing predict requests for debugging out, err = proc.communicate() print("stdout: '{}' | stderr: '{}'".format(out, err))
def testStartTracing_ProcessInvalidAddress(self): with self.assertRaises(errors.UnavailableError): profiler_client.trace('localhost:6006', tempfile.mkdtemp(), 2000)
def main(unused_argv=None): logging.set_verbosity(logging.INFO) tf_version = versions.__version__ print('TensorFlow version %s detected' % tf_version) print('Welcome to the Cloud TPU Profiler v%s' % profiler_version.__version__) if LooseVersion(tf_version) < LooseVersion('2.2.0'): sys.exit('You must install tensorflow >= 2.2.0 to use this plugin.') if not FLAGS.service_addr and not FLAGS.tpu: sys.exit('You must specify either --service_addr or --tpu.') tpu_cluster_resolver = None if FLAGS.service_addr: if FLAGS.tpu: logging.warn('Both --service_addr and --tpu are set. Ignoring ' '--tpu and using --service_addr.') service_addr = FLAGS.service_addr else: try: tpu_cluster_resolver = (resolver.TPUClusterResolver( [FLAGS.tpu], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)) service_addr = tpu_cluster_resolver.get_master() except (ValueError, TypeError): sys.exit( 'Failed to find TPU %s in zone %s project %s. You may use ' '--tpu_zone and --gcp_project to specify the zone and project of' ' your TPU.' % (FLAGS.tpu, FLAGS.tpu_zone, FLAGS.gcp_project)) service_addr = service_addr.replace('grpc://', '').replace(':8470', ':8466') workers_list = '' if FLAGS.workers_list is not None: workers_list = FLAGS.workers_list elif tpu_cluster_resolver is not None: workers_list = get_workers_list(tpu_cluster_resolver) # If profiling duration was not set by user or set to a non-positive value, # we set it to a default value of 1000ms. duration_ms = FLAGS.duration_ms if FLAGS.duration_ms > 0 else 1000 if FLAGS.monitoring_level > 0: print('Since monitoring level is provided, profile', service_addr, ' for ', FLAGS.duration_ms, ' ms and show metrics for ', FLAGS.num_queries, ' time(s).') monitoring_helper(service_addr, duration_ms, FLAGS.monitoring_level, FLAGS.num_queries) else: if not FLAGS.logdir: sys.exit('You must specify either --logdir or --monitoring_level.') if not gfile.Exists(FLAGS.logdir): gfile.MakeDirs(FLAGS.logdir) try: if LooseVersion(tf_version) < LooseVersion('2.3.0'): profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir), duration_ms, workers_list, FLAGS.num_tracing_attempts) else: options = profiler.ProfilerOptions( host_tracer_level=FLAGS.host_tracer_level) profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir), duration_ms, workers_list, FLAGS.num_tracing_attempts, options) except errors.UnavailableError: sys.exit(0)
def capture_route_impl(self, request): """Runs the client trace for capturing profiling information.""" service_addr = request.args.get('service_addr') duration = int(request.args.get('duration', '1000')) is_tpu_name = request.args.get('is_tpu_name') == 'true' worker_list = request.args.get('worker_list') num_tracing_attempts = int(request.args.get('num_retry', '0')) + 1 options = None try: options = profiler.ProfilerOptions( host_tracer_level=int( request.args.get('host_tracer_level', '2')), device_tracer_level=int( request.args.get('device_tracer_level', '1')), python_tracer_level=int( request.args.get('python_tracer_level', '0')), ) # For preserving backwards compatibility with TensorFlow 2.3 and older. if 'delay_ms' in options._fields: options.delay_ms = int(request.args.get('delay', '0')) except AttributeError: logger.warning( 'ProfilerOptions are available after tensorflow 2.3') if is_tpu_name: try: tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( service_addr) master_grpc_addr = tpu_cluster_resolver.get_master() except (ImportError, RuntimeError) as err: return respond({'error': err.message}, 'application/json', code=200) except (ValueError, TypeError): return respond( {'error': 'no TPUs with the specified names exist.'}, 'application/json', code=200, ) if not worker_list: worker_list = get_worker_list(tpu_cluster_resolver) # TPU cluster resolver always returns port 8470. Replace it with 8466 # on which profiler service is running. master_ip = master_grpc_addr.replace('grpc://', '').replace(':8470', '') service_addr = master_ip + ':8466' # Set the master TPU for streaming trace viewer. self.master_tpu_unsecure_channel = master_ip try: if options: profiler_client.trace(service_addr, self.logdir, duration, worker_list, num_tracing_attempts, options=options) else: profiler_client.trace( service_addr, self.logdir, duration, worker_list, num_tracing_attempts, ) return respond( {'result': 'Capture profile successfully. Please refresh.'}, 'application/json', ) except tf.errors.UnavailableError: return respond( {'error': 'empty trace result.'}, 'application/json', code=200, ) except Exception as e: # pylint: disable=broad-except return respond( {'error': str(e)}, 'application/json', code=200, )
def testStartTracing_ProcessInvalidAddressWithOptions(self): with self.assertRaises(errors.UnavailableError): options = profiler.ProfilerOptions( host_tracer_level=3, device_tracer_level=0) profiler_client.trace( 'localhost:6006', tempfile.mkdtemp(), 2000, options=options)
def main(argv): server = argv[1] if len(argv) > 1 else 'localhost:8500' logdir = argv[2] if len(argv) > 2 else '/tmp' duration_ms = argv[3] if len(argv) > 3 else 2000 profiler_client.trace(server, logdir, duration_ms)