示例#1
0
    def test_conv_bn_fusion(self):
        rn18 = resnet18().eval()
        traced = symbolic_trace(rn18)
        fused = optimization.fuse(traced)

        self.assertTrue(all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()))

        N, C, H, W = 20, 3, 224, 224
        inp = torch.randn(N, C, H, W)

        self.assertEqual(fused(inp), rn18(inp))
示例#2
0
文件: main.py 项目: IntelAI/models
def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    if args.dummy:
        number_iter = args.steps if args.steps > 0 else 200
        if args.int8:
            number_iter = args.steps if args.steps > 0 else 200
    else:
        number_iter = args.steps if args.steps > 0 else len(val_loader)
    if args.calibration:
        number_iter = 100

    progress = ProgressMeter(number_iter, [batch_time, losses, top1, top5],
                             prefix='Test: ')
    print('Evaluating RESNET: total Steps: {}'.format(number_iter))

    # switch to evaluate mode
    model.eval()

    if args.ipex and args.int8 and args.calibration:
        model = optimization.fuse(model)
        print("runing int8 calibration step\n")
        conf = ipex.QuantConf(qscheme=torch.per_tensor_symmetric)
        with torch.no_grad():
            for i, (images, target) in enumerate(val_loader):
                with ipex.quantization.calibrate(conf):
                    # compute output
                    images = images.contiguous(
                        memory_format=torch.channels_last)
                    output = model(images)
                    if i == 120:
                        break

            conf.save(args.configure_dir)
            print(".........calibration step done..........")
    else:
        if args.dummy:
            # always running channle last for fp32, bf16, int8
            with torch.no_grad():
                if args.weight_sharing:
                    threads = []
                    for i in range(1, args.number_instance + 1):
                        thread = threading.Thread(
                            target=run_weights_sharing_model,
                            args=(model, i, args))
                        threads.append(thread)
                        thread.start()
                    for thread in threads:
                        thread.join()
                    exit()
                else:
                    images = torch.randn(args.batch_size, 3, 224, 224)
                    if args.ipex:
                        images = images.contiguous(
                            memory_format=torch.channels_last)
                    target = torch.arange(1, args.batch_size + 1).long()
                    if args.bf16:
                        images = images.to(torch.bfloat16)

                    for i in range(number_iter):
                        if i >= args.warmup_iterations:
                            end = time.time()
                        if not args.jit and args.bf16:
                            with torch.cpu.amp.autocast():
                                output = model(images)
                        else:
                            output = model(images)

                        if i >= args.warmup_iterations:
                            batch_time.update(time.time() - end)

                        if args.bf16:
                            output = output.to(torch.float32)
                        loss = criterion(output, target)
                        # measure accuracy and record loss
                        acc1, acc5 = accuracy(output, target, topk=(1, 5))
                        losses.update(loss.item(), images.size(0))
                        top1.update(acc1[0], images.size(0))
                        top5.update(acc5[0], images.size(0))

                        if i % args.print_freq == 0:
                            progress.display(i)
        else:
            with torch.no_grad():
                for i, (images, target) in enumerate(val_loader):
                    end = time.time()
                    if args.ipex:
                        images = images.contiguous(
                            memory_format=torch.channels_last)
                    if args.bf16:
                        images = images.to(torch.bfloat16)
                    if not args.jit and args.bf16:
                        with torch.cpu.amp.autocast():
                            output = model(images)
                    else:
                        output = model(images)

                    # compute output
                    batch_time.update(time.time() - end)
                    #print(output)
                    if args.bf16:
                        output = output.to(torch.float32)
                    loss = criterion(output, target)
                    # measure accuracy and record loss
                    acc1, acc5 = accuracy(output, target, topk=(1, 5))
                    losses.update(loss.item(), images.size(0))
                    top1.update(acc1[0], images.size(0))
                    top5.update(acc5[0], images.size(0))

                    if i % args.print_freq == 0:
                        progress.display(i)

        if args.weight_sharing:
            latency = stats.latency_avg_ms
            perf = stats.iters_per_second
        else:
            batch_size = args.batch_size
            latency = batch_time.avg / batch_size * 1000
            perf = batch_size / batch_time.avg

        print('inference latency %.3f ms' % latency)
        print("Throughput: {:.3f} fps".format(perf))
        print("Accuracy: {top1.avg:.3f} ".format(top1=top1))

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1,
                                                                    top5=top5))

    return top1.avg
