def main():

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    if not os.path.exists(args.test_debug_vis_dir):
        os.makedirs(args.test_debug_vis_dir)

    model = SegNet(model='resnet50')
    model.load_state_dict(torch.load(args.snapshot_dir + '150000.pth'))

    # freeze bn statics
    model.eval()
    model.cuda()

    dataloader = DataLoader(SegDataset(mode='test'),
                            batch_size=1,
                            shuffle=False,
                            num_workers=4)

    for i_iter, batch_data in enumerate(dataloader):

        Input_image, vis_image, gt_mask, weight_matrix, dataset_length, image_name = batch_data

        pred_mask = model(Input_image.cuda())

        print('i_iter/total {}/{}'.format(\
               i_iter, int(dataset_length[0].data)))

        if not os.path.exists(args.test_debug_vis_dir +
                              image_name[0].split('/')[0]):
            os.makedirs(args.test_debug_vis_dir + image_name[0].split('/')[0])

        vis_pred_result(vis_image, gt_mask, pred_mask,
                        args.test_debug_vis_dir + image_name[0] + '.png')
Beispiel #2
0
def main():

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    if not os.path.exists(args.train_debug_vis_dir):
        os.makedirs(args.train_debug_vis_dir)

    model = SegNet(model='resnet50')

    # freeze bn statics
    model.train()
    model.cuda()

    optimizer = torch.optim.SGD(params=[
        {
            "params": get_params(model, key="backbone", bias=False),
            "lr": INI_LEARNING_RATE
        },
        {
            "params": get_params(model, key="backbone", bias=True),
            "lr": 2 * INI_LEARNING_RATE
        },
        {
            "params": get_params(model, key="added", bias=False),
            "lr": 10 * INI_LEARNING_RATE
        },
        {
            "params": get_params(model, key="added", bias=True),
            "lr": 20 * INI_LEARNING_RATE
        },
    ],
                                lr=INI_LEARNING_RATE,
                                weight_decay=WEIGHT_DECAY)

    dataloader = DataLoader(SegDataset(mode='train'),
                            batch_size=8,
                            shuffle=True,
                            num_workers=4)

    global_step = 0

    for epoch in range(1, EPOCHES):

        for i_iter, batch_data in enumerate(dataloader):

            global_step += 1

            Input_image, vis_image, gt_mask, weight_matrix, dataset_length, image_name = batch_data

            optimizer.zero_grad()

            pred_mask = model(Input_image.cuda())

            loss = loss_calc(pred_mask, gt_mask, weight_matrix)

            loss.backward()

            optimizer.step()

            if global_step % 10 == 0:
                print('epoche {} i_iter/total {}/{} loss {:.4f}'.format(\
                       epoch, i_iter, int(dataset_length[0].data), loss))

            if global_step % 10000 == 0:
                vis_pred_result(
                    vis_image, gt_mask, pred_mask,
                    args.train_debug_vis_dir + str(global_step) + '.png')

            if global_step % 1e4 == 0:
                torch.save(model.state_dict(),
                           args.snapshot_dir + str(global_step) + '.pth')
