Esempio n. 1
0
def main():

    args = parser.parse_args()

    if args.dataset == 'thermal':
        transform = transforms.Compose([
            Normalizer(inference_mode=True),
            Resizer(min_side=int(60), max_side=int(80), inference_mode=True)
        ])
    elif args.dataset == '3s-pocket-thermal-face':
        transform = transforms.Compose([
            Normalizer(inference_mode=True),
            Resizer(height=int(288),
                    width=int(384),
                    resize_mode=1,
                    inference_mode=True)
        ])
    else:
        raise ValueError('unknow dataset.')

    # print('network_name:', network_name)
    net_logger = logging.getLogger('Demo Logger')
    formatter = logging.Formatter(LOGGING_FORMAT)
    streamhandler = logging.StreamHandler()
    streamhandler.setFormatter(formatter)
    net_logger.addHandler(streamhandler)
    net_logger.setLevel(logging.INFO)

    net_logger.info('Positive Threshold: {:.2f}'.format(args.threshold))

    if args.resume is None:
        raise ValueError('Must provide --resume when testing.')

    build_param = {'logger': net_logger}
    if args.architecture == 'RetinaNet':
        model = retinanet.retinanet(args.depth,
                                    num_classes=args.num_classes,
                                    **build_param)
    elif args.architecture == 'RetinaNet-Tiny':
        model = retinanet.retinanet_tiny(num_classes=args.num_classes,
                                         **build_param)
    elif args.architecture == 'RetinaNet_P45P6':
        model = retinanet.retinanet_p45p6(num_classes=args.num_classes,
                                          **build_param)
    else:
        raise ValueError('Architecture <{}> unknown.'.format(
            args.architecture))

    net_logger.info('Loading Weights from Checkpoint : {}'.format(args.resume))
    model.load_state_dict(torch.load(args.resume))
    #model = torch.load(args.resume)

    use_gpu = True

    if use_gpu:
        if torch.cuda.is_available():
            model = model.cuda()

    if torch.cuda.is_available():
        model = torch.nn.DataParallel(model).cuda()
    else:
        model = torch.nn.DataParallel(model)

    demo_image_files = os.listdir(args.demo_path)
    demo_image_files.sort()
    if len(demo_image_files) > CONVERT_FILE_LIMIT:
        print('WARNING: Too many files...    total {} files.'.format(
            len(demo_image_files)))

    model.eval()

    img_array = []
    # print(model)

    #for f in demo_image_files:
    #for f in demo_image_files[:1]:
    # for f in demo_image_files[:100]:
    for f in demo_image_files[:min(len(demo_image_files), CONVERT_FILE_LIMIT)]:
        print(f)
        if f[-3:] not in ['png', 'jpg']:
            continue
        #img = skimage.io.imread(os.path.join(args.demo_path, f))
        #if len(img.shape) == 2:
        #    img = skimage.color.gray2rgb(img)
        #print(np.sum(img - a_pil_img))
        img = Image.open(os.path.join(args.demo_path, f)).convert('RGB')
        a_img = np.array(img)
        # print(a_img)
        a_img = a_img.astype(np.float32) / 255.0
        # print(a_img.shape)
        a_img = transform(a_img)
        # print(a_img.shape)
        a_img = torch.unsqueeze(a_img, 0)
        # print(a_img.shape)
        a_img = a_img.permute(0, 3, 1, 2)
        # print(a_img.shape)

        # print('predict...')
        scores, labels, boxes = model(a_img)

        scores = scores.cpu()
        labels = labels.cpu()
        boxes = boxes.cpu()

        # change to (x, y, w, h) (MS COCO standard)
        boxes[:, 2] -= boxes[:, 0]
        boxes[:, 3] -= boxes[:, 1]

        if args.dataset == 'thermal':
            img = img.resize((80, 60))

        draw = ImageDraw.Draw(img)
        for box_id in range(boxes.shape[0]):
            score = float(scores[box_id])
            label = int(labels[box_id])
            box = boxes[box_id, :]

            # scores are sorted, so we can break
            if score < args.threshold:
                break

            x, y, w, h = box
            #draw.rectangle(tuple([x, y, x+w, y+h]), width = 1, outline ='green')
            draw.rectangle(tuple([x, y, x + w, y + h]),
                           width=1,
                           outline=COLOR_LABEL[label])

            # append detection to results
            # results.append(image_result)
        #plt.figure()
        #plt.imshow(img)
        #plt.axis('off')
        #plt.show()
        img_array.append(np.array(img))

    height, width, layers = img_array[0].shape
    size = (width, height)
    fps = 25
    #fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')

    out_video_file = os.path.join(
        args.output_path, '{}.avi'.format(os.path.basename(args.demo_path)))
    print('Convert to video... {}'.format(out_video_file))
    out = cv2.VideoWriter(out_video_file, cv2.VideoWriter_fourcc(*'mp4v'), fps,
                          size)

    for i in range(len(img_array)):
        out.write(img_array[i])

    out.release()

    print('Done')
