def test_skip_onnx_input_quantize_expected_exception():
    # test that a graph with already quantized inputs fails for this optimization

    int_input = onnx.helper.make_tensor_value_info(
        "input", TensorProto.UINT8, [1, 3, None, None]
    )
    qconv_node = onnx.helper.make_node(
        "QLinearConv",
        ["input", "scale", "zp", "w", "w_scale", "w_zp", "y_scale", "y_zp"],
        ["qconv_output"],
    )

    qconv_output = onnx.helper.make_tensor_value_info(
        "qconv_output", TensorProto.UINT8, [1, 1, None, None]
    )

    graph = onnx.helper.make_graph(
        [qconv_node],
        "test_graph",
        [int_input],
        [qconv_output],
        [],
    )
    model = onnx.helper.make_model(graph)
    with pytest.raises(RuntimeError):
        skip_onnx_input_quantize(model)
def test_skip_onnx_input_quantize():
    # make sample graph of fp32 input -> QuantizeLinear -> QLinearConv
    # verify that it is transformed to uint8 input -> QLinearConv

    float_input = onnx.helper.make_tensor_value_info(
        "input", TensorProto.FLOAT, [1, 3, None, None]
    )
    quant_node = onnx.helper.make_node(
        "QuantizeLinear",
        ["input", "scale", "zp"],
        ["quant_output"],
    )
    qconv_node = onnx.helper.make_node(
        "QLinearConv",
        ["quant_output", "scale", "zp", "w", "w_scale", "w_zp", "y_scale", "y_zp"],
        ["qconv_output"],
    )

    qconv_output = onnx.helper.make_tensor_value_info(
        "qconv_output", TensorProto.UINT8, [1, 1, None, None]
    )

    graph = onnx.helper.make_graph(
        [quant_node, qconv_node],
        "test_graph",
        [float_input],
        [qconv_output],
        [],
    )
    model = onnx.helper.make_model(graph)

    # initial model checks
    assert model.graph.input[0].type.tensor_type.elem_type == TensorProto.FLOAT
    assert len(model.graph.node) == 2
    assert model.graph.node[0].op_type == "QuantizeLinear"
    assert model.graph.node[1].op_type == "QLinearConv"

    assert model.graph.node[0].input[0] == model.graph.input[0].name
    assert model.graph.node[1].input[0] == model.graph.node[0].output[0]

    # run optimization
    skip_onnx_input_quantize(model)

    # check model has uint8 inputs and no qlinear input node
    assert model.graph.input[0].type.tensor_type.elem_type == TensorProto.UINT8
    assert len(model.graph.node) == 1
    assert model.graph.node[0].op_type == "QLinearConv"

    assert model.graph.node[0].input[0] == model.graph.input[0].name
