示例#1
0
def trace_model(args, dlrm, test_ld):
    dlrm.eval()
    for j, inputBatch in enumerate(test_ld):
        X, lS_o, lS_i, _, _, _ = unpack_batch(inputBatch)
        if args.bf16:
            # at::GradMode::is_enabled() will query a threadlocal flag
            # but new thread generate from throughputbench mark will
            # init this flag to true, so we temporal cast embedding's
            # weight to bfloat16 for now
            if args.inference_only:
                dlrm.emb_l.bfloat16()
            dlrm = ipex.optimize(dlrm, dtype=torch.bfloat16, inplace=True)
        elif args.int8:
            conf = ipex.quantization.QuantConf(args.int8_configure)
            dlrm = ipex.quantization.convert(dlrm, conf, (X, lS_o, lS_i))
        else:
            dlrm = ipex.optimize(dlrm, dtype=torch.float, inplace=True)
        if args.int8:
            dlrm = freeze(dlrm)
        else:
            with torch.cpu.amp.autocast(enabled=args.bf16):
                dlrm = torch.jit.trace(dlrm, (X, lS_o, lS_i), check_trace=True)
                dlrm = torch.jit.freeze(dlrm)
        dlrm(X, lS_o, lS_i)
        dlrm(X, lS_o, lS_i)
        return dlrm
示例#2
0
def compute_on_dataset(model, data_loader, device, bbox_aug, timer=None, bf16=False, jit=False, iterations=-1, iter_warmup=-1, enable_profiling=False):
    model.eval()
    results_dict = {}
    cpu_device = torch.device("cpu")
    steps_per_epoch = len(data_loader)
    iter_warmup = max(0, iter_warmup)
    total_steps = (iterations if iterations > 0 else steps_per_epoch) + iter_warmup
    test_epoches = int(total_steps / steps_per_epoch)
    print('Evaluating MaskRCNN: Steps per Epoch {} total Steps {}'.format(steps_per_epoch, total_steps))

    model = model.to(memory_format=torch.channels_last)
    model.backbone = ipex.optimize(model.backbone, dtype=torch.bfloat16 if bf16 else torch.float, inplace=True)
    model.rpn = ipex.optimize(model.rpn, dtype=torch.bfloat16 if bf16 else torch.float, inplace=True)
    model.roi_heads = ipex.optimize(model.roi_heads, dtype=torch.bfloat16 if bf16 else torch.float, inplace=True)

    with torch.cpu.amp.autocast(enabled=bf16), torch.no_grad():
        # generate trace model
        if jit:
            print("generate trace model")
            for i, batch in enumerate(tqdm(data_loader)):
                images, targets, image_ids = batch
                model.backbone = torch.jit.trace(model.backbone, images.tensors.to(memory_format=torch.channels_last))
                model.backbone = torch.jit.freeze(model.backbone)
                trace_graph = model.backbone.graph_for(images.tensors.to(memory_format=torch.channels_last))
                print(trace_graph)
                break
        # Inference
        print("runing inference step")
        with torch.autograd.profiler.profile(enable_profiling) as prof:
            with tqdm(total=total_steps, desc="Evaluating") as pbar:
                for epoch in range(test_epoches + 1):
                    for i, batch in enumerate(data_loader):
                        if epoch * steps_per_epoch + i >= total_steps:
                            break
                        images, targets, image_ids = batch
                        images = images.to(memory_format=torch.channels_last)
                        
                        if bf16:
                            images = images.to(torch.bfloat16)
                        if timer and epoch * steps_per_epoch + i >= iter_warmup:
                            timer.tic()
                        if bbox_aug:
                            output = im_detect_bbox_aug(model, images, device)
                        else:
                            output = model(images)
                        if timer and epoch * steps_per_epoch + i >= iter_warmup:
                            timer.toc()
                        output = [o.to(cpu_device) for o in output]
                        results_dict.update(
                            {img_id: result for img_id, result in zip(image_ids, output)}
                        )
                        pbar.update(1)
        if enable_profiling:
            print(prof.key_averages().table(sort_by="self_cpu_time_total"))
    return results_dict
示例#3
0
    def initialize(self, context):
        """Initialize function loads the model.pt file and initialized the model object.
           First try to load torchscript else load eager mode state_dict based model.

        Args:
            context (context): It is a JSON Object containing information
            pertaining to the model artifacts parameters.

        Raises:
            RuntimeError: Raises the Runtime error when the model.py is missing

        """
        properties = context.system_properties
        self.map_location = "cuda" if torch.cuda.is_available(
        ) and properties.get("gpu_id") is not None else "cpu"
        self.device = torch.device(
            self.map_location + ":" + str(properties.get("gpu_id"))
            if torch.cuda.is_available() and properties.get("gpu_id") is not None
            else self.map_location
        )
        self.manifest = context.manifest

        model_dir = properties.get("model_dir")
        model_pt_path = None
        if "serializedFile" in self.manifest["model"]:
            serialized_file = self.manifest["model"]["serializedFile"]
            model_pt_path = os.path.join(model_dir, serialized_file)

        # model def file
        model_file = self.manifest["model"].get("modelFile", "")

        if model_file:
            logger.debug("Loading eager model")
            self.model = self._load_pickled_model(
                model_dir, model_file, model_pt_path)
            self.model.to(self.device)
        else:
            logger.debug("Loading torchscript model")
            if not os.path.isfile(model_pt_path):
                raise RuntimeError("Missing the model.pt file")

            self.model = self._load_torchscript_model(model_pt_path)

        self.model.eval()
        if ipex_enabled:
            self.model = self.model.to(memory_format=torch.channels_last)
            self.model = ipex.optimize(self.model)

        logger.debug('Model file %s loaded successfully', model_pt_path)

        # Load class mapping for classifiers
        mapping_file_path = os.path.join(model_dir, "index_to_name.json")
        self.mapping = load_label_mapping(mapping_file_path)

        self.initialized = True
def main():
    '''
    The following 3 components are required to perform training.
    1. model: Instantiate model class
    2. optim: Optimization function for update topology parameters during training
    3. crite: Criterion function to minimize loss
    '''
    model = TestModel()
    model = model.to(memory_format=torch.channels_last)
    optim = torch.optim.SGD(model.parameters(), lr=0.01)
    crite = nn.MSELoss(reduction='sum')
    '''
    1. Instantiate the Dataset class defined before
    2. Use torch.utils.data.DataLoader to load data from the Dataset instance
    '''
    train_data = TestDataset()
    trainLoader = DataLoader(train_data, batch_size=BS_TRAIN)
    test_data = TestDataset(train=False)
    testLoader = DataLoader(test_data, batch_size=BS_TEST)
    '''
    Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
    '''
    model, optim = ipex.optimize(model, optimizer=optim)
    '''
    Perform training and inference
    Use model.train() to set the model into train mode. Use model.eval() to set the model into inference mode.
    Use for loop with enumerate(instance of DataLoader) to go through the whole dataset for training/inference.
    '''
    for i in range(0, EPOCHNUM - 1):
        '''
        Iterate dataset for training to train the model
        '''
        model.train()
        for batch_index, (data, y_ans) in enumerate(trainLoader):
            data = data.to(memory_format=torch.channels_last)
            optim.zero_grad()
            y = model(data)
            loss = crite(y, y_ans)
            loss.backward()
            optim.step()
        '''
        Iterate dataset for validation to evaluate the model
        '''
        model.eval()
        for batch_index, data in enumerate(testLoader):
            data = data.to(memory_format=torch.channels_last)
            y = model(data)
示例#5
0
def inference(model, dataloader, datatype, args):
    batch_time = AverageMeter('Time', ':6.3f')
    batch_size = args.batch_size
    warmup_iters = args.warmup_iterations
    max_iters = args.max_iterations if dataloader is None else len(dataloader)
    model.eval()
    coco = get_coco_api_from_dataset(dataloader.dataset)
    iou_types = ["bbox"]
    iou_types.append("segm")
    coco_evaluator = CocoEvaluator(coco, iou_types)
    if args.ipex:
        import intel_extension_for_pytorch as ipex
        model = model.to(memory_format=torch.channels_last)
        model = ipex.optimize(model,
                              dtype=datatype,
                              level="O1",
                              conv_bn_folding=False,
                              replace_dropout_with_identity=False)
        model.backbone = ipex.optimize(model.backbone,
                                       dtype=datatype,
                                       level="O1")
    else:
        if args.jit:
            model = model.to(memory_format=torch.channels_last)
        else:
            from torch.utils import mkldnn as mkldnn_utils
            model = mkldnn_utils.to_mkldnn(model, dtype=datatype)
    if args.jit:
        x = torch.randn(batch_size, 3, 1200,
                        1200).to(memory_format=torch.channels_last)
        if args.precision == "bf16":
            with torch.cpu.amp.autocast(), torch.no_grad():
                model.backbone = torch.jit.trace(model.backbone,
                                                 x,
                                                 strict=False)
            model.backbone = torch.jit.freeze(model.backbone)
        else:
            with torch.no_grad():
                model.backbone = torch.jit.trace(model.backbone,
                                                 x,
                                                 strict=False)
            model.backbone = torch.jit.freeze(model.backbone)
    with torch.no_grad():
        if dataloader is None:
            print(
                "Models for detection tasks need to use real dataset. You need to specify coco dataset. "
            )
            exit(1)
        else:
            for i, batch in enumerate(dataloader):
                images = batch[0]
                if not args.ipex and not args.jit:
                    images = list(img.to(datatype) for img in images)
                if args.ipex and args.precision == "bf16":
                    with torch.cpu.amp.autocast():
                        if i == warmup_iters:
                            with profile(
                                    activities=[ProfilerActivity.CPU],
                                    record_shapes=True
                            ) as prof, record_function("model_inference"):
                                output = model(images)
                        else:
                            output = model(images)
                else:
                    if i == warmup_iters:
                        with profile(
                                activities=[ProfilerActivity.CPU],
                                record_shapes=True) as prof, record_function(
                                    "model_inference"):
                            output = model(images)
                    else:
                        output = model(images)
                if i > warmup_iters:
                    break
            for i, batch in enumerate(dataloader):
                images = batch[0]
                end = time.time()
                if not args.ipex and not args.jit:
                    images = list(img.to(datatype) for img in images)
                if args.ipex and args.precision == "bf16":
                    with torch.cpu.amp.autocast():
                        output = model(images)
                else:
                    output = model(images)
                batch_time.update(time.time() - end)
                output = [{k: v.to(torch.float32)
                           for k, v in t.items()} for t in output]
                res = {
                    target["image_id"].item(): output
                    for target, output in zip(batch[1], output)
                }
                coco_evaluator.update(res)
                if max_iters != -1 and i >= max_iters:
                    break
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=-1))
    latency = batch_time.avg / batch_size * 1000
    perf = batch_size / batch_time.avg
    coco_evaluator.synchronize_between_processes()
    coco_evaluator.accumulate()
    coco_evaluator.summarize()
    print("Bbox AP: {:.5f} ".format(coco_evaluator.coco_eval['bbox'].stats[0]))
    print("Segm AP: {:.5f} ".format(coco_evaluator.coco_eval['segm'].stats[0]))
    print('Latency: %.3f ms' % latency)
    print("Throughput: {:.3f} fps".format(perf))