Esempio n. 2
0
def main():
    args = get_args()
    if args.resume is None:
        raise ValueError('Must provide --resume when testing.')

    support_architectures = [
        'ksevendet',
    ]
    support_architectures += [f'efficientdet-d{i}' for i in range(8)]
    support_architectures += [
        f'retinanet-res{i}' for i in [18, 34, 50, 101, 152]
    ]

    support_architectures.append('retinanet-p45p6')

    print(support_architectures)

    if args.architecture == 'ksevendet':
        ksevendet_cfg = args.model_cfg
        if ksevendet_cfg.get('variant'):
            network_name = f'{args.architecture}-{ksevendet_cfg["variant"]}-{ksevendet_cfg["neck"]}'
        else:
            assert 0, 'not support now.'
            assert isinstance(ksevendet_cfg, dict)
            network_name = f'{args.architecture}-{ksevendet_cfg["backbone"]}_specifical-{ksevendet_cfg["neck"]}'
    elif args.architecture in support_architectures:
        network_name = args.architecture
    else:
        raise ValueError('Architecture {} is not support.'.format(
            args.architecture))

    args.network_name = network_name
    net_logger = get_logger(name='Network Logger', args=args)
    net_logger.info('Positive Threshold: {:.2f}'.format(args.threshold))

    _shape_1, _shape_2 = tuple(map(int, args.input_shape.split(',')))
    _normalizer = Normalizer(inference_mode=True)
    if args.resize_mode == 0:
        _resizer = Resizer(min_side=_shape_1,
                           max_side=_shape_2,
                           resize_mode=args.resize_mode,
                           logger=net_logger,
                           inference_mode=True)
    elif args.resize_mode == 1:
        _resizer = Resizer(height=_shape_1,
                           width=_shape_2,
                           resize_mode=args.resize_mode,
                           logger=net_logger,
                           inference_mode=True)
    else:
        raise ValueError('Illegal resize mode.')

    transfrom_funcs_valid = [
        _normalizer,
        _resizer,
    ]
    transform = transforms.Compose(transfrom_funcs_valid)

    net_logger.info('Number of Classes: {:>3}'.format(args.num_classes))

    build_param = {'logger': net_logger}
    if args.architecture == 'ksevendet':
        net_model = ksevendet.KSevenDet(ksevendet_cfg,
                                        num_classes=args.num_classes,
                                        pretrained=False,
                                        **build_param)
    elif args.architecture == 'retinanet-p45p6':
        net_model = retinanet.retinanet_p45p6(num_classes=args.num_classes,
                                              **build_param)
    elif args.architecture.split('-')[0] == 'retinanet':
        net_model = retinanet.build_retinanet(args.architecture,
                                              num_classes=args.num_classes,
                                              pretrained=False,
                                              **build_param)
    elif args.architecture.split('-')[0] == 'efficientdet':
        net_model = efficientdet.build_efficientdet(
            args.architecture,
            num_classes=args.num_classes,
            pretrained=False,
            **build_param)
    else:
        assert 0, 'architecture error'

    net_logger.info('Loading Weights from Checkpoint : {}'.format(args.resume))
    net_model.load_state_dict(torch.load(args.resume))
    #model = torch.load(args.resume)

    use_gpu = True

    if use_gpu:
        if torch.cuda.is_available():
            net_model = net_model.cuda()

    if torch.cuda.is_available():
        net_model = torch.nn.DataParallel(net_model).cuda()
    else:
        net_model = torch.nn.DataParallel(net_model)

    #net_model.eval()
    net_model.module.eval()

    img_array = []

    cap = cv2.VideoCapture(args.input_path)
    fontsize = 12
    score_font = ImageFont.truetype("DejaVuSans.ttf", size=fontsize)

    cap_i = 0
    while (cap.isOpened()):
        ret, frame = cap.read()
        if ret == False:
            break
        #if cap_i > 20:
        #    break
        #img = skimage.io.imread(os.path.join(args.demo_path, f))
        #if len(img.shape) == 2:
        #    img = skimage.color.gray2rgb(img)
        a_img = np.copy(frame)
        img = Image.fromarray(np.uint8(frame))
        a_img = a_img.astype(np.float32) / 255.0
        a_img = transform(a_img)
        a_img = torch.unsqueeze(a_img, 0)
        a_img = a_img.permute(0, 3, 1, 2)

        # print('predict...')
        scores, labels, boxes = net_model(a_img, return_loss=False)

        scores = scores.cpu()
        labels = labels.cpu()
        boxes = boxes.cpu()

        # change to (x, y, w, h) (MS COCO standard)
        boxes[:, 2] -= boxes[:, 0]
        boxes[:, 3] -= boxes[:, 1]

        print(f'{cap_i}   inference ...', end="\r")

        draw = ImageDraw.Draw(img)
        for box_id in range(boxes.shape[0]):
            score = float(scores[box_id])
            label = int(labels[box_id])
            box = boxes[box_id, :]

            # scores are sorted, so we can break
            if score < args.threshold:
                break

            x, y, w, h = box
            color_ = COLOR_LABEL[label]
            _text_offset_x, _text_offset_y = 2, 3
            draw.rectangle(tuple([x, y, x + w, y + h]),
                           width=1,
                           outline=color_)
            draw.text(tuple(
                [int(x) + _text_offset_x + 1,
                 int(y) + _text_offset_y]),
                      '{:.3f}'.format(score),
                      fill='#000000',
                      font=score_font)
            draw.text(tuple([int(x) + _text_offset_x,
                             int(y) + _text_offset_y]),
                      '{:.3f}'.format(score),
                      fill=color_,
                      font=score_font)

        img_array.append(np.asarray(img))
        cap_i += 1

    cap.release()

    height, width, layers = img_array[0].shape
    size = (width, height)
    fps = 30
    #fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')

    input_video_name = os.path.basename(args.input_path)
    input_video_dir = os.path.dirname(args.input_path)
    out_video_path = os.path.join(
        'trash', '{}_{}_thr{}.avi'.format(
            input_video_name[:-4],
            network_name if not args.model_name else args.model_name,
            int(args.threshold * 100)))
    print('Convert to video... {}'.format(out_video_path))
    out = cv2.VideoWriter(out_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps,
                          size)

    for i in range(len(img_array)):
        out.write(img_array[i])

    out.release()

    print('Done')