Exemple #3
0
def train(hyp, opt, device, tb_writer=None, wandb=None):
    logger.info(
        colorstr('hyperparameters: ') + ', '.join(f'{k}={v}'
                                                  for k, v in hyp.items()))
    save_dir, epochs, batch_size, total_batch_size, weights, rank = \
        Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank

    # Directories
    wdir = save_dir / 'weights'
    wdir.mkdir(parents=True, exist_ok=True)  # make dir
    last = wdir / 'last.pt'
    best = wdir / 'best.pt'
    results_file = save_dir / 'results.txt'

    # Save run settings
    with open(save_dir / 'hyp.yaml', 'w') as f:
        yaml.dump(hyp, f, sort_keys=False)
    with open(save_dir / 'opt.yaml', 'w') as f:
        yaml.dump(vars(opt), f, sort_keys=False)

    # Configure
    plots = not opt.evolve  # create plots
    cuda = device.type != 'cpu'
    init_seeds(2 + rank)
    with open(opt.data) as f:
        data_dict = yaml.load(f, Loader=yaml.SafeLoader)  # data dict
    with torch_distributed_zero_first(rank):
        check_dataset(data_dict)  # check
    train_path = data_dict['train']
    test_path = data_dict['val']
    nc = 1 if opt.single_cls else int(data_dict['nc'])  # number of classes
    names = ['item'] if opt.single_cls and len(
        data_dict['names']) != 1 else data_dict['names']  # class names
    assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (
        len(names), nc, opt.data)  # check

    # Model
    pretrained = weights.endswith('.pt') or weights.endswith(
        '.pth')  # SparseML integration
    if pretrained:
        with torch_distributed_zero_first(rank):
            attempt_download(weights)  # download if not found locally
        ckpt = torch.load(weights, map_location=device)  # load checkpoint
        if hyp.get('anchors'):
            ckpt['model'].yaml['anchors'] = round(
                hyp['anchors'])  # force autoanchor
        model = Model(opt.cfg or ckpt['model'].yaml, ch=3,
                      nc=nc).to(device)  # create
        exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [
        ]  # exclude keys
        state_dict = _load_checkpoint_model_state_dict(
            ckpt)  # SparseML integration
        state_dict = intersect_dicts(state_dict,
                                     model.state_dict(),
                                     exclude=exclude)  # intersect
        model.load_state_dict(state_dict, strict=False)  # load
        logger.info(
            'Transferred %g/%g items from %s' %
            (len(state_dict), len(model.state_dict()), weights))  # report
    else:
        model = Model(opt.cfg, ch=3, nc=nc).to(device)  # create

    # Freeze
    freeze = []  # parameter names to freeze (full or partial)
    for k, v in model.named_parameters():
        v.requires_grad = True  # train all layers
        if any(x in k for x in freeze):
            print('freezing %s' % k)
            v.requires_grad = False

    # Optimizer
    nbs = 64  # nominal batch size
    accumulate = max(round(nbs / total_batch_size),
                     1)  # accumulate loss before optimizing
    hyp['weight_decay'] *= total_batch_size * accumulate / nbs  # scale weight_decay
    logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")

    pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
    for k, v in model.named_modules():
        if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
            pg2.append(v.bias)  # biases
        if isinstance(v, nn.BatchNorm2d):
            pg0.append(v.weight)  # no decay
        elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
            pg1.append(v.weight)  # apply decay

    if opt.adam:
        optimizer = optim.Adam(pg0,
                               lr=hyp['lr0'],
                               betas=(hyp['momentum'],
                                      0.999))  # adjust beta1 to momentum
    else:
        optimizer = optim.SGD(pg0,
                              lr=hyp['lr0'],
                              momentum=hyp['momentum'],
                              nesterov=True)

    optimizer.add_param_group({
        'params': pg1,
        'weight_decay': hyp['weight_decay']
    })  # add pg1 with weight_decay
    optimizer.add_param_group({'params': pg2})  # add pg2 (biases)
    logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' %
                (len(pg2), len(pg1), len(pg0)))
    del pg0, pg1, pg2

    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
    if opt.linear_lr:
        lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp[
            'lrf']  # linear
    else:
        lf = one_cycle(1, hyp['lrf'], epochs)  # cosine 1->hyp['lrf']
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    # plot_lr_scheduler(optimizer, scheduler, epochs)

    # Logging
    if rank in [-1, 0] and wandb and wandb.run is None:
        opt.hyp = hyp  # add hyperparameters
        wandb_run = wandb.init(
            config=opt,
            resume="allow",
            project='YOLOv5'
            if opt.project == 'runs/train' else Path(opt.project).stem,
            name=save_dir.stem,
            id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
    loggers = {'wandb': wandb}  # loggers dict

    # Resume
    start_epoch, best_fitness = 0, 0.0
    if pretrained:
        # Optimizer
        if ckpt['optimizer'] is not None:
            optimizer.load_state_dict(ckpt['optimizer'])
            best_fitness = ckpt['best_fitness']

        # Results
        if ckpt.get('training_results') is not None:
            with open(results_file, 'w') as file:
                file.write(ckpt['training_results'])  # write results.txt

        # Epochs
        start_epoch = ckpt['epoch'] + 1
        if opt.resume:
            assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (
                weights, epochs)
        if epochs < start_epoch:
            logger.info(
                '%s has been trained for %g epochs. Fine-tuning for %g additional epochs.'
                % (weights, ckpt['epoch'], epochs))
            epochs += ckpt['epoch']  # finetune additional epochs

        del ckpt, state_dict

    # Image sizes
    gs = int(model.stride.max())  # grid size (max stride)
    nl = model.model[
        -1].nl  # number of detection layers (used for scaling hyp['obj'])
    imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size
                         ]  # verify imgsz are gs-multiples

    # DP mode
    if cuda and rank == -1 and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    # SyncBatchNorm
    if opt.sync_bn and cuda and rank != -1:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
        logger.info('Using SyncBatchNorm()')

    # EMA
    ####################################################################################
    # Start SparseML Integration - optional EMA
    ####################################################################################
    ema = ModelEMA(model) if rank in [-1, 0] and opt.use_ema else None
    ####################################################################################
    # End SparseML Integration - optional EMA
    ####################################################################################

    # DDP mode
    if cuda and rank != -1:
        model = DDP(model,
                    device_ids=[opt.local_rank],
                    output_device=opt.local_rank)

    # Trainloader
    dataloader, dataset = create_dataloader(train_path,
                                            imgsz,
                                            batch_size,
                                            gs,
                                            opt,
                                            hyp=hyp,
                                            augment=True,
                                            cache=opt.cache_images,
                                            rect=opt.rect,
                                            rank=rank,
                                            world_size=opt.world_size,
                                            workers=opt.workers,
                                            image_weights=opt.image_weights,
                                            quad=opt.quad,
                                            prefix=colorstr('train: '))
    mlc = np.concatenate(dataset.labels, 0)[:, 0].max()  # max label class
    nb = len(dataloader)  # number of batches
    assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (
        mlc, nc, opt.data, nc - 1)

    # Process 0
    if rank in [-1, 0]:
        if ema:
            ema.updates = start_epoch * nb // accumulate  # set EMA updates
        testloader = create_dataloader(
            test_path,
            imgsz_test,
            batch_size * 2,
            gs,
            opt,  # testloader
            hyp=hyp,
            cache=opt.cache_images and not opt.notest,
            rect=True,
            rank=-1,
            world_size=opt.world_size,
            workers=opt.workers,
            pad=0.5,
            prefix=colorstr('val: '))[0]

        if not opt.resume:
            labels = np.concatenate(dataset.labels, 0)
            c = torch.tensor(labels[:, 0])  # classes
            # cf = torch.bincount(c.long(), minlength=nc) + 1.  # frequency
            # model._initialize_biases(cf.to(device))
            if plots:
                plot_labels(labels, save_dir, loggers)
                if tb_writer:
                    tb_writer.add_histogram('classes', c, 0)

            # Anchors
            if not opt.noautoanchor:
                check_anchors(dataset,
                              model=model,
                              thr=hyp['anchor_t'],
                              imgsz=imgsz)

    # Model parameters
    hyp['box'] *= 3. / nl  # scale to layers
    hyp['cls'] *= nc / 80. * 3. / nl  # scale to classes and layers
    hyp['obj'] *= (imgsz / 640)**2 * 3. / nl  # scale to image size and layers
    model.nc = nc  # attach number of classes to model
    model.hyp = hyp  # attach hyperparameters to model
    model.gr = 1.0  # iou loss ratio (obj_loss = 1.0 or iou)
    model.class_weights = labels_to_class_weights(
        dataset.labels, nc).to(device) * nc  # attach class weights
    model.names = names

    ####################################################################################
    # Start SparseML Integration
    ####################################################################################
    from sparseml.pytorch.nn import replace_activations
    from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer
    from sparseml.pytorch.utils import is_parallel_model, PythonLogger, TensorBoardLogger

    if not opt.no_leaky_relu_override:  # use LeakyReLU activations
        model = replace_activations(model, "lrelu", inplace=True)

    manager = ScheduledModifierManager.from_yaml(opt.sparseml_recipe)
    optimizer = ScheduledOptimizer(
        optimizer,
        model if not is_parallel_model(model) else model.module,
        manager,
        steps_per_epoch=len(dataloader),
        loggers=[PythonLogger(),
                 TensorBoardLogger(writer=tb_writer)])
    # override lr scheduler if recipe makes any LR updates
    if any("LearningRate" in str(modifier) for modifier in manager.modifiers):
        logger.info(
            "Disabling yolo LR scheduler, managing LR using SparseML recipe")
        scheduler = None

    # disable model pickling if QAT is set
    qat = False
    if any("Quantization" in str(modifier) for modifier in manager.modifiers):
        logger.info("Disabling pickling for Yolo model, QAT modifiers present")
        qat = True

    if manager.max_epochs:
        epochs = manager.max_epochs or epochs  # override num_epochs
        logger.info(
            f"overriding number of epochs from SparseML manager to {manager.max_epochs}"
        )
    ####################################################################################
    # End SparseML Integration
    ####################################################################################

    # Start training
    t0 = time.time()
    nw = max(round(hyp['warmup_epochs'] * nb),
             1000)  # number of warmup iterations, max(3 epochs, 1k iterations)
    # nw = min(nw, (epochs - start_epoch) / 2 * nb)  # limit warmup to < 1/2 of training
    maps = np.zeros(nc)  # mAP per class
    results = (0, 0, 0, 0, 0, 0, 0
               )  # P, R, [email protected], [email protected], val_loss(box, obj, cls)
    if scheduler:  # SparseML integration
        scheduler.last_epoch = start_epoch - 1  # do not move
    scaler = amp.GradScaler(enabled=(cuda and opt.use_amp))
    compute_loss = ComputeLoss(model)  # init loss class
    logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
                f'Using {dataloader.num_workers} dataloader workers\n'
                f'Logging results to {save_dir}\n'
                f'Starting training for {epochs} epochs...')
    for epoch in range(
            start_epoch, epochs
    ):  # epoch ------------------------------------------------------------------
        model.train()

        # Update image weights (optional)
        if opt.image_weights:
            # Generate indices
            if rank in [-1, 0]:
                cw = model.class_weights.cpu().numpy() * (
                    1 - maps)**2 / nc  # class weights
                iw = labels_to_image_weights(dataset.labels,
                                             nc=nc,
                                             class_weights=cw)  # image weights
                dataset.indices = random.choices(
                    range(dataset.n), weights=iw,
                    k=dataset.n)  # rand weighted idx
            # Broadcast if DDP
            if rank != -1:
                indices = (torch.tensor(dataset.indices)
                           if rank == 0 else torch.zeros(dataset.n)).int()
                dist.broadcast(indices, 0)
                if rank != 0:
                    dataset.indices = indices.cpu().numpy()

        # Update mosaic border
        # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
        # dataset.mosaic_border = [b - imgsz, -b]  # height, width borders

        mloss = torch.zeros(4, device=device)  # mean losses
        if rank != -1:
            dataloader.sampler.set_epoch(epoch)
        pbar = enumerate(dataloader)
        logger.info(
            ('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls',
                                   'total', 'targets', 'img_size'))
        if rank in [-1, 0]:
            pbar = tqdm(pbar, total=nb)  # progress bar
        optimizer.zero_grad()
        for i, (
                imgs, targets, paths, _
        ) in pbar:  # batch -------------------------------------------------------------
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.to(device, non_blocking=True).float(
            ) / 255.0  # uint8 to float32, 0-255 to 0.0-1.0

            # Warmup
            if ni <= nw:
                xi = [0, nw]  # x interp
                # model.gr = np.interp(ni, xi, [0.0, 1.0])  # iou loss ratio (obj_loss = 1.0 or iou)
                accumulate = max(
                    1,
                    np.interp(ni, xi, [1, nbs / total_batch_size]).round())
                for j, x in enumerate(optimizer.param_groups):
                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                    if scheduler:  # SparseML integration, do not force warmup lr when overriding
                        x['lr'] = np.interp(ni, xi, [
                            hyp['warmup_bias_lr'] if j == 2 else 0.0,
                            x['initial_lr'] * lf(epoch)
                        ])
                    if 'momentum' in x:
                        x['momentum'] = np.interp(
                            ni, xi, [hyp['warmup_momentum'], hyp['momentum']])

            # Multi-scale
            if opt.multi_scale:
                sz = random.randrange(imgsz * 0.5,
                                      imgsz * 1.5 + gs) // gs * gs  # size
                sf = sz / max(imgs.shape[2:])  # scale factor
                if sf != 1:
                    ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]
                          ]  # new shape (stretched to gs-multiple)
                    imgs = F.interpolate(imgs,
                                         size=ns,
                                         mode='bilinear',
                                         align_corners=False)

            # Forward
            with amp.autocast(enabled=(cuda and opt.use_amp)):
                pred = model(imgs)  # forward
                loss, loss_items = compute_loss(
                    pred, targets.to(device))  # loss scaled by batch_size
                if rank != -1:
                    loss *= opt.world_size  # gradient averaged between devices in DDP mode
                if opt.quad:
                    loss *= 4.

            # Backward
            scaler.scale(loss).backward()

            # Optimize
            if ni % accumulate == 0:
                scaler.step(optimizer)  # optimizer.step
                scaler.update()
                optimizer.zero_grad()
                if ema:
                    ema.update(model)

            # Print
            if rank in [-1, 0]:
                mloss = (mloss * i + loss_items) / (i + 1
                                                    )  # update mean losses
                mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9
                                 if torch.cuda.is_available() else 0)  # (GB)
                s = ('%10s' * 2 +
                     '%10.4g' * 6) % ('%g/%g' % (epoch, epochs - 1), mem,
                                      *mloss, targets.shape[0], imgs.shape[-1])
                pbar.set_description(s)

                # Plot
                if plots and ni < 3:
                    f = save_dir / f'train_batch{ni}.jpg'  # filename
                    Thread(target=plot_images,
                           args=(imgs, targets, paths, f),
                           daemon=True).start()
                    # if tb_writer:
                    #     tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
                    #     tb_writer.add_graph(model, imgs)  # add model to tensorboard
                elif plots and ni == 10 and wandb:
                    wandb.log(
                        {
                            "Mosaics": [
                                wandb.Image(str(x), caption=x.name)
                                for x in save_dir.glob('train*.jpg')
                                if x.exists()
                            ]
                        },
                        commit=False)

            # end batch ------------------------------------------------------------------------------------------------
        # end epoch ----------------------------------------------------------------------------------------------------

        # Scheduler
        lr = [x['lr'] for x in optimizer.param_groups]  # for tensorboard
        if scheduler:  # SparseML integration
            scheduler.step()

        # DDP process 0 or single-GPU
        if rank in [-1, 0]:
            # mAP
            if ema:
                ema.update_attr(model,
                                include=[
                                    'yaml', 'nc', 'hyp', 'gr', 'names',
                                    'stride', 'class_weights'
                                ])

            final_epoch = epoch + 1 == epochs
            if not opt.notest or final_epoch:  # Calculate mAP
                results, maps, times = test.test(
                    opt.data,
                    batch_size=batch_size * 2,
                    imgsz=imgsz_test,
                    model=ema.ema if ema else model,
                    single_cls=opt.single_cls,
                    dataloader=testloader,
                    save_dir=save_dir,
                    verbose=nc < 50 and final_epoch,
                    plots=plots and final_epoch,
                    log_imgs=opt.log_imgs if wandb else 0,
                    compute_loss=compute_loss,
                    half_precision=opt.use_amp)  # SparseML integration

            # Write
            with open(results_file, 'a') as f:
                f.write(
                    s + '%10.4g' * 7 % results +
                    '\n')  # P, R, [email protected], [email protected], val_loss(box, obj, cls)
            if len(opt.name) and opt.bucket:
                os.system('gsutil cp %s gs://%s/results/results%s.txt' %
                          (results_file, opt.bucket, opt.name))

            # Log
            tags = [
                'train/box_loss',
                'train/obj_loss',
                'train/cls_loss',  # train loss
                'metrics/precision',
                'metrics/recall',
                'metrics/mAP_0.5',
                'metrics/mAP_0.5:0.95',
                'val/box_loss',
                'val/obj_loss',
                'val/cls_loss',  # val loss
                'x/lr0',
                'x/lr1',
                'x/lr2'
            ]  # params
            for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
                if tb_writer:
                    tb_writer.add_scalar(tag, x, epoch)  # tensorboard
                if wandb:
                    wandb.log({tag: x}, step=epoch,
                              commit=tag == tags[-1])  # W&B

            # Update best mAP
            fi = fitness(np.array(results).reshape(
                1, -1))  # weighted combination of [P, R, [email protected], [email protected]]
            if fi > best_fitness:
                best_fitness = fi

            # Save model
            save = (not opt.nosave) or (final_epoch and not opt.evolve)
            if save:
                with open(results_file, 'r') as f:  # create checkpoint
                    ckpt_model = ema.ema if ema else model if not qat else model.state_dict(
                    )  # SparseML integration
                    ckpt = {
                        'epoch':
                        epoch,
                        'best_fitness':
                        best_fitness,
                        'training_results':
                        f.read(),
                        'model':
                        ckpt_model,  # SparseML integration
                        'optimizer':
                        None if final_epoch else optimizer.state_dict(),
                        'wandb_id':
                        wandb_run.id if wandb else None
                    }

                # Save last, best and delete
                torch.save(ckpt, last)
                if best_fitness == fi:
                    torch.save(ckpt, best)
                del ckpt
        # end epoch ----------------------------------------------------------------------------------------------------
    # end training

    if rank in [-1, 0]:
        # Strip optimizers
        final = best if best.exists() else last  # final model
        for f in [last, best]:
            if f.exists(
            ) and not qat:  # SparseML integration - qat state dict incompatible
                strip_optimizer(f)  # strip optimizers
        if opt.bucket:
            os.system(f'gsutil cp {final} gs://{opt.bucket}/weights')  # upload

        # Plots
        if plots:
            plot_results(save_dir=save_dir)  # save as results.png
            if wandb:
                files = [
                    'results.png', 'confusion_matrix.png',
                    *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]
                ]
                wandb.log({
                    "Results": [
                        wandb.Image(str(save_dir / f), caption=f)
                        for f in files if (save_dir / f).exists()
                    ]
                })
                if opt.log_artifacts:
                    wandb.log_artifact(artifact_or_path=str(final),
                                       type='model',
                                       name=save_dir.stem)

        # Test best.pt
        logger.info('%g epochs completed in %.3f hours.\n' %
                    (epoch - start_epoch + 1, (time.time() - t0) / 3600))
        if opt.data.endswith('coco.yaml') and nc == 80:  # if COCO
            for conf, iou, save_json in ([0.25, 0.45,
                                          False], [0.001, 0.65,
                                                   True]):  # speed, mAP tests
                # SparseML integration - load test model
                test_model = model if qat else attempt_load(final, device)
                if opt.use_amp:
                    test_model = test_model.half()
                results, _, _ = test.test(
                    opt.data,
                    batch_size=batch_size * 2,
                    imgsz=imgsz_test,
                    conf_thres=conf,
                    iou_thres=iou,
                    model=test_model,
                    single_cls=opt.single_cls,
                    dataloader=testloader,
                    save_dir=save_dir,
                    save_json=save_json,
                    plots=False,
                    half_precision=opt.use_amp)  # SparseML integration
        #################################################################################
        # Start SparseML ONNX Export
        #################################################################################
            from sparseml.pytorch.utils import ModuleExporter
            from sparseml.pytorch.utils.quantization import skip_onnx_input_quantize

            onnx_path = f"{save_dir}/model.onnx"
            logger.info(f"training complete, exporting ONNX to {onnx_path}")
            export_model = model.module if is_parallel_model(model) else model
            export_model.model[
                -1].export = True  # do not export grid post-procesing
            exporter = ModuleExporter(export_model, save_dir)
            exporter.export_onnx(torch.randn(1, 3, imgsz, imgsz),
                                 convert_qat=True)
            if qat:
                skip_onnx_input_quantize(onnx_path, onnx_path)
        #################################################################################
        # End SparseML ONNX Export
        #################################################################################

    else:
        dist.destroy_process_group()

    wandb.run.finish() if wandb and wandb.run else None
    torch.cuda.empty_cache()
    return results
