def load_data(dataset, root, num_query, num_train, batch_size, num_workers): """ Load dataset. Args dataset(str): Dataset name. root(str): Path of dataset. num_query(int): Number of query data points. num_train(int): Number of training data points. num_workers(int): Number of loading data threads. Returns query_dataloader, train_dataloader, retrieval_dataloader(torch.utils.data.DataLoader): Data loader. """ if dataset == 'cifar-10': query_dataloader, train_dataloader, retrieval_dataloader = cifar10.load_data(root, num_query, num_train, batch_size, num_workers, ) elif dataset == 'nus-wide-tc10': query_dataloader, train_dataloader, retrieval_dataloader = nuswide.load_data(10, root, num_query, num_train, batch_size, num_workers, ) elif dataset == 'nus-wide-tc21': query_dataloader, train_dataloader, retrieval_dataloader = nuswide.load_data(21, root, num_query, num_train, batch_size, num_workers ) elif dataset == 'flickr25k': query_dataloader, train_dataloader, retrieval_dataloader = flickr25k.load_data(root, num_query, num_train, batch_size, num_workers, ) elif dataset == 'imagenet': query_dataloader, train_dataloader, retrieval_dataloader = imagenet.load_data(root, batch_size, num_workers, ) else: raise ValueError("Invalid dataset name!") return query_dataloader, train_dataloader, retrieval_dataloader
def load_data(dataset, root, batch_size, num_workers): """ Load dataset. Args dataset(str): Dataset name. root(str): Path of dataset. num_workers(int): Number of loading data threads. Returns train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.DataLoader): Data loader. """ if dataset == 'cifar-10': train_dataloader, query_dataloader, retrieval_dataloader = cifar10.load_data( root, batch_size, num_workers, ) elif dataset == 'nus-wide-tc21': train_dataloader, query_dataloader, retrieval_dataloader = nuswide.load_data( root, batch_size, num_workers) elif dataset == 'imagenet-tc100': train_dataloader, query_dataloader, retrieval_dataloader = imagenet.load_data( root, batch_size, num_workers, ) else: raise ValueError("Invalid dataset name!") return train_dataloader, query_dataloader, retrieval_dataloader
def load_data(opt): """加载数据 Parameters opt: Parser 参数 Returns DataLoader 数据加载器 """ if opt.dataset == 'cifar10': return cifar10.load_data(opt) elif opt.dataset == 'nus-wide': return nus_wide.load_data(opt)