Esempio n. 3
0
def main():
    #args = parser.parse_args()
    args = get_args()

    if args.pytorch_inference:
        assert args.architecture, 'Must provide --architecture when pytorch inference.'
        assert args.resume, 'Must provide --resume when pytorch inference.'
    if args.onnx_inference:
        assert args.onnx, 'Must provide --onnx when onnx inference.'

    #transform = transforms.Compose([Normalizer(inference_mode=True), 
    #                                Resizer(height=input_height, width=input_width, inference_mode=True)])
    #my_img_preprocessor = K7ImagePreprocessor(
    #                          mean = np.array([[[0.485, 0.456, 0.406]]]),
    #                          std  = np.array([[[0.229, 0.224, 0.225]]]),
    #                          resize_height=input_height, resize_width=input_width)
    if args.architecture == 'ksevendet':
        ksevendet_cfg = args.model_cfg
        if ksevendet_cfg.get('variant'):
            network_name = f'{args.architecture}-{ksevendet_cfg["variant"]}-{ksevendet_cfg["neck"]}'
        else:
            assert 0, 'not support now.'
            assert isinstance(ksevendet_cfg, dict)
            network_name = f'{args.architecture}-{ksevendet_cfg["backbone"]}_specifical-{ksevendet_cfg["neck"]}'
    elif args.architecture in support_architectures:
        network_name = args.architecture
    else:
        raise ValueError('Architecture {} is not support.'.format(args.architecture))

    net_logger = get_logger(name='ONNX Inference Logger', args=args)
    in_h, in_w = tuple(map(int, args.input_shape.split(',')))
    #in_h, in_w = int(in_h), int(in_w)
    net_logger.info(f'Input Tensor Size: [ {in_h}, {in_w}]')
    net_logger.info('Positive Threshold: {:.2f}'.format(args.threshold))

    _normalizer = Normalizer(inference_mode=True)
    if args.resize_mode == 0:
        assert 0, 'not use...'
        #_resizer = Resizer(min_side=_shape_1, max_side=_shape_2, resize_mode=args.resize_mode, logger=net_logger, inference_mode=True)
    elif args.resize_mode == 1:
        _resizer = Resizer(height=in_h, width=in_w, resize_mode=args.resize_mode, logger=net_logger, inference_mode=True)
    else:
        raise ValueError('Illegal resize mode.')

    transfrom_funcs_valid = [
        _normalizer,
        _resizer,
    ]
    transform = transforms.Compose(transfrom_funcs_valid)

    pytorch_model = None
    if args.pytorch_inference:
        net_logger.info('Build pytorch model...')
        build_param = {'logger': net_logger}
        if args.architecture == 'ksevendet':
            pytorch_model = ksevendet.KSevenDet(ksevendet_cfg, num_classes=args.num_classes, pretrained=False, **build_param)
        elif args.architecture == 'retinanet-p45p6':
            pytorch_model = retinanet.retinanet_p45p6(num_classes=args.num_classes, **build_param)
        elif args.architecture.split('-')[0] == 'retinanet':
            pytorch_model = retinanet.build_retinanet(args.architecture, num_classes=args.num_classes, pretrained=False, **build_param)
        elif args.architecture.split('-')[0] == 'efficientdet':
            pytorch_model = efficientdet.build_efficientdet(args.architecture, num_classes=args.num_classes, pretrained=False, **build_param)
        else:
            assert 0, 'architecture error'
    
        if args.resume is not None:
            #net_logger.info('Loading Checkpoint : {}'.format(args.resume))
            #model.load_state_dict(torch.load(args.resume))
            net_logger.info('Loading Weights from Checkpoint : {}'.format(args.resume))
            try:
                ret = pytorch_model.load_state_dict(torch.load(args.resume), strict=False)
            except RuntimeError as e:
                net_logger.warning(f'Ignoring {e}')
                net_logger.warning(f'Don\'t panic if you see this, this might be because you load a pretrained weights with different number of classes. The rest of the weights should be loaded already.')
        else:
            raise ValueError('Must provide --resume when testing.')

        use_gpu = True

        if use_gpu:
            if torch.cuda.is_available():
                pytorch_model = pytorch_model.cuda()

        #if torch.cuda.is_available():
        #   pytorch_model = torch.nn.DataParallel(pytorch_model).cuda()
        #else:
        #   pytorch_model = torch.nn.DataParallel(pytorch_model)

        pytorch_model.eval()
        net_logger.info('Initialize Pytorch Model...  Finished')

    onnx_model = None
    if args.onnx_inference:
        net_logger.info('Build onnx model...')
        net_logger.info('Onnx loading...')
        # Load the ONNX model
        #model = onnx.load('./retinanet-tiny.onnx')
        onnx_model = onnx.load(args.onnx)
        print(type(onnx_model))

        # Check that the IR is well formed
        net_logger.info('Onnx checking...')
        onnx.checker.check_model(onnx_model)

        ort_session = onnxruntime.InferenceSession(args.onnx)
        net_logger.info('Onnx initialize... Done')

        ## Print a human readable representation of the graph
        #onnx.helper.printable_graph(onnx_model.graph)

        #rep = backend.prepare(onnx_model, device="CUDA:0") # or "CPU"
        ## For the Caffe2 backend:
        ##     rep.predict_net is the Caffe2 protobuf for the network
        ##     rep.workspace is the Caffe2 workspace for the network
        ##       (see the class caffe2.python.onnx.backend.Workspace)


    sample_files = os.listdir(args.sample_path)
    sample_files.sort()

    tensor_comp = TensorCompare()

    # for f in sample_files[:1]:
    for f in sample_files:
        if f.split('.')[-1] not in ['png', 'jpg']:
            continue
        print(f'loading image... {f}')

        img = Image.open(os.path.join(args.sample_path, f)).convert('RGB')
        a_img = np.array(img)
        a_img = a_img.astype(np.float32) / 255.0

        #t_img = transform(a_img)
        #t_img = torch.unsqueeze(t_img, 0)
        #t_img = t_img.permute(0, 3, 1, 2)

        a_img = transform(a_img).numpy()
        # a_img = my_img_preprocessor(a_img)
        a_img = np.expand_dims(a_img, axis=0)
        a_img = a_img.transpose((0, 3, 1, 2))

        if args.dump_sample_npz:
            _npz_name = os.path.join(args.sample_path, f.split('.')[0]+'_preprocess.npz')
            print(f'save npz to ... {_npz_name}')
            np.savez(_npz_name, input=a_img)

        #tt = tensor_comp.compare(to_numpy(t_img), a_img)
        #pdb.set_trace()

        net_logger.info('pytorch inference start ...')
        # scores, labels, boxes = pytorch_model(aa_img)
        pytorch_scores, pytorch_labels, pytorch_boxes = pytorch_model(torch.from_numpy(a_img).to(torch.device("cuda:0")))
        net_logger.info('inference done.')
        if args.compare_head_tensor:
            pytorch_classification, pytorch_regression = pytorch_model(torch.from_numpy(a_img).to(torch.device("cuda:0")), return_head=True)

        pytorch_scores = pytorch_scores.cpu()
        pytorch_labels = pytorch_labels.cpu()
        pytorch_boxes  = pytorch_boxes.cpu()

        # change to (x, y, w, h) (MS COCO standard)
        pytorch_boxes[:, 2] -= pytorch_boxes[:, 0]
        pytorch_boxes[:, 3] -= pytorch_boxes[:, 1]

        draw = ImageDraw.Draw(img)
        for box_id in range(pytorch_boxes.shape[0]):
            score = float(pytorch_scores[box_id])
            label = int(pytorch_labels[box_id])
            box = pytorch_boxes[box_id, :]

            # scores are sorted, so we can break
            if score < args.threshold:
                break

            x, y, w, h = box
            # draw.rectangle(tuple([x, y, x+w, y+h]), width = 2, outline =COLOR_LABEL[label])
            
        _img_save_path = os.path.join(args.sample_path, os.path.basename(f)[:-4]+'_inference_visual.jpg')

        #onnx_input = a_img
        #onnx_input = to_tensor(a_img)
        #onnx_input.unsqueeze_(0)
        print('Onnx Inference')
        print('Input.shape = {}'.format(str(a_img.shape)))
        print('start inference')
        # outputs = rep.run(a_img)
        # ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(onnx_input)}
        ort_inputs = {ort_session.get_inputs()[0].name: a_img}

        net_logger.info('pytorch inference start ...')
        ort_outs = ort_session.run(None, ort_inputs)
        net_logger.info('inference done.')

        # onnx_out = ort_outs[0]
        # To run networks with more than one input, pass a tuple
        # rather than a single numpy ndarray.

        #print(f'len(ort_outs) = {len(ort_outs)}')
        #for i, t in enumerate(ort_outs):
        #    print(f'out[{i}] type is {type(t)}')
        #    print(f'out[{i}] shape is {t.shape}')
        
        classification, regression = ort_outs

        # this is wrong code?
        anchors = pytorch_model.anchors(a_img)
        # print(anchors)
        # print(anchors.shape)

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        transformed_anchors = pytorch_model.regressBoxes(anchors, torch.from_numpy(regression).float().to(device))
        transformed_anchors = pytorch_model.clipBoxes(transformed_anchors, a_img)
        transformed_anchors = to_numpy(transformed_anchors)

        #scores = torch.max(classification, dim=2, keepdim=True)[0]
        onnx_scores = np.max(classification, axis=2, keepdims=True)[0]

        # scores_over_thresh = (onnx_scores > 0.05)[0, :, 0]
        onnx_scores_over_thresh = (onnx_scores > 0.05)[:, 0]

        # print(f'onnx_scores_over_thresh.sum() = {onnx_scores_over_thresh.sum()}')
        if onnx_scores_over_thresh.sum() == 0:
            print('No boxes to NMS')
            # no boxes to NMS, just return
            # return [torch.zeros(0), torch.zeros(0), torch.zeros(0, 4)]

        classification = classification[:, onnx_scores_over_thresh, :]
        transformed_anchors = transformed_anchors[:, onnx_scores_over_thresh, :]
        #onnx_scores = onnx_scores[:, onnx_scores_over_thresh, :]
        onnx_scores = onnx_scores[onnx_scores_over_thresh, :]

        anchors_nms_idx = nms(transformed_anchors[0,:,:], onnx_scores[:,0], 0.5)
        # pdb.set_trace()
        onnx_nms_scores = classification[0, anchors_nms_idx, :].max(axis=1)
        onnx_nms_class  = classification[0, anchors_nms_idx, :].argmax(axis=1)
        onnx_boxes      = transformed_anchors[0, anchors_nms_idx, :]
        onnx_boxes[:, 2] -= onnx_boxes[:, 0]
        onnx_boxes[:, 3] -= onnx_boxes[:, 1]


        if args.compare_head_tensor:
            onnx_classification, onnx_regression = ort_outs
            net_logger.info('compare classification ...')
            _result = tensor_comp.compare(to_numpy(pytorch_classification), onnx_classification)
            print(f'Pass {_result[0]}')
            for k in _result[2]:
                print('{:>20} : {:<2.8f}'.format(k, _result[2][k]))
                if k == 'close_order':
                    print(_result)
            if not _result[0]:
                print('Not similar.')
                exit(0)
            net_logger.info('compare regression ...')
            _result = tensor_comp.compare(to_numpy(pytorch_regression), onnx_regression)
            print(f'Pass {_result[0]}')
            for k in _result[2]:
                print('{:>20} : {:<2.8f}'.format(k, _result[2][k]))
