def load_dataset(dataset_config, *args, num_batches=None, **kwargs): """ Loads a dataset from configuration file If num_batches is None, this function will return a generator that iterates over the entire dataset. """ dataset_module = import_module(dataset_config["module"]) dataset_fn = getattr(dataset_module, dataset_config["name"]) batch_size = dataset_config["batch_size"] for ds_kwarg in ["index", "class_ids"]: if ds_kwarg not in kwargs and ds_kwarg in dataset_config: kwargs[ds_kwarg] = dataset_config[ds_kwarg] framework = dataset_config.get("framework", "numpy") dataset = dataset_fn(batch_size=batch_size, framework=framework, *args, **kwargs) if not isinstance(dataset, ArmoryDataGenerator): raise ValueError( f"{dataset} is not an instance of {ArmoryDataGenerator}") if dataset_config.get("check_run"): return EvalGenerator(dataset, num_eval_batches=1) if num_batches: return EvalGenerator(dataset, num_eval_batches=num_batches) return dataset
def load_adversarial_dataset(config, num_batches=None, **kwargs): if config.get("type") != "preloaded": raise ValueError(f"attack type must be 'preloaded', not {config.get('type')}") dataset_module = import_module(config["module"]) dataset_fn = getattr(dataset_module, config["name"]) dataset_kwargs = config["kwargs"] dataset_kwargs.update(kwargs) if "description" in dataset_kwargs: dataset_kwargs.pop("description") dataset = dataset_fn(**dataset_kwargs) if not isinstance(dataset, ArmoryDataGenerator): raise ValueError(f"{dataset} is not an instance of {ArmoryDataGenerator}") if config.get("check_run"): return EvalGenerator(dataset, num_eval_batches=1) if num_batches: return EvalGenerator(dataset, num_eval_batches=num_batches) return dataset