示例#6
0
def train300_mlperf_coco(args):
    global torch
    from coco import COCO
    # Check that GPUs are actually available
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    args.distributed = False
    if use_cuda:
        try:
            from apex.parallel import DistributedDataParallel as DDP
            if 'WORLD_SIZE' in os.environ:
                args.distributed = int(os.environ['WORLD_SIZE']) > 1
        except:
            raise ImportError(
                "Please install APEX from https://github.com/nvidia/apex")

    local_seed = args.seed
    os.environ['USE_CUDA'] = str(use_cuda)
    if args.world_size > 1:
        args.distributed = True

    if args.distributed:
        # necessary pytorch imports
        import torch.utils.data.distributed
        import torch.distributed as dist
        print('Distributed training with DDP')
        if args.no_cuda:
            device = torch.device('cpu')
            os.environ['RANK'] = str(os.environ.get('PMI_RANK', args.rank))
            os.environ['WORLD_SIZE'] = str(
                os.environ.get('PMI_SIZE', args.world_size))
            os.environ['MASTER_ADDR'] = args.master_addr
            os.environ['MASTER_PORT'] = args.port

            # Initialize the process group with ccl backend
            if args.backend == 'ccl':
                import torch_ccl
            dist.init_process_group(backend=args.backend)
        else:
            torch.cuda.set_device(args.local_rank)
            device = torch.device('cuda')
            dist.init_process_group(backend='nccl', init_method='env://')
            # set seeds properly
            args.seed = broadcast_seeds(args.seed, device)
            local_seed = (args.seed + dist.get_rank()) % 2**32
    mllogger.event(key=mllog_const.SEED, value=local_seed)
    # Refer to https://pytorch.org/docs/stable/notes/randomness.html#dataloader
    torch.manual_seed(local_seed)  # Set PyTorch seed
    np.random.seed(seed=local_seed)  # Set Numpy seed
    random.seed(local_seed)  # Set the Python seed

    args.rank = dist.get_rank() if args.distributed else args.local_rank
    print("args.rank = {}".format(args.rank))
    print("local rank = {}".format(args.local_rank))
    print("distributed={}".format(args.distributed))

    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)

    input_size = 300
    train_trans = SSDTransformer(
        dboxes, (input_size, input_size),
        val=False,
        num_cropping_iterations=args.num_cropping_iterations)
    val_trans = SSDTransformer(dboxes, (input_size, input_size), val=True)

    val_annotate = os.path.join(args.data,
                                "annotations/instances_val2017.json")
    val_coco_root = os.path.join(args.data, "val2017")
    train_annotate = os.path.join(args.data,
                                  "annotations/instances_train2017.json")
    train_coco_root = os.path.join(args.data, "train2017")

    cocoGt = COCO(annotation_file=val_annotate)
    train_coco = COCODetection(train_coco_root, train_annotate, train_trans)
    val_coco = COCODetection(val_coco_root, val_annotate, val_trans)
    mllogger.event(key=mllog_const.TRAIN_SAMPLES, value=len(train_coco))
    mllogger.event(key=mllog_const.EVAL_SAMPLES, value=len(val_coco))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_coco)
    else:
        train_sampler = None
    train_dataloader = DataLoader(train_coco,
                                  batch_size=args.batch_size,
                                  shuffle=(train_sampler is None),
                                  sampler=train_sampler,
                                  num_workers=0)
    # set shuffle=True in DataLoader
    # Leslie: here is the workaround: dist.broadcast will fail on other rank. we will run evalution on all the ranks
    val_dataloader = DataLoader(val_coco,
                                batch_size=args.val_batch_size
                                or args.batch_size,
                                shuffle=False,
                                sampler=None,
                                num_workers=0)

    ssd300 = SSD300(train_coco.labelnum, model_path=args.pretrained_backbone)

    ssd300.train()
    if use_cuda:
        ssd300.cuda()
    loss_func = Loss(dboxes)
    if use_cuda:
        loss_func.cuda()
    if args.distributed:
        N_gpu = torch.distributed.get_world_size()
    else:
        N_gpu = 1

    global_batch_size = N_gpu * args.batch_size
    mllogger.event(key=mllog_const.GLOBAL_BATCH_SIZE, value=global_batch_size)
    # Reference doesn't support group batch norm, so bn_span==local_batch_size
    mllogger.event(key=mllog_const.MODEL_BN_SPAN, value=args.batch_size)
    current_lr = args.lr * (global_batch_size / 32)

    assert args.batch_size % args.batch_splits == 0, "--batch-size must be divisible by --batch-splits"
    fragment_size = args.batch_size // args.batch_splits
    if args.batch_splits != 1:
        print("using gradient accumulation with fragments of size {}".format(
            fragment_size))

    # Model to NHWC
    ssd300 = ssd300.to(memory_format=torch.channels_last)

    current_momentum = 0.9
    optim = torch.optim.SGD(ssd300.parameters(),
                            lr=current_lr,
                            momentum=current_momentum,
                            weight_decay=args.weight_decay)
    ssd_print(key=mllog_const.OPT_BASE_LR, value=current_lr)
    ssd_print(key=mllog_const.OPT_WEIGHT_DECAY, value=args.weight_decay)

    iter_num = args.iteration
    avg_loss = 0.0
    inv_map = {v: k for k, v in val_coco.label_map.items()}
    success = torch.zeros(1)
    if use_cuda:
        success = success.cuda()

    if args.warmup:
        nonempty_imgs = len(train_coco)
        wb = int(args.warmup * nonempty_imgs / (N_gpu * args.batch_size))
        ssd_print(key=mllog_const.OPT_LR_WARMUP_STEPS, value=wb)
        warmup_step = lambda iter_num, current_lr: lr_warmup(
            optim, wb, iter_num, current_lr, args)
    else:
        warmup_step = lambda iter_num, current_lr: None

    ssd_print(key=mllog_const.OPT_LR_WARMUP_FACTOR, value=args.warmup_factor)
    ssd_print(key=mllog_const.OPT_LR_DECAY_BOUNDARY_EPOCHS,
              value=args.lr_decay_schedule)
    mllogger.start(key=mllog_const.BLOCK_START,
                   metadata={
                       mllog_const.FIRST_EPOCH_NUM: 1,
                       mllog_const.EPOCH_COUNT: args.epochs
                   })

    if args.performance_only:
        train_time = AverageMeter('TrainTime', ':6.3f')
        progress = ProgressMeter(args.train_iteration, [train_time],
                                 prefix='Train: ')

    # Restore the model and optim from checkpoint
    if args.checkpoint is not None:
        print("loading model checkpoint", args.checkpoint)
        od = torch.load(args.checkpoint)
        ssd300.load_state_dict(od["model"])
        optim.load_state_dict(od['optim'])

    # Model Prepack
    if use_ipex:
        if args.autocast:
            ssd300, optim = ipex.optimize(ssd300,
                                          dtype=torch.bfloat16,
                                          optimizer=optim)
        else:
            ssd300, optim = ipex.optimize(ssd300,
                                          dtype=torch.float32,
                                          optimizer=optim)

    # parallelize
    if args.distributed:
        device_ids = None
        ssd300 = torch.nn.parallel.DistributedDataParallel(
            ssd300, device_ids=device_ids)

    optim.zero_grad(set_to_none=True)
    for epoch in range(args.epochs):
        mllogger.start(key=mllog_const.EPOCH_START,
                       metadata={mllog_const.EPOCH_NUM: epoch})
        # set the epoch for the sampler
        if args.distributed:
            train_sampler.set_epoch(epoch)

        if epoch in args.lr_decay_schedule:
            current_lr *= 0.1
            print("")
            print("lr decay step #{num}".format(
                num=args.lr_decay_schedule.index(epoch) + 1))
            for param_group in optim.param_groups:
                param_group['lr'] = current_lr
        for nbatch, (img, img_id, img_size, bbox,
                     label) in enumerate(train_dataloader):
            naive_train_case = True  # img.shape[0] == fragment_size
            if naive_train_case:
                # Naive train case
                fimg, gloc, glabel, mask, pos_num, neg_num, num_mask = data_preprocess(
                    img, bbox, label, loss_func, args.autocast)

                if args.performance_only and iter_num >= args.warmup_iterations:
                    start_time = time.time()
                if args.profile and args.performance_only and iter_num == 30:
                    # Profile Mode
                    with torch.profiler.profile(
                            on_trace_ready=trace_handler) as prof:
                        with torch.cpu.amp.autocast(enabled=args.autocast):
                            ploc, plabel = ssd300(fimg)
                            loss = loss_func(ploc, plabel, gloc, glabel, mask,
                                             pos_num, neg_num, num_mask,
                                             args.autocast)
                        loss.backward()

                        warmup_step(iter_num, current_lr)
                        optim.step()
                        optim.zero_grad(set_to_none=True)
                else:
                    # Non Profile Mode
                    with torch.cpu.amp.autocast(enabled=args.autocast):
                        ploc, plabel = ssd300(fimg)
                        loss = loss_func(ploc, plabel, gloc, glabel, mask,
                                         pos_num, neg_num, num_mask,
                                         args.autocast)
                    loss.backward()

                    warmup_step(iter_num, current_lr)
                    optim.step()
                    optim.zero_grad(set_to_none=True)
            else:
                # Train case: when split input to several fragment size
                print("Not support input with several fragment size yet.")
                exit(-1)
                # current_batch_size = img.shape[0]
                # # Split batch for gradient accumulation
                # img = torch.split(img, fragment_size)
                # bbox = torch.split(bbox, fragment_size)
                # label = torch.split(label, fragment_size)

                # if args.performance_only and iter_num >= args.warmup_iterations:
                #     start_time=time.time()
                # for (fimg, fbbox, flabel) in zip(img, bbox, label):
                #     current_fragment_size = fimg.shape[0]
                #     trans_bbox = fbbox.transpose(1,2).contiguous()
                #     if use_cuda:
                #         fimg = fimg.cuda()
                #         trans_bbox = trans_bbox.cuda()
                #         flabel = flabel.cuda()
                #     fimg = Variable(fimg, requires_grad=True)
                #     gloc, glabel = Variable(trans_bbox, requires_grad=False), \
                #                 Variable(flabel, requires_grad=False)
                #     gloc = loss_func._loc_vec(gloc)
                #     mask = glabel > 0
                #     pos_num = mask.sum(dim=1)
                #     neg_num = torch.clamp(3*pos_num, max=mask.size(1)).unsqueeze(-1)
                #     num_mask = (pos_num > 0).float()
                #     # image to NHWC
                #     fimg = fimg.contiguous(memory_format=torch.channels_last)
                #     if use_ipex:
                #         with ipex.amp.autocast(enabled=args.autocast, configure=ipex.conf.AmpConf(torch.bfloat16)):
                #             ploc, plabel = ssd300(fimg)
                #             loss = loss_func(ploc, plabel, gloc, glabel, mask, pos_num, neg_num, num_mask)
                #     else:
                #         ploc, plabel = ssd300(fimg)
                #         loss = loss_func(ploc, plabel, gloc, glabel, mask, pos_num, neg_num, num_mask)
                #     loss = loss * (current_fragment_size / current_batch_size) # weighted mean
                #     loss.backward()

                # warmup_step(iter_num, current_lr)
                # optim.step()
                # optim.zero_grad(set_to_none=True)
            if args.performance_only and iter_num >= args.warmup_iterations:
                train_time.update(time.time() - start_time)
            if args.performance_only and iter_num % args.print_freq == 0:
                progress.display(iter_num)
            if not np.isinf(loss.item()):
                avg_loss = 0.999 * avg_loss + 0.001 * loss.item()
            if args.log_interval and not iter_num % args.log_interval:
                print("Iteration: {:6d}, Loss function: {:5.8f}, Average Loss: {:.8f}"\
                    .format(iter_num, loss.item(), avg_loss))
            iter_num += 1
            if args.performance_only and iter_num >= args.train_iteration:
                break
        if args.performance_only and iter_num >= args.train_iteration:
            break

        if (args.val_epochs and (epoch+1) in args.val_epochs) or \
           (args.val_interval and not (epoch+1) % args.val_interval):
            if args.distributed:
                world_size = float(dist.get_world_size())
                for bn_name, bn_buf in ssd300.module.named_buffers(
                        recurse=True):
                    if ('running_mean' in bn_name) or ('running_var'
                                                       in bn_name):
                        dist.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
                        bn_buf /= world_size
                        ssd_print(key=mllog_const.MODEL_BN_SPAN,
                                  value=bn_buf.cpu().detach().numpy())
            if args.rank == 0 or True:  # Leslie: here is the workaround: dist.broadcast will fail on other rank. we will run evalution on all the ranks
                if not args.no_save:
                    print("")
                    print("saving model...")
                    torch.save(
                        {
                            "model": ssd300.state_dict(),
                            "label_map": train_coco.label_info,
                            "optim": optim.state_dict()
                        }, "./models/iter_{}.pt".format(iter_num))

                if coco_eval(ssd300,
                             val_dataloader,
                             cocoGt,
                             encoder,
                             inv_map,
                             args.threshold,
                             epoch + 1,
                             iter_num,
                             log_interval=args.log_interval,
                             nms_valid_thresh=args.nms_valid_thresh,
                             use_autocast=args.autocast):
                    success = torch.ones(1)
                    if use_cuda:
                        success = success.cuda()
            # Leslie: same Workaround: since we run evalution on all ranks, we don't need to broadcast the evalutation result
            # if args.distributed:
            #     dist.broadcast(success, 0)
            if success[0]:
                return True
            mllogger.end(key=mllog_const.EPOCH_STOP,
                         metadata={mllog_const.EPOCH_NUM: epoch})
    mllogger.end(key=mllog_const.BLOCK_STOP,
                 metadata={
                     mllog_const.FIRST_EPOCH_NUM: 1,
                     mllog_const.EPOCH_COUNT: args.epochs
                 })

    if args.performance_only:
        batch_size = args.batch_size
        latency = train_time.avg / batch_size * 1000
        perf = batch_size / train_time.avg
        print('train latency %.2f ms' % latency)
        print('train performance %.2f fps' % perf)
        print("Throughput: {:.3f} fps".format(perf))

    return False
