Beispiel #1
0
        avg_cost[index, 16], avg_cost[index, 17], avg_cost[index,
                                                           18], avg_cost[index,
                                                                         19],
        avg_cost[index, 20], avg_cost[index, 21], avg_cost[index,
                                                           22], avg_cost[index,
                                                                         23],
        dist_loss_save[0].avg, dist_loss_save[1].avg, dist_loss_save[2].avg
    ])

    if isbest:
        best_loss = loss_index
        print_index = index
    save_checkpoint(
        {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_loss': best_loss,
            'optimizer': optimizer.state_dict(),
        }, isbest)
print(
    'Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} '
    'TEST: {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} '
    .format(print_index, avg_cost[print_index,
                                  0], avg_cost[print_index,
                                               1], avg_cost[print_index, 2],
            avg_cost[print_index,
                     3], avg_cost[print_index,
                                  4], avg_cost[print_index,
                                               5], avg_cost[print_index, 6],
            avg_cost[print_index,
                     7], avg_cost[print_index,
Beispiel #2
0
def main(config):

    if config.channels == 1:
        mean = [0.467]
        std = [0.271]
    elif config.channels == 3:
        mean = [0.467, 0.467, 0.467]
        std = [0.271, 0.271, 0.271]

    if config.device == -1:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:{:d}'.format(config.device))

    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    train_tfms = Compose([
        ShiftScaleRotate(rotate_limit=15, interpolation=cv2.INTER_CUBIC),
        GaussianBlur(),
        GaussNoise(),
        HorizontalFlip(),
        RandomBrightnessContrast(),
        Normalize(
            mean=mean,
            std=std,
        ),
        ToTensor()
    ])

    val_tfms = Compose([Normalize(
        mean=mean,
        std=std,
    ), ToTensor()])

    SAVEPATH = Path(config.root_dir)
    #Depending on the stage we either create train/validation or test dataset
    if config.stage == 'train':
        train_ds = EdsDS(fldr=SAVEPATH / config.train_dir,
                         channels=config.channels,
                         transform=train_tfms)
        val_ds = EdsDS(fldr=SAVEPATH / config.valid_dir,
                       channels=config.channels,
                       transform=val_tfms)

        train_loader = DataLoader(train_ds,
                                  batch_size=config.bs,
                                  shuffle=(train_sampler is None),
                                  num_workers=workers,
                                  pin_memory=True,
                                  sampler=train_sampler)

        checkpoint = 'logger'
        if not os.path.exists(checkpoint):
            os.makedirs(checkpoint)
        arch = 'segnet_'
        title = 'Eye_' + arch + 'fast_fd_g{}_'.format(config.gamma)

        logger = Logger(os.path.join(
            checkpoint, '{}e{:d}_lr{:.4f}.txt'.format(title, config.ep,
                                                      config.lr)),
                        title=title)
        logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Dice'])
    elif config.stage == 'test':
        val_ds = EdsDS(fldr=SAVEPATH / config.test_dir,
                       channels=config.channels,
                       mask=False,
                       transform=val_tfms)

    val_loader = DataLoader(val_ds,
                            batch_size=config.bs * 2,
                            shuffle=False,
                            num_workers=workers,
                            pin_memory=True)

    model = SegNet(channels=config.channels).to(device)

    criterion = DiceFocalWithLogitsLoss(gamma=config.gamma).to(device)

    optimizer = AdamW(model.parameters(),
                      lr=start_lr,
                      betas=(max_mom, 0.999),
                      weight_decay=wd)

    if config.stage == 'train':
        steps = len(train_loader) * config.ep

        schs = []
        schs.append(
            SchedulerCosine(optimizer, start_lr, config.lr, lr_mult,
                            int(steps * warmup_part), max_mom, min_mom))
        schs.append(
            SchedulerCosine(optimizer, config.lr, finish_lr, lr_mult,
                            steps - int(steps * warmup_part), min_mom,
                            max_mom))
        lr_scheduler = LR_Scheduler(schs)

        max_dice = 0

        for epoch in range(config.ep):

            print('\nEpoch: [{:d} | {:d}] LR: {:.10f}|{:.10f}'.format(
                epoch + 1, config.ep, get_lr(optimizer, -1),
                get_lr(optimizer, 0)))

            # train for one epoch
            train_loss = train(train_loader, model, criterion, optimizer,
                               lr_scheduler, device, config)

            # evaluate on validation set
            valid_loss, dice = validate(val_loader, model, criterion, device,
                                        config)

            # append logger file
            logger.append(
                [get_lr(optimizer, -1), train_loss, valid_loss, dice])

            if dice > max_dice:
                max_dice = dice
                model_state = {
                    'epoch': epoch + 1,
                    'arch': arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }
                save_model = '{}e{:d}_lr_{:.3f}_max_dice.pth.tar'.format(
                    title, config.ep, config.lr)
                torch.save(model_state, save_model)
    elif config.stage == 'test':
        checkpoint = torch.load(config.saved_model)
        model.load_state_dict(checkpoint['state_dict'])
        logits = validate(val_loader, model, criterion, device, config)
        preds = np.concatenate([torch.argmax(l, 1).numpy()
                                for l in logits]).astype(np.uint8)
        leng = len(preds)
        data = {}
        data['num_model_params'] = NUM_MODEL_PARAMS
        data['number_of_samples'] = leng
        data['labels'] = {}
        for i in range(leng):
            data['labels'][val_ds.img_paths[i].stem] = np_to_base64_utf8_str(
                preds[i])
        with open(SAVEPATH / '{}.json'.format(config.filename), 'w') as f:
            json.dump(data, f)
        with zipfile.ZipFile(SAVEPATH / '{}.zip'.format(config.filename),
                             "w",
                             compression=zipfile.ZIP_DEFLATED) as zf:
            zf.write(SAVEPATH / '{}.json'.format(config.filename))
        os.remove(SAVEPATH / '{}.json'.format(config.filename))
Beispiel #3
0
def main():

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    if not os.path.exists(args.train_debug_vis_dir):
        os.makedirs(args.train_debug_vis_dir)

    model = SegNet(model='resnet50')

    # freeze bn statics
    model.train()
    model.cuda()

    optimizer = torch.optim.SGD(params=[
        {
            "params": get_params(model, key="backbone", bias=False),
            "lr": INI_LEARNING_RATE
        },
        {
            "params": get_params(model, key="backbone", bias=True),
            "lr": 2 * INI_LEARNING_RATE
        },
        {
            "params": get_params(model, key="added", bias=False),
            "lr": 10 * INI_LEARNING_RATE
        },
        {
            "params": get_params(model, key="added", bias=True),
            "lr": 20 * INI_LEARNING_RATE
        },
    ],
                                lr=INI_LEARNING_RATE,
                                weight_decay=WEIGHT_DECAY)

    dataloader = DataLoader(SegDataset(mode='train'),
                            batch_size=8,
                            shuffle=True,
                            num_workers=4)

    global_step = 0

    for epoch in range(1, EPOCHES):

        for i_iter, batch_data in enumerate(dataloader):

            global_step += 1

            Input_image, vis_image, gt_mask, weight_matrix, dataset_length, image_name = batch_data

            optimizer.zero_grad()

            pred_mask = model(Input_image.cuda())

            loss = loss_calc(pred_mask, gt_mask, weight_matrix)

            loss.backward()

            optimizer.step()

            if global_step % 10 == 0:
                print('epoche {} i_iter/total {}/{} loss {:.4f}'.format(\
                       epoch, i_iter, int(dataset_length[0].data), loss))

            if global_step % 10000 == 0:
                vis_pred_result(
                    vis_image, gt_mask, pred_mask,
                    args.train_debug_vis_dir + str(global_step) + '.png')

            if global_step % 1e4 == 0:
                torch.save(model.state_dict(),
                           args.snapshot_dir + str(global_step) + '.pth')
Beispiel #4
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    if args.saveTest == 'True':
        args.saveTest = True
    elif args.saveTest == 'False':
        args.saveTest = False

    # Check if the save directory exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    cudnn.benchmark = True

    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST),
            transforms.TenCrop(args.resizedImageSize),
            transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
            #transforms.Lambda(lambda normalized: torch.stack([transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182])(crop) for crop in normalized]))
            #transforms.RandomResizedCrop(224, interpolation=Image.NEAREST),
            #transforms.RandomHorizontalFlip(),
            #transforms.RandomVerticalFlip(),
            #transforms.ToTensor(),
        ]),
        'test': transforms.Compose([
            transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST),
            transforms.ToTensor(),
            #transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182])
        ]),
    }

    # Data Loading
    data_dir = '/media/salman/DATA/General Datasets/MICCAI/EndoVis_2018'
    # json path for class definitions
    json_path = '/home/salman/pytorch/endovis18/datasets/endovisClasses.json'

    trainval_image_dataset = endovisDataset(os.path.join(data_dir, 'train_data'),
                        data_transforms['train'], json_path=json_path, training=True)
    val_size = int(args.validationSplit * len(trainval_image_dataset))
    train_size = len(trainval_image_dataset) - val_size
    train_image_dataset, val_image_dataset = torch.utils.data.random_split(trainval_image_dataset, [train_size,
                                                                                                       val_size])

    test_image_dataset = endovisDataset(os.path.join(data_dir, 'test_data'),
                        data_transforms['test'], json_path=json_path, training=False)



    train_dataloader = torch.utils.data.DataLoader(train_image_dataset,
                                                  batch_size=args.batchSize,
                                                  shuffle=True,
                                                  num_workers=args.workers)
    val_dataloader = torch.utils.data.DataLoader(val_image_dataset,
                                                batch_size=args.batchSize,
                                                shuffle=True,
                                                num_workers=args.workers)
    test_dataloader = torch.utils.data.DataLoader(test_image_dataset,
                                                  batch_size=args.batchSize,
                                                  shuffle=True,
                                                  num_workers=args.workers)

    train_dataset_size = len(train_image_dataset)
    val_dataset_size = len(val_image_dataset)
    test_dataset_size = len(test_image_dataset)

    # Get the dictionary for the id and RGB value pairs for the dataset
    # print(train_image_dataset.classes)
    classes = trainval_image_dataset.classes
    key = utils.disentangleKey(classes)
    num_classes = len(key)

    # Initialize the model
    model = SegNet(batchNorm_momentum=args.bnMomentum , num_classes=num_classes)

    # # Optionally resume from a checkpoint
    # if args.resume:
    #     if os.path.isfile(args.resume):
    #         print("=> loading checkpoint '{}'".format(args.resume))
    #         checkpoint = torch.load(args.resume)
    #         #args.start_epoch = checkpoint['epoch']
    #         pretrained_dict = checkpoint['state_dict']
    #         pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model.state_dict()}
    #         model.state_dict().update(pretrained_dict)
    #         model.load_state_dict(model.state_dict())
    #         print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    #     else:
    #         print("=> no checkpoint found at '{}'".format(args.resume))
    #
    #     # # Freeze the encoder weights
    #     # for param in model.encoder.parameters():
    #     #     param.requires_grad = False
    #
    #     optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd)
    # else:
    optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd)

    # Load the saved model
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

    print(model)

    # Define loss function (criterion)
    criterion = nn.CrossEntropyLoss()

    # Use a learning rate scheduler
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    if use_gpu:
        model.cuda()
        criterion.cuda()

    # Initialize an evaluation Object
    evaluator = utils.Evaluate(key, use_gpu)

    for epoch in range(args.start_epoch, args.epochs):
        #adjust_learning_rate(optimizer, epoch)

        # Train for one epoch
        print('>>>>>>>>>>>>>>>>>>>>>>>Training<<<<<<<<<<<<<<<<<<<<<<<')
        train(train_dataloader, model, criterion, optimizer, scheduler, epoch, key)

        # Evaulate on validation set
        print('>>>>>>>>>>>>>>>>>>>>>>>Testing<<<<<<<<<<<<<<<<<<<<<<<')
        validate(val_dataloader, model, criterion, epoch, key, evaluator)

        # Calculate the metrics
        print('>>>>>>>>>>>>>>>>>> Evaluating the Metrics <<<<<<<<<<<<<<<<<')
        IoU = evaluator.getIoU()
        print('Mean IoU: {}, Class-wise IoU: {}'.format(torch.mean(IoU), IoU))
        writer.add_scalar('Epoch Mean IoU', torch.mean(IoU), epoch)
        PRF1 = evaluator.getPRF1()
        precision, recall, F1 = PRF1[0], PRF1[1], PRF1[2]
        print('Mean Precision: {}, Class-wise Precision: {}'.format(torch.mean(precision), precision))
        writer.add_scalar('Epoch Mean Precision', torch.mean(precision), epoch)
        print('Mean Recall: {}, Class-wise Recall: {}'.format(torch.mean(recall), recall))
        writer.add_scalar('Epoch Mean Recall', torch.mean(recall), epoch)
        print('Mean F1: {}, Class-wise F1: {}'.format(torch.mean(F1), F1))
        writer.add_scalar('Epoch Mean F1', torch.mean(F1), epoch)
        evaluator.reset()

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, filename=os.path.join(args.save_dir, 'checkpoint_{}.tar'.format(epoch)))