コード例 #1
0
 def run_profiler(tensor_creation_fn):
     # collecting allocs / deallocs
     with _profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof:
         x = None
         with record_function("test_user_scope_alloc"):
             x = tensor_creation_fn()
         with record_function("test_user_scope_dealloc"):
             del x
     return prof.key_averages(group_by_input_shape=True)
コード例 #2
0
ファイル: test_profiler.py プロジェクト: tongxin/pytorch
 def _record_function_with_param(self):
     u = torch.randn(3, 4, 5, requires_grad=True)
     with _profile(with_stack=True, use_kineto=kineto_available(), record_shapes=True) as prof:
         with record_function("## TEST 1 ##", "1, 2, 3"):
             rf_handle = _record_function_with_args_enter("## TEST 2 ##", 1, False, 2.5, [u, u], "hello", u)
             _record_function_with_args_exit(rf_handle)
         with record_function("## TEST 3 ##"):
             rf_handle = _record_function_with_args_enter("## TEST 4 ##")
             _record_function_with_args_exit(rf_handle)
     return prof
コード例 #3
0
ファイル: test_profiler.py プロジェクト: yanboliang/pytorch
    def test_execution_graph_start_stop(self):
        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities(
        )
        # Create a temp file to save execution graph data.
        fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
        fp.close()
        expected_loop_events = 0
        eg = ExecutionGraphObserver()
        eg.register_callback(fp.name)
        for idx in range(10):
            if idx == 3:
                eg.start()
            elif idx == 5:
                eg.stop()
            elif idx == 8:
                eg.start()
            elif idx == 9:
                eg.stop()
                eg.unregister_callback()
            if eg._execution_graph_running:
                expected_loop_events += 1
            with record_function(f"## LOOP {idx} ##"):
                self.payload(use_cuda=use_cuda)

        assert fp.name == eg.get_output_file_path()
        nodes = self.get_execution_graph_root(fp.name)
        loop_count = 0
        for n in nodes:
            assert "name" in n
            if "[pytorch|profiler|execution_graph|process]" in n["name"]:
                found_root_node = True
            if n["name"].startswith("## LOOP "):
                loop_count += 1
        assert found_root_node
        assert loop_count == expected_loop_events
