Esempio n. 1
0
    return iou_score, less2, nu, de, iou_mean_class, iou_mean


if __name__ == '__main__':
    print('start calculating...')

    # load_model = 'New_FPN2_DeltaModel_Epoch6-Step6000_ValIoU0.5855123247072112.pth'
    # load_model = 'New_FPN2_DeltaModel_Epoch5-Step4000_ValIoU0.6130178256915678.pth'
    # load_model = 'FPN_Epoch9-Step6000_ValIoU0.22982870834472033.pth'
    # load_model = 'Joint_FPN2_DeltaModel_Epoch7-Step3000_ValIoU0.6156470167768288.pth'   # 现在最好的是61.65
                                                                                    # 把val_every设置为500应该还可以提一下
    # load_model = 'Joint_FPN2_DeltaModel_Epoch9-Step6000_ValIoU0.6177457155898146.pth'
    load_model = 'ResNext_Plus_DeltaModel_Epoch7-Step3000_ValIoU0.619344842105648.pth'
    polynet_pretrained = '/data/duye/pretrained_models/' + load_model
    net = PolygonModel(predict_delta=True).to(device)
    net.load_state_dict(torch.load(polynet_pretrained))
    # Test mode
    net.eval()
    print('Pretrained model \'{}\' loaded!'.format(load_model))
    ious_test, less2_test, nu_test, de_test, iou_mean_class, iou_mean = get_score(net, maxnum=20, saved=True)
    print('Origin iou:', ious_test)
    print('True   iou:', iou_mean_class)
    print('Mean: ', iou_mean)
    """
    iou_mean_test = 0.
    for i in ious_test:
        iou_mean_test += ious_test[i]
    ious_val, less2_val, nu_val, de_val = get_score(net, dataset='val', saved=False)
    print('PRINT, VAL:', ious_val)
    nu_total = {}
    de_total = {}
class OnLineTrainer:
    def __init__(self, num_workers=8, update_every=8, save_every=10, t1=0.1, t2=0.2, pre=None):
        self.num_workers = num_workers
        self.update_every = update_every
        self.save_every = save_every
        self.t1 = t1
        self.t2 = t2
        self.model = PolygonModel(predict_delta=True).to(devices)
        if pre != None:
            self.model.load_state_dict(torch.load(pre))
        self.dataloader = loadData(data_num=16,
                                   batch_size=self.num_workers,
                                   len_s=71,
                                   path='val',
                                   shuffle=False)
        self.model.encoder.eval()
        self.model.delta_encoder.eval()
        for n, p in self.model.named_parameters():
            if 'encoder' in n:
                p.requires_grad = False
        self.train_params = [p for p in self.model.parameters() if p.requires_grad==False]
        self.optimizer = optim.Adam(self.train_params,
                                    lr=2e-6,
                                    amsgrad=False)

    def train(self):
        accum = defaultdict(float)
        accum2 = defaultdict(float)
        for step, batch in enumerate(self.dataloader):
            img = torch.tensor(batch[0], dtype=torch.float).cuda()
            bs = img.shape[0]
            WH = batch[-1]  # WH_dict
            left_WH = WH['left_WH']
            origion_WH = WH['origion_WH']
            object_WH = WH['object_WH']
            # TODO: step1
            self.model.delta_model.eval()
            self.model.decoder.train()
            outdict_sample = self.model(img, mode='train_rl', temperature=self.t1,
                                        temperature2=0.0)
            # greedy
            with torch.no_grad():
                outdict_greedy = self.model(img, mode='train_rl', temperature=0.0)
            # Get RL loss
            sampling_pred_x = outdict_sample['final_pred_x'].cpu().numpy()
            sampling_pred_y = outdict_sample['final_pred_y'].cpu().numpy()
            sampling_pred_len = outdict_sample['lengths'].cpu().numpy()
            greedy_pred_x = outdict_greedy['final_pred_x'].cpu().numpy()
            greedy_pred_y = outdict_greedy['final_pred_y'].cpu().numpy()
            greedy_pred_len = outdict_greedy['lengths'].cpu().numpy()
            sampling_iou = np.zeros(bs, dtype=np.float32)
            greedy_iou = np.zeros(bs, dtype=np.float32)
            vertices_GT = []  # (bs, 70, 2)
            vertices_sampling = []
            vertices_greedy = []
            GT_polys = batch[-2].numpy()  # (bs, 70, 2)
            GT_mask = batch[7]  # (bs, 70)
            for ii in range(bs):
                scaleW = 224.0 / object_WH[0][ii]
                scaleH = 224.0 / object_WH[1][ii]
                leftW = left_WH[0][ii]
                leftH = left_WH[1][ii]
                tmp = []
                all_len = np.sum(GT_mask[ii].numpy())
                cnt_target = GT_polys[ii][:all_len]
                for vert in cnt_target:
                    tmp.append((vert[0] / scaleW + leftW,
                                vert[1] / scaleH + leftH))
                vertices_GT.append(tmp)

                tmp = []
                for j in range(sampling_pred_len[ii] - 1):
                    vertex = (
                        sampling_pred_x[ii][j] / scaleW + leftW,
                        sampling_pred_y[ii][j] / scaleH + leftH
                    )
                    tmp.append(vertex)
                vertices_sampling.append(tmp)

                tmp = []
                for j in range(greedy_pred_len[ii] - 1):
                    vertex = (
                        greedy_pred_x[ii][j] / scaleW + leftW,
                        greedy_pred_y[ii][j] / scaleH + leftH
                    )
                    tmp.append(vertex)
                vertices_greedy.append(tmp)
            # IoU between sampling/greedy and GT
            for ii in range(bs):
                sam = vertices_sampling[ii]
                gt = vertices_GT[ii]
                gre = vertices_greedy[ii]
                if len(sam) < 2:
                    sampling_iou[ii] = 0.
                else:
                    iou_sam, _, _ = iou(sam, gt, origion_WH[1][ii], origion_WH[0][ii])
                    sampling_iou[ii] = iou_sam
                if len(gre) < 2:
                    greedy_iou[ii] = 0.
                else:
                    iou_gre, _, _ = iou(gre, gt, origion_WH[1][ii], origion_WH[0][ii])
                    greedy_iou[ii] = iou_gre
            logprobs = outdict_sample['log_probs']
            # 强化学习损失,logprob是两个logprob加和
            loss = losses.self_critical_loss(logprobs, outdict_sample['lengths'],
                                             torch.from_numpy(sampling_iou).to(devices),
                                             torch.from_numpy(greedy_iou).to(devices))
            self.model.zero_grad()
            nn.utils.clip_grad_norm_(self.model.parameters(), 40)
            loss.backward()
            self.optimizer.step()  # 更新参数
            accum['loss_total'] += loss
            accum['sampling_iou'] += np.mean(sampling_iou)
            accum['greedy_iou'] += np.mean(greedy_iou)
            # 打印损失
            print('Update {}, RL training of Main decoder, loss {}, model IoU {}'.format(step + 1,
                                                                             accum['loss_total'],
                                                                             accum['greedy_iou']))
            accum = defaultdict(float)
            # TODO:训练delta_model decoder step2
            self.model.decoder.eval()
            self.model.delta_model.train()
            outdict_sample = self.model(img, mode='train_rl',
                                        temperature=0.0,
                                        temperature2=self.t2)
            # greedy
            with torch.no_grad():
                outdict_greedy = self.model(img, mode='train_rl',
                                       temperature=0.0,
                                       temperature2=0.0)
            # Get RL loss
            sampling_pred_x = outdict_sample['final_pred_x'].cpu().numpy()
            sampling_pred_y = outdict_sample['final_pred_y'].cpu().numpy()
            sampling_pred_len = outdict_sample['lengths'].cpu().numpy()
            greedy_pred_x = outdict_greedy['final_pred_x'].cpu().numpy()
            greedy_pred_y = outdict_greedy['final_pred_y'].cpu().numpy()
            greedy_pred_len = outdict_greedy['lengths'].cpu().numpy()
            sampling_iou = np.zeros(bs, dtype=np.float32)
            greedy_iou = np.zeros(bs, dtype=np.float32)
            vertices_GT = []  # (bs, 70, 2)
            vertices_sampling = []
            vertices_greedy = []
            GT_polys = batch[-2].numpy()  # (bs, 70, 2)
            GT_mask = batch[7]  # (bs, 70)

            for ii in range(bs):
                scaleW = 224.0 / object_WH[0][ii]
                scaleH = 224.0 / object_WH[1][ii]
                leftW = left_WH[0][ii]
                leftH = left_WH[1][ii]
                tmp = []
                all_len = np.sum(GT_mask[ii].numpy())
                cnt_target = GT_polys[ii][:all_len]
                for vert in cnt_target:
                    tmp.append((vert[0] / scaleW + leftW,
                                vert[1] / scaleH + leftH))
                vertices_GT.append(tmp)

                tmp = []
                for j in range(sampling_pred_len[ii] - 1):
                    vertex = (
                        sampling_pred_x[ii][j] / scaleW + leftW,
                        sampling_pred_y[ii][j] / scaleH + leftH
                    )
                    tmp.append(vertex)
                vertices_sampling.append(tmp)

                tmp = []
                for j in range(greedy_pred_len[ii] - 1):
                    vertex = (
                        greedy_pred_x[ii][j] / scaleW + leftW,
                        greedy_pred_y[ii][j] / scaleH + leftH
                    )
                    tmp.append(vertex)
                vertices_greedy.append(tmp)

            # IoU between sampling/greedy and GT
            for ii in range(bs):
                sam = vertices_sampling[ii]
                gt = vertices_GT[ii]
                gre = vertices_greedy[ii]

                if len(sam) < 2:
                    sampling_iou[ii] = 0.
                else:
                    iou_sam, _, _ = iou(sam, gt, origion_WH[1][ii], origion_WH[0][ii])
                    sampling_iou[ii] = iou_sam

                if len(gre) < 2:
                    greedy_iou[ii] = 0.
                else:
                    iou_gre, _, _ = iou(gre, gt, origion_WH[1][ii], origion_WH[0][ii])
                    greedy_iou[ii] = iou_gre

            # TODO:
            logprobs = outdict_sample['delta_logprob']
            # 强化学习损失,logprob是两个logprob加和
            loss = losses.self_critical_loss(logprobs, outdict_sample['lengths'],
                                             torch.from_numpy(sampling_iou).to(devices),
                                             torch.from_numpy(greedy_iou).to(devices))
            self.model.zero_grad()
            nn.utils.clip_grad_norm_(self.model.parameters(), 40)
            loss.backward()
            self.optimizer.step()
            accum2['loss_total'] += loss
            accum2['sampling_iou'] += np.mean(sampling_iou)
            accum2['greedy_iou'] += np.mean(greedy_iou)
            # 打印损失
            print('Update {}, RL training of Second decoder, loss {}, model IoU {}'.format(step + 1,
                                                                             accum2['loss_total'],
                                                                             accum2['greedy_iou']))
            accum2 = defaultdict(float)

            if (step + 1) % self.save_every == 0:
                print('Saving training parameters after Updating...')
                save_dir = '/data/duye/pretrained_models/OnLineTraining/'
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                torch.save(self.model.state_dict(), save_dir + str(step+1) + '.pth')
def getscore2(model_path, saved=False, maxnum=float('inf')):

    model = PolygonModel(predict_delta=True).to(devices)
    if model_path is not None:
        model.load_state_dict(torch.load(model_path))
        print('Model loaded!')
    # set to eval
    model.eval()
    iou_score = 0.
    nu = 0.  # Intersection
    de = 0.  # Union
    count = 0
    files = glob.glob('/data/duye/KITTI/image/*')  # 所有img
    iouss = []
    trans = transforms.Compose([
        transforms.ToTensor(),
    ])
    for idx, f in enumerate(files):
        # data = scio.loadmat(f)
        # 读取相应的Image文件
        # img_f = f[:-3] + 'JPG'
        image = Image.open(f).convert('RGB')  # png文件
        # img_gt = Image.open(f).convert('RGB')
        W = image.width
        H = image.height
        scaleH = 224.0 / float(H)
        scaleW = 224.0 / float(W)
        # 裁减,resize到224*224
        img_new = image.resize((224, 224), Image.BILINEAR)
        img_new = trans(img_new)
        img_new = img_new.unsqueeze(0)
        color = [np.random.randint(0, 255) for _ in range(3)]
        color += [100]
        color = tuple(color)

        with torch.no_grad():
            pre_v2 = None
            pre_v1 = None
            result_dict = model(img_new.to(devices),
                                pre_v2,
                                pre_v1,
                                mode='test',
                                temperature=0.0)  # (bs, seq_len)

            # [0, 224] index 0: only one sample in mini-batch here
            pred_x = result_dict['final_pred_x'].cpu().numpy()[0]
            pred_y = result_dict['final_pred_y'].cpu().numpy()[0]
            pred_lengths = result_dict['lengths'].cpu().numpy()[0]
            pred_len = np.sum(pred_lengths) - 1  # sub EOS
            vertices1 = []

            # Get the pred poly
            for i in range(pred_len):
                vert = (pred_x[i] / scaleW, pred_y[i] / scaleH)
                vertices1.append(vert)

            if saved:
                try:
                    drw = ImageDraw.Draw(image, 'RGBA')
                    drw.polygon(vertices1, color)
                except TypeError:
                    continue
            #  GT
            gt_name = '/data/duye/KITTI/label/' + f.split(
                '/')[-1][:-4] + '.png'
            # print(gt_name)
            # 读取mask
            gt_mask = Image.open(gt_name)
            gt_mask = np.array(gt_mask)  # (H, W)

            gt_mask[gt_mask > 0] = 255
            gt_mask[gt_mask == 255] = 1

            if saved:
                pass
                #  GT draw
                # drw_gt = ImageDraw.Draw(img_gt, 'RGBA')
                # drw_gt.polygon(vertices2, color)

            # calculate IoU
            img1 = Image.new('L', (W, H), 0)
            ImageDraw.Draw(img1).polygon(vertices1, outline=1, fill=1)
            pre_mask = np.array(img1)  # (H, W)
            # get iou
            intersection = np.logical_and(gt_mask, pre_mask)
            union = np.logical_or(gt_mask, pre_mask)
            nu = np.sum(intersection)
            de = np.sum(union)
            iiou = nu / (de * 1.0) if de != 0 else 0.
            iouss.append(iiou)
        count += 1
        print(count)
        if saved:
            print('saving test result image...')
            save_result_dir = '/data/duye/save_dir/'
            image.save(save_result_dir + str(idx) + '_pred_rooftop.png', 'PNG')
            # img_gt.save(save_result_dir + str(idx) + '_gt_rooftop.png', 'PNG')
        if count >= maxnum:
            break

    iouss.sort()
    iouss.reverse()
    print(iouss)
    true_iou = np.mean(np.array(iouss[:741]))

    return iou_score, nu, de, true_iou
Esempio n. 4
0
def getscore_kitti(model_path, saved=False, maxnum=float('inf')):

    model = PolygonModel(predict_delta=True).to(devices)
    if model_path is not None:
        model.load_state_dict(torch.load(model_path))
        print('Model loaded!')
    # set to eval
    model.eval()
    iou_score = 0.
    nu = 0.  # Intersection
    de = 0.  # Union
    count = 0
    files = glob.glob('/data/duye/KITTI/rawImage/*.png')  # 所有img
    bbox = '/data/duye/KITTI/bbox/'
    annotation = '/data/duye/KITTI/annotation/'
    print(len(files))
    iouss = []
    for idx, f in enumerate(files):
        print('index:', idx)
        image = Image.open(f).convert('RGB')  # raw image
        W = image.width
        H = image.height
        I = np.array(image)
        # print(I.shape)
        # 读相应的BD
        name = f.split('/')[-1][:-4]  # 000019
        bd = bbox + name + '.txt'

        if not os.path.exists(bd):
            continue

        # 相应的annotation
        sss = annotation + name + '.png'
        if not os.path.exists(sss):
            continue
        anno = Image.open(annotation + name + '.png')
        anno = np.array(anno)
        # 遍历
        with open(bd, 'r') as bbd:
            all = bbd.readlines()
            for number, line in enumerate(all):
                line = line.replace('\n', '')
                line = line.split(' ')
                if float(line[0]) == 0.0 or \
                   float(line[1]) == 0.0 or \
                   float(line[2]) == 0.0 or\
                   float(line[3]) == 0.0:
                    continue
                xx = float(line[0])
                yy = float(line[1])
                ww = float(line[2])
                hh = float(line[3])
                minW = xx
                minH = yy
                maxW = xx + ww
                maxH = yy + hh
                # 扩展10%
                extendrate = 0.08
                curW = ww
                curH = hh
                extendW = int(round(curW * extendrate))
                extendH = int(round(curH * extendrate))
                leftW = int(np.maximum(minW - extendW, 0))
                leftH = int(np.maximum(minH - extendH, 0))
                rightW = int(np.minimum(maxW + extendW, W))
                rightH = int(np.minimum(maxH + extendH, H))
                # 当前object的BBoundBox大小,用作坐标缩放
                objectW = rightW - leftW
                objectH = rightH - leftH
                scaleH = 224.0 / float(objectH)
                scaleW = 224.0 / float(objectW)
                # 裁减,resize到224*224
                # img_new = image.crop(box=(leftW, leftH, rightW, rightH)).resize((224, 224), Image.BILINEAR)
                I_obj = I[leftH:rightH, leftW:rightW, :]
                # To PIL image
                I_obj_img = Image.fromarray(I_obj)
                # resize
                I_obj_img = I_obj_img.resize((224, 224), Image.BILINEAR)
                I_obj_new = np.array(I_obj_img)  # (H, W, C)
                I_obj_new = I_obj_new.transpose(2, 0, 1)  # (C, H, W)
                I_obj_new = I_obj_new / 255.0
                I_obj_tensor = torch.from_numpy(I_obj_new)  # (C, H, W)
                I_obj_tensor = torch.tensor(I_obj_tensor.unsqueeze(0),
                                            dtype=torch.float).to(devices)

                color = [np.random.randint(0, 255) for _ in range(3)]
                color += [100]
                color = tuple(color)

                with torch.no_grad():
                    pre_v2 = None
                    pre_v1 = None
                    result_dict = model(I_obj_tensor,
                                        pre_v2,
                                        pre_v1,
                                        mode='test',
                                        temperature=0.0)  # (bs, seq_len)
                # [0, 224] index 0: only one sample in mini-batch here
                pred_x = result_dict['final_pred_x'].cpu().numpy()[0]
                pred_y = result_dict['final_pred_y'].cpu().numpy()[0]
                pred_lengths = result_dict['lengths'].cpu().numpy()[0]
                pred_len = np.sum(pred_lengths) - 1  # sub EOS
                vertices1 = []
                # Get the pred poly
                for i in range(pred_len):
                    vert = (pred_x[i] / scaleW + leftW,
                            pred_y[i] / scaleH + leftH)
                    vertices1.append(vert)
                if saved:
                    try:
                        drw = ImageDraw.Draw(image, 'RGBA')
                        drw.polygon(vertices1, color)
                    except TypeError:
                        continue
                # pred-mask
                img1 = Image.new('L', (W, H), 0)
                ImageDraw.Draw(img1).polygon(vertices1, outline=1, fill=1)
                pre_mask = np.array(img1)  # (H, W)

                # gt mask
                # number 这样不对!
                cur_anno = anno
                cur_anno = np.array(cur_anno == number + 1, dtype=int)
                # cur_anno[cur_anno != 1] = 0

                # getIOU
                intersection = np.logical_and(cur_anno, pre_mask)
                union = np.logical_or(cur_anno, pre_mask)
                nu = np.sum(intersection)
                de = np.sum(union)
                iiou = nu / (de * 1.0) if de != 0 else 0.
                iouss.append(iiou)

    iouss.sort()
    iouss.reverse()
    print(iouss)

    print(np.mean(np.array(iouss)))

    true_iou = np.mean(np.array(iouss[:741]))

    return iou_score, nu, de, true_iou
Esempio n. 5
0
def getscore2(model_path, dataset='Rooftop', saved=False, maxnum=float('inf')):

    model = PolygonModel(predict_delta=True).to(devices)
    if model_path is not None:
        model.load_state_dict(torch.load(model_path))
        print('Model loaded!')
    # set to eval
    model.eval()
    iou_score = 0.
    nu = 0.  # Intersection
    de = 0.  # Union
    count = 0
    files = glob.glob('/data/duye/Aerial_Imagery/Rooftop/test/*.mat')  # 所有mat文件
    iouss = []
    for idx, f in enumerate(files):
        data = scio.loadmat(f)
        # 读取相应的Image文件
        img_f = f[:-3] + 'JPG'
        image = Image.open(img_f).convert('RGB')
        img_gt = Image.open(img_f).convert('RGB')
        I = np.array(image)
        W = image.width
        H = image.height
        lens = data['gt'][0].shape[0]
        for instance_id in range(lens):
            polygon = data['gt'][0][instance_id]
            polygon = np.array(polygon, dtype=np.float)
            vertex_num = len(polygon)
            if vertex_num < 3:
                continue
            # find min/max X,Y
            minW, minH = np.min(polygon, axis=0)
            maxW, maxH = np.max(polygon, axis=0)
            curW = maxW - minW
            curH = maxH - minH
            extendrate = 0.10
            extendW = curW * extendrate
            extendH = curH * extendrate
            leftW = int(np.maximum(minW - extendW, 0))
            leftH = int(np.maximum(minH - extendH, 0))
            rightW = int(np.minimum(maxW + extendW, W))
            rightH = int(np.minimum(maxH + extendH, H))
            objectW = rightW - leftW
            objectH = rightH - leftH

            # 过滤掉小的和过大的
            if objectW >= 150 or objectH >= 150:
                continue
            if objectW <= 20 or objectH <= 20:
                continue

            scaleH = 224.0 / float(objectH)
            scaleW = 224.0 / float(objectW)
            # 裁减,resize到224*224
            # img_new = image.crop(box=(leftW, leftH, rightW, rightH)).resize((224, 224), Image.BILINEAR)
            I_obj = I[leftH:rightH, leftW:rightW, :]
            # To PIL image
            I_obj_img = Image.fromarray(I_obj)
            # resize
            I_obj_img = I_obj_img.resize((224, 224), Image.BILINEAR)
            I_obj_new = np.array(I_obj_img)  # (H, W, C)
            I_obj_new = I_obj_new.transpose(2, 0, 1)  # (C, H, W)
            I_obj_new = I_obj_new / 255.0
            I_obj_tensor = torch.from_numpy(I_obj_new)  # (C, H, W)
            I_obj_tensor = torch.tensor(I_obj_tensor.unsqueeze(0), dtype=torch.float).cuda()

            color = [np.random.randint(0, 255) for _ in range(3)]
            color += [100]
            color = tuple(color)

            with torch.no_grad():
                pre_v2 = None
                pre_v1 = None
                result_dict = model(I_obj_tensor, pre_v2, pre_v1,
                                    mode='test',
                                    temperature=0.0)  # (bs, seq_len)
            pred_x = result_dict['final_pred_x'].cpu().numpy()[0]
            pred_y = result_dict['final_pred_y'].cpu().numpy()[0]
            pred_lengths = result_dict['lengths'].cpu().numpy()[0]
            pred_len = np.sum(pred_lengths) - 1  # sub EOS
            vertices1 = []
            vertices2 = []

            # Get the pred poly
            for i in range(pred_len):
                vert = (pred_x[i] / scaleW + leftW,
                        pred_y[i] / scaleH + leftH)
                vertices1.append(vert)
            if len(vertices1) < 3:
                continue

            if saved:
                try:
                    drw = ImageDraw.Draw(image, 'RGBA')
                    drw.polygon(vertices1, color)
                except TypeError:
                    continue
            #  GT
            for points in polygon:
                vertex = (points[0], points[1])
                vertices2.append(vertex)

            if saved:
                #  GT draw
                drw_gt = ImageDraw.Draw(img_gt, 'RGBA')
                drw_gt.polygon(vertices2, color)

            # calculate IoU
            tmp, nu_cur, de_cur = iou(vertices1, vertices2, H, W)
            nu += nu_cur
            de += de_cur
            iouss.append(tmp)
        count += 1
        if saved:
            print('saving test result image...')
            save_result_dir = '/data/duye/save_dir/'
            image.save(save_result_dir + str(idx) + '_pred_rooftop.png', 'PNG')
            img_gt.save(save_result_dir + str(idx) + '_gt_rooftop.png', 'PNG')
        if count >= maxnum:
            break

    iouss.sort()
    iouss.reverse()
    true_iou = np.mean(np.array(iouss))
    print(iouss)
    return iou_score, nu, de, true_iou
def get_score_ADE20K(saved=False, maxnum=float('inf')):
    model = PolygonModel(predict_delta=True).to(devices)
    pre = 'ResNext_Plus_RL2_retain_Epoch1-Step4000_ValIoU0.6316584628283326.pth'
    dirs = '/data/duye/pretrained_models/FPNRLtrain/' + pre
    model.load_state_dict(torch.load(dirs))
    model.eval()

    iou = []
    print('starting.....')
    img_PATH = '/data/duye/ADE20K/validation/'
    lbl_path = '/data/duye/ADE20K/val_new/label/*.png'
    labels = glob.glob(lbl_path)
    for label in labels:
        name = label
        label = Image.open(label)
        label_index = name.split('_')[2]
        # 相应的txt文件
        txt_file = '/data/duye/ADE20K/val_new/img/img_' + label_index + '.txt'
        with open(txt_file, "r") as f:  # 打开文件
            img_path = f.readline().replace('\n', '')  # 读取文件
            # 提取路径
        img_path = img_PATH + img_path[36:]
        # raw image
        img = Image.open(img_path).convert('RGB')
        W = img.width
        H = img.height
        # 根据label
        label = np.array(label)  # (H, W)
        Hs, Ws = np.where(label == np.max(label))
        minH = np.min(Hs)
        maxH = np.max(Hs)
        minW = np.min(Ws)
        maxW = np.max(Ws)
        curW = maxW - minW
        curH = maxH - minH
        extendrate = 0.10
        extendW = int(round(curW * extendrate))
        extendH = int(round(curH * extendrate))
        leftW = np.maximum(minW - extendW, 0)
        leftH = np.maximum(minH - extendH, 0)
        rightW = np.minimum(maxW + extendW, W)
        rightH = np.minimum(maxH + extendH, H)
        objectW = rightW - leftW
        objectH = rightH - leftH
        # print(leftH, rightH, leftW, rightW)
        # img_new = img.crop(box=(leftW, leftH, rightW, rightH)).resize((224, 224), Image.BILINEAR)
        I = np.array(img)
        I_obj = I[leftH:rightH, leftW:rightW, :]
        # To PIL image
        I_obj_img = Image.fromarray(I_obj)
        # resize
        I_obj_img = I_obj_img.resize((224, 224), Image.BILINEAR)
        I_obj_new = np.array(I_obj_img)  # (H, W, C)
        I_obj_new = I_obj_new.transpose(2, 0, 1)  # (C, H, W)
        I_obj_new = I_obj_new / 255.0
        I_obj_tensor = torch.from_numpy(I_obj_new)  # (C, H, W)
        I_obj_tensor = torch.tensor(I_obj_tensor.unsqueeze(0),
                                    dtype=torch.float).cuda()

        color = [np.random.randint(0, 255) for _ in range(3)]
        color += [100]
        color = tuple(color)

        with torch.no_grad():
            pre_v2 = None
            pre_v1 = None
            result_dict = model(I_obj_tensor,
                                pre_v2,
                                pre_v1,
                                mode='test',
                                temperature=0.0)  # (bs, seq_len)

        # [0, 224] index 0: only one sample in mini-batch here
        pred_x = result_dict['final_pred_x'].cpu().numpy()[0]
        pred_y = result_dict['final_pred_y'].cpu().numpy()[0]
        pred_lengths = result_dict['lengths'].cpu().numpy()[0]
        pred_len = np.sum(pred_lengths) - 1  # sub EOS
        vertices1 = []

        scaleW = 224.0 / float(objectW)
        scaleH = 224.0 / float(objectH)
        # Get the pred poly
        for i in range(pred_len):
            vert = (pred_x[i] / scaleW + leftW, pred_y[i] / scaleH + leftH)
            vertices1.append(vert)
        img1 = Image.new('L', (W, H), 0)
        ImageDraw.Draw(img1).polygon(vertices1, outline=1, fill=1)
        pre_mask = np.array(img1)  # (H, W)

        if saved:
            try:
                drw = ImageDraw.Draw(img, 'RGBA')
                drw.polygon(vertices1, color)
            except TypeError:
                continue

        gt_mask = np.array(label)
        gt_mask[gt_mask == 255] = 1
        filt = np.sum(gt_mask)
        if filt <= 20 * 20:
            continue
        intersection = np.logical_and(gt_mask, pre_mask)
        union = np.logical_or(gt_mask, pre_mask)
        nu = np.sum(intersection)
        de = np.sum(union)
        # 求IoU
        iiou = nu / (de * 1.0) if de != 0 else 0.
        iou.append(iiou)

    iou.sort()
    iou.reverse()

    print(iou)
    print(len(iou))

    print('IoU:', np.mean(np.array(iou)))
def train(config, load_resnet50=False, pre_trained=None, cur_epochs=0):
    batch_size = config['batch_size']
    lr = config['lr']
    epochs = config['epoch']
    train_dataloader = loadData('train', 16, 71, batch_size)
    val_loader = loadData('val', 16, 71, batch_size, shuffle=False)
    model = PolygonModel(load_predtrained_resnet50=load_resnet50,
                         predict_delta=True).to(devices)

    if pre_trained is not None:
        model.load_state_dict(torch.load(pre_trained))
        print('loaded pretrained polygon net!')

    # set to eval
    model.encoder.eval()
    model.delta_encoder.eval()

    for n, p in model.named_parameters():
        if 'encoder' in n:
            print('Not train:', n)
            p.requires_grad = False

    print('No weight decay in RL training')

    train_params = [p for p in model.parameters() if p.requires_grad]
    train_params1 = []
    train_params2 = []
    for n, p in model.named_parameters():
        if p.requires_grad and 'delta' not in n:
            train_params1.append(p)
        elif p.requires_grad and 'delta' in n:
            train_params2.append(p)

    # Adam 优化方法
    optimizer = optim.Adam(train_params, lr=lr, amsgrad=False)
    optimizer1 = optim.Adam(train_params1, lr=lr, amsgrad=False)
    optimizer2 = optim.Adam(train_params2, lr=lr, amsgrad=False)

    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=config['lr_decay'][0],
                                          gamma=config['lr_decay'][1])

    print('Total Epochs:', epochs)
    for it in range(cur_epochs, epochs):
        # init
        accum = defaultdict(float)
        accum2 = defaultdict(float)
        model.delta_model.train()
        model.decoder.train()
        for index, batch in enumerate(train_dataloader):
            img = torch.tensor(batch[0], dtype=torch.float).cuda()
            bs = img.shape[0]
            WH = batch[-1]  # WH_dict
            left_WH = WH['left_WH']
            origion_WH = WH['origion_WH']
            object_WH = WH['object_WH']

            # TODO: step1
            model.delta_model.eval()
            model.decoder.train()
            outdict_sample = model(img,
                                   mode='train_rl',
                                   temperature=config['temperature'],
                                   temperature2=0.0)  # (bs, seq_len, 28*28+1)
            # greedy
            with torch.no_grad():
                outdict_greedy = model(img, mode='train_rl', temperature=0.0)

            # Get RL loss
            sampling_pred_x = outdict_sample['final_pred_x'].cpu().numpy()
            sampling_pred_y = outdict_sample['final_pred_y'].cpu().numpy()
            sampling_pred_len = outdict_sample['lengths'].cpu().numpy()
            greedy_pred_x = outdict_greedy['final_pred_x'].cpu().numpy()
            greedy_pred_y = outdict_greedy['final_pred_y'].cpu().numpy()
            greedy_pred_len = outdict_greedy['lengths'].cpu().numpy()
            sampling_iou = np.zeros(bs, dtype=np.float32)
            greedy_iou = np.zeros(bs, dtype=np.float32)
            vertices_GT = []  # (bs, 70, 2)
            vertices_sampling = []
            vertices_greedy = []
            GT_polys = batch[-2].numpy()  # (bs, 70, 2)
            GT_mask = batch[7]  # (bs, 70)

            for ii in range(bs):
                scaleW = 224.0 / float(object_WH[0][ii])
                scaleH = 224.0 / float(object_WH[1][ii])
                leftW = left_WH[0][ii]
                leftH = left_WH[1][ii]
                tmp = []
                all_len = np.sum(GT_mask[ii].numpy())
                cnt_target = GT_polys[ii][:all_len]
                for vert in cnt_target:
                    tmp.append(
                        (vert[0] / scaleW + leftW, vert[1] / scaleH + leftH))
                vertices_GT.append(tmp)

                tmp = []
                for j in range(sampling_pred_len[ii] - 1):
                    vertex = (sampling_pred_x[ii][j] / scaleW + leftW,
                              sampling_pred_y[ii][j] / scaleH + leftH)
                    tmp.append(vertex)
                vertices_sampling.append(tmp)

                tmp = []
                for j in range(greedy_pred_len[ii] - 1):
                    vertex = (greedy_pred_x[ii][j] / scaleW + leftW,
                              greedy_pred_y[ii][j] / scaleH + leftH)
                    tmp.append(vertex)
                vertices_greedy.append(tmp)

            # IoU between sampling/greedy and GT
            for ii in range(bs):
                sam = vertices_sampling[ii]
                gt = vertices_GT[ii]
                gre = vertices_greedy[ii]

                if len(sam) < 2:
                    sampling_iou[ii] = 0.
                else:
                    iou_sam, _, _ = iou(sam, gt, origion_WH[1][ii],
                                        origion_WH[0][ii])
                    sampling_iou[ii] = iou_sam

                if len(gre) < 2:
                    greedy_iou[ii] = 0.
                else:
                    iou_gre, _, _ = iou(gre, gt, origion_WH[1][ii],
                                        origion_WH[0][ii])
                    greedy_iou[ii] = iou_gre

            logprobs = outdict_sample['log_probs']
            # 强化学习损失,logprob是两个logprob加和
            loss = losses.self_critical_loss(
                logprobs, outdict_sample['lengths'],
                torch.from_numpy(sampling_iou).to(devices),
                torch.from_numpy(greedy_iou).to(devices))
            model.zero_grad()
            if 'grid_clip' in config:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         config['grad_clip'])
            loss.backward()
            optimizer1.step()  # 更新参数

            accum['loss_total'] += loss
            accum['sampling_iou'] += np.mean(sampling_iou)
            accum['greedy_iou'] += np.mean(greedy_iou)
            # 打印损失
            if (index + 1) % 20 == 0:
                print('Epoch {} - Step {}'.format(it + 1, index + 1))
                print(
                    '     Main Decoder: loss_total {}, sampling_iou {}, greedy_iou {}'
                    .format(accum['loss_total'] / 20,
                            accum['sampling_iou'] / 20,
                            accum['greedy_iou'] / 20))
                accum = defaultdict(float)

            # TODO:训练delta_model decoder step2
            model.decoder.eval()
            model.delta_model.train()

            outdict_sample = model(
                img,
                mode='train_rl',
                temperature=0.0,
                temperature2=config['temperature2'])  # (bs, seq_len, 28*28+1)
            # greedy
            with torch.no_grad():
                outdict_greedy = model(img,
                                       mode='train_rl',
                                       temperature=0.0,
                                       temperature2=0.0)

            # Get RL loss
            sampling_pred_x = outdict_sample['final_pred_x'].cpu().numpy()
            sampling_pred_y = outdict_sample['final_pred_y'].cpu().numpy()
            sampling_pred_len = outdict_sample['lengths'].cpu().numpy()
            greedy_pred_x = outdict_greedy['final_pred_x'].cpu().numpy()
            greedy_pred_y = outdict_greedy['final_pred_y'].cpu().numpy()
            greedy_pred_len = outdict_greedy['lengths'].cpu().numpy()
            sampling_iou = np.zeros(bs, dtype=np.float32)
            greedy_iou = np.zeros(bs, dtype=np.float32)
            vertices_GT = []  # (bs, 70, 2)
            vertices_sampling = []
            vertices_greedy = []
            GT_polys = batch[-2].numpy()  # (bs, 70, 2)
            GT_mask = batch[7]  # (bs, 70)

            for ii in range(bs):
                scaleW = 224.0 / object_WH[0][ii]
                scaleH = 224.0 / object_WH[1][ii]
                leftW = left_WH[0][ii]
                leftH = left_WH[1][ii]
                tmp = []
                all_len = np.sum(GT_mask[ii].numpy())
                cnt_target = GT_polys[ii][:all_len]
                for vert in cnt_target:
                    tmp.append(
                        (vert[0] / scaleW + leftW, vert[1] / scaleH + leftH))
                vertices_GT.append(tmp)

                tmp = []
                for j in range(sampling_pred_len[ii] - 1):
                    vertex = (sampling_pred_x[ii][j] / scaleW + leftW,
                              sampling_pred_y[ii][j] / scaleH + leftH)
                    tmp.append(vertex)
                vertices_sampling.append(tmp)

                tmp = []
                for j in range(greedy_pred_len[ii] - 1):
                    vertex = (greedy_pred_x[ii][j] / scaleW + leftW,
                              greedy_pred_y[ii][j] / scaleH + leftH)
                    tmp.append(vertex)
                vertices_greedy.append(tmp)

            # IoU between sampling/greedy and GT
            for ii in range(bs):
                sam = vertices_sampling[ii]
                gt = vertices_GT[ii]
                gre = vertices_greedy[ii]

                if len(sam) < 2:
                    sampling_iou[ii] = 0.
                else:
                    iou_sam, _, _ = iou(sam, gt, origion_WH[1][ii],
                                        origion_WH[0][ii])
                    sampling_iou[ii] = iou_sam

                if len(gre) < 2:
                    greedy_iou[ii] = 0.
                else:
                    iou_gre, _, _ = iou(gre, gt, origion_WH[1][ii],
                                        origion_WH[0][ii])
                    greedy_iou[ii] = iou_gre

            # TODO:
            logprobs = outdict_sample['delta_logprob']
            # 强化学习损失,logprob是两个logprob加和
            loss = losses.self_critical_loss(
                logprobs, outdict_sample['lengths'],
                torch.from_numpy(sampling_iou).to(devices),
                torch.from_numpy(greedy_iou).to(devices))
            model.zero_grad()
            if 'grid_clip' in config:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         config['grad_clip'])
            loss.backward()
            optimizer2.step()

            accum2['loss_total'] += loss
            accum2['sampling_iou'] += np.mean(sampling_iou)
            accum2['greedy_iou'] += np.mean(greedy_iou)
            # 打印损失
            if (index + 1) % 20 == 0:
                print(
                    '     Second Decoder: loss_total {}, sampling_iou {}, greedy_iou {}'
                    .format(accum2['loss_total'] / 20,
                            accum2['sampling_iou'] / 20,
                            accum2['greedy_iou'] / 20))
                accum2 = defaultdict(float)

            if (index + 1) % config['val_every'] == 0:
                # validation
                model.decoder.eval()
                model.delta_model.eval()
                val_IoU = []
                less_than2 = 0
                with torch.no_grad():
                    for val_index, val_batch in enumerate(val_loader):
                        img = torch.tensor(val_batch[0],
                                           dtype=torch.float).cuda()
                        bs = img.shape[0]
                        WH = val_batch[-1]  # WH_dict
                        left_WH = WH['left_WH']
                        origion_WH = WH['origion_WH']
                        object_WH = WH['object_WH']
                        val_target = val_batch[-2].numpy()  # (bs, 70, 2)
                        val_mask_final = val_batch[7]  # (bs, 70)
                        out_dict = model(
                            img, mode='test')  # (N, seq_len) # test_time
                        pred_x = out_dict['final_pred_x'].cpu().numpy()
                        pred_y = out_dict['final_pred_y'].cpu().numpy()
                        pred_len = out_dict['lengths']  # 预测的长度
                        # 求IoU
                        for ii in range(bs):
                            vertices1 = []
                            vertices2 = []
                            scaleW = 224.0 / float(object_WH[0][ii])
                            scaleH = 224.0 / float(object_WH[1][ii])
                            leftW = left_WH[0][ii]
                            leftH = left_WH[1][ii]

                            all_len = np.sum(val_mask_final[ii].numpy())
                            cnt_target = val_target[ii][:all_len]
                            for vert in cnt_target:
                                vertices2.append((vert[0] / scaleW + leftW,
                                                  vert[1] / scaleH + leftH))
                            pred_len_b = pred_len[ii] - 1
                            if pred_len_b < 2:
                                val_IoU.append(0.)
                                less_than2 += 1
                                continue
                            for j in range(pred_len_b):
                                vertex = (pred_x[ii][j] / scaleW + leftW,
                                          pred_y[ii][j] / scaleH + leftH)
                                vertices1.append(vertex)

                            _, nu_cur, de_cur = iou(vertices1, vertices2,
                                                    origion_WH[1][ii],
                                                    origion_WH[0][ii])
                            iou_cur = nu_cur * 1.0 / de_cur if de_cur != 0 else 0
                            val_IoU.append(iou_cur)

                val_iou_data = np.mean(np.array(val_IoU))
                print('Validation After Epoch {} - step {}'.format(
                    str(it + 1), str(index + 1)))
                print('           IoU      on validation set: ', val_iou_data)
                print('less than 2: ', less_than2)
                print('Saving training parameters after this epoch:')
                torch.save(
                    model.state_dict(),
                    '/data/duye/pretrained_models/FPNRLtrain/ResNext_Plus_RL2_retain_Epoch{}-Step{}_ValIoU{}.pth'
                    .format(str(it + 1), str(index + 1), str(val_iou_data)))
                # set to init
                model.decoder.train()  # important
                model.delta_model.train()

        scheduler.step()
        print('Epoch {} Completed!'.format(str(it + 1)))
Esempio n. 8
0
class OnLineTrainer:
    def __init__(self,
                 num_workers=8,
                 update_every=8,
                 save_every=100,
                 t1=0.1,
                 t2=0.1,
                 pre=None):
        self.num_workers = num_workers
        self.update_every = update_every
        self.save_every = save_every
        self.t1 = t1
        self.t2 = t2
        self.max_epoch = 2
        self.model = PolygonModel(predict_delta=True).to(devices)
        if pre != None:
            self.model.load_state_dict(torch.load(pre))
        self.dataloader = loadAde20K(batch_size=4)
        self.model.encoder.eval()
        self.model.delta_encoder.eval()
        for n, p in self.model.named_parameters():
            if 'encoder' in n:
                p.requires_grad = False
        self.train_params = [
            p for p in self.model.parameters() if p.requires_grad == True
        ]
        self.optimizer = optim.Adam(self.train_params, lr=2e-5, amsgrad=False)

    # TODO: 1. 多训练一下Rooftop把精度提升上去(现在是每50次保存一下)
    #  2. ADE20K train数据集加载不对,gt维度不匹配
    #  3. cityscape每20保存一下,看看是否有不妥?

    def train(self):
        accum = defaultdict(float)
        accum2 = defaultdict(float)
        global_step = 0.
        for epoch in range(self.max_epoch):
            for step, batch in enumerate(self.dataloader):
                global_step += 1
                b = []
                b.append(batch['s1'])
                b.append(batch['s2'])
                b.append(batch['s3'])
                b.append(batch['s4'])

                # print(b1[0].shape)  # (3, 224, 224)

                img = torch.cat([
                    b[0][0].unsqueeze(0), b[1][0].unsqueeze(0),
                    b[2][0].unsqueeze(0), b[3][0].unsqueeze(0)
                ],
                                dim=0).to(devices)

                bs = img.shape[0]

                # TODO: step1
                self.model.delta_model.eval()
                self.model.decoder.train()
                outdict_sample = self.model(img,
                                            mode='train_rl',
                                            temperature=self.t1,
                                            temperature2=0.0)
                # greedy
                with torch.no_grad():
                    outdict_greedy = self.model(img,
                                                mode='train_rl',
                                                temperature=0.0)
                # Get RL loss
                sampling_pred_x = outdict_sample['final_pred_x'].cpu().numpy()
                sampling_pred_y = outdict_sample['final_pred_y'].cpu().numpy()
                sampling_pred_len = outdict_sample['lengths'].cpu().numpy()
                greedy_pred_x = outdict_greedy['final_pred_x'].cpu().numpy()
                greedy_pred_y = outdict_greedy['final_pred_y'].cpu().numpy()
                greedy_pred_len = outdict_greedy['lengths'].cpu().numpy()
                sampling_iou = np.zeros(bs, dtype=np.float32)
                greedy_iou = np.zeros(bs, dtype=np.float32)

                vertices_sampling = []
                vertices_greedy = []
                for ii in range(bs):
                    WH = b[ii][-1]
                    object_WH = WH['object_WH']
                    left_WH = WH['left_WH']
                    #     WH = {'left_WH': left_WH, 'object_WH': object_WH, 'origion_WH': origion_WH}
                    scaleW = 224.0 / float(object_WH[0])
                    scaleH = 224.0 / float(object_WH[1])
                    leftW = left_WH[0]
                    leftH = left_WH[1]

                    tmp = []
                    for j in range(sampling_pred_len[ii] - 1):
                        vertex = (sampling_pred_x[ii][j] / scaleW + leftW,
                                  sampling_pred_y[ii][j] / scaleH + leftH)
                        tmp.append(vertex)
                    vertices_sampling.append(tmp)

                    tmp = []
                    for j in range(greedy_pred_len[ii] - 1):
                        vertex = (greedy_pred_x[ii][j] / scaleW + leftW,
                                  greedy_pred_y[ii][j] / scaleH + leftH)
                        tmp.append(vertex)
                    vertices_greedy.append(tmp)
                # IoU between sampling/greedy and GT
                for ii in range(bs):
                    gt = b[ii][1]  # (H, W)
                    WH = b[ii][-1]
                    origion_WH = WH['origion_WH']
                    sam = vertices_sampling[ii]
                    gre = vertices_greedy[ii]
                    if len(sam) < 2:
                        sampling_iou[ii] = 0.
                    else:
                        # iou_sam, _, _ = iou(sam, gt, origion_WH[1], origion_WH[0])
                        # sampling_iou[ii] = iou_sam
                        img1 = Image.new('L', (origion_WH[0], origion_WH[1]),
                                         0)
                        ImageDraw.Draw(img1).polygon(sam, outline=1, fill=1)
                        mask1 = np.array(img1)  # (h, w)
                        intersection = np.logical_and(mask1, gt)  # 都是1
                        union = np.logical_or(mask1, gt)  # 有个1
                        nu = np.sum(intersection)
                        de = np.sum(union)
                        sampling_iou[ii] = nu * 1.0 / de if de != 0 else 0.
                    if len(gre) < 2:
                        greedy_iou[ii] = 0.
                    else:
                        # iou_gre, _, _ = iou(gre, gt, origion_WH[1], origion_WH[0])
                        # greedy_iou[ii] = iou_gre
                        img1 = Image.new('L', (origion_WH[0], origion_WH[1]),
                                         0)
                        ImageDraw.Draw(img1).polygon(gre, outline=1, fill=1)
                        mask1 = np.array(img1)  # (h, w)
                        intersection = np.logical_and(mask1, gt)  # 都是1
                        union = np.logical_or(mask1, gt)  # 有个1
                        nu = np.sum(intersection)
                        de = np.sum(union)
                        greedy_iou[ii] = nu * 1.0 / de if de != 0 else 0.
                logprobs = outdict_sample['log_probs']
                # 强化学习损失,logprob是两个logprob加和
                loss = losses.self_critical_loss(
                    logprobs, outdict_sample['lengths'],
                    torch.from_numpy(sampling_iou).to(devices),
                    torch.from_numpy(greedy_iou).to(devices))
                self.model.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 40)
                self.optimizer.step()  # 更新参数

                accum['loss_total'] += loss
                accum['sampling_iou'] += np.mean(sampling_iou)
                accum['greedy_iou'] += np.mean(greedy_iou)
                # 打印损失
                print(
                    'Update {}, RL training of Main decoder, loss {}, model IoU {}'
                    .format(step + 1, accum['loss_total'],
                            accum['greedy_iou']))
                accum = defaultdict(float)
                # TODO:训练delta_model decoder step2
                self.model.decoder.eval()
                self.model.delta_model.train()
                outdict_sample = self.model(img,
                                            mode='train_rl',
                                            temperature=0.0,
                                            temperature2=self.t2)
                # greedy
                with torch.no_grad():
                    outdict_greedy = self.model(img,
                                                mode='train_rl',
                                                temperature=0.0,
                                                temperature2=0.0)
                # Get RL loss
                sampling_pred_x = outdict_sample['final_pred_x'].cpu().numpy()
                sampling_pred_y = outdict_sample['final_pred_y'].cpu().numpy()
                sampling_pred_len = outdict_sample['lengths'].cpu().numpy()
                greedy_pred_x = outdict_greedy['final_pred_x'].cpu().numpy()
                greedy_pred_y = outdict_greedy['final_pred_y'].cpu().numpy()
                greedy_pred_len = outdict_greedy['lengths'].cpu().numpy()
                sampling_iou = np.zeros(bs, dtype=np.float32)
                greedy_iou = np.zeros(bs, dtype=np.float32)
                vertices_sampling = []
                vertices_greedy = []
                for ii in range(bs):
                    WH = b[ii][-1]
                    object_WH = WH['object_WH']
                    left_WH = WH['left_WH']
                    #     WH = {'left_WH': left_WH, 'object_WH': object_WH, 'origion_WH': origion_WH}
                    scaleW = 224.0 / float(object_WH[0])
                    scaleH = 224.0 / float(object_WH[1])
                    leftW = left_WH[0]
                    leftH = left_WH[1]

                    tmp = []
                    for j in range(sampling_pred_len[ii] - 1):
                        vertex = (sampling_pred_x[ii][j] / scaleW + leftW,
                                  sampling_pred_y[ii][j] / scaleH + leftH)
                        tmp.append(vertex)
                    vertices_sampling.append(tmp)

                    tmp = []
                    for j in range(greedy_pred_len[ii] - 1):
                        vertex = (greedy_pred_x[ii][j] / scaleW + leftW,
                                  greedy_pred_y[ii][j] / scaleH + leftH)
                        tmp.append(vertex)
                    vertices_greedy.append(tmp)
                # IoU between sampling/greedy and GT
                for ii in range(bs):
                    gt = b[ii][1]  # (H, W)
                    WH = b[ii][-1]
                    origion_WH = WH['origion_WH']
                    sam = vertices_sampling[ii]
                    gre = vertices_greedy[ii]
                    if len(sam) < 2:
                        sampling_iou[ii] = 0.
                    else:
                        # iou_sam, _, _ = iou(sam, gt, origion_WH[1], origion_WH[0])
                        # sampling_iou[ii] = iou_sam
                        img1 = Image.new('L', (origion_WH[0], origion_WH[1]),
                                         0)
                        ImageDraw.Draw(img1).polygon(sam, outline=1, fill=1)
                        mask1 = np.array(img1)  # (h, w)
                        intersection = np.logical_and(mask1, gt)  # 都是1
                        union = np.logical_or(mask1, gt)  # 有个1
                        nu = np.sum(intersection)
                        de = np.sum(union)
                        sampling_iou[ii] = nu * 1.0 / de if de != 0 else 0.
                    if len(gre) < 2:
                        greedy_iou[ii] = 0.
                    else:
                        # iou_gre, _, _ = iou(gre, gt, origion_WH[1], origion_WH[0])
                        # greedy_iou[ii] = iou_gre
                        img1 = Image.new('L', (origion_WH[0], origion_WH[1]),
                                         0)
                        ImageDraw.Draw(img1).polygon(gre, outline=1, fill=1)
                        mask1 = np.array(img1)  # (h, w)
                        intersection = np.logical_and(mask1, gt)  # 都是1
                        union = np.logical_or(mask1, gt)  # 有个1
                        nu = np.sum(intersection)
                        de = np.sum(union)
                        greedy_iou[ii] = nu * 1.0 / de if de != 0 else 0.
                logprobs = outdict_sample['log_probs']
                # 强化学习损失,logprob是两个logprob加和
                loss = losses.self_critical_loss(
                    logprobs, outdict_sample['lengths'],
                    torch.from_numpy(sampling_iou).to(devices),
                    torch.from_numpy(greedy_iou).to(devices))
                self.model.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 40)
                self.optimizer.step()  # 更新参数
                accum2['loss_total'] += loss
                accum2['sampling_iou'] += np.mean(sampling_iou)
                accum2['greedy_iou'] += np.mean(greedy_iou)
                # 打印损失
                print(
                    'Update {}, RL training of Second decoder, loss {}, model IoU {}'
                    .format(step + 1, accum2['loss_total'],
                            accum2['greedy_iou']))
                accum2 = defaultdict(float)

                if global_step % self.save_every == 0:
                    print('Saving training parameters after Updating...')
                    save_dir = '/data/duye/pretrained_models/OnLineTraining_ADE20K/'
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir)
                    torch.save(self.model.state_dict(),
                               save_dir + str(global_step) + '.pth')
    dataloader = DataLoader(ss, batch_size=batch_size, shuffle=shuffle,
                            drop_last=False)
    print('DataLoader complete!', dataloader)
    return dataloader

# 测试得分
devices = 'cuda' if torch.cuda.is_available() else 'cpu'
if __name__ == '__main__':
    parse = argparse.ArgumentParser(description='测试在ssTEM上的泛化得分')
    parse.add_argument('-p', '--pretrained', type=str, default=None)
    args = parse.parse_args()
    pre = args.pretrained
    model = PolygonModel(predict_delta=True).to(devices)
    pre = 'ResNext_Plus_RL2_retain_Epoch1-Step4000_ValIoU0.6316584628283326.pth'
    dirs = '/data/duye/pretrained_models/FPNRLtrain/' + pre
    model.load_state_dict(torch.load(dirs))
    model.eval()
    loader = loadssTEM(batch_size=8)

    iou = []
    for index, batch in enumerate(loader):
        print('index: ', index)
        img = batch[0]
        WH = batch[-1]  # WH_dict
        left_WH = WH['left_WH']
        origion_WH = WH['origion_WH']
        object_WH = WH['object_WH']
        gt = batch[1]

        bs = img.shape[0]
        with torch.no_grad():
def train(config, load_resnet50=False, pre_trained=None, cur_epochs=0):

    batch_size = config['batch_size']
    lr = config['lr']
    epochs = config['epoch']

    train_dataloader = loadData('train', 16, 71, batch_size)
    val_loader = loadData('val', 16, 71, batch_size, shuffle=False)
    model = PolygonModel(load_predtrained_resnet50=load_resnet50,
                         predict_delta=False).cuda()
    # checkpoint
    if pre_trained is not None:
        model.load_state_dict(torch.load(pre_trained))
        print('loaded pretrained polygon net!')

    # Regulation,原paper没有+regulation
    no_wd = []
    wd = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            # No optimization for frozen params
            continue
        if 'bn' in name or 'convLSTM' in name or 'bias' in name:
            no_wd.append(param)
        else:
            wd.append(param)

    optimizer = optim.Adam([{
        'params': no_wd,
        'weight_decay': 0.0
    }, {
        'params': wd
    }],
                           lr=lr,
                           weight_decay=config['weight_decay'],
                           amsgrad=False)

    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=config['lr_decay'][0],
                                          gamma=config['lr_decay'][1])

    print('Total Epochs:', epochs)
    for it in range(cur_epochs, epochs):
        accum = defaultdict(float)
        # accum['loss_total'] = 0.
        # accum['loss_lstm'] = 0.
        # accum['loss_delta'] = 0.
        for index, batch in enumerate(train_dataloader):
            img = torch.tensor(batch[0], dtype=torch.float).cuda()
            bs = img.shape[0]
            pre_v2 = torch.tensor(batch[2], dtype=torch.float).cuda()
            pre_v1 = torch.tensor(batch[3], dtype=torch.float).cuda()
            outdict = model(img, pre_v2, pre_v1,
                            mode='train_ce')  # (bs, seq_len, 28*28+1)s

            out = outdict['logits']
            # 之前训练不小心加了下面这句
            # out = torch.nn.functional.log_softmax(out, dim=-1)  # logits->log_probs
            out = out.contiguous().view(-1,
                                        28 * 28 + 1)  # (bs*seq_len, 28*28+1)
            target = batch[4]

            # smooth target
            target = dt_targets_from_class(np.array(target, dtype=np.int), 28,
                                           2)  # (bs, seq_len, 28*28+1)
            target = torch.from_numpy(target).cuda().contiguous().view(
                -1, 28 * 28 + 1)  # (bs, seq_len, 28*28+1)
            # 交叉熵损失计算
            mask_final = batch[6]  # 结束符标志mask  (bs, seq_len(70)从第一个点开始)
            mask_final = torch.tensor(mask_final).cuda().view(-1)
            mask_delta = batch[7]
            mask_delta = torch.tensor(mask_delta).cuda().view(-1)  # (bs*70)
            loss_lstm = torch.sum(-target *
                                  torch.nn.functional.log_softmax(out, dim=1),
                                  dim=1)  # (bs*seq_len)
            loss_lstm = loss_lstm * mask_final.type_as(
                loss_lstm)  # 从end point截断损失计算
            loss_lstm = loss_lstm.view(bs, -1)  # (bs, seq_len)
            loss_lstm = torch.sum(loss_lstm, dim=1)  # sum over seq_len  (bs,)
            real_pointnum = torch.sum(mask_final.contiguous().view(bs, -1),
                                      dim=1)
            loss_lstm = loss_lstm / real_pointnum  # mean over seq_len
            loss_lstm = torch.mean(loss_lstm)  # mean over batch

            # loss = loss_lstm + loss_delta
            loss = loss_lstm
            #TODO: 这里train_ce可以用这个loss, 但train_rl可以根据条件概率重写损失函数
            model.zero_grad()

            if 'grid_clip' in config:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         config['grad_clip'])

            loss.backward()

            accum['loss_total'] += loss
            optimizer.step()

            # 打印损失
            if (index + 1) % 20 == 0:
                print('Epoch {} - Step {}, loss_total {}'.format(
                    it + 1, index, accum['loss_total'] / 20))
                accum = defaultdict(float)
            # 每3000step一次
            if (index + 1) % config['val_every'] == 0:
                # validation
                model.eval()  # 原作者只eval了这个
                val_IoU = []
                less_than2 = 0
                with torch.no_grad():
                    for val_index, val_batch in enumerate(val_loader):
                        img = torch.tensor(val_batch[0],
                                           dtype=torch.float).cuda()
                        bs = img.shape[0]

                        WH = val_batch[-1]  # WH_dict
                        left_WH = WH['left_WH']
                        origion_WH = WH['origion_WH']
                        object_WH = WH['object_WH']

                        val_mask_final = val_batch[6]
                        val_mask_final = torch.tensor(
                            val_mask_final).cuda().contiguous().view(-1)
                        out_dict = model(
                            img, mode='test')  # (N, seq_len) # test_time
                        pred_polys = out_dict['pred_polys']  # (bs, seq_len)
                        tmp = pred_polys
                        pred_polys = pred_polys.contiguous().view(
                            -1)  # (bs*seq_len)
                        val_target = val_batch[4]  # (bs, seq_len)
                        # 求accuracy
                        val_target = torch.tensor(
                            val_target,
                            dtype=torch.long).cuda().contiguous().view(
                                -1)  # (bs*seq_len)
                        val_acc1 = torch.tensor(pred_polys == val_target,
                                                dtype=torch.float).cuda()
                        val_acc1 = (val_acc1 * val_mask_final).sum().item()
                        val_acc1 = val_acc1 * 1.0 / val_mask_final.sum().item()
                        # 用作计算IoU
                        val_result_index = tmp.cpu().numpy()  # (bs, seq_len)
                        val_target = val_batch[4].numpy()  # (bs, seq_len)

                        # 求IoU
                        for ii in range(bs):
                            vertices1 = []
                            vertices2 = []
                            scaleW = 224.0 / object_WH[0][ii]
                            scaleH = 224.0 / object_WH[1][ii]
                            leftW = left_WH[0][ii]
                            leftH = left_WH[1][ii]
                            for label in val_result_index[ii]:
                                if label == 28 * 28:
                                    break
                                vertex = (
                                    ((label % 28) * 8.0 + 4) / scaleW + leftW,
                                    ((int(label / 28)) * 8.0 + 4) / scaleH +
                                    leftH)
                                vertices1.append(vertex)
                            for label in val_target[ii]:
                                if label == 28 * 28:
                                    break
                                vertex = (
                                    ((label % 28) * 8.0 + 4) / scaleW + leftW,
                                    ((int(label / 28)) * 8.0 + 4) / scaleH +
                                    leftH)
                                vertices2.append(vertex)
                            if len(vertices1) < 2:
                                less_than2 += 1
                                # IoU=0.
                                val_IoU.append(0.)
                                continue
                            _, nu_cur, de_cur = iou(
                                vertices1, vertices2, origion_WH[1][ii],
                                origion_WH[0][ii])  # (H, W)
                            iou_cur = nu_cur * 1.0 / de_cur if de_cur != 0 else 0
                            val_IoU.append(iou_cur)

                val_iou_data = np.mean(np.array(val_IoU))
                print('Validation After Epoch {} - step {}'.format(
                    str(it + 1), str(index + 1)))
                print('           IoU      on validation set: ', val_iou_data)
                print('less than 2: ', less_than2)
                if it > 4:  # it = 5
                    print('Saving training parameters after this epoch:')
                    torch.save(
                        model.state_dict(),
                        '/data/duye/pretrained_models/ResNext50_FPN_LSTM_Epoch{}-Step{}_ValIoU{}.pth'
                        .format(str(it + 1), str(index + 1),
                                str(val_iou_data)))
                # set to init
                model.train()  # important

        # 衰减
        scheduler.step()
        # 打印当前lr
        print()
        print('Epoch {} Completed!'.format(str(it + 1)))
        print()