示例#7
0
文件: main.py 项目: IntelAI/models
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        os.environ['RANK'] = str(os.environ.get('PMI_RANK', args.rank))
        os.environ['WORLD_SIZE'] = str(
            os.environ.get('PMI_SIZE', args.world_size))
        os.environ['MASTER_ADDR'] = args.master_addr
        os.environ['MASTER_PORT'] = args.port
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu

        # Initialize the process group with ccl backend
        if args.dist_backend == 'ccl':
            import torch_ccl
        dist.init_process_group(backend=args.dist_backend)
        #dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
        #                        world_size=args.world_size, rank=args.rank)
    if args.hub:
        torch.set_flush_denormal(True)
        model = torch.hub.load('facebookresearch/WSL-Images', args.arch)
    else:
        # create model
        if args.pretrained:
            print("=> using pre-trained model '{}'".format(args.arch))
            model = models.__dict__[args.arch](pretrained=True)
        else:
            print("=> creating model '{}'".format(args.arch))
            model = models.__dict__[args.arch]()

    if args.ipex:
        import intel_extension_for_pytorch as ipex
    # for ipex path, always convert model to channels_last for bf16, fp32.
    # TODO: int8 path: https://jira.devtools.intel.com/browse/MFDNN-6103
    if args.ipex and not args.int8:
        model = model.to(memory_format=torch.channels_last)

    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None and args.cuda:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            if args.cuda:
                model.cuda()
                print("create DistributedDataParallel in GPU")
                # DistributedDataParallel will divide and allocate batch_size to all
                # available GPUs if device_ids are not set
                model = torch.nn.parallel.DistributedDataParallel(model)
            else:
                print("create DistributedDataParallel in CPU")
    elif args.gpu is not None and args.cuda:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            if args.cuda:
                model.cuda()
        else:
            model = torch.nn.DataParallel(model)
            if args.cuda():
                model.cuda()

    # define loss function (criterion) and optimizer

    criterion = nn.CrossEntropyLoss()
    if args.cuda:
        criterion = criterion.cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None and args.cuda:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None and args.cuda:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.cuda:
        cudnn.benchmark = True

    if args.weight_sharing:
        assert args.dummy and args.batch_size, \
                "please using dummy data and set batch_size to 1 if you want run weight sharing case for latency case"
    if args.jit and args.int8:
        assert False, "jit path is not available for int8 path using ipex"
    if args.calibration:
        assert args.int8, "please enable int8 path if you want to do int8 calibration path"
    if args.dummy:
        assert args.evaluate, "please using real dataset if you want run training path"
    if not args.ipex:
        # for offical pytorch, int8 and jit path is not enabled.
        assert not args.int8, "int8 path is not enabled for offical pytorch"
        assert not args.jit, "jit path is not enabled for offical pytorch"

    if not args.dummy:
        # Data loading code
        assert args.data != None, "please set dataset path if you want to using real data"
        valdir = os.path.join(args.data, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        if not args.evaluate:
            traindir = os.path.join(args.data, 'train')
            train_dataset = datasets.ImageFolder(
                traindir,
                transforms.Compose([
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ]))

            if args.distributed:
                train_sampler = torch.utils.data.distributed.DistributedSampler(
                    train_dataset)
            else:
                train_sampler = None

            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=(train_sampler is None),
                num_workers=args.workers,
                pin_memory=True,
                sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
    else:
        train_loader = None
        val_loader = None

    if args.evaluate:
        if args.ipex:
            print("using ipex model to do inference\n")
        else:
            print("using offical pytorch model to do inference\n")

        if args.ipex:
            model.eval()
            if args.int8:
                if not args.calibration:
                    model = optimization.fuse(model, inplace=True)
                    conf = ipex.quantization.QuantConf(args.configure_dir)
                    x = torch.randn(
                        args.batch_size, 3, 224,
                        224).contiguous(memory_format=torch.channels_last)
                    model = ipex.quantization.convert(model, conf, x)
                    with torch.no_grad():
                        y = model(x)
                        print(model.graph_for(x))
                    print("running int8 evalation step\n")
            else:
                if args.bf16:
                    model = ipex.optimize(model,
                                          dtype=torch.bfloat16,
                                          inplace=True)
                    print("running bfloat16 evalation step\n")
                else:
                    model = ipex.optimize(model,
                                          dtype=torch.float32,
                                          inplace=True)
                    print("running fp32 evalation step\n")
                if args.jit:
                    x = torch.randn(
                        args.batch_size, 3, 224,
                        224).contiguous(memory_format=torch.channels_last)
                    if args.bf16:
                        x = x.to(torch.bfloat16)
                        with torch.cpu.amp.autocast(), torch.no_grad():
                            model = torch.jit.trace(model, x).eval()
                    else:
                        with torch.no_grad():
                            model = torch.jit.trace(model, x).eval()
                    model = torch.jit.freeze(model)
        validate(val_loader, model, criterion, args)
        return

    if args.ipex:
        if args.bf16:
            model, optimizer = ipex.optimize(model,
                                             dtype=torch.bfloat16,
                                             optimizer=optimizer)
        else:
            model, optimizer = ipex.optimize(model,
                                             dtype=torch.float32,
                                             optimizer=optimizer)

    # parallelize
    if args.distributed and not args.cuda and args.gpu is None:
        print("create DistributedDataParallel in CPU")
        device_ids = None
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=device_ids)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
示例#8
0
def run():
    ### parse arguments ###
    parser = argparse.ArgumentParser(
        description="Train Deep Learning Recommendation Model (DLRM)")
    # model related parameters
    parser.add_argument("--arch-sparse-feature-size", type=int, default=2)
    parser.add_argument("--arch-embedding-size",
                        type=dash_separated_ints,
                        default="4-3-2")
    # j will be replaced with the table number
    parser.add_argument("--arch-mlp-bot",
                        type=dash_separated_ints,
                        default="4-3-2")
    parser.add_argument("--arch-mlp-top",
                        type=dash_separated_ints,
                        default="4-2-1")
    # activations and loss
    parser.add_argument("--activation-function", type=str, default="relu")
    parser.add_argument("--loss-threshold", type=float, default=0.0)  # 1.0e-7
    parser.add_argument("--round-targets", type=bool, default=False)
    # data
    parser.add_argument("--num-batches", type=int, default=0)
    parser.add_argument("--data-set", type=str,
                        default="kaggle")  # or terabyte
    parser.add_argument("--raw-data-file", type=str, default="")
    parser.add_argument("--processed-data-file", type=str, default="")
    parser.add_argument("--max-ind-range", type=int, default=-1)
    parser.add_argument("--memory-map", action="store_true", default=False)
    parser.add_argument("--data-sub-sample-rate", type=float,
                        default=0.0)  # in [0, 1]
    parser.add_argument("--data-randomize", type=str,
                        default="total")  # or day or none
    parser.add_argument(
        "--dataset-multiprocessing",
        action="store_true",
        default=False,
        help="The Kaggle dataset can be multiprocessed in an environment \
                        with more than 7 CPU cores and more than 20 GB of memory. \n \
                        The Terabyte dataset can be multiprocessed in an environment \
                        with more than 24 CPU cores and at least 1 TB of memory.",
    )
    # training
    parser.add_argument("--mini-batch-size", type=int, default=1)
    parser.add_argument("--nepochs", type=int, default=1)
    parser.add_argument("--learning-rate", type=float, default=0.01)
    parser.add_argument("--print-precision", type=int, default=5)
    parser.add_argument("--numpy-rand-seed", type=int, default=123)
    # inference
    parser.add_argument("--inference-only", action="store_true", default=False)
    # store/load model
    parser.add_argument("--save-model", type=str, default="")
    parser.add_argument("--load-model", type=str, default="")
    # debugging and profiling
    parser.add_argument("--print-freq", type=int, default=1)
    parser.add_argument("--test-freq", type=int, default=-1)
    parser.add_argument("--test-mini-batch-size", type=int, default=-1)
    parser.add_argument("--print-time", action="store_true", default=False)
    parser.add_argument("--print-wall-time",
                        action="store_true",
                        default=False)
    parser.add_argument("--enable-profiling",
                        action="store_true",
                        default=False)
    # stop at target AUC Terabyte (no subsampling) 0.8025
    parser.add_argument("--mlperf-auc-threshold", type=float, default=0.0)
    parser.add_argument("--mlperf-bin-loader",
                        action="store_true",
                        default=False)
    parser.add_argument("--mlperf-bin-shuffle",
                        action="store_true",
                        default=False)
    # LR policy
    parser.add_argument("--lr-num-warmup-steps", type=int, default=0)
    parser.add_argument("--lr-decay-start-step", type=int, default=0)
    parser.add_argument("--lr-num-decay-steps", type=int, default=0)
    # intel
    parser.add_argument("--print-auc", action="store_true", default=False)
    parser.add_argument("--should-test", action="store_true", default=False)
    parser.add_argument("--bf16", action="store_true", default=False)
    parser.add_argument("--share-weight-instance", type=int, default=0)
    parser.add_argument("--ipex-interaction",
                        action="store_true",
                        default=False)
    parser.add_argument("--ipex-merged-emb",
                        action="store_true",
                        default=False)
    parser.add_argument("--num-warmup-iters", type=int, default=1000)
    parser.add_argument("--int8", action="store_true", default=False)
    parser.add_argument("--int8-configure",
                        type=str,
                        default="./int8_configure.json")
    parser.add_argument("--dist-backend", type=str, default="ccl")

    global args
    global nbatches
    global nbatches_test
    args = parser.parse_args()
    ext_dist.init_distributed(backend=args.dist_backend)

    ### some basic setup ###
    np.random.seed(args.numpy_rand_seed)
    np.set_printoptions(precision=args.print_precision)
    torch.set_printoptions(precision=args.print_precision)
    torch.manual_seed(args.numpy_rand_seed)

    if args.test_mini_batch_size < 0:
        # if the parameter is not set, use the training batch size
        args.test_mini_batch_size = args.mini_batch_size

    device = torch.device("cpu")
    print("Using CPU...")

    ### prepare training data ###
    ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-")
    # input data

    train_data, train_ld, test_data, test_ld = dp.make_criteo_data_and_loaders(
        args)
    nbatches = args.num_batches if args.num_batches > 0 else len(train_ld)
    nbatches_test = len(test_ld)

    ln_emb = train_data.counts
    # enforce maximum limit on number of vectors per embedding
    if args.max_ind_range > 0:
        ln_emb = np.array(
            list(
                map(
                    lambda x: x
                    if x < args.max_ind_range else args.max_ind_range,
                    ln_emb,
                )))
    else:
        ln_emb = np.array(ln_emb)
    m_den = train_data.m_den
    ln_bot[0] = m_den

    args.ln_emb = ln_emb.tolist()

    ### parse command line arguments ###
    m_spa = args.arch_sparse_feature_size
    ln_emb = np.asarray(ln_emb)
    num_fea = ln_emb.size + 1  # num sparse + num dense features

    m_den_out = ln_bot[ln_bot.size - 1]
    # approach 1: all
    # num_int = num_fea * num_fea + m_den_out
    # approach 2: unique
    num_int = (num_fea * (num_fea - 1)) // 2 + m_den_out

    arch_mlp_top_adjusted = str(num_int) + "-" + args.arch_mlp_top
    ln_top = np.fromstring(arch_mlp_top_adjusted, dtype=int, sep="-")

    ### construct the neural network specified above ###
    # WARNING: to obtain exactly the same initialization for
    # the weights we need to start from the same random seed.
    # np.random.seed(args.numpy_rand_seed)
    global dlrm
    dlrm = DLRM_Net(
        m_spa,
        ln_emb,
        ln_bot,
        ln_top,
        sigmoid_bot=-1,
        sigmoid_top=ln_top.size - 2,
        loss_threshold=args.loss_threshold,
    )
    if args.ipex_merged_emb:
        dlrm.emb_l = ipex.nn.modules.MergedEmbeddingBagWithSGD.from_embeddingbag_list(
            dlrm.emb_l, lr=args.learning_rate)
        dlrm.need_linearize_indices_and_offsets = torch.BoolTensor([False])

    if not args.inference_only:
        optimizer = torch.optim.SGD(dlrm.parameters(), lr=args.learning_rate)
        lr_scheduler = LRPolicyScheduler(
            optimizer,
            args.lr_num_warmup_steps,
            args.lr_decay_start_step,
            args.lr_num_decay_steps,
        )

    ### main loop ###

    # training or inference
    best_acc_test = 0
    best_auc_test = 0
    skip_upto_epoch = 0
    skip_upto_batch = 0
    total_time = 0
    total_loss = 0
    total_iter = 0
    total_samp = 0

    # Load model is specified
    if not (args.load_model == ""):
        print("Loading saved model {}".format(args.load_model))
        ld_model = torch.load(args.load_model,
                              map_location=torch.device("cpu"))
        dlrm.load_state_dict(ld_model["state_dict"])
        ld_j = ld_model["iter"]
        ld_k = ld_model["epoch"]
        ld_nepochs = ld_model["nepochs"]
        ld_nbatches = ld_model["nbatches"]
        ld_nbatches_test = ld_model["nbatches_test"]
        ld_train_loss = ld_model["train_loss"]
        ld_total_loss = ld_model["total_loss"]
        ld_acc_test = ld_model["test_acc"]
        if not args.inference_only:
            optimizer.load_state_dict(ld_model["opt_state_dict"])
            best_acc_test = ld_acc_test
            total_loss = ld_total_loss
            skip_upto_epoch = ld_k  # epochs
            skip_upto_batch = ld_j  # batches
        else:
            args.print_freq = ld_nbatches
            args.test_freq = 0

        print("Saved at: epoch = {:d}/{:d}, batch = {:d}/{:d}, ntbatch = {:d}".
              format(ld_k, ld_nepochs, ld_j, ld_nbatches, ld_nbatches_test))
        print("Training state: loss = {:.6f}".format(ld_train_loss, ))
        print("Testing state: accuracy = {:3.3f} %".format(ld_acc_test * 100))

    ext_dist.barrier()
    print("time/loss/accuracy (if enabled):")

    if args.bf16 and not args.inference_only:
        for j, inputBatch in enumerate(train_ld):
            X, lS_o, lS_i, T, W, CBPP = unpack_batch(inputBatch)
            if ext_dist.my_size > 1:
                local_bs = X.size()[0] // ext_dist.my_size
                rank_id = dlrm.rank
                X = X[rank_id * local_bs:(rank_id + 1) * local_bs]
                T = T[rank_id * local_bs:(rank_id + 1) * local_bs]
                global_bs = local_bs * ext_dist.my_size
                lS_o = lS_o[:, :global_bs]
                lS_i = lS_i[:, :global_bs]

            if isinstance(dlrm.emb_l,
                          ipex.nn.modules.MergedEmbeddingBagWithSGD):
                if ext_dist.my_size > 1:
                    batch_size = X.size()[0]
                    g_i = lS_i[dlrm.local_ln_emb]
                    g_o = lS_o[dlrm.local_ln_emb]
                    n_tables = g_i.shape[0]
                    idx = [g_i[i] for i in range(n_tables)]
                    offset = [g_o[i] for i in range(n_tables)]
                    include_last = [False for i in range(n_tables)]
                    indices, offsets, indices_with_row_offsets = dlrm.emb_l.linearize_indices_and_offsets(
                        idx, offset, include_last)
                else:
                    n_tables = lS_i.shape[0]
                    idx = [lS_i[i] for i in range(n_tables)]
                    offset = [lS_o[i] for i in range(n_tables)]
                    include_last = [False for i in range(n_tables)]
                    indices, offsets, indices_with_row_offsets = dlrm.emb_l.linearize_indices_and_offsets(
                        idx, offset, include_last)
            if isinstance(dlrm.emb_l,
                          ipex.nn.modules.MergedEmbeddingBagWithSGD):
                sample_input = (X, indices, offsets, indices_with_row_offsets)
            else:
                sample_input = (X, lS_o, lS_i)
            break
        dlrm, optimizer = ipex.optimize(dlrm,
                                        dtype=torch.bfloat16,
                                        optimizer=optimizer,
                                        inplace=True,
                                        sample_input=sample_input)

        if args.ipex_merged_emb:
            dlrm.emb_l.to_bfloat16_train()
        for i in range(len(dlrm.top_l)):
            if isinstance(dlrm.top_l[i],
                          ipex.nn.utils._weight_prepack._IPEXLinear):
                if isinstance(dlrm.top_l[i + 1], torch.nn.ReLU):
                    dlrm.top_l[i] = ipex.nn.modules.IPEXLinearEltwise(
                        dlrm.top_l[i], 'relu')
                else:
                    dlrm.top_l[i] = ipex.nn.modules.IPEXLinearEltwise(
                        dlrm.top_l[i], 'sigmoid')
                dlrm.top_l[i + 1] = torch.nn.Identity()
        for i in range(len(dlrm.bot_l)):
            if isinstance(dlrm.bot_l[i],
                          ipex.nn.utils._weight_prepack._IPEXLinear):
                if isinstance(dlrm.bot_l[i + 1], torch.nn.ReLU):
                    dlrm.bot_l[i] = ipex.nn.modules.IPEXLinearEltwise(
                        dlrm.bot_l[i], 'relu')
                else:
                    dlrm.bot_l[i] = ipex.nn.modules.IPEXLinearEltwise(
                        dlrm.bot_l[i], 'sigmoid')
                dlrm.bot_l[i + 1] = torch.nn.Identity()

        if ext_dist.my_size > 1:
            dlrm.bot_l = ext_dist.DDP(dlrm.bot_l)
            dlrm.top_l = ext_dist.DDP(dlrm.top_l)
    training_record = [0, 0]

    def update_training_performance(time,
                                    iters,
                                    training_record=training_record):
        if iters > args.num_warmup_iters:
            training_record[0] += time
            training_record[1] += 1

    def print_training_performance(training_record=training_record):
        if training_record[0] == 0:
            print(
                "num-batches larger than warm up iters, please increase num-batches or decrease warmup iters"
            )
            exit()
        total_samples = training_record[1] * args.mini_batch_size
        throughput = total_samples / training_record[0] * 1000
        print("Throughput: {:.3f} fps".format(throughput))

    test_freq = args.test_freq if args.test_freq != -1 else nbatches // 20
    with torch.autograd.profiler.profile(enabled=args.enable_profiling,
                                         use_cuda=False,
                                         record_shapes=False) as prof:
        if not args.inference_only:
            k = 0
            while k < args.nepochs:

                if k < skip_upto_epoch:
                    continue

                for j, inputBatch in enumerate(train_ld):

                    if j < skip_upto_batch:
                        continue

                    X, lS_o, lS_i, T, W, CBPP = unpack_batch(inputBatch)
                    if ext_dist.my_size > 1:
                        local_bs = X.size()[0] // ext_dist.my_size
                        rank_id = dlrm.rank
                        X = X[rank_id * local_bs:(rank_id + 1) * local_bs]
                        T = T[rank_id * local_bs:(rank_id + 1) * local_bs]
                        global_bs = local_bs * ext_dist.my_size
                        lS_o = lS_o[:, :global_bs]
                        lS_i = lS_i[:, :global_bs]

                    if isinstance(dlrm.emb_l,
                                  ipex.nn.modules.MergedEmbeddingBagWithSGD):
                        if ext_dist.my_size > 1:
                            batch_size = X.size()[0]
                            g_i = lS_i[dlrm.local_ln_emb]
                            g_o = lS_o[dlrm.local_ln_emb]
                            n_tables = g_i.shape[0]
                            idx = [g_i[i] for i in range(n_tables)]
                            offset = [g_o[i] for i in range(n_tables)]
                            include_last = [False for i in range(n_tables)]
                            indices, offsets, indices_with_row_offsets = dlrm.emb_l.linearize_indices_and_offsets(
                                idx, offset, include_last)
                        else:
                            n_tables = lS_i.shape[0]
                            idx = [lS_i[i] for i in range(n_tables)]
                            offset = [lS_o[i] for i in range(n_tables)]
                            include_last = [False for i in range(n_tables)]
                            indices, offsets, indices_with_row_offsets = dlrm.emb_l.linearize_indices_and_offsets(
                                idx, offset, include_last)

                    t1 = time_wrap()

                    # early exit if nbatches was set by the user and has been exceeded
                    if nbatches > 0 and j >= nbatches:
                        break

                    mbs = T.shape[
                        0]  # = args.mini_batch_size except maybe for last

                    # forward pass
                    with torch.cpu.amp.autocast(enabled=args.bf16):
                        if isinstance(
                                dlrm.emb_l,
                                ipex.nn.modules.MergedEmbeddingBagWithSGD):
                            Z = dlrm_wrap(X, indices, offsets,
                                          indices_with_row_offsets).float()
                        else:
                            Z = dlrm_wrap(
                                X,
                                lS_o,
                                lS_i,
                            ).float()

                    # loss
                    E = loss_fn_wrap(Z, T)

                    # compute loss and accuracy
                    L = E.detach().cpu().numpy()  # numpy array

                    with record_function("DLRM backward"):
                        # scaled error gradient propagation
                        # (where we do not accumulate gradients across mini-batches)
                        optimizer.zero_grad(set_to_none=True)
                        # backward pass
                        E.backward()

                    with record_function("DLRM update"):
                        # optimizer
                        optimizer.step()
                    lr_scheduler.step()
                    if isinstance(dlrm.emb_l,
                                  ipex.nn.modules.MergedEmbeddingBagWithSGD):
                        dlrm.emb_l.sgd_args = dlrm.emb_l.sgd_args._replace(
                            lr=lr_scheduler.get_last_lr()[0])

                    t2 = time_wrap()
                    total_time += t2 - t1

                    total_loss += L * mbs
                    total_iter += 1
                    total_samp += mbs

                    should_print = ((j + 1) % args.print_freq
                                    == 0) or (j + 1 == nbatches)
                    should_test = ((args.should_test)
                                   and (((j + 1) % test_freq == 0) or
                                        (j + 1 == nbatches)))

                    # print time, loss and accuracy
                    if should_print or should_test:
                        gT = 1000.0 * total_time / total_iter if args.print_time else -1
                        total_time = 0

                        train_loss = total_loss / total_samp
                        total_loss = 0

                        str_run_type = ("inference"
                                        if args.inference_only else "training")

                        wall_time = ""
                        if args.print_wall_time:
                            wall_time = " ({})".format(time.strftime("%H:%M"))

                        print(
                            "Finished {} it {}/{} of epoch {}, {:.2f} ms/it,".
                            format(str_run_type, j + 1, nbatches, k, gT) +
                            " loss {:.6f}".format(train_loss) + wall_time,
                            flush=True,
                        )
                        update_training_performance(gT, j)

                        total_iter = 0
                        total_samp = 0

                    # testing
                    if should_test:
                        model_metrics_dict, is_best = inference(
                            args,
                            dlrm,
                            best_acc_test,
                            best_auc_test,
                            test_ld,
                        )

                        if (is_best and not (args.save_model == "")
                                and not args.inference_only):
                            model_metrics_dict["epoch"] = k
                            model_metrics_dict["iter"] = j + 1
                            model_metrics_dict["train_loss"] = train_loss
                            model_metrics_dict["total_loss"] = total_loss
                            model_metrics_dict[
                                "opt_state_dict"] = optimizer.state_dict()
                            print("Saving model to {}".format(args.save_model))
                            torch.save(model_metrics_dict, args.save_model)

                        if ((args.mlperf_auc_threshold > 0) and
                            (best_auc_test > args.mlperf_auc_threshold)):
                            print("MLPerf testing auc threshold " +
                                  str(args.mlperf_auc_threshold) +
                                  " reached, stop training")
                k += 1  # nepochs
        else:
            print("Testing for inference only")
            with torch.no_grad():
                inference(args, dlrm, best_acc_test, best_auc_test, test_ld)

    # profiling
    print_training_performance()

    if args.enable_profiling:
        time_stamp = str(datetime.datetime.now()).replace(" ", "_")
        with open("dlrm_s_pytorch" + time_stamp + "_shape.prof",
                  "w") as prof_f:
            prof_f.write(
                prof.key_averages(group_by_input_shape=True).table(
                    sort_by="self_cpu_time_total"))
        with open("dlrm_s_pytorch" + time_stamp + "_total.prof",
                  "w") as prof_f:
            prof_f.write(
                prof.key_averages().table(sort_by="self_cpu_time_total"))
        prof.export_chrome_trace("dlrm_s_pytorch" + time_stamp + ".json")