コード例 #4
0
def profile_cuda_kernels(fn, args, string_id="Model time"):
    print("################################################")
    print(f"#### Profiling for {string_id} starts #########")
    print("################################################")
    warmup = 50
    old_args = args[:]
    n_repeats = 1
    n_layers = 1
    ref = fn(*old_args)
    gO = torch.rand_like(ref)
    for _ in range(0, warmup // n_layers):
        args = list(old_args[:])
        ref = fn(*args)
        ref.backward(gO)

    torch.cuda.synchronize()

    # Forward profile
    def fwd_run():
        for _ in range(0, n_repeats // n_layers):
            args = list(old_args[:])
            for arg in args:
                if isinstance(arg, torch.Tensor):
                    arg.grad = None
            ref = fn(*args)

    print(f"###### Forward profile for {string_id} starts #####")
    with profile(activities=[ProfilerActivity.CUDA],
                 record_shapes=True) as prof:
        with record_function("baseline"):
            fwd_run()
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
    print(f"###### Forward profile for {string_id} ends #####")

    # Backward profile
    def bwd_run():
        for _ in range(0, n_repeats // n_layers):
            args = list(old_args[:])
            for arg in args:
                if isinstance(arg, torch.Tensor):
                    arg.grad = None
            ref = fn(*args)

            print(f"###### Backward profile for {string_id} starts #####")
            torch.cuda.synchronize()
            with profile(activities=[ProfilerActivity.CUDA],
                         record_shapes=True) as prof:
                with record_function("baseline"):
                    ref.backward(gO)
            print(prof.key_averages().table(sort_by="cuda_time_total",
                                            row_limit=30))
            torch.cuda.synchronize()
            print(f"###### Backward profile for {string_id} ends #####")

    bwd_run()
    print("################################################")
    print(f"#### Profiling for {string_id} ends #########")
    print("################################################\n\n\n\n")
コード例 #5
0
def train_func():
    twp = TorchWorkerProfiler()
    with profile(
            activities=[],
            schedule=schedule(wait=0, warmup=0, active=1),
            on_trace_ready=twp.trace_handler,
    ) as p:

        # Setup model.
        model = torch.nn.Linear(1, 1)
        model = train.torch.prepare_model(model)
        loss_fn = torch.nn.MSELoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

        # Setup data.
        input = torch.randn(1000, 1)
        labels = input * 2
        dataset = torch.utils.data.TensorDataset(input, labels)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
        dataloader = train.torch.prepare_data_loader(dataloader)

        # Train.
        for epoch in range(5):
            with record_function("train_epoch"):
                for X, y in dataloader:
                    pred = model(X)
                    loss = loss_fn(pred, y)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

            with record_function("train_checkpoint"):
                state_dict = model.state_dict()
                consume_prefix_in_state_dict_if_present(state_dict, "module.")
                train.save_checkpoint(epoch=epoch, model_weights=state_dict)

            p.step()

            with record_function("train_report"):
                profile_results = twp.get_and_clear_profile_traces()
                train.report(epoch=epoch, **profile_results)
コード例 #6
0
ファイル: profiler.py プロジェクト: necla-ml/ML
 def __init__(self,
              record_func_name='inference',
              activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
              record_shapes=False,
              profile_memory=True,
              scheduler=schedule(wait=1, warmup=1, active=2),
              trace_handler=tensorboard_trace_handler('./log')):
     self.activities = activities
     self.profile = profile(activities=activities,
                            record_shapes=record_shapes,
                            profile_memory=profile_memory,
                            with_flops=True,
                            schedule=scheduler,
                            on_trace_ready=trace_handler)
     self.record_function = record_function(record_func_name)
コード例 #7
0
ファイル: test_profiler.py プロジェクト: tongxin/pytorch
 def payload(self, use_cuda=False):
     u = torch.randn(3, 4, 5, requires_grad=True)
     with record_function("## TEST 1 ##", "1, 2, 3"):
         rf_handle = _record_function_with_args_enter("## TEST 2 ##", 1, False, 2.5, [u, u], (u, u), "hello", u)
         x = torch.randn(10, 10, requires_grad=True)
         if use_cuda:
             x = x.cuda()
         y = torch.randn(10, 10, requires_grad=True)
         if use_cuda:
             y = y.cuda()
         z = x + y + x * y + x * y
         z.backward(z)
         if use_cuda:
             z = z.cpu()
         _record_function_with_args_exit(rf_handle)
コード例 #8
0
ファイル: test_profiler.py プロジェクト: yanboliang/pytorch
    def test_execution_graph_with_kineto(self):
        trace_called_num = 0

        def trace_handler(p):
            nonlocal trace_called_num
            trace_called_num += 1

        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities(
        )
        # Create a temp file to save execution graph data.
        fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
        fp.close()
        expected_loop_events = 0
        eg = ExecutionGraphObserver()
        eg.register_callback(fp.name)
        with profile(
                activities=supported_activities(),
                schedule=torch.profiler.schedule(skip_first=3,
                                                 wait=1,
                                                 warmup=1,
                                                 active=2),
                on_trace_ready=trace_handler,
        ) as p:
            eg.start()
            for idx in range(10):
                expected_loop_events += 1
                with record_function(f"## LOOP {idx} ##"):
                    self.payload(use_cuda=use_cuda)
                p.step()
            eg.stop()

        eg.unregister_callback()

        assert trace_called_num == 2
        assert fp.name == eg.get_output_file_path()
        nodes = self.get_execution_graph_root(fp.name)
        loop_count = 0
        for n in nodes:
            assert "name" in n
            if "[pytorch|profiler|execution_graph|process]" in n["name"]:
                found_root_node = True
            if n["name"].startswith("## LOOP "):
                loop_count += 1
        assert found_root_node
        assert loop_count == expected_loop_events
コード例 #9
0
def profile_conv_runtimes(model, filename):
    model = model.cuda()
    inputs = torch.randn(32, 3, 224, 224).cuda()
    with profile(activities=[ProfilerActivity.CUDA],
                 profile_memory=True,
                 record_shapes=True) as prof:
        with record_function("model_inference"):
            model(inputs)
    print(
        prof.key_averages(group_by_input_shape=True).table(
            sort_by="cpu_time_total", row_limit=10))
    print(
        prof.key_averages(group_by_input_shape=True).table(
            sort_by="cuda_time_total", row_limit=10))
    print(
        prof.key_averages(group_by_input_shape=True).table(
            sort_by="cuda_memory_usage", row_limit=10))
    prof.export_chrome_trace(filename + '.json')
コード例 #10
0
def record_function(name: str, with_tag: str = "##"):
    """
    Context manager to annotate a scope with meta data used for
    profiling. The tag is used to surround the name.
    """
    import torch.autograd.profiler as profiler

    if with_tag:
        name = " ".join([with_tag, name, with_tag])

    if is_nvtx_available():
        import nvtx

        nvtx_context = nvtx.annotate(message=name)
    else:
        nvtx_context = null_context()
    with profiler.record_function(name), nvtx_context:
        yield
コード例 #11
0
    def bwd_run():
        for _ in range(0, n_repeats // n_layers):
            args = list(old_args[:])
            for arg in args:
                if isinstance(arg, torch.Tensor):
                    arg.grad = None
            ref = fn(*args)

            print(f"###### Backward profile for {string_id} starts #####")
            torch.cuda.synchronize()
            with profile(activities=[ProfilerActivity.CUDA],
                         record_shapes=True) as prof:
                with record_function("baseline"):
                    ref.backward(gO)
            print(prof.key_averages().table(sort_by="cuda_time_total",
                                            row_limit=30))
            torch.cuda.synchronize()
            print(f"###### Backward profile for {string_id} ends #####")
コード例 #12
0
    def train_func():
        from ray.train.torch import TorchWorkerProfiler
        from torch.profiler import profile, record_function, schedule

        twp = TorchWorkerProfiler()
        with profile(
                activities=[],
                schedule=schedule(wait=0, warmup=0, active=1),
                on_trace_ready=twp.trace_handler,
        ) as p:

            for epoch in range(num_epochs):
                with record_function("test_function"):
                    pass

                p.step()

                profile_results = twp.get_and_clear_profile_traces()
                train.report(epoch=epoch, **profile_results)
コード例 #13
0
    def test_memory_profiler(self):
        def run_profiler(tensor_creation_fn):
            # collecting allocs / deallocs
            with _profile(profile_memory=True,
                          record_shapes=True,
                          use_kineto=kineto_available()) as prof:
                x = None
                with record_function("test_user_scope_alloc"):
                    x = tensor_creation_fn()
                with record_function("test_user_scope_dealloc"):
                    del x
            return prof.key_averages(group_by_input_shape=True)

        def check_metrics(stats, metric, allocs=None, deallocs=None):
            stat_metrics = {}
            for stat in stats:
                stat_metrics[stat.key] = getattr(stat, metric)
            if allocs is not None:
                for alloc_fn in allocs:
                    self.assertTrue(alloc_fn in stat_metrics)
                    self.assertTrue(stat_metrics[alloc_fn] > 0)
            if deallocs is not None:
                for dealloc_fn in deallocs:
                    self.assertTrue(dealloc_fn in stat_metrics)
                    self.assertTrue(stat_metrics[dealloc_fn] < 0)

        def create_cpu_tensor():
            return torch.rand(10, 10)

        def create_cuda_tensor():
            return torch.rand(10, 10).cuda()

        def create_mkldnn_tensor():
            return torch.rand(10, 10, dtype=torch.float32).to_mkldnn()

        stats = run_profiler(create_cpu_tensor)
        check_metrics(stats,
                      "cpu_memory_usage",
                      allocs=[
                          "aten::empty",
                          "aten::rand",
                          "test_user_scope_alloc",
                      ],
                      deallocs=[
                          "test_user_scope_dealloc",
                      ])

        if kineto_available():
            with TemporaryFileName(mode="w+") as fname:
                with profile(profile_memory=True) as prof:
                    x = None
                    with record_function("test_user_scope_alloc"):
                        x = create_cpu_tensor()
                    with record_function("test_user_scope_dealloc"):
                        del x
                prof.export_chrome_trace(fname)
                with io.open(fname, 'r') as f:
                    trace = json.load(f)
                    assert "traceEvents" in trace
                    events = trace["traceEvents"]
                    found_memory_events = False
                    for evt in events:
                        assert "name" in evt
                        if evt["name"] == "[memory]":
                            found_memory_events = True
                            assert "args" in evt
                            assert "Device Type" in evt["args"]
                            assert "Device Id" in evt["args"]
                            assert "Bytes" in evt["args"]
                    assert found_memory_events

        if torch.cuda.is_available():
            create_cuda_tensor()
            stats = run_profiler(create_cuda_tensor)
            check_metrics(stats,
                          "cuda_memory_usage",
                          allocs=[
                              "test_user_scope_alloc",
                              "aten::to",
                              "aten::empty_strided",
                          ],
                          deallocs=[
                              "test_user_scope_dealloc",
                          ])
            check_metrics(stats,
                          "cpu_memory_usage",
                          allocs=[
                              "aten::rand",
                              "aten::empty",
                          ])

        if torch._C.has_mkldnn:
            create_mkldnn_tensor()
            stats = run_profiler(create_mkldnn_tensor)
            check_metrics(stats,
                          "cpu_memory_usage",
                          allocs=[
                              "test_user_scope_alloc",
                              "aten::rand",
                              "aten::empty",
                              "aten::to_mkldnn",
                          ],
                          deallocs=[
                              "test_user_scope_dealloc",
                          ])

        # check top-level memory events
        with _profile(profile_memory=True,
                      use_kineto=kineto_available()) as prof:
            x = torch.rand(10, 10)
            del x
            if torch.cuda.is_available():
                y = torch.rand(10, 10).cuda()
                del y
            gc.collect()
        stats = prof.key_averages(group_by_input_shape=True)
        check_metrics(stats,
                      "cpu_memory_usage",
                      allocs=["aten::rand", "aten::empty"],
                      deallocs=["[memory]"])
        if torch.cuda.is_available():
            check_metrics(stats, "cuda_memory_usage", deallocs=["[memory]"])
コード例 #14
0
def inference(model, dataloader, datatype, args):
    batch_time = AverageMeter('Time', ':6.3f')
    batch_size = args.batch_size
    warmup_iters = args.warmup_iterations
    max_iters = args.max_iterations if dataloader is None else len(dataloader)
    model.eval()
    coco = get_coco_api_from_dataset(dataloader.dataset)
    iou_types = ["bbox"]
    iou_types.append("segm")
    coco_evaluator = CocoEvaluator(coco, iou_types)
    if args.ipex:
        import intel_extension_for_pytorch as ipex
        model = model.to(memory_format=torch.channels_last)
        model = ipex.optimize(model,
                              dtype=datatype,
                              level="O1",
                              conv_bn_folding=False,
                              replace_dropout_with_identity=False)
        model.backbone = ipex.optimize(model.backbone,
                                       dtype=datatype,
                                       level="O1")
    else:
        if args.jit:
            model = model.to(memory_format=torch.channels_last)
        else:
            from torch.utils import mkldnn as mkldnn_utils
            model = mkldnn_utils.to_mkldnn(model, dtype=datatype)
    if args.jit:
        x = torch.randn(batch_size, 3, 1200,
                        1200).to(memory_format=torch.channels_last)
        if args.precision == "bf16":
            with torch.cpu.amp.autocast(), torch.no_grad():
                model.backbone = torch.jit.trace(model.backbone,
                                                 x,
                                                 strict=False)
            model.backbone = torch.jit.freeze(model.backbone)
        else:
            with torch.no_grad():
                model.backbone = torch.jit.trace(model.backbone,
                                                 x,
                                                 strict=False)
            model.backbone = torch.jit.freeze(model.backbone)
    with torch.no_grad():
        if dataloader is None:
            print(
                "Models for detection tasks need to use real dataset. You need to specify coco dataset. "
            )
            exit(1)
        else:
            for i, batch in enumerate(dataloader):
                images = batch[0]
                if not args.ipex and not args.jit:
                    images = list(img.to(datatype) for img in images)
                if args.ipex and args.precision == "bf16":
                    with torch.cpu.amp.autocast():
                        if i == warmup_iters:
                            with profile(
                                    activities=[ProfilerActivity.CPU],
                                    record_shapes=True
                            ) as prof, record_function("model_inference"):
                                output = model(images)
                        else:
                            output = model(images)
                else:
                    if i == warmup_iters:
                        with profile(
                                activities=[ProfilerActivity.CPU],
                                record_shapes=True) as prof, record_function(
                                    "model_inference"):
                            output = model(images)
                    else:
                        output = model(images)
                if i > warmup_iters:
                    break
            for i, batch in enumerate(dataloader):
                images = batch[0]
                end = time.time()
                if not args.ipex and not args.jit:
                    images = list(img.to(datatype) for img in images)
                if args.ipex and args.precision == "bf16":
                    with torch.cpu.amp.autocast():
                        output = model(images)
                else:
                    output = model(images)
                batch_time.update(time.time() - end)
                output = [{k: v.to(torch.float32)
                           for k, v in t.items()} for t in output]
                res = {
                    target["image_id"].item(): output
                    for target, output in zip(batch[1], output)
                }
                coco_evaluator.update(res)
                if max_iters != -1 and i >= max_iters:
                    break
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=-1))
    latency = batch_time.avg / batch_size * 1000
    perf = batch_size / batch_time.avg
    coco_evaluator.synchronize_between_processes()
    coco_evaluator.accumulate()
    coco_evaluator.summarize()
    print("Bbox AP: {:.5f} ".format(coco_evaluator.coco_eval['bbox'].stats[0]))
    print("Segm AP: {:.5f} ".format(coco_evaluator.coco_eval['segm'].stats[0]))
    print('Latency: %.3f ms' % latency)
    print("Throughput: {:.3f} fps".format(perf))
コード例 #15
0
ファイル: inference.py プロジェクト: IntelAI/models
def inference(model, dataloader, datatype, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    batch_size = args.batch_size
    warmup_iters = args.warmup_iterations
    max_iters = args.max_iterations if dataloader is None else len(dataloader)
    progress = ProgressMeter(max_iters, [batch_time, losses, top1, top5],
                             prefix='Test: ')
    model.eval()
    if args.ipex:
        import intel_extension_for_pytorch as ipex
        model = model.to(memory_format=torch.channels_last)
        model = ipex.optimize(model, dtype=datatype, level="O1")
    else:
        if args.jit:
            model = model.to(memory_format=torch.channels_last)
        else:
            from torch.utils import mkldnn as mkldnn_utils
            model = mkldnn_utils.to_mkldnn(model, dtype=datatype)
    if args.jit:
        if dataloader is None:
            x = torch.randn(batch_size, 3, args.height, args.width)
        else:
            for i, batch in enumerate(dataloader):
                x = torch.randn(batch[0].shape)
                break
        x = x.to(memory_format=torch.channels_last)
        if args.precision == "bf16":
            with torch.cpu.amp.autocast(), torch.no_grad():
                model = torch.jit.trace(model, x, strict=False)
            model = torch.jit.freeze(model)
        else:
            with torch.no_grad():
                model = torch.jit.trace(model, x, strict=False)
            model = torch.jit.freeze(model)
    with torch.no_grad():
        if dataloader is None:
            for i in range(max_iters):
                images = torch.randn(batch_size, 3, args.height, args.width)
                if i > warmup_iters:
                    end = time.time()
                if not args.ipex and not args.jit:
                    images = images.to(datatype)
                else:
                    images = images.to(memory_format=torch.channels_last)
                if args.ipex and args.precision == "bf16" and not args.jit:
                    with torch.cpu.amp.autocast():
                        if i == warmup_iters:
                            with profile(
                                    activities=[ProfilerActivity.CPU],
                                    record_shapes=True
                            ) as prof, record_function("model_inference"):
                                output = model(images)
                        else:
                            output = model(images)
                else:
                    if i == warmup_iters:
                        with profile(
                                activities=[ProfilerActivity.CPU],
                                record_shapes=True) as prof, record_function(
                                    "model_inference"):
                            output = model(images)
                    else:
                        output = model(images)
                if i > warmup_iters:
                    batch_time.update(time.time() - end)
                if i % args.print_freq == 0:
                    progress.display(i)
        else:
            # warm up
            for i, (images, target) in enumerate(dataloader):
                if i > warmup_iters:
                    break
                if not args.ipex and not args.jit:
                    images = images.to(datatype).to(
                        memory_format=torch.channels_last)
                if args.ipex and args.precision == "bf16" and not args.jit:
                    with torch.cpu.amp.autocast():
                        if i == warmup_iters:
                            with profile(
                                    activities=[ProfilerActivity.CPU],
                                    record_shapes=True
                            ) as prof, record_function("model_inference"):
                                output = model(images)
                        else:
                            output = model(images)
                else:
                    if i == warmup_iters:
                        with profile(
                                activities=[ProfilerActivity.CPU],
                                record_shapes=True) as prof, record_function(
                                    "model_inference"):
                            output = model(images)
                    else:
                        output = model(images)

            criterion = nn.CrossEntropyLoss()
            for i, (images, target) in enumerate(dataloader):
                end = time.time()
                if not args.ipex and not args.jit:
                    images = images.to(datatype).to(
                        memory_format=torch.channels_last)
                if args.ipex and args.precision == "bf16" and not args.jit:
                    output = model(images)
                else:
                    output = model(images)
                batch_time.update(time.time() - end)
                if args.precision == "bf16":
                    output = output.to(torch.float32)
                loss = criterion(output, target)
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                losses.update(loss.item(), images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))
                if max_iters != -1 and i >= max_iters:
                    break
                if i % args.print_freq == 0:
                    progress.display(i)
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=-1))
    latency = batch_time.avg / batch_size * 1000
    perf = batch_size / batch_time.avg
    print('Latency: %.3f ms' % latency)
    print("Throughput: {:.3f} fps".format(perf))
    print("Accuracy: {top1.avg:.3f} ".format(top1=top1))
