def getscore2(model_path, saved=False, maxnum=float('inf')):

    model = PolygonModel(predict_delta=True).to(devices)
    if model_path is not None:
        model.load_state_dict(torch.load(model_path))
        print('Model loaded!')
    # set to eval
    model.eval()
    iou_score = 0.
    nu = 0.  # Intersection
    de = 0.  # Union
    count = 0
    files = glob.glob('/data/duye/KITTI/image/*')  # 所有img
    iouss = []
    trans = transforms.Compose([
        transforms.ToTensor(),
    ])
    for idx, f in enumerate(files):
        # data = scio.loadmat(f)
        # 读取相应的Image文件
        # img_f = f[:-3] + 'JPG'
        image = Image.open(f).convert('RGB')  # png文件
        # img_gt = Image.open(f).convert('RGB')
        W = image.width
        H = image.height
        scaleH = 224.0 / float(H)
        scaleW = 224.0 / float(W)
        # 裁减,resize到224*224
        img_new = image.resize((224, 224), Image.BILINEAR)
        img_new = trans(img_new)
        img_new = img_new.unsqueeze(0)
        color = [np.random.randint(0, 255) for _ in range(3)]
        color += [100]
        color = tuple(color)

        with torch.no_grad():
            pre_v2 = None
            pre_v1 = None
            result_dict = model(img_new.to(devices),
                                pre_v2,
                                pre_v1,
                                mode='test',
                                temperature=0.0)  # (bs, seq_len)

            # [0, 224] index 0: only one sample in mini-batch here
            pred_x = result_dict['final_pred_x'].cpu().numpy()[0]
            pred_y = result_dict['final_pred_y'].cpu().numpy()[0]
            pred_lengths = result_dict['lengths'].cpu().numpy()[0]
            pred_len = np.sum(pred_lengths) - 1  # sub EOS
            vertices1 = []

            # Get the pred poly
            for i in range(pred_len):
                vert = (pred_x[i] / scaleW, pred_y[i] / scaleH)
                vertices1.append(vert)

            if saved:
                try:
                    drw = ImageDraw.Draw(image, 'RGBA')
                    drw.polygon(vertices1, color)
                except TypeError:
                    continue
            #  GT
            gt_name = '/data/duye/KITTI/label/' + f.split(
                '/')[-1][:-4] + '.png'
            # print(gt_name)
            # 读取mask
            gt_mask = Image.open(gt_name)
            gt_mask = np.array(gt_mask)  # (H, W)

            gt_mask[gt_mask > 0] = 255
            gt_mask[gt_mask == 255] = 1

            if saved:
                pass
                #  GT draw
                # drw_gt = ImageDraw.Draw(img_gt, 'RGBA')
                # drw_gt.polygon(vertices2, color)

            # calculate IoU
            img1 = Image.new('L', (W, H), 0)
            ImageDraw.Draw(img1).polygon(vertices1, outline=1, fill=1)
            pre_mask = np.array(img1)  # (H, W)
            # get iou
            intersection = np.logical_and(gt_mask, pre_mask)
            union = np.logical_or(gt_mask, pre_mask)
            nu = np.sum(intersection)
            de = np.sum(union)
            iiou = nu / (de * 1.0) if de != 0 else 0.
            iouss.append(iiou)
        count += 1
        print(count)
        if saved:
            print('saving test result image...')
            save_result_dir = '/data/duye/save_dir/'
            image.save(save_result_dir + str(idx) + '_pred_rooftop.png', 'PNG')
            # img_gt.save(save_result_dir + str(idx) + '_gt_rooftop.png', 'PNG')
        if count >= maxnum:
            break

    iouss.sort()
    iouss.reverse()
    print(iouss)
    true_iou = np.mean(np.array(iouss[:741]))

    return iou_score, nu, de, true_iou
