Beispiel #1
0
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

hflip_data_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop((224, 224)),
    transforms.RandomHorizontalFlip(p=1.0),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

darkness_jitter_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop((224, 224)),
    transforms.ColorJitter(brightness=[0.5, 0.9]),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

lightness_jitter_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop((224, 224)),
    transforms.ColorJitter(brightness=[1.1, 1.5]),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

rotations_transform = transforms.Compose([
Beispiel #2
0
#cnt_class_weights = float(len(train_img_lists)) / (cfg.num_classes * label_cnt)

train_img_lists = list(
    map(lambda x: os.path.join(cfg.data_root, cfg.train_dir, x),
        train_img_lists))
val_img_lists = os.listdir(os.path.join(cfg.data_root, cfg.val_dir))
val_img_lists = list(
    map(lambda x: os.path.join(cfg.data_root, cfg.val_dir, x), val_img_lists))

train_transforms_warm = transforms.Compose([
    transforms.Resize(size=(args.img_size + 20, args.img_size + 20)),
    transforms.RandomCrop(size=(args.img_size, args.img_size)),
    transforms.RandomHorizontalFlip(),
    #transforms.RandomRotation((-10, 10)),
    transforms.ColorJitter(0.3, 0.3, 0.3),
    transforms.ToTensor(),
    transforms.Normalize(cfg.mean, cfg.std)
])
train_transforms = transforms.Compose([
    transforms.Resize(size=(args.img_size + 20, args.img_size + 20)),
    #transforms.RandomRotation((-10, 10)),
    transforms.RandomCrop(size=(args.img_size, args.img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.3, 0.3, 0.3),
    transforms.ToTensor(),
    transforms.Normalize(cfg.mean, cfg.std)
])
train_transforms_no_color_aug = transforms.Compose([
    transforms.Resize(size=(args.img_size + 20, args.img_size + 20)),
    transforms.RandomHorizontalFlip(),
Beispiel #3
0
    def __init__(self):
        ##The top config
        #self.data_root = '/media/hhy/data/USdata/MergePhase1/test_0.3'
        #self.log_dir = '/media/hhy/data/code_results/MILs/MIL_H_Attention'

        self.root = '/remote-home/my/Ultrasound_CV/data/Ruijin/clean'
        self.log_dir = '/remote-home/my/hhy/Ultrasound_MIL/experiments/PLN1/weighted_sampler+res18/fold4'
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)
        ##training config
        self.lr = 1e-4
        self.epoch = 50
        self.resume = -1
        self.batch_size = 1
        self.net = Res_Attention()
        self.net.cuda()

        self.optimizer = Adam(self.net.parameters(), lr=self.lr)
        self.lrsch = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=[10, 30, 50, 70], gamma=0.5)

        self.logger = Logger(self.log_dir)
        self.train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomResizedCrop((224, 224)),
            transforms.RandomHorizontalFlip(0.5),
            transforms.RandomVerticalFlip(0.5),
            transforms.ColorJitter(0.25, 0.25, 0.25, 0.25),
            transforms.ToTensor()
        ])
        self.test_transform = transforms.Compose(
            [transforms.Resize((224, 224)),
             transforms.ToTensor()])

        self.label_name = "手术淋巴结情况(0未转移;1转移)"
        self.trainbag = RuijinBags(self.root, [0, 1, 2, 3],
                                   self.train_transform,
                                   label_name=self.label_name)
        self.testbag = RuijinBags(self.root, [4],
                                  self.test_transform,
                                  label_name=self.label_name)

        train_label_list = list(
            map(lambda x: int(x['label']), self.trainbag.patient_info))
        pos_ratio = sum(train_label_list) / len(train_label_list)
        print(pos_ratio)
        train_weight = [(1 - pos_ratio) if x > 0 else pos_ratio
                        for x in train_label_list]

        self.train_sampler = WeightedRandomSampler(weights=train_weight,
                                                   num_samples=len(
                                                       self.trainbag))
        self.train_loader = DataLoader(self.trainbag,
                                       batch_size=self.batch_size,
                                       num_workers=8,
                                       sampler=self.train_sampler)
        self.val_loader = DataLoader(self.testbag,
                                     batch_size=self.batch_size,
                                     shuffle=False,
                                     num_workers=8)

        if self.resume > 0:
            self.net, self.optimizer, self.lrsch, self.loss, self.global_step = self.logger.load(
                self.net, self.optimizer, self.lrsch, self.loss, self.resume)
        else:
            self.global_step = 0

        # self.trainer = MTTrainer(self.net, self.optimizer, self.lrsch, self.loss, self.train_loader, self.val_loader, self.logger, self.global_step, mode=2)
        self.trainer = MILTrainer(self.net, self.optimizer, self.lrsch, None,
                                  self.train_loader, self.val_loader,
                                  self.logger, self.global_step)
Beispiel #4
0
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
        format_string += ', xscale={0}'.format(
            tuple(round(s, 4) for s in self.xscale))
        format_string += ', yscale={0}'.format(
            tuple(round(r, 4) for r in self.xscale))
        format_string += ', interpolation={0})'.format(interpolate_str)
        return format_string


transform = {
    'train':
    transforms.Compose([
        transforms.RandomRotation(12, resample=Image.BILINEAR),
        CenterRandomCrop(size, xscale=(0.6, 1.0), aspect_ratio=(1, 1.4)),
        transforms.ColorJitter(0.2, 0.1, 0.1, 0.04),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
    'val':
    transforms.Compose([
        transforms.CenterCrop((512, 640)),
        transforms.Resize(size, interpolation=Image.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
}

if torch.cuda.is_available():
Beispiel #5
0
def execute(args):
    '''
    Train the Model

    ..notes:
    Standard input image size: 2448 x 2448

    Label construction size is dependent on the input image size
    and the kernels used
        label_size = 196 #3x3 kernel sizes for all three decode layers and input size 284
        label_size = 180 #7x7, 5x5, 3x3 kernel sizes for decode layers and input size 284
        label_size = 172 #9x9, 6x6, 3x3 kernel sizes for decode layers and input size 284
        label_size = TBD #9x9, 6x6, 3x3 kernel sizes for decode layers and input size 1224
        label_size = 500 #9x9, 6x6, 3x3 kernel sizes for decode layers and input size 612
        label_size = 420 #7x7, 3x3, 3x3, 3x3, kernel size for decode layers and input size 612
    '''

    # Satellite Image Transformations
    t = transforms.Compose([
        transforms.Resize(args.input_image_size, interpolation=4),
        transforms.ColorJitter(),
        transforms.RandomHorizontalFlip(p=0.3),
        transforms.ToTensor(),
    ])

    # Mask Transformations
    t2 = transforms.Compose([
        transforms.Resize(args.label_size),
        transforms.ToTensor(),
    ])

    data_set = landpy.MyDataLoader(args.data_dir,
                                   args.label_size,
                                   image_transforms=t,
                                   mask_transforms=t2)

    train_loader, validation_loader = landpy.create_data_loaders(
        data_set, args.training_split, args.batch_size)

    # Establish the UNet Model & Training parameters
    unet_model = landpy.UNet(3, 7)
    if args.start_new_model == 0:
        unet_path = os.path.join(args.model_paths, f"{args.model_to_load}.pt")
        unet_model.load_state_dict(torch.load(unet_path))

    loss_weights = torch.tensor([
        0.145719925, 0.022623007, 0.133379898, 0.098588677, 0.36688587,
        0.222802623, 0.01
    ])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        torch.cuda.empty_cache()
        print("GPU Enabled")
        print(f"Current GPU Memory Usage: {torch.cuda.memory_allocated()}")
        print("Making Model GPU Based")
        unet_model = unet_model.cuda()
        loss_weights = loss_weights.cuda()

    loss = torch.nn.NLLLoss(weight=loss_weights)

    # optimizer = torch.optim.SGD(unet_model.parameters(), lr=args.learning_rate,
    #                             momentum=args.momentum
    #                             )
    optimizer = torch.optim.Adam(unet_model.parameters(),
                                 lr=args.learning_rate)

    final_path = os.path.join(args.model_paths, f"{args.final_model_name}.pt")

    print(
        f"Number of Images for Training: {int(len(data_set)*args.training_split)}"
    )
    print(
        f"Number of Images for Validation: {int(len(data_set)*(1-args.training_split))}"
    )
    print(f"Number of Epochs Used: {args.epochs}")
    print(f"Batch Size Used: {args.batch_size}")
    print(f"Learning Rate Used: {args.learning_rate}")
    print(f"Momentum for Optimizer: {args.momentum}")
    print(f"Final Model Name: {args.final_model_name}")
    print(f"Loss Weights by Class: {loss_weights}")
    print("\n")

    epoch_losses = {}
    checkpoint_idx = 1
    print("Begin Training")
    for epoch in trange(args.epochs):
        if use_gpu:
            torch.cuda.empty_cache()

        t0 = time.time()
        total_training_loss = 0
        with torch.set_grad_enabled(True):
            for i, (batch_x_images, batch_y_mask,
                    match_y_class_mask) in enumerate(train_loader):
                unet_model.train()
                if use_gpu:
                    batch_x_images = batch_x_images.cuda()
                    match_y_class_mask = match_y_class_mask.cuda()

                batch_loss = landpy.train_step(batch_x_images,
                                               match_y_class_mask, optimizer,
                                               loss, unet_model)
                total_training_loss += batch_loss

        t1 = time.time()
        print(
            f"Total Training Loss for Epoch {epoch} is: {total_training_loss}")

        total_validation_loss = 0
        total_mean_iou = []
        with torch.no_grad():
            if use_gpu:
                torch.cuda.empty_cache()
            for j, (batch_val_x_images, batch_val_y_mask,
                    match_val_y_class_mask) in enumerate(validation_loader):
                unet_model.eval()
                if use_gpu:
                    batch_val_x_images = batch_val_x_images.cuda()
                    match_val_y_class_mask = match_val_y_class_mask.cuda()

                outputs = unet_model(batch_val_x_images)
                soft_max_output = torch.nn.LogSoftmax(dim=1)(outputs)
                val_batch_loss = loss(soft_max_output,
                                      match_val_y_class_mask.long())
                total_validation_loss += val_batch_loss

                batch_mean_iou = landpy.mean_IOU(soft_max_output,
                                                 match_val_y_class_mask)
                total_mean_iou.append(batch_mean_iou)

        epoch_losses[epoch] = {
            "Training Loss": total_training_loss.item(),
            "Validation Loss": total_validation_loss.item(),
            "Mean IOU": np.mean(np.array(total_mean_iou)),
            "Execution Time": (t1 - t0)
        }
        print(
            f"Total Validation Loss for Epoch {epoch} is: {total_validation_loss.item()}"
        )

        if epoch % args.checkpoint == 0:
            # Checkpoint Save
            checkpoint_path = os.path.join(
                args.model_paths,
                f"{args.final_model_name}_chp_{checkpoint_idx}.pt")
            torch.save(unet_model.state_dict(), checkpoint_path)
            checkpoint_idx += 1

    print("\n")
    print("Completed Training; Saving Model")
    torch.save(unet_model.state_dict(), final_path)

    print("Saving Epoch Losses to DF")
    epoch_losses_path = os.path.join(
        args.epoch_loss_dir, args.final_model_name + "_epoch_losses.csv")
    df = pd.DataFrame.from_dict(epoch_losses, orient='index')
    df.to_csv(epoch_losses_path)
Beispiel #6
0
def get_dataloaders(dataset,
                    batch,
                    dataroot,
                    split=0.15,
                    split_idx=0,
                    multinode=False,
                    target_lb=-1):
    if 'cifar' in dataset or 'svhn' in dataset:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
        ])
    elif 'imagenet' in dataset:
        input_size = 224
        sized_size = 256

        if 'efficientnet' in C.get()['model']['type']:
            input_size = EfficientNet.get_image_size(C.get()['model']['type'])
            sized_size = input_size + 32  # TODO
            # sized_size = int(round(input_size / 224. * 256))
            # sized_size = input_size
            logger.info('size changed to %d/%d.' % (input_size, sized_size))

        transform_train = transforms.Compose([
            EfficientNetRandomCrop(input_size),
            transforms.Resize((input_size, input_size),
                              interpolation=Image.BICUBIC),
            # transforms.RandomResizedCrop(input_size, scale=(0.1, 1.0), interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(
                brightness=0.4,
                contrast=0.4,
                saturation=0.4,
            ),
            transforms.ToTensor(),
            Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        transform_test = transforms.Compose([
            EfficientNetCenterCrop(input_size),
            transforms.Resize((input_size, input_size),
                              interpolation=Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    else:
        raise ValueError('dataset=%s' % dataset)

    total_aug = augs = None
    if isinstance(C.get()['aug'], list):
        logger.debug('augmentation provided.')
        transform_train.transforms.insert(0, Augmentation(C.get()['aug']))
    else:
        logger.debug('augmentation: %s' % C.get()['aug'])
        if C.get()['aug'] == 'uniformaugment':
            transform_train.transforms.insert(0, UniformAugment())
        elif C.get()['aug'] in ['default']:
            pass
        else:
            raise ValueError('not found augmentations. %s' % C.get()['aug'])

    if C.get()['cutout'] > 0:
        transform_train.transforms.append(CutoutDefault(C.get()['cutout']))

    if dataset == 'cifar10':
        total_trainset = torchvision.datasets.CIFAR10(
            root=dataroot,
            train=True,
            download=True,
            transform=transform_train)
        testset = torchvision.datasets.CIFAR10(root=dataroot,
                                               train=False,
                                               download=True,
                                               transform=transform_test)
    elif dataset == 'reduced_cifar10':
        total_trainset = torchvision.datasets.CIFAR10(
            root=dataroot,
            train=True,
            download=True,
            transform=transform_train)
        sss = StratifiedShuffleSplit(n_splits=1,
                                     test_size=46000,
                                     random_state=0)  # 4000 trainset
        sss = sss.split(list(range(len(total_trainset))),
                        total_trainset.targets)
        train_idx, valid_idx = next(sss)
        targets = [total_trainset.targets[idx] for idx in train_idx]
        total_trainset = Subset(total_trainset, train_idx)
        total_trainset.targets = targets

        testset = torchvision.datasets.CIFAR10(root=dataroot,
                                               train=False,
                                               download=True,
                                               transform=transform_test)
    elif dataset == 'cifar100':
        total_trainset = torchvision.datasets.CIFAR100(
            root=dataroot,
            train=True,
            download=True,
            transform=transform_train)
        testset = torchvision.datasets.CIFAR100(root=dataroot,
                                                train=False,
                                                download=True,
                                                transform=transform_test)
    elif dataset == 'svhn':
        trainset = torchvision.datasets.SVHN(root=dataroot,
                                             split='train',
                                             download=True,
                                             transform=transform_train)
        extraset = torchvision.datasets.SVHN(root=dataroot,
                                             split='extra',
                                             download=True,
                                             transform=transform_train)
        total_trainset = ConcatDataset([trainset, extraset])
        testset = torchvision.datasets.SVHN(root=dataroot,
                                            split='test',
                                            download=True,
                                            transform=transform_test)
    elif dataset == 'reduced_svhn':
        total_trainset = torchvision.datasets.SVHN(root=dataroot,
                                                   split='train',
                                                   download=True,
                                                   transform=transform_train)
        sss = StratifiedShuffleSplit(n_splits=1,
                                     test_size=73257 - 1000,
                                     random_state=0)  # 1000 trainset
        sss = sss.split(list(range(len(total_trainset))),
                        total_trainset.targets)
        train_idx, valid_idx = next(sss)
        targets = [total_trainset.targets[idx] for idx in train_idx]
        total_trainset = Subset(total_trainset, train_idx)
        total_trainset.targets = targets

        testset = torchvision.datasets.SVHN(root=dataroot,
                                            split='test',
                                            download=True,
                                            transform=transform_test)
    elif dataset == 'imagenet':
        total_trainset = ImageNet(root=os.path.join(dataroot,
                                                    'imagenet-pytorch'),
                                  transform=transform_train)
        testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'),
                           split='val',
                           transform=transform_test)

        # compatibility
        total_trainset.targets = [lb for _, lb in total_trainset.samples]
    elif dataset == 'reduced_imagenet':
        # randomly chosen indices
        #         idx120 = sorted(random.sample(list(range(1000)), k=120))
        idx120 = [
            16, 23, 52, 57, 76, 93, 95, 96, 99, 121, 122, 128, 148, 172, 181,
            189, 202, 210, 232, 238, 257, 258, 259, 277, 283, 289, 295, 304,
            307, 318, 322, 331, 337, 338, 345, 350, 361, 375, 376, 381, 388,
            399, 401, 408, 424, 431, 432, 440, 447, 462, 464, 472, 483, 497,
            506, 512, 530, 541, 553, 554, 557, 564, 570, 584, 612, 614, 619,
            626, 631, 632, 650, 657, 658, 660, 674, 675, 680, 682, 691, 695,
            699, 711, 734, 736, 741, 754, 757, 764, 769, 770, 780, 781, 787,
            797, 799, 811, 822, 829, 830, 835, 837, 842, 843, 845, 873, 883,
            897, 900, 902, 905, 913, 920, 925, 937, 938, 940, 941, 944, 949,
            959
        ]
        total_trainset = ImageNet(root=os.path.join(dataroot,
                                                    'imagenet-pytorch'),
                                  transform=transform_train)
        testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'),
                           split='val',
                           transform=transform_test)

        # compatibility
        total_trainset.targets = [lb for _, lb in total_trainset.samples]

        sss = StratifiedShuffleSplit(n_splits=1,
                                     test_size=len(total_trainset) - 50000,
                                     random_state=0)  # 4000 trainset
        sss = sss.split(list(range(len(total_trainset))),
                        total_trainset.targets)
        train_idx, valid_idx = next(sss)

        # filter out
        train_idx = list(
            filter(lambda x: total_trainset.labels[x] in idx120, train_idx))
        valid_idx = list(
            filter(lambda x: total_trainset.labels[x] in idx120, valid_idx))
        test_idx = list(
            filter(lambda x: testset.samples[x][1] in idx120,
                   range(len(testset))))

        targets = [
            idx120.index(total_trainset.targets[idx]) for idx in train_idx
        ]
        for idx in range(len(total_trainset.samples)):
            if total_trainset.samples[idx][1] not in idx120:
                continue
            total_trainset.samples[idx] = (total_trainset.samples[idx][0],
                                           idx120.index(
                                               total_trainset.samples[idx][1]))
        total_trainset = Subset(total_trainset, train_idx)
        total_trainset.targets = targets

        for idx in range(len(testset.samples)):
            if testset.samples[idx][1] not in idx120:
                continue
            testset.samples[idx] = (testset.samples[idx][0],
                                    idx120.index(testset.samples[idx][1]))
        testset = Subset(testset, test_idx)
        print('reduced_imagenet train=', len(total_trainset))
    else:
        raise ValueError('invalid dataset name=%s' % dataset)

    if total_aug is not None and augs is not None:
        total_trainset.set_preaug(augs, total_aug)
        print('set_preaug-')

    train_sampler = None
    if split > 0.0:
        sss = StratifiedShuffleSplit(n_splits=5,
                                     test_size=split,
                                     random_state=0)
        sss = sss.split(list(range(len(total_trainset))),
                        total_trainset.targets)
        for _ in range(split_idx + 1):
            train_idx, valid_idx = next(sss)

        if target_lb >= 0:
            train_idx = [
                i for i in train_idx if total_trainset.targets[i] == target_lb
            ]
            valid_idx = [
                i for i in valid_idx if total_trainset.targets[i] == target_lb
            ]

        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetSampler(valid_idx)

        if multinode:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                Subset(total_trainset, train_idx),
                num_replicas=dist.get_world_size(),
                rank=dist.get_rank())
    else:
        valid_sampler = SubsetSampler([])

        if multinode:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                total_trainset,
                num_replicas=dist.get_world_size(),
                rank=dist.get_rank())
            logger.info(
                f'----- dataset with DistributedSampler  {dist.get_rank()}/{dist.get_world_size()}'
            )

    trainloader = torch.utils.data.DataLoader(
        total_trainset,
        batch_size=batch,
        shuffle=True if train_sampler is None else False,
        num_workers=8,
        pin_memory=True,
        sampler=train_sampler,
        drop_last=True)
    validloader = torch.utils.data.DataLoader(total_trainset,
                                              batch_size=batch,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=True,
                                              sampler=valid_sampler,
                                              drop_last=False)

    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=batch,
                                             shuffle=False,
                                             num_workers=8,
                                             pin_memory=True,
                                             drop_last=False)
    return train_sampler, trainloader, validloader, testloader
Beispiel #7
0
def get_transforms(config, image_size=None):
    config = config.get_dictionary()
    if image_size is not None:
        image_size = image_size
    elif config['estimator'] not in resize_size_dict:
        image_size = 32
    else:
        image_size = resize_size_dict[config['estimator']]

    val_transforms = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
    ])
    if parse_bool(config['aug']):
        if parse_bool(config['auto_aug']):
            # from .transforms import AutoAugment
            data_transforms = {
                'train':
                transforms.Compose([
                    # AutoAugment(),
                    transforms.Resize(image_size),
                    transforms.RandomCrop(image_size,
                                          padding=int(image_size / 8)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                ]),
                'val':
                val_transforms,
            }
        else:
            transform_list = []
            if parse_bool(config['jitter']):
                transform_list.append(
                    transforms.ColorJitter(brightness=config['brightness'],
                                           saturation=config['saturation'],
                                           hue=config['hue']))
            if parse_bool(config['affine']):
                transform_list.append(
                    transforms.RandomAffine(degrees=config['degree'],
                                            shear=config['shear']))

            transform_list.append(transforms.RandomResizedCrop(image_size))
            transform_list.append(transforms.RandomCrop(image_size, padding=4))

            if parse_bool(config['random_flip']):
                transform_list.append(transforms.RandomHorizontalFlip())

            transform_list.append(transforms.ToTensor())

            data_transforms = {
                'train': transforms.Compose(transform_list),
                'val': val_transforms
            }
    else:
        data_transforms = {
            'train':
            transforms.Compose([
                transforms.Resize(image_size),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
            ]),
            'val':
            val_transforms,
        }
    return data_transforms
Beispiel #8
0
    def __init__(self,
                 path_root,
                 t_task,
                 n_way,
                 k_shot,
                 k_query,
                 x_dim,
                 split,
                 augment='0',
                 test=None,
                 shuffle=True,
                 fetch_global=False):
        self.t_task = t_task
        self.n_way = n_way
        self.k_shot = k_shot
        self.k_query = k_query
        self.x_dim = list(map(int, x_dim.split(',')))
        self.split = split
        self.shuffle = shuffle
        self.path_root = path_root
        self.fet_global = fetch_global

        if augment == '0':
            self.transform = transforms.Compose([
                transforms.Lambda(f1),
                transforms.Resize(self.x_dim[:2]),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225))
            ])
        elif augment == '1':
            if self.split == 'train':
                self.transform = transforms.Compose([
                    # lambda x: Image.open(x).convert('RGB'),
                    transforms.Lambda(f1),
                    transforms.Resize(
                        (self.x_dim[0] + 20, self.x_dim[1] + 20)),
                    transforms.RandomCrop(self.x_dim[:2]),
                    transforms.RandomHorizontalFlip(),
                    transforms.ColorJitter(brightness=.1,
                                           contrast=.1,
                                           saturation=.1,
                                           hue=.1),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406),
                                         (0.229, 0.224, 0.225))
                ])
            else:
                self.transform = transforms.Compose([
                    # lambda x: Image.open(x).convert('RGB'),
                    transforms.Lambda(f1),
                    transforms.Resize(
                        (self.x_dim[0] + 20, self.x_dim[1] + 20)),
                    transforms.RandomCrop(self.x_dim[:2]),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406),
                                         (0.229, 0.224, 0.225))
                ])

        self.path = os.path.join(path_root, 'images')

        self.lmdb_file = os.path.join(path_root, "lmdb_data",
                                      "%s.lmdb" % self.split)
        if not os.path.exists(self.lmdb_file):
            print("lmdb_file is not found, start to generate %s" %
                  self.lmdb_file)
            self._generate_lmdb()

        # read lmdb_file
        self.env = lmdb.open(self.lmdb_file,
                             subdir=False,
                             readonly=True,
                             lock=False,
                             readahead=False,
                             meminit=False)
        with self.env.begin(write=False) as txn:
            self.total_sample = pyarrow.deserialize(txn.get(b'__len__'))
            self.keys = pyarrow.deserialize(txn.get(b'__keys__'))
            self.label2num = pyarrow.deserialize(txn.get(b'__label2num__'))
            self.num2label = pyarrow.deserialize(txn.get(b'__num2label__'))

        self.image_labels = [i.decode() for i in self.keys]
        self.total_cls = len(self.num2label)
        self.dic_img_label = defaultdict(list)
        for i in self.image_labels:
            self.dic_img_label[i[:9]].append(i)

        self.support_set_size = self.n_way * self.k_shot  # num of samples per support set
        self.query_set_size = self.n_way * self.k_query

        self.episode = self.total_sample // (
            self.t_task *
            (self.support_set_size + self.query_set_size))  # how many episode

        if platform.system().lower() == 'windows':
            self.platform = "win"
            del self.env
        elif platform.system().lower() == 'linux':
            self.platform = "linux"
