Example #1
0
    def __init__(self, config):
        super(TestProgram, self).__init__()
        self.config = config
        if (config['infer']['trt_path'] is not None):
            engine = build_engine(config['infer']['onnx_path'],
                                  config['infer']['trt_path'],
                                  config['testload']['batch_size'])
            self.inputs, self.outputs, self.bindings, self.stream = allocate_buffers(
                engine)
            self.context = engine.create_execution_context()
            self.infer_type = 'trt_infer'

        elif (config['infer']['onnx_path'] is not None):
            self.model = onnxruntime.InferenceSession(
                config['infer']['onnx_path'])
            self.infer_type = 'onnx_infer'
        else:
            model = create_module(
                config['architectures']['model_function'])(config)
            model = load_model(model, config['infer']['model_path'])
            if torch.cuda.is_available():
                model = model.cuda()
            self.model = model
            self.model.eval()
            self.infer_type = 'torch_infer'

        img_process = create_module(config['postprocess']['function'])(config)
        self.img_process = img_process
Example #2
0
 def __init__(self,config):
     super(TestProgram,self).__init__()
     self.congig = config
     model = create_module(config['architectures']['model_function'])(config)
     img_process = create_module(config['postprocess']['function'])(config)
     model = load_model(model,config['infer']['model_path'])
     if torch.cuda.is_available():
         model = model.cuda()
     self.model = model
     self.img_process = img_process
     self.model.eval()
Example #3
0
    def __init__(self, config):
        super(TestProgram, self).__init__()

        self.converter = create_module(
            config['label_transform']['function'])(config)
        config['base']['classes'] = len(self.converter.alphabet)
        model = create_module(
            config['architectures']['model_function'])(config)
        model = load_model(model, config['infer']['model_path'])
        if torch.cuda.is_available():
            model = model.cuda()
        self.model = model
        self.congig = config
        self.model.eval()
Example #4
0
def GetTeacherModel(args):
    config = yaml.load(open(args.t_config, 'r', encoding='utf-8'),
                       Loader=yaml.FullLoader)
    model = create_module(config['architectures']['model_function'])(config)
    model = load_model(model, args.t_model_path)
    if torch.cuda.is_available():
        model = model.cuda()
    return model
