def __init__(self,
                 root_folder,
                 folder=None,
                 num_classes=100,
                 split="train",
                 img_transforms=None):
        assert split in ["train", "val", "test"]
        # root folder, split
        self.root_folder = root_folder
        self.split = split
        self.img_transforms = img_transforms
        self.n_classes = num_classes

        #define mask tranforms
        mask_transforms = []
        mask_transforms.append(transforms.Scale((256, 256)))
        mask_transforms.append(transforms.ToTensor())
        self.mask_transforms = transforms.Compose(mask_transforms)

        # load all labels
        if folder is None:
            folder = os.path.join(root_folder,
                                  "ISIC2018_Task1_Training_GroundTruth")
        if not os.path.exists(folder):
            raise ValueError(
                'Label folder {:s} does not exist!'.format(folder))

        if split == "train":
            start, end = 1, 501  #1,1200; 1200,2400 #count = 1596
        elif split == "val":
            start, end = 2495, 2595  #count = 500
        elif split == "test":
            start, end = 121, 151  #2095, 2595 #count = 500

        masks_filename = []
        for itr in range(start, end):
            filename = "ISIC_Mask_" + str(itr) + ".png"
            mask = os.path.join(folder, filename)
            if mask is not None:
                masks_filename.append(mask)

        # load input images
        folder = os.path.join(root_folder, "ISIC2018_Task1-2_Training_Input")
        if not os.path.exists(folder):
            raise ValueError(
                'Input folder {:s} does not exist!'.format(folder))

        images_filename = []
        for itr in range(start, end):
            filename = "ISIC_Input_" + str(itr) + ".jpg"
            img = os.path.join(folder, filename)
            if img is not None:
                images_filename.append(img)

        self.images_filename = images_filename
        self.masks_filename = masks_filename
예제 #2
0
    def __getitem__(self, idx):
        # read input image
        if self.input_format == 'img':
            rgb_name = os.path.join(self.root_dir,
                                    self.rgbd_frame.iloc[idx, 0])
            with open(rgb_name, 'rb') as fRgb:
                rgb_image = Image.open(rgb_name).convert('RGB')

            depth_name = os.path.join(self.root_dir,
                                      self.rgbd_frame.iloc[idx, 1])
            with open(depth_name, 'rb') as fDepth:
                depth_image = Image.open(depth_name)

        # read input hdf5
        elif self.input_format == 'hdf5':
            file_name = os.path.join(self.root_dir,
                                     self.rgbd_frame.iloc[idx, 0])
            rgb_h5, depth_h5 = self.load_h5(file_name)
            rgb_image = Image.fromarray(rgb_h5, mode='RGB')
            depth_image = Image.fromarray(depth_h5.astype('float32'), mode='F')
        else:
            print('error: the input format is not supported now!')
            return None

        _s = np.random.uniform(1.0, 1.5)
        s = np.int(240*_s)
        degree = np.random.uniform(-5.0, 5.0)
        if self.split == 'train':
            tRgb = data_transform.Compose([transforms.Resize(s),
                                           data_transform.Rotation(degree),
                                           transforms.ColorJitter(brightness = 0.4, contrast = 0.4, saturation = 0.4),
#                                           data_transform.Lighting(0.1, imagenet_eigval, imagenet_eigvec)])
                                           transforms.CenterCrop((228, 304)),
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                                           transforms.ToPILImage()])

            tDepth = data_transform.Compose([transforms.Resize(s),
                                             data_transform.Rotation(degree),
                                             transforms.CenterCrop((228, 304))])
            rgb_image = tRgb(rgb_image)
            depth_image = tDepth(depth_image)
            if np.random.uniform()<0.5:
                rgb_image = rgb_image.transpose(Image.FLIP_LEFT_RIGHT)
                depth_image = depth_image.transpose(Image.FLIP_LEFT_RIGHT)

            rgb_image = transforms.ToTensor()(rgb_image)
            if self.input_format == 'img':
                depth_image = transforms.ToTensor()(depth_image)
            else:
                depth_image = data_transform.ToTensor()(depth_image)
            depth_image = depth_image.div(_s)
            sparse_image = self.createSparseDepthImage(depth_image, self.n_sample)
            rgbd_image = torch.cat((rgb_image, sparse_image), 0)
            sample = {'rgbd': rgbd_image, 'depth': depth_image}

        elif self.split == 'val':
            tRgb = data_transform.Compose([transforms.Resize(240),
                                           transforms.CenterCrop((228, 304)),
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                                           transforms.ToPILImage()])

            tDepth = data_transform.Compose([transforms.Resize(240),
                                             transforms.CenterCrop((228, 304))])

            rgb_raw = tDepth(rgb_image)
            rgb_image = tRgb(rgb_image)
            depth_image = tDepth(depth_image)
            rgb_image = transforms.ToTensor()(rgb_image)
            rgb_raw = transforms.ToTensor()(rgb_raw)
            if self.input_format == 'img':
                depth_image = transforms.ToTensor()(depth_image)
            else:
                depth_image = data_transform.ToTensor()(depth_image)
            sparse_image = self.createSparseDepthImage(depth_image, self.n_sample)
            rgbd_image = torch.cat((rgb_image, sparse_image), 0)

            sample = {'rgbd': rgbd_image, 'depth': depth_image, 'raw_rgb': rgb_raw }

        return sample
