def main():

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    if not os.path.exists(args.test_debug_vis_dir):
        os.makedirs(args.test_debug_vis_dir)

    model = SegNet(model='resnet50')
    model.load_state_dict(torch.load(args.snapshot_dir + '150000.pth'))

    # freeze bn statics
    model.eval()
    model.cuda()

    dataloader = DataLoader(SegDataset(mode='test'),
                            batch_size=1,
                            shuffle=False,
                            num_workers=4)

    for i_iter, batch_data in enumerate(dataloader):

        Input_image, vis_image, gt_mask, weight_matrix, dataset_length, image_name = batch_data

        pred_mask = model(Input_image.cuda())

        print('i_iter/total {}/{}'.format(\
               i_iter, int(dataset_length[0].data)))

        if not os.path.exists(args.test_debug_vis_dir +
                              image_name[0].split('/')[0]):
            os.makedirs(args.test_debug_vis_dir + image_name[0].split('/')[0])

        vis_pred_result(vis_image, gt_mask, pred_mask,
                        args.test_debug_vis_dir + image_name[0] + '.png')
示例#2
0
def getModel(device, params):
    if params.model == 'UNet':
        model = UNet(3, 1).to(device)
    if params.model == 'resnet34_unet':
        model = resnet34_unet(1, pretrained=False).to(device)
    if params.model == 'unet++':
        params.deepsupervision = True
        model = NestedUNet(params, 3, 1).to(device)
    if params.model == 'Attention_UNet':
        model = Attention_Gate_UNet(3, 1).to(device)
    if params.model == 'segnet':
        model = SegNet(3, 1).to(device)
    if params.model == 'r2unet':
        model = R2U_Net(3, 1).to(device)
    if params.model == 'fcn32s':
        model = get_fcn32s(1).to(device)
    if params.model == 'myChannelUnet':
        model = ChannelUnet(3, 1).to(device)
    if params.model == 'fcn8s':
        assert params.dataset != 'esophagus', "fcn8s模型不能用于数据集esophagus,因为esophagus数据集为80x80,经过5次的2倍降采样后剩下2.5x2.5,分辨率不能为小数,建议把数据集resize成更高的分辨率再用于fcn"
        model = get_fcn8s(1).to(device)
    if params.model == 'cenet':
        model = CE_Net_().to(device)
    if params.model == 'smaatunet':
        model = SmaAt_UNet(3, 1).to(device)
    # if params.model == "self_attention_unet":
    #     model = get_unet_depthwise_light_encoder_attention_with_skip_connections_decoder(3,1).to(device)
    if params.model == "kiunet":
        model = kiunet().to(device)
    if params.model == "Lite_RASPP":
        model = MobileNetV3Seg(nclass=1).to(device=device)
    if params.model == "design_one":
        model = AttentionDesignOne(3, 1).to(device)
    if params.model == "design_two":
        model = AttentionDesignTwo(3, 1).to(device)
    if params.model == "design_three":
        model = AttentionDesignThree(3, 1).to(device)
    if params.model == "only_attention":
        model = Design_Attention(3, 1).to(device)
    if params.model == "only_bottleneck":
        model = Design_MRC_RMP(3, 1).to(device)
    return model
                     lr_init=lr_init,
                     lr_decay=lr_decay)
elif model_name == "deeplabv3p":
    model = Deeplabv3(input_shape=(256, 256, 3), classes=labels)
elif model_name == "deeplabv3":
    model = deeplabv3_plus(input_shape=(256, 256, 7), num_classes=labels)
elif model_name == "maskrcnn":
    modelt = modellib.MaskRCNN(mode='training',
                               config=config,
                               model_dir=MODEL_DIR)
elif model_name == 'refinenet':
    model = refinenet(input_shape=(256, 256, 5), num_classes=lebels)
    #model =build_network_resnet101(inputHeight=256,inputWidth=256,n_classes=len(labels))
    #model = build_network_resnet101_stack(inputHeight=256,inputWidth=256,n_classes=len(labels),nStack=2)
elif model_name == "segnet":
    model = SegNet(input_shape=(256, 256, 5), classes=labels)
elif model_name == "fcn32":
    model = get_model(input_shape=(256, 256, 7), num_classes=labels)
elif model_name == 'icnet':
    model = build_bn(input_shape=(256, 256, 7), n_classes=labels)
elif model_name == "lstm":
    model = model(shape=(256, 256, 7), num_classes=labels)


def flip_axis(x, axis):
    x = np.asarray(x).swapaxes(axis, 0)
    x = x[::-1, ...]
    x = x.swapaxes(0, axis)
    return x