def main():
    args = get_args()
    if args.dataset == 'thermal':
        input_height, input_width = 60, 80
    elif args.dataset == '3s-pocket-thermal-face':
        input_height, input_width = 288, 384
        dataset_valid = CVIDataset(args.dataset_root,
                                   set_name='train',
                                   annotation_name='annotations.json',
                                   transform=transforms.Compose([
                                       Normalizer(),
                                       Resizer(height=input_height,
                                               width=input_width)
                                   ]))
    else:
        raise ValueError('unknow dataset.')
    transform = transforms.Compose([
        Normalizer(inference_mode=True),
        Resizer(height=input_height, width=input_width, inference_mode=True)
    ])

    # print('network_name:', network_name)
    my_logger = get_logger(description='Distiller Summary Logger')

    my_logger.info('Build pytorch model...')
    build_param = {'logger': my_logger}
    if args.architecture == 'RetinaNet':
        model = retinanet.retinanet(args.depth,
                                    num_classes=args.num_classes,
                                    **build_param)
    elif args.architecture == 'RetinaNet-Tiny':
        model = retinanet.retinanet_tiny(num_classes=args.num_classes,
                                         **build_param)
    elif args.architecture == 'RetinaNet_P45P6':
        model = retinanet.retinanet_p45p6(num_classes=args.num_classes,
                                          **build_param)
    else:
        raise ValueError('Architecture <{}> unknown.'.format(
            args.architecture))

    my_logger.info('Loading Weights from Checkpoint : {}'.format(args.resume))
    model.load_state_dict(torch.load(args.resume))

    use_gpu = True

    if use_gpu:
        if torch.cuda.is_available():
            model = model.cuda()

    if torch.cuda.is_available():
        model = torch.nn.DataParallel(model).cuda()
    else:
        model = torch.nn.DataParallel(model)

    # pdb.set_trace()

    #sample_input_shape = dataset_valid[0]['img'].shape
    sample_input_shape = dataset_valid[0]['img'].permute(
        2, 0, 1).cuda().float().unsqueeze(dim=0).shape

    print(sample_input_shape)

    model.eval()
    model_summaries.model_summary(model,
                                  'compute',
                                  input_shape=sample_input_shape)
