def test_single_worker_sampling_mode(self):
        """Test single worker sampling mode."""
        def on_worker(port):
            logging.info('worker starting server on {}'.format(port))
            profiler.start_server(port)
            _, steps, train_ds, model = _model_setup()
            model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)

        port = portpicker.pick_unused_port()
        thread = threading.Thread(target=on_worker, args=(port, ))
        thread.start()
        # Request for 3 seconds of profile.
        duration_ms = 3000
        logdir = self.get_temp_dir()

        options = profiler.ProfilerOptions(
            host_tracer_level=2,
            python_tracer_level=0,
            device_tracer_level=1,
        )

        profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms,
                              '', 3, options)
        thread.join(30)
        self._check_tools_pb_exist(logdir)
Example #2
0
def collect_profile(port: int, duration_in_ms: int, host: str,
                    log_dir: Optional[str], host_tracer_level: int,
                    device_tracer_level: int, python_tracer_level: int,
                    no_perfetto_link: bool):
    options = profiler.ProfilerOptions(
        host_tracer_level=host_tracer_level,
        device_tracer_level=device_tracer_level,
        python_tracer_level=python_tracer_level,
    )
    log_dir_ = pathlib.Path(
        log_dir if log_dir is not None else tempfile.mkdtemp())
    profiler_client.trace(f"{host}:{port}",
                          str(log_dir_),
                          duration_in_ms,
                          options=options)
    print(f"Dumped profiling information in: {log_dir_}")
    # The profiler dumps `xplane.pb` to the logging directory. To upload it to
    # the Perfetto trace viewer, we need to convert it to a `trace.json` file.
    # We do this by first finding the `xplane.pb` file, then passing it into
    # tensorflow_profile_plugin's `xplane` conversion function.
    curr_path = log_dir_.resolve()
    root_trace_folder = curr_path / "plugins" / "profile"
    trace_folders = [
        root_trace_folder / trace_folder
        for trace_folder in root_trace_folder.iterdir()
    ]
    latest_folder = max(trace_folders, key=os.path.getmtime)
    xplane = next(latest_folder.glob("*.xplane.pb"))
    result = convert.xspace_to_tool_data([xplane], "trace_viewer^", None)

    with gzip.open(str(latest_folder / "remote.trace.json.gz"), "wb") as fp:
        fp.write(result.encode("utf-8"))

    if not no_perfetto_link:
        jax._src.profiler._host_perfetto_trace_file(str(log_dir_))
Example #3
0
 def testTrace_ProfileIdleServer(self):
   test_port = portpicker.pick_unused_port()
   profiler.start_server(test_port)
   # Test the profilers are successfully started and connected to profiler
   # service on the worker. Since there is no op running, it is expected to
   # return UnavailableError with no trace events collected string.
   with self.assertRaises(errors.UnavailableError) as error:
     profiler_client.trace(
         'localhost:' + str(test_port), self.get_temp_dir(), duration_ms=10)
   self.assertStartsWith(str(error.exception), 'No trace event was collected')
 def testProfileWorker(self):
   dispatcher = server_lib.DispatchServer()
   worker = server_lib.WorkerServer(
       server_lib.WorkerConfig(dispatcher._address))
   # Test the profilers are successfully started and connected to profiler
   # service on the worker. Since there is no op running, it is expected to
   # return UnavailableError with no trace events collected string.
   with self.assertRaises(errors.UnavailableError) as error:
     profiler_client.trace(worker._address, tempfile.mkdtemp(), duration_ms=10)
   self.assertStartsWith(str(error.exception), "No trace event was collected")
Example #5
0
    def on_profile(port, logdir):
      # Request for 30 milliseconds of profile.
      duration_ms = 30

      options = profiler.ProfilerOptions(
          host_tracer_level=2,
          python_tracer_level=0,
          device_tracer_level=1,
      )

      profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms,
                            '', 100, options)
Example #6
0
  def capture_route(self, request):
    service_addr = request.args.get('service_addr')
    duration = int(request.args.get('duration', '1000'))
    is_tpu_name = request.args.get('is_tpu_name') == 'true'
    worker_list = request.args.get('worker_list')
    num_tracing_attempts = int(request.args.get('num_retry', '0')) + 1

    if is_tpu_name:
      try:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            service_addr)
        master_grpc_addr = tpu_cluster_resolver.get_master()
      except (ImportError, RuntimeError) as err:
        return respond({'error': err.message}, 'application/json', code=200)
      except (ValueError, TypeError):
        return respond(
            {'error': 'no TPUs with the specified names exist.'},
            'application/json',
            code=200,
        )
      if not worker_list:
        worker_list = get_worker_list(tpu_cluster_resolver)
      # TPU cluster resolver always returns port 8470. Replace it with 8466
      # on which profiler service is running.
      master_ip = master_grpc_addr.replace('grpc://', '').replace(':8470', '')
      service_addr = master_ip + ':8466'
      # Set the master TPU for streaming trace viewer.
      self.master_tpu_unsecure_channel = master_ip
    try:
      profiler_client.trace(
          service_addr,
          self.logdir,
          duration,
          worker_list,
          num_tracing_attempts,
      )
      return respond(
          {'result': 'Capture profile successfully. Please refresh.'},
          'application/json',
      )
    except tf.errors.UnavailableError:
      return respond(
          {'error': 'empty trace result.'},
          'application/json',
          code=200,
      )
    except Exception as e:  # pylint: disable=broad-except
      return respond(
          {'error': str(e)},
          'application/json',
          code=200,
      )