Esempio n. 2
0
if __name__ == '__main__':
    print('start calculating...')

    # load_model = 'New_FPN2_DeltaModel_Epoch6-Step6000_ValIoU0.5855123247072112.pth'
    # load_model = 'New_FPN2_DeltaModel_Epoch5-Step4000_ValIoU0.6130178256915678.pth'
    # load_model = 'FPN_Epoch9-Step6000_ValIoU0.22982870834472033.pth'
    # load_model = 'Joint_FPN2_DeltaModel_Epoch7-Step3000_ValIoU0.6156470167768288.pth'   # 现在最好的是61.65
                                                                                    # 把val_every设置为500应该还可以提一下
    # load_model = 'Joint_FPN2_DeltaModel_Epoch9-Step6000_ValIoU0.6177457155898146.pth'
    load_model = 'ResNext_Plus_DeltaModel_Epoch7-Step3000_ValIoU0.619344842105648.pth'
    polynet_pretrained = '/data/duye/pretrained_models/' + load_model
    net = PolygonModel(predict_delta=True).to(device)
    net.load_state_dict(torch.load(polynet_pretrained))
    # Test mode
    net.eval()
    print('Pretrained model \'{}\' loaded!'.format(load_model))
    ious_test, less2_test, nu_test, de_test, iou_mean_class, iou_mean = get_score(net, maxnum=20, saved=True)
    print('Origin iou:', ious_test)
    print('True   iou:', iou_mean_class)
    print('Mean: ', iou_mean)
    """
    iou_mean_test = 0.
    for i in ious_test:
        iou_mean_test += ious_test[i]
    ious_val, less2_val, nu_val, de_val = get_score(net, dataset='val', saved=False)
    print('PRINT, VAL:', ious_val)
    nu_total = {}
    de_total = {}
    iou_total = {}
    for cls in selected_classes:  # init
Esempio n. 3
0
def getscore2(model_path, dataset='Rooftop', saved=False, maxnum=float('inf')):

    model = PolygonModel(predict_delta=True).to(devices)
    if model_path is not None:
        model.load_state_dict(torch.load(model_path))
        print('Model loaded!')
    # set to eval
    model.eval()
    iou_score = 0.
    nu = 0.  # Intersection
    de = 0.  # Union
    count = 0
    files = glob.glob('/data/duye/Aerial_Imagery/Rooftop/test/*.mat')  # 所有mat文件
    iouss = []
    for idx, f in enumerate(files):
        data = scio.loadmat(f)
        # 读取相应的Image文件
        img_f = f[:-3] + 'JPG'
        image = Image.open(img_f).convert('RGB')
        img_gt = Image.open(img_f).convert('RGB')
        I = np.array(image)
        W = image.width
        H = image.height
        lens = data['gt'][0].shape[0]
        for instance_id in range(lens):
            polygon = data['gt'][0][instance_id]
            polygon = np.array(polygon, dtype=np.float)
            vertex_num = len(polygon)
            if vertex_num < 3:
                continue
            # find min/max X,Y
            minW, minH = np.min(polygon, axis=0)
            maxW, maxH = np.max(polygon, axis=0)
            curW = maxW - minW
            curH = maxH - minH
            extendrate = 0.10
            extendW = curW * extendrate
            extendH = curH * extendrate
            leftW = int(np.maximum(minW - extendW, 0))
            leftH = int(np.maximum(minH - extendH, 0))
            rightW = int(np.minimum(maxW + extendW, W))
            rightH = int(np.minimum(maxH + extendH, H))
            objectW = rightW - leftW
            objectH = rightH - leftH

            # 过滤掉小的和过大的
            if objectW >= 150 or objectH >= 150:
                continue
            if objectW <= 20 or objectH <= 20:
                continue

            scaleH = 224.0 / float(objectH)
            scaleW = 224.0 / float(objectW)
            # 裁减,resize到224*224
            # img_new = image.crop(box=(leftW, leftH, rightW, rightH)).resize((224, 224), Image.BILINEAR)
            I_obj = I[leftH:rightH, leftW:rightW, :]
            # To PIL image
            I_obj_img = Image.fromarray(I_obj)
            # resize
            I_obj_img = I_obj_img.resize((224, 224), Image.BILINEAR)
            I_obj_new = np.array(I_obj_img)  # (H, W, C)
            I_obj_new = I_obj_new.transpose(2, 0, 1)  # (C, H, W)
            I_obj_new = I_obj_new / 255.0
            I_obj_tensor = torch.from_numpy(I_obj_new)  # (C, H, W)
            I_obj_tensor = torch.tensor(I_obj_tensor.unsqueeze(0), dtype=torch.float).cuda()

            color = [np.random.randint(0, 255) for _ in range(3)]
            color += [100]
            color = tuple(color)

            with torch.no_grad():
                pre_v2 = None
                pre_v1 = None
                result_dict = model(I_obj_tensor, pre_v2, pre_v1,
                                    mode='test',
                                    temperature=0.0)  # (bs, seq_len)
            pred_x = result_dict['final_pred_x'].cpu().numpy()[0]
            pred_y = result_dict['final_pred_y'].cpu().numpy()[0]
            pred_lengths = result_dict['lengths'].cpu().numpy()[0]
            pred_len = np.sum(pred_lengths) - 1  # sub EOS
            vertices1 = []
            vertices2 = []

            # Get the pred poly
            for i in range(pred_len):
                vert = (pred_x[i] / scaleW + leftW,
                        pred_y[i] / scaleH + leftH)
                vertices1.append(vert)
            if len(vertices1) < 3:
                continue

            if saved:
                try:
                    drw = ImageDraw.Draw(image, 'RGBA')
                    drw.polygon(vertices1, color)
                except TypeError:
                    continue
            #  GT
            for points in polygon:
                vertex = (points[0], points[1])
                vertices2.append(vertex)

            if saved:
                #  GT draw
                drw_gt = ImageDraw.Draw(img_gt, 'RGBA')
                drw_gt.polygon(vertices2, color)

            # calculate IoU
            tmp, nu_cur, de_cur = iou(vertices1, vertices2, H, W)
            nu += nu_cur
            de += de_cur
            iouss.append(tmp)
        count += 1
        if saved:
            print('saving test result image...')
            save_result_dir = '/data/duye/save_dir/'
            image.save(save_result_dir + str(idx) + '_pred_rooftop.png', 'PNG')
            img_gt.save(save_result_dir + str(idx) + '_gt_rooftop.png', 'PNG')
        if count >= maxnum:
            break

    iouss.sort()
    iouss.reverse()
    true_iou = np.mean(np.array(iouss))
    print(iouss)
    return iou_score, nu, de, true_iou
Esempio n. 4
0
def getscore_kitti(model_path, saved=False, maxnum=float('inf')):

    model = PolygonModel(predict_delta=True).to(devices)
    if model_path is not None:
        model.load_state_dict(torch.load(model_path))
        print('Model loaded!')
    # set to eval
    model.eval()
    iou_score = 0.
    nu = 0.  # Intersection
    de = 0.  # Union
    count = 0
    files = glob.glob('/data/duye/KITTI/rawImage/*.png')  # 所有img
    bbox = '/data/duye/KITTI/bbox/'
    annotation = '/data/duye/KITTI/annotation/'
    print(len(files))
    iouss = []
    for idx, f in enumerate(files):
        print('index:', idx)
        image = Image.open(f).convert('RGB')  # raw image
        W = image.width
        H = image.height
        I = np.array(image)
        # print(I.shape)
        # 读相应的BD
        name = f.split('/')[-1][:-4]  # 000019
        bd = bbox + name + '.txt'

        if not os.path.exists(bd):
            continue

        # 相应的annotation
        sss = annotation + name + '.png'
        if not os.path.exists(sss):
            continue
        anno = Image.open(annotation + name + '.png')
        anno = np.array(anno)
        # 遍历
        with open(bd, 'r') as bbd:
            all = bbd.readlines()
            for number, line in enumerate(all):
                line = line.replace('\n', '')
                line = line.split(' ')
                if float(line[0]) == 0.0 or \
                   float(line[1]) == 0.0 or \
                   float(line[2]) == 0.0 or\
                   float(line[3]) == 0.0:
                    continue
                xx = float(line[0])
                yy = float(line[1])
                ww = float(line[2])
                hh = float(line[3])
                minW = xx
                minH = yy
                maxW = xx + ww
                maxH = yy + hh
                # 扩展10%
                extendrate = 0.08
                curW = ww
                curH = hh
                extendW = int(round(curW * extendrate))
                extendH = int(round(curH * extendrate))
                leftW = int(np.maximum(minW - extendW, 0))
                leftH = int(np.maximum(minH - extendH, 0))
                rightW = int(np.minimum(maxW + extendW, W))
                rightH = int(np.minimum(maxH + extendH, H))
                # 当前object的BBoundBox大小,用作坐标缩放
                objectW = rightW - leftW
                objectH = rightH - leftH
                scaleH = 224.0 / float(objectH)
                scaleW = 224.0 / float(objectW)
                # 裁减,resize到224*224
                # img_new = image.crop(box=(leftW, leftH, rightW, rightH)).resize((224, 224), Image.BILINEAR)
                I_obj = I[leftH:rightH, leftW:rightW, :]
                # To PIL image
                I_obj_img = Image.fromarray(I_obj)
                # resize
                I_obj_img = I_obj_img.resize((224, 224), Image.BILINEAR)
                I_obj_new = np.array(I_obj_img)  # (H, W, C)
                I_obj_new = I_obj_new.transpose(2, 0, 1)  # (C, H, W)
                I_obj_new = I_obj_new / 255.0
                I_obj_tensor = torch.from_numpy(I_obj_new)  # (C, H, W)
                I_obj_tensor = torch.tensor(I_obj_tensor.unsqueeze(0),
                                            dtype=torch.float).to(devices)

                color = [np.random.randint(0, 255) for _ in range(3)]
                color += [100]
                color = tuple(color)

                with torch.no_grad():
                    pre_v2 = None
                    pre_v1 = None
                    result_dict = model(I_obj_tensor,
                                        pre_v2,
                                        pre_v1,
                                        mode='test',
                                        temperature=0.0)  # (bs, seq_len)
                # [0, 224] index 0: only one sample in mini-batch here
                pred_x = result_dict['final_pred_x'].cpu().numpy()[0]
                pred_y = result_dict['final_pred_y'].cpu().numpy()[0]
                pred_lengths = result_dict['lengths'].cpu().numpy()[0]
                pred_len = np.sum(pred_lengths) - 1  # sub EOS
                vertices1 = []
                # Get the pred poly
                for i in range(pred_len):
                    vert = (pred_x[i] / scaleW + leftW,
                            pred_y[i] / scaleH + leftH)
                    vertices1.append(vert)
                if saved:
                    try:
                        drw = ImageDraw.Draw(image, 'RGBA')
                        drw.polygon(vertices1, color)
                    except TypeError:
                        continue
                # pred-mask
                img1 = Image.new('L', (W, H), 0)
                ImageDraw.Draw(img1).polygon(vertices1, outline=1, fill=1)
                pre_mask = np.array(img1)  # (H, W)

                # gt mask
                # number 这样不对!
                cur_anno = anno
                cur_anno = np.array(cur_anno == number + 1, dtype=int)
                # cur_anno[cur_anno != 1] = 0

                # getIOU
                intersection = np.logical_and(cur_anno, pre_mask)
                union = np.logical_or(cur_anno, pre_mask)
                nu = np.sum(intersection)
                de = np.sum(union)
                iiou = nu / (de * 1.0) if de != 0 else 0.
                iouss.append(iiou)

    iouss.sort()
    iouss.reverse()
    print(iouss)

    print(np.mean(np.array(iouss)))

    true_iou = np.mean(np.array(iouss[:741]))

    return iou_score, nu, de, true_iou
def get_score_ADE20K(saved=False, maxnum=float('inf')):
    model = PolygonModel(predict_delta=True).to(devices)
    pre = 'ResNext_Plus_RL2_retain_Epoch1-Step4000_ValIoU0.6316584628283326.pth'
    dirs = '/data/duye/pretrained_models/FPNRLtrain/' + pre
    model.load_state_dict(torch.load(dirs))
    model.eval()

    iou = []
    print('starting.....')
    img_PATH = '/data/duye/ADE20K/validation/'
    lbl_path = '/data/duye/ADE20K/val_new/label/*.png'
    labels = glob.glob(lbl_path)
    for label in labels:
        name = label
        label = Image.open(label)
        label_index = name.split('_')[2]
        # 相应的txt文件
        txt_file = '/data/duye/ADE20K/val_new/img/img_' + label_index + '.txt'
        with open(txt_file, "r") as f:  # 打开文件
            img_path = f.readline().replace('\n', '')  # 读取文件
            # 提取路径
        img_path = img_PATH + img_path[36:]
        # raw image
        img = Image.open(img_path).convert('RGB')
        W = img.width
        H = img.height
        # 根据label
        label = np.array(label)  # (H, W)
        Hs, Ws = np.where(label == np.max(label))
        minH = np.min(Hs)
        maxH = np.max(Hs)
        minW = np.min(Ws)
        maxW = np.max(Ws)
        curW = maxW - minW
        curH = maxH - minH
        extendrate = 0.10
        extendW = int(round(curW * extendrate))
        extendH = int(round(curH * extendrate))
        leftW = np.maximum(minW - extendW, 0)
        leftH = np.maximum(minH - extendH, 0)
        rightW = np.minimum(maxW + extendW, W)
        rightH = np.minimum(maxH + extendH, H)
        objectW = rightW - leftW
        objectH = rightH - leftH
        # print(leftH, rightH, leftW, rightW)
        # img_new = img.crop(box=(leftW, leftH, rightW, rightH)).resize((224, 224), Image.BILINEAR)
        I = np.array(img)
        I_obj = I[leftH:rightH, leftW:rightW, :]
        # To PIL image
        I_obj_img = Image.fromarray(I_obj)
        # resize
        I_obj_img = I_obj_img.resize((224, 224), Image.BILINEAR)
        I_obj_new = np.array(I_obj_img)  # (H, W, C)
        I_obj_new = I_obj_new.transpose(2, 0, 1)  # (C, H, W)
        I_obj_new = I_obj_new / 255.0
        I_obj_tensor = torch.from_numpy(I_obj_new)  # (C, H, W)
        I_obj_tensor = torch.tensor(I_obj_tensor.unsqueeze(0),
                                    dtype=torch.float).cuda()

        color = [np.random.randint(0, 255) for _ in range(3)]
        color += [100]
        color = tuple(color)

        with torch.no_grad():
            pre_v2 = None
            pre_v1 = None
            result_dict = model(I_obj_tensor,
                                pre_v2,
                                pre_v1,
                                mode='test',
                                temperature=0.0)  # (bs, seq_len)

        # [0, 224] index 0: only one sample in mini-batch here
        pred_x = result_dict['final_pred_x'].cpu().numpy()[0]
        pred_y = result_dict['final_pred_y'].cpu().numpy()[0]
        pred_lengths = result_dict['lengths'].cpu().numpy()[0]
        pred_len = np.sum(pred_lengths) - 1  # sub EOS
        vertices1 = []

        scaleW = 224.0 / float(objectW)
        scaleH = 224.0 / float(objectH)
        # Get the pred poly
        for i in range(pred_len):
            vert = (pred_x[i] / scaleW + leftW, pred_y[i] / scaleH + leftH)
            vertices1.append(vert)
        img1 = Image.new('L', (W, H), 0)
        ImageDraw.Draw(img1).polygon(vertices1, outline=1, fill=1)
        pre_mask = np.array(img1)  # (H, W)

        if saved:
            try:
                drw = ImageDraw.Draw(img, 'RGBA')
                drw.polygon(vertices1, color)
            except TypeError:
                continue

        gt_mask = np.array(label)
        gt_mask[gt_mask == 255] = 1
        filt = np.sum(gt_mask)
        if filt <= 20 * 20:
            continue
        intersection = np.logical_and(gt_mask, pre_mask)
        union = np.logical_or(gt_mask, pre_mask)
        nu = np.sum(intersection)
        de = np.sum(union)
        # 求IoU
        iiou = nu / (de * 1.0) if de != 0 else 0.
        iou.append(iiou)

    iou.sort()
    iou.reverse()

    print(iou)
    print(len(iou))

    print('IoU:', np.mean(np.array(iou)))
                            drop_last=False)
    print('DataLoader complete!', dataloader)
    return dataloader

# 测试得分
devices = 'cuda' if torch.cuda.is_available() else 'cpu'
if __name__ == '__main__':
    parse = argparse.ArgumentParser(description='测试在ssTEM上的泛化得分')
    parse.add_argument('-p', '--pretrained', type=str, default=None)
    args = parse.parse_args()
    pre = args.pretrained
    model = PolygonModel(predict_delta=True).to(devices)
    pre = 'ResNext_Plus_RL2_retain_Epoch1-Step4000_ValIoU0.6316584628283326.pth'
    dirs = '/data/duye/pretrained_models/FPNRLtrain/' + pre
    model.load_state_dict(torch.load(dirs))
    model.eval()
    loader = loadssTEM(batch_size=8)

    iou = []
    for index, batch in enumerate(loader):
        print('index: ', index)
        img = batch[0]
        WH = batch[-1]  # WH_dict
        left_WH = WH['left_WH']
        origion_WH = WH['origion_WH']
        object_WH = WH['object_WH']
        gt = batch[1]

        bs = img.shape[0]
        with torch.no_grad():
            pre_v2 = None
def train(config, load_resnet50=False, pre_trained=None, cur_epochs=0):

    batch_size = config['batch_size']
    lr = config['lr']
    epochs = config['epoch']

    train_dataloader = loadData('train', 16, 71, batch_size)
    val_loader = loadData('val', 16, 71, batch_size, shuffle=False)
    model = PolygonModel(load_predtrained_resnet50=load_resnet50,
                         predict_delta=False).cuda()
    # checkpoint
    if pre_trained is not None:
        model.load_state_dict(torch.load(pre_trained))
        print('loaded pretrained polygon net!')

    # Regulation,原paper没有+regulation
    no_wd = []
    wd = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            # No optimization for frozen params
            continue
        if 'bn' in name or 'convLSTM' in name or 'bias' in name:
            no_wd.append(param)
        else:
            wd.append(param)

    optimizer = optim.Adam([{
        'params': no_wd,
        'weight_decay': 0.0
    }, {
        'params': wd
    }],
                           lr=lr,
                           weight_decay=config['weight_decay'],
                           amsgrad=False)

    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=config['lr_decay'][0],
                                          gamma=config['lr_decay'][1])

    print('Total Epochs:', epochs)
    for it in range(cur_epochs, epochs):
        accum = defaultdict(float)
        # accum['loss_total'] = 0.
        # accum['loss_lstm'] = 0.
        # accum['loss_delta'] = 0.
        for index, batch in enumerate(train_dataloader):
            img = torch.tensor(batch[0], dtype=torch.float).cuda()
            bs = img.shape[0]
            pre_v2 = torch.tensor(batch[2], dtype=torch.float).cuda()
            pre_v1 = torch.tensor(batch[3], dtype=torch.float).cuda()
            outdict = model(img, pre_v2, pre_v1,
                            mode='train_ce')  # (bs, seq_len, 28*28+1)s

            out = outdict['logits']
            # 之前训练不小心加了下面这句
            # out = torch.nn.functional.log_softmax(out, dim=-1)  # logits->log_probs
            out = out.contiguous().view(-1,
                                        28 * 28 + 1)  # (bs*seq_len, 28*28+1)
            target = batch[4]

            # smooth target
            target = dt_targets_from_class(np.array(target, dtype=np.int), 28,
                                           2)  # (bs, seq_len, 28*28+1)
            target = torch.from_numpy(target).cuda().contiguous().view(
                -1, 28 * 28 + 1)  # (bs, seq_len, 28*28+1)
            # 交叉熵损失计算
            mask_final = batch[6]  # 结束符标志mask  (bs, seq_len(70)从第一个点开始)
            mask_final = torch.tensor(mask_final).cuda().view(-1)
            mask_delta = batch[7]
            mask_delta = torch.tensor(mask_delta).cuda().view(-1)  # (bs*70)
            loss_lstm = torch.sum(-target *
                                  torch.nn.functional.log_softmax(out, dim=1),
                                  dim=1)  # (bs*seq_len)
            loss_lstm = loss_lstm * mask_final.type_as(
                loss_lstm)  # 从end point截断损失计算
            loss_lstm = loss_lstm.view(bs, -1)  # (bs, seq_len)
            loss_lstm = torch.sum(loss_lstm, dim=1)  # sum over seq_len  (bs,)
            real_pointnum = torch.sum(mask_final.contiguous().view(bs, -1),
                                      dim=1)
            loss_lstm = loss_lstm / real_pointnum  # mean over seq_len
            loss_lstm = torch.mean(loss_lstm)  # mean over batch

            # loss = loss_lstm + loss_delta
            loss = loss_lstm
            #TODO: 这里train_ce可以用这个loss, 但train_rl可以根据条件概率重写损失函数
            model.zero_grad()

            if 'grid_clip' in config:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         config['grad_clip'])

            loss.backward()

            accum['loss_total'] += loss
            optimizer.step()

            # 打印损失
            if (index + 1) % 20 == 0:
                print('Epoch {} - Step {}, loss_total {}'.format(
                    it + 1, index, accum['loss_total'] / 20))
                accum = defaultdict(float)
            # 每3000step一次
            if (index + 1) % config['val_every'] == 0:
                # validation
                model.eval()  # 原作者只eval了这个
                val_IoU = []
                less_than2 = 0
                with torch.no_grad():
                    for val_index, val_batch in enumerate(val_loader):
                        img = torch.tensor(val_batch[0],
                                           dtype=torch.float).cuda()
                        bs = img.shape[0]

                        WH = val_batch[-1]  # WH_dict
                        left_WH = WH['left_WH']
                        origion_WH = WH['origion_WH']
                        object_WH = WH['object_WH']

                        val_mask_final = val_batch[6]
                        val_mask_final = torch.tensor(
                            val_mask_final).cuda().contiguous().view(-1)
                        out_dict = model(
                            img, mode='test')  # (N, seq_len) # test_time
                        pred_polys = out_dict['pred_polys']  # (bs, seq_len)
                        tmp = pred_polys
                        pred_polys = pred_polys.contiguous().view(
                            -1)  # (bs*seq_len)
                        val_target = val_batch[4]  # (bs, seq_len)
                        # 求accuracy
                        val_target = torch.tensor(
                            val_target,
                            dtype=torch.long).cuda().contiguous().view(
                                -1)  # (bs*seq_len)
                        val_acc1 = torch.tensor(pred_polys == val_target,
                                                dtype=torch.float).cuda()
                        val_acc1 = (val_acc1 * val_mask_final).sum().item()
                        val_acc1 = val_acc1 * 1.0 / val_mask_final.sum().item()
                        # 用作计算IoU
                        val_result_index = tmp.cpu().numpy()  # (bs, seq_len)
                        val_target = val_batch[4].numpy()  # (bs, seq_len)

                        # 求IoU
                        for ii in range(bs):
                            vertices1 = []
                            vertices2 = []
                            scaleW = 224.0 / object_WH[0][ii]
                            scaleH = 224.0 / object_WH[1][ii]
                            leftW = left_WH[0][ii]
                            leftH = left_WH[1][ii]
                            for label in val_result_index[ii]:
                                if label == 28 * 28:
                                    break
                                vertex = (
                                    ((label % 28) * 8.0 + 4) / scaleW + leftW,
                                    ((int(label / 28)) * 8.0 + 4) / scaleH +
                                    leftH)
                                vertices1.append(vertex)
                            for label in val_target[ii]:
                                if label == 28 * 28:
                                    break
                                vertex = (
                                    ((label % 28) * 8.0 + 4) / scaleW + leftW,
                                    ((int(label / 28)) * 8.0 + 4) / scaleH +
                                    leftH)
                                vertices2.append(vertex)
                            if len(vertices1) < 2:
                                less_than2 += 1
                                # IoU=0.
                                val_IoU.append(0.)
                                continue
                            _, nu_cur, de_cur = iou(
                                vertices1, vertices2, origion_WH[1][ii],
                                origion_WH[0][ii])  # (H, W)
                            iou_cur = nu_cur * 1.0 / de_cur if de_cur != 0 else 0
                            val_IoU.append(iou_cur)

                val_iou_data = np.mean(np.array(val_IoU))
                print('Validation After Epoch {} - step {}'.format(
                    str(it + 1), str(index + 1)))
                print('           IoU      on validation set: ', val_iou_data)
                print('less than 2: ', less_than2)
                if it > 4:  # it = 5
                    print('Saving training parameters after this epoch:')
                    torch.save(
                        model.state_dict(),
                        '/data/duye/pretrained_models/ResNext50_FPN_LSTM_Epoch{}-Step{}_ValIoU{}.pth'
                        .format(str(it + 1), str(index + 1),
                                str(val_iou_data)))
                # set to init
                model.train()  # important

        # 衰减
        scheduler.step()
        # 打印当前lr
        print()
        print('Epoch {} Completed!'.format(str(it + 1)))
        print()