def main():
    args = get_args()
    assert args.dataset, 'dataset must provide'
    if args.resume is None:
        raise ValueError('Must provide --resume when testing.')

    support_architectures = [
        'ksevendet',
    ]
    support_architectures += [f'efficientdet-d{i}' for i in range(8)]
    support_architectures += [
        f'retinanet-res{i}' for i in [18, 34, 50, 101, 152]
    ]

    support_architectures.append('retinanet-p45p6')

    print(support_architectures)

    if args.architecture == 'ksevendet':
        ksevendet_cfg = args.model_cfg
        if ksevendet_cfg.get('variant'):
            network_name = f'{args.architecture}-{ksevendet_cfg["variant"]}-{ksevendet_cfg["neck"]}'
        else:
            assert 0, 'not support now.'
            assert isinstance(ksevendet_cfg, dict)
            network_name = f'{args.architecture}-{ksevendet_cfg["backbone"]}_specifical-{ksevendet_cfg["neck"]}'
    elif args.architecture in support_architectures:
        network_name = args.architecture
    else:
        raise ValueError('Architecture {} is not support.'.format(
            args.architecture))

    args.network_name = network_name
    net_logger = get_logger(name='Network Logger', args=args)
    net_logger.info('Positive Threshold: {:.2f}'.format(args.threshold))

    _shape_1, _shape_2 = tuple(map(int, args.input_shape.split(',')))
    _normalizer = Normalizer(inference_mode=True)
    if args.resize_mode == 0:
        _resizer = Resizer(min_side=_shape_1,
                           max_side=_shape_2,
                           resize_mode=args.resize_mode,
                           logger=net_logger,
                           inference_mode=True)
    elif args.resize_mode == 1:
        _resizer = Resizer(height=_shape_1,
                           width=_shape_2,
                           resize_mode=args.resize_mode,
                           logger=net_logger,
                           inference_mode=True)
    else:
        raise ValueError('Illegal resize mode.')

    transfrom_funcs_valid = [
        _normalizer,
        _resizer,
    ]
    transform = transforms.Compose(transfrom_funcs_valid)

    net_logger.info('Number of Classes: {:>3}'.format(args.num_classes))

    build_param = {'logger': net_logger}
    if args.architecture == 'ksevendet':
        net_model = ksevendet.KSevenDet(ksevendet_cfg,
                                        num_classes=args.num_classes,
                                        pretrained=False,
                                        **build_param)
    elif args.architecture == 'retinanet-p45p6':
        net_model = retinanet.retinanet_p45p6(num_classes=args.num_classes,
                                              **build_param)
    elif args.architecture.split('-')[0] == 'retinanet':
        net_model = retinanet.build_retinanet(args.architecture,
                                              num_classes=args.num_classes,
                                              pretrained=False,
                                              **build_param)
    elif args.architecture.split('-')[0] == 'efficientdet':
        net_model = efficientdet.build_efficientdet(
            args.architecture,
            num_classes=args.num_classes,
            pretrained=False,
            **build_param)
    else:
        assert 0, 'architecture error'

    net_logger.info('Loading Weights from Checkpoint : {}'.format(args.resume))
    net_model.load_state_dict(torch.load(args.resume))
    #model = torch.load(args.resume)

    use_gpu = True

    if use_gpu:
        if torch.cuda.is_available():
            net_model = net_model.cuda()

    if torch.cuda.is_available():
        net_model = torch.nn.DataParallel(net_model).cuda()
    else:
        net_model = torch.nn.DataParallel(net_model)

    demo_image_files = os.listdir(args.demo_path)
    demo_image_files.sort()
    #if len(demo_image_files) > CONVERT_FILE_LIMIT:
    #    print('WARNING: Too many files...    total {} files.'.format(len(demo_image_files)))
    fontsize = 12
    score_font = ImageFont.truetype("DejaVuSans.ttf", size=fontsize)

    net_model.eval()

    img_array = []
    # print(net_model)

    for f in demo_image_files:
        #for f in demo_image_files[:1]:
        # for f in demo_image_files[:100]:
        #for f in demo_image_files[:min(len(demo_image_files), CONVERT_FILE_LIMIT)]:
        print(f'inference {f}', end="\r")
        if f[-3:] not in ['png', 'jpg']:
            continue
        #img = skimage.io.imread(os.path.join(args.demo_path, f))
        #if len(img.shape) == 2:
        #    img = skimage.color.gray2rgb(img)
        #print(np.sum(img - a_pil_img))
        img = Image.open(os.path.join(args.demo_path, f)).convert('RGB')
        a_img = np.array(img)
        # print(a_img)
        a_img = a_img.astype(np.float32) / 255.0
        # print(a_img.shape)
        a_img = transform(a_img)
        # print(a_img.shape)
        a_img = torch.unsqueeze(a_img, 0)
        # print(a_img.shape)
        a_img = a_img.permute(0, 3, 1, 2)
        # print(a_img.shape)

        # print('predict...')
        scores, labels, boxes = net_model(a_img, return_loss=False)

        scores = scores.cpu()
        labels = labels.cpu()
        boxes = boxes.cpu()

        # change to (x, y, w, h) (MS COCO standard)
        boxes[:, 2] -= boxes[:, 0]
        boxes[:, 3] -= boxes[:, 1]

        #if args.dataset == 'thermal':
        #    img = img.resize((80, 60))

        draw = ImageDraw.Draw(img)
        for box_id in range(boxes.shape[0]):
            score = float(scores[box_id])
            label = int(labels[box_id])
            box = boxes[box_id, :]

            # scores are sorted, so we can break
            if score < args.threshold:
                break

            x, y, w, h = box
            color_ = COLOR_LABEL[label]
            _text_offset_x, _text_offset_y = 2, 3
            #draw.rectangle(tuple([x, y, x+w, y+h]), width = 1, outline ='green')
            draw.rectangle(tuple([x, y, x + w, y + h]),
                           width=1,
                           outline=color_)
            draw.text(tuple(
                [int(x) + _text_offset_x + 1,
                 int(y) + _text_offset_y + 1]),
                      '{:.3f}'.format(score),
                      fill='#000000',
                      font=score_font)
            draw.text(tuple([int(x) + _text_offset_x,
                             int(y) + _text_offset_y]),
                      '{:.3f}'.format(score),
                      fill=color_,
                      font=score_font)

            # append detection to results
            # results.append(image_result)
        #plt.figure()
        #plt.imshow(img)
        #plt.axis('off')
        #plt.show()
        img_array.append(np.array(img))

    height, width, layers = img_array[0].shape
    size = (width, height)
    fps = 30
    #fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')

    out_video_file = os.path.join(
        args.output_path, '{}.avi'.format(
            os.path.basename(args.demo_path) if not args.output_name else args.
            output_name))
    print('Convert to video... {}'.format(out_video_file))
    out = cv2.VideoWriter(out_video_file, cv2.VideoWriter_fourcc(*'mp4v'), fps,
                          size)

    for i in range(len(img_array)):
        out.write(img_array[i])

    out.release()

    print('Done')
