コード例 #1
0
    def get_data_loader(self, data_file,
                        aug):  #parameters that would change on train/val set
        #         transform = self.trans_loader.get_composed_transform(aug) # TODO: maybe change here?
        pre_transform = self.trans_loader.get_crop_transform(aug)
        aug_transform = self.trans_loader.get_aug_transform(
            self.aug_type, aug_target=self.aug_target)
        post_transform = self.trans_loader.get_hdf5_transform(aug)
        dataset = AugSetDataset(data_file,
                                self.batch_size,
                                pre_transform=pre_transform,
                                post_transform=post_transform,
                                aug_target=self.aug_target)
        dataset.set_aug_transform(aug_transform)
        self.dataset = dataset
        sampler = EpisodicBatchSampler(
            len(dataset), self.n_way,
            self.n_episode)  # sample classes randomly
        collate_fn = self.get_collate(
        )  # to get different transform for every batch
        num_workers = 0 if self.aug_target == 'batch' else 12  # BUGFIX: there's a bug when multiprocessing, but we can still multi-process when aug_target != 'batch' because only batch need align set_aug_transform for every batch?
        data_loader_params = dict(batch_sampler=sampler,
                                  num_workers=num_workers,
                                  pin_memory=True,
                                  collate_fn=collate_fn)

        data_loader = torch.utils.data.DataLoader(dataset,
                                                  **data_loader_params)
        return data_loader
コード例 #2
0
 def get_data_loader(self, data_file, aug):
     transform = self.trans_loader.get_composed_transform(aug)
     dataset = SetDataset( data_file , self.batch_size, transform )
     sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide )
     data_loader_params = dict(batch_sampler = sampler,  num_workers = 12, pin_memory = True)
     data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
     return data_loader
コード例 #3
0
 def get_data_loader(self, data_file, aug): #parameters that would change on train/val set
   transform = self.trans_loader.get_composed_transform(aug)
   if isinstance(data_file, list):
     dataset = MultiSetDataset( data_file , self.batch_size, transform )
     sampler = MultiEpisodicBatchSampler(dataset.lens(), self.n_way, self.n_eposide )
   else:
     dataset = SetDataset( data_file , self.batch_size, transform )
     sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide )
   data_loader_params = dict(batch_sampler = sampler,  num_workers=4)
   data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
   return data_loader
コード例 #4
0
 def get_data_loader(self,
                     aug):  # parameters that would change on train/val set
     transform = self.trans_loader.get_composed_transform(aug)
     dataset = SetDataset(self.data_file, self.dataset_dir, self.batch_size,
                          transform)
     sampler = EpisodicBatchSampler(self.mode, len(dataset), self.n_way,
                                    self.n_episode)
     data_loader_params = dict(batch_sampler=sampler,
                               num_workers=8,
                               pin_memory=True)
     data_loader = torch.utils.data.DataLoader(dataset,
                                               **data_loader_params)
     return data_loader
コード例 #5
0
 def get_data_loader(self, data_file, aug): #parameters that would change on train/val set
   transform = self.trans_loader.get_composed_transform(aug)
   if isinstance(data_file, list):
     dataset = MultiSetDataset( data_file , self.batch_size, transform )# 包含多个(各个数据集关系数目之和)sub_dataloader,里面每个类别是一个Dataloader
     sampler = MultiEpisodicBatchSampler(dataset.lens(), self.n_way, self.n_eposide )# dataset.lens()是个list 每个数据集的关系类别数目
   else:
     dataset = SetDataset( data_file , self.batch_size, transform )
     # dataset 里面有1)标签列表cl_list 0~N-1, 2)源数据meta 3)sub_dataloader 里面是 每个类 对应的dataloader
     # 4)sub_meta 每个类对应的样本, len(dataset)是标签数目
     sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide )
   data_loader_params = dict(batch_sampler = sampler,  num_workers=4)
   data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
   return data_loader
コード例 #6
0
 def get_data_loader(self, data,
                     aug):  #parameters that would change on train/val set
     #         transform = lambda x:x
     dataset = VirtualSetDataset(data, self.batch_size)  #, transform )
     sampler = EpisodicBatchSampler(
         len(dataset), self.n_way,
         self.n_episode)  # sample classes randomly
     data_loader_params = dict(batch_sampler=sampler,
                               num_workers=12,
                               pin_memory=True)
     data_loader = torch.utils.data.DataLoader(dataset,
                                               **data_loader_params)
     return data_loader