示例#4
0
if not os.path.isdir(opt.out):
    mkdir_p(opt.out)
title = 'NYUv2'
logger = Logger(os.path.join(opt.out, 'segnet_kdmtl_' + 'log.txt'),
                title=title)
logger.set_names([
    'Epoch', 'T.Ls', 'T. mIoU', 'T. Pix', 'T.Ld', 'T.abs', 'T.rel', 'T.Ln',
    'T.Mean', 'T.Med', 'T.11', 'T.22', 'T.30', 'V.Ls', 'V. mIoU', 'V. Pix',
    'V.Ld', 'V.abs', 'V.rel', 'V.Ln', 'V.Mean', 'V.Med', 'V.11', 'V.22',
    'V.30', 'ds', 'dd', 'dh'
])

# define model, optimiser and scheduler
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu
use_cuda = torch.cuda.is_available()
model = SegNet(type_=opt.type, class_nb=13).cuda()
single_model = {}
transformers = {}

for i, t in enumerate(tasks):
    single_model[i] = SegNet(type_=opt.type, class_nb=13).cuda()
    checkpoint = torch.load(
        '{}segnet_single_model_task_{}_model_best.pth.tar'.format(
            opt.single_dir, tasks[i]))
    single_model[i].load_state_dict(checkpoint['state_dict'])
    transformers[i] = transformer().cuda()

optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
params = []
for i in range(len(tasks)):
示例#5
0
def main():

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    if not os.path.exists(args.train_debug_vis_dir):
        os.makedirs(args.train_debug_vis_dir)

    model = SegNet(model='resnet50')

    # freeze bn statics
    model.train()
    model.cuda()

    optimizer = torch.optim.SGD(params=[
        {
            "params": get_params(model, key="backbone", bias=False),
            "lr": INI_LEARNING_RATE
        },
        {
            "params": get_params(model, key="backbone", bias=True),
            "lr": 2 * INI_LEARNING_RATE
        },
        {
            "params": get_params(model, key="added", bias=False),
            "lr": 10 * INI_LEARNING_RATE
        },
        {
            "params": get_params(model, key="added", bias=True),
            "lr": 20 * INI_LEARNING_RATE
        },
    ],
                                lr=INI_LEARNING_RATE,
                                weight_decay=WEIGHT_DECAY)

    dataloader = DataLoader(SegDataset(mode='train'),
                            batch_size=8,
                            shuffle=True,
                            num_workers=4)

    global_step = 0

    for epoch in range(1, EPOCHES):

        for i_iter, batch_data in enumerate(dataloader):

            global_step += 1

            Input_image, vis_image, gt_mask, weight_matrix, dataset_length, image_name = batch_data

            optimizer.zero_grad()

            pred_mask = model(Input_image.cuda())

            loss = loss_calc(pred_mask, gt_mask, weight_matrix)

            loss.backward()

            optimizer.step()

            if global_step % 10 == 0:
                print('epoche {} i_iter/total {}/{} loss {:.4f}'.format(\
                       epoch, i_iter, int(dataset_length[0].data), loss))

            if global_step % 10000 == 0:
                vis_pred_result(
                    vis_image, gt_mask, pred_mask,
                    args.train_debug_vis_dir + str(global_step) + '.png')

            if global_step % 1e4 == 0:
                torch.save(model.state_dict(),
                           args.snapshot_dir + str(global_step) + '.pth')