Example #5
0
def TrainValProgram(config):
    os.environ["CUDA_VISIBLE_DEVICES"] = config['base']['gpu_id']

    create_dir(config['base']['checkpoints'])
    checkpoints = os.path.join(
        config['base']['checkpoints'], "ag_%s_bb_%s_he_%s_bs_%d_ep_%d" %
        (config['base']['algorithm'],
         config['backbone']['function'].split(',')[-1],
         config['head']['function'].split(',')[-1],
         config['trainload']['batch_size'], config['base']['n_epoch']))
    create_dir(checkpoints)

    model = create_module(config['architectures']['model_function'])(config)
    criterion = create_module(config['architectures']['loss_function'])(config)
    train_dataset = create_module(config['trainload']['function'])(config)
    test_dataset = create_module(config['testload']['function'])(config)
    optimizer = create_module(config['optimizer']['function'])(config, model)
    optimizer_decay = create_module(config['optimizer_decay']['function'])
    img_process = create_module(config['postprocess']['function'])(config)

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['trainload']['batch_size'],
        shuffle=True,
        num_workers=config['trainload']['num_workers'],
        worker_init_fn=worker_init_fn,
        drop_last=True,
        pin_memory=True)

    test_data_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config['testload']['batch_size'],
        shuffle=False,
        num_workers=config['testload']['num_workers'],
        drop_last=True,
        pin_memory=True)

    loss_bin = create_loss_bin(config['base']['algorithm'])

    if torch.cuda.is_available():
        if (len(config['base']['gpu_id'].split(',')) > 1):
            model = torch.nn.DataParallel(model).cuda()
        else:
            model = model.cuda()
        criterion = criterion.cuda()

    start_epoch = 0
    rescall, precision, hmean = 0, 0, 0
    best_rescall, best_precision, best_hmean = 0, 0, 0

    if config['base']['restore']:
        print('Resuming from checkpoint.')
        assert os.path.isfile(
            config['base']['restore_file']), 'Error: no checkpoint file found!'
        checkpoint = torch.load(config['base']['restore_file'])
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_rescall = checkpoint['rescall']
        best_precision = checkpoint['precision']
        best_hmean = checkpoint['hmean']
        log_write = Logger(os.path.join(checkpoints, 'log.txt'),
                           title=config['base']['algorithm'],
                           resume=True)
    else:
        print('Training from scratch.')
        log_write = Logger(os.path.join(checkpoints, 'log.txt'),
                           title=config['base']['algorithm'])
        title = list(loss_bin.keys())
        title.extend([
            'piexl_acc', 'piexl_iou', 't_rescall', 't_precision', 't_hmean',
            'b_rescall', 'b_precision', 'b_hmean'
        ])
        log_write.set_names(title)

    for epoch in range(start_epoch, config['base']['n_epoch']):
        model.train()
        optimizer_decay(config, optimizer, epoch)
        loss_write = ModelTrain(train_data_loader, model, criterion, optimizer,
                                loss_bin, config, epoch)

        if (epoch >= config['base']['start_val']):
            create_dir(os.path.join(checkpoints, 'val'))
            create_dir(os.path.join(checkpoints, 'val', 'res_img'))
            create_dir(os.path.join(checkpoints, 'val', 'res_txt'))
            model.eval()
            rescall, precision, hmean = ModelEval(test_dataset,
                                                  test_data_loader, model,
                                                  img_process, checkpoints,
                                                  config)
            print('rescall:', rescall, 'precision', precision, 'hmean', hmean)
            if (hmean > best_hmean):
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'lr': config['optimizer']['base_lr'],
                        'optimizer': optimizer.state_dict(),
                        'hmean': hmean,
                        'rescall': rescall,
                        'precision': precision
                    }, checkpoints,
                    config['base']['algorithm'] + '_best' + '.pth.tar')
                best_hmean = hmean
                best_precision = precision
                best_rescall = rescall

        loss_write.extend([
            rescall, precision, hmean, best_rescall, best_precision, best_hmean
        ])
        log_write.append(loss_write)
        for key in loss_bin.keys():
            loss_bin[key].loss_clear()
        if epoch % config['base']['save_epoch'] == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'lr': config['optimizer']['base_lr'],
                    'optimizer': optimizer.state_dict(),
                    'hmean': 0,
                    'rescall': 0,
                    'precision': 0
                }, checkpoints,
                config['base']['algorithm'] + '_' + str(epoch) + '.pth.tar')
