예제 #1
0
  def get_data_loader(self, data_file, aug): #parameters that would change on train/val set
    transform = self.trans_loader.get_composed_transform(aug)
    dataset = SimpleDataset(data_file, transform)
    data_loader_params = dict(batch_size = self.batch_size, shuffle = True, num_workers = 4, pin_memory = True)
    data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)

    return data_loader
예제 #2
0
 def get_data_loader(self, data_file, aug):
     transform = self.trans_loader.get_composed_transform(aug)
     dataset = SimpleDataset(data_file, transform)
     data_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=self.batch_size,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)
     return data_loader
예제 #3
0
    def get_data_loader(self, data_path, load_set, aug):
        transform = self.trans_loader.get_composed_transform(aug)
        dataset = SimpleDataset(data_path, load_set, transform)
        data_loader_params = dict(batch_size=self.batch_size,
                                  shuffle=True,
                                  num_workers=self.num_workers,
                                  pin_memory=True)
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  **data_loader_params)

        return data_loader
예제 #4
0
 def get_data_loader_dct(self, data_file, aug, filter_size):
     transform = self.trans_loader.get_composed_transform_dct(
         aug, filter_size)
     dataset = SimpleDataset(data_file, transform, dct_status=True)
     data_loader_params = dict(batch_size=self.batch_size,
                               shuffle=False,
                               num_workers=16,
                               pin_memory=True)
     data_loader = torch.utils.data.DataLoader(dataset,
                                               **data_loader_params)
     return data_loader
예제 #5
0
    def get_data_loader(self, data_file,
                        aug):  #parameters that would change on train/val set
        transform = self.trans_loader.get_composed_transform(aug)

        ## Add transform for jigsaw puzzle
        self.transform_patch_jigsaw = None
        self.transform_jigsaw = None
        if self.jigsaw:
            if aug:
                self.transform_jigsaw = transforms.Compose([
                    # transforms.Resize(256),
                    # transforms.CenterCrop(225),
                    ## follow paper setting:
                    # transforms.Resize(255),
                    # transforms.CenterCrop(240),
                    ## setting of my experiment before 0515
                    transforms.RandomResizedCrop(255, scale=(0.5, 1.0)),
                    transforms.RandomHorizontalFlip()
                ])
                # transforms.ToTensor(),
                # transforms.Normalize(mean=[0.485, 0.456, 0.406],
                #                      std =[0.229, 0.224, 0.225])])
            else:
                self.transform_jigsaw = transforms.Compose([
                    # transforms.Resize(256),
                    # transforms.CenterCrop(225),])
                    # transforms.RandomResizedCrop(225,scale=(0.5, 1.0))])
                    transforms.Resize(255)
                ])
                # transforms.ToTensor(),
                # transforms.Normalize(mean=[0.485, 0.456, 0.406],
                #                      std =[0.229, 0.224, 0.225])])
            self.transform_patch_jigsaw = transforms.Compose([
                transforms.RandomCrop(64),
                # transforms.Resize((75, 75), Image.BILINEAR),
                transforms.Lambda(self.rgb_jittering),
                transforms.ToTensor(),
                # transforms.Normalize(mean=[0.485, 0.456, 0.406],
                # std =[0.229, 0.224, 0.225])
            ])

        dataset = SimpleDataset(data_file, transform, jigsaw=self.jigsaw, \
                    transform_jigsaw=self.transform_jigsaw, transform_patch_jigsaw=self.transform_patch_jigsaw, \
                    rotation=self.rotation, isAircraft=self.isAircraft, grey=self.grey, return_name=self.return_name)
        data_loader_params = dict(batch_size=self.batch_size,
                                  shuffle=self.shuffle,
                                  num_workers=NUM_WORKERS,
                                  pin_memory=True,
                                  drop_last=self.drop_last)
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  **data_loader_params)

        return data_loader
예제 #6
0
 def get_data_loader(self, data_file, aug):
     transform = self.trans_loader.get_simple_transform(aug=aug)
     dataset = SimpleDataset(data_file, transform)
     data_loader_params = dict(
         batch_size=self.batch_size,
         shuffle=True,
         num_workers=0,
         pin_memory=False
     )  # pin_memory for fast load to GPU, but i don't need it
     data_loader = torch.utils.data.DataLoader(dataset,
                                               **data_loader_params)
     return data_loader
예제 #7
0
    def get_data_loader(self, data_file):
        #         transform = transforms.ToTensor()
        transform = self.trans_loader.get_simple_transform(aug=self.aug)
        dataset = SimpleDataset(data_file,
                                transform=transform,
                                return_path=True)
        data_loader_params = dict(
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=12,
            pin_memory=True
        )  # not sure if should be True when input to a TensorFlow model
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  **data_loader_params)

        return data_loader
예제 #8
0
    def get_data_loader(self, data_file,
                        aug):  #parameters that would change on train/val set
        transform = self.trans_loader.get_composed_transform(aug)
        dataset = SimpleDataset(data_file, transform)
        if sys.platform == "win32":  # Note: windows system doesn't support num_workers multiple threads
            data_loader_params = dict(batch_size=self.batch_size,
                                      shuffle=True,
                                      pin_memory=True)
        elif sys.platform == "linux":
            data_loader_params = dict(batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=8,
                                      pin_memory=True)
        else:
            assert False, "Unknown OS!"
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  **data_loader_params)

        return data_loader
예제 #9
0
    def get_data_loader(
            self,
            data_file,
            aug,
            lang_dir=None,
            normalize=True,
            to_pil=False):  #parameters that would change on train/val set
        if lang_dir is not None:
            raise NotImplementedError
        transform = self.trans_loader.get_composed_transform(
            aug, normalize=normalize, to_pil=to_pil)
        dataset = SimpleDataset(data_file, transform)
        data_loader_params = dict(batch_size=self.batch_size,
                                  shuffle=True,
                                  num_workers=self.num_workers,
                                  pin_memory=self.pin_memory)
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  **data_loader_params)

        return data_loader
예제 #10
0
    def get_data_loader(
            self,
            data_file,
            aug,
            shuffle=True,
            num_workers=12,
            return_path=False):  #parameters that would change on train/val set
        transform = self.trans_loader.get_composed_transform(aug)
        dataset = SimpleDataset(data_file, transform, return_path=return_path)
        data_loader_params = dict(batch_size=self.batch_size,
                                  shuffle=shuffle,
                                  num_workers=num_workers,
                                  pin_memory=True)
        ########### DEBUG ###########
        #         data_loader_params['num_workers'] = 0 # set to 0 when debugging
        ########### DEBUG ###########
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  **data_loader_params)

        return data_loader
예제 #11
0
    def get_data_loader(self, root='./filelists/tabula_muris', mode='train'): #parameters that would change on train/val set
        dataset = SimpleDataset(root=root, mode=mode, min_samples=self.batch_size)
        data_loader_params = dict(batch_size = self.batch_size, shuffle = True, num_workers = 4, pin_memory = True)       
        data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)

        return data_loader