Esempio n. 1
0
    def get_valloader(self):
        if self.configer.get('method') == 'conv_pose_machine':
            valloader = data.DataLoader(
                CPMDataLoader(root_dir=os.path.join(
                    self.configer.get('data', 'data_dir'), 'val'),
                              aug_transform=self.aug_val_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'),
                shuffle=False,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: CollateFunctions.default_collate(
                    *args, data_keys=['img', 'heatmap']))

            return valloader

        elif self.configer.get('method') == 'open_pose':
            valloader = data.DataLoader(
                OPDataLoader(root_dir=os.path.join(
                    self.configer.get('data', 'data_dir'), 'val'),
                             aug_transform=self.aug_val_transform,
                             img_transform=self.img_transform,
                             configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'),
                shuffle=False,
                collate_fn=lambda *args: CollateFunctions.default_collate(
                    *args, data_keys=['img', 'maskmap', 'heatmap', 'vecmap']))

            return valloader

        else:
            Log.error('Method: {} loader is invalid.'.format(
                self.configer.get('method')))
            return None
Esempio n. 2
0
    def get_trainloader(self):
        if self.configer.get('method') == 'conv_pose_machine':
            trainloader = data.DataLoader(
                CPMDataLoader(root_dir=os.path.join(
                    self.configer.get('data', 'data_dir'), 'train'),
                              aug_transform=self.aug_train_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'),
                shuffle=True,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: CollateFunctions.our_collate(
                    *args,
                    data_keys=['img', 'kpts'],
                    configer=self.configer,
                    trans_dict=self.configer.get('train', 'data_transformer')))

            return trainloader

        elif self.configer.get('method') == 'open_pose':
            trainloader = data.DataLoader(
                OPDataLoader(root_dir=os.path.join(
                    self.configer.get('data', 'data_dir'), 'train'),
                             aug_transform=self.aug_train_transform,
                             img_transform=self.img_transform,
                             configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'),
                shuffle=True,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: CollateFunctions.our_collate(
                    *args,
                    data_keys=['img', 'maskmap', 'kpts'],
                    configer=self.configer,
                    trans_dict=self.configer.get('train', 'data_transformer')))

            return trainloader

        else:
            Log.error('Method: {} loader is invalid.'.format(
                self.configer.get('method')))
            return None
Esempio n. 3
0
    def get_trainloader(self):
        if self.configer.get('method') == 'fcn_segmentor':
            trainloader = data.DataLoader(
                FSDataLoader(root_dir=os.path.join(
                    self.configer.get('data', 'data_dir'), 'train'),
                             aug_transform=self.aug_train_transform,
                             img_transform=self.img_transform,
                             label_transform=self.label_transform,
                             configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'),
                shuffle=True,
                drop_last=True,
                collate_fn=lambda *args: CollateFunctions.our_collate(
                    *args,
                    data_keys=['img', 'labelmap'],
                    configer=self.configer,
                    trans_dict=self.configer.get('train', 'data_transformer')))

            return trainloader

        elif self.configer.get('method') == 'mask_rcnn':
            trainloader = data.DataLoader(
                MRDataLoader(root_dir=os.path.join(
                    self.configer.get('data', 'data_dir'), 'train'),
                             aug_transform=self.aug_train_transform,
                             img_transform=self.img_transform,
                             configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'),
                shuffle=True,
                collate_fn=lambda *args: CollateFunctions.our_collate(
                    *args,
                    data_keys=['img', 'bboxes', 'labels', 'polygons'],
                    configer=self.configer,
                    trans_dict=self.configer.get('train', 'data_transformer')))

            return trainloader

        else:
            Log.error('Method: {} loader is invalid.'.format(
                self.configer.get('method')))
            return None
Esempio n. 4
0
    def get_valloader(self):
        if self.configer.get('method') == 'fc_classifier':
            valloader = data.DataLoader(
                FCDataLoader(root_dir=os.path.join(self.configer.get('data', 'data_dir'), 'val'),
                             aug_transform=self.aug_val_transform,
                             img_transform=self.img_transform, configer=self.configer),
                batch_size=self.configer.get('val', 'batch_size'), shuffle=False,
                num_workers=self.configer.get('data', 'workers'), pin_memory=True,
                collate_fn=lambda *args: CollateFunctions.our_collate(
                    *args, data_keys=['img', 'label'],
                    trans_dict=self.configer.get('val', 'data_transformer')
                )
            )

            return valloader

        else:
            Log.error('Method: {} loader is invalid.'.format(self.configer.get('method')))
            return None
Esempio n. 5
0
    def get_trainloader(self):
        if self.configer.get('method') == 'single_shot_detector':
            trainloader = data.DataLoader(
                SSDDataLoader(root_dir=os.path.join(
                    self.configer.get('data', 'data_dir'), 'train'),
                              aug_transform=self.aug_train_transform,
                              img_transform=self.img_transform,
                              configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'),
                shuffle=True,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: CollateFunctions.our_collate(
                    *args,
                    data_keys=['img', 'bboxes', 'labels'],
                    configer=self.configer,
                    trans_dict=self.configer.get('train', 'data_transformer')))

            return trainloader

        elif self.configer.get('method') == 'faster_rcnn':
            trainloader = data.DataLoader(
                FRDataLoader(root_dir=os.path.join(
                    self.configer.get('data', 'data_dir'), 'train'),
                             aug_transform=self.aug_train_transform,
                             img_transform=self.img_transform,
                             configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'),
                shuffle=True,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: CollateFunctions.our_collate(
                    *args,
                    data_keys=['img', 'imgscale', 'bboxes', 'labels'],
                    configer=self.configer,
                    trans_dict=self.configer.get('train', 'data_transformer')))

            return trainloader

        elif self.configer.get('method') == 'yolov3':
            trainloader = data.DataLoader(
                YOLODataLoader(root_dir=os.path.join(
                    self.configer.get('data', 'data_dir'), 'train'),
                               aug_transform=self.aug_train_transform,
                               img_transform=self.img_transform,
                               configer=self.configer),
                batch_size=self.configer.get('train', 'batch_size'),
                shuffle=True,
                num_workers=self.configer.get('data', 'workers'),
                pin_memory=True,
                collate_fn=lambda *args: CollateFunctions.our_collate(
                    *args,
                    data_keys=['img', 'bboxes', 'labels'],
                    configer=self.configer,
                    trans_dict=self.configer.get('train', 'data_transformer')))

            return trainloader

        else:
            Log.error('Method: {} loader is invalid.'.format(
                self.configer.get('method')))
            return None