示例#9
0
def main(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    multi_gpu = args.local_rank is not None
    if multi_gpu:
        print("DISTRIBUTED with ", torch.distributed.get_world_size())

    if args.fp16:
        optim_level = Optimization.mxprO3
    else:
        optim_level = Optimization.mxprO0

    model_definition = toml.load(args.model_toml)
    dataset_vocab = model_definition['labels']['labels']
    ctc_vocab = add_blank_label(dataset_vocab)

    val_manifest = args.val_manifest
    featurizer_config = model_definition['input_eval']
    featurizer_config["optimization_level"] = optim_level

    if args.max_duration is not None:
        featurizer_config['max_duration'] = args.max_duration
    if args.pad_to is not None:
        featurizer_config['pad_to'] = args.pad_to if args.pad_to >= 0 else "max"

    print('model_config')
    print_dict(model_definition)
    print('feature_config')
    print_dict(featurizer_config)
    data_layer = None
    
    if args.wav is None:
        data_layer = AudioToTextDataLayer(
            dataset_dir=args.dataset_dir, 
            featurizer_config=featurizer_config,
            manifest_filepath=val_manifest,
            # sampler='bucket',
            sort_by_duration=args.sort_by_duration,
            labels=dataset_vocab,
            batch_size=args.batch_size,
            pad_to_max=featurizer_config['pad_to'] == "max",
            shuffle=False,
            multi_gpu=multi_gpu)
    audio_preprocessor = AudioPreprocessing(**featurizer_config)

    #encoderdecoder = JasperEncoderDecoder(jasper_model_definition=jasper_model_definition, feat_in=1024, num_classes=len(ctc_vocab))
    model = RNNT(
        feature_config=featurizer_config,
        rnnt=model_definition['rnnt'],
        num_classes=len(ctc_vocab)
    )

    if args.ckpt is not None:
        print("loading model from ", args.ckpt)
        checkpoint = torch.load(args.ckpt, map_location="cpu")
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    if args.ipex:
        import intel_extension_for_pytorch as ipex
        from rnn import IPEXStackTime
        model.joint_net.eval()
        data_type = torch.bfloat16 if args.mix_precision else torch.float32
        if model.encoder["stack_time"].factor == 2:
            model.encoder["stack_time"] = IPEXStackTime(model.encoder["stack_time"].factor)
        model.joint_net = ipex.optimize(model.joint_net, dtype=data_type, auto_kernel_selection=True)
        model.prediction["embed"] = model.prediction["embed"].to(data_type)
        if args.jit:
            print("running jit path")
            model.joint_net.eval()
            if args.mix_precision:
                with torch.cpu.amp.autocast(), torch.no_grad():
                    model.joint_net = torch.jit.trace(model.joint_net, torch.randn(args.batch_size, 1, 1, model_definition['rnnt']['encoder_n_hidden'] + model_definition['rnnt']['pred_n_hidden']), check_trace=False)
            else:
                with torch.no_grad():
                    model.joint_net = torch.jit.trace(model.joint_net, torch.randn(args.batch_size, 1, 1, model_definition['rnnt']['encoder_n_hidden'] + model_definition['rnnt']['pred_n_hidden']), check_trace=False)
            model.joint_net = torch.jit.freeze(model.joint_net)
    else:
        model = model.to("cpu")

    #greedy_decoder = GreedyCTCDecoder()

    # print("Number of parameters in encoder: {0}".format(model.jasper_encoder.num_weights()))
    if args.wav is None:
        N = len(data_layer)
        # step_per_epoch = math.ceil(N / (args.batch_size * (1 if not torch.distributed.is_available() else torch.distributed.get_world_size())))
        step_per_epoch = math.ceil(N / (args.batch_size * (1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size())))

        if args.steps is not None:
            print('-----------------')
            # print('Have {0} examples to eval on.'.format(args.steps * args.batch_size * (1 if not torch.distributed.is_available() else torch.distributed.get_world_size())))
            print('Have {0} examples to eval on.'.format(args.steps * args.batch_size * (1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size())))
            print('Have {0} warm up steps / (gpu * epoch).'.format(args.warm_up))
            print('Have {0} measure steps / (gpu * epoch).'.format(args.steps))
            print('-----------------')
        else:
            print('-----------------')
            print('Have {0} examples to eval on.'.format(N))
            print('Have {0} warm up steps / (gpu * epoch).'.format(args.warm_up))
            print('Have {0} measure steps / (gpu * epoch).'.format(step_per_epoch))
            print('-----------------')
    else:
            audio_preprocessor.featurizer.normalize = "per_feature"

    print ("audio_preprocessor.normalize: ", audio_preprocessor.featurizer.normalize)
    audio_preprocessor.eval()

    # eval_transforms = torchvision.transforms.Compose([
    #     lambda xs: [x.to(ipex.DEVICE) if args.ipex else x.cpu() for x in xs],
    #     lambda xs: [*audio_preprocessor(xs[0:2]), *xs[2:]],
    #     lambda xs: [xs[0].permute(2, 0, 1), *xs[1:]],
    # ])

    eval_transforms = torchvision.transforms.Compose([
        lambda xs: [x.cpu() for x in xs],
        lambda xs: [*audio_preprocessor(xs[0:2]), *xs[2:]],
        lambda xs: [xs[0].permute(2, 0, 1), *xs[1:]],
    ])

    model.eval()
    if args.ipex:
        ipex.nn.utils._model_convert.replace_lstm_with_ipex_lstm(model)

    greedy_decoder = RNNTGreedyDecoder(len(ctc_vocab) - 1, model.module if multi_gpu else model)

    eval(
        data_layer=data_layer,
        audio_processor=eval_transforms,
        encoderdecoder=model,
        greedy_decoder=greedy_decoder,
        labels=ctc_vocab,
        args=args,
        multi_gpu=multi_gpu)
示例#10
0
文件: trainer.py 项目: IntelAI/models
def do_train(cfg,
             model,
             data_loader,
             data_loader_val,
             optimizer,
             scheduler,
             checkpointer,
             device,
             checkpoint_period,
             test_period,
             arguments,
             bf16=False,
             iterations=-1,
             iter_warmup=-1):
    logger = logging.getLogger("maskrcnn_benchmark.trainer")
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()
    training_timer = Timer()
    start_training_time = time.time()
    end = time.time()

    iou_types = ("bbox", )
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm", )
    if cfg.MODEL.KEYPOINT_ON:
        iou_types = iou_types + ("keypoints", )
    dataset_names = cfg.DATASETS.TEST

    model, optimizer = ipex.optimize(
        model,
        dtype=torch.bfloat16 if bf16 else torch.float,
        optimizer=optimizer,
        inplace=True)

    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):

        if any(len(target) < 1 for target in targets):
            logger.error(
                f"Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}"
            )
            continue
        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        if iteration > iter_warmup:
            training_timer.tic()

        with torch.cpu.amp.autocast(enabled=bf16):
            loss_dict = model(images.to(memory_format=torch.channels_last),
                              targets)

        losses = sum(loss.to(torch.float32) for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        scheduler.step()

        if iteration > iter_warmup:
            training_timer.toc()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 1 == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))
        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
        if data_loader_val is not None and test_period > 0 and iteration % test_period == 0:
            meters_val = MetricLogger(delimiter="  ")
            synchronize()
            _ = inference(  # The result can be used for additional logging, e. g. for TensorBoard
                model,
                # The method changes the segmentation mask format in a data loader,
                # so every time a new data loader is created:
                make_data_loader(cfg,
                                 is_train=False,
                                 is_distributed=(get_world_size() > 1),
                                 is_for_period=True),
                dataset_name="[Validation]",
                iou_types=iou_types,
                box_only=False
                if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
                device=cfg.MODEL.DEVICE,
                expected_results=cfg.TEST.EXPECTED_RESULTS,
                expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
                output_folder=None,
            )
            synchronize()
            model.train()
            with torch.no_grad():
                # Should be one image for each GPU:
                for iteration_val, (images_val, targets_val,
                                    _) in enumerate(tqdm(data_loader_val)):
                    images_val = images_val.to(device)
                    targets_val = [target.to(device) for target in targets_val]
                    loss_dict = model(images_val, targets_val)
                    losses = sum(loss for loss in loss_dict.values())
                    loss_dict_reduced = reduce_loss_dict(loss_dict)
                    losses_reduced = sum(
                        loss for loss in loss_dict_reduced.values())
                    meters_val.update(loss=losses_reduced, **loss_dict_reduced)
            synchronize()
            logger.info(
                meters_val.delimiter.join([
                    "[Validation]: ",
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters_val),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))
        if iterations <= 0:
            if iteration == max_iter:
                checkpointer.save("model_final", **arguments)
        elif iter_warmup > 0:
            if iteration == iterations + iter_warmup:
                break
        else:
            if iteration == iterations:
                break

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    if iterations <= 0:
        iterations = max_iter
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / iterations))

    total_train_time = get_time_str(training_timer.total_time)
    logger.info("Model training time: {} ({} s / iter per device)".format(
        total_train_time, training_timer.total_time / iterations))
    print("Training throughput: {:.3f} fps".format(
        (iterations * cfg.SOLVER.IMS_PER_BATCH) / (training_timer.total_time)))
