示例#1
0
    def get_valloader(self, dataset=None):
        dataset = 'val' if dataset is None else dataset
        if self.configer.get('val.loader', default=None) in [None, 'default']:
            valloader = data.DataLoader(
                DefaultLoader(root_dir=self.configer.get('data', 'data_dir'), dataset=dataset,
                              aug_transform=self.aug_val_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'), shuffle=False,
                num_workers=self.configer.get('data', 'workers'), pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args, trans_dict=self.configer.get('val', 'data_transformer')
                )
            )

            return valloader

        elif self.configer.get('val', 'loader') == 'fasterrcnn':
            valloader = data.DataLoader(
                FasterRCNNLoader(root_dir=self.configer.get('data', 'data_dir'), dataset=dataset,
                                 aug_transform=self.aug_val_transform,
                                 img_transform=self.img_transform,
                                 configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'), shuffle=False,
                num_workers=self.configer.get('data', 'workers'), pin_memory=True,
                collate_fn=lambda *args: collate(
                    *args, trans_dict=self.configer.get('val', 'data_transformer')
                )
            )

            return valloader

        else:
            Log.error('{} val loader is invalid.'.format(self.configer.get('val', 'loader')))
            exit(1)
    def get_trainloader(self):
        if self.configer.get('train.loader', default=None) in [None, 'default']:
            trainloader = data.DataLoader(
                DefaultLoader(root_dir=self.configer.get('data', 'data_dir'), dataset='train',
                              aug_transform=self.aug_train_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'), shuffle=True,
                num_workers=self.configer.get('data', 'workers'), pin_memory=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args, trans_dict=self.configer.get('train', 'data_transformer')
                )
            )

            return trainloader

        elif self.configer.get('train', 'loader') == 'fasterrcnn':
            trainloader = data.DataLoader(
                FasterRCNNLoader(root_dir=self.configer.get('data', 'data_dir'), dataset='train',
                                 aug_transform=self.aug_train_transform,
                                 img_transform=self.img_transform,
                                 configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'), shuffle=True,
                num_workers=self.configer.get('data', 'workers'), pin_memory=True,
                drop_last=self.configer.get('data', 'drop_last'),
                collate_fn=lambda *args: collate(
                    *args, trans_dict=self.configer.get('train', 'data_transformer')
                )
            )

            return trainloader
        else:
            Log.error('{} train loader is invalid.'.format(self.configer.get('train', 'loader')))
            exit(1)