コード例 #16
0
#    - ``ProfilerActivity.CPU`` - PyTorch operators, TorchScript functions and
#      user-defined code labels (see ``record_function`` below);
#    - ``ProfilerActivity.CUDA`` - on-device CUDA kernels;
# - ``record_shapes`` - whether to record shapes of the operator inputs;
# - ``profile_memory`` - whether to report amount of memory consumed by
#   model's Tensors;
# - ``use_cuda`` - whether to measure execution time of CUDA kernels.
#
# Note: when using CUDA, profiler also shows the runtime CUDA events
# occuring on the host.

######################################################################
# Let's see how we can use profiler to analyze the execution time:

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("model_inference"):
        model(inputs)

######################################################################
# Note that we can use ``record_function`` context manager to label
# arbitrary code ranges with user provided names
# (``model_inference`` is used as a label in the example above).
#
# Profiler allows one to check which operators were called during the
# execution of a code range wrapped with a profiler context manager.
# If multiple profiler ranges are active at the same time (e.g. in
# parallel PyTorch threads), each profiling context manager tracks only
# the operators of its corresponding range.
# Profiler also automatically profiles the async tasks launched
# with ``torch.jit._fork`` and (in case of a backward pass)
# the backward pass operators launched with ``backward()`` call.
コード例 #17
0
            schedule=torch.profiler.schedule(wait=2,
                                             warmup=2,
                                             active=6,
                                             repeat=1),
            on_trace_ready=torch.profiler.tensorboard_trace_handler(
                'profiler'),
            with_stack=True) as profiler:
        overall_loss = 0
        for batch_idx, (x, _) in enumerate(train_loader):
            x = x.view(batch_size, x_dim)
            x = x.to(DEVICE)

            optimizer.zero_grad()

            x_hat, mean, log_var = model(x)
            with record_function("model_loss"):
                loss = loss_function(x, x_hat, mean, log_var)

            overall_loss += loss.item()

            with record_function("backward"):
                loss.backward()
                optimizer.step()
            profiler.step()

        print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ",
              overall_loss / (batch_idx * batch_size))
print("Finish!!")

# Generate reconstructions
model.eval()