示例#3
0
文件: main.py 项目: IntelAI/models
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        os.environ['RANK'] = str(os.environ.get('PMI_RANK', args.rank))
        os.environ['WORLD_SIZE'] = str(
            os.environ.get('PMI_SIZE', args.world_size))
        os.environ['MASTER_ADDR'] = args.master_addr
        os.environ['MASTER_PORT'] = args.port
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu

        # Initialize the process group with ccl backend
        if args.dist_backend == 'ccl':
            import torch_ccl
        dist.init_process_group(backend=args.dist_backend)
        #dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
        #                        world_size=args.world_size, rank=args.rank)
    if args.hub:
        torch.set_flush_denormal(True)
        model = torch.hub.load('facebookresearch/WSL-Images', args.arch)
    else:
        # create model
        if args.pretrained:
            print("=> using pre-trained model '{}'".format(args.arch))
            model = models.__dict__[args.arch](pretrained=True)
        else:
            print("=> creating model '{}'".format(args.arch))
            model = models.__dict__[args.arch]()

    if args.ipex:
        import intel_extension_for_pytorch as ipex
    # for ipex path, always convert model to channels_last for bf16, fp32.
    # TODO: int8 path: https://jira.devtools.intel.com/browse/MFDNN-6103
    if args.ipex and not args.int8:
        model = model.to(memory_format=torch.channels_last)

    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None and args.cuda:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            if args.cuda:
                model.cuda()
                print("create DistributedDataParallel in GPU")
                # DistributedDataParallel will divide and allocate batch_size to all
                # available GPUs if device_ids are not set
                model = torch.nn.parallel.DistributedDataParallel(model)
            else:
                print("create DistributedDataParallel in CPU")
    elif args.gpu is not None and args.cuda:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            if args.cuda:
                model.cuda()
        else:
            model = torch.nn.DataParallel(model)
            if args.cuda():
                model.cuda()

    # define loss function (criterion) and optimizer

    criterion = nn.CrossEntropyLoss()
    if args.cuda:
        criterion = criterion.cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None and args.cuda:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None and args.cuda:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.cuda:
        cudnn.benchmark = True

    if args.weight_sharing:
        assert args.dummy and args.batch_size, \
                "please using dummy data and set batch_size to 1 if you want run weight sharing case for latency case"
    if args.jit and args.int8:
        assert False, "jit path is not available for int8 path using ipex"
    if args.calibration:
        assert args.int8, "please enable int8 path if you want to do int8 calibration path"
    if args.dummy:
        assert args.evaluate, "please using real dataset if you want run training path"
    if not args.ipex:
        # for offical pytorch, int8 and jit path is not enabled.
        assert not args.int8, "int8 path is not enabled for offical pytorch"
        assert not args.jit, "jit path is not enabled for offical pytorch"

    if not args.dummy:
        # Data loading code
        assert args.data != None, "please set dataset path if you want to using real data"
        valdir = os.path.join(args.data, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        if not args.evaluate:
            traindir = os.path.join(args.data, 'train')
            train_dataset = datasets.ImageFolder(
                traindir,
                transforms.Compose([
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ]))

            if args.distributed:
                train_sampler = torch.utils.data.distributed.DistributedSampler(
                    train_dataset)
            else:
                train_sampler = None

            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=(train_sampler is None),
                num_workers=args.workers,
                pin_memory=True,
                sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
    else:
        train_loader = None
        val_loader = None

    if args.evaluate:
        if args.ipex:
            print("using ipex model to do inference\n")
        else:
            print("using offical pytorch model to do inference\n")

        if args.ipex:
            model.eval()
            if args.int8:
                if not args.calibration:
                    model = optimization.fuse(model, inplace=True)
                    conf = ipex.quantization.QuantConf(args.configure_dir)
                    x = torch.randn(
                        args.batch_size, 3, 224,
                        224).contiguous(memory_format=torch.channels_last)
                    model = ipex.quantization.convert(model, conf, x)
                    with torch.no_grad():
                        y = model(x)
                        print(model.graph_for(x))
                    print("running int8 evalation step\n")
            else:
                if args.bf16:
                    model = ipex.optimize(model,
                                          dtype=torch.bfloat16,
                                          inplace=True)
                    print("running bfloat16 evalation step\n")
                else:
                    model = ipex.optimize(model,
                                          dtype=torch.float32,
                                          inplace=True)
                    print("running fp32 evalation step\n")
                if args.jit:
                    x = torch.randn(
                        args.batch_size, 3, 224,
                        224).contiguous(memory_format=torch.channels_last)
                    if args.bf16:
                        x = x.to(torch.bfloat16)
                        with torch.cpu.amp.autocast(), torch.no_grad():
                            model = torch.jit.trace(model, x).eval()
                    else:
                        with torch.no_grad():
                            model = torch.jit.trace(model, x).eval()
                    model = torch.jit.freeze(model)
        validate(val_loader, model, criterion, args)
        return

    if args.ipex:
        if args.bf16:
            model, optimizer = ipex.optimize(model,
                                             dtype=torch.bfloat16,
                                             optimizer=optimizer)
        else:
            model, optimizer = ipex.optimize(model,
                                             dtype=torch.float32,
                                             optimizer=optimizer)

    # parallelize
    if args.distributed and not args.cuda and args.gpu is None:
        print("create DistributedDataParallel in CPU")
        device_ids = None
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=device_ids)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
示例#4
0
def coco_eval(model, val_dataloader, cocoGt, encoder, inv_map, args):
    from pycocotools.cocoeval import COCOeval
    device = args.device
    threshold = args.threshold
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    model.eval()

    ret = []

    inference_time = AverageMeter('InferenceTime', ':6.3f')
    decoding_time = AverageMeter('DecodingTime', ':6.3f')

    if (args.calibration or args.accuracy_mode or args.throughput_mode
            or args.latency_mode) is False:
        print(
            "one of --calibration, --accuracy-mode, --throughput-mode, --latency-mode must be input."
        )
        exit(-1)

    if args.accuracy_mode:
        if args.iteration is not None:
            print("accuracy mode should not input iteration")
        progress_meter_iteration = len(val_dataloader)
        epoch_number = 1
    else:
        if args.iteration is None:
            print("None accuracy mode must input --iteration")
            exit(-1)
        progress_meter_iteration = args.iteration
        epoch_number = (args.iteration // len(val_dataloader) +
                        1) if (args.iteration > len(val_dataloader)) else 1

    progress = ProgressMeter(progress_meter_iteration,
                             [inference_time, decoding_time],
                             prefix='Test: ')

    Profilling_iterator = 99
    start = time.time()
    if args.int8:
        model = model.eval()
        print('int8 conv_bn_fusion enabled')
        with torch.no_grad():
            model.model = optimization.fuse(model.model, inplace=False)

            if args.calibration:
                print("runing int8 LLGA calibration step\n")
                conf = ipex.quantization.QuantConf(
                    qscheme=torch.per_tensor_affine
                )  # qscheme can is torch.per_tensor_affine, torch.per_tensor_symmetric
                with torch.no_grad():
                    for nbatch, (img, img_id, img_size, bbox,
                                 label) in enumerate(val_dataloader):
                        print("nbatch:{}".format(nbatch))
                        with ipex.quantization.calibrate(conf):
                            ploc, plabel = model(img)
                        if nbatch == args.iteration:
                            break
                conf.save(args.configure)
                return

            else:
                print("INT8 LLGA start trace")
                # insert quant/dequant based on configure.json
                conf = ipex.quantization.QuantConf(
                    configure_file=args.configure)
                model = ipex.quantization.convert(
                    model, conf,
                    torch.randn(args.batch_size, 3, 1200,
                                1200).to(memory_format=torch.channels_last))
                print("done ipex default recipe.......................")
                # freeze the module
                # model = torch.jit._recursive.wrap_cpp_module(torch._C._freeze_module(model._c, preserveParameters=True))

                # After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
                # At the 2nd run, the llga pass will be triggered and the model is turned into an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
                with torch.no_grad():
                    for i in range(2):
                        #_, _ = model(torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last))
                        _, _ = model(
                            torch.randn(
                                args.batch_size, 3, 1200,
                                1200).to(memory_format=torch.channels_last))

                print('runing int8 real inputs inference path')
                with torch.no_grad():
                    total_iteration = 0
                    for epoch in range(epoch_number):
                        for nbatch, (img, img_id, img_size, bbox,
                                     label) in enumerate(val_dataloader):
                            img = img.to(memory_format=torch.channels_last)
                            if total_iteration >= args.warmup_iterations:
                                start_time = time.time()

                            if args.profile and total_iteration == Profilling_iterator:
                                print("Profilling")
                                with torch.profiler.profile(
                                        on_trace_ready=torch.profiler.
                                        tensorboard_trace_handler(
                                            "./int8_log")) as prof:
                                    # ploc, plabel = model(img.to(memory_format=torch.channels_last))
                                    ploc, plabel = model(img)
                                print(prof.key_averages().table(
                                    sort_by="self_cpu_time_total"))
                            else:
                                #ploc, plabel = model(img.to(memory_format=torch.channels_last))
                                ploc, plabel = model(img)

                            if total_iteration >= args.warmup_iterations:
                                inference_time.update(time.time() - start_time)
                                end_time = time.time()
                            try:
                                results_raw = encoder.decode_batch(
                                    ploc, plabel, 0.50, 200, device=device)
                            except:
                                print(
                                    "No object detected in nbatch: {}".format(
                                        total_iteration))
                                continue
                            if total_iteration >= args.warmup_iterations:
                                decoding_time.update(time.time() - end_time)

                            # Re-assembly the result
                            results = []
                            idx = 0
                            for i in range(results_raw[3].size(0)):
                                results.append(
                                    (results_raw[0][idx:idx +
                                                    results_raw[3][i]],
                                     results_raw[1][idx:idx +
                                                    results_raw[3][i]],
                                     results_raw[2][idx:idx +
                                                    results_raw[3][i]]))
                                idx += results_raw[3][i]

                            (htot, wtot) = [d.cpu().numpy() for d in img_size]
                            img_id = img_id.cpu().numpy()
                            # Iterate over batch elements
                            for img_id_, wtot_, htot_, result in zip(
                                    img_id, wtot, htot, results):
                                loc, label, prob = [
                                    r.cpu().numpy() for r in result
                                ]
                                # Iterate over image detections
                                for loc_, label_, prob_ in zip(
                                        loc, label, prob):
                                    ret.append([img_id_, loc_[0]*wtot_, \
                                                loc_[1]*htot_,
                                                (loc_[2] - loc_[0])*wtot_,
                                                (loc_[3] - loc_[1])*htot_,
                                                prob_,
                                                inv_map[label_]])

                            if total_iteration % args.print_freq == 0:
                                progress.display(total_iteration)
                            if total_iteration == args.iteration:
                                break
                            total_iteration += 1
    else:
        if args.dummy:
            print('dummy inputs inference path is not supported')
        else:
            print('runing real inputs path')
            if args.autocast:
                print('bf16 autocast enabled')
                print('enable nhwc')
                model = model.to(memory_format=torch.channels_last)
                if use_ipex:
                    print('bf16 block format weights cache enabled')
                    model.model = ipex.optimize(model.model,
                                                dtype=torch.bfloat16,
                                                inplace=False)
                    # model = ipex.utils._convert_module_data_type(model, torch.bfloat16)
                else:
                    from oob_utils import conv_bn_fuse
                    print('OOB bf16 conv_bn_fusion enabled')
                    model.model = conv_bn_fuse(model.model)

                if args.jit:
                    if use_ipex:
                        print('enable IPEX jit path')
                        with torch.cpu.amp.autocast(), torch.no_grad():
                            model = torch.jit.trace(
                                model,
                                torch.randn(args.batch_size, 3, 1200, 1200).to(
                                    memory_format=torch.channels_last)).eval()
                    else:
                        print('enable OOB jit path')
                        with torch.cpu.amp.autocast(
                                cache_enabled=False), torch.no_grad():
                            model = torch.jit.trace(
                                model,
                                torch.randn(args.batch_size, 3, 1200, 1200).to(
                                    memory_format=torch.channels_last)).eval()

                    model = torch.jit.freeze(model)
                    with torch.no_grad():
                        total_iteration = 0
                        for epoch in range(epoch_number):
                            for nbatch, (img, img_id, img_size, bbox,
                                         label) in enumerate(val_dataloader):
                                with torch.no_grad():
                                    if total_iteration >= args.warmup_iterations:
                                        start_time = time.time()

                                    img = img.to(
                                        memory_format=torch.channels_last)
                                    if args.profile and total_iteration == Profilling_iterator:
                                        print("Profilling")
                                        with torch.profiler.profile(
                                                on_trace_ready=torch.profiler.
                                                tensorboard_trace_handler(
                                                    "./log")) as prof:
                                            ploc, plabel = model(img)
                                        print(prof.key_averages().table(
                                            sort_by="self_cpu_time_total"))
                                    else:
                                        ploc, plabel = model(img)
                                    if total_iteration >= args.warmup_iterations:
                                        inference_time.update(time.time() -
                                                              start_time)
                                        end_time = time.time()

                                    try:
                                        if args.profile and total_iteration == Profilling_iterator:
                                            with torch.profiler.profile(
                                                    on_trace_ready=torch.
                                                    profiler.
                                                    tensorboard_trace_handler(
                                                        "./decode_log"
                                                    )) as prof:
                                                results_raw = encoder.decode_batch(
                                                    ploc,
                                                    plabel,
                                                    0.50,
                                                    200,
                                                    device=device)
                                            print(prof.key_averages().table(
                                                sort_by="self_cpu_time_total"))
                                        else:
                                            results_raw = encoder.decode_batch(
                                                ploc,
                                                plabel,
                                                0.50,
                                                200,
                                                device=device)
                                    except:
                                        print(
                                            "No object detected in nbatch: {}".
                                            format(total_iteration))
                                        continue
                                    if total_iteration >= args.warmup_iterations:
                                        decoding_time.update(time.time() -
                                                             end_time)

                                    # Re-assembly the result
                                    results = []
                                    idx = 0
                                    for i in range(results_raw[3].size(0)):
                                        results.append((
                                            results_raw[0][idx:idx +
                                                           results_raw[3][i]],
                                            results_raw[1][idx:idx +
                                                           results_raw[3][i]],
                                            results_raw[2][idx:idx +
                                                           results_raw[3][i]]))
                                        idx += results_raw[3][i]

                                    (htot, wtot) = [
                                        d.cpu().numpy() for d in img_size
                                    ]
                                    img_id = img_id.cpu().numpy()

                                    for img_id_, wtot_, htot_, result in zip(
                                            img_id, wtot, htot, results):
                                        loc, label, prob = [
                                            r.cpu().numpy() for r in result
                                        ]
                                        # Iterate over image detections
                                        for loc_, label_, prob_ in zip(
                                                loc, label, prob):
                                            ret.append([img_id_, loc_[0]*wtot_, \
                                                        loc_[1]*htot_,
                                                        (loc_[2] - loc_[0])*wtot_,
                                                        (loc_[3] - loc_[1])*htot_,
                                                        prob_,
                                                        inv_map[label_]])

                                    if total_iteration % args.print_freq == 0:
                                        progress.display(total_iteration)
                                    if total_iteration == args.iteration:
                                        break
                                    total_iteration += 1
                else:
                    if use_ipex:
                        print('Ipex Autocast imperative path')
                        with torch.cpu.amp.autocast(), torch.no_grad():
                            total_iteration = 0
                            for epoch in range(epoch_number):
                                for nbatch, (
                                        img, img_id, img_size, bbox,
                                        label) in enumerate(val_dataloader):
                                    with torch.no_grad():
                                        if total_iteration >= args.warmup_iterations:
                                            start_time = time.time()
                                        img = img.contiguous(
                                            memory_format=torch.channels_last)
                                        if args.profile and total_iteration == Profilling_iterator:
                                            print("Profilling")
                                            with torch.profiler.profile(
                                                    on_trace_ready=torch.
                                                    profiler.
                                                    tensorboard_trace_handler(
                                                        "./bf16_imperative_log"
                                                    )) as prof:
                                                ploc, plabel = model(img)
                                            print(prof.key_averages().table(
                                                sort_by="self_cpu_time_total"))
                                        else:
                                            ploc, plabel = model(img)

                                        if total_iteration >= args.warmup_iterations:
                                            inference_time.update(time.time() -
                                                                  start_time)
                                            end_time = time.time()

                                        try:
                                            results_raw = encoder.decode_batch(
                                                ploc,
                                                plabel,
                                                0.50,
                                                200,
                                                device=device)
                                        except:
                                            print(
                                                "No object detected in total_iteration: {}"
                                                .format(total_iteration))
                                            continue
                                        if total_iteration >= args.warmup_iterations:
                                            decoding_time.update(time.time() -
                                                                 end_time)

                                        # Re-assembly the result
                                        results = []
                                        idx = 0
                                        for i in range(results_raw[3].size(0)):
                                            results.append(
                                                (results_raw[0]
                                                 [idx:idx + results_raw[3][i]],
                                                 results_raw[1]
                                                 [idx:idx + results_raw[3][i]],
                                                 results_raw[2]
                                                 [idx:idx +
                                                  results_raw[3][i]]))
                                            idx += results_raw[3][i]

                                        (htot, wtot) = [
                                            d.cpu().numpy() for d in img_size
                                        ]
                                        img_id = img_id.cpu().numpy()

                                        for img_id_, wtot_, htot_, result in zip(
                                                img_id, wtot, htot, results):
                                            loc, label, prob = [
                                                r.cpu().numpy() for r in result
                                            ]
                                            # Iterate over image detections
                                            for loc_, label_, prob_ in zip(
                                                    loc, label, prob):
                                                ret.append([img_id_, loc_[0]*wtot_, \
                                                            loc_[1]*htot_,
                                                            (loc_[2] - loc_[0])*wtot_,
                                                            (loc_[3] - loc_[1])*htot_,
                                                            prob_,
                                                            inv_map[label_]])

                                        if total_iteration % args.print_freq == 0:
                                            progress.display(total_iteration)
                                        if total_iteration == args.iteration:
                                            break
                                        total_iteration += 1
                    else:
                        print("OOB Autocast imperative path")
                        with torch.cpu.amp.autocast(), torch.no_grad():
                            total_iteration = 0
                            for epoch in range(epoch_number):
                                for nbatch, (
                                        img, img_id, img_size, bbox,
                                        label) in enumerate(val_dataloader):
                                    if total_iteration >= args.warmup_iterations:
                                        start_time = time.time()
                                    img = img.contiguous(
                                        memory_format=torch.channels_last)
                                    if args.profile and total_iteration == Profilling_iterator:
                                        print("Profilling")
                                        with torch.profiler.profile(
                                                on_trace_ready=torch.profiler.
                                                tensorboard_trace_handler(
                                                    "./bf16_oob_log")) as prof:
                                            ploc, plabel = model(img)
                                        print(prof.key_averages().table(
                                            sort_by="self_cpu_time_total"))
                                    else:
                                        ploc, plabel = model(img)

                                    if total_iteration >= args.warmup_iterations:
                                        inference_time.update(time.time() -
                                                              start_time)
                                        end_time = time.time()

                                    with torch.cpu.amp.autocast(enabled=False):
                                        try:
                                            results_raw = encoder.decode_batch(
                                                ploc,
                                                plabel,
                                                0.50,
                                                200,
                                                device=device)
                                        except:
                                            print(
                                                "No object detected in total_iteration: {}"
                                                .format(total_iteration))
                                            continue
                                        if total_iteration >= args.warmup_iterations:
                                            decoding_time.update(time.time() -
                                                                 end_time)

                                        # Re-assembly the result
                                        results = []
                                        idx = 0
                                        for i in range(results_raw[3].size(0)):
                                            results.append(
                                                (results_raw[0]
                                                 [idx:idx + results_raw[3][i]],
                                                 results_raw[1]
                                                 [idx:idx + results_raw[3][i]],
                                                 results_raw[2]
                                                 [idx:idx +
                                                  results_raw[3][i]]))
                                            idx += results_raw[3][i]

                                        (htot, wtot) = [
                                            d.cpu().numpy() for d in img_size
                                        ]
                                        img_id = img_id.cpu().numpy()

                                        for img_id_, wtot_, htot_, result in zip(
                                                img_id, wtot, htot, results):
                                            loc, label, prob = [
                                                r.cpu().numpy() for r in result
                                            ]
                                            # Iterate over image detections
                                            for loc_, label_, prob_ in zip(
                                                    loc, label, prob):
                                                ret.append([img_id_, loc_[0]*wtot_, \
                                                            loc_[1]*htot_,
                                                            (loc_[2] - loc_[0])*wtot_,
                                                            (loc_[3] - loc_[1])*htot_,
                                                            prob_,
                                                            inv_map[label_]])

                                        if total_iteration % args.print_freq == 0:
                                            progress.display(total_iteration)
                                        if total_iteration == args.iteration:
                                            break
                                        total_iteration += 1
            else:
                print('autocast disabled, fp32 is used')
                print('enable nhwc')
                model = model.to(memory_format=torch.channels_last)
                if use_ipex:
                    print('fp32 block format weights cache enabled')
                    model.model = ipex.optimize(model.model,
                                                dtype=torch.float32,
                                                inplace=False)
                if args.jit:
                    print("enable jit")
                    with torch.no_grad():
                        model = torch.jit.trace(
                            model,
                            torch.randn(
                                args.batch_size, 3, 1200,
                                1200).to(memory_format=torch.channels_last))
                    model = torch.jit.freeze(model)
                with torch.no_grad():
                    total_iteration = 0
                    for epoch in range(epoch_number):
                        for nbatch, (img, img_id, img_size, bbox,
                                     label) in enumerate(val_dataloader):
                            if total_iteration >= args.warmup_iterations:
                                start_time = time.time()

                            img = img.contiguous(
                                memory_format=torch.channels_last)
                            if args.profile and total_iteration == Profilling_iterator:
                                print("Profilling")
                                with torch.profiler.profile(
                                        on_trace_ready=torch.profiler.
                                        tensorboard_trace_handler(
                                            "./fp32_log")) as prof:
                                    ploc, plabel = model(img)
                                print(prof.key_averages().table(
                                    sort_by="self_cpu_time_total"))
                            else:
                                ploc, plabel = model(img)
                            if total_iteration >= args.warmup_iterations:
                                inference_time.update(time.time() - start_time)
                                end_time = time.time()
                            try:
                                if args.profile and total_iteration == Profilling_iterator:
                                    with torch.profiler.profile(
                                            on_trace_ready=torch.profiler.
                                            tensorboard_trace_handler(
                                                "./fp32_decode_log")) as prof:
                                        results_raw = encoder.decode_batch(
                                            ploc,
                                            plabel,
                                            0.50,
                                            200,
                                            device=device)
                                    print(prof.key_averages().table(
                                        sort_by="self_cpu_time_total"))
                                else:
                                    results_raw = encoder.decode_batch(
                                        ploc, plabel, 0.50, 200, device=device)
                            except:
                                print(
                                    "No object detected in total_iteration: {}"
                                    .format(total_iteration))
                                continue
                            if total_iteration >= args.warmup_iterations:
                                decoding_time.update(time.time() - end_time)

                            # Re-assembly the result
                            results = []
                            idx = 0
                            for i in range(results_raw[3].size(0)):
                                results.append(
                                    (results_raw[0][idx:idx +
                                                    results_raw[3][i]],
                                     results_raw[1][idx:idx +
                                                    results_raw[3][i]],
                                     results_raw[2][idx:idx +
                                                    results_raw[3][i]]))
                                idx += results_raw[3][i]

                            (htot, wtot) = [d.cpu().numpy() for d in img_size]
                            img_id = img_id.cpu().numpy()
                            # Iterate over batch elements
                            for img_id_, wtot_, htot_, result in zip(
                                    img_id, wtot, htot, results):
                                loc, label, prob = [
                                    r.cpu().numpy() for r in result
                                ]
                                # Iterate over image detections
                                for loc_, label_, prob_ in zip(
                                        loc, label, prob):
                                    ret.append([img_id_, loc_[0]*wtot_, \
                                                loc_[1]*htot_,
                                                (loc_[2] - loc_[0])*wtot_,
                                                (loc_[3] - loc_[1])*htot_,
                                                prob_,
                                                inv_map[label_]])

                            if total_iteration % args.print_freq == 0:
                                progress.display(total_iteration)
                            if total_iteration == args.iteration:
                                break
                            total_iteration += 1
    print("Predicting Ended, total time: {:.2f} s".format(time.time() - start))

    batch_size = args.batch_size
    latency = inference_time.avg / batch_size * 1000
    perf = batch_size / inference_time.avg
    print('inference latency %.2f ms' % latency)
    print('inference performance %.2f fps' % perf)

    if not args.dummy:
        latency = decoding_time.avg / batch_size * 1000
        perf = batch_size / decoding_time.avg
        print('decoding latency %.2f ms' % latency)
        print('decodingperformance %.2f fps' % perf)

        total_time_avg = inference_time.avg + decoding_time.avg
        throughput = batch_size / total_time_avg
        print("Throughput: {:.3f} fps".format(throughput))

        if not args.accuracy_mode:
            return True

        cocoDt = cocoGt.loadRes(np.array(ret))

        E = COCOeval(cocoGt, cocoDt, iouType='bbox')
        E.evaluate()
        E.accumulate()
        E.summarize()
        print("Current AP: {:.5f} AP goal: {:.5f}".format(
            E.stats[0], threshold))
        print("Accuracy: {:.5f} ".format(E.stats[0]))

        return (
            E.stats[0] >= threshold
        )  #Average Precision  (AP) @[ IoU=050:0.95 | area=   all | maxDets=100 ]
    else:
        total_time_avg = inference_time.avg
        throughput = batch_size / total_time_avg
        print("Throughput: {:.3f} fps".format(throughput))
        return False