def get_trainloader(self):
        if self.configer.get('train.loader', default=None) in [None, 'default']:
            trainloader = data.DataLoader(
                DefaultLoader(root_dir=self.configer.get('data', 'data_dir'), dataset='train',
                              aug_transform=self.aug_train_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'), shuffle=True,
                num_workers=self.configer.get('data', 'workers'), pin_memory=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args, trans_dict=self.configer.get('train', 'data_transformer')
                )
            )

            return trainloader

        elif self.configer.get('train', 'loader') == 'fasterrcnn':
            trainloader = data.DataLoader(
                FasterRCNNLoader(root_dir=self.configer.get('data', 'data_dir'), dataset='train',
                                 aug_transform=self.aug_train_transform,
                                 img_transform=self.img_transform,
                                 configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'), shuffle=True,
                num_workers=self.configer.get('data', 'workers'), pin_memory=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args, trans_dict=self.configer.get('train', 'data_transformer')
                )
            )

            return trainloader
        else:
            Log.error('{} train loader is invalid.'.format(self.configer.get('train', 'loader')))
            exit(1)
Exemple #2
0
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset
        if self.configer.get('val.loader', default=None) in [None, 'default']:
            valloader = data.DataLoader(
                DefaultLoader(root_dir=self.configer.get('data', 'data_dir'), dataset=dataset,
                              aug_transform=self.aug_val_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'), shuffle=False,
                num_workers=self.configer.get('data', 'workers'), pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args, trans_dict=self.configer.get('val', 'data_transformer')
                )
            )

            return valloader

        elif self.configer.get('val', 'loader') == 'fasterrcnn':
            valloader = data.DataLoader(
                FasterRCNNLoader(root_dir=self.configer.get('data', 'data_dir'), dataset=dataset,
                                 aug_transform=self.aug_val_transform,
                                 img_transform=self.img_transform,
                                 configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'), shuffle=False,
                num_workers=self.configer.get('data', 'workers'), pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args, trans_dict=self.configer.get('val', 'data_transformer')
                )
            )

            return valloader

        else:
            Log.error('{} val loader is invalid.'.format(self.configer.get('val', 'loader')))
            exit(1)
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset
        if not self.configer.exists('val', 'loader') or self.configer.get(
                'val', 'loader') == 'default':
            valloader = data.DataLoader(
                DefaultLoader(root_dir=self.configer.get('data', 'data_dir'),
                              dataset=dataset,
                              aug_transform=self.aug_val_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('val', 'data_transformer')))

            return valloader

        elif self.configer.get('val', 'loader') == 'cyclegan':
            valloader = data.DataLoader(
                CycleGANLoader(root_dir=self.configer.get('data', 'data_dir'),
                               dataset=dataset,
                               aug_transform=self.aug_val_transform,
                               img_transform=self.img_transform,
                               configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('val', 'data_transformer')))

            return valloader

        elif self.configer.get('val', 'loader') == 'facegan':
            valloader = data.DataLoader(
                FaceGANLoader(root_dir=self.configer.get('data', 'data_dir'),
                              dataset=dataset,
                              tag=self.configer.get('data', 'tag'),
                              aug_transform=self.aug_val_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('val', 'data_transformer')))

            return valloader
        else:
            Log.error('{} val loader is invalid.'.format(
                self.configer.get('val', 'loader')))
            exit(1)
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset
        if self.configer.get('val.loader', default=None) in [None, 'default']:
            dataset = DefaultLoader(root_dir=self.configer.get(
                'data', 'data_dir'),
                                    dataset=dataset,
                                    aug_transform=self.aug_val_transform,
                                    img_transform=self.img_transform,
                                    configer=self.configer)
            sampler = None
            if self.configer.get('network.distributed'):
                sampler = torch.utils.data.distributed.DistributedSampler(
                    dataset)

            valloader = data.DataLoader(
                dataset,
                sampler=sampler,
                batch_size=self.configer.get('val', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('val', 'data_transformer')))
            return valloader

        else:
            Log.error('{} val loader is invalid.'.format(
                self.configer.get('val', 'loader')))
            exit(1)
    def get_trainloader(self):
        if self.configer.get('train.loader',
                             default=None) in [None, 'default']:
            dataset = DefaultLoader(root_dir=self.configer.get(
                'data', 'data_dir'),
                                    dataset='train',
                                    aug_transform=self.aug_train_transform,
                                    img_transform=self.img_transform,
                                    configer=self.configer)
            sampler = None
            if self.configer.get('network.distributed'):
                sampler = torch.utils.data.distributed.DistributedSampler(
                    dataset)

            trainloader = data.DataLoader(
                dataset,
                sampler=sampler,
                batch_size=self.configer.get('train', 'batch_size'),
                shuffle=(sampler is None),
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('train', 'data_transformer')))
            return trainloader

        else:
            Log.error('{} train loader is invalid.'.format(
                self.configer.get('train', 'loader')))
            exit(1)
    def get_valloader(self):
        if self.configer.get('val.loader', default=None) in [None, 'default']:
            Log.info('Get val dataloader start')
            dataset = DefaultLoader(root_dir=self.configer.get(
                'data', 'data_dir'),
                                    dataset='val',
                                    aug_transform=self.aug_val_transform,
                                    img_transform=self.img_transform,
                                    label_transform=self.label_transform,
                                    configer=self.configer)
            Log.info('Get dataloader')
            valloader = data.DataLoader(
                dataset,
                batch_size=self.configer.get('val', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=False,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('val', 'data_transformer')))
            Log.info('Get val dataloader end')
            return valloader

        else:
            Log.error('{} val loader is invalid.'.format(
                self.configer.get('val', 'loader')))
            exit(1)