Example #7
0
        def on_profile(port, logdir, worker_start):
            worker_start.wait()
            options = tf_profiler.ProfilerOptions(
                host_tracer_level=2,
                python_tracer_level=2,
                device_tracer_level=1,
                delay_ms=delay_ms,
            )

            # Request for 1000 milliseconds of profile.
            duration_ms = 1000
            profiler_client.trace('localhost:{}'.format(port), logdir,
                                  duration_ms, '', 1000, options)
            self.profile_done = True
Example #8
0
 def trace(self):
     self.trace_dir = os.path.join(env['dir'], 'logs')
     os.makedirs(self.trace_dir, exist_ok=True)
     options = profiler.ProfilerOptions(
         host_tracer_level=self.monitoring_level)
     while self.alive:
         with self._lock:
             try:
                 profiler_client.trace(self.service_addr, self.trace_dir,
                                       self.duration_ms, self.workers_list,
                                       5, options)
             except KeyboardInterrupt:
                 self.alive = False
                 print('Closing Tracer')
                 sys.exit()
Example #9
0
    def start_monitoring(self):
        success = False
        sleep_time = 2

        # Sleep for wait_time seconds to avoid the training warmup
        time.sleep(self.wait_time)

        while not success:
            try:
                profiler_client.trace(**self.args)
            except UnavailableError as e:
                self.warning(
                    "Failed to capture TPU profile, retry in {} seconds".
                    format(sleep_time))
                time.sleep(sleep_time)
            else:
                success = True
                self.info("Successfully captured TPU profile")
Example #10
0
    def test_profiler_service_with_valid_trace_request(self):
        """Test integration with profiler service by sending tracing requests."""

        # Start model server
        model_path = self._GetSavedModelBundlePath()
        _, grpc_addr, rest_addr = TensorflowModelServerTest.RunServer(
            'default', model_path)

        # Prepare predict request
        url = 'http://{}/v1/models/default:predict'.format(rest_addr)
        json_req = '{"instances": [2.0, 3.0, 4.0]}'

        # In a subprocess, send a REST predict request every second for 3 seconds
        exec_command = ("wget {} --content-on-error=on -O- --post-data  '{}' "
                        "--header='Content-Type:application/json'").format(
                            url, json_req)
        repeat_command = 'for n in {{1..3}}; do {} & sleep 1; done;'.format(
            exec_command)
        proc = subprocess.Popen(repeat_command,
                                shell=True,
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE)

        # Prepare args to ProfilerClient
        logdir = os.path.join(self.temp_dir, 'logs')
        worker_list = ''
        duration_ms = 1000
        num_tracing_attempts = 3
        os.makedirs(logdir)

        # Send a tracing request
        profiler_client.trace(grpc_addr, logdir, duration_ms, worker_list,
                              num_tracing_attempts)

        #  Log stdout & stderr of subprocess issuing predict requests for debugging
        out, err = proc.communicate()
        print("stdout: '{}' | stderr: '{}'".format(out, err))
Example #11
0
 def testStartTracing_ProcessInvalidAddress(self):
     with self.assertRaises(errors.UnavailableError):
         profiler_client.trace('localhost:6006', tempfile.mkdtemp(), 2000)