Beispiel #9
0
    def __init__(self,
                 cache_dir,
                 image_dir,
                 split,
                 chunk_size=(1.5, 1.5),
                 chunk_thresh=0.3,
                 chunk_margin=(0.2, 0.2),
                 nb_pts=-1,
                 num_rgbd_frames=0,
                 resize=(160, 120),
                 image_normalizer=None,
                 k=3,
                 z_rot=None,
                 flip=0.0,
                 color_jitter=None,
                 to_tensor=False,
                 ):
        """

        Args:
            cache_dir (str): path to cache of 3D point clouds, 3D semantic labels and RGB-D overlap info
            image_dir (str): path to 2D images, depth maps and poses
            split:
            chunk_size (tuple): xy chunk size
            chunk_thresh (float): minimum number of labeled points within a chunk
            chunk_margin (tuple): margin to calculate ratio of labeled points within a chunk
            nb_pts (int): number of points to resample in a chunk
            num_rgbd_frames (int): number of RGB-D frames to choose
            resize (tuple): target image size
            image_normalizer (tuple, optional): (mean, std)
            k (int): k-nn unprojected neighbors of target points
            z_rot (tuple, optional): range of rotation (degree instead of rad)
            flip (float): probability to flip horizontally
            color_jitter (tuple, optional): paramters of color jitter
            to_tensor (bool): whether to convert to torch.Tensor
        """
        super(ScanNet2D3DChunks, self).__init__()

        # cache: pickle files containing point clouds, 3D labels and rgbd overlap
        self.cache_dir = cache_dir
        # includes color, depth, 2D label
        self.image_dir = image_dir

        # load split
        self.split = split
        with open(osp.join(self.split_dir, self.split_map[split]), 'r') as f:
            self.scan_ids = [line.rstrip() for line in f.readlines()]

        # ---------------------------------------------------------------------------- #
        # Build label mapping
        # ---------------------------------------------------------------------------- #
        # read tsv file to get raw to nyu40 mapping (dict)
        self.raw_to_nyu40_mapping = read_label_mapping(self.label_id_tsv_path,
                                                       label_from='id', label_to='nyu40id', as_int=True)
        self.raw_to_nyu40 = np.zeros(max(self.raw_to_nyu40_mapping.keys()) + 1, dtype=np.int64)
        for key, value in self.raw_to_nyu40_mapping.items():
            self.raw_to_nyu40[key] = value
        # scannet
        self.scannet_mapping = load_class_mapping(self.scannet_classes_path)
        assert len(self.scannet_mapping) == 20
        # nyu40 -> scannet
        self.nyu40_to_scannet = np.full(shape=41, fill_value=self.ignore_value, dtype=np.int64)
        self.nyu40_to_scannet[list(self.scannet_mapping.keys())] = np.arange(len(self.scannet_mapping))
        # scannet -> nyu40
        self.scannet_to_nyu40 = np.array(list(self.scannet_mapping.keys()) + [0], dtype=np.int64)
        # raw -> scannet
        self.raw_to_scannet = self.nyu40_to_scannet[self.raw_to_nyu40]
        self.class_names = tuple(self.scannet_mapping.values())

        # ---------------------------------------------------------------------------- #
        # 3D
        # ---------------------------------------------------------------------------- #
        # The height / z-axis is ignored in fact.
        self.chunk_size = np.array(chunk_size, dtype=np.float32)
        self.chunk_thresh = chunk_thresh
        self.chunk_margin = np.array(chunk_margin, dtype=np.float32)
        self.nb_pts = nb_pts

        # ---------------------------------------------------------------------------- #
        # 2D
        # ---------------------------------------------------------------------------- #
        self.num_rgbd_frames = num_rgbd_frames
        self.resize = resize
        self.image_normalizer = image_normalizer

        # ---------------------------------------------------------------------------- #
        # 2D-3D
        # ---------------------------------------------------------------------------- #
        self.k = k
        if num_rgbd_frames > 0 and resize:
            depth_size = (640, 480)  # intrinsic matrix is based on 640x480 depth maps.
            self.resize_scale = (depth_size[0] / resize[0], depth_size[1] / resize[1])
        else:
            self.resize_scale = None

        # ---------------------------------------------------------------------------- #
        # Augmentation
        # ---------------------------------------------------------------------------- #
        self.z_rot = z_rot
        self.flip = flip
        self.color_jitter = T.ColorJitter(*color_jitter) if color_jitter else None
        self.to_tensor = to_tensor

        # ---------------------------------------------------------------------------- #
        # Load cache data
        # ---------------------------------------------------------------------------- #
        # import time
        # tic = time.time()
        self._load_dataset()
        # print(time.time() - tic)

        logger = logging.getLogger(__name__)
        logger.info(str(self))
        # vessel_model = VesselNet('./vessels/')
        image = process(image,
                        size=cfg.img_size,
                        crop='normal',
                        preprocessing='clahe',
                        fourth=None)
        image = transforms.ToPILImage()(image)

        if self.transform:
            image = self.transform(image)

        return image, label


