transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) hflip_data_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.CenterCrop((224, 224)), transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) darkness_jitter_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.CenterCrop((224, 224)), transforms.ColorJitter(brightness=[0.5, 0.9]), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) lightness_jitter_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.CenterCrop((224, 224)), transforms.ColorJitter(brightness=[1.1, 1.5]), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) rotations_transform = transforms.Compose([
#cnt_class_weights = float(len(train_img_lists)) / (cfg.num_classes * label_cnt) train_img_lists = list( map(lambda x: os.path.join(cfg.data_root, cfg.train_dir, x), train_img_lists)) val_img_lists = os.listdir(os.path.join(cfg.data_root, cfg.val_dir)) val_img_lists = list( map(lambda x: os.path.join(cfg.data_root, cfg.val_dir, x), val_img_lists)) train_transforms_warm = transforms.Compose([ transforms.Resize(size=(args.img_size + 20, args.img_size + 20)), transforms.RandomCrop(size=(args.img_size, args.img_size)), transforms.RandomHorizontalFlip(), #transforms.RandomRotation((-10, 10)), transforms.ColorJitter(0.3, 0.3, 0.3), transforms.ToTensor(), transforms.Normalize(cfg.mean, cfg.std) ]) train_transforms = transforms.Compose([ transforms.Resize(size=(args.img_size + 20, args.img_size + 20)), #transforms.RandomRotation((-10, 10)), transforms.RandomCrop(size=(args.img_size, args.img_size)), transforms.RandomHorizontalFlip(), transforms.ColorJitter(0.3, 0.3, 0.3), transforms.ToTensor(), transforms.Normalize(cfg.mean, cfg.std) ]) train_transforms_no_color_aug = transforms.Compose([ transforms.Resize(size=(args.img_size + 20, args.img_size + 20)), transforms.RandomHorizontalFlip(),
def __init__(self): ##The top config #self.data_root = '/media/hhy/data/USdata/MergePhase1/test_0.3' #self.log_dir = '/media/hhy/data/code_results/MILs/MIL_H_Attention' self.root = '/remote-home/my/Ultrasound_CV/data/Ruijin/clean' self.log_dir = '/remote-home/my/hhy/Ultrasound_MIL/experiments/PLN1/weighted_sampler+res18/fold4' if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) ##training config self.lr = 1e-4 self.epoch = 50 self.resume = -1 self.batch_size = 1 self.net = Res_Attention() self.net.cuda() self.optimizer = Adam(self.net.parameters(), lr=self.lr) self.lrsch = torch.optim.lr_scheduler.MultiStepLR( self.optimizer, milestones=[10, 30, 50, 70], gamma=0.5) self.logger = Logger(self.log_dir) self.train_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomResizedCrop((224, 224)), transforms.RandomHorizontalFlip(0.5), transforms.RandomVerticalFlip(0.5), transforms.ColorJitter(0.25, 0.25, 0.25, 0.25), transforms.ToTensor() ]) self.test_transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor()]) self.label_name = "手术淋巴结情况(0未转移;1转移)" self.trainbag = RuijinBags(self.root, [0, 1, 2, 3], self.train_transform, label_name=self.label_name) self.testbag = RuijinBags(self.root, [4], self.test_transform, label_name=self.label_name) train_label_list = list( map(lambda x: int(x['label']), self.trainbag.patient_info)) pos_ratio = sum(train_label_list) / len(train_label_list) print(pos_ratio) train_weight = [(1 - pos_ratio) if x > 0 else pos_ratio for x in train_label_list] self.train_sampler = WeightedRandomSampler(weights=train_weight, num_samples=len( self.trainbag)) self.train_loader = DataLoader(self.trainbag, batch_size=self.batch_size, num_workers=8, sampler=self.train_sampler) self.val_loader = DataLoader(self.testbag, batch_size=self.batch_size, shuffle=False, num_workers=8) if self.resume > 0: self.net, self.optimizer, self.lrsch, self.loss, self.global_step = self.logger.load( self.net, self.optimizer, self.lrsch, self.loss, self.resume) else: self.global_step = 0 # self.trainer = MTTrainer(self.net, self.optimizer, self.lrsch, self.loss, self.train_loader, self.val_loader, self.logger, self.global_step, mode=2) self.trainer = MILTrainer(self.net, self.optimizer, self.lrsch, None, self.train_loader, self.val_loader, self.logger, self.global_step)
interpolate_str = _pil_interpolation_to_str[self.interpolation] format_string = self.__class__.__name__ + '(size={0}'.format(self.size) format_string += ', xscale={0}'.format( tuple(round(s, 4) for s in self.xscale)) format_string += ', yscale={0}'.format( tuple(round(r, 4) for r in self.xscale)) format_string += ', interpolation={0})'.format(interpolate_str) return format_string transform = { 'train': transforms.Compose([ transforms.RandomRotation(12, resample=Image.BILINEAR), CenterRandomCrop(size, xscale=(0.6, 1.0), aspect_ratio=(1, 1.4)), transforms.ColorJitter(0.2, 0.1, 0.1, 0.04), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std) ]), 'val': transforms.Compose([ transforms.CenterCrop((512, 640)), transforms.Resize(size, interpolation=Image.BILINEAR), transforms.ToTensor(), transforms.Normalize(mean, std) ]) } if torch.cuda.is_available():
def execute(args): ''' Train the Model ..notes: Standard input image size: 2448 x 2448 Label construction size is dependent on the input image size and the kernels used label_size = 196 #3x3 kernel sizes for all three decode layers and input size 284 label_size = 180 #7x7, 5x5, 3x3 kernel sizes for decode layers and input size 284 label_size = 172 #9x9, 6x6, 3x3 kernel sizes for decode layers and input size 284 label_size = TBD #9x9, 6x6, 3x3 kernel sizes for decode layers and input size 1224 label_size = 500 #9x9, 6x6, 3x3 kernel sizes for decode layers and input size 612 label_size = 420 #7x7, 3x3, 3x3, 3x3, kernel size for decode layers and input size 612 ''' # Satellite Image Transformations t = transforms.Compose([ transforms.Resize(args.input_image_size, interpolation=4), transforms.ColorJitter(), transforms.RandomHorizontalFlip(p=0.3), transforms.ToTensor(), ]) # Mask Transformations t2 = transforms.Compose([ transforms.Resize(args.label_size), transforms.ToTensor(), ]) data_set = landpy.MyDataLoader(args.data_dir, args.label_size, image_transforms=t, mask_transforms=t2) train_loader, validation_loader = landpy.create_data_loaders( data_set, args.training_split, args.batch_size) # Establish the UNet Model & Training parameters unet_model = landpy.UNet(3, 7) if args.start_new_model == 0: unet_path = os.path.join(args.model_paths, f"{args.model_to_load}.pt") unet_model.load_state_dict(torch.load(unet_path)) loss_weights = torch.tensor([ 0.145719925, 0.022623007, 0.133379898, 0.098588677, 0.36688587, 0.222802623, 0.01 ]) use_gpu = torch.cuda.is_available() if use_gpu: torch.cuda.empty_cache() print("GPU Enabled") print(f"Current GPU Memory Usage: {torch.cuda.memory_allocated()}") print("Making Model GPU Based") unet_model = unet_model.cuda() loss_weights = loss_weights.cuda() loss = torch.nn.NLLLoss(weight=loss_weights) # optimizer = torch.optim.SGD(unet_model.parameters(), lr=args.learning_rate, # momentum=args.momentum # ) optimizer = torch.optim.Adam(unet_model.parameters(), lr=args.learning_rate) final_path = os.path.join(args.model_paths, f"{args.final_model_name}.pt") print( f"Number of Images for Training: {int(len(data_set)*args.training_split)}" ) print( f"Number of Images for Validation: {int(len(data_set)*(1-args.training_split))}" ) print(f"Number of Epochs Used: {args.epochs}") print(f"Batch Size Used: {args.batch_size}") print(f"Learning Rate Used: {args.learning_rate}") print(f"Momentum for Optimizer: {args.momentum}") print(f"Final Model Name: {args.final_model_name}") print(f"Loss Weights by Class: {loss_weights}") print("\n") epoch_losses = {} checkpoint_idx = 1 print("Begin Training") for epoch in trange(args.epochs): if use_gpu: torch.cuda.empty_cache() t0 = time.time() total_training_loss = 0 with torch.set_grad_enabled(True): for i, (batch_x_images, batch_y_mask, match_y_class_mask) in enumerate(train_loader): unet_model.train() if use_gpu: batch_x_images = batch_x_images.cuda() match_y_class_mask = match_y_class_mask.cuda() batch_loss = landpy.train_step(batch_x_images, match_y_class_mask, optimizer, loss, unet_model) total_training_loss += batch_loss t1 = time.time() print( f"Total Training Loss for Epoch {epoch} is: {total_training_loss}") total_validation_loss = 0 total_mean_iou = [] with torch.no_grad(): if use_gpu: torch.cuda.empty_cache() for j, (batch_val_x_images, batch_val_y_mask, match_val_y_class_mask) in enumerate(validation_loader): unet_model.eval() if use_gpu: batch_val_x_images = batch_val_x_images.cuda() match_val_y_class_mask = match_val_y_class_mask.cuda() outputs = unet_model(batch_val_x_images) soft_max_output = torch.nn.LogSoftmax(dim=1)(outputs) val_batch_loss = loss(soft_max_output, match_val_y_class_mask.long()) total_validation_loss += val_batch_loss batch_mean_iou = landpy.mean_IOU(soft_max_output, match_val_y_class_mask) total_mean_iou.append(batch_mean_iou) epoch_losses[epoch] = { "Training Loss": total_training_loss.item(), "Validation Loss": total_validation_loss.item(), "Mean IOU": np.mean(np.array(total_mean_iou)), "Execution Time": (t1 - t0) } print( f"Total Validation Loss for Epoch {epoch} is: {total_validation_loss.item()}" ) if epoch % args.checkpoint == 0: # Checkpoint Save checkpoint_path = os.path.join( args.model_paths, f"{args.final_model_name}_chp_{checkpoint_idx}.pt") torch.save(unet_model.state_dict(), checkpoint_path) checkpoint_idx += 1 print("\n") print("Completed Training; Saving Model") torch.save(unet_model.state_dict(), final_path) print("Saving Epoch Losses to DF") epoch_losses_path = os.path.join( args.epoch_loss_dir, args.final_model_name + "_epoch_losses.csv") df = pd.DataFrame.from_dict(epoch_losses, orient='index') df.to_csv(epoch_losses_path)
def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode=False, target_lb=-1): if 'cifar' in dataset or 'svhn' in dataset: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), ]) elif 'imagenet' in dataset: input_size = 224 sized_size = 256 if 'efficientnet' in C.get()['model']['type']: input_size = EfficientNet.get_image_size(C.get()['model']['type']) sized_size = input_size + 32 # TODO # sized_size = int(round(input_size / 224. * 256)) # sized_size = input_size logger.info('size changed to %d/%d.' % (input_size, sized_size)) transform_train = transforms.Compose([ EfficientNetRandomCrop(input_size), transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC), # transforms.RandomResizedCrop(input_size, scale=(0.1, 1.0), interpolation=Image.BICUBIC), transforms.RandomHorizontalFlip(), transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, ), transforms.ToTensor(), Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) transform_test = transforms.Compose([ EfficientNetCenterCrop(input_size), transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: raise ValueError('dataset=%s' % dataset) total_aug = augs = None if isinstance(C.get()['aug'], list): logger.debug('augmentation provided.') transform_train.transforms.insert(0, Augmentation(C.get()['aug'])) else: logger.debug('augmentation: %s' % C.get()['aug']) if C.get()['aug'] == 'uniformaugment': transform_train.transforms.insert(0, UniformAugment()) elif C.get()['aug'] in ['default']: pass else: raise ValueError('not found augmentations. %s' % C.get()['aug']) if C.get()['cutout'] > 0: transform_train.transforms.append(CutoutDefault(C.get()['cutout'])) if dataset == 'cifar10': total_trainset = torchvision.datasets.CIFAR10( root=dataroot, train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) elif dataset == 'reduced_cifar10': total_trainset = torchvision.datasets.CIFAR10( root=dataroot, train=True, download=True, transform=transform_train) sss = StratifiedShuffleSplit(n_splits=1, test_size=46000, random_state=0) # 4000 trainset sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) train_idx, valid_idx = next(sss) targets = [total_trainset.targets[idx] for idx in train_idx] total_trainset = Subset(total_trainset, train_idx) total_trainset.targets = targets testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) elif dataset == 'cifar100': total_trainset = torchvision.datasets.CIFAR100( root=dataroot, train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=True, transform=transform_test) elif dataset == 'svhn': trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train) extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=True, transform=transform_train) total_trainset = ConcatDataset([trainset, extraset]) testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test) elif dataset == 'reduced_svhn': total_trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train) sss = StratifiedShuffleSplit(n_splits=1, test_size=73257 - 1000, random_state=0) # 1000 trainset sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) train_idx, valid_idx = next(sss) targets = [total_trainset.targets[idx] for idx in train_idx] total_trainset = Subset(total_trainset, train_idx) total_trainset.targets = targets testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test) elif dataset == 'imagenet': total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train) testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test) # compatibility total_trainset.targets = [lb for _, lb in total_trainset.samples] elif dataset == 'reduced_imagenet': # randomly chosen indices # idx120 = sorted(random.sample(list(range(1000)), k=120)) idx120 = [ 16, 23, 52, 57, 76, 93, 95, 96, 99, 121, 122, 128, 148, 172, 181, 189, 202, 210, 232, 238, 257, 258, 259, 277, 283, 289, 295, 304, 307, 318, 322, 331, 337, 338, 345, 350, 361, 375, 376, 381, 388, 399, 401, 408, 424, 431, 432, 440, 447, 462, 464, 472, 483, 497, 506, 512, 530, 541, 553, 554, 557, 564, 570, 584, 612, 614, 619, 626, 631, 632, 650, 657, 658, 660, 674, 675, 680, 682, 691, 695, 699, 711, 734, 736, 741, 754, 757, 764, 769, 770, 780, 781, 787, 797, 799, 811, 822, 829, 830, 835, 837, 842, 843, 845, 873, 883, 897, 900, 902, 905, 913, 920, 925, 937, 938, 940, 941, 944, 949, 959 ] total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train) testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test) # compatibility total_trainset.targets = [lb for _, lb in total_trainset.samples] sss = StratifiedShuffleSplit(n_splits=1, test_size=len(total_trainset) - 50000, random_state=0) # 4000 trainset sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) train_idx, valid_idx = next(sss) # filter out train_idx = list( filter(lambda x: total_trainset.labels[x] in idx120, train_idx)) valid_idx = list( filter(lambda x: total_trainset.labels[x] in idx120, valid_idx)) test_idx = list( filter(lambda x: testset.samples[x][1] in idx120, range(len(testset)))) targets = [ idx120.index(total_trainset.targets[idx]) for idx in train_idx ] for idx in range(len(total_trainset.samples)): if total_trainset.samples[idx][1] not in idx120: continue total_trainset.samples[idx] = (total_trainset.samples[idx][0], idx120.index( total_trainset.samples[idx][1])) total_trainset = Subset(total_trainset, train_idx) total_trainset.targets = targets for idx in range(len(testset.samples)): if testset.samples[idx][1] not in idx120: continue testset.samples[idx] = (testset.samples[idx][0], idx120.index(testset.samples[idx][1])) testset = Subset(testset, test_idx) print('reduced_imagenet train=', len(total_trainset)) else: raise ValueError('invalid dataset name=%s' % dataset) if total_aug is not None and augs is not None: total_trainset.set_preaug(augs, total_aug) print('set_preaug-') train_sampler = None if split > 0.0: sss = StratifiedShuffleSplit(n_splits=5, test_size=split, random_state=0) sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) for _ in range(split_idx + 1): train_idx, valid_idx = next(sss) if target_lb >= 0: train_idx = [ i for i in train_idx if total_trainset.targets[i] == target_lb ] valid_idx = [ i for i in valid_idx if total_trainset.targets[i] == target_lb ] train_sampler = SubsetRandomSampler(train_idx) valid_sampler = SubsetSampler(valid_idx) if multinode: train_sampler = torch.utils.data.distributed.DistributedSampler( Subset(total_trainset, train_idx), num_replicas=dist.get_world_size(), rank=dist.get_rank()) else: valid_sampler = SubsetSampler([]) if multinode: train_sampler = torch.utils.data.distributed.DistributedSampler( total_trainset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) logger.info( f'----- dataset with DistributedSampler {dist.get_rank()}/{dist.get_world_size()}' ) trainloader = torch.utils.data.DataLoader( total_trainset, batch_size=batch, shuffle=True if train_sampler is None else False, num_workers=8, pin_memory=True, sampler=train_sampler, drop_last=True) validloader = torch.utils.data.DataLoader(total_trainset, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, sampler=valid_sampler, drop_last=False) testloader = torch.utils.data.DataLoader(testset, batch_size=batch, shuffle=False, num_workers=8, pin_memory=True, drop_last=False) return train_sampler, trainloader, validloader, testloader
def get_transforms(config, image_size=None): config = config.get_dictionary() if image_size is not None: image_size = image_size elif config['estimator'] not in resize_size_dict: image_size = 32 else: image_size = resize_size_dict[config['estimator']] val_transforms = transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), ]) if parse_bool(config['aug']): if parse_bool(config['auto_aug']): # from .transforms import AutoAugment data_transforms = { 'train': transforms.Compose([ # AutoAugment(), transforms.Resize(image_size), transforms.RandomCrop(image_size, padding=int(image_size / 8)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]), 'val': val_transforms, } else: transform_list = [] if parse_bool(config['jitter']): transform_list.append( transforms.ColorJitter(brightness=config['brightness'], saturation=config['saturation'], hue=config['hue'])) if parse_bool(config['affine']): transform_list.append( transforms.RandomAffine(degrees=config['degree'], shear=config['shear'])) transform_list.append(transforms.RandomResizedCrop(image_size)) transform_list.append(transforms.RandomCrop(image_size, padding=4)) if parse_bool(config['random_flip']): transform_list.append(transforms.RandomHorizontalFlip()) transform_list.append(transforms.ToTensor()) data_transforms = { 'train': transforms.Compose(transform_list), 'val': val_transforms } else: data_transforms = { 'train': transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), ]), 'val': val_transforms, } return data_transforms
def __init__(self, path_root, t_task, n_way, k_shot, k_query, x_dim, split, augment='0', test=None, shuffle=True, fetch_global=False): self.t_task = t_task self.n_way = n_way self.k_shot = k_shot self.k_query = k_query self.x_dim = list(map(int, x_dim.split(','))) self.split = split self.shuffle = shuffle self.path_root = path_root self.fet_global = fetch_global if augment == '0': self.transform = transforms.Compose([ transforms.Lambda(f1), transforms.Resize(self.x_dim[:2]), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) elif augment == '1': if self.split == 'train': self.transform = transforms.Compose([ # lambda x: Image.open(x).convert('RGB'), transforms.Lambda(f1), transforms.Resize( (self.x_dim[0] + 20, self.x_dim[1] + 20)), transforms.RandomCrop(self.x_dim[:2]), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=.1, contrast=.1, saturation=.1, hue=.1), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) else: self.transform = transforms.Compose([ # lambda x: Image.open(x).convert('RGB'), transforms.Lambda(f1), transforms.Resize( (self.x_dim[0] + 20, self.x_dim[1] + 20)), transforms.RandomCrop(self.x_dim[:2]), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) self.path = os.path.join(path_root, 'images') self.lmdb_file = os.path.join(path_root, "lmdb_data", "%s.lmdb" % self.split) if not os.path.exists(self.lmdb_file): print("lmdb_file is not found, start to generate %s" % self.lmdb_file) self._generate_lmdb() # read lmdb_file self.env = lmdb.open(self.lmdb_file, subdir=False, readonly=True, lock=False, readahead=False, meminit=False) with self.env.begin(write=False) as txn: self.total_sample = pyarrow.deserialize(txn.get(b'__len__')) self.keys = pyarrow.deserialize(txn.get(b'__keys__')) self.label2num = pyarrow.deserialize(txn.get(b'__label2num__')) self.num2label = pyarrow.deserialize(txn.get(b'__num2label__')) self.image_labels = [i.decode() for i in self.keys] self.total_cls = len(self.num2label) self.dic_img_label = defaultdict(list) for i in self.image_labels: self.dic_img_label[i[:9]].append(i) self.support_set_size = self.n_way * self.k_shot # num of samples per support set self.query_set_size = self.n_way * self.k_query self.episode = self.total_sample // ( self.t_task * (self.support_set_size + self.query_set_size)) # how many episode if platform.system().lower() == 'windows': self.platform = "win" del self.env elif platform.system().lower() == 'linux': self.platform = "linux"
def __init__(self, cache_dir, image_dir, split, chunk_size=(1.5, 1.5), chunk_thresh=0.3, chunk_margin=(0.2, 0.2), nb_pts=-1, num_rgbd_frames=0, resize=(160, 120), image_normalizer=None, k=3, z_rot=None, flip=0.0, color_jitter=None, to_tensor=False, ): """ Args: cache_dir (str): path to cache of 3D point clouds, 3D semantic labels and RGB-D overlap info image_dir (str): path to 2D images, depth maps and poses split: chunk_size (tuple): xy chunk size chunk_thresh (float): minimum number of labeled points within a chunk chunk_margin (tuple): margin to calculate ratio of labeled points within a chunk nb_pts (int): number of points to resample in a chunk num_rgbd_frames (int): number of RGB-D frames to choose resize (tuple): target image size image_normalizer (tuple, optional): (mean, std) k (int): k-nn unprojected neighbors of target points z_rot (tuple, optional): range of rotation (degree instead of rad) flip (float): probability to flip horizontally color_jitter (tuple, optional): paramters of color jitter to_tensor (bool): whether to convert to torch.Tensor """ super(ScanNet2D3DChunks, self).__init__() # cache: pickle files containing point clouds, 3D labels and rgbd overlap self.cache_dir = cache_dir # includes color, depth, 2D label self.image_dir = image_dir # load split self.split = split with open(osp.join(self.split_dir, self.split_map[split]), 'r') as f: self.scan_ids = [line.rstrip() for line in f.readlines()] # ---------------------------------------------------------------------------- # # Build label mapping # ---------------------------------------------------------------------------- # # read tsv file to get raw to nyu40 mapping (dict) self.raw_to_nyu40_mapping = read_label_mapping(self.label_id_tsv_path, label_from='id', label_to='nyu40id', as_int=True) self.raw_to_nyu40 = np.zeros(max(self.raw_to_nyu40_mapping.keys()) + 1, dtype=np.int64) for key, value in self.raw_to_nyu40_mapping.items(): self.raw_to_nyu40[key] = value # scannet self.scannet_mapping = load_class_mapping(self.scannet_classes_path) assert len(self.scannet_mapping) == 20 # nyu40 -> scannet self.nyu40_to_scannet = np.full(shape=41, fill_value=self.ignore_value, dtype=np.int64) self.nyu40_to_scannet[list(self.scannet_mapping.keys())] = np.arange(len(self.scannet_mapping)) # scannet -> nyu40 self.scannet_to_nyu40 = np.array(list(self.scannet_mapping.keys()) + [0], dtype=np.int64) # raw -> scannet self.raw_to_scannet = self.nyu40_to_scannet[self.raw_to_nyu40] self.class_names = tuple(self.scannet_mapping.values()) # ---------------------------------------------------------------------------- # # 3D # ---------------------------------------------------------------------------- # # The height / z-axis is ignored in fact. self.chunk_size = np.array(chunk_size, dtype=np.float32) self.chunk_thresh = chunk_thresh self.chunk_margin = np.array(chunk_margin, dtype=np.float32) self.nb_pts = nb_pts # ---------------------------------------------------------------------------- # # 2D # ---------------------------------------------------------------------------- # self.num_rgbd_frames = num_rgbd_frames self.resize = resize self.image_normalizer = image_normalizer # ---------------------------------------------------------------------------- # # 2D-3D # ---------------------------------------------------------------------------- # self.k = k if num_rgbd_frames > 0 and resize: depth_size = (640, 480) # intrinsic matrix is based on 640x480 depth maps. self.resize_scale = (depth_size[0] / resize[0], depth_size[1] / resize[1]) else: self.resize_scale = None # ---------------------------------------------------------------------------- # # Augmentation # ---------------------------------------------------------------------------- # self.z_rot = z_rot self.flip = flip self.color_jitter = T.ColorJitter(*color_jitter) if color_jitter else None self.to_tensor = to_tensor # ---------------------------------------------------------------------------- # # Load cache data # ---------------------------------------------------------------------------- # # import time # tic = time.time() self._load_dataset() # print(time.time() - tic) logger = logging.getLogger(__name__) logger.info(str(self))
# vessel_model = VesselNet('./vessels/') image = process(image, size=cfg.img_size, crop='normal', preprocessing='clahe', fourth=None) image = transforms.ToPILImage()(image) if self.transform: image = self.transform(image) return image, label transforms_train = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomRotation((-150, 150)), transforms.RandomVerticalFlip(), transforms.ColorJitter(brightness=0.1, contrast=0.5, saturation=0.1, hue=0.1), # transforms.RandomResizedCrop(cfg.img_size_crop), transforms.ToTensor(), transforms.Normalize([0.406, 0.456, 0.485], [0.225, 0.224, 0.229]), ]) transforms_valid = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.406, 0.456, 0.485], [0.225, 0.224, 0.229]) ])
def __init__(self, opt, phase="train"): # TODO split the dataset of val and test if phase == "val": phase = "test" opt.load_dataset_mode = 'reader' super(Cifar10Dataset, self).__init__(opt, phase) self.data_dir = opt.cifar10_dataset_dir self.data_name = CIFAR10 self.x_transforms_train = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter(brightness=0.24705882352941178), transforms.Resize((opt.imsize, opt.imsize)), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) ]) self.x_transforms_test = transforms.Compose([ transforms.Resize((opt.imsize, opt.imsize)), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) ]) self.y_transforms = None if self.opt.load_dataset_mode == 'dir': self.data = [] # image_paths,targets self.label2Indices = defaultdict(list) image_dir = os.path.join(self.data_dir, phase) self._labels = os.listdir(image_dir) # get label to targets, dict type self.label2target = dict([ (label, target) for target, label in enumerate(self.labels) ]) self.target2label = dict([ (target, label) for target, label in enumerate(self.labels) ]) if not os.path.exists(image_dir): raise FileNotFoundError( f"Image Dir {image_dir} not exists, please check it") for root, label_dirs, files in os.walk(image_dir): for file in files: label = os.path.basename(root) image_path = os.path.join(root, file) target = self.label2target[label] self.label2Indices[label].append(len(self.data)) self.data.append( Bunch(image_path=image_path, target=target)) elif self.opt.load_dataset_mode == 'reader': dataset = datasets.CIFAR10(root=os.path.join( self.data_dir, 'raw_data'), train=self.isTrain, download=True) self.data, self._labels, self.label2Indices, self.label2target, self.target2label = prepare_datas_by_standard_data( dataset) else: raise ValueError( f"Expected load_dataset_mode in [dir,reader], but got {self.opt.load_dataset_mode}" )
import os import torch from glob import glob from PIL import Image from torch.utils.data import Dataset from torchvision.transforms import transforms from kaolin.rep import TriangleMesh from config import * img_transform = transforms.Compose([ transforms.Resize(IMG_SIZE), transforms.ColorJitter(brightness=0.4, saturation=0.4, contrast=0.4), transforms.ToTensor() ]) vp_num = CUBOID_NUM + SPHERE_NUM + CONE_NUM class PointMixUpDataset(Dataset): def __init__(self, dataset_name): self.dataset_path = os.path.join(DATASET_ROOT, dataset_name) self.rgb_paths = sorted(glob(self.dataset_path + '/rgb*.png')) self.silhouette_paths = sorted( glob(self.dataset_path + '/silhouette*.png')) self.obj_paths = sorted(glob(self.dataset_path + '/mesh*.obj')) def __len__(self) -> int: return len(self.rgb_paths) def __getitem__(self, item) -> dict: rgb_path = self.rgb_paths[item]
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): self.trans = transforms.ColorJitter(brightness, contrast, saturation, hue)
).to(device) # SETUP DATA TRANSFORMS if args.random: r = args.random train_transforms = transforms.Compose([ transforms.ToTensor(), #transforms.RandomApply([ # transforms.GaussianBlur(3, sigma=(0.1, 2.0)) #], p=0.2), transforms.RandomApply([ transforms.Grayscale(num_output_channels=3) ], p=0.2), transforms.RandomApply([ transforms.ColorJitter(brightness=r, contrast=r, saturation=r, hue=r) ]), transforms.RandomApply([ transforms.RandomAffine(r*10, shear=r*10) ]), transforms.RandomResizedCrop((32,32), scale=(1-r, 1.0)), transforms.RandomHorizontalFlip(p=0.5), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) test_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) else: train_transforms = transforms.ToTensor() test_transforms = transforms.ToTensor()