def load(filename, model):
    print('load {}'.format(filename))
    if args.model_num == 'deeplab_se50':
        from semantic_segmentation.network.deepv3 import DeepSRNX50V3PlusD_m1  # r
        model = DeepSRNX50V3PlusD_m1(1, None)
    elif args.model_num == 'unet_ef3':
        from ef_unet import EfficientNet_3_unet
        model = EfficientNet_3_unet()
    elif args.model_num == 'unet_ef5':
        from ef_unet import EfficientNet_5_unet
        model = EfficientNet_5_unet()
    elif args.model_num == 'unet_se50':
        model = Unet_se50_model()
    elif args.model_num == 'unet_se101':
        model = Unet_se101_model()
    else:
        print('model is Error')

    model = torch.nn.DataParallel(model)
    pretrained_model_path = filename
    state = torch.load(pretrained_model_path)
    data = state['state_dict']
    model.load_state_dict(data)

    return model
 def get_model(model_name="deep_se101",
               in_channel=6,
               num_classes=1,
               criterion=SoftDiceLoss_binary()):
     if model_name == 'deep_se50':
         from semantic_segmentation.network.deepv3 import DeepSRNX50V3PlusD_m1  # r
         model = DeepSRNX50V3PlusD_m1(in_channel=in_channel,
                                      num_classes=num_classes,
                                      criterion=SoftDiceLoss_binary())
     elif model_name == 'deep_se101':
         from semantic_segmentation.network.deepv3 import DeepSRNX101V3PlusD_m1  # r
         model = DeepSRNX101V3PlusD_m1(in_channel=in_channel,
                                       num_classes=num_classes,
                                       criterion=SoftDiceLoss_binary())
     elif model_name == 'WideResnet38':
         from semantic_segmentation.network.deepv3 import DeepWR38V3PlusD_m1  # r
         model = DeepWR38V3PlusD_m1(in_channel=in_channel,
                                    num_classes=num_classes,
                                    criterion=SoftDiceLoss_binary())
     elif model_name == 'unet_ef3':
         from ef_unet import EfficientNet_3_unet
         model = EfficientNet_3_unet()
     elif model_name == 'unet_ef5':
         from ef_unet import EfficientNet_5_unet
         model = EfficientNet_5_unet()
     else:
         print('No model name in it')
         model = None
     return model
