예제 #1
0
def train_val(config):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    train_loader = get_dataloader(img_dir=config.train_img_dir,
                                  mask_dir=config.train_mask_dir,
                                  mode="train",
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers,
                                  smooth=config.smooth)
    val_loader = get_dataloader(img_dir=config.val_img_dir,
                                mask_dir=config.val_mask_dir,
                                mode="val",
                                batch_size=config.batch_size,
                                num_workers=config.num_workers)

    writer = SummaryWriter(
        comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" %
        (config.lr, config.batch_size, config.model_type, config.data_type))

    if config.model_type == "UNet":
        model = UNet()
    elif config.model_type == "UNet++":
        model = UNetPP()
    elif config.model_type == "SEDANet":
        model = SEDANet()
    elif config.model_type == "RefineNet":
        model = rf101()
    elif config.model_type == "DANet":
        # src = "./pretrained/60_DANet_0.8086.pth"
        # pretrained_dict = torch.load(src, map_location='cpu').module.state_dict()
        # print("load pretrained params from stage 1: " + src)
        # pretrained_dict.pop('seg1.1.weight')
        # pretrained_dict.pop('seg1.1.bias')
        model = DANet(backbone='resnext101',
                      nclass=config.output_ch,
                      pretrained=True,
                      norm_layer=nn.BatchNorm2d)
        # model_dict = model.state_dict()
        # model_dict.update(pretrained_dict)
        # model.load_state_dict(model_dict)
    elif config.model_type == "Deeplabv3+":
        # src = "./pretrained/Deeplabv3+.pth"
        # pretrained_dict = torch.load(src, map_location='cpu').module.state_dict()
        # print("load pretrained params from stage 1: " + src)
        # # print(pretrained_dict.keys())
        # for key in list(pretrained_dict.keys()):
        #     if key.split('.')[0] == "cbr_last":
        #         pretrained_dict.pop(key)
        model = deeplabv3_plus.DeepLabv3_plus(in_channels=3,
                                              num_classes=config.output_ch,
                                              backend='resnet101',
                                              os=16,
                                              pretrained=True,
                                              norm_layer=nn.BatchNorm2d)
        # model_dict = model.state_dict()
        # model_dict.update(pretrained_dict)
        # model.load_state_dict(model_dict)
    elif config.model_type == "HRNet_OCR":
        model = seg_hrnet_ocr.get_seg_model()
    elif config.model_type == "scSEUNet":
        model = scSEUNet(pretrained=True, norm_layer=nn.BatchNorm2d)
    else:
        model = UNet()

    if config.iscontinue:
        model = torch.load("./exp/13_Deeplabv3+_0.7619.pth",
                           map_location='cpu').module

    for k, m in model.named_modules():
        m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatability

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(device)

    labels = [1, 2, 3, 4, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    objects = [
        '水体', '道路', '建筑物', '机场', '停车场', '操场', '普通耕地', '农业大棚', '自然草地', '绿地绿化',
        '自然林', '人工林', '自然裸土', '人为裸土', '其它'
    ]
    frequency = np.array([
        0.0279, 0.0797, 0.1241, 0.00001, 0.0616, 0.0029, 0.2298, 0.0107,
        0.1207, 0.0249, 0.1470, 0.0777, 0.0617, 0.0118, 0.0187
    ])

    if config.optimizer == "sgd":
        optimizer = SGD(model.parameters(),
                        lr=config.lr,
                        weight_decay=1e-4,
                        momentum=0.9)
    elif config.optimizer == "adamw":
        optimizer = adamw.AdamW(model.parameters(), lr=config.lr)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    # weight = torch.tensor([1, 1.5, 1, 2, 1.5, 2, 2, 1.2]).to(device)
    # criterion = nn.CrossEntropyLoss(weight=weight)

    if config.smooth == "all":
        criterion = LabelSmoothSoftmaxCE()
    elif config.smooth == "edge":
        criterion = LabelSmoothCE()
    else:
        criterion = nn.CrossEntropyLoss()

    # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[25, 30, 35, 40], gamma=0.5)
    # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.1, patience=5, verbose=True)
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
                                                         T_0=15,
                                                         eta_min=1e-4)

    global_step = 0
    max_fwiou = 0
    for epoch in range(config.num_epochs):
        epoch_loss = 0.0
        seed = np.random.randint(0, 2, 1)
        seed = 0
        print("seed is ", seed)
        if seed == 1:
            train_loader = get_dataloader(img_dir=config.train_img_dir,
                                          mask_dir=config.train_mask_dir,
                                          mode="train",
                                          batch_size=config.batch_size // 2,
                                          num_workers=config.num_workers,
                                          smooth=config.smooth)
            val_loader = get_dataloader(img_dir=config.val_img_dir,
                                        mask_dir=config.val_mask_dir,
                                        mode="val",
                                        batch_size=config.batch_size // 2,
                                        num_workers=config.num_workers)
        else:
            train_loader = get_dataloader(img_dir=config.train_img_dir,
                                          mask_dir=config.train_mask_dir,
                                          mode="train",
                                          batch_size=config.batch_size,
                                          num_workers=config.num_workers,
                                          smooth=config.smooth)
            val_loader = get_dataloader(img_dir=config.val_img_dir,
                                        mask_dir=config.val_mask_dir,
                                        mode="val",
                                        batch_size=config.batch_size,
                                        num_workers=config.num_workers)

        cm = np.zeros([15, 15])
        print(optimizer.param_groups[0]['lr'])
        with tqdm(total=config.num_train,
                  desc="Epoch %d / %d" % (epoch + 1, config.num_epochs),
                  unit='img',
                  ncols=100) as train_pbar:
            model.train()

            for image, mask in train_loader:
                image = image.to(device, dtype=torch.float32)

                if seed == 0:
                    pass
                elif seed == 1:
                    image = F.interpolate(image,
                                          size=(384, 384),
                                          mode='bilinear',
                                          align_corners=True)
                    mask = F.interpolate(mask.float(),
                                         size=(384, 384),
                                         mode='nearest')

                if config.smooth == "edge":
                    mask = mask.to(device, dtype=torch.float32)
                else:
                    mask = mask.to(device, dtype=torch.long).argmax(dim=1)

                aux_out, out = model(image)
                aux_loss = criterion(aux_out, mask)
                seg_loss = criterion(out, mask)
                loss = aux_loss + seg_loss

                # pred = model(image)
                # loss = criterion(pred, mask)

                epoch_loss += loss.item()

                writer.add_scalar('Loss/train', loss.item(), global_step)
                train_pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_pbar.update(image.shape[0])
                global_step += 1
                # if global_step > 10:
                #     break

            # scheduler.step()
            print("\ntraining epoch loss: " +
                  str(epoch_loss / (float(config.num_train) /
                                    (float(config.batch_size)))))
        torch.cuda.empty_cache()
        val_loss = 0
        with torch.no_grad():
            with tqdm(total=config.num_val,
                      desc="Epoch %d / %d validation round" %
                      (epoch + 1, config.num_epochs),
                      unit='img',
                      ncols=100) as val_pbar:
                model.eval()
                locker = 0
                for image, mask in val_loader:
                    image = image.to(device, dtype=torch.float32)
                    target = mask.to(device, dtype=torch.long).argmax(dim=1)
                    mask = mask.cpu().numpy()
                    _, pred = model(image)
                    val_loss += F.cross_entropy(pred, target).item()
                    pred = pred.cpu().detach().numpy()
                    mask = semantic_to_mask(mask, labels)
                    pred = semantic_to_mask(pred, labels)
                    cm += get_confusion_matrix(mask, pred, labels)
                    val_pbar.update(image.shape[0])
                    if locker == 5:
                        writer.add_images('mask_a/true',
                                          mask[2, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_a/pred',
                                          pred[2, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_b/true',
                                          mask[3, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_b/pred',
                                          pred[3, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                    locker += 1

                    # break
                miou = get_miou(cm)
                fw_miou = (miou * frequency).sum()
                scheduler.step()

                if True:
                    if torch.__version__ == "1.6.0":
                        torch.save(model,
                                   config.result_path + "/%d_%s_%.4f.pth" %
                                   (epoch + 1, config.model_type, fw_miou),
                                   _use_new_zipfile_serialization=False)
                    else:
                        torch.save(
                            model, config.result_path + "/%d_%s_%.4f.pth" %
                            (epoch + 1, config.model_type, fw_miou))
                    max_fwiou = fw_miou
                print("\n")
                print(miou)
                print("testing epoch loss: " + str(val_loss),
                      "FWmIoU = %.4f" % fw_miou)
                writer.add_scalar('FWIoU/val', fw_miou, epoch + 1)
                writer.add_scalar('loss/val', val_loss, epoch + 1)
                for idx, name in enumerate(objects):
                    writer.add_scalar('iou/val' + name, miou[idx], epoch + 1)
            torch.cuda.empty_cache()
    writer.close()
    print("Training finished")
예제 #2
0
def train_val(config):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    train_loader = get_dataloader(img_dir=config.train_img_dir,
                                  mask_dir=config.train_mask_dir,
                                  mode="train",
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers,
                                  smooth=config.smooth)
    val_loader = get_dataloader(img_dir=config.val_img_dir,
                                mask_dir=config.val_mask_dir,
                                mode="val",
                                batch_size=4,
                                num_workers=config.num_workers)

    writer = SummaryWriter(
        comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" %
        (config.lr, config.batch_size, config.model_type, config.data_type))

    if config.model_type == "UNet":
        model = UNet()
    elif config.model_type == "UNet++":
        model = UNetPP()
    elif config.model_type == "SEDANet":
        model = SEDANet()
    elif config.model_type == "RefineNet":
        model = rf101()
    elif config.model_type == "BASNet":
        model = BASNet(n_classes=8)
    elif config.model_type == "DANet":
        model = DANet(backbone='resnet101',
                      nclass=config.output_ch,
                      pretrained=True,
                      norm_layer=nn.BatchNorm2d)
    elif config.model_type == "Deeplabv3+":
        model = deeplabv3_plus.DeepLabv3_plus(in_channels=3,
                                              num_classes=8,
                                              backend='resnet101',
                                              os=16,
                                              pretrained=True,
                                              norm_layer=nn.BatchNorm2d)
    elif config.model_type == "HRNet_OCR":
        model = seg_hrnet_ocr.get_seg_model()
    elif config.model_type == "scSEUNet":
        model = scSEUNet(pretrained=True, norm_layer=nn.BatchNorm2d)
    else:
        model = UNet()

    if config.iscontinue:
        model = torch.load("./exp/24_Deeplabv3+_0.7825757691389714.pth").module

    for k, m in model.named_modules():
        m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatability

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(device)

    labels = [100, 200, 300, 400, 500, 600, 700, 800]
    objects = ['水体', '交通建筑', '建筑', '耕地', '草地', '林地', '裸土', '其他']

    if config.optimizer == "sgd":
        optimizer = SGD(model.parameters(),
                        lr=config.lr,
                        weight_decay=1e-4,
                        momentum=0.9)
    elif config.optimizer == "adamw":
        optimizer = adamw.AdamW(model.parameters(), lr=config.lr)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    # weight = torch.tensor([1, 1.5, 1, 2, 1.5, 2, 2, 1.2]).to(device)
    # criterion = nn.CrossEntropyLoss(weight=weight)

    criterion = BasLoss()

    # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[25, 30, 35, 40], gamma=0.5)
    # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.1, patience=5, verbose=True)
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
                                                         T_0=15,
                                                         eta_min=1e-4)

    global_step = 0
    max_fwiou = 0
    frequency = np.array(
        [0.1051, 0.0607, 0.1842, 0.1715, 0.0869, 0.1572, 0.0512, 0.1832])
    for epoch in range(config.num_epochs):
        epoch_loss = 0.0
        cm = np.zeros([8, 8])
        print(optimizer.param_groups[0]['lr'])
        with tqdm(total=config.num_train,
                  desc="Epoch %d / %d" % (epoch + 1, config.num_epochs),
                  unit='img',
                  ncols=100) as train_pbar:
            model.train()

            for image, mask in train_loader:
                image = image.to(device, dtype=torch.float32)
                mask = mask.to(device, dtype=torch.float16)

                pred = model(image)
                loss = criterion(pred, mask)
                epoch_loss += loss.item()

                writer.add_scalar('Loss/train', loss.item(), global_step)
                train_pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_pbar.update(image.shape[0])
                global_step += 1
                # if global_step > 10:
                #     break

            # scheduler.step()
            print("\ntraining epoch loss: " +
                  str(epoch_loss / (float(config.num_train) /
                                    (float(config.batch_size)))))
            torch.cuda.empty_cache()

        val_loss = 0
        with torch.no_grad():
            with tqdm(total=config.num_val,
                      desc="Epoch %d / %d validation round" %
                      (epoch + 1, config.num_epochs),
                      unit='img',
                      ncols=100) as val_pbar:
                model.eval()
                locker = 0
                for image, mask in val_loader:
                    image = image.to(device, dtype=torch.float32)
                    target = mask.to(device, dtype=torch.long).argmax(dim=1)
                    mask = mask.cpu().numpy()
                    pred, _, _, _, _, _, _, _ = model(image)
                    val_loss += F.cross_entropy(pred, target).item()
                    pred = pred.cpu().detach().numpy()
                    mask = semantic_to_mask(mask, labels)
                    pred = semantic_to_mask(pred, labels)
                    cm += get_confusion_matrix(mask, pred, labels)
                    val_pbar.update(image.shape[0])
                    if locker == 25:
                        writer.add_images('mask_a/true',
                                          mask[2, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_a/pred',
                                          pred[2, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_b/true',
                                          mask[3, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_b/pred',
                                          pred[3, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                    locker += 1

                    # break
                miou = get_miou(cm)
                fw_miou = (miou * frequency).sum()
                scheduler.step()

                if fw_miou > max_fwiou:
                    if torch.__version__ == "1.6.0":
                        torch.save(model,
                                   config.result_path + "/%d_%s_%.4f.pth" %
                                   (epoch + 1, config.model_type, fw_miou),
                                   _use_new_zipfile_serialization=False)
                    else:
                        torch.save(
                            model, config.result_path + "/%d_%s_%.4f.pth" %
                            (epoch + 1, config.model_type, fw_miou))
                    max_fwiou = fw_miou
                print("\n")
                print(miou)
                print("testing epoch loss: " + str(val_loss),
                      "FWmIoU = %.4f" % fw_miou)
                writer.add_scalar('mIoU/val', miou.mean(), epoch + 1)
                writer.add_scalar('FWIoU/val', fw_miou, epoch + 1)
                writer.add_scalar('loss/val', val_loss, epoch + 1)
                for idx, name in enumerate(objects):
                    writer.add_scalar('iou/val' + name, miou[idx], epoch + 1)
                torch.cuda.empty_cache()
    writer.close()
    print("Training finished")
def main():
    ##################################################################data
    CAMVID_PATH = Path('./Data/dataprocess/RITE-RGB-40-True/')
    transform_train = transforms.Compose([transforms.ToTensor()])
    train_joint_transformer = transforms.Compose([joint_transforms.JointRandomHorizontalFlip()])
	#train
    train_dset = camvid.CamVid(CAMVID_PATH, Gray = Gray_Flag, index_start = 0, index_end = 20, joint_transform=train_joint_transformer, transform=transform_train)
    train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, shuffle=True)
	#val
    val_dset = camvid.CamVid(CAMVID_PATH, Gray = Gray_Flag, index_start = 20, index_end = 40, joint_transform=None, transform=transform_train)
    val_loader = torch.utils.data.DataLoader(val_dset, batch_size=batch_size, shuffle=True)
    makedir_Flag = True
    if makedir_Flag:
        dirs = ['./result/' +PATH +'/visulizeT0/','./result/' +PATH+'/visulizeT1/']
        for dir_num in range(len(dirs)):
            os.makedirs(dirs[dir_num], exist_ok = True)
    ########################
    global best_dice
    global best_acc
    print(args.model)
    if args.model =='unet':
        model = Unet3.UNet(in_channels = inch, num_classes = class_num, filter_scale=1)
    elif args.model =='deeplabv3':
        model = deeplabv3.DeepLabV3(class_num = class_num)
    elif args.model =='deeplabv3_plus':
        model = deeplabv3_plus.DeepLabv3_plus(in_channels = inch, num_classes = class_num, backend = 'mobilenet_v2', os = 16, pretrained = 'imagenet')
    elif args.model =='pspnet':
        model = pspnet.PSPNet(in_channels = inch, num_classes = class_num, backend = 'resnet101',pool_scales = (1, 2, 3, 6), pretrained = 'imagenet')
    elif args.model =='unet_ae':
        model = unet_ae.UnetResnetAE(in_channels = inch, num_classes = class_num, backend = 'resnet101', pretrained = 'imagenet')
    elif args.model =='U_Net':
        model = Unet2.U_Net(img_ch= inch,output_ch = class_num)
    elif args.model =='R2U_Net':
        model = R2Unet.R2U_Net(img_ch = inch,output_ch = class_num,t = args.t)
    elif args.model =='Attunet':
        model = AttUnet.AttU_Net(img_ch = inch,output_ch = class_num)
    elif args.model == 'R2Attunet':
        model = R2AttUnet.R2AttU_Net(img_ch = inch,output_ch = class_num,t = args.t)
    elif args.model == 'segnet':
        model = segnet.SegNet(num_classes = class_num, in_channels= inch)
    elif args.model == 'cenet':
        model = cenet.CE_Net(num_classes = class_num, num_channels= inch)
    elif args.model == 'unet_nested':
        model = UNet_Nested.UNet_Nested(in_channels = inch, n_classes= class_num)
    elif args.model == 'denseaspp':
        model = denseaspp.DenseASPP(class_num = class_num)
    elif args.model == 'refinenet':
        model = RefineNet.get_refinenet(input_size = 256, num_classes = class_num)
    # elif args.model == 'rdfnet':#out = net(left, right)
        # model = rdfnet.RDF(input_size = 256, num_classes = class_num)
    device = torch.device("cuda")
    model = torch.nn.DataParallel(model).to(device).cuda()
    print('model',model)
    print('loss',args.model, args.loss)
    criterion = LossSelector[args.loss](**args.loss_params[args.loss])
    cudnn.benchmark = True
    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
    optimizer = optim.Adam(model.parameters(), lr = args.lr)
    test_dices = []
    test_aucs = []
    model_fold = "./PKL/"+PATH + '/'
    os.makedirs(model_fold, exist_ok = True)
    for epoch in range(start_epoch, args.epochs):
        print('--------------------------------------')
        ###########################
        
        #####################
        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr']))
        print('val', va)
        train_loss, train_dice_x1, train_dice_x2, train_dice_x3, train_acc = train(train_loader, model, optimizer, epoch, dirs[0], criterion)
        val_loss, val_dice1, val_dice2, val_dice3, val_acc = validate(val_loader, model, epoch, dirs[1])
            # test(train_test_loader, model, epoch, dirs[2])
        # save model
        monitor_dice = val_dice1 + val_dice2 + val_dice3
        monitor_acc = val_acc
        is_best_dice = monitor_dice > best_dice
        is_best_acc = monitor_acc > best_acc
        best_dice = max(monitor_dice, best_dice)
        best_acc = max(monitor_acc, best_acc)
        if is_best_dice or is_best_acc:
            name = model_fold + str(epoch) + '_D1-' + str(round(val_dice1,4)) + '_D2-' + str(round(val_dice2,4))+ '_D3-' + str(round(val_dice3,4))+ '_A-' + str(round(val_acc,4)) + '.pkl'
            print(name)
            # torch.save(model, name)
    print('Best dice:', best_dice)
    print('Best acc:', best_acc)