def LoadData(MNISTsaveFolder,imsize=[54,98],train = True,batch_size=32,num_works=0,DataSetName = MNIST): # original image size is [28,28] # data_set = DealDataset(imsize=imsize) datamean = 0.5 datastd = 0.5 Trans = trasnFcn(imsize,datamean = datamean, datastd = datastd) if train: if isinstance(DataSetName ,tuple): data_set = [] for SUBsetName in DataSetName: data_set.append ( SUBsetName(root=MNISTsaveFolder, train=True, transform=Trans , download=True)) data_set = ConcatDataset(data_set) else: data_set = DataSetName(root=MNISTsaveFolder, train=True, transform=Trans , download=True) else: if isinstance(DataSetName ,tuple): data_set = [] for SUBsetName in DataSetName: data_set.append ( SUBsetName(root=MNISTsaveFolder, train=False, transform=Trans , download=True)) data_set = ConcatDataset(data_set) else: data_set = DataSetName(root=MNISTsaveFolder, train=False, transform=Trans) dataLoader = DataLoader(dataset= data_set,batch_size=batch_size, shuffle = True, num_workers=num_works,drop_last=True) return dataLoader
def get_train_loader(conf): if conf.data_mode in ['ms1m', 'concat']: ms1m_ds, ms1m_class_num = get_train_dataset(conf.ms1m_folder/'imgs') print('ms1m loader generated') if conf.data_mode in ['vgg', 'concat']: vgg_ds, vgg_class_num = get_train_dataset(conf.vgg_folder/'imgs') print('vgg loader generated') if conf.data_mode == 'vgg': ds = vgg_ds class_num = vgg_class_num elif conf.data_mode == 'ms1m': ds = ms1m_ds class_num = ms1m_class_num elif conf.data_mode == 'concat': for i,(url,label) in enumerate(vgg_ds.imgs): vgg_ds.imgs[i] = (url, label + ms1m_class_num) ds = ConcatDataset([ms1m_ds,vgg_ds]) class_num = vgg_class_num + ms1m_class_num elif conf.data_mode == 'emore': ds, class_num = get_train_dataset(conf.emore_folder/'imgs') elif conf.data_mode == 'African': ds, class_num = get_train_dataset(conf.ccf_folder/'African') elif conf.data_mode == 'Caucasian': ds, class_num = get_train_dataset(conf.ccf_folder/'Caucasian') elif conf.data_mode == 'Asian': ds, class_num = get_train_dataset(conf.ccf_folder/'Asian') elif conf.data_mode == 'Indian': ds, class_num = get_train_dataset(conf.ccf_folder/'Indian') elif conf.data_mode =='ccf': ds = [] class_num = [] imgs_num = [] for path in conf.ccf_folder.iterdir(): if path.is_file(): continue else: ds_tmp, class_num_tmp = get_train_dataset(path) ds.append(ds_tmp) class_num.append(class_num_tmp) imgs_num.append(len(ds_tmp)) for j,sub_ds in enumerate(ds): for i,(url,label) in enumerate(sub_ds.imgs): if j>0: sub_ds.imgs[i] = (url, label + sum(class_num[:j])) ds = ConcatDataset(ds) # ds = ccf_test_dataset(conf.ccf_folder) # class_num = ds.class_num() print('##################################') print(conf.batch_size) conf.race_num = class_num weights = [] for i in range(4): weights+=[sum(class_num)//class_num[i] for j in range(imgs_num[i])] print(len(ds)) print(len(weights)) assert len(ds) ==len(weights) weights = torch.FloatTensor(weights) train_sampler = WeightedRandomSampler(weights,len(ds),replacement=True) loader = DataLoader(ds, batch_size=conf.batch_size, sampler = train_sampler, pin_memory=conf.pin_memory, num_workers=conf.num_workers) if isinstance(class_num,list): class_num = sum(class_num) return loader, class_num
dataset = ConcatDataset(dataset) if len(dataset) > 1 else dataset[0] val_dataset = ScaperLoader(folder=params['validation_folder'], length=params['initial_length'], n_fft=params['n_fft'], hop_length=params['hop_length'], output_type=params['target_type'], group_sources=params['group_sources'], ignore_sources=params['ignore_sources'], source_labels=params['source_labels']) elif params['dataset_type'] == 'wsj': for i in range(len(params['training_folder'])): dataset.append(WSJ0(folder=params['training_folder'][i], length=params['initial_length'], n_fft=params['n_fft'], hop_length=params['hop_length'], output_type=params['target_type'], weight_method=params['weight_method'], create_cache=params['create_cache'], num_channels=1)) dataset = ConcatDataset(dataset) if len(dataset) > 1 else dataset[0] target_type = params['target_type'] if params['target_type'] != 'spatial_bootstrap' else 'psa' val_dataset = WSJ0(folder=params['validation_folder'], length='full', n_fft=params['n_fft'], hop_length=params['hop_length'], output_type=target_type, create_cache=True, #params['create_cache'], num_channels=1) if args.sample_strategy == 'sequential':