def prune(args):

    stream = open(args.config, 'r', encoding='utf-8')
    config = yaml.load(stream, Loader=yaml.FullLoader)

    img = cv2.imread(args.img_file)
    img = resize_image(img,
                       config['base']['algorithm'],
                       config['testload']['test_size'],
                       stride=config['testload']['stride'])
    img = Image.fromarray(img)
    img = img.convert('RGB')
    img = transforms.ToTensor()(img)
    img = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])(img)
    img = Variable(img.cuda()).unsqueeze(0)

    model = create_module(
        config['architectures']['model_function'])(config).cuda()
    model = load_model(model, args.checkpoint)

    model.eval()
    print(model)

    cut_percent = 0.5
    base_num = 4

    bn_weights = []
    for m in model.modules():
        if (isinstance(m, nn.BatchNorm2d)):
            bn_weights.append(m.weight.data.abs().clone())
    bn_weights = torch.cat(bn_weights, 0)

    sort_result, sort_index = torch.sort(bn_weights)

    thresh_index = int(cut_percent * bn_weights.shape[0])

    if (thresh_index == bn_weights.shape[0]):
        thresh_index = bn_weights.shape[0] - 1

    prued = 0
    prued_mask = []
    bn_index = []
    conv_index = []
    remain_channel_nums = []
    tag = 0
    for k, m in enumerate(model.modules()):
        if (tag > 187):
            break
        tag += 1
        if (isinstance(m, nn.BatchNorm2d)):
            bn_weight = m.weight.data.clone()
            mask = bn_weight.abs().gt(sort_result[thresh_index])
            remain_channel = mask.sum()

            if (remain_channel == 0):
                remain_channel = 1
                mask[int(torch.argmax(bn_weight))] = 1

            v = 0
            n = 1
            if (remain_channel % base_num != 0):
                if (remain_channel > base_num):
                    while (v < remain_channel):
                        n += 1
                        v = base_num * n
                    if (remain_channel - (v - base_num) < v - remain_channel):
                        remain_channel = v - base_num
                    else:
                        remain_channel = v
                    if (remain_channel > bn_weight.size()[0]):
                        remain_channel = bn_weight.size()[0]
                    remain_channel = torch.tensor(remain_channel)
                    result, index = torch.sort(bn_weight)
                    mask = bn_weight.abs().ge(result[-remain_channel])

            remain_channel_nums.append(int(mask.sum()))
            prued_mask.append(mask)
            bn_index.append(k)
            prued += mask.shape[0] - mask.sum()
        elif (isinstance(m, nn.Conv2d)):
            conv_index.append(k)

    print('remain_channel_nums', remain_channel_nums)
    print('total_prune_ratio:', float(prued) / bn_weights.shape[0])
    print('bn_index', bn_index)
    print('conv_index', conv_index)

    new_model = create_module(
        config['architectures']['model_function'])(config).cuda()

    keys = {}
    tag = 0
    for k, m in enumerate(new_model.modules()):
        if (isinstance(m, ptocr.model.backbone.det_mobilev3.Block)):
            keys[tag] = k
            tag += 1
    print(keys)
    #### step 1
    mg_1 = np.array([-3, 7, 16])
    block_idx = keys[0]
    tag = 0
    for idx in mg_1 + block_idx:
        if (tag == 0):
            msk = prued_mask[bn_index.index(idx)]
        else:
            msk = msk | prued_mask[bn_index.index(idx)]
        tag += 1
        print('step1', idx)
    print(msk.sum())
    for idx in mg_1 + block_idx:
        prued_mask[bn_index.index(idx)] = msk
    msk_1 = msk.clone()

    #### step 2
    block_idx2 = np.array([keys[1], keys[2]])
    mg_2 = 7
    tag = 0
    for idx in mg_2 + block_idx2:
        print('step2', idx)
        if (tag == 0):
            msk = prued_mask[bn_index.index(idx)]
        else:
            msk = msk | prued_mask[bn_index.index(idx)]
        tag += 1
    for idx in mg_2 + block_idx2:
        prued_mask[bn_index.index(idx)] = msk
    print(msk.sum())
    msk_2 = msk.clone()

    ####step 3
    block_idx3s = [keys[3], keys[4], keys[5]]
    mg_3 = np.array([7, 16])
    tag = 0
    for block_idx3 in block_idx3s:
        for idx in block_idx3 + mg_3:
            print('step3', idx)
            if (tag == 0):
                msk = prued_mask[bn_index.index(idx)]
            else:
                msk = msk | prued_mask[bn_index.index(idx)]
            tag += 1
    for block_idx3 in block_idx3s:
        for idx in block_idx3 + mg_3:
            prued_mask[bn_index.index(idx)] = msk
    print(msk.sum())
    msk_3 = msk.clone()

    ####step 4_1
    block_idx4_all = []

    block_idx4 = keys[6]

    mg_4 = np.array([7, 16])
    block_idx4_all.extend((block_idx4 + mg_4).tolist())

    ####step 4_2
    block_idx4 = keys[7]
    mg_4 = np.array([7, 16])
    block_idx4_all.extend((block_idx4 + mg_4).tolist())
    tag = 0

    for idx in block_idx4_all:
        print('step4', idx)
        if (tag == 0):
            msk = prued_mask[bn_index.index(idx)]
        else:
            msk = msk | prued_mask[bn_index.index(idx)]
        tag += 1

    for idx in block_idx4_all:
        prued_mask[bn_index.index(idx)] = msk
    print(msk.sum())
    msk_4 = msk.clone()

    ####step 5
    block_idx5s = [keys[8], keys[9], keys[10]]
    mg_5 = np.array([7, 16])
    tag = 0
    for block_idx5 in block_idx5s:
        for idx in block_idx5 + mg_5:
            if (tag == 0):
                msk = prued_mask[bn_index.index(idx)]
            else:
                msk = msk | prued_mask[bn_index.index(idx)]
            tag += 1

    for block_idx5 in block_idx5s:
        for idx in block_idx5 + mg_5:
            prued_mask[bn_index.index(idx)] = msk
    print(msk.sum())
    msk_5 = msk.clone()

    group_index = []
    spl_index = []
    for i in range(11):
        block_idx6 = keys[i]
        tag = 0
        mg_6 = np.array([2, 5])
        for idx in mg_6 + block_idx6:
            if (tag == 0):
                msk = prued_mask[bn_index.index(idx)]
            else:
                msk = msk | prued_mask[bn_index.index(idx)]
            tag += 1
        for idx in mg_6 + block_idx6:
            prued_mask[bn_index.index(idx)] = msk
        if (i == 6):
            spl_index.extend([block_idx6 + 9, block_idx6 - 2])
        group_index.append(block_idx6 + 4)
    import pdb
    pdb.set_trace()
    count_conv = 0
    count_bn = 0
    conv_in_mask = [torch.ones(3)]
    conv_out_mask = []
    bn_mask = []
    tag = 0
    for k, m in enumerate(new_model.modules()):
        if (tag > 187):
            break
        if isinstance(m, nn.Conv2d):

            if (tag in group_index):
                m.groups = int(prued_mask[bn_index.index(tag + 1)].sum())
            m.out_channels = int(prued_mask[count_conv].sum())
            conv_out_mask.append(prued_mask[count_conv])
            if (count_conv > 0):
                if (tag == spl_index[0]):
                    m.in_channels = int(prued_mask[bn_index.index(
                        spl_index[1])].sum())
                    conv_in_mask.append(prued_mask[bn_index.index(
                        spl_index[1])])
                else:
                    m.in_channels = int(prued_mask[count_conv - 1].sum())
                    conv_in_mask.append(prued_mask[count_conv - 1])

            count_conv += 1
        elif isinstance(m, nn.BatchNorm2d):
            m.num_features = prued_mask[count_bn].sum()
            bn_mask.append(prued_mask[count_bn])
            count_bn += 1
        tag += 1

    bn_i = 0
    conv_i = 0
    model_i = 0
    scale = [188, 192, 196, 200]
    scale_mask = [msk_5, msk_4, msk_3, msk_2]
    for [m0, m1] in zip(model.modules(), new_model.modules()):
        if (model_i > 187):
            if isinstance(m0, nn.Conv2d):
                if (model_i in scale):
                    index = scale.index(model_i)
                    m1.in_channels = int(scale_mask[index].sum())
                    idx0 = np.squeeze(
                        np.argwhere(np.asarray(
                            scale_mask[index].cpu().numpy())))
                    idx1 = np.squeeze(
                        np.argwhere(np.asarray(torch.ones(96).cpu().numpy())))
                    if idx0.size == 1:
                        idx0 = np.resize(idx0, (1, ))
                    if idx1.size == 1:
                        idx1 = np.resize(idx1, (1, ))
                    w = m0.weight.data[:, idx0, :, :].clone()
                    m1.weight.data = w[idx1, :, :, :].clone()
                    if m1.bias is not None:
                        m1.bias.data = m0.bias.data[idx1].clone()

                else:
                    m1.weight.data = m0.weight.data.clone()
                    if m1.bias is not None:
                        m1.bias.data = m0.bias.data.clone()

            elif isinstance(m0, nn.BatchNorm2d):
                m1.weight.data = m0.weight.data.clone()
                if m1.bias is not None:
                    m1.bias.data = m0.bias.data.clone()
                m1.running_mean = m0.running_mean.clone()
                m1.running_var = m0.running_var.clone()
        else:
            if isinstance(m0, nn.BatchNorm2d):
                idx1 = np.squeeze(
                    np.argwhere(np.asarray(bn_mask[bn_i].cpu().numpy())))
                if idx1.size == 1:
                    idx1 = np.resize(idx1, (1, ))
                m1.weight.data = m0.weight.data[idx1].clone()
                if m1.bias is not None:
                    m1.bias.data = m0.bias.data[idx1].clone()
                m1.running_mean = m0.running_mean[idx1].clone()
                m1.running_var = m0.running_var[idx1].clone()
                bn_i += 1
            elif isinstance(m0, nn.Conv2d):
                if (isinstance(conv_in_mask[conv_i], list)):
                    idx0 = np.squeeze(
                        np.argwhere(
                            np.asarray(
                                torch.cat(conv_in_mask[conv_i],
                                          0).cpu().numpy())))
                else:
                    idx0 = np.squeeze(
                        np.argwhere(
                            np.asarray(conv_in_mask[conv_i].cpu().numpy())))
                idx1 = np.squeeze(
                    np.argwhere(np.asarray(
                        conv_out_mask[conv_i].cpu().numpy())))
                if idx0.size == 1:
                    idx0 = np.resize(idx0, (1, ))
                if idx1.size == 1:
                    idx1 = np.resize(idx1, (1, ))
                if (model_i in group_index):
                    m1.weight.data = m0.weight.data[idx1, :, :, :].clone()
                    if m1.bias is not None:
                        m1.bias.data = m0.bias.clone()
                else:
                    w = m0.weight.data[:, idx0, :, :].clone()
                    m1.weight.data = w[idx1, :, :, :].clone()
                    if m1.bias is not None:
                        m1.bias.data = m0.bias.data[idx1].clone()
                conv_i += 1
        model_i += 1

    print(new_model)
    new_model.eval()
    with torch.no_grad():
        out = new_model(img)
    print(out.shape)
    cv2.imwrite('re1.jpg', out[0, 0].cpu().numpy() * 255)

    save_obj = {'prued_mask': prued_mask, 'bn_index': bn_index}
    torch.save(save_obj,
               os.path.join(args.save_prune_model_path, 'pruned_dict.dict'))
    torch.save(new_model.state_dict(),
               os.path.join(args.save_prune_model_path, 'pruned_dict.pth.tar'))
