def test_filelist_benchmark(self): download_url( 'https://storage.googleapis.com/mledu-datasets/' 'cats_and_dogs_filtered.zip', './data', 'cats_and_dogs_filtered.zip') archive_name = os.path.join('./data', 'cats_and_dogs_filtered.zip') extract_archive(archive_name, to_path='./data/') dirpath = "./data/cats_and_dogs_filtered/train" for filelist, dir, label in zip( ["train_filelist_00.txt", "train_filelist_01.txt"], ["cats", "dogs"], [0, 1]): # First, obtain the list of files filenames_list = os.listdir(os.path.join(dirpath, dir)) with open(filelist, "w") as wf: for name in filenames_list: wf.write("{} {}\n".format(os.path.join(dir, name), label)) generic_scenario = filelist_benchmark( dirpath, ["train_filelist_00.txt", "train_filelist_01.txt"], ["train_filelist_00.txt"], task_labels=[0, 0], complete_test_set_only=True, train_transform=ToTensor(), eval_transform=ToTensor()) self.assertEqual(2, len(generic_scenario.train_stream)) self.assertEqual(1, len(generic_scenario.test_stream))
def test_filelist_benchmark(self): download_url( "https://storage.googleapis.com/mledu-datasets/" "cats_and_dogs_filtered.zip", expanduser("~") + "/.avalanche/data", "cats_and_dogs_filtered.zip", ) archive_name = os.path.join( expanduser("~") + "/.avalanche/data", "cats_and_dogs_filtered.zip") extract_archive(archive_name, to_path=expanduser("~") + "/.avalanche/data/") dirpath = (expanduser("~") + "/.avalanche/data/cats_and_dogs_filtered/train") for filelist, dir, label in zip( ["train_filelist_00.txt", "train_filelist_01.txt"], ["cats", "dogs"], [0, 1], ): # First, obtain the list of files filenames_list = os.listdir(os.path.join(dirpath, dir)) with open(filelist, "w") as wf: for name in filenames_list: wf.write("{} {}\n".format(os.path.join(dir, name), label)) generic_benchmark = filelist_benchmark( dirpath, ["train_filelist_00.txt", "train_filelist_01.txt"], ["train_filelist_00.txt"], task_labels=[0, 0], complete_test_set_only=True, train_transform=ToTensor(), eval_transform=ToTensor(), ) self.assertEqual(2, len(generic_benchmark.train_stream)) self.assertEqual(1, len(generic_benchmark.test_stream))