class OnLineTrainer:
    def __init__(self, num_workers=8, update_every=8, save_every=10, t1=0.1, t2=0.2, pre=None):
        self.num_workers = num_workers
        self.update_every = update_every
        self.save_every = save_every
        self.t1 = t1
        self.t2 = t2
        self.model = PolygonModel(predict_delta=True).to(devices)
        if pre != None:
            self.model.load_state_dict(torch.load(pre))
        self.dataloader = loadData(data_num=16,
                                   batch_size=self.num_workers,
                                   len_s=71,
                                   path='val',
                                   shuffle=False)
        self.model.encoder.eval()
        self.model.delta_encoder.eval()
        for n, p in self.model.named_parameters():
            if 'encoder' in n:
                p.requires_grad = False
        self.train_params = [p for p in self.model.parameters() if p.requires_grad==False]
        self.optimizer = optim.Adam(self.train_params,
                                    lr=2e-6,
                                    amsgrad=False)

    def train(self):
        accum = defaultdict(float)
        accum2 = defaultdict(float)
        for step, batch in enumerate(self.dataloader):
            img = torch.tensor(batch[0], dtype=torch.float).cuda()
            bs = img.shape[0]
            WH = batch[-1]  # WH_dict
            left_WH = WH['left_WH']
            origion_WH = WH['origion_WH']
            object_WH = WH['object_WH']
            # TODO: step1
            self.model.delta_model.eval()
            self.model.decoder.train()
            outdict_sample = self.model(img, mode='train_rl', temperature=self.t1,
                                        temperature2=0.0)
            # greedy
            with torch.no_grad():
                outdict_greedy = self.model(img, mode='train_rl', temperature=0.0)
            # Get RL loss
            sampling_pred_x = outdict_sample['final_pred_x'].cpu().numpy()
            sampling_pred_y = outdict_sample['final_pred_y'].cpu().numpy()
            sampling_pred_len = outdict_sample['lengths'].cpu().numpy()
            greedy_pred_x = outdict_greedy['final_pred_x'].cpu().numpy()
            greedy_pred_y = outdict_greedy['final_pred_y'].cpu().numpy()
            greedy_pred_len = outdict_greedy['lengths'].cpu().numpy()
            sampling_iou = np.zeros(bs, dtype=np.float32)
            greedy_iou = np.zeros(bs, dtype=np.float32)
            vertices_GT = []  # (bs, 70, 2)
            vertices_sampling = []
            vertices_greedy = []
            GT_polys = batch[-2].numpy()  # (bs, 70, 2)
            GT_mask = batch[7]  # (bs, 70)
            for ii in range(bs):
                scaleW = 224.0 / object_WH[0][ii]
                scaleH = 224.0 / object_WH[1][ii]
                leftW = left_WH[0][ii]
                leftH = left_WH[1][ii]
                tmp = []
                all_len = np.sum(GT_mask[ii].numpy())
                cnt_target = GT_polys[ii][:all_len]
                for vert in cnt_target:
                    tmp.append((vert[0] / scaleW + leftW,
                                vert[1] / scaleH + leftH))
                vertices_GT.append(tmp)

                tmp = []
                for j in range(sampling_pred_len[ii] - 1):
                    vertex = (
                        sampling_pred_x[ii][j] / scaleW + leftW,
                        sampling_pred_y[ii][j] / scaleH + leftH
                    )
                    tmp.append(vertex)
                vertices_sampling.append(tmp)

                tmp = []
                for j in range(greedy_pred_len[ii] - 1):
                    vertex = (
                        greedy_pred_x[ii][j] / scaleW + leftW,
                        greedy_pred_y[ii][j] / scaleH + leftH
                    )
                    tmp.append(vertex)
                vertices_greedy.append(tmp)
            # IoU between sampling/greedy and GT
            for ii in range(bs):
                sam = vertices_sampling[ii]
                gt = vertices_GT[ii]
                gre = vertices_greedy[ii]
                if len(sam) < 2:
                    sampling_iou[ii] = 0.
                else:
                    iou_sam, _, _ = iou(sam, gt, origion_WH[1][ii], origion_WH[0][ii])
                    sampling_iou[ii] = iou_sam
                if len(gre) < 2:
                    greedy_iou[ii] = 0.
                else:
                    iou_gre, _, _ = iou(gre, gt, origion_WH[1][ii], origion_WH[0][ii])
                    greedy_iou[ii] = iou_gre
            logprobs = outdict_sample['log_probs']
            # 强化学习损失,logprob是两个logprob加和
            loss = losses.self_critical_loss(logprobs, outdict_sample['lengths'],
                                             torch.from_numpy(sampling_iou).to(devices),
                                             torch.from_numpy(greedy_iou).to(devices))
            self.model.zero_grad()
            nn.utils.clip_grad_norm_(self.model.parameters(), 40)
            loss.backward()
            self.optimizer.step()  # 更新参数
            accum['loss_total'] += loss
            accum['sampling_iou'] += np.mean(sampling_iou)
            accum['greedy_iou'] += np.mean(greedy_iou)
            # 打印损失
            print('Update {}, RL training of Main decoder, loss {}, model IoU {}'.format(step + 1,
                                                                             accum['loss_total'],
                                                                             accum['greedy_iou']))
            accum = defaultdict(float)
            # TODO:训练delta_model decoder step2
            self.model.decoder.eval()
            self.model.delta_model.train()
            outdict_sample = self.model(img, mode='train_rl',
                                        temperature=0.0,
                                        temperature2=self.t2)
            # greedy
            with torch.no_grad():
                outdict_greedy = self.model(img, mode='train_rl',
                                       temperature=0.0,
                                       temperature2=0.0)
            # Get RL loss
            sampling_pred_x = outdict_sample['final_pred_x'].cpu().numpy()
            sampling_pred_y = outdict_sample['final_pred_y'].cpu().numpy()
            sampling_pred_len = outdict_sample['lengths'].cpu().numpy()
            greedy_pred_x = outdict_greedy['final_pred_x'].cpu().numpy()
            greedy_pred_y = outdict_greedy['final_pred_y'].cpu().numpy()
            greedy_pred_len = outdict_greedy['lengths'].cpu().numpy()
            sampling_iou = np.zeros(bs, dtype=np.float32)
            greedy_iou = np.zeros(bs, dtype=np.float32)
            vertices_GT = []  # (bs, 70, 2)
            vertices_sampling = []
            vertices_greedy = []
            GT_polys = batch[-2].numpy()  # (bs, 70, 2)
            GT_mask = batch[7]  # (bs, 70)

            for ii in range(bs):
                scaleW = 224.0 / object_WH[0][ii]
                scaleH = 224.0 / object_WH[1][ii]
                leftW = left_WH[0][ii]
                leftH = left_WH[1][ii]
                tmp = []
                all_len = np.sum(GT_mask[ii].numpy())
                cnt_target = GT_polys[ii][:all_len]
                for vert in cnt_target:
                    tmp.append((vert[0] / scaleW + leftW,
                                vert[1] / scaleH + leftH))
                vertices_GT.append(tmp)

                tmp = []
                for j in range(sampling_pred_len[ii] - 1):
                    vertex = (
                        sampling_pred_x[ii][j] / scaleW + leftW,
                        sampling_pred_y[ii][j] / scaleH + leftH
                    )
                    tmp.append(vertex)
                vertices_sampling.append(tmp)

                tmp = []
                for j in range(greedy_pred_len[ii] - 1):
                    vertex = (
                        greedy_pred_x[ii][j] / scaleW + leftW,
                        greedy_pred_y[ii][j] / scaleH + leftH
                    )
                    tmp.append(vertex)
                vertices_greedy.append(tmp)

            # IoU between sampling/greedy and GT
            for ii in range(bs):
                sam = vertices_sampling[ii]
                gt = vertices_GT[ii]
                gre = vertices_greedy[ii]

                if len(sam) < 2:
                    sampling_iou[ii] = 0.
                else:
                    iou_sam, _, _ = iou(sam, gt, origion_WH[1][ii], origion_WH[0][ii])
                    sampling_iou[ii] = iou_sam

                if len(gre) < 2:
                    greedy_iou[ii] = 0.
                else:
                    iou_gre, _, _ = iou(gre, gt, origion_WH[1][ii], origion_WH[0][ii])
                    greedy_iou[ii] = iou_gre

            # TODO:
            logprobs = outdict_sample['delta_logprob']
            # 强化学习损失,logprob是两个logprob加和
            loss = losses.self_critical_loss(logprobs, outdict_sample['lengths'],
                                             torch.from_numpy(sampling_iou).to(devices),
                                             torch.from_numpy(greedy_iou).to(devices))
            self.model.zero_grad()
            nn.utils.clip_grad_norm_(self.model.parameters(), 40)
            loss.backward()
            self.optimizer.step()
            accum2['loss_total'] += loss
            accum2['sampling_iou'] += np.mean(sampling_iou)
            accum2['greedy_iou'] += np.mean(greedy_iou)
            # 打印损失
            print('Update {}, RL training of Second decoder, loss {}, model IoU {}'.format(step + 1,
                                                                             accum2['loss_total'],
                                                                             accum2['greedy_iou']))
            accum2 = defaultdict(float)

            if (step + 1) % self.save_every == 0:
                print('Saving training parameters after Updating...')
                save_dir = '/data/duye/pretrained_models/OnLineTraining/'
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                torch.save(self.model.state_dict(), save_dir + str(step+1) + '.pth')
def train(config, load_resnet50=False, pre_trained=None, cur_epochs=0):
    batch_size = config['batch_size']
    lr = config['lr']
    epochs = config['epoch']
    train_dataloader = loadData('train', 16, 71, batch_size)
    val_loader = loadData('val', 16, 71, batch_size, shuffle=False)
    model = PolygonModel(load_predtrained_resnet50=load_resnet50,
                         predict_delta=True).to(devices)

    if pre_trained is not None:
        model.load_state_dict(torch.load(pre_trained))
        print('loaded pretrained polygon net!')

    # set to eval
    model.encoder.eval()
    model.delta_encoder.eval()

    for n, p in model.named_parameters():
        if 'encoder' in n:
            print('Not train:', n)
            p.requires_grad = False

    print('No weight decay in RL training')

    train_params = [p for p in model.parameters() if p.requires_grad]
    train_params1 = []
    train_params2 = []
    for n, p in model.named_parameters():
        if p.requires_grad and 'delta' not in n:
            train_params1.append(p)
        elif p.requires_grad and 'delta' in n:
            train_params2.append(p)

    # Adam 优化方法
    optimizer = optim.Adam(train_params, lr=lr, amsgrad=False)
    optimizer1 = optim.Adam(train_params1, lr=lr, amsgrad=False)
    optimizer2 = optim.Adam(train_params2, lr=lr, amsgrad=False)

    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=config['lr_decay'][0],
                                          gamma=config['lr_decay'][1])

    print('Total Epochs:', epochs)
    for it in range(cur_epochs, epochs):
        # init
        accum = defaultdict(float)
        accum2 = defaultdict(float)
        model.delta_model.train()
        model.decoder.train()
        for index, batch in enumerate(train_dataloader):
            img = torch.tensor(batch[0], dtype=torch.float).cuda()
            bs = img.shape[0]
            WH = batch[-1]  # WH_dict
            left_WH = WH['left_WH']
            origion_WH = WH['origion_WH']
            object_WH = WH['object_WH']

            # TODO: step1
            model.delta_model.eval()
            model.decoder.train()
            outdict_sample = model(img,
                                   mode='train_rl',
                                   temperature=config['temperature'],
                                   temperature2=0.0)  # (bs, seq_len, 28*28+1)
            # greedy
            with torch.no_grad():
                outdict_greedy = model(img, mode='train_rl', temperature=0.0)

            # Get RL loss
            sampling_pred_x = outdict_sample['final_pred_x'].cpu().numpy()
            sampling_pred_y = outdict_sample['final_pred_y'].cpu().numpy()
            sampling_pred_len = outdict_sample['lengths'].cpu().numpy()
            greedy_pred_x = outdict_greedy['final_pred_x'].cpu().numpy()
            greedy_pred_y = outdict_greedy['final_pred_y'].cpu().numpy()
            greedy_pred_len = outdict_greedy['lengths'].cpu().numpy()
            sampling_iou = np.zeros(bs, dtype=np.float32)
            greedy_iou = np.zeros(bs, dtype=np.float32)
            vertices_GT = []  # (bs, 70, 2)
            vertices_sampling = []
            vertices_greedy = []
            GT_polys = batch[-2].numpy()  # (bs, 70, 2)
            GT_mask = batch[7]  # (bs, 70)

            for ii in range(bs):
                scaleW = 224.0 / float(object_WH[0][ii])
                scaleH = 224.0 / float(object_WH[1][ii])
                leftW = left_WH[0][ii]
                leftH = left_WH[1][ii]
                tmp = []
                all_len = np.sum(GT_mask[ii].numpy())
                cnt_target = GT_polys[ii][:all_len]
                for vert in cnt_target:
                    tmp.append(
                        (vert[0] / scaleW + leftW, vert[1] / scaleH + leftH))
                vertices_GT.append(tmp)

                tmp = []
                for j in range(sampling_pred_len[ii] - 1):
                    vertex = (sampling_pred_x[ii][j] / scaleW + leftW,
                              sampling_pred_y[ii][j] / scaleH + leftH)
                    tmp.append(vertex)
                vertices_sampling.append(tmp)

                tmp = []
                for j in range(greedy_pred_len[ii] - 1):
                    vertex = (greedy_pred_x[ii][j] / scaleW + leftW,
                              greedy_pred_y[ii][j] / scaleH + leftH)
                    tmp.append(vertex)
                vertices_greedy.append(tmp)

            # IoU between sampling/greedy and GT
            for ii in range(bs):
                sam = vertices_sampling[ii]
                gt = vertices_GT[ii]
                gre = vertices_greedy[ii]

                if len(sam) < 2:
                    sampling_iou[ii] = 0.
                else:
                    iou_sam, _, _ = iou(sam, gt, origion_WH[1][ii],
                                        origion_WH[0][ii])
                    sampling_iou[ii] = iou_sam

                if len(gre) < 2:
                    greedy_iou[ii] = 0.
                else:
                    iou_gre, _, _ = iou(gre, gt, origion_WH[1][ii],
                                        origion_WH[0][ii])
                    greedy_iou[ii] = iou_gre

            logprobs = outdict_sample['log_probs']
            # 强化学习损失,logprob是两个logprob加和
            loss = losses.self_critical_loss(
                logprobs, outdict_sample['lengths'],
                torch.from_numpy(sampling_iou).to(devices),
                torch.from_numpy(greedy_iou).to(devices))
            model.zero_grad()
            if 'grid_clip' in config:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         config['grad_clip'])
            loss.backward()
            optimizer1.step()  # 更新参数

            accum['loss_total'] += loss
            accum['sampling_iou'] += np.mean(sampling_iou)
            accum['greedy_iou'] += np.mean(greedy_iou)
            # 打印损失
            if (index + 1) % 20 == 0:
                print('Epoch {} - Step {}'.format(it + 1, index + 1))
                print(
                    '     Main Decoder: loss_total {}, sampling_iou {}, greedy_iou {}'
                    .format(accum['loss_total'] / 20,
                            accum['sampling_iou'] / 20,
                            accum['greedy_iou'] / 20))
                accum = defaultdict(float)

            # TODO:训练delta_model decoder step2
            model.decoder.eval()
            model.delta_model.train()

            outdict_sample = model(
                img,
                mode='train_rl',
                temperature=0.0,
                temperature2=config['temperature2'])  # (bs, seq_len, 28*28+1)
            # greedy
            with torch.no_grad():
                outdict_greedy = model(img,
                                       mode='train_rl',
                                       temperature=0.0,
                                       temperature2=0.0)

            # Get RL loss
            sampling_pred_x = outdict_sample['final_pred_x'].cpu().numpy()
            sampling_pred_y = outdict_sample['final_pred_y'].cpu().numpy()
            sampling_pred_len = outdict_sample['lengths'].cpu().numpy()
            greedy_pred_x = outdict_greedy['final_pred_x'].cpu().numpy()
            greedy_pred_y = outdict_greedy['final_pred_y'].cpu().numpy()
            greedy_pred_len = outdict_greedy['lengths'].cpu().numpy()
            sampling_iou = np.zeros(bs, dtype=np.float32)
            greedy_iou = np.zeros(bs, dtype=np.float32)
            vertices_GT = []  # (bs, 70, 2)
            vertices_sampling = []
            vertices_greedy = []
            GT_polys = batch[-2].numpy()  # (bs, 70, 2)
            GT_mask = batch[7]  # (bs, 70)

            for ii in range(bs):
                scaleW = 224.0 / object_WH[0][ii]
                scaleH = 224.0 / object_WH[1][ii]
                leftW = left_WH[0][ii]
                leftH = left_WH[1][ii]
                tmp = []
                all_len = np.sum(GT_mask[ii].numpy())
                cnt_target = GT_polys[ii][:all_len]
                for vert in cnt_target:
                    tmp.append(
                        (vert[0] / scaleW + leftW, vert[1] / scaleH + leftH))
                vertices_GT.append(tmp)

                tmp = []
                for j in range(sampling_pred_len[ii] - 1):
                    vertex = (sampling_pred_x[ii][j] / scaleW + leftW,
                              sampling_pred_y[ii][j] / scaleH + leftH)
                    tmp.append(vertex)
                vertices_sampling.append(tmp)

                tmp = []
                for j in range(greedy_pred_len[ii] - 1):
                    vertex = (greedy_pred_x[ii][j] / scaleW + leftW,
                              greedy_pred_y[ii][j] / scaleH + leftH)
                    tmp.append(vertex)
                vertices_greedy.append(tmp)

            # IoU between sampling/greedy and GT
            for ii in range(bs):
                sam = vertices_sampling[ii]
                gt = vertices_GT[ii]
                gre = vertices_greedy[ii]

                if len(sam) < 2:
                    sampling_iou[ii] = 0.
                else:
                    iou_sam, _, _ = iou(sam, gt, origion_WH[1][ii],
                                        origion_WH[0][ii])
                    sampling_iou[ii] = iou_sam

                if len(gre) < 2:
                    greedy_iou[ii] = 0.
                else:
                    iou_gre, _, _ = iou(gre, gt, origion_WH[1][ii],
                                        origion_WH[0][ii])
                    greedy_iou[ii] = iou_gre

            # TODO:
            logprobs = outdict_sample['delta_logprob']
            # 强化学习损失,logprob是两个logprob加和
            loss = losses.self_critical_loss(
                logprobs, outdict_sample['lengths'],
                torch.from_numpy(sampling_iou).to(devices),
                torch.from_numpy(greedy_iou).to(devices))
            model.zero_grad()
            if 'grid_clip' in config:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         config['grad_clip'])
            loss.backward()
            optimizer2.step()

            accum2['loss_total'] += loss
            accum2['sampling_iou'] += np.mean(sampling_iou)
            accum2['greedy_iou'] += np.mean(greedy_iou)
            # 打印损失
            if (index + 1) % 20 == 0:
                print(
                    '     Second Decoder: loss_total {}, sampling_iou {}, greedy_iou {}'
                    .format(accum2['loss_total'] / 20,
                            accum2['sampling_iou'] / 20,
                            accum2['greedy_iou'] / 20))
                accum2 = defaultdict(float)

            if (index + 1) % config['val_every'] == 0:
                # validation
                model.decoder.eval()
                model.delta_model.eval()
                val_IoU = []
                less_than2 = 0
                with torch.no_grad():
                    for val_index, val_batch in enumerate(val_loader):
                        img = torch.tensor(val_batch[0],
                                           dtype=torch.float).cuda()
                        bs = img.shape[0]
                        WH = val_batch[-1]  # WH_dict
                        left_WH = WH['left_WH']
                        origion_WH = WH['origion_WH']
                        object_WH = WH['object_WH']
                        val_target = val_batch[-2].numpy()  # (bs, 70, 2)
                        val_mask_final = val_batch[7]  # (bs, 70)
                        out_dict = model(
                            img, mode='test')  # (N, seq_len) # test_time
                        pred_x = out_dict['final_pred_x'].cpu().numpy()
                        pred_y = out_dict['final_pred_y'].cpu().numpy()
                        pred_len = out_dict['lengths']  # 预测的长度
                        # 求IoU
                        for ii in range(bs):
                            vertices1 = []
                            vertices2 = []
                            scaleW = 224.0 / float(object_WH[0][ii])
                            scaleH = 224.0 / float(object_WH[1][ii])
                            leftW = left_WH[0][ii]
                            leftH = left_WH[1][ii]

                            all_len = np.sum(val_mask_final[ii].numpy())
                            cnt_target = val_target[ii][:all_len]
                            for vert in cnt_target:
                                vertices2.append((vert[0] / scaleW + leftW,
                                                  vert[1] / scaleH + leftH))
                            pred_len_b = pred_len[ii] - 1
                            if pred_len_b < 2:
                                val_IoU.append(0.)
                                less_than2 += 1
                                continue
                            for j in range(pred_len_b):
                                vertex = (pred_x[ii][j] / scaleW + leftW,
                                          pred_y[ii][j] / scaleH + leftH)
                                vertices1.append(vertex)

                            _, nu_cur, de_cur = iou(vertices1, vertices2,
                                                    origion_WH[1][ii],
                                                    origion_WH[0][ii])
                            iou_cur = nu_cur * 1.0 / de_cur if de_cur != 0 else 0
                            val_IoU.append(iou_cur)

                val_iou_data = np.mean(np.array(val_IoU))
                print('Validation After Epoch {} - step {}'.format(
                    str(it + 1), str(index + 1)))
                print('           IoU      on validation set: ', val_iou_data)
                print('less than 2: ', less_than2)
                print('Saving training parameters after this epoch:')
                torch.save(
                    model.state_dict(),
                    '/data/duye/pretrained_models/FPNRLtrain/ResNext_Plus_RL2_retain_Epoch{}-Step{}_ValIoU{}.pth'
                    .format(str(it + 1), str(index + 1), str(val_iou_data)))
                # set to init
                model.decoder.train()  # important
                model.delta_model.train()

        scheduler.step()
        print('Epoch {} Completed!'.format(str(it + 1)))
