コード例 #1
0
ファイル: test_profiler.py プロジェクト: pytorch/xla
    def test_trace_and_metrics(self):

        port = xu.get_free_tcp_ports()[0]
        training_started = multiprocessing.Event()

        def train_worker():
            flags = args_parse.parse_common_options(datadir='/tmp/mnist-data',
                                                    batch_size=16,
                                                    momentum=0.5,
                                                    lr=0.01,
                                                    num_epochs=10)
            flags.fake_data = True
            flags.profiler_port = port
            test_profile_mp_mnist.train_mnist(
                flags,
                training_started=training_started,
                dynamic_graph=True,
                fetch_often=True)

        p = multiprocessing.Process(target=train_worker, daemon=True)
        p.start()
        training_started.wait(60)

        logdir = tempfile.mkdtemp()
        xp.trace(f'localhost:{port}',
                 logdir,
                 duration_ms=5000,
                 num_tracing_attempts=5,
                 delay_ms=1000)
        p.terminate()
        path = self._check_xspace_pb_exist(logdir)
        self._check_trace_namespace_exists(path)
        self._check_metrics_warnings_exist(self.fname)
コード例 #2
0
def test_xla_profiler_prog_capture(tmpdir):

    port = xu.get_free_tcp_ports()[0]
    training_started = Event()

    def train_worker():
        model = BoringModel()
        trainer = Trainer(default_root_dir=tmpdir,
                          max_epochs=4,
                          profiler="xla",
                          accelerator="tpu",
                          devices=8)

        trainer.fit(model)

    p = Process(target=train_worker, daemon=True)
    p.start()
    training_started.wait(120)

    logdir = str(tmpdir)
    xp.trace(f"localhost:{port}",
             logdir,
             duration_ms=2000,
             num_tracing_attempts=5,
             delay_ms=1000)

    p.terminate()

    assert os.isfile(
        os.path.join(logdir, "plugins", "profile", "*", "*.xplane.pb"))
コード例 #3
0
ファイル: capture_profile.py プロジェクト: pytorch/xla
 def trace():
     xp.trace(
         service_addr=args.service_addr,
         logdir=args.logdir,
         duration_ms=args.duration_ms,
     )
     print(f"Saved profiling output to {args.logdir}")
コード例 #4
0
if __name__ == "__main__":
    xla_enabled = True
    amp_enabled = True
    debug_enabled = False
    dlprof_enabled = False
    cpu_mem_usage = False

    if dlprof_enabled and not xla_enabled and False:
        import nvidia_dlprof_pytorch_nvtx
        nvidia_dlprof_pytorch_nvtx.init()
    if xla_enabled:
        port_number = 8192
        training_started = multiprocessing.Event()
        dataset_path = '/pytorch/xla/test/IMDB Dataset.csv'
        download_dataset()

        def target_fn():
            xmp.spawn(_mp_fn, nprocs=1)

        p = multiprocessing.Process(target=target_fn, args=())
        p.start()
        training_started.wait()
        xp.trace(f'localhost:{port_number}',
                 '/pytorch/xla/test/bert_tensorboard')
        # xmp.spawn(_mp_fn, nprocs=1)
    else:
        dataset_path = os.path.join(os.getcwd(), "IMDB Dataset.csv")
        download_dataset()
        train_bert(dataset_path, xla_enabled, amp_enabled)