def get_valloader(self): if self.configer.get('method') == 'single_shot_detector': valloader = data.DataLoader( SSDDataLoader(root_dir=os.path.join( self.configer.get('data', 'data_dir'), 'val'), aug_transform=self.aug_val_transform, img_transform=self.img_transform, configer=self.configer), batch_size=self.configer.get('data', 'val_batch_size'), shuffle=False, num_workers=self.configer.get('data', 'workers'), pin_memory=True) return valloader elif self.configer.get('method') == 'faster_rcnn': valloader = data.DataLoader( FRDataLoader(root_dir=os.path.join( self.configer.get('data', 'data_dir'), 'val'), aug_transform=self.aug_val_transform, img_transform=self.img_transform, configer=self.configer), batch_size=self.configer.get('data', 'val_batch_size'), shuffle=False, num_workers=self.configer.get('data', 'workers'), pin_memory=True) return valloader else: Log.error('Method: {} loader is invalid.'.format( self.configer.get('method'))) return None
def get_trainloader(self): if self.configer.get('method') == 'single_shot_detector': trainloader = data.DataLoader( SSDDataLoader(root_dir=os.path.join( self.configer.get('data', 'data_dir'), 'train'), aug_transform=self.aug_train_transform, img_transform=self.img_transform, configer=self.configer), batch_size=self.configer.get('train', 'batch_size'), shuffle=True, num_workers=self.configer.get('data', 'workers'), pin_memory=True, collate_fn=lambda *args: CollateFunctions.our_collate( *args, data_keys=['img', 'bboxes', 'labels'], configer=self.configer, trans_dict=self.configer.get('train', 'data_transformer'))) return trainloader elif self.configer.get('method') == 'faster_rcnn': trainloader = data.DataLoader( FRDataLoader(root_dir=os.path.join( self.configer.get('data', 'data_dir'), 'train'), aug_transform=self.aug_train_transform, img_transform=self.img_transform, configer=self.configer), batch_size=self.configer.get('train', 'batch_size'), shuffle=True, num_workers=self.configer.get('data', 'workers'), pin_memory=True, collate_fn=lambda *args: CollateFunctions.our_collate( *args, data_keys=['img', 'imgscale', 'bboxes', 'labels'], configer=self.configer, trans_dict=self.configer.get('train', 'data_transformer'))) return trainloader elif self.configer.get('method') == 'yolov3': trainloader = data.DataLoader( YOLODataLoader(root_dir=os.path.join( self.configer.get('data', 'data_dir'), 'train'), aug_transform=self.aug_train_transform, img_transform=self.img_transform, configer=self.configer), batch_size=self.configer.get('train', 'batch_size'), shuffle=True, num_workers=self.configer.get('data', 'workers'), pin_memory=True, collate_fn=lambda *args: CollateFunctions.our_collate( *args, data_keys=['img', 'bboxes', 'labels'], configer=self.configer, trans_dict=self.configer.get('train', 'data_transformer'))) return trainloader else: Log.error('Method: {} loader is invalid.'.format( self.configer.get('method'))) return None