示例#1
0
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset
        if self.configer.get('dataset', default=None) == 'default_cpm':
            dataset = DefaultDataset(root_dir=self.configer.get(
                'data', 'data_dir'),
                                     dataset=dataset,
                                     aug_transform=self.aug_val_transform,
                                     img_transform=self.img_transform,
                                     configer=self.configer)

        elif self.configer.get('dataset', default=None) == 'default_openpose':
            dataset = DefaultOpenPoseDataset(
                root_dir=self.configer.get('data', 'data_dir'),
                dataset=dataset,
                aug_transform=self.aug_val_transform,
                img_transform=self.img_transform,
                configer=self.configer),

        else:
            Log.error('{} dataset is invalid.'.format(
                self.configer.get('dataset')))
            exit(1)

        valloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('val', 'batch_size'),
            shuffle=False,
            num_workers=self.configer.get('data', 'workers'),
            pin_memory=True,
            collate_fn=lambda *args: collate(*args,
                                             trans_dict=self.configer.get(
                                                 'val', 'data_transformer')))
        return valloader
示例#2
0
    def get_trainloader(self):
        if self.configer.get('dataset', default=None) == 'default_cpm':
            dataset = DefaultCPMDataset(root_dir=self.configer.get(
                'data', 'data_dir'),
                                        dataset='train',
                                        aug_transform=self.aug_train_transform,
                                        img_transform=self.img_transform,
                                        configer=self.configer)

        elif self.configer.get('dataset', default=None) == 'default_openpose':
            dataset = DefaultOpenPoseDataset(
                root_dir=self.configer.get('data', 'data_dir'),
                dataset='train',
                aug_transform=self.aug_train_transform,
                img_transform=self.img_transform,
                configer=self.configer)

        else:
            Log.error('{} dataset is invalid.'.format(
                self.configer.get('dataset', default=None)))
            exit(1)

        trainloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('train', 'batch_size'),
            shuffle=True,
            num_workers=self.configer.get('data', 'workers'),
            pin_memory=True,
            drop_last=self.configer.get('data', 'drop_last'),
            collate_fn=lambda *args: collate(*args,
                                             trans_dict=self.configer.get(
                                                 'train', 'data_transformer')))
        return trainloader
示例#3
0
    def get_trainloader(self):
        if self.configer.get('dataset', default=None) in [None, 'default']:
            dataset = DefaultDataset(root_dir=self.configer.get(
                'data', 'data_dir'),
                                     dataset='train',
                                     aug_transform=self.aug_train_transform,
                                     img_transform=self.img_transform,
                                     configer=self.configer)

        else:
            Log.error('{} dataset is invalid.'.format(
                self.configer.get('dataset')))
            exit(1)

        sampler = None
        if self.configer.get('network.distributed'):
            sampler = torch.utils.data.distributed.DistributedSampler(dataset)

        trainloader = data.DataLoader(
            dataset,
            sampler=sampler,
            batch_size=self.configer.get('train', 'batch_size'),
            shuffle=(sampler is None),
            num_workers=self.configer.get('data', 'workers'),
            pin_memory=True,
            drop_last=self.configer.get('data', 'drop_last'),
            collate_fn=lambda *args: collate(*args,
                                             trans_dict=self.configer.get(
                                                 'train', 'data_transformer')))

        return trainloader
示例#4
0
    def get_valloader(self):
        if self.configer.get('dataset', default=None) in [None, 'default']:
            dataset = DefaultDataset(root_dir=self.configer.get(
                'data', 'data_dir'),
                                     dataset='val',
                                     aug_transform=self.aug_val_transform,
                                     img_transform=self.img_transform,
                                     label_transform=self.label_transform,
                                     configer=self.configer)

        elif self.configer.get('dataset', default=None) == 'cityscapes':
            dataset = CityscapesDataset(root_dir=self.configer.get(
                'data', 'data_dir'),
                                        dataset='val',
                                        aug_transform=self.aug_val_transform,
                                        img_transform=self.img_transform,
                                        label_transform=self.label_transform,
                                        configer=self.configer)

        else:
            Log.error('{} dataset is invalid.'.format(
                self.configer.get('dataset')))
            exit(1)

        valloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('val', 'batch_size'),
            shuffle=False,
            num_workers=self.configer.get('data', 'workers'),
            pin_memory=True,
            collate_fn=lambda *args: collate(*args,
                                             trans_dict=self.configer.get(
                                                 'val', 'data_transformer')))

        return valloader
示例#5
0
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset
        if self.configer.get('dataset') == 'default_pix2pix':
            dataset = DefaultPix2pixDataset(
                root_dir=self.configer.get('data', 'data_dir'),
                dataset=dataset,
                aug_transform=self.aug_val_transform,
                img_transform=self.img_transform,
                configer=self.configer)

        elif self.configer.get('dataset') == 'default_cyclegan':
            dataset = DefaultCycleGANDataset(
                root_dir=self.configer.get('data', 'data_dir'),
                dataset=dataset,
                aug_transform=self.aug_val_transform,
                img_transform=self.img_transform,
                configer=self.configer)

        elif self.configer.get('dataset') == 'default_facegan':
            dataset = DefaultFaceGANDataset(
                root_dir=self.configer.get('data', 'data_dir'),
                dataset=dataset,
                tag=self.configer.get('data', 'tag'),
                aug_transform=self.aug_val_transform,
                img_transform=self.img_transform,
                configer=self.configer)

        else:
            Log.error('{} val loader is invalid.'.format(
                self.configer.get('val', 'loader')))
            exit(1)

        valloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('val', 'batch_size'),
            shuffle=False,
            num_workers=self.configer.get('data', 'workers'),
            pin_memory=True,
            collate_fn=lambda *args: collate(*args,
                                             trans_dict=self.configer.get(
                                                 'val', 'data_transformer')))

        return valloader
示例#6
0
    def get_trainloader(self):
        if self.configer.get('dataset', default=None) == 'default_pix2pix':
            dataset = DefaultPix2pixDataset(
                root_dir=self.configer.get('data', 'data_dir'),
                dataset='train',
                aug_transform=self.aug_train_transform,
                img_transform=self.img_transform,
                configer=self.configer)

        elif self.configer.get('dataset') == 'default_cyclegan':
            dataset = DefaultCycleGANDataset(
                root_dir=self.configer.get('data', 'data_dir'),
                dataset='train',
                aug_transform=self.aug_train_transform,
                img_transform=self.img_transform,
                configer=self.configer)

        elif self.configer.get('dataset') == 'default_facegan':
            dataset = DefaultFaceGANDataset(
                root_dir=self.configer.get('data', 'data_dir'),
                dataset='train',
                tag=self.configer.get('data', 'tag'),
                aug_transform=self.aug_train_transform,
                img_transform=self.img_transform,
                configer=self.configer)

        else:
            Log.error('{} train loader is invalid.'.format(
                self.configer.get('train', 'loader')))
            exit(1)

        trainloader = data.DataLoader(
            dataset,
            batch_size=self.configer.get('train', 'batch_size'),
            shuffle=True,
            num_workers=self.configer.get('data', 'workers'),
            pin_memory=True,
            drop_last=self.configer.get('data', 'drop_last'),
            collate_fn=lambda *args: collate(*args,
                                             trans_dict=self.configer.get(
                                                 'train', 'data_transformer')))

        return trainloader