示例#6
0
def main(config):

    if config.channels == 1:
        mean = [0.467]
        std = [0.271]
    elif config.channels == 3:
        mean = [0.467, 0.467, 0.467]
        std = [0.271, 0.271, 0.271]

    if config.device == -1:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:{:d}'.format(config.device))

    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    train_tfms = Compose([
        ShiftScaleRotate(rotate_limit=15, interpolation=cv2.INTER_CUBIC),
        GaussianBlur(),
        GaussNoise(),
        HorizontalFlip(),
        RandomBrightnessContrast(),
        Normalize(
            mean=mean,
            std=std,
        ),
        ToTensor()
    ])

    val_tfms = Compose([Normalize(
        mean=mean,
        std=std,
    ), ToTensor()])

    SAVEPATH = Path(config.root_dir)
    #Depending on the stage we either create train/validation or test dataset
    if config.stage == 'train':
        train_ds = EdsDS(fldr=SAVEPATH / config.train_dir,
                         channels=config.channels,
                         transform=train_tfms)
        val_ds = EdsDS(fldr=SAVEPATH / config.valid_dir,
                       channels=config.channels,
                       transform=val_tfms)

        train_loader = DataLoader(train_ds,
                                  batch_size=config.bs,
                                  shuffle=(train_sampler is None),
                                  num_workers=workers,
                                  pin_memory=True,
                                  sampler=train_sampler)

        checkpoint = 'logger'
        if not os.path.exists(checkpoint):
            os.makedirs(checkpoint)
        arch = 'segnet_'
        title = 'Eye_' + arch + 'fast_fd_g{}_'.format(config.gamma)

        logger = Logger(os.path.join(
            checkpoint, '{}e{:d}_lr{:.4f}.txt'.format(title, config.ep,
                                                      config.lr)),
                        title=title)
        logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Dice'])
    elif config.stage == 'test':
        val_ds = EdsDS(fldr=SAVEPATH / config.test_dir,
                       channels=config.channels,
                       mask=False,
                       transform=val_tfms)

    val_loader = DataLoader(val_ds,
                            batch_size=config.bs * 2,
                            shuffle=False,
                            num_workers=workers,
                            pin_memory=True)

    model = SegNet(channels=config.channels).to(device)

    criterion = DiceFocalWithLogitsLoss(gamma=config.gamma).to(device)

    optimizer = AdamW(model.parameters(),
                      lr=start_lr,
                      betas=(max_mom, 0.999),
                      weight_decay=wd)

    if config.stage == 'train':
        steps = len(train_loader) * config.ep

        schs = []
        schs.append(
            SchedulerCosine(optimizer, start_lr, config.lr, lr_mult,
                            int(steps * warmup_part), max_mom, min_mom))
        schs.append(
            SchedulerCosine(optimizer, config.lr, finish_lr, lr_mult,
                            steps - int(steps * warmup_part), min_mom,
                            max_mom))
        lr_scheduler = LR_Scheduler(schs)

        max_dice = 0

        for epoch in range(config.ep):

            print('\nEpoch: [{:d} | {:d}] LR: {:.10f}|{:.10f}'.format(
                epoch + 1, config.ep, get_lr(optimizer, -1),
                get_lr(optimizer, 0)))

            # train for one epoch
            train_loss = train(train_loader, model, criterion, optimizer,
                               lr_scheduler, device, config)

            # evaluate on validation set
            valid_loss, dice = validate(val_loader, model, criterion, device,
                                        config)

            # append logger file
            logger.append(
                [get_lr(optimizer, -1), train_loss, valid_loss, dice])

            if dice > max_dice:
                max_dice = dice
                model_state = {
                    'epoch': epoch + 1,
                    'arch': arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }
                save_model = '{}e{:d}_lr_{:.3f}_max_dice.pth.tar'.format(
                    title, config.ep, config.lr)
                torch.save(model_state, save_model)
    elif config.stage == 'test':
        checkpoint = torch.load(config.saved_model)
        model.load_state_dict(checkpoint['state_dict'])
        logits = validate(val_loader, model, criterion, device, config)
        preds = np.concatenate([torch.argmax(l, 1).numpy()
                                for l in logits]).astype(np.uint8)
        leng = len(preds)
        data = {}
        data['num_model_params'] = NUM_MODEL_PARAMS
        data['number_of_samples'] = leng
        data['labels'] = {}
        for i in range(leng):
            data['labels'][val_ds.img_paths[i].stem] = np_to_base64_utf8_str(
                preds[i])
        with open(SAVEPATH / '{}.json'.format(config.filename), 'w') as f:
            json.dump(data, f)
        with zipfile.ZipFile(SAVEPATH / '{}.zip'.format(config.filename),
                             "w",
                             compression=zipfile.ZIP_DEFLATED) as zf:
            zf.write(SAVEPATH / '{}.json'.format(config.filename))
        os.remove(SAVEPATH / '{}.json'.format(config.filename))