Beispiel #3
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    if args.saveTest == 'True':
        args.saveTest = True
    elif args.saveTest == 'False':
        args.saveTest = False

    # Check if the save directory exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    cudnn.benchmark = True

    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST),
            transforms.TenCrop(args.resizedImageSize),
            transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
            #transforms.Lambda(lambda normalized: torch.stack([transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182])(crop) for crop in normalized]))
            #transforms.RandomResizedCrop(224, interpolation=Image.NEAREST),
            #transforms.RandomHorizontalFlip(),
            #transforms.RandomVerticalFlip(),
            #transforms.ToTensor(),
        ]),
        'test': transforms.Compose([
            transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST),
            transforms.ToTensor(),
            #transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182])
        ]),
    }

    # Data Loading
    data_dir = '/media/salman/DATA/General Datasets/MICCAI/EndoVis_2018'
    # json path for class definitions
    json_path = '/home/salman/pytorch/endovis18/datasets/endovisClasses.json'

    trainval_image_dataset = endovisDataset(os.path.join(data_dir, 'train_data'),
                        data_transforms['train'], json_path=json_path, training=True)
    val_size = int(args.validationSplit * len(trainval_image_dataset))
    train_size = len(trainval_image_dataset) - val_size
    train_image_dataset, val_image_dataset = torch.utils.data.random_split(trainval_image_dataset, [train_size,
                                                                                                       val_size])

    test_image_dataset = endovisDataset(os.path.join(data_dir, 'test_data'),
                        data_transforms['test'], json_path=json_path, training=False)



    train_dataloader = torch.utils.data.DataLoader(train_image_dataset,
                                                  batch_size=args.batchSize,
                                                  shuffle=True,
                                                  num_workers=args.workers)
    val_dataloader = torch.utils.data.DataLoader(val_image_dataset,
                                                batch_size=args.batchSize,
                                                shuffle=True,
                                                num_workers=args.workers)
    test_dataloader = torch.utils.data.DataLoader(test_image_dataset,
                                                  batch_size=args.batchSize,
                                                  shuffle=True,
                                                  num_workers=args.workers)

    train_dataset_size = len(train_image_dataset)
    val_dataset_size = len(val_image_dataset)
    test_dataset_size = len(test_image_dataset)

    # Get the dictionary for the id and RGB value pairs for the dataset
    # print(train_image_dataset.classes)
    classes = trainval_image_dataset.classes
    key = utils.disentangleKey(classes)
    num_classes = len(key)

    # Initialize the model
    model = SegNet(batchNorm_momentum=args.bnMomentum , num_classes=num_classes)

    # # Optionally resume from a checkpoint
    # if args.resume:
    #     if os.path.isfile(args.resume):
    #         print("=> loading checkpoint '{}'".format(args.resume))
    #         checkpoint = torch.load(args.resume)
    #         #args.start_epoch = checkpoint['epoch']
    #         pretrained_dict = checkpoint['state_dict']
    #         pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model.state_dict()}
    #         model.state_dict().update(pretrained_dict)
    #         model.load_state_dict(model.state_dict())
    #         print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    #     else:
    #         print("=> no checkpoint found at '{}'".format(args.resume))
    #
    #     # # Freeze the encoder weights
    #     # for param in model.encoder.parameters():
    #     #     param.requires_grad = False
    #
    #     optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd)
    # else:
    optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd)

    # Load the saved model
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

    print(model)

    # Define loss function (criterion)
    criterion = nn.CrossEntropyLoss()

    # Use a learning rate scheduler
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    if use_gpu:
        model.cuda()
        criterion.cuda()

    # Initialize an evaluation Object
    evaluator = utils.Evaluate(key, use_gpu)

    for epoch in range(args.start_epoch, args.epochs):
        #adjust_learning_rate(optimizer, epoch)

        # Train for one epoch
        print('>>>>>>>>>>>>>>>>>>>>>>>Training<<<<<<<<<<<<<<<<<<<<<<<')
        train(train_dataloader, model, criterion, optimizer, scheduler, epoch, key)

        # Evaulate on validation set
        print('>>>>>>>>>>>>>>>>>>>>>>>Testing<<<<<<<<<<<<<<<<<<<<<<<')
        validate(val_dataloader, model, criterion, epoch, key, evaluator)

        # Calculate the metrics
        print('>>>>>>>>>>>>>>>>>> Evaluating the Metrics <<<<<<<<<<<<<<<<<')
        IoU = evaluator.getIoU()
        print('Mean IoU: {}, Class-wise IoU: {}'.format(torch.mean(IoU), IoU))
        writer.add_scalar('Epoch Mean IoU', torch.mean(IoU), epoch)
        PRF1 = evaluator.getPRF1()
        precision, recall, F1 = PRF1[0], PRF1[1], PRF1[2]
        print('Mean Precision: {}, Class-wise Precision: {}'.format(torch.mean(precision), precision))
        writer.add_scalar('Epoch Mean Precision', torch.mean(precision), epoch)
        print('Mean Recall: {}, Class-wise Recall: {}'.format(torch.mean(recall), recall))
        writer.add_scalar('Epoch Mean Recall', torch.mean(recall), epoch)
        print('Mean F1: {}, Class-wise F1: {}'.format(torch.mean(F1), F1))
        writer.add_scalar('Epoch Mean F1', torch.mean(F1), epoch)
        evaluator.reset()

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, filename=os.path.join(args.save_dir, 'checkpoint_{}.tar'.format(epoch)))
def main():
    global args
    args = parser.parse_args()
    print(args)

    if args.saveTest == 'True':
        args.saveTest = True
    elif args.saveTest == 'False':
        args.saveTest = False

    # Check if the save directory exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    cudnn.benchmark = True

    data_transform = transforms.Compose([
        transforms.Resize((args.imageSize, args.imageSize),
                          interpolation=Image.NEAREST),
        transforms.ToTensor(),
    ])

    # Data Loading
    data_dir = '/home/salman/pytorch/segmentationNetworks/datasets/miccaiSegOrgans'
    # json path for class definitions
    json_path = '/home/salman/pytorch/segmentationNetworks/datasets/miccaiSegOrganClasses.json'

    image_dataset = miccaiSegDataset(os.path.join(data_dir, 'test'),
                                     data_transform, json_path)

    dataloader = torch.utils.data.DataLoader(image_dataset,
                                             batch_size=args.batchSize,
                                             shuffle=True,
                                             num_workers=args.workers)

    # Get the dictionary for the id and RGB value pairs for the dataset
    classes = image_dataset.classes
    key = utils.disentangleKey(classes)
    num_classes = len(key)

    # Initialize the model
    model = SegNet(args.bnMomentum, num_classes)

    # Load the saved model
    if os.path.isfile(args.model):
        print("=> loading checkpoint '{}'".format(args.model))
        checkpoint = torch.load(args.model)
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.model))

    print(model)

    # Define loss function (criterion)
    criterion = nn.CrossEntropyLoss()

    if use_gpu:
        model.cuda()
        criterion.cuda()

    # Initialize an evaluation Object
    evaluator = utils.Evaluate(key, use_gpu)

    # Evaulate on validation/test set
    print('>>>>>>>>>>>>>>>>>>>>>>>Testing<<<<<<<<<<<<<<<<<<<<<<<')
    validate(dataloader, model, criterion, key, evaluator)

    # Calculate the metrics
    print('>>>>>>>>>>>>>>>>>> Evaluating the Metrics <<<<<<<<<<<<<<<<<')
    IoU = evaluator.getIoU()
    print('Mean IoU: {}, Class-wise IoU: {}'.format(torch.mean(IoU), IoU))
    PRF1 = evaluator.getPRF1()
    precision, recall, F1 = PRF1[0], PRF1[1], PRF1[2]
    print('Mean Precision: {}, Class-wise Precision: {}'.format(
        torch.mean(precision), precision))
    print('Mean Recall: {}, Class-wise Recall: {}'.format(
        torch.mean(recall), recall))
    print('Mean F1: {}, Class-wise F1: {}'.format(torch.mean(F1), F1))