Example #7
0
def TrainValProgram(config):

    config = yaml.load(open(args.config, 'r', encoding='utf-8'),
                       Loader=yaml.FullLoader)
    config = merge_config(config, args)

    os.environ["CUDA_VISIBLE_DEVICES"] = config['base']['gpu_id']

    create_dir(config['base']['checkpoints'])
    checkpoints = os.path.join(
        config['base']['checkpoints'], "ag_%s_bb_%s_he_%s_bs_%d_ep_%d_%s" %
        (config['base']['algorithm'],
         config['backbone']['function'].split(',')[-1],
         config['head']['function'].split(',')[-1],
         config['trainload']['batch_size'], config['base']['n_epoch'],
         args.log_str))
    create_dir(checkpoints)

    LabelConverter = create_module(
        config['label_transform']['function'])(config)
    config['base']['classes'] = len(LabelConverter.alphabet)
    model = create_module(config['architectures']['model_function'])(config)
    criterion = create_module(config['architectures']['loss_function'])(config)
    train_dataset = create_module(config['trainload']['function'])(config)
    test_dataset = create_module(config['testload']['function'])(config)
    optimizer = create_module(config['optimizer']['function'])(config, model)
    optimizer_decay = create_module(config['optimizer_decay']['function'])

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['trainload']['batch_size'],
        shuffle=True,
        num_workers=config['trainload']['num_workers'],
        worker_init_fn=worker_init_fn,
        collate_fn=alignCollate(),
        drop_last=True,
        pin_memory=True)

    test_data_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config['testload']['batch_size'],
        shuffle=False,
        num_workers=config['testload']['num_workers'],
        collate_fn=alignCollate(),
        drop_last=True,
        pin_memory=True)

    loss_bin = create_loss_bin(config['base']['algorithm'])

    if torch.cuda.is_available():
        if (len(config['base']['gpu_id'].split(',')) > 1):
            model = torch.nn.DataParallel(model).cuda()
        else:
            model = model.cuda()
        criterion = criterion.cuda()

    start_epoch = 0
    val_acc = 0
    val_loss = 0
    best_acc = 0

    if config['base']['restore']:
        print('Resuming from checkpoint.')
        assert os.path.isfile(
            config['base']['restore_file']), 'Error: no checkpoint file found!'
        checkpoint = torch.load(config['base']['restore_file'])
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_acc = checkpoint['best_acc']
        log_write = Logger(os.path.join(checkpoints, 'log.txt'),
                           title=config['base']['algorithm'],
                           resume=True)
    else:
        print('Training from scratch.')
        log_write = Logger(os.path.join(checkpoints, 'log.txt'),
                           title=config['base']['algorithm'])
        title = list(loss_bin.keys())
        title.extend(['val_loss', 'test_acc', 'best_acc'])
        log_write.set_names(title)

    for epoch in range(start_epoch, config['base']['n_epoch']):
        model.train()
        optimizer_decay(config, optimizer, epoch)
        loss_write = ModelTrain(train_data_loader, LabelConverter, model,
                                criterion, optimizer, loss_bin, config, epoch)
        if (epoch >= config['base']['start_val']):
            model.eval()
            val_acc, val_loss = ModelEval(test_data_loader, LabelConverter,
                                          model, criterion, config)
            print('val_acc:', val_acc, 'val_loss', val_loss)
            if (val_acc > best_acc):
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'lr': config['optimizer']['base_lr'],
                        'optimizer': optimizer.state_dict(),
                        'best_acc': val_acc
                    }, checkpoints,
                    config['base']['algorithm'] + '_best' + '.pth.tar')
                best_acc = val_acc

        loss_write.extend([val_loss, val_acc, best_acc])
        log_write.append(loss_write)
        for key in loss_bin.keys():
            loss_bin[key].loss_clear()
        if epoch % config['base']['save_epoch'] == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'lr': config['optimizer']['base_lr'],
                    'optimizer': optimizer.state_dict(),
                    'best_acc': 0
                }, checkpoints,
                config['base']['algorithm'] + '_' + str(epoch) + '.pth.tar')