예제 #3
0
def main(args):
    # parse args
    if args.mode == "preTrain":
        # For MSE loss we store loss of validation set
        best_acc1 = 100000
    else:
        # For dice loss we store the dice coefficient accuracy of validation set
        best_acc1 = 0.0

    if args.gpu >= 0:
        print("Use GPU: {}".format(args.gpu))
    else:
        print('You are using CPU for computing!',
              'Yet we assume you are using a GPU.',
              'You will NOT be able to switch between CPU and GPU training!')

    criterion1 = jaccard_loss()
    # Train in self supervised mode
    if args.mode == "preTrain":
        model = preTrain_model()
        criterion = nn.MSELoss()
    else:  # Train in fully supervised mode
        #initial_param = {}
        #load_checkpoint(initial_param)
        #return
        model = uNet_model()  #imgSeg_model(initial_param=initial_param)
        criterion = nn.BCELoss()
        #criterion = nn.BCELoss()
    model_arch = "UNet"

    # put everthing to gpu
    if args.gpu >= 0:
        model = model.cuda(args.gpu)
        criterion = criterion.cuda(args.gpu)
        criterion1 = criterion1.cuda(args.gpu)

    # setup the optimizer
    if args.mode == "preTrain":
        optimizer = torch.optim.Adam(model.parameters(), args.lr)
    #momentum=args.momentum,
    #weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     weight_decay=args.weight_decay)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            if args.gpu < 0:
                model = model.cpu()
            else:
                model = model.cuda(args.gpu)

            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {}, acc1 {})".format(
                args.resume, checkpoint['epoch'], best_acc1))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # set up transforms for data augmentation
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # train transforms
    print('Loading training, validation and test dataset......')
    train_transforms = []
    train_transforms.append(transforms.Scale((256, 256)))
    train_transforms.append(transforms.ToTensor())
    train_transforms.append(normalize)
    train_transforms = transforms.Compose(train_transforms)
    # val transforms
    val_transforms = []
    val_transforms.append(transforms.Scale((256, 256)))
    val_transforms.append(transforms.ToTensor())
    val_transforms.append(normalize)
    val_transforms = transforms.Compose(val_transforms)
    # test transforms
    #test_transforms=[]
    #test_transforms.append(transforms.Scale((512, 512)))
    #test_transforms.append(transforms.ToTensor())
    #test_transforms.append(normalize)
    #test_transforms = transforms.Compose(test_transforms)

    train_dataset = MelanomaDataLoader(args.data_folder,
                                       split="train",
                                       img_transforms=train_transforms)
    val_dataset = MelanomaDataLoader(args.data_folder,
                                     split="val",
                                     img_transforms=val_transforms)
    #test_dataset = MelanomaDataLoader(args.data_folder,
    #	                                         split="test", transforms=test_transforms)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=None,
                                               drop_last=False)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=None,
                                             drop_last=False)
    #test_loader = torch.utils.data.DataLoader(
    #  test_dataset, batch_size=1, shuffle=False,
    #  num_workers=args.workers, pin_memory=True, sampler=None, drop_last=False)

    # enable cudnn benchmark
    cudnn.enabled = True
    cudnn.benchmark = True
    print(optimizer)
    print(criterion)
    if args.mode == "supTrain":
        print(criterion1)

    # start the training
    print("Training the model ...")
    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train(train_loader, model, criterion, criterion1, optimizer, epoch,
              "train", args)

        # evaluate on validation set
        #acc1 = validate(val_loader, model, epoch, args)
        if args.mode == "preTrain":
            loss = validate(val_loader, model, criterion, epoch, args)
            # remember best loss and save checkpoint
            is_best = loss < best_acc1
            best_acc1 = min(loss, best_acc1)
        else:
            acc1 = validate(val_loader, model, criterion, epoch, args)
            # remember best acc@1 and save checkpoint
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model_arch': model_arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)