示例#11
0
def main():
    args = parse_args()
    status = 'aborted'  # later set to 'success' if termination criteria met
    device, args = setup_training(args)
    total_batch_size = global_batch_size(args) 
    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    if args.local_rank == 0 or args.local_rank == -1:
        print("parsed args:")
        print(args)
    # Prepare optimizer
    model, optimizer, lr_scheduler, checkpoint, global_step = prepare_model_and_optimizer(args, device)
    model.train()
    model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16 if args.bf16 else torch.float32)
    worker_seeds, shuffling_seeds = utils.setup_seeds(args.seed, args.num_epochs_to_generate_seeds_for, device)
    worker_seed = worker_seeds[args.local_rank]

    random.seed(worker_seed)
    np.random.seed(worker_seed)
    torch.manual_seed(worker_seed)
    worker_init = WorkerInitObj(worker_seed)
    samples_trained = global_step * args.train_batch_size * args.gradient_accumulation_steps * args.world_size
    final_loss = float("inf")
    train_time_raw = float("inf")
    raw_train_start = time.time()
    if args.do_train:
        model.train()
        most_recent_ckpts_paths = []
        average_loss = 0.0  # averaged loss every args.log_freq steps
        epoch = 1
        training_steps = 0
        end_training, converged = False, False
        samples_trained_prev = 0

        # pre-compute eval boundaries
        samples_trained_per_step = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
        start, stop, step = args.eval_iter_start_samples, args.max_samples_termination, args.eval_iter_samples
        eval_steps = [math.ceil(i/samples_trained_per_step) for i in np.arange(start, stop, step)]
        eval_count = 0
        next_eval_step = eval_steps[eval_count]
        pool = ProcessPoolExecutor(1)

        if args.target_mlm_accuracy:
            if args.train_mlm_accuracy_window_size > 0:
                accuracy_scores = []
                avg_mlm_accuracy = torch.Tensor([0])


        first_epoch = True
        if found_resume_checkpoint(args):
            f_start_id = checkpoint['files'][0]
            files = checkpoint['files'][1:]
            num_files = len(files)
        else:
            files = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if
                     os.path.isfile(os.path.join(args.input_dir, f)) and 'part' in f]
            files.sort()
            num_files = len(files)
            random.Random(shuffling_seeds[epoch%len(shuffling_seeds)]).shuffle(files)
            f_start_id = 0
    global skipped_steps
    if torch.distributed.is_initialized():
        model = torch.nn.parallel.DistributedDataParallel(model,
                                                          find_unused_parameters=True,
                                                          bucket_cap_mb=8192,
                                                          gradient_as_bucket_view=args.use_gradient_as_bucket_view)

    
    now_step, now_skipped, skip_interval = 0, 0, 0
    # Start prefetching eval dataset
    if args.eval_dir:
        eval_dataset_future = pool.submit(create_eval_dataset, args, worker_init_fn=worker_init)
    # comparing to number of samples in a shard. There are ~38k samples in 4096-way shard, comparing to 10k to be safe
    need_next_training_shard = args.train_batch_size * args.gradient_accumulation_steps *  args.max_steps > 10000

    while global_step < args.max_steps and not end_training:
        if args.local_rank == 0 or args.local_rank == -1:
            now_time = time.time()
            print("epoch:", epoch)

        thread = None
        
        # Reshuffle file list on subsequent epochs
        if not first_epoch:
            files = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if
                     os.path.isfile(os.path.join(args.input_dir, f)) and 'part' in f]
            files.sort()
            num_files = len(files)
            random.Random(shuffling_seeds[epoch%len(shuffling_seeds)]).shuffle(files)
            f_start_id = 0

        first_epoch = False

        shared_file_list = {}

        if torch.distributed.is_initialized() and args.world_size > num_files:
            remainder = args.world_size % num_files
            data_file = files[(f_start_id*args.world_size + args.local_rank +
                               remainder * f_start_id) % num_files]
        else:
            data_file = files[(f_start_id*args.world_size + args.local_rank) % num_files]

        previous_file = data_file
        
        train_data = pretraining_dataset(data_file, args.max_predictions_per_seq)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        send_lr_in_parallel = False
        lr_cpu = torch.tensor([0.0], dtype=torch.float32, device='cpu')
        completed_steps=0
        bench_total_time=0
        for f_id in range(f_start_id, len(files)):
            if args.world_size > num_files:
                data_file = files[(f_id*args.world_size + args.local_rank +
                                   remainder * f_id) % num_files]
            else:
                data_file = files[(f_id*args.world_size + args.local_rank)%num_files]

            previous_file = data_file
            if need_next_training_shard:
                dataset_future = pool.submit(create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args, worker_init_fn=worker_init)
            t0 = time.time()
            for step, batch in enumerate(train_dataloader):
                training_steps += 1
                t_beg = time.time()
                t1 = time.time()
                input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
                #print(f"Input shape: {batch['input_ids'].shape}")
                t2 = time.time()
                outputs = None
                if args.bf16:
                    with torch.cpu.amp.autocast():
                        outputs = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask,
                             labels=masked_lm_labels, next_sentence_label=next_sentence_labels)
                else:
                    outputs = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask,
                             labels=masked_lm_labels, next_sentence_label=next_sentence_labels)
                t3 = time.time()
                loss = outputs.loss
                loss = loss / args.gradient_accumulation_steps
                loss.backward()
                t4 = time.time()
                if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()
                    #progress_bar.update(1)
                t5 = time.time()
                t_end = time.time()
                completed_steps += 1
                if args.benchmark and completed_steps > 10:
                    bench_total_time = bench_total_time + (t_end -t_beg)
                if args.benchmark and completed_steps > 60:
                    throughput = 50 * args.train_batch_size / bench_total_time
                    print("Throughput: {:.3f} sentence/s".format(throughput), flush=True)
                    exit()

                gloss, lm_acc, num_masked, seq_acc, seq_tot = calc_accuracy(outputs, masked_lm_labels, next_sentence_labels, args)
                #if args.local_rank == 0:
                print(f"Step {training_steps:5d}: loss: {gloss:6.3f} lm_acc: {lm_acc:.3f} seq_acc: {seq_acc:.3f} lbs: {args.train_batch_size} gbs: {total_batch_size} DT: {(t1-t0)*1000.0:.1f} XT: {(t2-t1)*1000.0:.1f} FT: {(t3-t2)*1000.0:.1f} BT: {(t4-t3)*1000.0:.1f} OT: {(t5-t4)*1000.0:.1f} TT: {(t5-t0)*1000.0:.1f}")

                update_step = training_steps % args.gradient_accumulation_steps == 0
                divisor = args.gradient_accumulation_steps
                if args.log_freq>0:
                    average_loss += loss.item()
                if update_step:
                    now_lr = optimizer.param_groups[0]['lr']
                    global_step += 1
                    if (args.eval_dir and args.eval_iter_samples > 0 and global_step == next_eval_step):
                        # on first eval, get eval_dataloader
                        if eval_count == 0:
                            eval_dataloader = create_eval_dataset(args, worker_init_fn=worker_init) #eval_dataset_future.result(timeout=None)
                        samples_trained = global_step * args.train_batch_size * args.gradient_accumulation_steps * args.world_size
                        samples_trained_prev = samples_trained
                        eval_avg_loss, eval_avg_mlm_accuracy = run_eval(model, eval_dataloader, device, args.num_eval_examples, args,
                                                                                    first_eval=(eval_count == 0))
                        if args.local_rank == 0 or args.local_rank == -1:
                            print({"global_steps": global_step, "eval_loss": eval_avg_loss, "eval_mlm_accuracy":eval_avg_mlm_accuracy})

                            if args.target_mlm_accuracy:
                                if eval_avg_mlm_accuracy >= args.target_mlm_accuracy:
                                    end_training, converged = True, True
                                    if utils.is_main_process():
                                        print("%f > %f, Target MLM Accuracy reached at %d"%(eval_avg_mlm_accuracy, args.target_mlm_accuracy, global_step))

                        eval_count += 1
                        next_eval_step = eval_steps[eval_count]
                if args.target_mlm_accuracy and args.train_mlm_accuracy_window_size > 0:
                    accuracy_scores.append(mlm_acc)
                    if update_step:
                        accuracy_scores = accuracy_scores[-args.train_mlm_accuracy_window_size * args.gradient_accumulation_steps:]
                        avg_mlm_accuracy[0] = sum(accuracy_scores) / len(accuracy_scores)
                        torch.distributed.all_reduce(avg_mlm_accuracy, op=torch.distributed.ReduceOp.SUM)
                        avg_mlm_accuracy /= args.world_size

                if args.log_freq > 0 and training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0:
                    samples_trained = global_step * args.train_batch_size * args.gradient_accumulation_steps * args.world_size
                    if args.local_rank == 0 or args.local_rank == -1:
                        time_interval = time.time() - now_time
                        step_interval = global_step - now_step
                        now_time = time.time()
                        now_step = global_step
                        training_perf = args.train_batch_size * args.gradient_accumulation_steps * args.world_size \
                                        * (step_interval + skip_interval) / time_interval
                        skip_interval = 0

                        if args.train_mlm_accuracy_window_size > 0:
                            print({"training_steps": training_steps,
                                  "average_loss": average_loss / (args.log_freq * divisor),
                                  "step_loss": loss.item() * args.gradient_accumulation_steps / divisor,
                                  "learning_rate": now_lr,
                                  "seq/s": training_perf,
                                  "global_steps": now_step,
                                  "samples_trained": samples_trained,
                                  "skipped_steps": now_skipped,
                                  "timestamp": now_time,
                                  "mlm_accuracy": avg_mlm_accuracy[0].item()})
                        else:
                            print({"training_steps": training_steps,
                                  "average_loss": average_loss / (args.log_freq * divisor),
                                  "step_loss": loss.item() * args.gradient_accumulation_steps / divisor,
                                  "learning_rate": now_lr,
                                  "seq/s": training_perf,
                                  "global_steps": now_step,
                                  "samples_trained": samples_trained,
                                  "skipped_steps": now_skipped,
                                  "timestamp": now_time})

                        
                    average_loss = 0
                
                if global_step >= args.max_steps or end_training:
                    status = 'success' if converged else 'aborted'
                    end_training = True
                    train_time_raw = time.time() - raw_train_start
                    average_loss = torch.tensor(average_loss, dtype=torch.float32)
                    if args.log_freq > 0:
                        last_num_steps = int(training_steps / args.gradient_accumulation_steps) % args.log_freq
                        last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps
                        average_loss = average_loss / (last_num_steps * divisor)
                    if (torch.distributed.is_initialized()):
                        average_loss /= args.world_size
                        torch.distributed.all_reduce(average_loss)
                    final_loss = average_loss.item()
                    if utils.is_main_process():
                        if args.train_mlm_accuracy_window_size > 0:
                            print((epoch, training_steps / args.gradient_accumulation_steps, ), {"final_loss": final_loss,
                                "final_mlm_accuracy": avg_mlm_accuracy[0].item()})
                        else:
                            print((epoch, training_steps / args.gradient_accumulation_steps, ), {"final_loss": final_loss})

                if end_training or (samples_trained - samples_trained_prev >= args.num_samples_per_checkpoint and samples_trained >= args.min_samples_to_start_checkpoints):
                    samples_trained_prev = samples_trained
                    if utils.is_main_process() and not args.skip_checkpoint:
                        # Save a trained model
                        model_to_save = model.module if hasattr(model,
                                                                'module') else model  # Only save the model it-self
                        if args.phase2:
                            output_save_file = os.path.join(args.output_dir, "phase2_ckpt_{}.pt".format(samples_trained))
                        else:
                            output_save_file = os.path.join(args.output_dir, "phase1_ckpt_{}.pt".format(samples_trained))
                        if args.do_train:
                            torch.save({'model': model_to_save.state_dict(),
                                        'optimizer': optimizer.state_dict(),
                                        'master params': list(amp.master_params(optimizer)),
                                        'files': [f_id] + files}, output_save_file)

                            most_recent_ckpts_paths.append(output_save_file)
                            if len(most_recent_ckpts_paths) > args.keep_n_most_recent_checkpoints:
                                ckpt_to_be_removed = most_recent_ckpts_paths.pop(0)
                                os.remove(ckpt_to_be_removed)

                    if samples_trained >= args.max_samples_termination or end_training:
                        status = 'success' if converged else 'aborted'
                        end_training = True
                        break
                t0 = time.time()    

            del train_dataloader

            if samples_trained >= args.max_samples_termination or end_training:
                status = 'success' if converged else 'aborted'
                end_training = True
                break

            if not need_next_training_shard:
                dataset_future = pool.submit(create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args, worker_init_fn=worker_init)
            train_dataloader, data_file = dataset_future.result(timeout=None)
        epoch += 1
       
    return args, final_loss, train_time_raw