def train(config, load_resnet50=False, pre_trained=None, cur_epochs=0):

    batch_size = config['batch_size']
    lr = config['lr']
    epochs = config['epoch']

    train_dataloader = loadData('train', 16, 71, batch_size)
    val_loader = loadData('val', 16, 71, batch_size * 2, shuffle=False)
    model = PolygonModel(load_predtrained_resnet50=load_resnet50,
                         predict_delta=True).to(devices)
    # checkpoint
    if pre_trained is not None:
        # model.load_state_dict(torch.load(pre_trained))
        # 逐参数load
        dict = torch.load(pre_trained)
        pre_name = []
        for name in dict:
            pre_name.append(name)
        for name in model.state_dict():
            if name in pre_name:
                model.state_dict()[name].data.copy_(dict[name])
        print('loaded pretrained polygon net!')

    # Set to eval
    model.encoder.eval()
    for n, p in model.named_parameters():
        if 'encoder' in n and 'delta' not in n:
            print('Not train:', n)
            p.requires_grad = False
    no_wd = []
    wd = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            # No optimization for frozen params
            continue
        if 'bn' in name or 'convLSTM' in name or 'bias' in name:
            no_wd.append(param)
        else:
            wd.append(param)

    optimizer = optim.Adam([{
        'params': no_wd,
        'weight_decay': 0.0
    }, {
        'params': wd
    }],
                           lr=lr,
                           weight_decay=config['weight_decay'],
                           amsgrad=False)

    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=config['lr_decay'][0],
                                          gamma=config['lr_decay'][1])

    print('Total Epochs:', epochs)
    for it in range(cur_epochs, epochs):
        accum = defaultdict(float)
        for index, batch in enumerate(train_dataloader):
            img = torch.tensor(batch[0], dtype=torch.float).cuda()
            bs = img.shape[0]
            pre_v2 = torch.tensor(batch[2], dtype=torch.float).cuda()
            pre_v1 = torch.tensor(batch[3], dtype=torch.float).cuda()
            outdict = model(img, pre_v2, pre_v1,
                            mode='train_ce')  # (bs, seq_len, 28*28+1)s
            out = outdict['logits']
            out = out.contiguous().view(-1,
                                        28 * 28 + 1)  # (bs*seq_len, 28*28+1)
            target = batch[4]
            mask_delta = batch[7]
            mask_delta = torch.tensor(mask_delta).cuda().view(-1)  # (bs*70)
            # # smooth target
            target = dt_targets_from_class(np.array(target, dtype=np.int), 28,
                                           2)  # (bs, seq_len, 28*28+1)
            target = torch.from_numpy(target).cuda().contiguous().view(
                -1, 28 * 28 + 1)  # (bs, seq_len, 28*28+1)
            # Cross-Entropy Loss
            mask_final = batch[6]  # 结束符标志mask
            mask_final = torch.tensor(mask_final).cuda().view(-1)
            loss_lstm = torch.sum(-target *
                                  torch.nn.functional.log_softmax(out, dim=1),
                                  dim=1)  # (bs*seq_len)
            loss_lstm = loss_lstm * mask_final.type_as(
                loss_lstm)  # 从end point截断损失计算
            loss_lstm = loss_lstm.view(bs, -1)  # (bs, seq_len)
            loss_lstm = torch.sum(loss_lstm, dim=1)  # sum over seq_len  (bs,)
            real_pointnum = torch.sum(mask_final.contiguous().view(bs, -1),
                                      dim=1)
            loss_lstm = loss_lstm / real_pointnum  # mean over seq_len
            loss_lstm = torch.mean(loss_lstm)  # mean over batch

            # Delta prediction Cross-Entropy Loss
            delta_target = prepare_delta_target(outdict['pred_polys'],
                                                torch.tensor(batch[-2]).cuda())
            delta_target = dt_targets_from_class(
                np.array(delta_target.cpu().numpy(), dtype=np.int), 15,
                1)  # No smooth
            delta_target = torch.from_numpy(
                delta_target[:, :, :-1]).cuda().contiguous().view(-1, 15 * 15)
            delta_logits = outdict['delta_logits'][:, :-1, :]  # (bs, 70, 225)
            delta_logits = delta_logits.contiguous().view(-1, 15 *
                                                          15)  # (bs*70, 225)
            # TODO:get delta loss
            tmp = torch.sum(
                -delta_target *
                torch.nn.functional.log_softmax(delta_logits, dim=1),
                dim=1)
            tmp = tmp * mask_delta.type_as(tmp)
            tmp = tmp.view(bs, -1)
            # sum over len_s  (bs,)
            tmp = torch.sum(tmp, dim=1)
            real_pointnum2 = torch.sum(mask_delta.contiguous().view(bs, -1),
                                       dim=1)
            tmp = tmp / real_pointnum2
            loss_delta = torch.mean(tmp)

            loss = config['lambda'] * loss_delta + loss_lstm
            model.zero_grad()

            if 'grid_clip' in config:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         config['grad_clip'])
            loss.backward()
            accum['loss_total'] += loss
            accum['loss_lstm'] += loss_lstm
            accum['loss_delta'] += loss_delta
            optimizer.step()

            # 打印损失
            if (index + 1) % 20 == 0:
                print(
                    'Epoch {} - Step {}, loss_total: {} [Loss lstm: {}, loss delta: {}]'
                    .format(it + 1, index + 1, accum['loss_total'] / 20,
                            accum['loss_lstm'] / 20, accum['loss_delta'] / 20))
                accum = defaultdict(float)

            # 每3000step一次
            if (index + 1) % config['val_every'] == 0:
                # validation
                model.delta_encoder.eval()
                model.delta_model.eval()
                model.decoder.eval()
                val_IoU = []
                less_than2 = 0
                with torch.no_grad():
                    for val_index, val_batch in enumerate(val_loader):
                        img = torch.tensor(val_batch[0],
                                           dtype=torch.float).cuda()
                        bs = img.shape[0]
                        WH = val_batch[-1]  # WH_dict
                        left_WH = WH['left_WH']
                        origion_WH = WH['origion_WH']
                        object_WH = WH['object_WH']
                        # target,在224*224中的坐标
                        val_target = val_batch[-2].numpy()  # (bs, 70, 2)
                        val_mask_final = val_batch[7]  # (bs, 70)
                        out_dict = model(
                            img, mode='test')  # (N, seq_len) # test_time
                        pred_x = out_dict['final_pred_x'].cpu().numpy()
                        pred_y = out_dict['final_pred_y'].cpu().numpy()
                        pred_len = out_dict['lengths']  # 预测的长度
                        # 求IoU
                        for ii in range(bs):
                            vertices1 = []
                            vertices2 = []
                            scaleW = 224.0 / object_WH[0][ii]
                            scaleH = 224.0 / object_WH[1][ii]
                            leftW = left_WH[0][ii]
                            leftH = left_WH[1][ii]

                            all_len = np.sum(val_mask_final[ii].numpy())
                            cnt_target = val_target[ii][:all_len]
                            for vert in cnt_target:
                                vertices2.append((vert[0] / scaleW + leftW,
                                                  vert[1] / scaleH + leftH))

                            # print('target:', cnt_target)

                            pred_len_b = pred_len[ii] - 1
                            if pred_len_b < 2:
                                val_IoU.append(0.)
                                less_than2 += 1
                                continue

                            for j in range(pred_len_b):
                                vertex = (pred_x[ii][j] / scaleW + leftW,
                                          pred_y[ii][j] / scaleH + leftH)
                                vertices1.append(vertex)

                            _, nu_cur, de_cur = iou(
                                vertices1, vertices2, origion_WH[1][ii],
                                origion_WH[0][ii])  # (H, W)
                            iou_cur = nu_cur * 1.0 / de_cur if de_cur != 0 else 0
                            val_IoU.append(iou_cur)

                val_iou_data = np.mean(np.array(val_IoU))
                print('Validation After Epoch {} - step {}'.format(
                    str(it + 1), str(index + 1)))
                print('           IoU      on validation set: ', val_iou_data)
                print('less than 2: ', less_than2)
                if it > 4:  # it = 6
                    print('Saving training parameters after this epoch:')
                    torch.save(
                        model.state_dict(),
                        '/data/duye/pretrained_models/ResNext_Plus_DeltaModel_Epoch{}-Step{}_ValIoU{}.pth'
                        .format(str(it + 1), str(index + 1),
                                str(val_iou_data)))
                # set to init
                model.delta_encoder.train()
                model.delta_model.train()  # important
                model.decoder.train()

        # 衰减
        scheduler.step()
        print()
        print('Epoch {} Completed!'.format(str(it + 1)))
        print()
