def test_trace_and_metrics(self): port = xu.get_free_tcp_ports()[0] training_started = multiprocessing.Event() def train_worker(): flags = args_parse.parse_common_options(datadir='/tmp/mnist-data', batch_size=16, momentum=0.5, lr=0.01, num_epochs=10) flags.fake_data = True flags.profiler_port = port test_profile_mp_mnist.train_mnist( flags, training_started=training_started, dynamic_graph=True, fetch_often=True) p = multiprocessing.Process(target=train_worker, daemon=True) p.start() training_started.wait(60) logdir = tempfile.mkdtemp() xp.trace(f'localhost:{port}', logdir, duration_ms=5000, num_tracing_attempts=5, delay_ms=1000) p.terminate() path = self._check_xspace_pb_exist(logdir) self._check_trace_namespace_exists(path) self._check_metrics_warnings_exist(self.fname)
def test_xla_profiler_prog_capture(tmpdir): port = xu.get_free_tcp_ports()[0] training_started = Event() def train_worker(): model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=4, profiler="xla", accelerator="tpu", devices=8) trainer.fit(model) p = Process(target=train_worker, daemon=True) p.start() training_started.wait(120) logdir = str(tmpdir) xp.trace(f"localhost:{port}", logdir, duration_ms=2000, num_tracing_attempts=5, delay_ms=1000) p.terminate() assert os.isfile( os.path.join(logdir, "plugins", "profile", "*", "*.xplane.pb"))
def trace(): xp.trace( service_addr=args.service_addr, logdir=args.logdir, duration_ms=args.duration_ms, ) print(f"Saved profiling output to {args.logdir}")
if __name__ == "__main__": xla_enabled = True amp_enabled = True debug_enabled = False dlprof_enabled = False cpu_mem_usage = False if dlprof_enabled and not xla_enabled and False: import nvidia_dlprof_pytorch_nvtx nvidia_dlprof_pytorch_nvtx.init() if xla_enabled: port_number = 8192 training_started = multiprocessing.Event() dataset_path = '/pytorch/xla/test/IMDB Dataset.csv' download_dataset() def target_fn(): xmp.spawn(_mp_fn, nprocs=1) p = multiprocessing.Process(target=target_fn, args=()) p.start() training_started.wait() xp.trace(f'localhost:{port_number}', '/pytorch/xla/test/bert_tensorboard') # xmp.spawn(_mp_fn, nprocs=1) else: dataset_path = os.path.join(os.getcwd(), "IMDB Dataset.csv") download_dataset() train_bert(dataset_path, xla_enabled, amp_enabled)