def get_data_loader(self, data_file, aug):
     transform = self.trans_loader.get_composed_transform(aug)
     dataset = SetDataset( data_file , self.batch_size, transform )
     sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide )
     data_loader_params = dict(batch_sampler = sampler,  num_workers = 12, pin_memory = True)
     data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
     return data_loader
Exemplo n.º 2
0
 def get_data_loader(self, data_file, aug): #parameters that would change on train/val set
   transform = self.trans_loader.get_composed_transform(aug)
   if isinstance(data_file, list):
     dataset = MultiSetDataset( data_file , self.batch_size, transform )
     sampler = MultiEpisodicBatchSampler(dataset.lens(), self.n_way, self.n_eposide )
   else:
     dataset = SetDataset( data_file , self.batch_size, transform )
     sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide )
   data_loader_params = dict(batch_sampler = sampler,  num_workers=4)
   data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
   return data_loader
Exemplo n.º 3
0
 def get_data_loader(self,
                     aug):  # parameters that would change on train/val set
     transform = self.trans_loader.get_composed_transform(aug)
     dataset = SetDataset(self.data_file, self.dataset_dir, self.batch_size,
                          transform)
     sampler = EpisodicBatchSampler(self.mode, len(dataset), self.n_way,
                                    self.n_episode)
     data_loader_params = dict(batch_sampler=sampler,
                               num_workers=8,
                               pin_memory=True)
     data_loader = torch.utils.data.DataLoader(dataset,
                                               **data_loader_params)
     return data_loader
Exemplo n.º 4
0
 def get_data_loader(self, data_file, aug): #parameters that would change on train/val set
   transform = self.trans_loader.get_composed_transform(aug)
   if isinstance(data_file, list):
     dataset = MultiSetDataset( data_file , self.batch_size, transform )# 包含多个(各个数据集关系数目之和)sub_dataloader,里面每个类别是一个Dataloader
     sampler = MultiEpisodicBatchSampler(dataset.lens(), self.n_way, self.n_eposide )# dataset.lens()是个list 每个数据集的关系类别数目
   else:
     dataset = SetDataset( data_file , self.batch_size, transform )
     # dataset 里面有1)标签列表cl_list 0~N-1, 2)源数据meta 3)sub_dataloader 里面是 每个类 对应的dataloader
     # 4)sub_meta 每个类对应的样本, len(dataset)是标签数目
     sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide )
   data_loader_params = dict(batch_sampler = sampler,  num_workers=4)
   data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
   return data_loader
Exemplo n.º 5
0
    def get_data_loader(self, data_file,
                        aug):  #parameters that would change on train/val set
        transform = self.trans_loader.get_composed_transform(aug)

        ## Add transform for jigsaw puzzle
        self.transform_patch_jigsaw = None
        self.transform_jigsaw = None
        if self.jigsaw:
            if aug:
                self.transform_jigsaw = transforms.Compose([
                    # transforms.Resize(256),
                    # transforms.CenterCrop(225),
                    ## follow paper setting:
                    # transforms.Resize(255),
                    # transforms.CenterCrop(240),
                    ## setting of my experiment before 0515
                    transforms.RandomResizedCrop(255, scale=(0.5, 1.0)),
                    transforms.RandomHorizontalFlip()
                ])
                # transforms.ToTensor(),
                # transforms.Normalize(mean=[0.485, 0.456, 0.406],
                #                      std =[0.229, 0.224, 0.225])])
            else:
                self.transform_jigsaw = transforms.Compose([
                    # transforms.Resize(256),
                    # transforms.CenterCrop(225),])
                    transforms.RandomResizedCrop(225, scale=(0.5, 1.0))
                ])
                # transforms.ToTensor(),
                # transforms.Normalize(mean=[0.485, 0.456, 0.406],
                #                      std =[0.229, 0.224, 0.225])])
            self.transform_patch_jigsaw = transforms.Compose([
                transforms.RandomCrop(64),
                # transforms.Resize((75, 75), Image.BILINEAR),
                transforms.Lambda(self.rgb_jittering),
                transforms.ToTensor(),
                # transforms.Normalize(mean=[0.485, 0.456, 0.406],
                # std =[0.229, 0.224, 0.225])
            ])

        dataset = SetDataset(data_file , self.batch_size, transform, jigsaw=self.jigsaw, \
                            transform_jigsaw=self.transform_jigsaw, transform_patch_jigsaw=self.transform_patch_jigsaw, \
                            rotation=self.rotation, isAircraft=self.isAircraft, grey=self.grey)
        sampler = EpisodicBatchSampler(len(dataset), self.n_way,
                                       self.n_eposide)
        data_loader_params = dict(batch_sampler=sampler,
                                  num_workers=NUM_WORKERS,
                                  pin_memory=True)
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  **data_loader_params)
        return data_loader
Exemplo n.º 6
0
 def get_data_loader(self, data_file, aug):
     transform = self.trans_loader.get_composed_transform(aug)
     if isinstance(data_file, list):  # Multi domain
         dataset = MultiSetDataset(data_file, self.batch_size, transform)
         sampler = MultiEpisodicBatchSampler(dataset.lens(), self.n_way,
                                             self.n_eposide)
     else:  # Single domain
         dataset = SetDataset(data_file, self.batch_size, transform)
         sampler = EpisodicBatchSampler(len(dataset), self.n_way,
                                        self.n_eposide)
     data_loader = torch.utils.data.DataLoader(dataset,
                                               batch_sampler=sampler,
                                               num_workers=4)
     return data_loader