Exemple #7
0
    def get_trainloader(self):
        if not self.configer.exists('train', 'loader') or self.configer.get(
                'train', 'loader') == 'default':
            trainloader = data.DataLoader(
                DefaultLoader(root_dir=self.configer.get('data', 'data_dir'),
                              dataset='train',
                              aug_transform=self.aug_train_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'),
                shuffle=True,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('train', 'data_transformer')))

            return trainloader

        elif self.configer.get('train', 'loader') == 'openpose':
            trainloader = data.DataLoader(
                OpenPoseLoader(root_dir=self.configer.get('data', 'data_dir'),
                               dataset='train',
                               aug_transform=self.aug_train_transform,
                               img_transform=self.img_transform,
                               configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'),
                shuffle=True,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('train', 'data_transformer')))

            return trainloader

        else:
            Log.error('{} train loader is invalid.'.format(
                self.configer.get('train', 'loader')))
            exit(1)
Exemple #8
0
    def get_trainloader(self):
        if self.configer.exists('train', 'loader') and self.configer.get(
                'train', 'loader') == 'ade20k':
            trainloader = data.DataLoader(
                ADE20KLoader(root_dir=self.configer.get('data', 'data_dir'),
                             dataset='train',
                             aug_transform=self.aug_train_transform,
                             img_transform=self.img_transform,
                             label_transform=self.label_transform,
                             configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'),
                pin_memory=True,
                num_workers=self.configer.get('data', 'workers'),
                shuffle=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('train', 'data_transformer')))

            return trainloader

        else:
            trainloader = data.DataLoader(
                DefaultLoader(root_dir=self.configer.get('data', 'data_dir'),
                              dataset='train',
                              aug_transform=self.aug_train_transform,
                              img_transform=self.img_transform,
                              label_transform=self.label_transform,
                              configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'),
                pin_memory=True,
                num_workers=self.configer.get('data', 'workers'),
                shuffle=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('train', 'data_transformer')))

            return trainloader
Exemple #9
0
    def get_testloader(self, test_dir=None, list_path=None):
        test_dir = test_dir if test_dir is not None else self.configer.get('test', 'data_dir')
        if not self.configer.exists('test', 'loader') or self.configer.get('test', 'loader') == 'default':
            trainloader = data.DataLoader(
                DefaultLoader(test_dir=test_dir,
                              list_path=list_path,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('test', 'batch_size'), shuffle=False,
                num_workers=self.configer.get('data', 'workers'), pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args, trans_dict=self.configer.get('test', 'data_transformer')
                )
            )

            return trainloader

        else:
            Log.error('{} train loader is invalid.'.format(self.configer.get('train', 'loader')))
            exit(1)
Exemple #10
0
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset
        if self.configer.get('method') == 'fcn_segmentor':
            valloader = data.DataLoader(
                DefaultLoader(root_dir=self.configer.get('data', 'data_dir'),
                              dataset=dataset,
                              aug_transform=self.aug_val_transform,
                              img_transform=self.img_transform,
                              label_transform=self.label_transform,
                              configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'),
                pin_memory=True,
                num_workers=self.configer.get('data', 'workers'),
                shuffle=False,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('val', 'data_transformer')))

            return valloader

        else:
            Log.error('Method: {} loader is invalid.'.format(
                self.configer.get('method')))
            return None
Exemple #11
0
    def get_testloader(self, test_dir=None, list_path=None, json_path=None):
        if not self.configer.exists('test', 'loader') or self.configer.get(
                'test', 'loader') == 'default':
            test_dir = test_dir if test_dir is not None else self.configer.get(
                'test', 'test_dir')
            testloader = data.DataLoader(
                DefaultLoader(test_dir=test_dir,
                              aug_transform=self.aug_test_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('test', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('test', 'data_transformer')))

            return testloader

        elif self.configer.get('test', 'loader') == 'list':
            list_path = list_path if list_path is not None else self.configer.get(
                'test', 'list_path')
            testloader = data.DataLoader(
                ListLoader(root_dir=self.configer.get('test', 'root_dir'),
                           list_path=list_path,
                           aug_transform=self.aug_test_transform,
                           img_transform=self.img_transform,
                           configer=self.configer),
                batch_size=self.configer.get('test', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('test', 'data_transformer')))

            return testloader

        elif self.configer.get('test', 'loader') == 'json':
            json_path = json_path if json_path is not None else self.configer.get(
                'test', 'json_path')
            testloader = data.DataLoader(
                JsonLoader(root_dir=self.configer.get('test', 'root_dir'),
                           json_path=json_path,
                           aug_transform=self.aug_test_transform,
                           img_transform=self.img_transform,
                           configer=self.configer),
                batch_size=self.configer.get('test', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('test', 'data_transformer')))

            return testloader

        elif self.configer.get('test', 'loader') == 'facegan':
            json_path = json_path if json_path is not None else self.configer.get(
                'test', 'json_path')
            testloader = data.DataLoader(
                FaceGANLoader(root_dir=self.configer.get('test', 'root_dir'),
                              json_path=json_path,
                              aug_transform=self.aug_test_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('test', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args,
                    trans_dict=self.configer.get('test', 'data_transformer')))

            return testloader

        else:
            Log.error('{} test loader is invalid.'.format(
                self.configer.get('test', 'loader')))
            exit(1)