コード例 #1
0
 def test_load_real_metric(self, metric_name):
     with tempfile.TemporaryDirectory() as temp_data_dir:
         download_config = DownloadConfig()
         download_config.download_mode = GenerateMode.FORCE_REDOWNLOAD
         load_metric(metric_name,
                     data_dir=temp_data_dir,
                     download_config=download_config)
コード例 #2
0
    def test_load_real_dataset(self, dataset_name):
        with tempfile.TemporaryDirectory() as temp_data_dir:
            download_config = DownloadConfig()
            download_config.download_mode = GenerateMode.FORCE_REDOWNLOAD

            dataset = load_dataset(dataset_name, data_dir=temp_data_dir, download_config=download_config)
            for split in dataset.keys():
                self.assertTrue(len(dataset[split]) > 0)
コード例 #3
0
    def test_load_real_dataset_local(self, dataset_name):
        with tempfile.TemporaryDirectory() as temp_data_dir:
            download_config = DownloadConfig()
            download_config.download_mode = GenerateMode.FORCE_REDOWNLOAD
            download_and_prepare_kwargs = {"download_config": download_config}

            dataset = load_dataset(
                "./datasets/" + dataset_name,
                data_dir=temp_data_dir,
                download_and_prepare_kwargs=download_and_prepare_kwargs,
            )
            for split in dataset.keys():
                self.assertTrue(len(dataset[split]) > 0)
コード例 #4
0
ファイル: test_dataset_common.py プロジェクト: xwild/nlp
    def test_load_real_dataset(self, dataset_name):
        if "/" not in dataset_name:
            logging.info("Skip {} because it is a canonical dataset")
            return

        with tempfile.TemporaryDirectory() as temp_data_dir:
            download_config = DownloadConfig()
            download_config.download_mode = GenerateMode.FORCE_REDOWNLOAD
            download_and_prepare_kwargs = {"download_config": download_config}

            dataset = load_dataset(
                dataset_name,
                data_dir=temp_data_dir,
                download_and_prepare_kwargs=download_and_prepare_kwargs)
            for split in dataset.keys():
                self.assertTrue(len(dataset[split]) > 0)
コード例 #5
0
    def test_load_real_metric(self, metric_name):
        with tempfile.TemporaryDirectory() as temp_data_dir:
            download_config = DownloadConfig()
            download_config.force_download = True
            name = None
            if metric_name == "glue":
                name = "sst2"
            metric = load_metric(metric_name,
                                 name=name,
                                 data_dir=temp_data_dir,
                                 download_config=download_config)

            parameters = inspect.signature(metric._compute).parameters
            self.assertTrue("predictions" in parameters)
            self.assertTrue("references" in parameters)
            self.assertTrue(
                all([p.kind != p.VAR_KEYWORD
                     for p in parameters.values()]))  # no **kwargs
コード例 #6
0
 def load_builder_class(self, dataset_name, is_local=False):
     # Download/copy dataset script
     if is_local is True:
         module_path = prepare_module("./datasets/" + dataset_name)
     else:
         module_path = prepare_module(dataset_name, download_config=DownloadConfig(force_download=True))
     # Get dataset builder class
     builder_cls = import_main_class(module_path)
     # Instantiate dataset builder
     return builder_cls
コード例 #7
0
 def test_load_real_dataset(self, dataset_name):
     path = dataset_name
     module_path, hash = prepare_module(
         path,
         download_config=DownloadConfig(force_download=True),
         dataset=True)
     builder_cls = import_main_class(module_path, dataset=True)
     name = builder_cls.BUILDER_CONFIGS[
         0].name if builder_cls.BUILDER_CONFIGS else None
     with tempfile.TemporaryDirectory() as temp_cache_dir:
         dataset = load_dataset(path,
                                name=name,
                                cache_dir=temp_cache_dir,
                                download_mode=GenerateMode.FORCE_REDOWNLOAD)
         for split in dataset.keys():
             self.assertTrue(len(dataset[split]) > 0)
コード例 #8
0
 def test_load_real_dataset_all_configs(self, dataset_name):
     path = "./datasets/" + dataset_name
     module_path, hash = prepare_module(
         path,
         download_config=DownloadConfig(local_files_only=True),
         dataset=True)
     builder_cls = import_main_class(module_path, dataset=True)
     config_names = ([
         config.name for config in builder_cls.BUILDER_CONFIGS
     ] if len(builder_cls.BUILDER_CONFIGS) > 0 else [None])
     for name in config_names:
         with tempfile.TemporaryDirectory() as temp_cache_dir:
             dataset = load_dataset(
                 path,
                 name=name,
                 cache_dir=temp_cache_dir,
                 download_mode=GenerateMode.FORCE_REDOWNLOAD)
             for split in dataset.keys():
                 self.assertTrue(len(dataset[split]) > 0)