transforms_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation((-150, 150)),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.1,
                           contrast=0.5,
                           saturation=0.1,
                           hue=0.1),
    # transforms.RandomResizedCrop(cfg.img_size_crop),
    transforms.ToTensor(),
    transforms.Normalize([0.406, 0.456, 0.485], [0.225, 0.224, 0.229]),
])

transforms_valid = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.406, 0.456, 0.485], [0.225, 0.224, 0.229])
])
    def __init__(self, opt, phase="train"):
        # TODO split the dataset of val and test
        if phase == "val":
            phase = "test"

        opt.load_dataset_mode = 'reader'
        super(Cifar10Dataset, self).__init__(opt, phase)

        self.data_dir = opt.cifar10_dataset_dir

        self.data_name = CIFAR10

        self.x_transforms_train = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.24705882352941178),
            transforms.Resize((opt.imsize, opt.imsize)),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408),
                                 (0.2675, 0.2565, 0.2761))
        ])

        self.x_transforms_test = transforms.Compose([
            transforms.Resize((opt.imsize, opt.imsize)),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408),
                                 (0.2675, 0.2565, 0.2761))
        ])

        self.y_transforms = None
        if self.opt.load_dataset_mode == 'dir':

            self.data = []  # image_paths,targets
            self.label2Indices = defaultdict(list)
            image_dir = os.path.join(self.data_dir, phase)

            self._labels = os.listdir(image_dir)
            # get label to targets, dict type
            self.label2target = dict([
                (label, target) for target, label in enumerate(self.labels)
            ])
            self.target2label = dict([
                (target, label) for target, label in enumerate(self.labels)
            ])

            if not os.path.exists(image_dir):
                raise FileNotFoundError(
                    f"Image Dir {image_dir} not exists, please check it")
            for root, label_dirs, files in os.walk(image_dir):
                for file in files:
                    label = os.path.basename(root)

                    image_path = os.path.join(root, file)
                    target = self.label2target[label]

                    self.label2Indices[label].append(len(self.data))

                    self.data.append(
                        Bunch(image_path=image_path, target=target))
        elif self.opt.load_dataset_mode == 'reader':
            dataset = datasets.CIFAR10(root=os.path.join(
                self.data_dir, 'raw_data'),
                                       train=self.isTrain,
                                       download=True)
            self.data, self._labels, self.label2Indices, self.label2target, self.target2label = prepare_datas_by_standard_data(
                dataset)
        else:
            raise ValueError(
                f"Expected load_dataset_mode in [dir,reader], but got {self.opt.load_dataset_mode}"
            )
