def load_dbpedia(size='small', test_with_fake_data=False): """Get DBpedia datasets from CSV files.""" if not test_with_fake_data: data_dir = os.path.join(os.getenv('TF_EXP_BASE_DIR', ''), 'dbpedia_data') maybe_download_dbpedia(data_dir) train_path = os.path.join(data_dir, 'dbpedia_csv', 'train.csv') test_path = os.path.join(data_dir, 'dbpedia_csv', 'test.csv') if size == 'small': # Reduce the size of original data by a factor of 1000. base.shrink_csv(train_path, 1000) base.shrink_csv(test_path, 1000) train_path = train_path.replace('train.csv', 'train_small.csv') test_path = test_path.replace('test.csv', 'test_small.csv') else: module_path = os.path.dirname(__file__) train_path = os.path.join(module_path, 'data', 'text_train.csv') test_path = os.path.join(module_path, 'data', 'text_test.csv') train = base.load_csv(train_path, np.int32, 0, has_header=False) test = base.load_csv(test_path, np.int32, 0, has_header=False) return base.Datasets(train=train, validation=None, test=test)
def get_dbpedia(data_dir): train_path = os.path.join(data_dir, 'dbpedia_csv/train.csv') test_path = os.path.join(data_dir, 'dbpedia_csv/test.csv') if not (gfile.Exists(train_path) and gfile.Exists(test_path)): archive_path = base.maybe_download('dbpedia_csv.tar.gz', data_dir, DBPEDIA_URL) tfile = tarfile.open(archive_path, 'r:*') tfile.extractall(data_dir) train = base.load_csv(train_path, np.int32, 0, has_header=False) test = base.load_csv(test_path, np.int32, 0, has_header=False) datasets = base.Datasets(train=train, validation=None, test=test) return datasets