コード例 #7
0
ファイル: datamgr.py プロジェクト: parsatorb/fsl_ssl
    def get_data_loader(self, data_file,
                        aug):  #parameters that would change on train/val set
        transform = self.trans_loader.get_composed_transform(aug)

        ## Add transform for jigsaw puzzle
        self.transform_patch_jigsaw = None
        self.transform_jigsaw = None
        if self.jigsaw:
            if aug:
                self.transform_jigsaw = transforms.Compose([
                    # transforms.Resize(256),
                    # transforms.CenterCrop(225),
                    ## follow paper setting:
                    # transforms.Resize(255),
                    # transforms.CenterCrop(240),
                    ## setting of my experiment before 0515
                    transforms.RandomResizedCrop(255, scale=(0.5, 1.0)),
                    transforms.RandomHorizontalFlip()
                ])
                # transforms.ToTensor(),
                # transforms.Normalize(mean=[0.485, 0.456, 0.406],
                #                      std =[0.229, 0.224, 0.225])])
            else:
                self.transform_jigsaw = transforms.Compose([
                    # transforms.Resize(256),
                    # transforms.CenterCrop(225),])
                    transforms.RandomResizedCrop(225, scale=(0.5, 1.0))
                ])
                # transforms.ToTensor(),
                # transforms.Normalize(mean=[0.485, 0.456, 0.406],
                #                      std =[0.229, 0.224, 0.225])])
            self.transform_patch_jigsaw = transforms.Compose([
                transforms.RandomCrop(64),
                # transforms.Resize((75, 75), Image.BILINEAR),
                transforms.Lambda(self.rgb_jittering),
                transforms.ToTensor(),
                # transforms.Normalize(mean=[0.485, 0.456, 0.406],
                # std =[0.229, 0.224, 0.225])
            ])

        dataset = SetDataset(data_file , self.batch_size, transform, jigsaw=self.jigsaw, \
                            transform_jigsaw=self.transform_jigsaw, transform_patch_jigsaw=self.transform_patch_jigsaw, \
                            rotation=self.rotation, isAircraft=self.isAircraft, grey=self.grey)
        sampler = EpisodicBatchSampler(len(dataset), self.n_way,
                                       self.n_eposide)
        data_loader_params = dict(batch_sampler=sampler,
                                  num_workers=NUM_WORKERS,
                                  pin_memory=True)
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  **data_loader_params)
        return data_loader
コード例 #8
0
 def get_data_loader(self, data_file, aug):
     transform = self.trans_loader.get_composed_transform(aug)
     if isinstance(data_file, list):  # Multi domain
         dataset = MultiSetDataset(data_file, self.batch_size, transform)
         sampler = MultiEpisodicBatchSampler(dataset.lens(), self.n_way,
                                             self.n_eposide)
     else:  # Single domain
         dataset = SetDataset(data_file, self.batch_size, transform)
         sampler = EpisodicBatchSampler(len(dataset), self.n_way,
                                        self.n_eposide)
     data_loader = torch.utils.data.DataLoader(dataset,
                                               batch_sampler=sampler,
                                               num_workers=4)
     return data_loader
コード例 #9
0
    def get_data_loader(self,
                        data_file,
                        aug,
                        lang_dir=None,
                        normalize=True,
                        vocab=None,
                        adversary=None,
                        max_class=None,
                        max_img_per_class=None,
                        max_lang_per_class=None,
                        fixed_noise=None,
                        confound_noise=None,
                        confound_noise_class_weight=None
                        ):  #parameters that would change on train/val set
        # TODO: FIXME
        if fixed_noise:
            transformers = [
                self.trans_loader.get_composed_transform(
                    aug,
                    normalize=normalize,
                    confound_noise=confound_noise,
                    confound_noise_class_weight=confound_noise_class_weight)
                for _ in range(100)
            ]
        else:
            # Use the same transformer
            transform = self.trans_loader.get_composed_transform(
                aug, normalize=normalize)
            transformers = [transform for _ in range(100)]

        dataset = SetDataset(self.name,
                             data_file,
                             self.batch_size,
                             transformers,
                             args=self.args,
                             lang_dir=lang_dir,
                             vocab=vocab,
                             adversary=adversary,
                             max_class=max_class,
                             max_img_per_class=max_img_per_class,
                             max_lang_per_class=max_lang_per_class)
        sampler = EpisodicBatchSampler(len(dataset), self.n_way,
                                       self.n_episode)
        data_loader_params = dict(batch_sampler=sampler,
                                  num_workers=self.args.n_workers,
                                  pin_memory=self.args.pin_memory)
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  **data_loader_params)
        return data_loader