示例#4
0
class OnLineTrainer:
    def __init__(self,
                 num_workers=8,
                 update_every=8,
                 save_every=100,
                 t1=0.1,
                 t2=0.1,
                 pre=None):
        self.num_workers = num_workers
        self.update_every = update_every
        self.save_every = save_every
        self.t1 = t1
        self.t2 = t2
        self.max_epoch = 2
        self.model = PolygonModel(predict_delta=True).to(devices)
        if pre != None:
            self.model.load_state_dict(torch.load(pre))
        self.dataloader = loadAde20K(batch_size=4)
        self.model.encoder.eval()
        self.model.delta_encoder.eval()
        for n, p in self.model.named_parameters():
            if 'encoder' in n:
                p.requires_grad = False
        self.train_params = [
            p for p in self.model.parameters() if p.requires_grad == True
        ]
        self.optimizer = optim.Adam(self.train_params, lr=2e-5, amsgrad=False)

    # TODO: 1. 多训练一下Rooftop把精度提升上去(现在是每50次保存一下)
    #  2. ADE20K train数据集加载不对,gt维度不匹配
    #  3. cityscape每20保存一下,看看是否有不妥?

    def train(self):
        accum = defaultdict(float)
        accum2 = defaultdict(float)
        global_step = 0.
        for epoch in range(self.max_epoch):
            for step, batch in enumerate(self.dataloader):
                global_step += 1
                b = []
                b.append(batch['s1'])
                b.append(batch['s2'])
                b.append(batch['s3'])
                b.append(batch['s4'])

                # print(b1[0].shape)  # (3, 224, 224)

                img = torch.cat([
                    b[0][0].unsqueeze(0), b[1][0].unsqueeze(0),
                    b[2][0].unsqueeze(0), b[3][0].unsqueeze(0)
                ],
                                dim=0).to(devices)

                bs = img.shape[0]

                # TODO: step1
                self.model.delta_model.eval()
                self.model.decoder.train()
                outdict_sample = self.model(img,
                                            mode='train_rl',
                                            temperature=self.t1,
                                            temperature2=0.0)
                # greedy
                with torch.no_grad():
                    outdict_greedy = self.model(img,
                                                mode='train_rl',
                                                temperature=0.0)
                # Get RL loss
                sampling_pred_x = outdict_sample['final_pred_x'].cpu().numpy()
                sampling_pred_y = outdict_sample['final_pred_y'].cpu().numpy()
                sampling_pred_len = outdict_sample['lengths'].cpu().numpy()
                greedy_pred_x = outdict_greedy['final_pred_x'].cpu().numpy()
                greedy_pred_y = outdict_greedy['final_pred_y'].cpu().numpy()
                greedy_pred_len = outdict_greedy['lengths'].cpu().numpy()
                sampling_iou = np.zeros(bs, dtype=np.float32)
                greedy_iou = np.zeros(bs, dtype=np.float32)

                vertices_sampling = []
                vertices_greedy = []
                for ii in range(bs):
                    WH = b[ii][-1]
                    object_WH = WH['object_WH']
                    left_WH = WH['left_WH']
                    #     WH = {'left_WH': left_WH, 'object_WH': object_WH, 'origion_WH': origion_WH}
                    scaleW = 224.0 / float(object_WH[0])
                    scaleH = 224.0 / float(object_WH[1])
                    leftW = left_WH[0]
                    leftH = left_WH[1]

                    tmp = []
                    for j in range(sampling_pred_len[ii] - 1):
                        vertex = (sampling_pred_x[ii][j] / scaleW + leftW,
                                  sampling_pred_y[ii][j] / scaleH + leftH)
                        tmp.append(vertex)
                    vertices_sampling.append(tmp)

                    tmp = []
                    for j in range(greedy_pred_len[ii] - 1):
                        vertex = (greedy_pred_x[ii][j] / scaleW + leftW,
                                  greedy_pred_y[ii][j] / scaleH + leftH)
                        tmp.append(vertex)
                    vertices_greedy.append(tmp)
                # IoU between sampling/greedy and GT
                for ii in range(bs):
                    gt = b[ii][1]  # (H, W)
                    WH = b[ii][-1]
                    origion_WH = WH['origion_WH']
                    sam = vertices_sampling[ii]
                    gre = vertices_greedy[ii]
                    if len(sam) < 2:
                        sampling_iou[ii] = 0.
                    else:
                        # iou_sam, _, _ = iou(sam, gt, origion_WH[1], origion_WH[0])
                        # sampling_iou[ii] = iou_sam
                        img1 = Image.new('L', (origion_WH[0], origion_WH[1]),
                                         0)
                        ImageDraw.Draw(img1).polygon(sam, outline=1, fill=1)
                        mask1 = np.array(img1)  # (h, w)
                        intersection = np.logical_and(mask1, gt)  # 都是1
                        union = np.logical_or(mask1, gt)  # 有个1
                        nu = np.sum(intersection)
                        de = np.sum(union)
                        sampling_iou[ii] = nu * 1.0 / de if de != 0 else 0.
                    if len(gre) < 2:
                        greedy_iou[ii] = 0.
                    else:
                        # iou_gre, _, _ = iou(gre, gt, origion_WH[1], origion_WH[0])
                        # greedy_iou[ii] = iou_gre
                        img1 = Image.new('L', (origion_WH[0], origion_WH[1]),
                                         0)
                        ImageDraw.Draw(img1).polygon(gre, outline=1, fill=1)
                        mask1 = np.array(img1)  # (h, w)
                        intersection = np.logical_and(mask1, gt)  # 都是1
                        union = np.logical_or(mask1, gt)  # 有个1
                        nu = np.sum(intersection)
                        de = np.sum(union)
                        greedy_iou[ii] = nu * 1.0 / de if de != 0 else 0.
                logprobs = outdict_sample['log_probs']
                # 强化学习损失,logprob是两个logprob加和
                loss = losses.self_critical_loss(
                    logprobs, outdict_sample['lengths'],
                    torch.from_numpy(sampling_iou).to(devices),
                    torch.from_numpy(greedy_iou).to(devices))
                self.model.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 40)
                self.optimizer.step()  # 更新参数

                accum['loss_total'] += loss
                accum['sampling_iou'] += np.mean(sampling_iou)
                accum['greedy_iou'] += np.mean(greedy_iou)
                # 打印损失
                print(
                    'Update {}, RL training of Main decoder, loss {}, model IoU {}'
                    .format(step + 1, accum['loss_total'],
                            accum['greedy_iou']))
                accum = defaultdict(float)
                # TODO:训练delta_model decoder step2
                self.model.decoder.eval()
                self.model.delta_model.train()
                outdict_sample = self.model(img,
                                            mode='train_rl',
                                            temperature=0.0,
                                            temperature2=self.t2)
                # greedy
                with torch.no_grad():
                    outdict_greedy = self.model(img,
                                                mode='train_rl',
                                                temperature=0.0,
                                                temperature2=0.0)
                # Get RL loss
                sampling_pred_x = outdict_sample['final_pred_x'].cpu().numpy()
                sampling_pred_y = outdict_sample['final_pred_y'].cpu().numpy()
                sampling_pred_len = outdict_sample['lengths'].cpu().numpy()
                greedy_pred_x = outdict_greedy['final_pred_x'].cpu().numpy()
                greedy_pred_y = outdict_greedy['final_pred_y'].cpu().numpy()
                greedy_pred_len = outdict_greedy['lengths'].cpu().numpy()
                sampling_iou = np.zeros(bs, dtype=np.float32)
                greedy_iou = np.zeros(bs, dtype=np.float32)
                vertices_sampling = []
                vertices_greedy = []
                for ii in range(bs):
                    WH = b[ii][-1]
                    object_WH = WH['object_WH']
                    left_WH = WH['left_WH']
                    #     WH = {'left_WH': left_WH, 'object_WH': object_WH, 'origion_WH': origion_WH}
                    scaleW = 224.0 / float(object_WH[0])
                    scaleH = 224.0 / float(object_WH[1])
                    leftW = left_WH[0]
                    leftH = left_WH[1]

                    tmp = []
                    for j in range(sampling_pred_len[ii] - 1):
                        vertex = (sampling_pred_x[ii][j] / scaleW + leftW,
                                  sampling_pred_y[ii][j] / scaleH + leftH)
                        tmp.append(vertex)
                    vertices_sampling.append(tmp)

                    tmp = []
                    for j in range(greedy_pred_len[ii] - 1):
                        vertex = (greedy_pred_x[ii][j] / scaleW + leftW,
                                  greedy_pred_y[ii][j] / scaleH + leftH)
                        tmp.append(vertex)
                    vertices_greedy.append(tmp)
                # IoU between sampling/greedy and GT
                for ii in range(bs):
                    gt = b[ii][1]  # (H, W)
                    WH = b[ii][-1]
                    origion_WH = WH['origion_WH']
                    sam = vertices_sampling[ii]
                    gre = vertices_greedy[ii]
                    if len(sam) < 2:
                        sampling_iou[ii] = 0.
                    else:
                        # iou_sam, _, _ = iou(sam, gt, origion_WH[1], origion_WH[0])
                        # sampling_iou[ii] = iou_sam
                        img1 = Image.new('L', (origion_WH[0], origion_WH[1]),
                                         0)
                        ImageDraw.Draw(img1).polygon(sam, outline=1, fill=1)
                        mask1 = np.array(img1)  # (h, w)
                        intersection = np.logical_and(mask1, gt)  # 都是1
                        union = np.logical_or(mask1, gt)  # 有个1
                        nu = np.sum(intersection)
                        de = np.sum(union)
                        sampling_iou[ii] = nu * 1.0 / de if de != 0 else 0.
                    if len(gre) < 2:
                        greedy_iou[ii] = 0.
                    else:
                        # iou_gre, _, _ = iou(gre, gt, origion_WH[1], origion_WH[0])
                        # greedy_iou[ii] = iou_gre
                        img1 = Image.new('L', (origion_WH[0], origion_WH[1]),
                                         0)
                        ImageDraw.Draw(img1).polygon(gre, outline=1, fill=1)
                        mask1 = np.array(img1)  # (h, w)
                        intersection = np.logical_and(mask1, gt)  # 都是1
                        union = np.logical_or(mask1, gt)  # 有个1
                        nu = np.sum(intersection)
                        de = np.sum(union)
                        greedy_iou[ii] = nu * 1.0 / de if de != 0 else 0.
                logprobs = outdict_sample['log_probs']
                # 强化学习损失,logprob是两个logprob加和
                loss = losses.self_critical_loss(
                    logprobs, outdict_sample['lengths'],
                    torch.from_numpy(sampling_iou).to(devices),
                    torch.from_numpy(greedy_iou).to(devices))
                self.model.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 40)
                self.optimizer.step()  # 更新参数
                accum2['loss_total'] += loss
                accum2['sampling_iou'] += np.mean(sampling_iou)
                accum2['greedy_iou'] += np.mean(greedy_iou)
                # 打印损失
                print(
                    'Update {}, RL training of Second decoder, loss {}, model IoU {}'
                    .format(step + 1, accum2['loss_total'],
                            accum2['greedy_iou']))
                accum2 = defaultdict(float)

                if global_step % self.save_every == 0:
                    print('Saving training parameters after Updating...')
                    save_dir = '/data/duye/pretrained_models/OnLineTraining_ADE20K/'
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir)
                    torch.save(self.model.state_dict(),
                               save_dir + str(global_step) + '.pth')
