def getscore2(model_path, saved=False, maxnum=float('inf')): model = PolygonModel(predict_delta=True).to(devices) if model_path is not None: model.load_state_dict(torch.load(model_path)) print('Model loaded!') # set to eval model.eval() iou_score = 0. nu = 0. # Intersection de = 0. # Union count = 0 files = glob.glob('/data/duye/KITTI/image/*') # 所有img iouss = [] trans = transforms.Compose([ transforms.ToTensor(), ]) for idx, f in enumerate(files): # data = scio.loadmat(f) # 读取相应的Image文件 # img_f = f[:-3] + 'JPG' image = Image.open(f).convert('RGB') # png文件 # img_gt = Image.open(f).convert('RGB') W = image.width H = image.height scaleH = 224.0 / float(H) scaleW = 224.0 / float(W) # 裁减,resize到224*224 img_new = image.resize((224, 224), Image.BILINEAR) img_new = trans(img_new) img_new = img_new.unsqueeze(0) color = [np.random.randint(0, 255) for _ in range(3)] color += [100] color = tuple(color) with torch.no_grad(): pre_v2 = None pre_v1 = None result_dict = model(img_new.to(devices), pre_v2, pre_v1, mode='test', temperature=0.0) # (bs, seq_len) # [0, 224] index 0: only one sample in mini-batch here pred_x = result_dict['final_pred_x'].cpu().numpy()[0] pred_y = result_dict['final_pred_y'].cpu().numpy()[0] pred_lengths = result_dict['lengths'].cpu().numpy()[0] pred_len = np.sum(pred_lengths) - 1 # sub EOS vertices1 = [] # Get the pred poly for i in range(pred_len): vert = (pred_x[i] / scaleW, pred_y[i] / scaleH) vertices1.append(vert) if saved: try: drw = ImageDraw.Draw(image, 'RGBA') drw.polygon(vertices1, color) except TypeError: continue # GT gt_name = '/data/duye/KITTI/label/' + f.split( '/')[-1][:-4] + '.png' # print(gt_name) # 读取mask gt_mask = Image.open(gt_name) gt_mask = np.array(gt_mask) # (H, W) gt_mask[gt_mask > 0] = 255 gt_mask[gt_mask == 255] = 1 if saved: pass # GT draw # drw_gt = ImageDraw.Draw(img_gt, 'RGBA') # drw_gt.polygon(vertices2, color) # calculate IoU img1 = Image.new('L', (W, H), 0) ImageDraw.Draw(img1).polygon(vertices1, outline=1, fill=1) pre_mask = np.array(img1) # (H, W) # get iou intersection = np.logical_and(gt_mask, pre_mask) union = np.logical_or(gt_mask, pre_mask) nu = np.sum(intersection) de = np.sum(union) iiou = nu / (de * 1.0) if de != 0 else 0. iouss.append(iiou) count += 1 print(count) if saved: print('saving test result image...') save_result_dir = '/data/duye/save_dir/' image.save(save_result_dir + str(idx) + '_pred_rooftop.png', 'PNG') # img_gt.save(save_result_dir + str(idx) + '_gt_rooftop.png', 'PNG') if count >= maxnum: break iouss.sort() iouss.reverse() print(iouss) true_iou = np.mean(np.array(iouss[:741])) return iou_score, nu, de, true_iou
if __name__ == '__main__': print('start calculating...') # load_model = 'New_FPN2_DeltaModel_Epoch6-Step6000_ValIoU0.5855123247072112.pth' # load_model = 'New_FPN2_DeltaModel_Epoch5-Step4000_ValIoU0.6130178256915678.pth' # load_model = 'FPN_Epoch9-Step6000_ValIoU0.22982870834472033.pth' # load_model = 'Joint_FPN2_DeltaModel_Epoch7-Step3000_ValIoU0.6156470167768288.pth' # 现在最好的是61.65 # 把val_every设置为500应该还可以提一下 # load_model = 'Joint_FPN2_DeltaModel_Epoch9-Step6000_ValIoU0.6177457155898146.pth' load_model = 'ResNext_Plus_DeltaModel_Epoch7-Step3000_ValIoU0.619344842105648.pth' polynet_pretrained = '/data/duye/pretrained_models/' + load_model net = PolygonModel(predict_delta=True).to(device) net.load_state_dict(torch.load(polynet_pretrained)) # Test mode net.eval() print('Pretrained model \'{}\' loaded!'.format(load_model)) ious_test, less2_test, nu_test, de_test, iou_mean_class, iou_mean = get_score(net, maxnum=20, saved=True) print('Origin iou:', ious_test) print('True iou:', iou_mean_class) print('Mean: ', iou_mean) """ iou_mean_test = 0. for i in ious_test: iou_mean_test += ious_test[i] ious_val, less2_val, nu_val, de_val = get_score(net, dataset='val', saved=False) print('PRINT, VAL:', ious_val) nu_total = {} de_total = {} iou_total = {} for cls in selected_classes: # init
def getscore2(model_path, dataset='Rooftop', saved=False, maxnum=float('inf')): model = PolygonModel(predict_delta=True).to(devices) if model_path is not None: model.load_state_dict(torch.load(model_path)) print('Model loaded!') # set to eval model.eval() iou_score = 0. nu = 0. # Intersection de = 0. # Union count = 0 files = glob.glob('/data/duye/Aerial_Imagery/Rooftop/test/*.mat') # 所有mat文件 iouss = [] for idx, f in enumerate(files): data = scio.loadmat(f) # 读取相应的Image文件 img_f = f[:-3] + 'JPG' image = Image.open(img_f).convert('RGB') img_gt = Image.open(img_f).convert('RGB') I = np.array(image) W = image.width H = image.height lens = data['gt'][0].shape[0] for instance_id in range(lens): polygon = data['gt'][0][instance_id] polygon = np.array(polygon, dtype=np.float) vertex_num = len(polygon) if vertex_num < 3: continue # find min/max X,Y minW, minH = np.min(polygon, axis=0) maxW, maxH = np.max(polygon, axis=0) curW = maxW - minW curH = maxH - minH extendrate = 0.10 extendW = curW * extendrate extendH = curH * extendrate leftW = int(np.maximum(minW - extendW, 0)) leftH = int(np.maximum(minH - extendH, 0)) rightW = int(np.minimum(maxW + extendW, W)) rightH = int(np.minimum(maxH + extendH, H)) objectW = rightW - leftW objectH = rightH - leftH # 过滤掉小的和过大的 if objectW >= 150 or objectH >= 150: continue if objectW <= 20 or objectH <= 20: continue scaleH = 224.0 / float(objectH) scaleW = 224.0 / float(objectW) # 裁减,resize到224*224 # img_new = image.crop(box=(leftW, leftH, rightW, rightH)).resize((224, 224), Image.BILINEAR) I_obj = I[leftH:rightH, leftW:rightW, :] # To PIL image I_obj_img = Image.fromarray(I_obj) # resize I_obj_img = I_obj_img.resize((224, 224), Image.BILINEAR) I_obj_new = np.array(I_obj_img) # (H, W, C) I_obj_new = I_obj_new.transpose(2, 0, 1) # (C, H, W) I_obj_new = I_obj_new / 255.0 I_obj_tensor = torch.from_numpy(I_obj_new) # (C, H, W) I_obj_tensor = torch.tensor(I_obj_tensor.unsqueeze(0), dtype=torch.float).cuda() color = [np.random.randint(0, 255) for _ in range(3)] color += [100] color = tuple(color) with torch.no_grad(): pre_v2 = None pre_v1 = None result_dict = model(I_obj_tensor, pre_v2, pre_v1, mode='test', temperature=0.0) # (bs, seq_len) pred_x = result_dict['final_pred_x'].cpu().numpy()[0] pred_y = result_dict['final_pred_y'].cpu().numpy()[0] pred_lengths = result_dict['lengths'].cpu().numpy()[0] pred_len = np.sum(pred_lengths) - 1 # sub EOS vertices1 = [] vertices2 = [] # Get the pred poly for i in range(pred_len): vert = (pred_x[i] / scaleW + leftW, pred_y[i] / scaleH + leftH) vertices1.append(vert) if len(vertices1) < 3: continue if saved: try: drw = ImageDraw.Draw(image, 'RGBA') drw.polygon(vertices1, color) except TypeError: continue # GT for points in polygon: vertex = (points[0], points[1]) vertices2.append(vertex) if saved: # GT draw drw_gt = ImageDraw.Draw(img_gt, 'RGBA') drw_gt.polygon(vertices2, color) # calculate IoU tmp, nu_cur, de_cur = iou(vertices1, vertices2, H, W) nu += nu_cur de += de_cur iouss.append(tmp) count += 1 if saved: print('saving test result image...') save_result_dir = '/data/duye/save_dir/' image.save(save_result_dir + str(idx) + '_pred_rooftop.png', 'PNG') img_gt.save(save_result_dir + str(idx) + '_gt_rooftop.png', 'PNG') if count >= maxnum: break iouss.sort() iouss.reverse() true_iou = np.mean(np.array(iouss)) print(iouss) return iou_score, nu, de, true_iou
def getscore_kitti(model_path, saved=False, maxnum=float('inf')): model = PolygonModel(predict_delta=True).to(devices) if model_path is not None: model.load_state_dict(torch.load(model_path)) print('Model loaded!') # set to eval model.eval() iou_score = 0. nu = 0. # Intersection de = 0. # Union count = 0 files = glob.glob('/data/duye/KITTI/rawImage/*.png') # 所有img bbox = '/data/duye/KITTI/bbox/' annotation = '/data/duye/KITTI/annotation/' print(len(files)) iouss = [] for idx, f in enumerate(files): print('index:', idx) image = Image.open(f).convert('RGB') # raw image W = image.width H = image.height I = np.array(image) # print(I.shape) # 读相应的BD name = f.split('/')[-1][:-4] # 000019 bd = bbox + name + '.txt' if not os.path.exists(bd): continue # 相应的annotation sss = annotation + name + '.png' if not os.path.exists(sss): continue anno = Image.open(annotation + name + '.png') anno = np.array(anno) # 遍历 with open(bd, 'r') as bbd: all = bbd.readlines() for number, line in enumerate(all): line = line.replace('\n', '') line = line.split(' ') if float(line[0]) == 0.0 or \ float(line[1]) == 0.0 or \ float(line[2]) == 0.0 or\ float(line[3]) == 0.0: continue xx = float(line[0]) yy = float(line[1]) ww = float(line[2]) hh = float(line[3]) minW = xx minH = yy maxW = xx + ww maxH = yy + hh # 扩展10% extendrate = 0.08 curW = ww curH = hh extendW = int(round(curW * extendrate)) extendH = int(round(curH * extendrate)) leftW = int(np.maximum(minW - extendW, 0)) leftH = int(np.maximum(minH - extendH, 0)) rightW = int(np.minimum(maxW + extendW, W)) rightH = int(np.minimum(maxH + extendH, H)) # 当前object的BBoundBox大小,用作坐标缩放 objectW = rightW - leftW objectH = rightH - leftH scaleH = 224.0 / float(objectH) scaleW = 224.0 / float(objectW) # 裁减,resize到224*224 # img_new = image.crop(box=(leftW, leftH, rightW, rightH)).resize((224, 224), Image.BILINEAR) I_obj = I[leftH:rightH, leftW:rightW, :] # To PIL image I_obj_img = Image.fromarray(I_obj) # resize I_obj_img = I_obj_img.resize((224, 224), Image.BILINEAR) I_obj_new = np.array(I_obj_img) # (H, W, C) I_obj_new = I_obj_new.transpose(2, 0, 1) # (C, H, W) I_obj_new = I_obj_new / 255.0 I_obj_tensor = torch.from_numpy(I_obj_new) # (C, H, W) I_obj_tensor = torch.tensor(I_obj_tensor.unsqueeze(0), dtype=torch.float).to(devices) color = [np.random.randint(0, 255) for _ in range(3)] color += [100] color = tuple(color) with torch.no_grad(): pre_v2 = None pre_v1 = None result_dict = model(I_obj_tensor, pre_v2, pre_v1, mode='test', temperature=0.0) # (bs, seq_len) # [0, 224] index 0: only one sample in mini-batch here pred_x = result_dict['final_pred_x'].cpu().numpy()[0] pred_y = result_dict['final_pred_y'].cpu().numpy()[0] pred_lengths = result_dict['lengths'].cpu().numpy()[0] pred_len = np.sum(pred_lengths) - 1 # sub EOS vertices1 = [] # Get the pred poly for i in range(pred_len): vert = (pred_x[i] / scaleW + leftW, pred_y[i] / scaleH + leftH) vertices1.append(vert) if saved: try: drw = ImageDraw.Draw(image, 'RGBA') drw.polygon(vertices1, color) except TypeError: continue # pred-mask img1 = Image.new('L', (W, H), 0) ImageDraw.Draw(img1).polygon(vertices1, outline=1, fill=1) pre_mask = np.array(img1) # (H, W) # gt mask # number 这样不对! cur_anno = anno cur_anno = np.array(cur_anno == number + 1, dtype=int) # cur_anno[cur_anno != 1] = 0 # getIOU intersection = np.logical_and(cur_anno, pre_mask) union = np.logical_or(cur_anno, pre_mask) nu = np.sum(intersection) de = np.sum(union) iiou = nu / (de * 1.0) if de != 0 else 0. iouss.append(iiou) iouss.sort() iouss.reverse() print(iouss) print(np.mean(np.array(iouss))) true_iou = np.mean(np.array(iouss[:741])) return iou_score, nu, de, true_iou
def get_score_ADE20K(saved=False, maxnum=float('inf')): model = PolygonModel(predict_delta=True).to(devices) pre = 'ResNext_Plus_RL2_retain_Epoch1-Step4000_ValIoU0.6316584628283326.pth' dirs = '/data/duye/pretrained_models/FPNRLtrain/' + pre model.load_state_dict(torch.load(dirs)) model.eval() iou = [] print('starting.....') img_PATH = '/data/duye/ADE20K/validation/' lbl_path = '/data/duye/ADE20K/val_new/label/*.png' labels = glob.glob(lbl_path) for label in labels: name = label label = Image.open(label) label_index = name.split('_')[2] # 相应的txt文件 txt_file = '/data/duye/ADE20K/val_new/img/img_' + label_index + '.txt' with open(txt_file, "r") as f: # 打开文件 img_path = f.readline().replace('\n', '') # 读取文件 # 提取路径 img_path = img_PATH + img_path[36:] # raw image img = Image.open(img_path).convert('RGB') W = img.width H = img.height # 根据label label = np.array(label) # (H, W) Hs, Ws = np.where(label == np.max(label)) minH = np.min(Hs) maxH = np.max(Hs) minW = np.min(Ws) maxW = np.max(Ws) curW = maxW - minW curH = maxH - minH extendrate = 0.10 extendW = int(round(curW * extendrate)) extendH = int(round(curH * extendrate)) leftW = np.maximum(minW - extendW, 0) leftH = np.maximum(minH - extendH, 0) rightW = np.minimum(maxW + extendW, W) rightH = np.minimum(maxH + extendH, H) objectW = rightW - leftW objectH = rightH - leftH # print(leftH, rightH, leftW, rightW) # img_new = img.crop(box=(leftW, leftH, rightW, rightH)).resize((224, 224), Image.BILINEAR) I = np.array(img) I_obj = I[leftH:rightH, leftW:rightW, :] # To PIL image I_obj_img = Image.fromarray(I_obj) # resize I_obj_img = I_obj_img.resize((224, 224), Image.BILINEAR) I_obj_new = np.array(I_obj_img) # (H, W, C) I_obj_new = I_obj_new.transpose(2, 0, 1) # (C, H, W) I_obj_new = I_obj_new / 255.0 I_obj_tensor = torch.from_numpy(I_obj_new) # (C, H, W) I_obj_tensor = torch.tensor(I_obj_tensor.unsqueeze(0), dtype=torch.float).cuda() color = [np.random.randint(0, 255) for _ in range(3)] color += [100] color = tuple(color) with torch.no_grad(): pre_v2 = None pre_v1 = None result_dict = model(I_obj_tensor, pre_v2, pre_v1, mode='test', temperature=0.0) # (bs, seq_len) # [0, 224] index 0: only one sample in mini-batch here pred_x = result_dict['final_pred_x'].cpu().numpy()[0] pred_y = result_dict['final_pred_y'].cpu().numpy()[0] pred_lengths = result_dict['lengths'].cpu().numpy()[0] pred_len = np.sum(pred_lengths) - 1 # sub EOS vertices1 = [] scaleW = 224.0 / float(objectW) scaleH = 224.0 / float(objectH) # Get the pred poly for i in range(pred_len): vert = (pred_x[i] / scaleW + leftW, pred_y[i] / scaleH + leftH) vertices1.append(vert) img1 = Image.new('L', (W, H), 0) ImageDraw.Draw(img1).polygon(vertices1, outline=1, fill=1) pre_mask = np.array(img1) # (H, W) if saved: try: drw = ImageDraw.Draw(img, 'RGBA') drw.polygon(vertices1, color) except TypeError: continue gt_mask = np.array(label) gt_mask[gt_mask == 255] = 1 filt = np.sum(gt_mask) if filt <= 20 * 20: continue intersection = np.logical_and(gt_mask, pre_mask) union = np.logical_or(gt_mask, pre_mask) nu = np.sum(intersection) de = np.sum(union) # 求IoU iiou = nu / (de * 1.0) if de != 0 else 0. iou.append(iiou) iou.sort() iou.reverse() print(iou) print(len(iou)) print('IoU:', np.mean(np.array(iou)))
drop_last=False) print('DataLoader complete!', dataloader) return dataloader # 测试得分 devices = 'cuda' if torch.cuda.is_available() else 'cpu' if __name__ == '__main__': parse = argparse.ArgumentParser(description='测试在ssTEM上的泛化得分') parse.add_argument('-p', '--pretrained', type=str, default=None) args = parse.parse_args() pre = args.pretrained model = PolygonModel(predict_delta=True).to(devices) pre = 'ResNext_Plus_RL2_retain_Epoch1-Step4000_ValIoU0.6316584628283326.pth' dirs = '/data/duye/pretrained_models/FPNRLtrain/' + pre model.load_state_dict(torch.load(dirs)) model.eval() loader = loadssTEM(batch_size=8) iou = [] for index, batch in enumerate(loader): print('index: ', index) img = batch[0] WH = batch[-1] # WH_dict left_WH = WH['left_WH'] origion_WH = WH['origion_WH'] object_WH = WH['object_WH'] gt = batch[1] bs = img.shape[0] with torch.no_grad(): pre_v2 = None
def train(config, load_resnet50=False, pre_trained=None, cur_epochs=0): batch_size = config['batch_size'] lr = config['lr'] epochs = config['epoch'] train_dataloader = loadData('train', 16, 71, batch_size) val_loader = loadData('val', 16, 71, batch_size, shuffle=False) model = PolygonModel(load_predtrained_resnet50=load_resnet50, predict_delta=False).cuda() # checkpoint if pre_trained is not None: model.load_state_dict(torch.load(pre_trained)) print('loaded pretrained polygon net!') # Regulation,原paper没有+regulation no_wd = [] wd = [] for name, param in model.named_parameters(): if not param.requires_grad: # No optimization for frozen params continue if 'bn' in name or 'convLSTM' in name or 'bias' in name: no_wd.append(param) else: wd.append(param) optimizer = optim.Adam([{ 'params': no_wd, 'weight_decay': 0.0 }, { 'params': wd }], lr=lr, weight_decay=config['weight_decay'], amsgrad=False) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=config['lr_decay'][0], gamma=config['lr_decay'][1]) print('Total Epochs:', epochs) for it in range(cur_epochs, epochs): accum = defaultdict(float) # accum['loss_total'] = 0. # accum['loss_lstm'] = 0. # accum['loss_delta'] = 0. for index, batch in enumerate(train_dataloader): img = torch.tensor(batch[0], dtype=torch.float).cuda() bs = img.shape[0] pre_v2 = torch.tensor(batch[2], dtype=torch.float).cuda() pre_v1 = torch.tensor(batch[3], dtype=torch.float).cuda() outdict = model(img, pre_v2, pre_v1, mode='train_ce') # (bs, seq_len, 28*28+1)s out = outdict['logits'] # 之前训练不小心加了下面这句 # out = torch.nn.functional.log_softmax(out, dim=-1) # logits->log_probs out = out.contiguous().view(-1, 28 * 28 + 1) # (bs*seq_len, 28*28+1) target = batch[4] # smooth target target = dt_targets_from_class(np.array(target, dtype=np.int), 28, 2) # (bs, seq_len, 28*28+1) target = torch.from_numpy(target).cuda().contiguous().view( -1, 28 * 28 + 1) # (bs, seq_len, 28*28+1) # 交叉熵损失计算 mask_final = batch[6] # 结束符标志mask (bs, seq_len(70)从第一个点开始) mask_final = torch.tensor(mask_final).cuda().view(-1) mask_delta = batch[7] mask_delta = torch.tensor(mask_delta).cuda().view(-1) # (bs*70) loss_lstm = torch.sum(-target * torch.nn.functional.log_softmax(out, dim=1), dim=1) # (bs*seq_len) loss_lstm = loss_lstm * mask_final.type_as( loss_lstm) # 从end point截断损失计算 loss_lstm = loss_lstm.view(bs, -1) # (bs, seq_len) loss_lstm = torch.sum(loss_lstm, dim=1) # sum over seq_len (bs,) real_pointnum = torch.sum(mask_final.contiguous().view(bs, -1), dim=1) loss_lstm = loss_lstm / real_pointnum # mean over seq_len loss_lstm = torch.mean(loss_lstm) # mean over batch # loss = loss_lstm + loss_delta loss = loss_lstm #TODO: 这里train_ce可以用这个loss, 但train_rl可以根据条件概率重写损失函数 model.zero_grad() if 'grid_clip' in config: nn.utils.clip_grad_norm_(model.parameters(), config['grad_clip']) loss.backward() accum['loss_total'] += loss optimizer.step() # 打印损失 if (index + 1) % 20 == 0: print('Epoch {} - Step {}, loss_total {}'.format( it + 1, index, accum['loss_total'] / 20)) accum = defaultdict(float) # 每3000step一次 if (index + 1) % config['val_every'] == 0: # validation model.eval() # 原作者只eval了这个 val_IoU = [] less_than2 = 0 with torch.no_grad(): for val_index, val_batch in enumerate(val_loader): img = torch.tensor(val_batch[0], dtype=torch.float).cuda() bs = img.shape[0] WH = val_batch[-1] # WH_dict left_WH = WH['left_WH'] origion_WH = WH['origion_WH'] object_WH = WH['object_WH'] val_mask_final = val_batch[6] val_mask_final = torch.tensor( val_mask_final).cuda().contiguous().view(-1) out_dict = model( img, mode='test') # (N, seq_len) # test_time pred_polys = out_dict['pred_polys'] # (bs, seq_len) tmp = pred_polys pred_polys = pred_polys.contiguous().view( -1) # (bs*seq_len) val_target = val_batch[4] # (bs, seq_len) # 求accuracy val_target = torch.tensor( val_target, dtype=torch.long).cuda().contiguous().view( -1) # (bs*seq_len) val_acc1 = torch.tensor(pred_polys == val_target, dtype=torch.float).cuda() val_acc1 = (val_acc1 * val_mask_final).sum().item() val_acc1 = val_acc1 * 1.0 / val_mask_final.sum().item() # 用作计算IoU val_result_index = tmp.cpu().numpy() # (bs, seq_len) val_target = val_batch[4].numpy() # (bs, seq_len) # 求IoU for ii in range(bs): vertices1 = [] vertices2 = [] scaleW = 224.0 / object_WH[0][ii] scaleH = 224.0 / object_WH[1][ii] leftW = left_WH[0][ii] leftH = left_WH[1][ii] for label in val_result_index[ii]: if label == 28 * 28: break vertex = ( ((label % 28) * 8.0 + 4) / scaleW + leftW, ((int(label / 28)) * 8.0 + 4) / scaleH + leftH) vertices1.append(vertex) for label in val_target[ii]: if label == 28 * 28: break vertex = ( ((label % 28) * 8.0 + 4) / scaleW + leftW, ((int(label / 28)) * 8.0 + 4) / scaleH + leftH) vertices2.append(vertex) if len(vertices1) < 2: less_than2 += 1 # IoU=0. val_IoU.append(0.) continue _, nu_cur, de_cur = iou( vertices1, vertices2, origion_WH[1][ii], origion_WH[0][ii]) # (H, W) iou_cur = nu_cur * 1.0 / de_cur if de_cur != 0 else 0 val_IoU.append(iou_cur) val_iou_data = np.mean(np.array(val_IoU)) print('Validation After Epoch {} - step {}'.format( str(it + 1), str(index + 1))) print(' IoU on validation set: ', val_iou_data) print('less than 2: ', less_than2) if it > 4: # it = 5 print('Saving training parameters after this epoch:') torch.save( model.state_dict(), '/data/duye/pretrained_models/ResNext50_FPN_LSTM_Epoch{}-Step{}_ValIoU{}.pth' .format(str(it + 1), str(index + 1), str(val_iou_data))) # set to init model.train() # important # 衰减 scheduler.step() # 打印当前lr print() print('Epoch {} Completed!'.format(str(it + 1))) print()