Esempio n. 6
0
def train_pruning_model(pruning_tensor_cfg,
                        cfg_name='',
                        tensor_pruning_dependency=None):
    assert tensor_pruning_dependency is not None, 'tensor_pruning_dependency is None'

    args = get_args()
    args.cfg_name = cfg_name

    assert args.dataset, 'dataset must provide'
    default_support_backbones = registry._module_to_models

    # write_support_backbones(default_support_backbones)

    support_architectures = [
        'ksevendet',
    ]
    support_architectures += [f'efficientdet-d{i}' for i in range(8)]
    support_architectures += [
        f'retinanet-res{i}' for i in [18, 34, 50, 101, 152]
    ]

    support_architectures.append('retinanet-p45p6')

    print(support_architectures)

    if args.architecture == 'ksevendet':
        ksevendet_cfg = args.model_cfg
        if ksevendet_cfg.get('variant'):
            network_name = f'{args.architecture}-{ksevendet_cfg["variant"]}-{ksevendet_cfg["neck"]}'
        else:
            assert isinstance(ksevendet_cfg, dict)
            network_name = f'{args.architecture}-{ksevendet_cfg["backbone"]}_specifical-{ksevendet_cfg["neck"]}'
    elif args.architecture in support_architectures:
        network_name = args.architecture
    else:
        raise ValueError('Architecture {} is not support.'.format(
            args.architecture))

    args.network_name = network_name

    net_logger = get_logger(name='pruning_{}_{}'.format(
        args.network_name, cfg_name),
                            args=args)
    net_logger.info('Network Name: {}'.format(network_name))
    net_logger.info('Dataset Name: {}'.format(args.dataset))
    net_logger.info('Dataset Root: {}'.format(args.dataset_root))
    net_logger.info('Dataset Type: {}'.format(args.dataset_type))
    net_logger.info('Training Epochs : {}'.format(args.epochs))
    net_logger.info('Batch Size      : {}'.format(args.batch_size))
    net_logger.info('Weight Decay    : {}'.format(args.weight_decay))
    net_logger.info('Learning Rate   : {}'.format(args.lr))

    height, width = _shape_1, _shape_2 = tuple(
        map(int, args.input_shape.split(',')))
    _normalizer = Normalizer()
    #_augmenter  = Augmenter(scale_min=0.9, logger=net_logger)
    _augmenter = Augmenter(use_scale=False, scale_min=0.9, logger=net_logger)
    if args.resize_mode == 0:
        _resizer = Resizer(min_side=_shape_1,
                           max_side=_shape_2,
                           resize_mode=args.resize_mode,
                           logger=net_logger)
    elif args.resize_mode == 1:
        _resizer = Resizer(height=height,
                           width=width,
                           resize_mode=args.resize_mode,
                           logger=net_logger)
    else:
        raise ValueError('Illegal resize mode.')
    transfrom_funcs_train = [
        _augmenter,
        _normalizer,
        _resizer,
    ]
    transfrom_funcs_valid = [
        _normalizer,
        _resizer,
    ]
    # Create the data loaders
    if args.dataset_type == 'kseven':
        dataset_train = KSevenDataset(
            args.dataset_root,
            set_name='train',
            transform=transforms.Compose(transfrom_funcs_train))
        # dataset_valid = KSevenDataset(args.dataset_root, set_name='train', transform=transforms.Compose(transfrom_funcs_valid))
        dataset_valid = KSevenDataset(
            args.dataset_root,
            set_name='valid',
            transform=transforms.Compose(transfrom_funcs_valid))
    elif args.dataset_type == 'coco':
        dataset_train = CocoDataset(
            args.dataset_root,
            set_name='train',
            transform=transforms.Compose(transfrom_funcs_train))
        dataset_valid = CocoDataset(
            args.dataset_root,
            set_name='valid',
            transform=transforms.Compose(transfrom_funcs_valid))
    else:
        raise ValueError(
            'Dataset type not understood (must be FLIR, COCO or csv), exiting.'
        )

    dataloader_train = DataLoader(dataset_train,
                                  batch_size=args.batch_size,
                                  num_workers=args.workers,
                                  shuffle=True,
                                  collate_fn=collater,
                                  pin_memory=True)
    dataloader_valid = DataLoader(dataset_valid,
                                  batch_size=1,
                                  num_workers=args.workers,
                                  shuffle=False,
                                  collate_fn=collater,
                                  pin_memory=True)

    net_logger.info('Number of Classes: {:>3}'.format(
        dataset_train.num_classes()))

    build_param = {'logger': net_logger}
    if args.architecture == 'ksevendet':
        net_model = ksevendet.KSevenDet(
            ksevendet_cfg,
            num_classes=dataset_train.num_classes(),
            pretrained=False,
            **build_param)
    elif args.architecture == 'retinanet-p45p6':
        net_model = retinanet.retinanet_p45p6(
            num_classes=dataset_train.num_classes(), **build_param)
    elif args.architecture.split('-')[0] == 'retinanet':
        net_model = retinanet.build_retinanet(
            args.architecture,
            num_classes=dataset_train.num_classes(),
            pretrained=False,
            **build_param)
    elif args.architecture.split('-')[0] == 'efficientdet':
        net_model = efficientdet.build_efficientdet(
            args.architecture,
            num_classes=dataset_train.num_classes(),
            pretrained=False,
            **build_param)
    else:
        assert 0, 'architecture error'

    # load last weights
    if args.resume is not None:
        net_logger.info('Loading Weights from Checkpoint : {}'.format(
            args.resume))
        try:
            ret = net_model.load_state_dict(torch.load(args.resume),
                                            strict=False)
        except RuntimeError as e:
            net_logger.warning(f'Ignoring {e}')
            net_logger.warning(
                f'Don\'t panic if you see this, this might be because you load a pretrained weights with different number of classes. The rest of the weights should be loaded already.'
            )

        s_b = args.resume.rindex('_')
        s_e = args.resume.rindex('.')
        start_epoch = int(args.resume[s_b + 1:s_e]) + 1
        net_logger.info('Continue on {} Epoch'.format(start_epoch))
    else:
        start_epoch = 1

    use_gpu = True
    if use_gpu:
        if torch.cuda.is_available():
            net_model = net_model.cuda()

    sample_image = np.zeros((height, width, 3)).astype(np.float32)
    sample_image = torch.from_numpy(sample_image)
    sample_input = sample_image.permute(2, 0,
                                        1).cuda().float().unsqueeze(dim=0)
    sample_input_shape = sample_image.permute(
        2, 0, 1).cuda().float().unsqueeze(dim=0).shape

    # the following statement is unnecessary.
    # net_model.set_onnx_convert_info(fixed_size=(height, width))

    net_pruner = KSevenPruner(
        net_model,
        input_shape=(height, width, 3),
        tensor_pruning_dependency=tensor_pruning_dependency,
        **build_param)
    eq_tensors_ids = list(tensor_pruning_dependency.keys())
    eq_tensors_ids.sort()

    net_logger.info('Start Pruning.')
    net_pruner.prune(pruning_tensor_cfg)
    net_logger.info('Pruning Complete.')

    if torch.cuda.is_available():
        net_model = torch.nn.DataParallel(net_model).cuda()
    else:
        net_model = torch.nn.DataParallel(net_model)

    net_model.training = True

    if args.optim == 'adamw':
        optimizer = optim.AdamW(net_model.parameters(),
                                lr=args.lr,
                                weight_decay=args.weight_decay)
    elif args.optim == 'adam':
        optimizer = optim.Adam(net_model.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    elif args.optim == 'adagrad':
        optimizer = optim.Adagrad(net_model.parameters(),
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    elif args.optim == 'sgd':
        optimizer = torch.optim.SGD(net_model.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.weight_decay,
                                    momentum=0.9,
                                    nesterov=True)
    else:
        raise ValueError(f'Unknown optimizer type {args.optim}')

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    loss_hist = collections.deque(maxlen=500)

    net_model.train()
    if isinstance(net_model, torch.nn.DataParallel):
        net_model.module.freeze_bn()
    else:
        net_model.freeze_bn()

    net_logger.info('Num Training Images: {}'.format(len(dataset_train)))

    if args.validation_only:
        net_model.eval()
        test(dataset_valid, net_model, start_epoch - 1, args, net_logger)
        print('Validation Done.')
        exit(0)

    for epoch_num in range(start_epoch, start_epoch + args.epochs):
        net_model.train()
        if isinstance(net_model, torch.nn.DataParallel):
            net_model.module.freeze_bn()
        else:
            net_model.freeze_bn()
        #net_model.module.freeze_bn()
        #net_model.eval()
        #test(dataset_valid, net_model, epoch_num, args, net_logger)
        #exit(0)

        epoch_loss = []
        for iter_num, data in enumerate(dataloader_train):
            try:
                optimizer.zero_grad()
                # print('Image Shape : {}'.format(str(data['img'][0,:,:,:].shape)))
                # print(data['annot'])
                # exit(0)
                imgs = data['img']
                annot = data['annot']
                if args.num_gpus == 1:
                    # if only one gpu, just send it to cuda:0
                    # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
                    imgs = imgs.cuda()
                    annot = annot.cuda()
                classification_loss, regression_loss = net_model(
                    imgs, annot, return_loss=True)

                #if torch.cuda.is_available():
                #    #classification_loss, regression_loss = net_model([data['img'].cuda().float(), data['annot']])
                #    classification_loss, regression_loss = net_model(data['img'].cuda().float(), data['annot'].cuda())
                #else:
                #    #classification_loss, regression_loss = net_model([data['img'].float(), data['annot']])
                #    classification_loss, regression_loss = net_model(data['img'].float(), data['annot'])

                classification_loss = classification_loss.mean()
                regression_loss = regression_loss.mean()
                loss = classification_loss + regression_loss

                if bool(loss == 0):
                    continue

                loss.backward()
                torch.nn.utils.clip_grad_norm_(net_model.parameters(), 0.1)
                optimizer.step()

                loss_hist.append(float(loss))
                epoch_loss.append(float(loss))

                if epoch_num == 1 and (iter_num % 10 == 0):
                    _log = 'Epoch: {:>3} | Iter: {:>4} | Class loss: {:1.5f} | BBox loss: {:1.5f} | Running loss: {:1.5f}'.format(
                        epoch_num, iter_num, float(classification_loss),
                        float(regression_loss), np.mean(loss_hist))
                    net_logger.info(_log)
                elif (iter_num % 100 == 0):
                    _log = 'Epoch: {:>3} | Iter: {:>4} | Class loss: {:1.5f} | BBox loss: {:1.5f} | Running loss: {:1.5f}'.format(
                        epoch_num, iter_num, float(classification_loss),
                        float(regression_loss), np.mean(loss_hist))
                    net_logger.info(_log)

                del classification_loss
                del regression_loss
            except Exception as e:
                raise Exception
                #print(e)
                #continue

        # if (epoch_num + 1) % 1 == 0:
        #if (epoch_num + 1) % args.valid_period == 0:
        if epoch_num % args.valid_period == 0:
            test(dataset_valid, net_model, epoch_num, args, net_logger)

        scheduler.step(np.mean(epoch_loss))
        print('Learning Rate:', str(scheduler._last_lr))
        # save_checkpoint(net_model, os.path.join(
        #                 'saved', '{}_{}_{}.pt'.format(args.dataset, network_name, epoch_num)))

    net_logger.info('Training Complete.')

    net_model.eval()
    test(dataset_valid, net_model, epoch_num, args, net_logger)
Esempio n. 7
0
def main():
    args = parser.parse_args()

    if args.pytorch_inference:
        assert args.architecture, 'Must provide --architecture when pytorch inference.'
        assert args.resume, 'Must provide --resume when pytorch inference.'
    if args.onnx_inference:
        assert args.onnx, 'Must provide --onnx when onnx inference.'

    if args.dataset == 'thermal':
        input_height, input_width = 60, 80
    elif args.dataset == '3s-pocket-thermal-face':
        input_height, input_width = 288, 384
    else:
        raise ValueError('unknow dataset.')
    transform = transforms.Compose([
        Normalizer(inference_mode=True),
        Resizer(height=input_height, width=input_width, inference_mode=True)
    ])
    my_img_preprocessor = K7ImagePreprocessor(
        mean=np.array([[[0.485, 0.456, 0.406]]]),
        std=np.array([[[0.229, 0.224, 0.225]]]),
        resize_height=input_height,
        resize_width=input_width)

    net_logger = logging.getLogger('ONNX Inference Logger')
    formatter = logging.Formatter(LOGGING_FORMAT)
    streamhandler = logging.StreamHandler()
    streamhandler.setFormatter(formatter)
    net_logger.addHandler(streamhandler)
    net_logger.setLevel(logging.INFO)

    net_logger.info('Positive Threshold: {:.2f}'.format(args.threshold))

    pytorch_model = None
    if args.pytorch_inference:
        net_logger.info('Build pytorch model...')
        build_param = {'logger': net_logger}
        if args.architecture == 'RetinaNet':
            pytorch_model = retinanet.retinanet(args.depth,
                                                num_classes=args.num_classes,
                                                **build_param)
        elif args.architecture == 'RetinaNet-Tiny':
            pytorch_model = retinanet.retinanet_tiny(
                num_classes=args.num_classes, **build_param)
        elif args.architecture == 'RetinaNet_P45P6':
            pytorch_model = retinanet.retinanet_p45p6(
                num_classes=args.num_classes, **build_param)
        else:
            raise ValueError('Architecture <{}> unknown.'.format(
                args.architecture))

        net_logger.info('Loading Weights from Checkpoint : {}'.format(
            args.resume))
        pytorch_model.load_state_dict(torch.load(args.resume))

        use_gpu = True

        if use_gpu:
            if torch.cuda.is_available():
                pytorch_model = pytorch_model.cuda()

        if torch.cuda.is_available():
            pytorch_model = torch.nn.DataParallel(pytorch_model).cuda()
        else:
            pytorch_model = torch.nn.DataParallel(pytorch_model)

        pytorch_model.eval()
        net_logger.info('Initialize Pytorch Model...  Finished')

    onnx_model = None
    if args.onnx_inference:
        net_logger.info('Build onnx model...')
        net_logger.info('Onnx loading...')
        # Load the ONNX model
        #model = onnx.load('./retinanet-tiny.onnx')
        onnx_model = onnx.load(args.onnx)
        print(type(onnx_model))

        # Check that the IR is well formed
        net_logger.info('Onnx checking...')
        onnx.checker.check_model(onnx_model)

        ort_session = onnxruntime.InferenceSession(args.onnx)
        net_logger.info('Onnx initialize... Done')

        ## Print a human readable representation of the graph
        #onnx.helper.printable_graph(onnx_model.graph)

        #rep = backend.prepare(onnx_model, device="CUDA:0") # or "CPU"
        ## For the Caffe2 backend:
        ##     rep.predict_net is the Caffe2 protobuf for the network
        ##     rep.workspace is the Caffe2 workspace for the network
        ##       (see the class caffe2.python.onnx.backend.Workspace)

    sample_files = os.listdir(args.sample_path)
    sample_files.sort()

    tensor_comp = TensorCompare()

    # for f in sample_files[:1]:
    for f in sample_files:
        if f.split('.')[-1] not in ['png', 'jpg']:
            continue
        print(f'loading image... {f}')

        img = Image.open(os.path.join(args.sample_path, f)).convert('RGB')
        a_img = np.array(img)
        a_img = a_img.astype(np.float32) / 255.0

        #t_img = transform(a_img)
        #t_img = torch.unsqueeze(t_img, 0)
        #t_img = t_img.permute(0, 3, 1, 2)

        a_img = my_img_preprocessor(a_img)
        a_img = np.expand_dims(a_img, axis=0)
        a_img = a_img.transpose((0, 3, 1, 2))

        if args.dump_sample_npz:
            _npz_name = os.path.join(args.sample_path,
                                     f.split('.')[0] + '_preprocess.npz')
            print(f'save npz to ... {_npz_name}')
            np.savez(_npz_name, input=a_img)

        #tt = tensor_comp.compare(to_numpy(t_img), a_img)
        #pdb.set_trace()

        net_logger.info('pytorch inference start ...')
        # scores, labels, boxes = pytorch_model(aa_img)
        pytorch_scores, pytorch_labels, pytorch_boxes = pytorch_model(
            torch.from_numpy(a_img))
        net_logger.info('inference done.')
        if args.compare_head_tensor:
            pytorch_classification, pytorch_regression = pytorch_model(
                torch.from_numpy(a_img), return_head=True)

        pytorch_scores = pytorch_scores.cpu()
        pytorch_labels = pytorch_labels.cpu()
        pytorch_boxes = pytorch_boxes.cpu()

        # change to (x, y, w, h) (MS COCO standard)
        pytorch_boxes[:, 2] -= pytorch_boxes[:, 0]
        pytorch_boxes[:, 3] -= pytorch_boxes[:, 1]

        if args.dataset == 'thermal':
            img = img.resize((80, 60))

        draw = ImageDraw.Draw(img)
        for box_id in range(pytorch_boxes.shape[0]):
            score = float(pytorch_scores[box_id])
            label = int(pytorch_labels[box_id])
            box = pytorch_boxes[box_id, :]

            # scores are sorted, so we can break
            if score < args.threshold:
                break

            x, y, w, h = box
            # draw.rectangle(tuple([x, y, x+w, y+h]), width = 2, outline =COLOR_LABEL[label])

        _img_save_path = os.path.join(
            args.sample_path,
            os.path.basename(f)[:-4] + '_inference_visual.jpg')

        #onnx_input = a_img
        #onnx_input = to_tensor(a_img)
        #onnx_input.unsqueeze_(0)
        print('Onnx Inference')
        print('Input.shape = {}'.format(str(a_img.shape)))
        print('start inference')
        # outputs = rep.run(a_img)
        # ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(onnx_input)}
        ort_inputs = {ort_session.get_inputs()[0].name: a_img}

        net_logger.info('pytorch inference start ...')
        ort_outs = ort_session.run(None, ort_inputs)
        net_logger.info('inference done.')

        # onnx_out = ort_outs[0]
        # To run networks with more than one input, pass a tuple
        # rather than a single numpy ndarray.

        #print(f'len(ort_outs) = {len(ort_outs)}')
        #for i, t in enumerate(ort_outs):
        #    print(f'out[{i}] type is {type(t)}')
        #    print(f'out[{i}] shape is {t.shape}')

        classification, regression = ort_outs

        anchors = pytorch_model.module.anchors(a_img)
        # print(anchors)
        # print(anchors.shape)

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        transformed_anchors = pytorch_model.module.regressBoxes(
            anchors,
            torch.from_numpy(regression).float().to(device))
        transformed_anchors = pytorch_model.module.clipBoxes(
            transformed_anchors, a_img)
        transformed_anchors = to_numpy(transformed_anchors)

        #scores = torch.max(classification, dim=2, keepdim=True)[0]
        onnx_scores = np.max(classification, axis=2, keepdims=True)[0]

        # scores_over_thresh = (onnx_scores > 0.05)[0, :, 0]
        onnx_scores_over_thresh = (onnx_scores > 0.05)[:, 0]

        # print(f'onnx_scores_over_thresh.sum() = {onnx_scores_over_thresh.sum()}')
        if onnx_scores_over_thresh.sum() == 0:
            print('No boxes to NMS')
            # no boxes to NMS, just return
            # return [torch.zeros(0), torch.zeros(0), torch.zeros(0, 4)]

        classification = classification[:, onnx_scores_over_thresh, :]
        transformed_anchors = transformed_anchors[:,
                                                  onnx_scores_over_thresh, :]
        #onnx_scores = onnx_scores[:, onnx_scores_over_thresh, :]
        onnx_scores = onnx_scores[onnx_scores_over_thresh, :]

        anchors_nms_idx = nms(transformed_anchors[0, :, :], onnx_scores[:, 0],
                              0.5)
        # pdb.set_trace()
        onnx_nms_scores = classification[0, anchors_nms_idx, :].max(axis=1)
        onnx_nms_class = classification[0, anchors_nms_idx, :].argmax(axis=1)
        onnx_boxes = transformed_anchors[0, anchors_nms_idx, :]
        onnx_boxes[:, 2] -= onnx_boxes[:, 0]
        onnx_boxes[:, 3] -= onnx_boxes[:, 1]

        if args.compare_head_tensor:
            onnx_classification, onnx_regression = ort_outs
            net_logger.info('compare classification ...')
            _result = tensor_comp.compare(to_numpy(pytorch_classification),
                                          onnx_classification)
            print(f'Pass {_result[0]}')
            for k in _result[2]:
                print('{:>20} : {:<2.8f}'.format(k, _result[2][k]))
                if k == 'close_order':
                    print(_result)
            if not _result[0]:
                print('Not similar.')
                exit(0)
            net_logger.info('compare regression ...')
            _result = tensor_comp.compare(to_numpy(pytorch_regression),
                                          onnx_regression)
            print(f'Pass {_result[0]}')
            for k in _result[2]:
                print('{:>20} : {:<2.8f}'.format(k, _result[2][k]))