def test_save_load(self): # Create a temporary directory os.makedirs("tmp", exist_ok=True) # Save the dataset to disk self.testbed.dataset.save_to_disk(path="tmp") # Load the dataset from disk dataset = Dataset.load_from_disk(path="tmp") # Remove the temporary directory shutil.rmtree("tmp") self.assertEqual(dataset.features, self.testbed.dataset.features)
def load(cls, path: str) -> DevBench: """Load a devbench from disk. Args: path: string path to the devbench directory Returns: """ # Path to the save directory savedir = pathlib.Path(path) # Load all the slices slices = [] for sl_path in tqdm(list((savedir / "slices").glob("*"))): try: slices.append(Slice.load_from_disk(str(sl_path))) except FileNotFoundError: continue # Load dataset dataset = Dataset.load_from_disk(str(savedir / "dataset")) # Load metrics metrics = dill.load(open(str(savedir / "metrics.dill"), "rb")) # Load metrics aggregators = dill.load(open(str(savedir / "aggregators.dill"), "rb")) # Load metadata _ = dill.load(open(str(savedir / "metadata.dill"), "rb")) # Create the devbench devbench = cls(dataset=dataset, ) # Set previously stored metrics devbench.metrics = metrics # Set previously stored aggregators devbench.aggregators = aggregators # Set the slices devbench.add_slices(slices) # Load version info with open(str(savedir / "version.dill"), "rb") as f: devbench._loads_version(f.read()) return devbench