Example #12
0
def main(unused_argv=None):
    logging.set_verbosity(logging.INFO)
    tf_version = versions.__version__
    print('TensorFlow version %s detected' % tf_version)
    print('Welcome to the Cloud TPU Profiler v%s' %
          profiler_version.__version__)

    if LooseVersion(tf_version) < LooseVersion('2.2.0'):
        sys.exit('You must install tensorflow >= 2.2.0 to use this plugin.')

    if not FLAGS.service_addr and not FLAGS.tpu:
        sys.exit('You must specify either --service_addr or --tpu.')

    tpu_cluster_resolver = None
    if FLAGS.service_addr:
        if FLAGS.tpu:
            logging.warn('Both --service_addr and --tpu are set. Ignoring '
                         '--tpu and using --service_addr.')
        service_addr = FLAGS.service_addr
    else:
        try:
            tpu_cluster_resolver = (resolver.TPUClusterResolver(
                [FLAGS.tpu], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project))
            service_addr = tpu_cluster_resolver.get_master()
        except (ValueError, TypeError):
            sys.exit(
                'Failed to find TPU %s in zone %s project %s. You may use '
                '--tpu_zone and --gcp_project to specify the zone and project of'
                ' your TPU.' % (FLAGS.tpu, FLAGS.tpu_zone, FLAGS.gcp_project))
    service_addr = service_addr.replace('grpc://',
                                        '').replace(':8470', ':8466')

    workers_list = ''
    if FLAGS.workers_list is not None:
        workers_list = FLAGS.workers_list
    elif tpu_cluster_resolver is not None:
        workers_list = get_workers_list(tpu_cluster_resolver)

    # If profiling duration was not set by user or set to a non-positive value,
    # we set it to a default value of 1000ms.
    duration_ms = FLAGS.duration_ms if FLAGS.duration_ms > 0 else 1000

    if FLAGS.monitoring_level > 0:
        print('Since monitoring level is provided, profile', service_addr,
              ' for ', FLAGS.duration_ms, ' ms and show metrics for ',
              FLAGS.num_queries, ' time(s).')
        monitoring_helper(service_addr, duration_ms, FLAGS.monitoring_level,
                          FLAGS.num_queries)
    else:
        if not FLAGS.logdir:
            sys.exit('You must specify either --logdir or --monitoring_level.')

        if not gfile.Exists(FLAGS.logdir):
            gfile.MakeDirs(FLAGS.logdir)

        try:
            if LooseVersion(tf_version) < LooseVersion('2.3.0'):
                profiler_client.trace(service_addr,
                                      os.path.expanduser(FLAGS.logdir),
                                      duration_ms, workers_list,
                                      FLAGS.num_tracing_attempts)
            else:
                options = profiler.ProfilerOptions(
                    host_tracer_level=FLAGS.host_tracer_level)
                profiler_client.trace(service_addr,
                                      os.path.expanduser(FLAGS.logdir),
                                      duration_ms, workers_list,
                                      FLAGS.num_tracing_attempts, options)
        except errors.UnavailableError:
            sys.exit(0)
Example #13
0
    def capture_route_impl(self, request):
        """Runs the client trace for capturing profiling information."""
        service_addr = request.args.get('service_addr')
        duration = int(request.args.get('duration', '1000'))
        is_tpu_name = request.args.get('is_tpu_name') == 'true'
        worker_list = request.args.get('worker_list')
        num_tracing_attempts = int(request.args.get('num_retry', '0')) + 1
        options = None
        try:
            options = profiler.ProfilerOptions(
                host_tracer_level=int(
                    request.args.get('host_tracer_level', '2')),
                device_tracer_level=int(
                    request.args.get('device_tracer_level', '1')),
                python_tracer_level=int(
                    request.args.get('python_tracer_level', '0')),
            )
            # For preserving backwards compatibility with TensorFlow 2.3 and older.
            if 'delay_ms' in options._fields:
                options.delay_ms = int(request.args.get('delay', '0'))
        except AttributeError:
            logger.warning(
                'ProfilerOptions are available after tensorflow 2.3')

        if is_tpu_name:
            try:
                tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
                    service_addr)
                master_grpc_addr = tpu_cluster_resolver.get_master()
            except (ImportError, RuntimeError) as err:
                return respond({'error': err.message},
                               'application/json',
                               code=200)
            except (ValueError, TypeError):
                return respond(
                    {'error': 'no TPUs with the specified names exist.'},
                    'application/json',
                    code=200,
                )
            if not worker_list:
                worker_list = get_worker_list(tpu_cluster_resolver)
            # TPU cluster resolver always returns port 8470. Replace it with 8466
            # on which profiler service is running.
            master_ip = master_grpc_addr.replace('grpc://',
                                                 '').replace(':8470', '')
            service_addr = master_ip + ':8466'
            # Set the master TPU for streaming trace viewer.
            self.master_tpu_unsecure_channel = master_ip
        try:
            if options:
                profiler_client.trace(service_addr,
                                      self.logdir,
                                      duration,
                                      worker_list,
                                      num_tracing_attempts,
                                      options=options)
            else:
                profiler_client.trace(
                    service_addr,
                    self.logdir,
                    duration,
                    worker_list,
                    num_tracing_attempts,
                )
            return respond(
                {'result': 'Capture profile successfully. Please refresh.'},
                'application/json',
            )
        except tf.errors.UnavailableError:
            return respond(
                {'error': 'empty trace result.'},
                'application/json',
                code=200,
            )
        except Exception as e:  # pylint: disable=broad-except
            return respond(
                {'error': str(e)},
                'application/json',
                code=200,
            )
 def testStartTracing_ProcessInvalidAddressWithOptions(self):
   with self.assertRaises(errors.UnavailableError):
     options = profiler.ProfilerOptions(
         host_tracer_level=3, device_tracer_level=0)
     profiler_client.trace(
         'localhost:6006', tempfile.mkdtemp(), 2000, options=options)
Example #15
0
def main(argv):
  server = argv[1] if len(argv) > 1 else 'localhost:8500'
  logdir = argv[2] if len(argv) > 2 else '/tmp'
  duration_ms = argv[3] if len(argv) > 3 else 2000
  profiler_client.trace(server, logdir, duration_ms)