示例#12
0
文件: train.py 项目: IntelAI/models
def main(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    args.local_rank = os.environ.get('LOCAL_RANK', args.local_rank)
    # set up distributed training
    cpu_distributed_training = False
    if torch.distributed.is_available() and int(os.environ.get('PMI_SIZE', '0')) > 1:
        print('Distributed training with DDP')
        os.environ['RANK'] = os.environ.get('PMI_RANK', '0')
        os.environ['WORLD_SIZE'] = os.environ.get('PMI_SIZE', '1')
        if not 'MASTER_ADDR' in os.environ:
            os.environ['MASTER_ADDR'] = args.master_addr
        if not 'MASTER_PORT' in os.environ:
            os.environ['MASTER_PORT'] = args.port

        # Initialize the process group with ccl backend
        if args.backend == 'ccl':
            import torch_ccl
        dist.init_process_group(
                backend=args.backend                
        )
        cpu_distributed_training = True
        if torch.distributed.is_initialized():
            print("Torch distributed is initialized.")
            args.rank = torch.distributed.get_rank()
            args.world_size = torch.distributed.get_world_size()
        else:
            print("Torch distributed is not initialized.")
            args.rank = 0
            args.world_size = 1

    multi_gpu = False
    if multi_gpu:
        print_once("DISTRIBUTED TRAINING with {} gpus".format(torch.distributed.get_world_size()))

    optim_level = Optimization.mxprO0

    model_definition = toml.load(args.model_toml)
    dataset_vocab = model_definition['labels']['labels']
    ctc_vocab = add_blank_label(dataset_vocab)

    train_manifest = args.train_manifest
    val_manifest = args.val_manifest
    tst_manifest = args.tst_manifest
    featurizer_config = model_definition['input']
    featurizer_config_eval = model_definition['input_eval']
    featurizer_config["optimization_level"] = optim_level
    featurizer_config_eval["optimization_level"] = optim_level

    sampler_type = featurizer_config.get("sampler", 'default')
    perturb_config = model_definition.get('perturb', None)
    if args.pad_to_max:
        assert(args.max_duration > 0)
        featurizer_config['max_duration'] = args.max_duration
        featurizer_config_eval['max_duration'] = args.max_duration
        featurizer_config['pad_to'] = "max"
        featurizer_config_eval['pad_to'] = "max"
    print_once('model_config')
    print_dict(model_definition)

    if args.gradient_accumulation_steps < 1:
        raise ValueError('Invalid gradient accumulation steps parameter {}'.format(args.gradient_accumulation_steps))
    if args.batch_size % args.gradient_accumulation_steps != 0:
        raise ValueError('gradient accumulation step {} is not divisible by batch size {}'.format(args.gradient_accumulation_steps, args.batch_size))


    preprocessor = preprocessing.AudioPreprocessing(**featurizer_config)
    if args.cuda:
        preprocessor.cuda()
    else:
        preprocessor.cpu()

    augmentations = preprocessing.SpectrogramAugmentation(**featurizer_config)
    if args.cuda:
        augmentations.cuda()
    else:
        augmentations.cpu()

    train_transforms = torchvision.transforms.Compose([
        lambda xs: [x.cpu() for x in xs],
        lambda xs: [*preprocessor(xs[0:2]), *xs[2:]],
        lambda xs: [augmentations(xs[0]),   *xs[1:]],
        lambda xs: [xs[0].permute(2, 0, 1), *xs[1:]],
    ])

    eval_transforms = torchvision.transforms.Compose([
        lambda xs: [x.cpu() for x in xs],
        lambda xs: [*preprocessor(xs[0:2]), *xs[2:]],
        lambda xs: [xs[0].permute(2, 0, 1), *xs[1:]],
    ])

    data_layer = AudioToTextDataLayer(
                                    dataset_dir=args.dataset_dir,
                                    featurizer_config=featurizer_config,
                                    perturb_config=perturb_config,
                                    manifest_filepath=train_manifest,
                                    labels=dataset_vocab,
                                    batch_size=args.batch_size // args.gradient_accumulation_steps,
                                    multi_gpu=multi_gpu,
                                    pad_to_max=args.pad_to_max,
                                    sampler=sampler_type,
                                    cpu_distributed_training=cpu_distributed_training)

    eval_datasets = [(
        AudioToTextDataLayer(
            dataset_dir=args.dataset_dir,
            featurizer_config=featurizer_config_eval,
            manifest_filepath=val_manifest,
            labels=dataset_vocab,
            batch_size=args.eval_batch_size,
            multi_gpu=multi_gpu,
            pad_to_max=args.pad_to_max
        ),
        args.eval_frequency,
        'Eval clean',
    )]

    if tst_manifest:
        eval_datasets.append((
            AudioToTextDataLayer(
                dataset_dir=args.dataset_dir,
                featurizer_config=featurizer_config_eval,
                manifest_filepath=tst_manifest,
                labels=dataset_vocab,
                batch_size=args.eval_batch_size,
                multi_gpu=multi_gpu,
                pad_to_max=args.pad_to_max
            ),
            args.test_frequency,
            'Test other',
        ))

    model = RNNT(
        feature_config=featurizer_config,
        rnnt=model_definition['rnnt'],
        num_classes=len(ctc_vocab)
    )

    if args.ckpt is not None:
        print_once("loading model from {}".format(args.ckpt))
        checkpoint = torch.load(args.ckpt, map_location="cpu")
        model.load_state_dict(checkpoint['state_dict'], strict=True)
        args.start_epoch = checkpoint['epoch']
    else:
        args.start_epoch = 0

    loss_fn = RNNTLoss(blank=len(ctc_vocab) - 1)

    N = len(data_layer)
    if sampler_type == 'default':
        args.step_per_epoch = math.ceil(N / (args.batch_size * (1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size())))
    elif sampler_type == 'bucket':
        args.step_per_epoch = int(len(data_layer.sampler) / args.batch_size )

    print_once('-----------------')
    print_once('Have {0} examples to train on.'.format(N))
    print_once('Have {0} steps / (gpu * epoch).'.format(args.step_per_epoch))
    print_once('-----------------')

    constant_lr_policy = lambda _: args.lr
    fn_lr_policy = constant_lr_policy
    if args.lr_decay:
        pre_decay_policy = fn_lr_policy
        fn_lr_policy = lambda s: lr_decay(args.num_epochs * args.step_per_epoch, s, pre_decay_policy(s))
    if args.lr_warmup:
        pre_warmup_policy = fn_lr_policy
        fn_lr_policy = lambda s: lr_warmup(args.lr_warmup, s, pre_warmup_policy(s) )

    if args.optimizer_kind == "novograd":
        optimizer = Novograd(model.parameters(),
                        lr=args.lr,
                        weight_decay=args.weight_decay)
    elif args.optimizer_kind == "adam":
        optimizer = AdamW(model.parameters(),
                        lr=args.lr,
                        weight_decay=args.weight_decay)
    else:
        raise ValueError("invalid optimizer choice: {}".format(args.optimizer_kind))

    if args.cuda and optim_level in AmpOptimizations:
        assert False, "not supported in ipex"

    if args.ckpt is not None:
        optimizer.load_state_dict(checkpoint['optimizer'])

    if args.ipex:
        if args.bf16:
            model, optimizer = ipex.optimize(model, dtype=torch.bfloat16, optimizer=optimizer)
            ipex.nn.utils._model_convert.replace_lstm_with_ipex_lstm(model)
        else:
            model, optimizer = ipex.optimize(model, dtype=torch.float32, optimizer=optimizer)
            ipex.nn.utils._model_convert.replace_lstm_with_ipex_lstm(model)

    if args.world_size > 1:
        device_ids = None
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids)

    print_once(model)
    print_once("# parameters: {}".format(sum(p.numel() for p in model.parameters())))
    greedy_decoder = RNNTGreedyDecoder(len(ctc_vocab) - 1, model.module if multi_gpu else model)

    if args.tb_path and args.local_rank == 0:
        logger = TensorBoardLogger(args.tb_path, model.module if multi_gpu else model, args.histogram)
    else:
        logger = DummyLogger()

    train(
        data_layer=data_layer,
        model=model,
        loss_fn=loss_fn,
        greedy_decoder=greedy_decoder,
        optimizer=optimizer,
        data_transforms=train_transforms,
        labels=ctc_vocab,
        optim_level=optim_level,
        multi_gpu=multi_gpu,
        fn_lr_policy=fn_lr_policy,
        evalutaion=evaluator(model, eval_transforms, loss_fn, greedy_decoder, ctc_vocab, eval_datasets, logger),
        logger=logger,
        args=args)