コード例 #10
0
ファイル: datamanager.py プロジェクト: CRuJia/CrossDomain
 def get_data_loader(self, aug=False):
     transform = transforms.Compose([
     transforms.Resize((self.image_size, self.image_size)),
     transforms.ToTensor(),
     transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))
 ])
     if isinstance(self.data_file, list):
         dataset = MultiSetDataset(data_files=self.data_file, batch_size=self.batch_size, transform=transform)
         sampler = MultiEpisodicBatchSampler(dataset.lens(), self.n_way, self.n_episode)
     else:
         dataset = SetDataset(data_file=self.data_file, batch_size=self.batch_size, transform=transform)
         sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_episode)
     data_loader_params = dict(batch_sampler = sampler, num_workers=8)
     data_loader = DataLoader(dataset, **data_loader_params)
     return data_loader
コード例 #11
0
 def get_data_loader(self, data_file,
                     aug):  #parameters that would change on train/val set
     transform = self.trans_loader.get_composed_transform(aug)
     dataset = SetDataset(data_file, self.batch_size, transform)
     sampler = EpisodicBatchSampler(len(dataset), self.n_way,
                                    self.n_eposide)
     if sys.platform == "win32":  # Note: windows system doesn't support num_workers multiple threads
         data_loader_params = dict(batch_sampler=sampler, pin_memory=True)
     elif sys.platform == "linux":
         data_loader_params = dict(batch_sampler=sampler,
                                   num_workers=8,
                                   pin_memory=True)
     else:
         assert False, "Unknown OS!"
     data_loader = torch.utils.data.DataLoader(dataset,
                                               **data_loader_params)
     return data_loader
コード例 #12
0
    def get_data_loader(self, data_file, aug):
        #         pre_transform = self.trans_loader.get_simple_transform(aug)
        #         aug_transform = self.trans_loader.get_vae_transform(self.vaegan, self.lambda_zlogvar, self.fake_prob)
        #         post_transform = self.trans_loader.get_hdf5_transform(aug, inputs='tensor')
        transform = self.trans_loader.get_composed_transform(aug)
        fake_img_transform = self.trans_loader.get_hdf5_transform(aug)
        dataset = VAESetDataset(data_file, self.batch_size, transform,
                                fake_img_transform, self.vaegan_params)
        #         dataset = VAESetDataset(data_file , self.batch_size, pre_transform=pre_transform, post_transform=post_transform, aug_transform=aug_transform)
        sampler = EpisodicBatchSampler(
            len(dataset), self.n_way,
            self.n_episode)  # sample classes randomly
        data_loader_params = dict(batch_sampler=sampler,
                                  num_workers=0,
                                  pin_memory=True)  # to debug
        #         data_loader_params = dict(batch_sampler = sampler,  num_workers = 0, pin_memory = False) # to debug
        # TODO: cancel debug mode
        #         data_loader_params = dict(batch_sampler = sampler,  num_workers = 12, pin_memory = True)
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  **data_loader_params)

        return data_loader
コード例 #13
0
    def get_data_loader(
        self,
        data_file,
        aug,
        lang_dir=None,
        normalize=True,
        vocab=None,
        max_class=None,
        max_img_per_class=None,
        max_lang_per_class=None,
    ):
        transform = self.trans_loader.get_composed_transform(
            aug, normalize=normalize)

        dataset = SetDataset(
            self.name,
            data_file,
            self.batch_size,
            transform,
            args=self.args,
            lang_dir=lang_dir,
            vocab=vocab,
            max_class=max_class,
            max_img_per_class=max_img_per_class,
            max_lang_per_class=max_lang_per_class,
        )
        sampler = EpisodicBatchSampler(len(dataset), self.n_way,
                                       self.n_episode)
        data_loader_params = dict(
            batch_sampler=sampler,
            num_workers=self.args.n_workers,
            pin_memory=True,
        )
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  **data_loader_params)
        return data_loader
コード例 #14
0
ファイル: datamgr.py プロジェクト: snap-stanford/comet
 def get_data_loader(self, root='./filelists/tabula_muris', mode='train'): #parameters that would change on train/val set
     dataset = SetDataset(root=root, mode=mode, min_samples=self.batch_size)
     sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide )  
     data_loader_params = dict(batch_sampler = sampler,  num_workers = 4, pin_memory = True)       
     data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
     return data_loader