示例#7
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    if args.saveTest == 'True':
        args.saveTest = True
    elif args.saveTest == 'False':
        args.saveTest = False

    # Check if the save directory exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    cudnn.benchmark = True

    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST),
            transforms.TenCrop(args.resizedImageSize),
            transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
            #transforms.Lambda(lambda normalized: torch.stack([transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182])(crop) for crop in normalized]))
            #transforms.RandomResizedCrop(224, interpolation=Image.NEAREST),
            #transforms.RandomHorizontalFlip(),
            #transforms.RandomVerticalFlip(),
            #transforms.ToTensor(),
        ]),
        'test': transforms.Compose([
            transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST),
            transforms.ToTensor(),
            #transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182])
        ]),
    }

    # Data Loading
    data_dir = '/media/salman/DATA/General Datasets/MICCAI/EndoVis_2018'
    # json path for class definitions
    json_path = '/home/salman/pytorch/endovis18/datasets/endovisClasses.json'

    trainval_image_dataset = endovisDataset(os.path.join(data_dir, 'train_data'),
                        data_transforms['train'], json_path=json_path, training=True)
    val_size = int(args.validationSplit * len(trainval_image_dataset))
    train_size = len(trainval_image_dataset) - val_size
    train_image_dataset, val_image_dataset = torch.utils.data.random_split(trainval_image_dataset, [train_size,
                                                                                                       val_size])

    test_image_dataset = endovisDataset(os.path.join(data_dir, 'test_data'),
                        data_transforms['test'], json_path=json_path, training=False)



    train_dataloader = torch.utils.data.DataLoader(train_image_dataset,
                                                  batch_size=args.batchSize,
                                                  shuffle=True,
                                                  num_workers=args.workers)
    val_dataloader = torch.utils.data.DataLoader(val_image_dataset,
                                                batch_size=args.batchSize,
                                                shuffle=True,
                                                num_workers=args.workers)
    test_dataloader = torch.utils.data.DataLoader(test_image_dataset,
                                                  batch_size=args.batchSize,
                                                  shuffle=True,
                                                  num_workers=args.workers)

    train_dataset_size = len(train_image_dataset)
    val_dataset_size = len(val_image_dataset)
    test_dataset_size = len(test_image_dataset)

    # Get the dictionary for the id and RGB value pairs for the dataset
    # print(train_image_dataset.classes)
    classes = trainval_image_dataset.classes
    key = utils.disentangleKey(classes)
    num_classes = len(key)

    # Initialize the model
    model = SegNet(batchNorm_momentum=args.bnMomentum , num_classes=num_classes)

    # # Optionally resume from a checkpoint
    # if args.resume:
    #     if os.path.isfile(args.resume):
    #         print("=> loading checkpoint '{}'".format(args.resume))
    #         checkpoint = torch.load(args.resume)
    #         #args.start_epoch = checkpoint['epoch']
    #         pretrained_dict = checkpoint['state_dict']
    #         pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model.state_dict()}
    #         model.state_dict().update(pretrained_dict)
    #         model.load_state_dict(model.state_dict())
    #         print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    #     else:
    #         print("=> no checkpoint found at '{}'".format(args.resume))
    #
    #     # # Freeze the encoder weights
    #     # for param in model.encoder.parameters():
    #     #     param.requires_grad = False
    #
    #     optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd)
    # else:
    optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd)

    # Load the saved model
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

    print(model)

    # Define loss function (criterion)
    criterion = nn.CrossEntropyLoss()

    # Use a learning rate scheduler
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    if use_gpu:
        model.cuda()
        criterion.cuda()

    # Initialize an evaluation Object
    evaluator = utils.Evaluate(key, use_gpu)

    for epoch in range(args.start_epoch, args.epochs):
        #adjust_learning_rate(optimizer, epoch)

        # Train for one epoch
        print('>>>>>>>>>>>>>>>>>>>>>>>Training<<<<<<<<<<<<<<<<<<<<<<<')
        train(train_dataloader, model, criterion, optimizer, scheduler, epoch, key)

        # Evaulate on validation set
        print('>>>>>>>>>>>>>>>>>>>>>>>Testing<<<<<<<<<<<<<<<<<<<<<<<')
        validate(val_dataloader, model, criterion, epoch, key, evaluator)

        # Calculate the metrics
        print('>>>>>>>>>>>>>>>>>> Evaluating the Metrics <<<<<<<<<<<<<<<<<')
        IoU = evaluator.getIoU()
        print('Mean IoU: {}, Class-wise IoU: {}'.format(torch.mean(IoU), IoU))
        writer.add_scalar('Epoch Mean IoU', torch.mean(IoU), epoch)
        PRF1 = evaluator.getPRF1()
        precision, recall, F1 = PRF1[0], PRF1[1], PRF1[2]
        print('Mean Precision: {}, Class-wise Precision: {}'.format(torch.mean(precision), precision))
        writer.add_scalar('Epoch Mean Precision', torch.mean(precision), epoch)
        print('Mean Recall: {}, Class-wise Recall: {}'.format(torch.mean(recall), recall))
        writer.add_scalar('Epoch Mean Recall', torch.mean(recall), epoch)
        print('Mean F1: {}, Class-wise F1: {}'.format(torch.mean(F1), F1))
        writer.add_scalar('Epoch Mean F1', torch.mean(F1), epoch)
        evaluator.reset()

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, filename=os.path.join(args.save_dir, 'checkpoint_{}.tar'.format(epoch)))
def main():
    global args
    args = parser.parse_args()
    print(args)

    if args.saveTest == 'True':
        args.saveTest = True
    elif args.saveTest == 'False':
        args.saveTest = False

    # Check if the save directory exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    cudnn.benchmark = True

    data_transform = transforms.Compose([
        transforms.Resize((args.imageSize, args.imageSize),
                          interpolation=Image.NEAREST),
        transforms.ToTensor(),
    ])

    # Data Loading
    data_dir = '/home/salman/pytorch/segmentationNetworks/datasets/miccaiSegOrgans'
    # json path for class definitions
    json_path = '/home/salman/pytorch/segmentationNetworks/datasets/miccaiSegOrganClasses.json'

    image_dataset = miccaiSegDataset(os.path.join(data_dir, 'test'),
                                     data_transform, json_path)

    dataloader = torch.utils.data.DataLoader(image_dataset,
                                             batch_size=args.batchSize,
                                             shuffle=True,
                                             num_workers=args.workers)

    # Get the dictionary for the id and RGB value pairs for the dataset
    classes = image_dataset.classes
    key = utils.disentangleKey(classes)
    num_classes = len(key)

    # Initialize the model
    model = SegNet(args.bnMomentum, num_classes)

    # Load the saved model
    if os.path.isfile(args.model):
        print("=> loading checkpoint '{}'".format(args.model))
        checkpoint = torch.load(args.model)
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.model))

    print(model)

    # Define loss function (criterion)
    criterion = nn.CrossEntropyLoss()

    if use_gpu:
        model.cuda()
        criterion.cuda()

    # Initialize an evaluation Object
    evaluator = utils.Evaluate(key, use_gpu)

    # Evaulate on validation/test set
    print('>>>>>>>>>>>>>>>>>>>>>>>Testing<<<<<<<<<<<<<<<<<<<<<<<')
    validate(dataloader, model, criterion, key, evaluator)

    # Calculate the metrics
    print('>>>>>>>>>>>>>>>>>> Evaluating the Metrics <<<<<<<<<<<<<<<<<')
    IoU = evaluator.getIoU()
    print('Mean IoU: {}, Class-wise IoU: {}'.format(torch.mean(IoU), IoU))
    PRF1 = evaluator.getPRF1()
    precision, recall, F1 = PRF1[0], PRF1[1], PRF1[2]
    print('Mean Precision: {}, Class-wise Precision: {}'.format(
        torch.mean(precision), precision))
    print('Mean Recall: {}, Class-wise Recall: {}'.format(
        torch.mean(recall), recall))
    print('Mean F1: {}, Class-wise F1: {}'.format(torch.mean(F1), F1))