示例#13
0
def inference(model, dataloader, datatype, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    batch_size = args.batch_size
    warmup_iters = args.warmup_iterations
    max_iters = args.max_iterations if dataloader is None else len(dataloader)
    progress = ProgressMeter(max_iters, [batch_time, losses, top1, top5],
                             prefix='Test: ')
    model.eval()
    if args.ipex:
        import intel_extension_for_pytorch as ipex
        model = model.to(memory_format=torch.channels_last)
        model = ipex.optimize(model, dtype=datatype, level="O1")
    else:
        if args.jit:
            model = model.to(memory_format=torch.channels_last)
        else:
            from torch.utils import mkldnn as mkldnn_utils
            model = mkldnn_utils.to_mkldnn(model, dtype=datatype)
    if args.jit:
        if dataloader is None:
            x = torch.randn(batch_size, 3, args.height, args.width)
        else:
            for i, batch in enumerate(dataloader):
                x = torch.randn(batch[0].shape)
                break
        x = x.to(memory_format=torch.channels_last)
        if args.precision == "bf16":
            with torch.cpu.amp.autocast(), torch.no_grad():
                model = torch.jit.trace(model, x, strict=False)
            model = torch.jit.freeze(model)
        else:
            with torch.no_grad():
                model = torch.jit.trace(model, x, strict=False)
            model = torch.jit.freeze(model)
    with torch.no_grad():
        if dataloader is None:
            for i in range(max_iters):
                images = torch.randn(batch_size, 3, args.height, args.width)
                if i > warmup_iters:
                    end = time.time()
                if not args.ipex and not args.jit:
                    images = images.to(datatype)
                else:
                    images = images.to(memory_format=torch.channels_last)
                if args.ipex and args.precision == "bf16" and not args.jit:
                    with torch.cpu.amp.autocast():
                        if i == warmup_iters:
                            with profile(
                                    activities=[ProfilerActivity.CPU],
                                    record_shapes=True
                            ) as prof, record_function("model_inference"):
                                output = model(images)
                        else:
                            output = model(images)
                else:
                    if i == warmup_iters:
                        with profile(
                                activities=[ProfilerActivity.CPU],
                                record_shapes=True) as prof, record_function(
                                    "model_inference"):
                            output = model(images)
                    else:
                        output = model(images)
                if i > warmup_iters:
                    batch_time.update(time.time() - end)
                if i % args.print_freq == 0:
                    progress.display(i)
        else:
            # warm up
            for i, (images, target) in enumerate(dataloader):
                if i > warmup_iters:
                    break
                if not args.ipex and not args.jit:
                    images = images.to(datatype).to(
                        memory_format=torch.channels_last)
                if args.ipex and args.precision == "bf16" and not args.jit:
                    with torch.cpu.amp.autocast():
                        if i == warmup_iters:
                            with profile(
                                    activities=[ProfilerActivity.CPU],
                                    record_shapes=True
                            ) as prof, record_function("model_inference"):
                                output = model(images)
                        else:
                            output = model(images)
                else:
                    if i == warmup_iters:
                        with profile(
                                activities=[ProfilerActivity.CPU],
                                record_shapes=True) as prof, record_function(
                                    "model_inference"):
                            output = model(images)
                    else:
                        output = model(images)

            criterion = nn.CrossEntropyLoss()
            for i, (images, target) in enumerate(dataloader):
                end = time.time()
                if not args.ipex and not args.jit:
                    images = images.to(datatype).to(
                        memory_format=torch.channels_last)
                if args.ipex and args.precision == "bf16" and not args.jit:
                    output = model(images)
                else:
                    output = model(images)
                batch_time.update(time.time() - end)
                if args.precision == "bf16":
                    output = output.to(torch.float32)
                loss = criterion(output, target)
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                losses.update(loss.item(), images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))
                if max_iters != -1 and i >= max_iters:
                    break
                if i % args.print_freq == 0:
                    progress.display(i)
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=-1))
    latency = batch_time.avg / batch_size * 1000
    perf = batch_size / batch_time.avg
    print('Latency: %.3f ms' % latency)
    print("Throughput: {:.3f} fps".format(perf))
    print("Accuracy: {top1.avg:.3f} ".format(top1=top1))
示例#14
0
def inference(model, dataloader, args):
    batch_time = AverageMeter('Time', ':6.3f')
    batch_size = args.batch_size
    total_stats = {"tp": 0, "fp": 0, "fn": 0}
    warmup_iters = args.warmup_iterations
    max_iters = args.max_iterations
    logit_fc = torch.sigmoid
    datatype = torch.float32
    if args.precision == 'bf16':
        datatype = torch.bfloat16
    model.eval()

    if args.ipex:
        import intel_extension_for_pytorch as ipex
        model = model.to(memory_format=torch.channels_last_3d)
        model = ipex.optimize(model, dtype=datatype, inplace=True)
    else:
        if args.jit:
            model = model.to(memory_format=torch.channels_last_3d)
        else:
            from torch.utils import mkldnn as mkldnn_utils
            model = mkldnn_utils.to_mkldnn(model, dtype=datatype)
    if args.jit:
        x = torch.randint(
            0, 255, (args.batch_size, 100, 27, 48, 3),
            dtype=datatype).to(memory_format=torch.channels_last_3d)
        if args.precision == "bf16":
            with torch.cpu.amp.autocast(), torch.no_grad():
                model = torch.jit.trace(model, x, strict=False).eval()
            model = torch.jit.freeze(model)
        else:
            with torch.no_grad():
                model = torch.jit.trace(model, x, strict=False).eval()
            model = torch.jit.freeze(model)

    with torch.no_grad():
        for i in range(warmup_iters + 1):
            images = torch.randint(
                0, 255, (args.batch_size, 100, 27, 48, 3),
                dtype=datatype).to(memory_format=torch.channels_last_3d)
            if args.ipex and args.precision == "bf16":
                with torch.cpu.amp.autocast():
                    if i == warmup_iters:
                        with profile(activities=[ProfilerActivity.CPU],
                                     record_shapes=True) as prof:
                            output = model(images)
                    else:
                        output = model(images)
            else:
                if i == warmup_iters:
                    with profile(activities=[ProfilerActivity.CPU],
                                 record_shapes=True) as prof:
                        output = model(images)
                else:
                    output = model(images)

        for i in range(max_iters):
            images = torch.randint(0, 255, (args.batch_size, 100, 27, 48, 3))
            end = time.time()
            images = images.to(datatype).to(
                memory_format=torch.channels_last_3d)
            if args.ipex and args.precision == "bf16":
                with torch.cpu.amp.autocast():
                    output = model(images)
            else:
                output = model(images)
            batch_time.update(time.time() - end)
            if isinstance(output, tuple):
                output = output[0]
            output = output.to(torch.float32)
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=-1))
    latency = batch_time.avg / batch_size * 1000
    perf = batch_size / batch_time.avg
    print('Latency: %.3f ms' % latency)
    print("Throughput: {:.3f} fps".format(perf))