Exemplo n.º 3
0
def train_one_model(model_name, img_size, use_chexpert, path_data):
    RESIZE_SIZE = img_size

    train_transform = albumentations.Compose([
        albumentations.Resize(RESIZE_SIZE, RESIZE_SIZE),
        albumentations.OneOf([
            albumentations.RandomGamma(gamma_limit=(60, 120), p=0.9),
            albumentations.RandomBrightnessContrast(brightness_limit=0.2,
                                                    contrast_limit=0.2,
                                                    p=0.9),
            albumentations.CLAHE(clip_limit=4.0, tile_grid_size=(4, 4), p=0.9),
        ]),
        albumentations.OneOf([
            albumentations.Blur(blur_limit=4, p=1),
            albumentations.MotionBlur(blur_limit=4, p=1),
            albumentations.MedianBlur(blur_limit=4, p=1)
        ],
                             p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.ShiftScaleRotate(shift_limit=0.2,
                                        scale_limit=0.2,
                                        rotate_limit=20,
                                        interpolation=cv2.INTER_LINEAR,
                                        border_mode=cv2.BORDER_CONSTANT,
                                        p=1),
        albumentations.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225),
                                 max_pixel_value=255.0,
                                 p=1.0)
    ])
    val_transform = albumentations.Compose([
        albumentations.Resize(RESIZE_SIZE, RESIZE_SIZE, p=1),
        albumentations.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225),
                                 max_pixel_value=255.0,
                                 p=1.0)
    ])
    if not os.path.exists(snapshot_path):
        os.makedirs(snapshot_path)
    header = [
        'Epoch', 'Learning rate', 'Time', 'Train Loss', 'Val Loss',
        'best_thr_with_no_mask', 'best_dice_with_no_mask',
        'best_thr_without_no_mask', 'best_dice_without_no_mask'
    ]
    if not os.path.isfile(snapshot_path + '/log.csv'):
        with open(snapshot_path + '/log.csv', 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(header)
    kfold_path = path_data['k_fold_path']
    extra_data = path_data['extra_img_csv']

    for f_fold in range(5):
        num_fold = f_fold
        print(num_fold)

        with open(snapshot_path + '/log.csv', 'a', newline='') as f:

            writer = csv.writer(f)
            writer.writerow([num_fold])

        if use_chexpert:  # only use csv1
            df1 = pd.read_csv(csv_path)
            df2 = pd.read_csv(extra_data +
                              'chexpert_mask_{}.csv'.format(num_fold + 1))
            df_all = df1.append(df2, ignore_index=True)

            f_train = open(kfold_path + 'fold' + str(num_fold) + '/train.txt',
                           'r')
            f_val = open(kfold_path + 'fold' + str(num_fold) + '/val.txt', 'r')

            f_fake = open(
                extra_data + '/chexpert_list_{}.txt'.format(num_fold + 1), 'r')

            c_train = f_train.readlines()
            c_val = f_val.readlines()
            c_fake = f_fake.readlines()
            c_train = c_fake + c_train

            f_train.close()
            f_val.close()
            f_fake.close()

        else:
            df_all = pd.read_csv(csv_path)
            f_train = open(kfold_path + 'fold' + str(num_fold) + '/train.txt',
                           'r')
            f_val = open(kfold_path + 'fold' + str(num_fold) + '/val.txt', 'r')
            c_train = f_train.readlines()
            c_val = f_val.readlines()
            f_train.close()
            f_val.close()

        c_train = [s.replace('\n', '') for s in c_train]
        c_val = [s.replace('\n', '') for s in c_val]

        print('train dataset:', len(c_train),
              '  val dataset c_val_without_no_mask:', 476,
              '  val dataset c_val_with_no_mask:', len(c_val))
        with open(snapshot_path + '/log.csv', 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([
                'train dataset:',
                len(c_train), '  val dataset c_val_without_no_mask:', 476,
                '  val dataset c_val_with_no_mask:',
                len(c_val)
            ])
            writer.writerow([
                'train_batch_size:', train_batch_size, 'val_batch_size:',
                val_batch_size
            ])

        train_loader, val_loader = generate_dataset_loader_cls_seg(
            df_all, c_train, train_transform, train_batch_size, c_val,
            val_transform, val_batch_size, workers)

        if model_name == 'deep_se50':
            from semantic_segmentation.network.deepv3 import DeepSRNX50V3PlusD_m1  # r
            model = DeepSRNX50V3PlusD_m1(1, SoftDiceLoss_binary())
        elif model_name == 'unet_ef3':
            from ef_unet import EfficientNet_3_unet
            model = EfficientNet_3_unet()
        elif model_name == 'unet_ef5':
            from ef_unet import EfficientNet_5_unet
            model = EfficientNet_5_unet()
        else:
            print('No model name in it')
            model = None

        model = apex.parallel.convert_syncbn_model(model).cuda()

        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=5e-4,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=0)
        scheduler = WarmRestart(optimizer, T_max=5, T_mult=1, eta_min=1e-6)

        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
        model = torch.nn.DataParallel(model)

        loss_seg = SoftDiceLoss_binary()

        trMaxEpoch = 44
        lossMIN = 100000
        val_dice_max = 0

        for epochID in range(0, trMaxEpoch):

            start_time = time.time()
            model.train()
            trainLoss = 30
            lossTrainNorm = 0
            trainLoss_cls = 0
            trainLoss_seg = 0

            if epochID < 40:
                if epochID != 0:
                    scheduler.step()
                    scheduler = warm_restart(scheduler, T_mult=2)
            elif epochID > 39 and epochID < 42:
                optimizer.param_groups[0]['lr'] = 1e-5
            else:
                optimizer.param_groups[0]['lr'] = 5e-6

            for batchID, (input, target_seg,
                          target_cls) in enumerate(train_loader):

                if batchID == 0:
                    ss_time = time.time()
                print(str(batchID) + '/' +
                      str(int(len(c_train) / train_batch_size)) + '     ' +
                      str((time.time() - ss_time) / (batchID + 1)),
                      end='\r')
                varInput = torch.autograd.Variable(input)
                varTarget_seg = torch.autograd.Variable(
                    target_seg.contiguous().cuda(async=True))

                varOutput_seg = model(varInput)
                varTarget_seg = varTarget_seg.float()

                lossvalue_seg = loss_seg(varOutput_seg, varTarget_seg)
                trainLoss_seg = trainLoss_seg + lossvalue_seg.item()

                lossvalue = lossvalue_seg
                lossTrainNorm = lossTrainNorm + 1
                optimizer.zero_grad()
                with amp.scale_loss(lossvalue, optimizer) as scaled_loss:
                    scaled_loss.backward()
                optimizer.step()

            trainLoss_seg = trainLoss_seg / lossTrainNorm
            trainLoss = trainLoss_seg

            best_thr_with_no_mask = -1
            best_dice_with_no_mask = -1
            best_thr_without_no_mask = -1
            best_dice_without_no_mask = -1
            valLoss_seg = -1

            if epochID % 1 == 0:
                valLoss_seg, best_thr_with_no_mask, best_dice_with_no_mask, best_thr_without_no_mask, best_dice_without_no_mask = epochVal(
                    model, val_loader, loss_seg, c_val, val_batch_size
                )  # (model, dataLoader, loss_seg, c_val, val_batch_size):

            epoch_time = time.time() - start_time

            if epochID % 1 == 0:
                torch.save(
                    {
                        'epoch':
                        epochID + 1,
                        'state_dict':
                        model.state_dict(),
                        'valLoss':
                        0,
                        'best_thr_with_no_mask':
                        best_thr_with_no_mask,
                        'best_dice_with_no_mask':
                        float(best_dice_with_no_mask),
                        'best_thr_without_no_mask':
                        best_thr_without_no_mask,
                        'best_dice_without_no_mask':
                        float(best_dice_without_no_mask)
                    }, snapshot_path + '/model_epoch_' + str(epochID) + '_' +
                    str(num_fold) + '.pth.tar')

            result = [
                epochID,
                round(optimizer.state_dict()['param_groups'][0]['lr'], 6),
                round(epoch_time, 0),
                round(trainLoss, 4),
                round(valLoss_seg, 4),
                round(best_thr_with_no_mask, 3),
                round(float(best_dice_with_no_mask), 3),
                round(best_thr_without_no_mask, 3),
                round(float(best_dice_without_no_mask), 3)
            ]
            print(result)
            with open(snapshot_path + '/log.csv', 'a', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(result)
        del model
# ----------------------------------------------------------------
pred0 = (PRED_5_FOLD[0] + PRED_5_FOLD[1] + PRED_5_FOLD[2] + PRED_5_FOLD[3] + PRED_5_FOLD[4]) / 5
pred0 = pred0.numpy()
np.save('./se101_pred.npy', pred0)

del model


PRED_5_FOLD = []

for num_fold in range(5):
    print('This is {} fold processing...'.format(num_fold))

    from semantic_segmentation.network.deepv3 import DeepSRNX50V3PlusD_m1
    model = DeepSRNX50V3PlusD_m1(1, SoftDiceLoss_binary())

    model = apex.parallel.convert_syncbn_model(model).cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    model = torch.nn.DataParallel(model).cuda()
    state = torch.load('./deeplab_swa_' + str(num_fold) + '.pth.tar')

    model.load_state_dict(state['state_dict'])

    model.eval()
    outPRED = []

    for i, input in enumerate(test_loader):
    print(args.input)
    print(args.output)
    print(args.model_num)
    print(args.epoch0)
    print(args.epoch1)
    print(args.epoch2)
    print('bs is', args.batch_size)

    df_all = pd.read_csv(csv_path)
    kfold_path = path_data['k_fold_path']
    args.input = path_data['snapshot_path'] + args.input

    if args.model_num == 'deeplab_se50':
        from semantic_segmentation.network.deepv3 import DeepSRNX50V3PlusD_m1  # r
        model = DeepSRNX50V3PlusD_m1(1, None)
    elif args.model_num == 'unet_ef3':
        from ef_unet import EfficientNet_3_unet
        model = EfficientNet_3_unet()
    elif args.model_num == 'unet_ef5':
        from ef_unet import EfficientNet_5_unet
        model = EfficientNet_5_unet()
    elif args.model_num == 'unet_se50':
        model = Unet_se50_model()
    elif args.model_num == 'unet_se101':
        model = Unet_se101_model()
    else:
        model = None
        print('model is Error')

    for f_fold in range(5):