Example #8
0
def gen_onnx(args):
    stream = open(args.config, 'r', encoding='utf-8')
    config = yaml.load(stream, Loader=yaml.FullLoader)

    model = create_module(config['architectures']['model_function'])(config)

    model = model.cuda()
    model = load_model(model, args.model_path)
    model.eval()

    print('load model ok.....')

    img = cv2.imread(args.img_path)
    img = cv2.resize(img, (1280, 768))

    img1 = Image.fromarray(img).convert('RGB')
    img1 = transforms.ToTensor()(img1)
    img1 = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224,
                                     0.225])(img1).unsqueeze(0).cuda()

    s = time.time()
    out = model(img1)
    print('cost time:', time.time() - s)
    if isinstance(out, dict):
        out = out['f_score']

    cv2.imwrite('./onnx/ori_output.jpg',
                out[0, 0].cpu().detach().numpy() * 255)

    output_onnx = args.save_path
    print("==> Exporting model to ONNX format at '{}'".format(output_onnx))
    input_names = ["input"]
    # output_names = ["hm" , "wh"  , "reg"]
    output_names = ["out"]
    inputs = torch.randn(1, 3, 768, 1280).cuda()
    torch_out = torch.onnx._export(model,
                                   inputs,
                                   output_onnx,
                                   export_params=True,
                                   verbose=False,
                                   do_constant_folding=False,
                                   keep_initializers_as_inputs=True,
                                   input_names=input_names,
                                   output_names=output_names)

    onnx_path = args.save_path
    session = onnxruntime.InferenceSession(onnx_path)
    # session.get_modelmeta()
    # input_name = session.get_inputs()[0].name
    # output_name = session.get_outputs()[0].name

    image = img / 255.0
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    image = (image - mean) / std
    image = np.transpose(image, [2, 0, 1])
    image = np.expand_dims(image, axis=0)
    image = image.astype(np.float32)

    s = time.time()
    preds = session.run(['out'], {'input': image})
    preds = preds[0]
    print(time.time() - s)
    if isinstance(preds, dict):
        preds = preds['f_score']
    cv2.imwrite('./onnx/onnx_output.jpg', preds[0, 0] * 255)

    print('error_distance:',
          np.abs((out.cpu().detach().numpy() - preds)).mean())