Exemple #4
0
                            3: 'width'
                        },  # size(1,3,640,640)
                        'output': {
                            0: 'batch',
                            2: 'y',
                            3: 'x'
                        }
                    } if opt.dynamic else None)
            else:
                # export through SparseML so quantized and pruned graphs can be corrected
                save_dir = '/'.join(f.split('/')[:-1])
                save_name = f.split('/')[-1]
                exporter = ModuleExporter(model, save_dir)
                exporter.export_onnx(img, name=save_name, convert_qat=True)
                try:
                    skip_onnx_input_quantize(f, f)
                except:
                    pass

            # Checks
            model_onnx = onnx.load(f)  # load onnx model
            onnx.checker.check_model(model_onnx)  # check onnx model
            # print(onnx.helper.printable_graph(model_onnx.graph))  # print

            # Simplify
            if opt.simplify:
                try:
                    check_requirements(['onnx-simplifier'])
                    import onnxsim

                    print(
Exemple #5
0
def export_onnx(
    module: Module,
    sample_batch: Any,
    file_path: str,
    opset: int = DEFAULT_ONNX_OPSET,
    disable_bn_fusing: bool = True,
    convert_qat: bool = False,
    dynamic_axes: Union[str, Dict[str, List[int]]] = None,
    skip_input_quantize: bool = False,
    **export_kwargs,
):
    """
    Export an onnx file for the current module and for a sample batch.
    Sample batch used to feed through the model to freeze the graph for a
    particular execution.

    :param module: torch Module object to export
    :param sample_batch: the batch to export an onnx for, handles creating the
        static graph for onnx as well as setting dimensions
    :param file_path: path to the onnx file to save
    :param opset: onnx opset to use for exported model. Default is 11, if torch
        version is 1.2 or below, default is 9
    :param disable_bn_fusing: torch >= 1.7.0 only. Set True to disable batch norm
        fusing during torch export. Default and suggested setting is True. Batch
        norm fusing will change the exported parameter names as well as affect
        sensitivity analyses of the exported graph.  Additionally, the DeepSparse
        inference engine, and other engines, perform batch norm fusing at model
        compilation.
    :param convert_qat: if True and quantization aware training is detected in
        the module being exported, the resulting QAT ONNX model will be converted
        to a fully quantized ONNX model using `quantize_torch_qat_export`. Default
        is False.
    :param dynamic_axes: dictionary of input or output names to list of dimensions
        of those tensors that should be exported as dynamic. May input 'batch'
        to set the first dimension of all inputs and outputs to dynamic. Default
        is an empty dict
    :param skip_input_quantize: if True, the export flow will attempt to delete
        the first Quantize Linear Nodes(s) immediately after model input and set
        the model input type to UINT8. Default is False
    :param export_kwargs: kwargs to be passed as is to the torch.onnx.export api
        call. Useful to pass in dyanmic_axes, input_names, output_names, etc.
        See more on the torch.onnx.export api spec in the PyTorch docs:
        https://pytorch.org/docs/stable/onnx.html
    """
    if not export_kwargs:
        export_kwargs = {}

    if isinstance(sample_batch, Dict) and not isinstance(
            sample_batch, collections.OrderedDict):
        warnings.warn(
            "Sample inputs passed into the ONNX exporter should be in "
            "the same order defined in the model forward function. "
            "Consider using OrderedDict for this purpose.",
            UserWarning,
        )

    sample_batch = tensors_to_device(sample_batch, "cpu")
    create_parent_dirs(file_path)

    module = deepcopy(module).cpu()
    module.eval()

    with torch.no_grad():
        out = tensors_module_forward(sample_batch,
                                     module,
                                     check_feat_lab_inp=False)

    if "input_names" not in export_kwargs:
        if isinstance(sample_batch, Tensor):
            export_kwargs["input_names"] = ["input"]
        elif isinstance(sample_batch, Dict):
            export_kwargs["input_names"] = list(sample_batch.keys())
            sample_batch = tuple(
                [sample_batch[f] for f in export_kwargs["input_names"]])
        elif isinstance(sample_batch, Iterable):
            export_kwargs["input_names"] = [
                "input_{}".format(index)
                for index, _ in enumerate(iter(sample_batch))
            ]
            if isinstance(sample_batch, List):
                sample_batch = tuple(
                    sample_batch)  # torch.onnx.export requires tuple

    if "output_names" not in export_kwargs:
        export_kwargs["output_names"] = _get_output_names(out)

    if dynamic_axes == "batch":
        dynamic_axes = {
            tensor_name: {
                0: "batch"
            }
            for tensor_name in (export_kwargs["input_names"] +
                                export_kwargs["output_names"])
        }

    # disable active quantization observers because they cannot be exported
    disabled_observers = []
    for submodule in module.modules():
        if (hasattr(submodule, "observer_enabled")
                and submodule.observer_enabled[0] == 1):
            submodule.observer_enabled[0] = 0
            disabled_observers.append(submodule)

    is_quant_module = any(
        hasattr(submodule, "qconfig") and submodule.qconfig
        for submodule in module.modules())
    batch_norms_wrapped = False
    if torch.__version__ >= "1.7" and not is_quant_module and disable_bn_fusing:
        # prevent batch norm fusing by adding a trivial operation before every
        # batch norm layer
        batch_norms_wrapped = _wrap_batch_norms(module)

    torch.onnx.export(
        module,
        sample_batch,
        file_path,
        strip_doc_string=True,
        verbose=False,
        opset_version=opset,
        dynamic_axes=dynamic_axes,
        **export_kwargs,
    )

    # re-enable disabled quantization observers
    for submodule in disabled_observers:
        submodule.observer_enabled[0] = 1

    # onnx file fixes
    onnx_model = onnx.load(file_path)
    # fix changed batch norm names
    _fix_batch_norm_names(onnx_model)
    if batch_norms_wrapped:
        # clean up graph from any injected / wrapped operations
        _delete_trivial_onnx_adds(onnx_model)
    onnx.save(onnx_model, file_path)

    if convert_qat and is_quant_module:
        # overwrite exported model with fully quantized version
        quantize_torch_qat_export(model=file_path, output_file_path=file_path)

    if skip_input_quantize:
        try:
            skip_onnx_input_quantize(file_path, file_path)
        except Exception as e:
            _LOGGER.warning(
                f"Unable to skip input QuantizeLinear op with exception {e}")