Exemplo n.º 7
0
    def get_data_loader(self,
                        data_file,
                        aug,
                        lang_dir=None,
                        normalize=True,
                        vocab=None,
                        adversary=None,
                        max_class=None,
                        max_img_per_class=None,
                        max_lang_per_class=None,
                        fixed_noise=None,
                        confound_noise=None,
                        confound_noise_class_weight=None
                        ):  #parameters that would change on train/val set
        # TODO: FIXME
        if fixed_noise:
            transformers = [
                self.trans_loader.get_composed_transform(
                    aug,
                    normalize=normalize,
                    confound_noise=confound_noise,
                    confound_noise_class_weight=confound_noise_class_weight)
                for _ in range(100)
            ]
        else:
            # Use the same transformer
            transform = self.trans_loader.get_composed_transform(
                aug, normalize=normalize)
            transformers = [transform for _ in range(100)]

        dataset = SetDataset(self.name,
                             data_file,
                             self.batch_size,
                             transformers,
                             args=self.args,
                             lang_dir=lang_dir,
                             vocab=vocab,
                             adversary=adversary,
                             max_class=max_class,
                             max_img_per_class=max_img_per_class,
                             max_lang_per_class=max_lang_per_class)
        sampler = EpisodicBatchSampler(len(dataset), self.n_way,
                                       self.n_episode)
        data_loader_params = dict(batch_sampler=sampler,
                                  num_workers=self.args.n_workers,
                                  pin_memory=self.args.pin_memory)
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  **data_loader_params)
        return data_loader
Exemplo n.º 8
0
 def get_data_loader(self, aug=False):
     transform = transforms.Compose([
     transforms.Resize((self.image_size, self.image_size)),
     transforms.ToTensor(),
     transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))
 ])
     if isinstance(self.data_file, list):
         dataset = MultiSetDataset(data_files=self.data_file, batch_size=self.batch_size, transform=transform)
         sampler = MultiEpisodicBatchSampler(dataset.lens(), self.n_way, self.n_episode)
     else:
         dataset = SetDataset(data_file=self.data_file, batch_size=self.batch_size, transform=transform)
         sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_episode)
     data_loader_params = dict(batch_sampler = sampler, num_workers=8)
     data_loader = DataLoader(dataset, **data_loader_params)
     return data_loader
Exemplo n.º 9
0
 def get_data_loader(self, data_file,
                     aug):  #parameters that would change on train/val set
     transform = self.trans_loader.get_composed_transform(aug)
     dataset = SetDataset(data_file, self.batch_size, transform)
     sampler = EpisodicBatchSampler(len(dataset), self.n_way,
                                    self.n_eposide)
     if sys.platform == "win32":  # Note: windows system doesn't support num_workers multiple threads
         data_loader_params = dict(batch_sampler=sampler, pin_memory=True)
     elif sys.platform == "linux":
         data_loader_params = dict(batch_sampler=sampler,
                                   num_workers=8,
                                   pin_memory=True)
     else:
         assert False, "Unknown OS!"
     data_loader = torch.utils.data.DataLoader(dataset,
                                               **data_loader_params)
     return data_loader
Exemplo n.º 10
0
    def get_data_loader(
        self,
        data_file,
        aug,
        lang_dir=None,
        normalize=True,
        vocab=None,
        max_class=None,
        max_img_per_class=None,
        max_lang_per_class=None,
    ):
        transform = self.trans_loader.get_composed_transform(
            aug, normalize=normalize)

        dataset = SetDataset(
            self.name,
            data_file,
            self.batch_size,
            transform,
            args=self.args,
            lang_dir=lang_dir,
            vocab=vocab,
            max_class=max_class,
            max_img_per_class=max_img_per_class,
            max_lang_per_class=max_lang_per_class,
        )
        sampler = EpisodicBatchSampler(len(dataset), self.n_way,
                                       self.n_episode)
        data_loader_params = dict(
            batch_sampler=sampler,
            num_workers=self.args.n_workers,
            pin_memory=True,
        )
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  **data_loader_params)
        return data_loader
Exemplo n.º 11
0
 def get_data_loader(self, data_file, aug,ifshuffle): #parameters that would change on train/val set
     transform = self.trans_loader.get_composed_transform(aug)
     dataset = SetDataset( data_file , transform )
     data_loader_params = dict(batch_size = self.batch_size,  num_workers = 12, shuffle=ifshuffle, pin_memory = True)       
     data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
     return data_loader
Exemplo n.º 12
0
 def get_data_loader(self, root='./filelists/tabula_muris', mode='train'): #parameters that would change on train/val set
     dataset = SetDataset(root=root, mode=mode, min_samples=self.batch_size)
     sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide )  
     data_loader_params = dict(batch_sampler = sampler,  num_workers = 4, pin_memory = True)       
     data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
     return data_loader