elif model_name == "unet":
    model = unet(input_shape=(256, 256, 7),
                 num_classes=labels,
                 lr_init=1e-3,
                 lr_decay=5e-4)
elif model_name == "pspnet":
    model = pspnet50(input_shape=(256, 256, 5),
                     num_classes=labels,
                     lr_init=1e-3,
                     lr_decay=5e-4)
elif model_name == 'deeplabv3p':
    model = Deeplabv3(input_shape=(256, 256, 5), classes=labels)
elif model_name == "deeplabv3":
    model = deeplabv3_plus(input_shape=(256, 256, 5), num_classes=labels)
elif model_name == "segnet":
    model = SegNet(input_shape=(256, 256, 5), classes=labels)
elif model_name == "refinenet":
    model = refinenet(input_shape=(256, 256, 5), num_classes=labels)
model.load_weights("h5File/unet_model_weight.h5")

#model.load_weights("h5File/"+model_name+'_model_weight.h5')
print("load model successfully")

x_img = tifffile.imread("/data/test_h/15out.tif") / 255
ocr = np.zeros((x_img.shape[0] + 235, x_img.shape[1] + 277, 7), 'float16')
ocr[0:3093, 0:3051, :] = x_img
ocr[3093:, 3051, :] = 0
tmp = np.zeros((x_img.shape[0] + 235, x_img.shape[1] + 277))
for i in range(int(ocr.shape[0] / 128) - 1):
    for j in range(int(ocr.shape[1] / 128) - 1):
        pred = model.predict(
示例#10
0
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    if 'deeplab' in args.model_name:
        if 'resnet101' in args.model_name:
            net = Deeplabv3plus(nInputChannels=3, n_classes=args.num_classes, os=args.output_stride,
                                backbone_type='resnet101')
        elif 'resnet50' in args.model_name:
            net = Deeplabv3plus(nInputChannels=3, n_classes=args.num_classes, os=args.output_stride,
                                backbone_type='resnet50')
        elif 'resnet34' in args.model_name:
            net = Deeplabv3plus(nInputChannels=3, n_classes=args.num_classes, os=args.output_stride,
                                backbone_type='resnet34')
    elif 'unet' in args.model_name:
        net = Unet(in_ch=3, out_ch=1)
    elif 'trfe' in args.model_name:
        if args.model_name == 'trfe':
            net = TRFENet(in_ch=3, out_ch=1)
        elif args.model_name == 'trfe1':
            net = TRFENet1(in_ch=3, out_ch=1)
        elif args.model_name == 'trfe2':
            net = TRFENet2(in_ch=3, out_ch=1)
    elif 'mtnet' in args.model_name:
        net = MTNet(in_ch=3, out_ch=1)
    elif 'segnet' in args.model_name:
        net = SegNet(input_channels=3, output_channels=1)
    elif 'fcn' in args.model_name:
        net = FCN8s(1)
    else:
        raise NotImplementedError
    net.load_state_dict(torch.load(args.load_path))
    net.cuda()

    composed_transforms_ts = transforms.Compose([
        trforms.FixedResize(size=(args.input_size, args.input_size)),
        trforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        trforms.ToTensor()])

    if args.test_dataset == 'TN3K':
        test_data = tn3k.TN3K(mode='test', transform=composed_transforms_ts, return_size=True)

    save_dir = args.save_dir + args.test_fold + '-' + args.test_dataset + os.sep + args.model_name + os.sep
    testloader = DataLoader(test_data, batch_size=1, shuffle=False, num_workers=0)
    num_iter_ts = len(testloader)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    net.cuda()
    net.eval()
    start_time = time.time()
    with torch.no_grad():
        total_iou = 0
        for sample_batched in tqdm(testloader):
            inputs, labels, label_name, size = sample_batched['image'], sample_batched['label'], sample_batched[
                'label_name'], sample_batched['size']
            inputs = Variable(inputs, requires_grad=False)
            labels = Variable(labels)
            labels = labels.cuda()
            inputs = inputs.cuda()
            if 'trfe' in args.model_name or 'mtnet' in args.model_name:
                outputs, _ = net.forward(inputs)
            else:
                outputs = net.forward(inputs)
            prob_pred = torch.sigmoid(outputs)
            iou = utils.get_iou(prob_pred, labels)
            total_iou += iou

            shape = (size[0, 0], size[0, 1])
            prob_pred = F.interpolate(prob_pred, size=shape, mode='bilinear', align_corners=True).cpu().data
            save_data = prob_pred[0]
            save_png = save_data[0].numpy()
            save_png = np.round(save_png)
            save_png = save_png * 255
            save_png = save_png.astype(np.uint8)
            save_path = save_dir + label_name[0]
            if not os.path.exists(save_path[:save_path.rfind('/')]):
                os.makedirs(save_path[:save_path.rfind('/')])
            cv2.imwrite(save_dir + label_name[0], save_png)

    print(args.model_name + ' iou:' + str(total_iou / len(testloader)))
    duration = time.time() - start_time
    print("-- %s contain %d images, cost time: %.4f s, speed: %.4f s." % (
        args.test_dataset, num_iter_ts, duration, duration / num_iter_ts))
    print("------------------------------------------------------------------")
示例#11
0
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
    if args.resume_epoch != 0:
        runs = sorted(glob.glob(os.path.join(save_dir_root, 'run', 'run_*')))
        run_id = int(runs[-1].split('_')[-1]) if runs else 0
    else:
        runs = sorted(glob.glob(os.path.join(save_dir_root, 'run', 'run_*')))
        run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0

    if args.run_id >= 0:
        run_id = args.run_id

    save_dir = os.path.join(save_dir_root, 'run', 'run_' + str(run_id))
    log_dir = os.path.join(
        save_dir,
        datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    writer = SummaryWriter(log_dir=log_dir)
    batch_size = args.batch_size

    if 'deeplab' in args.model_name:
        if 'resnet101' in args.model_name:
            net = Deeplabv3plus(nInputChannels=3,
                                n_classes=args.num_classes,
                                os=args.output_stride,
                                backbone_type='resnet101')
        elif 'resnet50' in args.model_name:
            net = Deeplabv3plus(nInputChannels=3,
                                n_classes=args.num_classes,
                                os=args.output_stride,
                                backbone_type='resnet50')
        elif 'resnet34' in args.model_name:
            net = Deeplabv3plus(nInputChannels=3,
                                n_classes=args.num_classes,
                                os=args.output_stride,
                                backbone_type='resnet34')
        else:
            raise NotImplementedError
    elif 'unet' in args.model_name:
        net = Unet(in_ch=3, out_ch=1)
    elif 'trfe' in args.model_name:
        if args.model_name == 'trfe1':
            net = TRFENet1(in_ch=3, out_ch=1)
        elif args.model_name == 'trfe2':
            net = TRFENet2(in_ch=3, out_ch=1)
        elif args.model_name == 'trfe':
            net = TRFENet(in_ch=3, out_ch=1)
        batch_size = 4
    elif 'mtnet' in args.model_name:
        net = MTNet(in_ch=3, out_ch=1)
        batch_size = 4
    elif 'segnet' in args.model_name:
        net = SegNet(input_channels=3, output_channels=1)
    elif 'fcn' in args.model_name:
        net = FCN8s(1)
    else:
        raise NotImplementedError

    if args.resume_epoch == 0:
        print('Training ' + args.model_name + ' from scratch...')
    else:
        load_path = os.path.join(
            save_dir,
            args.model_name + '_epoch-' + str(args.resume_epoch) + '.pth')
        print('Initializing weights from: {}...'.format(load_path))
        net.load_state_dict(torch.load(load_path))

    if args.pretrain == 'THYROID':
        net.load_state_dict(
            torch.load('./pre_train/thyroid-pretrain.pth',
                       map_location=lambda storage, loc: storage))
        print('loading pretrain model......')

    torch.cuda.set_device(device=0)
    net.cuda()

    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum)

    if args.criterion == 'Dice':
        criterion = soft_dice
    else:
        raise NotImplementedError

    composed_transforms_tr = transforms.Compose([
        trforms.FixedResize(size=(args.input_size, args.input_size)),
        trforms.RandomHorizontalFlip(),
        trforms.Normalize(mean=(0.485, 0.456, 0.406),
                          std=(0.229, 0.224, 0.225)),
        trforms.ToTensor()
    ])

    composed_transforms_ts = transforms.Compose([
        trforms.FixedResize(size=(args.input_size, args.input_size)),
        trforms.Normalize(mean=(0.485, 0.456, 0.406),
                          std=(0.229, 0.224, 0.225)),
        trforms.ToTensor()
    ])

    if args.dataset == 'TN3K':
        train_data = tn3k.TN3K(mode='train',
                               transform=composed_transforms_tr,
                               fold=args.fold)
        val_data = tn3k.TN3K(mode='val',
                             transform=composed_transforms_ts,
                             fold=args.fold)
    elif args.dataset == 'TG3K':
        train_data = tg3k.TG3K(mode='train', transform=composed_transforms_tr)
        val_data = tg3k.TG3K(mode='val', transform=composed_transforms_ts)
    elif args.dataset == 'TATN':
        train_data = tatn.TATN(mode='train',
                               transform=composed_transforms_tr,
                               fold=args.fold)
        val_data = tatn.TATN(mode='val',
                             transform=composed_transforms_ts,
                             fold=args.fold)

    trainloader = DataLoader(train_data,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=0)
    testloader = DataLoader(val_data,
                            batch_size=1,
                            shuffle=False,
                            num_workers=0)

    num_iter_tr = len(trainloader)
    num_iter_ts = len(testloader)
    nitrs = args.resume_epoch * num_iter_tr
    nsamples = args.resume_epoch * len(train_data)
    print('nitrs: %d num_iter_tr: %d' % (nitrs, num_iter_tr))
    print('nsamples: %d tot_num_samples: %d' % (nsamples, len(train_data)))

    aveGrad = 0
    global_step = 0
    recent_losses = []
    start_t = time.time()

    best_f, cur_f = 0.0, 0.0
    for epoch in range(args.resume_epoch, args.nepochs):
        net.train()
        epoch_losses = []
        for ii, sample_batched in enumerate(trainloader):
            if 'trfe' in args.model_name or args.model_name == 'mtnet':
                nodules, glands = sample_batched
                inputs_n, labels_n = nodules['image'].cuda(
                ), nodules['label'].cuda()
                inputs_g, labels_g = glands['image'].cuda(
                ), glands['label'].cuda()
                inputs = torch.cat(
                    [inputs_n[0].unsqueeze(0), inputs_g[0].unsqueeze(0)],
                    dim=0)

                for i in range(1, inputs_n.size()[0]):
                    inputs = torch.cat([inputs, inputs_n[i].unsqueeze(0)],
                                       dim=0)
                    inputs = torch.cat([inputs, inputs_g[i].unsqueeze(0)],
                                       dim=0)

                global_step += inputs.data.shape[0]
                nodule, thyroid = net.forward(inputs)
                loss = 0
                for i in range(inputs.size()[0]):
                    if i % 2 == 0:
                        loss += criterion(nodule[i],
                                          labels_n[int(i / 2)],
                                          size_average=False,
                                          batch_average=True)
                    else:
                        loss += 0.5 * criterion(thyroid[i],
                                                labels_g[int((i - 1) / 2)],
                                                size_average=False,
                                                batch_average=True)

            else:
                inputs, labels = sample_batched['image'].cuda(
                ), sample_batched['label'].cuda()
                global_step += inputs.data.shape[0]

                outputs = net.forward(inputs)
                loss = criterion(outputs,
                                 labels,
                                 size_average=False,
                                 batch_average=True)

            trainloss = loss.item()
            epoch_losses.append(trainloss)
            if len(recent_losses) < args.log_every:
                recent_losses.append(trainloss)
            else:
                recent_losses[nitrs % len(recent_losses)] = trainloss

            # Backward the averaged gradient
            loss.backward()
            aveGrad += 1
            nitrs += 1
            nsamples += args.batch_size

            # Update the weights once in p['nAveGrad'] forward passes
            if aveGrad % args.naver_grad == 0:
                optimizer.step()
                optimizer.zero_grad()
                aveGrad = 0

            if nitrs % args.log_every == 0:
                meanloss = sum(recent_losses) / len(recent_losses)
                print('epoch: %d ii: %d trainloss: %.2f timecost:%.2f secs' %
                      (epoch, ii, meanloss, time.time() - start_t))
                writer.add_scalar('data/trainloss', meanloss, nsamples)

        meanloss = sum(epoch_losses) / len(epoch_losses)
        print('epoch: %d meanloss: %.2f' % (epoch, meanloss))
        writer.add_scalar('data/epochloss', meanloss, nsamples)

        if args.use_test == 1:
            prec_lists = []
            recall_lists = []
            sum_testloss = 0.0
            total_mae = 0.0
            cnt = 0
            count = 0
            iou = 0
            if args.use_eval == 1:
                net.eval()
            for ii, sample_batched in enumerate(testloader):
                inputs, labels = sample_batched['image'].cuda(
                ), sample_batched['label'].cuda()
                with torch.no_grad():
                    if 'trfe' in args.model_name or args.model_name == 'mtnet':
                        outputs, _ = net.forward(inputs)
                    else:
                        outputs = net.forward(inputs)

                loss = criterion(outputs,
                                 labels,
                                 size_average=False,
                                 batch_average=True)
                sum_testloss += loss.item()

                predictions = torch.sigmoid(outputs)

                iou += utils.get_iou(predictions, labels)
                count += 1

                total_mae += utils.get_mae(predictions,
                                           labels) * predictions.size(0)
                prec_list, recall_list = utils.get_prec_recall(
                    predictions, labels)
                prec_lists.extend(prec_list)
                recall_lists.extend(recall_list)
                cnt += predictions.size(0)

                if ii % num_iter_ts == num_iter_ts - 1:
                    mmae = total_mae / cnt
                    mean_testloss = sum_testloss / num_iter_ts
                    mean_prec = sum(prec_lists) / len(prec_lists)
                    mean_recall = sum(recall_lists) / len(recall_lists)
                    fbeta = 1.3 * mean_prec * mean_recall / (0.3 * mean_prec +
                                                             mean_recall)
                    iou = iou / count

                    print('Validation:')
                    print(
                        'epoch: %d, numImages: %d testloss: %.2f mmae: %.4f fbeta: %.4f iou: %.4f'
                        % (epoch, cnt, mean_testloss, mmae, fbeta, iou))
                    writer.add_scalar('data/validloss', mean_testloss,
                                      nsamples)
                    writer.add_scalar('data/validmae', mmae, nsamples)
                    writer.add_scalar('data/validfbeta', fbeta, nsamples)
                    writer.add_scalar('data/validiou', iou, epoch)

                    cur_f = iou
                    if cur_f > best_f:
                        save_path = os.path.join(
                            save_dir, args.model_name + '_best' + '.pth')
                        torch.save(net.state_dict(), save_path)
                        print("Save model at {}\n".format(save_path))
                        best_f = cur_f

        if epoch % args.save_every == args.save_every - 1:
            save_path = os.path.join(
                save_dir, args.model_name + '_epoch-' + str(epoch) + '.pth')
            torch.save(net.state_dict(), save_path)
            print("Save model at {}\n".format(save_path))