示例#15
0
def coco_eval(model, val_dataloader, cocoGt, encoder, inv_map, args):
    from pycocotools.cocoeval import COCOeval
    device = args.device
    threshold = args.threshold
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    model.eval()

    ret = []

    inference_time = AverageMeter('InferenceTime', ':6.3f')
    decoding_time = AverageMeter('DecodingTime', ':6.3f')

    if (args.calibration or args.accuracy_mode or args.throughput_mode
            or args.latency_mode) is False:
        print(
            "one of --calibration, --accuracy-mode, --throughput-mode, --latency-mode must be input."
        )
        exit(-1)

    if args.accuracy_mode:
        if args.iteration is not None:
            print("accuracy mode should not input iteration")
        progress_meter_iteration = len(val_dataloader)
        epoch_number = 1
    else:
        if args.iteration is None:
            print("None accuracy mode must input --iteration")
            exit(-1)
        progress_meter_iteration = args.iteration
        epoch_number = (args.iteration // len(val_dataloader) +
                        1) if (args.iteration > len(val_dataloader)) else 1

    progress = ProgressMeter(progress_meter_iteration,
                             [inference_time, decoding_time],
                             prefix='Test: ')

    Profilling_iterator = 99
    start = time.time()
    if args.int8:
        model = model.eval()
        print('int8 conv_bn_fusion enabled')
        with torch.no_grad():
            model.model = optimization.fuse(model.model, inplace=False)

            if args.calibration:
                print("runing int8 LLGA calibration step\n")
                conf = ipex.quantization.QuantConf(
                    qscheme=torch.per_tensor_affine
                )  # qscheme can is torch.per_tensor_affine, torch.per_tensor_symmetric
                with torch.no_grad():
                    for nbatch, (img, img_id, img_size, bbox,
                                 label) in enumerate(val_dataloader):
                        print("nbatch:{}".format(nbatch))
                        with ipex.quantization.calibrate(conf):
                            ploc, plabel = model(img)
                        if nbatch == args.iteration:
                            break
                conf.save(args.configure)
                return

            else:
                print("INT8 LLGA start trace")
                # insert quant/dequant based on configure.json
                conf = ipex.quantization.QuantConf(
                    configure_file=args.configure)
                model = ipex.quantization.convert(
                    model, conf,
                    torch.randn(args.batch_size, 3, 1200,
                                1200).to(memory_format=torch.channels_last))
                print("done ipex default recipe.......................")
                # freeze the module
                # model = torch.jit._recursive.wrap_cpp_module(torch._C._freeze_module(model._c, preserveParameters=True))

                # After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
                # At the 2nd run, the llga pass will be triggered and the model is turned into an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
                with torch.no_grad():
                    for i in range(2):
                        #_, _ = model(torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last))
                        _, _ = model(
                            torch.randn(
                                args.batch_size, 3, 1200,
                                1200).to(memory_format=torch.channels_last))

                print('runing int8 real inputs inference path')
                with torch.no_grad():
                    total_iteration = 0
                    for epoch in range(epoch_number):
                        for nbatch, (img, img_id, img_size, bbox,
                                     label) in enumerate(val_dataloader):
                            img = img.to(memory_format=torch.channels_last)
                            if total_iteration >= args.warmup_iterations:
                                start_time = time.time()

                            if args.profile and total_iteration == Profilling_iterator:
                                print("Profilling")
                                with torch.profiler.profile(
                                        on_trace_ready=torch.profiler.
                                        tensorboard_trace_handler(
                                            "./int8_log")) as prof:
                                    # ploc, plabel = model(img.to(memory_format=torch.channels_last))
                                    ploc, plabel = model(img)
                                print(prof.key_averages().table(
                                    sort_by="self_cpu_time_total"))
                            else:
                                #ploc, plabel = model(img.to(memory_format=torch.channels_last))
                                ploc, plabel = model(img)

                            if total_iteration >= args.warmup_iterations:
                                inference_time.update(time.time() - start_time)
                                end_time = time.time()
                            try:
                                results_raw = encoder.decode_batch(
                                    ploc, plabel, 0.50, 200, device=device)
                            except:
                                print(
                                    "No object detected in nbatch: {}".format(
                                        total_iteration))
                                continue
                            if total_iteration >= args.warmup_iterations:
                                decoding_time.update(time.time() - end_time)

                            # Re-assembly the result
                            results = []
                            idx = 0
                            for i in range(results_raw[3].size(0)):
                                results.append(
                                    (results_raw[0][idx:idx +
                                                    results_raw[3][i]],
                                     results_raw[1][idx:idx +
                                                    results_raw[3][i]],
                                     results_raw[2][idx:idx +
                                                    results_raw[3][i]]))
                                idx += results_raw[3][i]

                            (htot, wtot) = [d.cpu().numpy() for d in img_size]
                            img_id = img_id.cpu().numpy()
                            # Iterate over batch elements
                            for img_id_, wtot_, htot_, result in zip(
                                    img_id, wtot, htot, results):
                                loc, label, prob = [
                                    r.cpu().numpy() for r in result
                                ]
                                # Iterate over image detections
                                for loc_, label_, prob_ in zip(
                                        loc, label, prob):
                                    ret.append([img_id_, loc_[0]*wtot_, \
                                                loc_[1]*htot_,
                                                (loc_[2] - loc_[0])*wtot_,
                                                (loc_[3] - loc_[1])*htot_,
                                                prob_,
                                                inv_map[label_]])

                            if total_iteration % args.print_freq == 0:
                                progress.display(total_iteration)
                            if total_iteration == args.iteration:
                                break
                            total_iteration += 1
    else:
        if args.dummy:
            print('dummy inputs inference path is not supported')
        else:
            print('runing real inputs path')
            if args.autocast:
                print('bf16 autocast enabled')
                print('enable nhwc')
                model = model.to(memory_format=torch.channels_last)
                if use_ipex:
                    print('bf16 block format weights cache enabled')
                    model.model = ipex.optimize(model.model,
                                                dtype=torch.bfloat16,
                                                inplace=False)
                    # model = ipex.utils._convert_module_data_type(model, torch.bfloat16)
                else:
                    from oob_utils import conv_bn_fuse
                    print('OOB bf16 conv_bn_fusion enabled')
                    model.model = conv_bn_fuse(model.model)

                if args.jit:
                    if use_ipex:
                        print('enable IPEX jit path')
                        with torch.cpu.amp.autocast(), torch.no_grad():
                            model = torch.jit.trace(
                                model,
                                torch.randn(args.batch_size, 3, 1200, 1200).to(
                                    memory_format=torch.channels_last)).eval()
                    else:
                        print('enable OOB jit path')
                        with torch.cpu.amp.autocast(
                                cache_enabled=False), torch.no_grad():
                            model = torch.jit.trace(
                                model,
                                torch.randn(args.batch_size, 3, 1200, 1200).to(
                                    memory_format=torch.channels_last)).eval()

                    model = torch.jit.freeze(model)
                    with torch.no_grad():
                        total_iteration = 0
                        for epoch in range(epoch_number):
                            for nbatch, (img, img_id, img_size, bbox,
                                         label) in enumerate(val_dataloader):
                                with torch.no_grad():
                                    if total_iteration >= args.warmup_iterations:
                                        start_time = time.time()

                                    img = img.to(
                                        memory_format=torch.channels_last)
                                    if args.profile and total_iteration == Profilling_iterator:
                                        print("Profilling")
                                        with torch.profiler.profile(
                                                on_trace_ready=torch.profiler.
                                                tensorboard_trace_handler(
                                                    "./log")) as prof:
                                            ploc, plabel = model(img)
                                        print(prof.key_averages().table(
                                            sort_by="self_cpu_time_total"))
                                    else:
                                        ploc, plabel = model(img)
                                    if total_iteration >= args.warmup_iterations:
                                        inference_time.update(time.time() -
                                                              start_time)
                                        end_time = time.time()

                                    try:
                                        if args.profile and total_iteration == Profilling_iterator:
                                            with torch.profiler.profile(
                                                    on_trace_ready=torch.
                                                    profiler.
                                                    tensorboard_trace_handler(
                                                        "./decode_log"
                                                    )) as prof:
                                                results_raw = encoder.decode_batch(
                                                    ploc,
                                                    plabel,
                                                    0.50,
                                                    200,
                                                    device=device)
                                            print(prof.key_averages().table(
                                                sort_by="self_cpu_time_total"))
                                        else:
                                            results_raw = encoder.decode_batch(
                                                ploc,
                                                plabel,
                                                0.50,
                                                200,
                                                device=device)
                                    except:
                                        print(
                                            "No object detected in nbatch: {}".
                                            format(total_iteration))
                                        continue
                                    if total_iteration >= args.warmup_iterations:
                                        decoding_time.update(time.time() -
                                                             end_time)

                                    # Re-assembly the result
                                    results = []
                                    idx = 0
                                    for i in range(results_raw[3].size(0)):
                                        results.append((
                                            results_raw[0][idx:idx +
                                                           results_raw[3][i]],
                                            results_raw[1][idx:idx +
                                                           results_raw[3][i]],
                                            results_raw[2][idx:idx +
                                                           results_raw[3][i]]))
                                        idx += results_raw[3][i]

                                    (htot, wtot) = [
                                        d.cpu().numpy() for d in img_size
                                    ]
                                    img_id = img_id.cpu().numpy()

                                    for img_id_, wtot_, htot_, result in zip(
                                            img_id, wtot, htot, results):
                                        loc, label, prob = [
                                            r.cpu().numpy() for r in result
                                        ]
                                        # Iterate over image detections
                                        for loc_, label_, prob_ in zip(
                                                loc, label, prob):
                                            ret.append([img_id_, loc_[0]*wtot_, \
                                                        loc_[1]*htot_,
                                                        (loc_[2] - loc_[0])*wtot_,
                                                        (loc_[3] - loc_[1])*htot_,
                                                        prob_,
                                                        inv_map[label_]])

                                    if total_iteration % args.print_freq == 0:
                                        progress.display(total_iteration)
                                    if total_iteration == args.iteration:
                                        break
                                    total_iteration += 1
                else:
                    if use_ipex:
                        print('Ipex Autocast imperative path')
                        with torch.cpu.amp.autocast(), torch.no_grad():
                            total_iteration = 0
                            for epoch in range(epoch_number):
                                for nbatch, (
                                        img, img_id, img_size, bbox,
                                        label) in enumerate(val_dataloader):
                                    with torch.no_grad():
                                        if total_iteration >= args.warmup_iterations:
                                            start_time = time.time()
                                        img = img.contiguous(
                                            memory_format=torch.channels_last)
                                        if args.profile and total_iteration == Profilling_iterator:
                                            print("Profilling")
                                            with torch.profiler.profile(
                                                    on_trace_ready=torch.
                                                    profiler.
                                                    tensorboard_trace_handler(
                                                        "./bf16_imperative_log"
                                                    )) as prof:
                                                ploc, plabel = model(img)
                                            print(prof.key_averages().table(
                                                sort_by="self_cpu_time_total"))
                                        else:
                                            ploc, plabel = model(img)

                                        if total_iteration >= args.warmup_iterations:
                                            inference_time.update(time.time() -
                                                                  start_time)
                                            end_time = time.time()

                                        try:
                                            results_raw = encoder.decode_batch(
                                                ploc,
                                                plabel,
                                                0.50,
                                                200,
                                                device=device)
                                        except:
                                            print(
                                                "No object detected in total_iteration: {}"
                                                .format(total_iteration))
                                            continue
                                        if total_iteration >= args.warmup_iterations:
                                            decoding_time.update(time.time() -
                                                                 end_time)

                                        # Re-assembly the result
                                        results = []
                                        idx = 0
                                        for i in range(results_raw[3].size(0)):
                                            results.append(
                                                (results_raw[0]
                                                 [idx:idx + results_raw[3][i]],
                                                 results_raw[1]
                                                 [idx:idx + results_raw[3][i]],
                                                 results_raw[2]
                                                 [idx:idx +
                                                  results_raw[3][i]]))
                                            idx += results_raw[3][i]

                                        (htot, wtot) = [
                                            d.cpu().numpy() for d in img_size
                                        ]
                                        img_id = img_id.cpu().numpy()

                                        for img_id_, wtot_, htot_, result in zip(
                                                img_id, wtot, htot, results):
                                            loc, label, prob = [
                                                r.cpu().numpy() for r in result
                                            ]
                                            # Iterate over image detections
                                            for loc_, label_, prob_ in zip(
                                                    loc, label, prob):
                                                ret.append([img_id_, loc_[0]*wtot_, \
                                                            loc_[1]*htot_,
                                                            (loc_[2] - loc_[0])*wtot_,
                                                            (loc_[3] - loc_[1])*htot_,
                                                            prob_,
                                                            inv_map[label_]])

                                        if total_iteration % args.print_freq == 0:
                                            progress.display(total_iteration)
                                        if total_iteration == args.iteration:
                                            break
                                        total_iteration += 1
                    else:
                        print("OOB Autocast imperative path")
                        with torch.cpu.amp.autocast(), torch.no_grad():
                            total_iteration = 0
                            for epoch in range(epoch_number):
                                for nbatch, (
                                        img, img_id, img_size, bbox,
                                        label) in enumerate(val_dataloader):
                                    if total_iteration >= args.warmup_iterations:
                                        start_time = time.time()
                                    img = img.contiguous(
                                        memory_format=torch.channels_last)
                                    if args.profile and total_iteration == Profilling_iterator:
                                        print("Profilling")
                                        with torch.profiler.profile(
                                                on_trace_ready=torch.profiler.
                                                tensorboard_trace_handler(
                                                    "./bf16_oob_log")) as prof:
                                            ploc, plabel = model(img)
                                        print(prof.key_averages().table(
                                            sort_by="self_cpu_time_total"))
                                    else:
                                        ploc, plabel = model(img)

                                    if total_iteration >= args.warmup_iterations:
                                        inference_time.update(time.time() -
                                                              start_time)
                                        end_time = time.time()

                                    with torch.cpu.amp.autocast(enabled=False):
                                        try:
                                            results_raw = encoder.decode_batch(
                                                ploc,
                                                plabel,
                                                0.50,
                                                200,
                                                device=device)
                                        except:
                                            print(
                                                "No object detected in total_iteration: {}"
                                                .format(total_iteration))
                                            continue
                                        if total_iteration >= args.warmup_iterations:
                                            decoding_time.update(time.time() -
                                                                 end_time)

                                        # Re-assembly the result
                                        results = []
                                        idx = 0
                                        for i in range(results_raw[3].size(0)):
                                            results.append(
                                                (results_raw[0]
                                                 [idx:idx + results_raw[3][i]],
                                                 results_raw[1]
                                                 [idx:idx + results_raw[3][i]],
                                                 results_raw[2]
                                                 [idx:idx +
                                                  results_raw[3][i]]))
                                            idx += results_raw[3][i]

                                        (htot, wtot) = [
                                            d.cpu().numpy() for d in img_size
                                        ]
                                        img_id = img_id.cpu().numpy()

                                        for img_id_, wtot_, htot_, result in zip(
                                                img_id, wtot, htot, results):
                                            loc, label, prob = [
                                                r.cpu().numpy() for r in result
                                            ]
                                            # Iterate over image detections
                                            for loc_, label_, prob_ in zip(
                                                    loc, label, prob):
                                                ret.append([img_id_, loc_[0]*wtot_, \
                                                            loc_[1]*htot_,
                                                            (loc_[2] - loc_[0])*wtot_,
                                                            (loc_[3] - loc_[1])*htot_,
                                                            prob_,
                                                            inv_map[label_]])

                                        if total_iteration % args.print_freq == 0:
                                            progress.display(total_iteration)
                                        if total_iteration == args.iteration:
                                            break
                                        total_iteration += 1
            else:
                print('autocast disabled, fp32 is used')
                print('enable nhwc')
                model = model.to(memory_format=torch.channels_last)
                if use_ipex:
                    print('fp32 block format weights cache enabled')
                    model.model = ipex.optimize(model.model,
                                                dtype=torch.float32,
                                                inplace=False)
                if args.jit:
                    print("enable jit")
                    with torch.no_grad():
                        model = torch.jit.trace(
                            model,
                            torch.randn(
                                args.batch_size, 3, 1200,
                                1200).to(memory_format=torch.channels_last))
                    model = torch.jit.freeze(model)
                with torch.no_grad():
                    total_iteration = 0
                    for epoch in range(epoch_number):
                        for nbatch, (img, img_id, img_size, bbox,
                                     label) in enumerate(val_dataloader):
                            if total_iteration >= args.warmup_iterations:
                                start_time = time.time()

                            img = img.contiguous(
                                memory_format=torch.channels_last)
                            if args.profile and total_iteration == Profilling_iterator:
                                print("Profilling")
                                with torch.profiler.profile(
                                        on_trace_ready=torch.profiler.
                                        tensorboard_trace_handler(
                                            "./fp32_log")) as prof:
                                    ploc, plabel = model(img)
                                print(prof.key_averages().table(
                                    sort_by="self_cpu_time_total"))
                            else:
                                ploc, plabel = model(img)
                            if total_iteration >= args.warmup_iterations:
                                inference_time.update(time.time() - start_time)
                                end_time = time.time()
                            try:
                                if args.profile and total_iteration == Profilling_iterator:
                                    with torch.profiler.profile(
                                            on_trace_ready=torch.profiler.
                                            tensorboard_trace_handler(
                                                "./fp32_decode_log")) as prof:
                                        results_raw = encoder.decode_batch(
                                            ploc,
                                            plabel,
                                            0.50,
                                            200,
                                            device=device)
                                    print(prof.key_averages().table(
                                        sort_by="self_cpu_time_total"))
                                else:
                                    results_raw = encoder.decode_batch(
                                        ploc, plabel, 0.50, 200, device=device)
                            except:
                                print(
                                    "No object detected in total_iteration: {}"
                                    .format(total_iteration))
                                continue
                            if total_iteration >= args.warmup_iterations:
                                decoding_time.update(time.time() - end_time)

                            # Re-assembly the result
                            results = []
                            idx = 0
                            for i in range(results_raw[3].size(0)):
                                results.append(
                                    (results_raw[0][idx:idx +
                                                    results_raw[3][i]],
                                     results_raw[1][idx:idx +
                                                    results_raw[3][i]],
                                     results_raw[2][idx:idx +
                                                    results_raw[3][i]]))
                                idx += results_raw[3][i]

                            (htot, wtot) = [d.cpu().numpy() for d in img_size]
                            img_id = img_id.cpu().numpy()
                            # Iterate over batch elements
                            for img_id_, wtot_, htot_, result in zip(
                                    img_id, wtot, htot, results):
                                loc, label, prob = [
                                    r.cpu().numpy() for r in result
                                ]
                                # Iterate over image detections
                                for loc_, label_, prob_ in zip(
                                        loc, label, prob):
                                    ret.append([img_id_, loc_[0]*wtot_, \
                                                loc_[1]*htot_,
                                                (loc_[2] - loc_[0])*wtot_,
                                                (loc_[3] - loc_[1])*htot_,
                                                prob_,
                                                inv_map[label_]])

                            if total_iteration % args.print_freq == 0:
                                progress.display(total_iteration)
                            if total_iteration == args.iteration:
                                break
                            total_iteration += 1
    print("Predicting Ended, total time: {:.2f} s".format(time.time() - start))

    batch_size = args.batch_size
    latency = inference_time.avg / batch_size * 1000
    perf = batch_size / inference_time.avg
    print('inference latency %.2f ms' % latency)
    print('inference performance %.2f fps' % perf)

    if not args.dummy:
        latency = decoding_time.avg / batch_size * 1000
        perf = batch_size / decoding_time.avg
        print('decoding latency %.2f ms' % latency)
        print('decodingperformance %.2f fps' % perf)

        total_time_avg = inference_time.avg + decoding_time.avg
        throughput = batch_size / total_time_avg
        print("Throughput: {:.3f} fps".format(throughput))

        if not args.accuracy_mode:
            return True

        cocoDt = cocoGt.loadRes(np.array(ret))

        E = COCOeval(cocoGt, cocoDt, iouType='bbox')
        E.evaluate()
        E.accumulate()
        E.summarize()
        print("Current AP: {:.5f} AP goal: {:.5f}".format(
            E.stats[0], threshold))
        print("Accuracy: {:.5f} ".format(E.stats[0]))

        return (
            E.stats[0] >= threshold
        )  #Average Precision  (AP) @[ IoU=050:0.95 | area=   all | maxDets=100 ]
    else:
        total_time_avg = inference_time.avg
        throughput = batch_size / total_time_avg
        print("Throughput: {:.3f} fps".format(throughput))
        return False