def get_dataloader(module_name, module_args, num_label): train_transfroms = transforms.Compose( [transforms.RandomColorJitter(brightness=0.5), transforms.ToTensor()]) val_transfroms = transforms.ToTensor() # 创建数据集 dataset_args = module_args['dataset'] train_data_path = dataset_args.pop('train_data_path') train_data_ratio = dataset_args.pop('train_data_ratio') val_data_path = dataset_args.pop('val_data_path') train_data_list, val_data_list = get_datalist( train_data_path, val_data_path, module_args['loader']['validation_split']) dataset_args['num_label'] = num_label train_dataset_list = [] for train_data in train_data_list: train_dataset_list.append( get_dataset(data_list=train_data, module_name=module_name, phase='train', dataset_args=dataset_args)) if len(train_dataset_list) > 1: train_loader = dataset.Batch_Balanced_Dataset( dataset_list=train_dataset_list, ratio_list=train_data_ratio, module_args=module_args, dataset_transfroms=train_transfroms, phase='train') elif len(train_dataset_list) == 1: train_loader = DataLoader( dataset=train_dataset_list[0].transform_first(val_transfroms), batch_size=module_args['loader']['train_batch_size'], shuffle=module_args['loader']['shuffle'], last_batch='rollover', num_workers=module_args['loader']['num_workers']) train_loader.dataset_len = len(train_dataset_list[0]) else: raise Exception('no images found') if len(val_data_list): val_dataset = get_dataset(data_list=val_data_list, module_name=module_name, phase='test', dataset_args=dataset_args) val_loader = DataLoader( dataset=val_dataset.transform_first(val_transfroms), batch_size=module_args['loader']['val_batch_size'], shuffle=module_args['loader']['shuffle'], last_batch='rollover', num_workers=module_args['loader']['num_workers']) val_loader.dataset_len = len(val_dataset) else: val_loader = None return train_loader, val_loader
def get_dataloader(module_config, num_label, alphabet): if module_config is None: return None config = copy.deepcopy(module_config) dataset_args = config['dataset']['args'] dataset_args['num_label'] = num_label dataset_args['alphabet'] = alphabet if 'transforms' in dataset_args: img_transfroms = get_transforms(dataset_args.pop('transforms')) else: img_transfroms = None # 创建数据集 dataset_name = config['dataset']['type'] data_path_list = dataset_args.pop('data_path') if 'data_ratio' in dataset_args: data_ratio = dataset_args.pop('data_ratio') else: data_ratio = [1.0] _dataset_list = [] for data_path in data_path_list: _dataset_list.append( get_dataset(data_path=data_path, module_name=dataset_name, dataset_args=dataset_args)) if len(data_ratio) > 1 and len( dataset_args['data_ratio']) == len(_dataset_list): from . import dataset loader = dataset.Batch_Balanced_Dataset( dataset_list=_dataset_list, ratio_list=data_ratio, loader_args=config['loader'], dataset_transfroms=img_transfroms, phase='train') else: _dataset = _dataset_list[0] loader = DataLoader(dataset=_dataset.transform_first(img_transfroms), **config['loader']) loader.dataset_len = len(_dataset) return loader