def train(config, load_resnet50=False, pre_trained=None, cur_epochs=0):

    batch_size = config['batch_size']
    lr = config['lr']
    epochs = config['epoch']

    train_dataloader = loadData('train', 16, 71, batch_size)
    val_loader = loadData('val', 16, 71, batch_size, shuffle=False)
    model = PolygonModel(load_predtrained_resnet50=load_resnet50,
                         predict_delta=False).cuda()
    # checkpoint
    if pre_trained is not None:
        model.load_state_dict(torch.load(pre_trained))
        print('loaded pretrained polygon net!')

    # Regulation,原paper没有+regulation
    no_wd = []
    wd = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            # No optimization for frozen params
            continue
        if 'bn' in name or 'convLSTM' in name or 'bias' in name:
            no_wd.append(param)
        else:
            wd.append(param)

    optimizer = optim.Adam([{
        'params': no_wd,
        'weight_decay': 0.0
    }, {
        'params': wd
    }],
                           lr=lr,
                           weight_decay=config['weight_decay'],
                           amsgrad=False)

    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=config['lr_decay'][0],
                                          gamma=config['lr_decay'][1])

    print('Total Epochs:', epochs)
    for it in range(cur_epochs, epochs):
        accum = defaultdict(float)
        # accum['loss_total'] = 0.
        # accum['loss_lstm'] = 0.
        # accum['loss_delta'] = 0.
        for index, batch in enumerate(train_dataloader):
            img = torch.tensor(batch[0], dtype=torch.float).cuda()
            bs = img.shape[0]
            pre_v2 = torch.tensor(batch[2], dtype=torch.float).cuda()
            pre_v1 = torch.tensor(batch[3], dtype=torch.float).cuda()
            outdict = model(img, pre_v2, pre_v1,
                            mode='train_ce')  # (bs, seq_len, 28*28+1)s

            out = outdict['logits']
            # 之前训练不小心加了下面这句
            # out = torch.nn.functional.log_softmax(out, dim=-1)  # logits->log_probs
            out = out.contiguous().view(-1,
                                        28 * 28 + 1)  # (bs*seq_len, 28*28+1)
            target = batch[4]

            # smooth target
            target = dt_targets_from_class(np.array(target, dtype=np.int), 28,
                                           2)  # (bs, seq_len, 28*28+1)
            target = torch.from_numpy(target).cuda().contiguous().view(
                -1, 28 * 28 + 1)  # (bs, seq_len, 28*28+1)
            # 交叉熵损失计算
            mask_final = batch[6]  # 结束符标志mask  (bs, seq_len(70)从第一个点开始)
            mask_final = torch.tensor(mask_final).cuda().view(-1)
            mask_delta = batch[7]
            mask_delta = torch.tensor(mask_delta).cuda().view(-1)  # (bs*70)
            loss_lstm = torch.sum(-target *
                                  torch.nn.functional.log_softmax(out, dim=1),
                                  dim=1)  # (bs*seq_len)
            loss_lstm = loss_lstm * mask_final.type_as(
                loss_lstm)  # 从end point截断损失计算
            loss_lstm = loss_lstm.view(bs, -1)  # (bs, seq_len)
            loss_lstm = torch.sum(loss_lstm, dim=1)  # sum over seq_len  (bs,)
            real_pointnum = torch.sum(mask_final.contiguous().view(bs, -1),
                                      dim=1)
            loss_lstm = loss_lstm / real_pointnum  # mean over seq_len
            loss_lstm = torch.mean(loss_lstm)  # mean over batch

            # loss = loss_lstm + loss_delta
            loss = loss_lstm
            #TODO: 这里train_ce可以用这个loss, 但train_rl可以根据条件概率重写损失函数
            model.zero_grad()

            if 'grid_clip' in config:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         config['grad_clip'])

            loss.backward()

            accum['loss_total'] += loss
            optimizer.step()

            # 打印损失
            if (index + 1) % 20 == 0:
                print('Epoch {} - Step {}, loss_total {}'.format(
                    it + 1, index, accum['loss_total'] / 20))
                accum = defaultdict(float)
            # 每3000step一次
            if (index + 1) % config['val_every'] == 0:
                # validation
                model.eval()  # 原作者只eval了这个
                val_IoU = []
                less_than2 = 0
                with torch.no_grad():
                    for val_index, val_batch in enumerate(val_loader):
                        img = torch.tensor(val_batch[0],
                                           dtype=torch.float).cuda()
                        bs = img.shape[0]

                        WH = val_batch[-1]  # WH_dict
                        left_WH = WH['left_WH']
                        origion_WH = WH['origion_WH']
                        object_WH = WH['object_WH']

                        val_mask_final = val_batch[6]
                        val_mask_final = torch.tensor(
                            val_mask_final).cuda().contiguous().view(-1)
                        out_dict = model(
                            img, mode='test')  # (N, seq_len) # test_time
                        pred_polys = out_dict['pred_polys']  # (bs, seq_len)
                        tmp = pred_polys
                        pred_polys = pred_polys.contiguous().view(
                            -1)  # (bs*seq_len)
                        val_target = val_batch[4]  # (bs, seq_len)
                        # 求accuracy
                        val_target = torch.tensor(
                            val_target,
                            dtype=torch.long).cuda().contiguous().view(
                                -1)  # (bs*seq_len)
                        val_acc1 = torch.tensor(pred_polys == val_target,
                                                dtype=torch.float).cuda()
                        val_acc1 = (val_acc1 * val_mask_final).sum().item()
                        val_acc1 = val_acc1 * 1.0 / val_mask_final.sum().item()
                        # 用作计算IoU
                        val_result_index = tmp.cpu().numpy()  # (bs, seq_len)
                        val_target = val_batch[4].numpy()  # (bs, seq_len)

                        # 求IoU
                        for ii in range(bs):
                            vertices1 = []
                            vertices2 = []
                            scaleW = 224.0 / object_WH[0][ii]
                            scaleH = 224.0 / object_WH[1][ii]
                            leftW = left_WH[0][ii]
                            leftH = left_WH[1][ii]
                            for label in val_result_index[ii]:
                                if label == 28 * 28:
                                    break
                                vertex = (
                                    ((label % 28) * 8.0 + 4) / scaleW + leftW,
                                    ((int(label / 28)) * 8.0 + 4) / scaleH +
                                    leftH)
                                vertices1.append(vertex)
                            for label in val_target[ii]:
                                if label == 28 * 28:
                                    break
                                vertex = (
                                    ((label % 28) * 8.0 + 4) / scaleW + leftW,
                                    ((int(label / 28)) * 8.0 + 4) / scaleH +
                                    leftH)
                                vertices2.append(vertex)
                            if len(vertices1) < 2:
                                less_than2 += 1
                                # IoU=0.
                                val_IoU.append(0.)
                                continue
                            _, nu_cur, de_cur = iou(
                                vertices1, vertices2, origion_WH[1][ii],
                                origion_WH[0][ii])  # (H, W)
                            iou_cur = nu_cur * 1.0 / de_cur if de_cur != 0 else 0
                            val_IoU.append(iou_cur)

                val_iou_data = np.mean(np.array(val_IoU))
                print('Validation After Epoch {} - step {}'.format(
                    str(it + 1), str(index + 1)))
                print('           IoU      on validation set: ', val_iou_data)
                print('less than 2: ', less_than2)
                if it > 4:  # it = 5
                    print('Saving training parameters after this epoch:')
                    torch.save(
                        model.state_dict(),
                        '/data/duye/pretrained_models/ResNext50_FPN_LSTM_Epoch{}-Step{}_ValIoU{}.pth'
                        .format(str(it + 1), str(index + 1),
                                str(val_iou_data)))
                # set to init
                model.train()  # important

        # 衰减
        scheduler.step()
        # 打印当前lr
        print()
        print('Epoch {} Completed!'.format(str(it + 1)))
        print()