Beispiel #12
0
import os
import torch
from glob import glob
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from kaolin.rep import TriangleMesh
from config import *

img_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ColorJitter(brightness=0.4, saturation=0.4, contrast=0.4),
    transforms.ToTensor()
])

vp_num = CUBOID_NUM + SPHERE_NUM + CONE_NUM


class PointMixUpDataset(Dataset):
    def __init__(self, dataset_name):
        self.dataset_path = os.path.join(DATASET_ROOT, dataset_name)
        self.rgb_paths = sorted(glob(self.dataset_path + '/rgb*.png'))
        self.silhouette_paths = sorted(
            glob(self.dataset_path + '/silhouette*.png'))
        self.obj_paths = sorted(glob(self.dataset_path + '/mesh*.obj'))

    def __len__(self) -> int:
        return len(self.rgb_paths)

    def __getitem__(self, item) -> dict:
        rgb_path = self.rgb_paths[item]
Beispiel #13
0
 def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
     self.trans = transforms.ColorJitter(brightness, contrast, saturation,
                                         hue)
Beispiel #14
0
    ).to(device)


    # SETUP DATA TRANSFORMS
    if args.random:
        r = args.random
        train_transforms = transforms.Compose([
            transforms.ToTensor(),
            #transforms.RandomApply([
            #    transforms.GaussianBlur(3, sigma=(0.1, 2.0))
            #], p=0.2),
            transforms.RandomApply([
                transforms.Grayscale(num_output_channels=3)
            ], p=0.2),
            transforms.RandomApply([
                transforms.ColorJitter(brightness=r, contrast=r, saturation=r, hue=r)
            ]),  
            transforms.RandomApply([
                transforms.RandomAffine(r*10, shear=r*10)
            ]),
            transforms.RandomResizedCrop((32,32), scale=(1-r, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        test_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
    else:
        train_transforms = transforms.ToTensor()
        test_transforms = transforms.ToTensor()