Пример #1
0
 def _get_model(self, model_config):
     model, _ = load_model(model_config)
     if not isinstance(
             model,
         (ClassifierMixin, SpeechRecognizerMixin, ObjectDetectorMixin)):
         raise TypeError(f"Unsupported model type: {type(model)}")
     return model
Пример #2
0
    def _evaluate(self, config: dict, num_eval_batches: Optional[int],
                  skip_benign: Optional[bool]) -> dict:
        """
        Evaluate the config and return a results dict
        """

        model_config = config["model"]
        classifier, preprocessing_fn = load_model(model_config)

        defense_config = config.get("defense") or {}
        defense_type = defense_config.get("type")

        if defense_type in ["Preprocessor", "Postprocessor"]:
            logger.info(
                f"Applying internal {defense_type} defense to classifier")
            classifier = load_defense_internal(config["defense"], classifier)

        if model_config["fit"]:
            classifier.set_learning_phase(True)
            logger.info(
                f"Fitting model {model_config['module']}.{model_config['name']}..."
            )
            train_epochs = config["model"]["fit_kwargs"]["nb_epochs"]
            batch_size = config["dataset"]["batch_size"]

            logger.info(
                f"Loading train dataset {config['dataset']['name']}...")
            train_data = load_dataset(
                config["dataset"],
                epochs=train_epochs,
                split_type="train",
                preprocessing_fn=preprocessing_fn,
                shuffle_files=True,
            )

            if defense_type == "Trainer":
                logger.info(f"Training with {defense_type} defense...")
                defense = load_defense_wrapper(config["defense"], classifier)
            else:
                logger.info("Fitting classifier on clean train dataset...")

            for epoch in range(train_epochs):
                classifier.set_learning_phase(True)

                for _ in tqdm(
                        range(train_data.batches_per_epoch),
                        desc=f"Epoch: {epoch}/{train_epochs}",
                ):
                    x, y = train_data.get_batch()
                    # x_trains consists of one or more videos, each represented as an
                    # ndarray of shape (n_stacks, 3, 16, 112, 112).
                    # To train, randomly sample a batch of stacks
                    x = np.stack(
                        [x_i[np.random.randint(x_i.shape[0])] for x_i in x])
                    if defense_type == "Trainer":
                        defense.fit(x, y, batch_size=batch_size, nb_epochs=1)
                    else:
                        classifier.fit(x,
                                       y,
                                       batch_size=batch_size,
                                       nb_epochs=1)

        if defense_type == "Transform":
            # NOTE: Transform currently not supported
            logger.info(
                f"Transforming classifier with {defense_type} defense...")
            defense = load_defense_wrapper(config["defense"], classifier)
            classifier = defense()

        classifier.set_learning_phase(False)

        metrics_logger = metrics.MetricsLogger.from_config(
            config["metric"], skip_benign=skip_benign)
        if skip_benign:
            logger.info("Skipping benign classification...")
        else:
            # Evaluate the ART classifier on benign test examples
            logger.info(f"Loading test dataset {config['dataset']['name']}...")
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split_type="test",
                preprocessing_fn=preprocessing_fn,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )

            logger.info("Running inference on benign examples...")

            for x_batch, y_batch in tqdm(test_data, desc="Benign"):
                for x, y in zip(x_batch, y_batch):
                    # combine predictions across all stacks
                    with metrics.resource_context(
                            name="Inference",
                            profiler=config["metric"].get("profiler_type"),
                            computational_resource_dict=metrics_logger.
                            computational_resource_dict,
                    ):
                        y_pred = np.mean(classifier.predict(x, batch_size=1),
                                         axis=0)
                    metrics_logger.update_task(y, y_pred)
            metrics_logger.log_task()

        # Evaluate the ART classifier on adversarial test examples
        logger.info("Generating or loading / testing adversarial examples...")

        attack_config = config["attack"]
        attack_type = attack_config.get("type")
        targeted = bool(attack_config.get("kwargs", {}).get("targeted"))
        if targeted and attack_config.get("use_label"):
            raise ValueError("Targeted attacks cannot have 'use_label'")
        if attack_type == "preloaded":
            test_data = load_adversarial_dataset(
                attack_config,
                epochs=1,
                split_type="adversarial",
                preprocessing_fn=preprocessing_fn,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )
        else:
            attack = load_attack(attack_config, classifier)
            if targeted != getattr(attack, "targeted", False):
                logger.warning(
                    f"targeted config {targeted} != attack field {getattr(attack, 'targeted', False)}"
                )
            attack.set_params(batch_size=1)
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split_type="test",
                preprocessing_fn=preprocessing_fn,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )
            if targeted:
                label_targeter = load_label_targeter(
                    attack_config["targeted_labels"])
        for x_batch, y_batch in tqdm(test_data, desc="Attack"):
            if attack_type == "preloaded":
                x_batch = list(zip(*x_batch))
                if targeted:
                    y_batch = list(zip(*y_batch))
            for x, y in zip(x_batch, y_batch):
                with metrics.resource_context(
                        name="Attack",
                        profiler=config["metric"].get("profiler_type"),
                        computational_resource_dict=metrics_logger.
                        computational_resource_dict,
                ):
                    if attack_type == "preloaded":
                        x, x_adv = x
                        if targeted:
                            y, y_target = y
                    else:
                        # each x is of shape (n_stack, 3, 16, 112, 112)
                        #    n_stack varies
                        if attack_config.get("use_label"):
                            # expansion required due to preprocessing
                            y_input = np.repeat(y, x.shape[0])
                            x_adv = attack.generate(x=x, y=y_input)
                        elif targeted:
                            y_target = label_targeter.generate(y)
                            y_input = np.repeat(y_target, x.shape[0])
                            x_adv = attack.generate(x=x, y=y_input)
                        else:
                            x_adv = attack.generate(x=x)
                # combine predictions across all stacks
                y_pred_adv = np.mean(classifier.predict(x_adv, batch_size=1),
                                     axis=0)
                if targeted:
                    metrics_logger.update_task(y_target,
                                               y_pred_adv,
                                               adversarial=True)
                else:
                    metrics_logger.update_task(y, y_pred_adv, adversarial=True)
                metrics_logger.update_perturbation([x], [x_adv])
        metrics_logger.log_task(adversarial=True, targeted=targeted)
        return metrics_logger.results()
    def _evaluate(self, config: dict) -> dict:
        """
        Evaluate a config file for classification robustness against attack.
        """

        model_config = config["model"]
        # Scenario assumes preprocessing_fn makes images all same size
        classifier, preprocessing_fn = load_model(model_config)

        config_adhoc = config.get("adhoc") or {}
        train_epochs = config_adhoc["train_epochs"]
        src_class = config_adhoc["source_class"]
        tgt_class = config_adhoc["target_class"]
        fit_batch_size = config_adhoc.get("fit_batch_size",
                                          config["dataset"]["batch_size"])

        # Set random seed due to large variance in attack and defense success
        np.random.seed(config_adhoc["np_seed"])
        set_random_seed(config_adhoc["tf_seed"])
        use_poison_filtering_defense = config_adhoc.get(
            "use_poison_filtering_defense", True)
        if self.check_run:
            # filtering defense requires more than a single batch to run properly
            use_poison_filtering_defense = False

        logger.info(f"Loading dataset {config['dataset']['name']}...")

        clean_data = load_dataset(
            config["dataset"],
            epochs=1,
            split_type="train",
            preprocessing_fn=preprocessing_fn,
            shuffle_files=False,
        )

        attack_config = config["attack"]
        attack_type = attack_config.get("type")

        if attack_type == "preloaded":
            num_images_tgt_class = config_adhoc["num_images_target_class"]
            logger.info(
                f"Loading poison dataset {config_adhoc['poison_samples']['name']}..."
            )
            num_poisoned = int(config_adhoc["fraction_poisoned"] *
                               num_images_tgt_class)
            if num_poisoned == 0:
                raise ValueError(
                    "For the preloaded attack, fraction_poisoned must be set so that at least on data point is poisoned."
                )
            config_adhoc["poison_samples"]["batch_size"] = num_poisoned
            poison_data = load_dataset(
                config["adhoc"]["poison_samples"],
                epochs=1,
                split_type="poison",
                preprocessing_fn=None,
            )
        else:
            attack = load(attack_config)
        logger.info(
            "Building in-memory dataset for poisoning detection and training")
        fraction_poisoned = config["adhoc"]["fraction_poisoned"]
        poison_dataset_flag = config["adhoc"]["poison_dataset"]

        # detect_poison does not currently support data generators
        #     therefore, make in memory dataset
        x_train_all, y_train_all = [], []
        if attack_type == "preloaded":
            for x_clean, y_clean in clean_data:
                x_poison, y_poison = poison_data.get_batch()
                x_poison = np.array([xp for xp in x_poison], dtype=np.float)
                x_train_all.append(x_clean)
                y_train_all.append(y_clean)
                x_train_all.append(x_poison)
                y_train_all.append(y_poison)
            x_train_all = np.concatenate(x_train_all, axis=0)
            y_train_all = np.concatenate(y_train_all, axis=0)
        else:
            for x_train, y_train in clean_data:
                x_train_all.append(x_train)
                y_train_all.append(y_train)
            x_train_all = np.concatenate(x_train_all, axis=0)
            y_train_all = np.concatenate(y_train_all, axis=0)
            if poison_dataset_flag:
                total_count = np.bincount(y_train_all)[src_class]
                poison_count = int(fraction_poisoned * total_count)
                if poison_count == 0:
                    logger.warning(
                        f"No poisons generated with fraction_poisoned {fraction_poisoned} for class {src_class}."
                    )
                src_indices = np.where(y_train_all == src_class)[0]
                poisoned_indices = np.random.choice(src_indices,
                                                    size=poison_count,
                                                    replace=False)
                x_train_all, y_train_all = poison_dataset(
                    x_train_all,
                    y_train_all,
                    src_class,
                    tgt_class,
                    y_train_all.shape[0],
                    attack,
                    poisoned_indices,
                )

        y_train_all_categorical = to_categorical(y_train_all)

        if use_poison_filtering_defense:
            defense_config = config["defense"]

            defense_model_config = config_adhoc.get("defense_model",
                                                    model_config)
            defense_train_epochs = config_adhoc.get("defense_train_epochs",
                                                    train_epochs)
            classifier_for_defense, _ = load_model(defense_model_config)
            logger.info(
                f"Fitting model {defense_model_config['module']}.{defense_model_config['name']} "
                f"for defense {defense_config['name']}...")
            classifier_for_defense.fit(
                x_train_all,
                y_train_all_categorical,
                batch_size=fit_batch_size,
                nb_epochs=defense_train_epochs,
                verbose=False,
            )
            defense_fn = load_fn(defense_config)
            defense = defense_fn(classifier_for_defense, x_train_all,
                                 y_train_all_categorical)

            _, is_clean = defense.detect_poison(nb_clusters=2,
                                                nb_dims=43,
                                                reduce="PCA")
            is_clean = np.array(is_clean)
            logger.info(f"Total clean data points: {np.sum(is_clean)}")

            logger.info("Filtering out detected poisoned samples")
            indices_to_keep = is_clean == 1
            x_train_final = x_train_all[indices_to_keep]
            y_train_final = y_train_all_categorical[indices_to_keep]
        else:
            logger.info(
                "Defense does not require filtering. Model fitting will use all data."
            )
            x_train_final = x_train_all
            y_train_final = y_train_all_categorical
        if len(x_train_final):
            logger.info(
                f"Fitting model of {model_config['module']}.{model_config['name']}..."
            )
            classifier.fit(
                x_train_final,
                y_train_final,
                batch_size=fit_batch_size,
                nb_epochs=train_epochs,
                verbose=False,
            )
        else:
            logger.warning(
                "All data points filtered by defense. Skipping training")

        logger.info("Validating on clean test data")
        config["dataset"]["batch_size"] = fit_batch_size
        test_data = load_dataset(
            config["dataset"],
            epochs=1,
            split_type="test",
            preprocessing_fn=preprocessing_fn,
            shuffle_files=False,
        )
        validation_metric = metrics.MetricList("categorical_accuracy")
        target_class_benign_metric = metrics.MetricList("categorical_accuracy")
        for x, y in tqdm(test_data, desc="Testing"):
            y_pred = classifier.predict(x)
            validation_metric.append(y, y_pred)
            y_pred_tgt_class = y_pred[y == src_class]
            if len(y_pred_tgt_class):
                target_class_benign_metric.append(
                    [src_class] * len(y_pred_tgt_class), y_pred_tgt_class)
        logger.info(
            f"Unpoisoned validation accuracy: {validation_metric.mean():.2%}")
        logger.info(
            f"Unpoisoned validation accuracy on targeted class: {target_class_benign_metric.mean():.2%}"
        )
        results = {
            "validation_accuracy":
            validation_metric.mean(),
            "validation_accuracy_targeted_class":
            target_class_benign_metric.mean(),
        }

        test_metric = metrics.MetricList("categorical_accuracy")
        targeted_test_metric = metrics.MetricList("categorical_accuracy")

        logger.info("Testing on poisoned test data")
        if attack_type == "preloaded":
            test_data_poison = load_dataset(
                config_adhoc["poison_samples"],
                epochs=1,
                split_type="poison_test",
                preprocessing_fn=None,
            )
            for x_poison_test, y_poison_test in tqdm(test_data_poison,
                                                     desc="Testing poison"):
                x_poison_test = np.array([xp for xp in x_poison_test],
                                         dtype=np.float)
                y_pred = classifier.predict(x_poison_test)
                y_true = [src_class] * len(y_pred)
                targeted_test_metric.append(y_poison_test, y_pred)
                test_metric.append(y_true, y_pred)
            test_data_clean = load_dataset(
                config["dataset"],
                epochs=1,
                split_type="test",
                preprocessing_fn=preprocessing_fn,
                shuffle_files=False,
            )
            for x_clean_test, y_clean_test in tqdm(test_data_clean,
                                                   desc="Testing clean"):
                x_clean_test = np.array([xp for xp in x_clean_test],
                                        dtype=np.float)
                y_pred = classifier.predict(x_clean_test)
                test_metric.append(y_clean_test, y_pred)

        elif poison_dataset_flag:
            logger.info("Testing on poisoned test data")
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split_type="test",
                preprocessing_fn=preprocessing_fn,
                shuffle_files=False,
            )
            for x_test, y_test in tqdm(test_data, desc="Testing"):
                src_indices = np.where(y_test == src_class)[0]
                poisoned_indices = src_indices  # Poison entire class
                x_test, _ = poison_dataset(
                    x_test,
                    y_test,
                    src_class,
                    tgt_class,
                    len(y_test),
                    attack,
                    poisoned_indices,
                )
                y_pred = classifier.predict(x_test)
                test_metric.append(y_test, y_pred)

                y_pred_targeted = y_pred[y_test == src_class]
                if not len(y_pred_targeted):
                    continue
                targeted_test_metric.append([tgt_class] * len(y_pred_targeted),
                                            y_pred_targeted)

        if poison_dataset_flag or attack_type == "preloaded":
            results["test_accuracy"] = test_metric.mean()
            results[
                "targeted_misclassification_accuracy"] = targeted_test_metric.mean(
                )
            logger.info(f"Test accuracy: {test_metric.mean():.2%}")
            logger.info(
                f"Test targeted misclassification accuracy: {targeted_test_metric.mean():.2%}"
            )

        return results
Пример #4
0
    def _evaluate(
        self,
        config: dict,
        num_eval_batches: Optional[int],
        skip_benign: Optional[bool],
        skip_attack: Optional[bool],
        skip_misclassified: Optional[bool],
    ) -> dict:
        """
        Evaluate a config file for classification robustness against attack.

        Note: num_eval_batches shouldn't be set for poisoning scenario and will raise an
        error if it is
        """
        if config["sysconfig"].get("use_gpu"):
            os.environ["TF_CUDNN_DETERMINISM"] = "1"
        if num_eval_batches:
            raise ValueError(
                "num_eval_batches shouldn't be set for poisoning scenario")
        if skip_benign:
            raise ValueError(
                "skip_benign shouldn't be set for poisoning scenario")
        if skip_attack:
            raise ValueError(
                "skip_attack shouldn't be set for poisoning scenario")
        if skip_misclassified:
            raise ValueError(
                "skip_misclassified shouldn't be set for poisoning scenario")

        model_config = config["model"]
        # Scenario assumes canonical preprocessing_fn is used makes images all same size
        classifier, _ = load_model(model_config)

        config_adhoc = config.get("adhoc") or {}
        train_epochs = config_adhoc["train_epochs"]
        src_class = config_adhoc["source_class"]
        tgt_class = config_adhoc["target_class"]
        fit_batch_size = config_adhoc.get("fit_batch_size",
                                          config["dataset"]["batch_size"])

        if not config["sysconfig"].get("use_gpu"):
            conf = ConfigProto(intra_op_parallelism_threads=1)
            set_session(Session(config=conf))

        # Set random seed due to large variance in attack and defense success
        np.random.seed(config_adhoc["split_id"])
        set_random_seed(config_adhoc["split_id"])
        random.seed(config_adhoc["split_id"])
        use_poison_filtering_defense = config_adhoc.get(
            "use_poison_filtering_defense", True)
        if self.check_run:
            # filtering defense requires more than a single batch to run properly
            use_poison_filtering_defense = False

        logger.info(f"Loading dataset {config['dataset']['name']}...")

        clean_data = load_dataset(
            config["dataset"],
            epochs=1,
            split=config["dataset"].get("train_split", "train"),
            preprocessing_fn=poison_scenario_preprocessing,
            shuffle_files=False,
        )

        attack_config = config["attack"]
        attack_type = attack_config.get("type")

        fraction_poisoned = config["adhoc"]["fraction_poisoned"]
        # Flag for whether to poison dataset -- used to evaluate
        #     performance of defense on clean data
        poison_dataset_flag = config["adhoc"]["poison_dataset"]
        # detect_poison does not currently support data generators
        #     therefore, make in memory dataset
        x_train_all, y_train_all = [], []

        if attack_type == "preloaded":
            # Number of datapoints in train split of target clasc
            num_images_tgt_class = config_adhoc["num_images_target_class"]
            logger.info(
                f"Loading poison dataset {config_adhoc['poison_samples']['name']}..."
            )
            num_poisoned = int(config_adhoc["fraction_poisoned"] *
                               num_images_tgt_class)
            if num_poisoned == 0:
                raise ValueError(
                    "For the preloaded attack, fraction_poisoned must be set so that at least on data point is poisoned."
                )
            # Set batch size to number of poisons -- read only one batch of preloaded poisons
            config_adhoc["poison_samples"]["batch_size"] = num_poisoned
            poison_data = load_dataset(
                config["adhoc"]["poison_samples"],
                epochs=1,
                split="poison",
                preprocessing_fn=None,
            )

            logger.info(
                "Building in-memory dataset for poisoning detection and training"
            )
            for x_clean, y_clean in clean_data:
                x_train_all.append(x_clean)
                y_train_all.append(y_clean)
            x_poison, y_poison = poison_data.get_batch()
            x_poison = np.array([xp for xp in x_poison], dtype=np.float32)
            x_train_all.append(x_poison)
            y_train_all.append(y_poison)
            x_train_all = np.concatenate(x_train_all, axis=0)
            y_train_all = np.concatenate(y_train_all, axis=0)
        else:
            attack = load(attack_config)
            logger.info(
                "Building in-memory dataset for poisoning detection and training"
            )
            for x_train, y_train in clean_data:
                x_train_all.append(x_train)
                y_train_all.append(y_train)
            x_train_all = np.concatenate(x_train_all, axis=0)
            y_train_all = np.concatenate(y_train_all, axis=0)
            if poison_dataset_flag:
                total_count = np.bincount(y_train_all)[src_class]
                poison_count = int(fraction_poisoned * total_count)
                if poison_count == 0:
                    logger.warning(
                        f"No poisons generated with fraction_poisoned {fraction_poisoned} for class {src_class}."
                    )
                src_indices = np.where(y_train_all == src_class)[0]
                poisoned_indices = np.sort(
                    np.random.choice(src_indices,
                                     size=poison_count,
                                     replace=False))
                x_train_all, y_train_all = poison_dataset(
                    x_train_all,
                    y_train_all,
                    src_class,
                    tgt_class,
                    y_train_all.shape[0],
                    attack,
                    poisoned_indices,
                )

        y_train_all_categorical = to_categorical(y_train_all)

        # Flag to determine whether defense_classifier is trained directly
        #     (default API) or is trained as part of detect_poisons method
        fit_defense_classifier_outside_defense = config_adhoc.get(
            "fit_defense_classifier_outside_defense", True)
        # Flag to determine whether defense_classifier uses sparse
        #     or categorical labels
        defense_categorical_labels = config_adhoc.get(
            "defense_categorical_labels", True)
        if use_poison_filtering_defense:
            if defense_categorical_labels:
                y_train_defense = y_train_all_categorical
            else:
                y_train_defense = y_train_all

            defense_config = config["defense"]
            detection_kwargs = config_adhoc.get("detection_kwargs", dict())

            defense_model_config = config_adhoc.get("defense_model",
                                                    model_config)
            defense_train_epochs = config_adhoc.get("defense_train_epochs",
                                                    train_epochs)

            # Assumes classifier_for_defense and classifier use same preprocessing function
            classifier_for_defense, _ = load_model(defense_model_config)
            logger.info(
                f"Fitting model {defense_model_config['module']}.{defense_model_config['name']} "
                f"for defense {defense_config['name']}...")
            if fit_defense_classifier_outside_defense:
                classifier_for_defense.fit(
                    x_train_all,
                    y_train_defense,
                    batch_size=fit_batch_size,
                    nb_epochs=defense_train_epochs,
                    verbose=False,
                    shuffle=True,
                )
            defense_fn = load_fn(defense_config)
            defense = defense_fn(classifier_for_defense, x_train_all,
                                 y_train_defense)

            _, is_clean = defense.detect_poison(**detection_kwargs)
            is_clean = np.array(is_clean)
            logger.info(f"Total clean data points: {np.sum(is_clean)}")

            logger.info("Filtering out detected poisoned samples")
            indices_to_keep = is_clean == 1
            x_train_final = x_train_all[indices_to_keep]
            y_train_final = y_train_all_categorical[indices_to_keep]
        else:
            logger.info(
                "Defense does not require filtering. Model fitting will use all data."
            )
            x_train_final = x_train_all
            y_train_final = y_train_all_categorical
        if len(x_train_final):
            logger.info(
                f"Fitting model of {model_config['module']}.{model_config['name']}..."
            )
            classifier.fit(
                x_train_final,
                y_train_final,
                batch_size=fit_batch_size,
                nb_epochs=train_epochs,
                verbose=False,
                shuffle=True,
            )
        else:
            logger.warning(
                "All data points filtered by defense. Skipping training")

        logger.info("Validating on clean test data")
        test_data = load_dataset(
            config["dataset"],
            epochs=1,
            split=config["dataset"].get("eval_split", "test"),
            preprocessing_fn=poison_scenario_preprocessing,
            shuffle_files=False,
        )
        benign_validation_metric = metrics.MetricList("categorical_accuracy")
        target_class_benign_metric = metrics.MetricList("categorical_accuracy")
        for x, y in tqdm(test_data, desc="Testing"):
            # Ensure that input sample isn't overwritten by classifier
            x.flags.writeable = False
            y_pred = classifier.predict(x)
            benign_validation_metric.add_results(y, y_pred)
            y_pred_tgt_class = y_pred[y == src_class]
            if len(y_pred_tgt_class):
                target_class_benign_metric.add_results(
                    [src_class] * len(y_pred_tgt_class), y_pred_tgt_class)
        logger.info(
            f"Unpoisoned validation accuracy: {benign_validation_metric.mean():.2%}"
        )
        logger.info(
            f"Unpoisoned validation accuracy on targeted class: {target_class_benign_metric.mean():.2%}"
        )
        results = {
            "benign_validation_accuracy":
            benign_validation_metric.mean(),
            "benign_validation_accuracy_targeted_class":
            target_class_benign_metric.mean(),
        }

        poisoned_test_metric = metrics.MetricList("categorical_accuracy")
        poisoned_targeted_test_metric = metrics.MetricList(
            "categorical_accuracy")

        logger.info("Testing on poisoned test data")
        if attack_type == "preloaded":
            test_data_poison = load_dataset(
                config_adhoc["poison_samples"],
                epochs=1,
                split="poison_test",
                preprocessing_fn=None,
            )
            for x_poison_test, y_poison_test in tqdm(test_data_poison,
                                                     desc="Testing poison"):
                x_poison_test = np.array([xp for xp in x_poison_test],
                                         dtype=np.float32)
                y_pred = classifier.predict(x_poison_test)
                y_true = [src_class] * len(y_pred)
                poisoned_targeted_test_metric.add_results(
                    y_poison_test, y_pred)
                poisoned_test_metric.add_results(y_true, y_pred)
            test_data_clean = load_dataset(
                config["dataset"],
                epochs=1,
                split=config["dataset"].get("eval_split", "test"),
                preprocessing_fn=poison_scenario_preprocessing,
                shuffle_files=False,
            )
            for x_clean_test, y_clean_test in tqdm(test_data_clean,
                                                   desc="Testing clean"):
                x_clean_test = np.array([xp for xp in x_clean_test],
                                        dtype=np.float32)
                y_pred = classifier.predict(x_clean_test)
                poisoned_test_metric.add_results(y_clean_test, y_pred)

        elif poison_dataset_flag:
            logger.info("Testing on poisoned test data")
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split=config["dataset"].get("eval_split", "test"),
                preprocessing_fn=poison_scenario_preprocessing,
                shuffle_files=False,
            )
            for x_test, y_test in tqdm(test_data, desc="Testing"):
                src_indices = np.where(y_test == src_class)[0]
                poisoned_indices = src_indices  # Poison entire class
                x_test, _ = poison_dataset(
                    x_test,
                    y_test,
                    src_class,
                    tgt_class,
                    len(y_test),
                    attack,
                    poisoned_indices,
                )
                y_pred = classifier.predict(x_test)
                poisoned_test_metric.add_results(y_test, y_pred)

                y_pred_targeted = y_pred[y_test == src_class]
                if not len(y_pred_targeted):
                    continue
                poisoned_targeted_test_metric.add_results(
                    [tgt_class] * len(y_pred_targeted), y_pred_targeted)

        if poison_dataset_flag or attack_type == "preloaded":
            results["poisoned_test_accuracy"] = poisoned_test_metric.mean()
            results[
                "poisoned_targeted_misclassification_accuracy"] = poisoned_targeted_test_metric.mean(
                )
            logger.info(f"Test accuracy: {poisoned_test_metric.mean():.2%}")
            logger.info(
                f"Test targeted misclassification accuracy: {poisoned_targeted_test_metric.mean():.2%}"
            )

        return results
Пример #5
0
    def _evaluate(self, config: dict) -> dict:
        """
        Evaluate the config and return a results dict
        """

        model_config = config["model"]
        classifier, preprocessing_fn = load_model(model_config)

        defense_config = config.get("defense") or {}
        defense_type = defense_config.get("type")

        if defense_type in ["Preprocessor", "Postprocessor"]:
            logger.info(f"Applying internal {defense_type} defense to classifier")
            classifier = load_defense_internal(config["defense"], classifier)

        if model_config["fit"]:
            classifier.set_learning_phase(True)
            logger.info(
                f"Fitting model {model_config['module']}.{model_config['name']}..."
            )
            fit_kwargs = model_config["fit_kwargs"]

            logger.info(f"Loading train dataset {config['dataset']['name']}...")
            train_data = load_dataset(
                config["dataset"],
                epochs=fit_kwargs["nb_epochs"],
                split_type="train",
                preprocessing_fn=preprocessing_fn,
            )
            if defense_type == "Trainer":
                logger.info(f"Training with {defense_type} defense...")
                defense = load_defense_wrapper(config["defense"], classifier)
                defense.fit_generator(train_data, **fit_kwargs)
            else:
                logger.info(f"Fitting classifier on clean train dataset...")
                classifier.fit_generator(train_data, **fit_kwargs)

        if defense_type == "Transform":
            # NOTE: Transform currently not supported
            logger.info(f"Transforming classifier with {defense_type} defense...")
            defense = load_defense_wrapper(config["defense"], classifier)
            classifier = defense()

        classifier.set_learning_phase(False)

        # Evaluate the ART classifier on benign test examples
        logger.info(f"Loading test dataset {config['dataset']['name']}...")
        test_data_generator = load_dataset(
            config["dataset"],
            epochs=1,
            split_type="test",
            preprocessing_fn=preprocessing_fn,
        )
        logger.info("Running inference on benign examples...")
        metrics_logger = metrics.MetricsLogger.from_config(config["metric"])

        for x, y in tqdm(test_data_generator, desc="Benign"):
            y_pred = classifier.predict(x)
            metrics_logger.update_task(y, y_pred)
        metrics_logger.log_task()

        # Evaluate the ART classifier on adversarial test examples
        logger.info("Generating / testing adversarial examples...")

        attack = load_attack(config["attack"], classifier)
        test_data_generator = load_dataset(
            config["dataset"],
            epochs=1,
            split_type="test",
            preprocessing_fn=preprocessing_fn,
        )
        for x, y in tqdm(test_data_generator, desc="Attack"):
            x_adv = attack.generate(x=x)
            y_pred_adv = classifier.predict(x_adv)
            metrics_logger.update_task(y, y_pred_adv, adversarial=True)
            metrics_logger.update_perturbation(x, x_adv)
        metrics_logger.log_task(adversarial=True)
        return metrics_logger.results()
Пример #6
0
    def _evaluate(
        self,
        config: dict,
        num_eval_batches: Optional[int],
        skip_benign: Optional[bool],
        skip_attack: Optional[bool],
    ) -> dict:
        """
        Evaluate the config and return a results dict
        """
        if config["dataset"]["batch_size"] != 1:
            raise ValueError(
                "batch_size must be 1 for evaluation, due to variable length inputs.\n"
                "    If training, set config['model']['fit_kwargs']['fit_batch_size']"
            )

        model_config = config["model"]
        classifier, fit_preprocessing_fn = load_model(model_config)

        defense_config = config.get("defense") or {}
        defense_type = defense_config.get("type")

        if defense_type in ["Preprocessor", "Postprocessor"]:
            logger.info(f"Applying internal {defense_type} defense to classifier")
            classifier = load_defense_internal(config["defense"], classifier)

        if model_config["fit"]:
            classifier.set_learning_phase(True)
            logger.info(
                f"Fitting model {model_config['module']}.{model_config['name']}..."
            )
            fit_kwargs = model_config["fit_kwargs"]

            logger.info(f"Loading train dataset {config['dataset']['name']}...")
            batch_size = config["dataset"].pop("batch_size")
            config["dataset"]["batch_size"] = fit_kwargs.get(
                "fit_batch_size", batch_size
            )
            train_data = load_dataset(
                config["dataset"],
                epochs=fit_kwargs["nb_epochs"],
                split=config["dataset"].get("train_split", "train"),
                preprocessing_fn=fit_preprocessing_fn,
                shuffle_files=True,
            )
            config["dataset"]["batch_size"] = batch_size
            if defense_type == "Trainer":
                logger.info(f"Training with {defense_type} defense...")
                defense = load_defense_wrapper(config["defense"], classifier)
                defense.fit_generator(train_data, **fit_kwargs)
            else:
                logger.info("Fitting classifier on clean train dataset...")
                classifier.fit_generator(train_data, **fit_kwargs)

        if defense_type == "Transform":
            # NOTE: Transform currently not supported
            logger.info(f"Transforming classifier with {defense_type} defense...")
            defense = load_defense_wrapper(config["defense"], classifier)
            classifier = defense()

        classifier.set_learning_phase(False)

        attack_config = config["attack"]
        attack_type = attack_config.get("type")

        targeted = bool(attack_config.get("kwargs", {}).get("targeted"))
        metrics_logger = metrics.MetricsLogger.from_config(
            config["metric"],
            skip_benign=skip_benign,
            skip_attack=skip_attack,
            targeted=targeted,
        )

        if config["dataset"]["batch_size"] != 1:
            logger.warning("Evaluation batch_size != 1 may not be supported.")

        eval_split = config["dataset"].get("eval_split", "test")
        if skip_benign:
            logger.info("Skipping benign classification...")
        else:
            # Evaluate the ART classifier on benign test examples
            logger.info(f"Loading test dataset {config['dataset']['name']}...")
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split=eval_split,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )

            logger.info("Running inference on benign examples...")
            for x, y in tqdm(test_data, desc="Benign"):
                # Ensure that input sample isn't overwritten by classifier
                x.flags.writeable = False
                with metrics.resource_context(
                    name="Inference",
                    profiler=config["metric"].get("profiler_type"),
                    computational_resource_dict=metrics_logger.computational_resource_dict,
                ):
                    y_pred = classifier.predict(x)
                metrics_logger.update_task(y, y_pred)
            metrics_logger.log_task()

        if skip_attack:
            logger.info("Skipping attack generation...")
            return metrics_logger.results()

        # Evaluate the ART classifier on adversarial test examples
        logger.info("Generating or loading / testing adversarial examples...")

        if targeted and attack_config.get("use_label"):
            raise ValueError("Targeted attacks cannot have 'use_label'")
        if attack_type == "preloaded":
            test_data = load_adversarial_dataset(
                attack_config,
                epochs=1,
                split="adversarial",
                num_batches=num_eval_batches,
                shuffle_files=False,
            )
        else:
            attack = load_attack(attack_config, classifier)
            if targeted != getattr(attack, "targeted", False):
                logger.warning(
                    f"targeted config {targeted} != attack field {getattr(attack, 'targeted', False)}"
                )
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split=eval_split,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )
            if targeted:
                label_targeter = load_label_targeter(attack_config["targeted_labels"])

        export_samples = config["scenario"].get("export_samples")
        if export_samples is not None and export_samples > 0:
            sample_exporter = SampleExporter(
                self.scenario_output_dir, test_data.context, export_samples
            )
        else:
            sample_exporter = None

        for x, y in tqdm(test_data, desc="Attack"):
            with metrics.resource_context(
                name="Attack",
                profiler=config["metric"].get("profiler_type"),
                computational_resource_dict=metrics_logger.computational_resource_dict,
            ):
                if attack_type == "preloaded":
                    x, x_adv = x
                    if targeted:
                        y, y_target = y
                else:
                    generate_kwargs = deepcopy(attack_config.get("generate_kwargs", {}))
                    if attack_config.get("use_label"):
                        generate_kwargs["y"] = y
                    elif targeted:
                        y_target = label_targeter.generate(y)
                        generate_kwargs["y"] = y_target
                    x_adv = attack.generate(x=x, **generate_kwargs)

            # Ensure that input sample isn't overwritten by classifier
            x_adv.flags.writeable = False
            y_pred_adv = classifier.predict(x_adv)
            metrics_logger.update_task(y, y_pred_adv, adversarial=True)
            if targeted:
                metrics_logger.update_task(
                    y_target, y_pred_adv, adversarial=True, targeted=True
                )
            metrics_logger.update_perturbation(x, x_adv)
            if sample_exporter is not None:
                sample_exporter.export(x, x_adv, y, y_pred_adv)
        metrics_logger.log_task(adversarial=True)
        if targeted:
            metrics_logger.log_task(adversarial=True, targeted=True)
        return metrics_logger.results()
    def _evaluate(
        self,
        config: dict,
        num_eval_batches: Optional[int],
        skip_benign: Optional[bool],
        skip_attack: Optional[bool],
    ) -> dict:
        """
        Evaluate the config and return a results dict
        """

        model_config = config["model"]
        estimator, _ = load_model(model_config)

        defense_config = config.get("defense") or {}
        defense_type = defense_config.get("type")

        if defense_type in ["Preprocessor", "Postprocessor"]:
            logger.info(f"Applying internal {defense_type} defense to estimator")
            estimator = load_defense_internal(config["defense"], estimator)

        attack_config = config["attack"]
        attack_channels = attack_config.get("generate_kwargs", {}).get("channels")

        if attack_channels is None:
            if self.attack_modality == "sar":
                logger.info("No mask configured. Attacking all SAR channels")
                attack_channels = range(4)
            elif self.attack_modality == "eo":
                logger.info("No mask configured. Attacking all EO channels")
                attack_channels = range(4, 14)
            elif self.attack_modality == "both":
                logger.info("No mask configured. Attacking all SAR and EO channels")
                attack_channels = range(14)

        else:
            assert isinstance(
                attack_channels, list
            ), "Mask is specified, but incorrect format. Expected list"
            attack_channels = np.array(attack_channels)
            if self.attack_modality == "sar":
                assert np.all(
                    np.logical_and(attack_channels >= 0, attack_channels < 4)
                ), "Selected SAR-only attack modality, but specify non-SAR channels"
            elif self.attack_modality == "eo":
                assert np.all(
                    np.logical_and(attack_channels >= 4, attack_channels < 14)
                ), "Selected EO-only attack modality, but specify non-EO channels"
            elif self.attack_modality == "both":
                assert np.all(
                    np.logical_and(attack_channels >= 0, attack_channels < 14)
                ), "Selected channels are out-of-bounds"

        if model_config["fit"]:
            try:
                estimator.set_learning_phase(True)
                logger.info(
                    f"Fitting model {model_config['module']}.{model_config['name']}..."
                )
                fit_kwargs = model_config["fit_kwargs"]

                logger.info(f"Loading train dataset {config['dataset']['name']}...")
                train_data = load_dataset(
                    config["dataset"],
                    epochs=fit_kwargs["nb_epochs"],
                    split=config["dataset"].get("train_split", "train"),
                    shuffle_files=True,
                )
                if defense_type == "Trainer":
                    logger.info(f"Training with {defense_type} defense...")
                    defense = load_defense_wrapper(config["defense"], estimator)
                    defense.fit_generator(train_data, **fit_kwargs)
                else:
                    logger.info("Fitting estimator on clean train dataset...")
                    estimator.fit_generator(train_data, **fit_kwargs)
            except NotImplementedError:
                raise NotImplementedError(
                    "Training has not yet been implemented for object detectors"
                )

        if defense_type == "Transform":
            # NOTE: Transform currently not supported
            logger.info(f"Transforming estimator with {defense_type} defense...")
            defense = load_defense_wrapper(config["defense"], estimator)
            estimator = defense()

        try:
            estimator.set_learning_phase(False)
        except NotImplementedError:
            logger.warning(
                "Unable to set estimator's learning phase. As of ART 1.4.1, "
                "this is not yet supported for object detectors."
            )

        attack_type = attack_config.get("type")
        targeted = bool(attack_config.get("kwargs", {}).get("targeted"))

        performance_metrics = deepcopy(config["metric"])
        performance_metrics.pop("perturbation")
        performance_logger = metrics.MetricsLogger.from_config(
            performance_metrics,
            skip_benign=skip_benign,
            skip_attack=skip_attack,
            targeted=targeted,
        )

        eval_split = config["dataset"].get("eval_split", "test")
        if skip_benign:
            logger.info("Skipping benign classification...")
        else:
            # Evaluate the ART estimator on benign test examples
            logger.info(f"Loading test dataset {config['dataset']['name']}...")
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split=eval_split,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )

            logger.info("Running inference on benign examples...")
            for x, y in tqdm(test_data, desc="Benign"):
                # Ensure that input sample isn't overwritten by estimator
                x.flags.writeable = False
                with metrics.resource_context(
                    name="Inference",
                    profiler=config["metric"].get("profiler_type"),
                    computational_resource_dict=performance_logger.computational_resource_dict,
                ):
                    y_pred = estimator.predict(x)
                performance_logger.update_task(y, y_pred)
            performance_logger.log_task()

        if skip_attack:
            logger.info("Skipping attack generation...")
            return performance_logger.results()

        # Evaluate the ART estimator on adversarial test examples
        logger.info("Generating or loading / testing adversarial examples...")

        perturbation_metrics = deepcopy(config["metric"])
        perturbation_metrics.pop("task")
        if self.attack_modality in ("sar", "both"):
            sar_perturbation_logger = metrics.MetricsLogger.from_config(
                perturbation_metrics,
                skip_benign=True,
                skip_attack=False,
                targeted=targeted,
            )
        else:
            sar_perturbation_logger = None

        if self.attack_modality in ("eo", "both"):
            eo_perturbation_logger = metrics.MetricsLogger.from_config(
                perturbation_metrics,
                skip_benign=True,
                skip_attack=False,
                targeted=targeted,
            )
        else:
            eo_perturbation_logger = None

        if targeted and attack_config.get("use_label"):
            raise ValueError("Targeted attacks cannot have 'use_label'")
        if attack_type == "preloaded":
            test_data = load_adversarial_dataset(
                attack_config,
                epochs=1,
                split="adversarial",
                num_batches=num_eval_batches,
                shuffle_files=False,
            )
        else:
            attack = load_attack(attack_config, estimator)
            if targeted != getattr(attack, "targeted", False):
                logger.warning(
                    f"targeted config {targeted} != attack field {getattr(attack, 'targeted', False)}"
                )
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split=eval_split,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )
            if targeted:
                label_targeter = load_label_targeter(attack_config["targeted_labels"])

        export_samples = config["scenario"].get("export_samples")
        if export_samples is not None and export_samples > 0:
            sample_exporter = SampleExporter(
                self.scenario_output_dir, test_data.context, export_samples
            )
        else:
            sample_exporter = None

        for x, y in tqdm(test_data, desc="Attack"):
            with metrics.resource_context(
                name="Attack",
                profiler=config["metric"].get("profiler_type"),
                computational_resource_dict=performance_logger.computational_resource_dict,
            ):
                if attack_type == "preloaded":
                    logger.warning(
                        "Specified preloaded attack. Ignoring `attack_modality` parameter"
                    )
                    if len(x) == 2:
                        x, x_adv = x
                    else:
                        x_adv = x
                    if targeted:
                        y, y_target = y
                else:
                    generate_kwargs = deepcopy(attack_config.get("generate_kwargs", {}))
                    generate_kwargs["mask"] = attack_channels
                    if attack_config.get("use_label"):
                        generate_kwargs["y"] = y
                    elif targeted:
                        y_target = label_targeter.generate(y)
                        generate_kwargs["y"] = y_target
                    x_adv = attack.generate(x=x, **generate_kwargs)

            # Ensure that input sample isn't overwritten by estimator
            x_adv.flags.writeable = False
            y_pred_adv = estimator.predict(x_adv)
            performance_logger.update_task(y, y_pred_adv, adversarial=True)
            if targeted:
                performance_logger.update_task(
                    y_target, y_pred_adv, adversarial=True, targeted=True
                )

            # Update perturbation metrics for SAR/EO separately
            x_sar = np.stack(
                (x[..., 0] + 1j * x[..., 1], x[..., 2] + 1j * x[..., 3]), axis=3
            )
            x_adv_sar = np.stack(
                (
                    x_adv[..., 0] + 1j * x_adv[..., 1],
                    x_adv[..., 2] + 1j * x_adv[..., 3],
                ),
                axis=3,
            )
            x_eo = x[..., 4:]
            x_adv_eo = x_adv[..., 4:]
            if sar_perturbation_logger is not None:
                sar_perturbation_logger.update_perturbation(x_sar, x_adv_sar)
            if eo_perturbation_logger is not None:
                eo_perturbation_logger.update_perturbation(x_eo, x_adv_eo)

            if sample_exporter is not None:
                sample_exporter.export(x, x_adv, y, y_pred_adv)

        performance_logger.log_task(adversarial=True)
        if targeted:
            performance_logger.log_task(adversarial=True, targeted=True)

        # Merge performance, SAR, EO results
        combined_results = performance_logger.results()
        if sar_perturbation_logger is not None:
            combined_results.update(
                {f"sar_{k}": v for k, v in sar_perturbation_logger.results().items()}
            )
        if eo_perturbation_logger is not None:
            combined_results.update(
                {f"eo_{k}": v for k, v in eo_perturbation_logger.results().items()}
            )
        return combined_results
    def _evaluate(self, config: dict) -> dict:
        """
        Evaluate a config file for classification robustness against attack.
        """
        model_config = config["model"]
        classifier, preprocessing_fn = load_model(model_config)

        n_tbins = 100  # number of time bins in spectrogram input to model

        task_metric = metrics.categorical_accuracy

        # Train ART classifier
        if not model_config["weights_file"]:
            classifier.set_learning_phase(True)
            logger.info(
                f"Fitting model {model_config['module']}.{model_config['name']}..."
            )
            fit_kwargs = model_config["fit_kwargs"]
            train_data_generator = load_dataset(
                config["dataset"],
                epochs=fit_kwargs["nb_epochs"],
                split_type="train",
                preprocessing_fn=preprocessing_fn,
            )

            for cnt, (x, y) in tqdm(enumerate(train_data_generator)):
                x_seg, y_seg = segment(x, y, n_tbins)
                classifier.fit(
                    x_seg,
                    y_seg,
                    batch_size=config["dataset"]["batch_size"],
                    nb_epochs=1,
                    verbose=True,
                )

                if (cnt + 1) % train_data_generator.batches_per_epoch == 0:
                    # evaluate on validation examples
                    val_data_generator = load_dataset(
                        config["dataset"],
                        epochs=1,
                        split_type="validation",
                        preprocessing_fn=preprocessing_fn,
                    )

                    cnt = 0
                    validation_accuracies = []
                    for x_val, y_val in tqdm(val_data_generator):
                        x_val_seg, y_val_seg = segment(x_val, y_val, n_tbins)
                        y_pred = classifier.predict(x_val_seg)
                        validation_accuracies.extend(
                            task_metric(y_val_seg, y_pred))
                        cnt += len(y_val_seg)
                    validation_accuracy = sum(validation_accuracies) / cnt
                    logger.info(
                        "Validation accuracy: {}".format(validation_accuracy))

        classifier.set_learning_phase(False)
        # Evaluate ART classifier on test examples
        logger.info(f"Loading testing dataset {config['dataset']['name']}...")
        test_data_generator = load_dataset(
            config["dataset"],
            epochs=1,
            split_type="test",
            preprocessing_fn=preprocessing_fn,
        )

        logger.info("Running inference on benign test examples...")

        cnt = 0
        benign_accuracies = []
        for x, y in tqdm(test_data_generator, desc="Benign"):
            x_seg, y_seg = segment(x, y, n_tbins)
            y_pred = classifier.predict(x_seg)
            benign_accuracies.extend(task_metric(y_seg, y_pred))
            cnt += len(y_seg)

        benign_accuracy = sum(benign_accuracies) / cnt
        logger.info(f"Accuracy on benign test examples: {benign_accuracy:.2%}")

        # Evaluate the ART classifier on adversarial test examples
        logger.info("Generating / testing adversarial examples...")
        attack = load_attack(config["attack"], classifier)

        test_data_generator = load_dataset(
            config["dataset"],
            epochs=1,
            split_type="test",
            preprocessing_fn=preprocessing_fn,
        )

        cnt = 0
        adversarial_accuracies = []
        for x, y in tqdm(test_data_generator, desc="Attack"):
            x_seg, y_seg = segment(x, y, n_tbins)
            x_adv = attack.generate(x=x_seg)
            y_pred = classifier.predict(x_adv)
            adversarial_accuracies.extend(task_metric(y_seg, y_pred))
            cnt += len(y_seg)
        adversarial_accuracy = sum(adversarial_accuracies) / cnt
        logger.info(
            f"Accuracy on adversarial test examples: {adversarial_accuracy:.2%}"
        )

        results = {
            "mean_benign_accuracy": benign_accuracy,
            "mean_adversarial_accuracy": adversarial_accuracy,
        }
        return results
Пример #9
0
    def _evaluate(
        self,
        config: dict,
        num_eval_batches: Optional[int],
        skip_benign: Optional[bool],
        skip_attack: Optional[bool],
        skip_misclassified: Optional[bool],
    ) -> dict:
        """
        Evaluate the config and return a results dict
        """
        if skip_misclassified:
            raise ValueError(
                "skip_misclassified shouldn't be set for ASR scenario")
        model_config = config["model"]
        estimator, fit_preprocessing_fn = load_model(model_config)

        audio_channel_config = config.get("adhoc", {}).get("audio_channel")
        if audio_channel_config is not None:
            logger.info("loading audio channel")
            for k in "delay", "attenuation":
                if k not in audio_channel_config:
                    raise ValueError(f"audio_channel must have key {k}")
            audio_channel = load_audio_channel(**audio_channel_config)
            if estimator.preprocessing_defences:
                estimator.preprocessing_defences.insert(0, audio_channel)
            else:
                estimator.preprocessing_defences = [audio_channel]
            estimator._update_preprocessing_operations()

        defense_config = config.get("defense") or {}
        defense_type = defense_config.get("type")

        if defense_type in ["Preprocessor", "Postprocessor"]:
            logger.info(
                f"Applying internal {defense_type} defense to estimator")
            estimator = load_defense_internal(config["defense"], estimator)

        if model_config["fit"]:
            logger.info(
                f"Fitting model {model_config['module']}.{model_config['name']}..."
            )
            fit_kwargs = model_config["fit_kwargs"]

            logger.info(
                f"Loading train dataset {config['dataset']['name']}...")
            batch_size = config["dataset"].pop("batch_size")
            config["dataset"]["batch_size"] = fit_kwargs.get(
                "fit_batch_size", batch_size)
            train_data = load_dataset(
                config["dataset"],
                epochs=fit_kwargs["nb_epochs"],
                split=config["dataset"].get("train_split", "train_clean100"),
                preprocessing_fn=fit_preprocessing_fn,
                shuffle_files=True,
            )
            config["dataset"]["batch_size"] = batch_size
            if defense_type == "Trainer":
                logger.info(f"Training with {defense_type} defense...")
                defense = load_defense_wrapper(config["defense"], estimator)
                defense.fit_generator(train_data, **fit_kwargs)
            else:
                logger.info("Fitting estimator on clean train dataset...")
                estimator.fit_generator(train_data, **fit_kwargs)

        if defense_type == "Transform":
            # NOTE: Transform currently not supported
            logger.info(
                f"Transforming estimator with {defense_type} defense...")
            defense = load_defense_wrapper(config["defense"], estimator)
            estimator = defense()

        attack_config = config["attack"]
        attack_type = attack_config.get("type")

        targeted = bool(attack_config.get("targeted"))
        metrics_logger = metrics.MetricsLogger.from_config(
            config["metric"],
            skip_benign=skip_benign,
            skip_attack=skip_attack,
            targeted=targeted,
        )

        if config["dataset"]["batch_size"] != 1:
            logger.warning("Evaluation batch_size != 1 may not be supported.")

        predict_kwargs = config["model"].get("predict_kwargs", {})
        eval_split = config["dataset"].get("eval_split", "test_clean")
        if skip_benign:
            logger.info("Skipping benign classification...")
        else:
            # Evaluate the ART estimator on benign test examples
            logger.info(f"Loading test dataset {config['dataset']['name']}...")
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split=eval_split,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )
            logger.info("Running inference on benign examples...")
            for x, y in tqdm(test_data, desc="Benign"):
                # Ensure that input sample isn't overwritten by estimator
                x.flags.writeable = False
                with metrics.resource_context(
                        name="Inference",
                        profiler=config["metric"].get("profiler_type"),
                        computational_resource_dict=metrics_logger.
                        computational_resource_dict,
                ):
                    y_pred = estimator.predict(x, **predict_kwargs)
                metrics_logger.update_task(y, y_pred)
            metrics_logger.log_task()

        if skip_attack:
            logger.info("Skipping attack generation...")
            return metrics_logger.results()

        # Imperceptible attack still WIP
        if (config.get("adhoc") or {}).get("skip_adversarial"):
            logger.info("Skipping adversarial classification...")
            return metrics_logger.results()

        # Evaluate the ART estimator on adversarial test examples
        logger.info("Generating or loading / testing adversarial examples...")

        if attack_type == "preloaded":
            test_data = load_adversarial_dataset(
                attack_config,
                epochs=1,
                split="adversarial",
                num_batches=num_eval_batches,
                shuffle_files=False,
            )
        else:
            attack = load_attack(attack_config, estimator)
            if targeted != attack.targeted:
                logger.warning(
                    f"targeted config {targeted} != attack field {attack.targeted}"
                )
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split=eval_split,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )
            if targeted:
                label_targeter = load_label_targeter(
                    attack_config["targeted_labels"])

        export_samples = config["scenario"].get("export_samples")
        if export_samples is not None and export_samples > 0:
            sample_exporter = SampleExporter(self.scenario_output_dir,
                                             test_data.context, export_samples)
        else:
            sample_exporter = None

        for x, y in tqdm(test_data, desc="Attack"):
            with metrics.resource_context(
                    name="Attack",
                    profiler=config["metric"].get("profiler_type"),
                    computational_resource_dict=metrics_logger.
                    computational_resource_dict,
            ):
                if attack_type == "preloaded":
                    x, x_adv = x
                    if targeted:
                        y, y_target = y
                elif attack_config.get("use_label"):
                    x_adv = attack.generate(x=x, y=y)
                elif targeted:
                    y_target = label_targeter.generate(y)
                    x_adv = attack.generate(x=x, y=y_target)
                else:
                    x_adv = attack.generate(x=x)

            # Ensure that input sample isn't overwritten by estimator
            x_adv.flags.writeable = False
            y_pred_adv = estimator.predict(x_adv, **predict_kwargs)
            metrics_logger.update_task(y, y_pred_adv, adversarial=True)
            if targeted:
                metrics_logger.update_task(
                    y_target,
                    y_pred_adv,
                    adversarial=True,
                    targeted=True,
                )
            metrics_logger.update_perturbation(x, x_adv)
            if sample_exporter is not None:
                sample_exporter.export(x, x_adv, y, y_pred_adv)
        metrics_logger.log_task(adversarial=True)
        if targeted:
            metrics_logger.log_task(adversarial=True, targeted=True)
        return metrics_logger.results()
Пример #10
0
    def _evaluate(
        self,
        config: dict,
        num_eval_batches: Optional[int],
        skip_benign: Optional[bool],
        skip_attack: Optional[bool],
    ) -> dict:
        """
        Evaluate a config file for classification robustness against attack.

        Note: num_eval_batches shouldn't be set for poisoning scenario and will raise an
        error if it is
        """
        if config["sysconfig"].get("use_gpu"):
            os.environ["TF_CUDNN_DETERMINISM"] = "1"
        if num_eval_batches:
            raise ValueError(
                "num_eval_batches shouldn't be set for poisoning scenario")
        if skip_benign:
            raise ValueError(
                "skip_benign shouldn't be set for poisoning scenario")
        if skip_attack:
            raise ValueError(
                "skip_attack shouldn't be set for poisoning scenario")

        model_config = config["model"]
        # Scenario assumes canonical preprocessing_fn is used makes images all same size
        classifier, _ = load_model(model_config)
        proxy_classifier, _ = load_model(model_config)

        config_adhoc = config.get("adhoc") or {}
        train_epochs = config_adhoc["train_epochs"]
        src_class = config_adhoc["source_class"]
        tgt_class = config_adhoc["target_class"]
        fit_batch_size = config_adhoc.get("fit_batch_size",
                                          config["dataset"]["batch_size"])

        if not config["sysconfig"].get("use_gpu"):
            conf = ConfigProto(intra_op_parallelism_threads=1)
            set_session(Session(config=conf))

        # Set random seed due to large variance in attack and defense success
        np.random.seed(config_adhoc["split_id"])
        set_random_seed(config_adhoc["split_id"])
        random.seed(config_adhoc["split_id"])
        use_poison_filtering_defense = config_adhoc.get(
            "use_poison_filtering_defense", True)
        if self.check_run:
            # filtering defense requires more than a single batch to run properly
            use_poison_filtering_defense = False

        logger.info(f"Loading dataset {config['dataset']['name']}...")

        clean_data = load_dataset(
            config["dataset"],
            epochs=1,
            split=config["dataset"].get("train_split", "train"),
            preprocessing_fn=poison_scenario_preprocessing,
            shuffle_files=False,
        )
        # Flag for whether to poison dataset -- used to evaluate
        #     performance of defense on clean data
        poison_dataset_flag = config["adhoc"]["poison_dataset"]
        # detect_poison does not currently support data generators
        #     therefore, make in memory dataset
        x_train_all, y_train_all = [], []

        logger.info(
            "Building in-memory dataset for poisoning detection and training")
        for x_train, y_train in clean_data:
            x_train_all.append(x_train)
            y_train_all.append(y_train)
        x_train_all = np.concatenate(x_train_all, axis=0)
        y_train_all = np.concatenate(y_train_all, axis=0)

        if poison_dataset_flag:
            y_train_all_categorical = to_categorical(y_train_all)
            attack_train_epochs = train_epochs
            attack_config = deepcopy(config["attack"])
            use_adversarial_trainer_flag = attack_config.get(
                "use_adversarial_trainer", False)

            proxy_classifier_fit_kwargs = {
                "batch_size": fit_batch_size,
                "nb_epochs": attack_train_epochs,
            }
            logger.info("Fitting proxy classifier...")
            if use_adversarial_trainer_flag:
                logger.info("Using adversarial trainer...")
                adversarial_trainer_kwargs = attack_config.pop(
                    "adversarial_trainer_kwargs", {})
                for k, v in proxy_classifier_fit_kwargs.items():
                    adversarial_trainer_kwargs[k] = v
                proxy_classifier = AdversarialTrainerMadryPGD(
                    proxy_classifier, **adversarial_trainer_kwargs)
                proxy_classifier.fit(x_train_all, y_train_all)
                attack_config["kwargs"][
                    "proxy_classifier"] = proxy_classifier.get_classifier()
            else:
                proxy_classifier_fit_kwargs["verbose"] = False
                proxy_classifier_fit_kwargs["shuffle"] = True
                proxy_classifier.fit(x_train_all, y_train_all,
                                     **proxy_classifier_fit_kwargs)
                attack_config["kwargs"]["proxy_classifier"] = proxy_classifier

            attack, backdoor = load(attack_config)

            x_train_all, y_train_all_categorical = attack.poison(
                x_train_all, y_train_all_categorical)
            y_train_all = np.argmax(y_train_all_categorical, axis=1)

        if use_poison_filtering_defense:
            y_train_defense = to_categorical(y_train_all)

            defense_config = config["defense"]
            detection_kwargs = config_adhoc.get("detection_kwargs", dict())

            defense_model_config = config_adhoc.get("defense_model",
                                                    model_config)

            # Assumes classifier_for_defense and classifier use same preprocessing function
            classifier_for_defense, _ = load_model(defense_model_config)
            # ART/Armory API requires that classifier_for_defense trains inside defense_fn
            defense_fn = load_fn(defense_config)
            defense = defense_fn(classifier_for_defense, x_train_all,
                                 y_train_defense)

            _, is_clean = defense.detect_poison(**detection_kwargs)
            is_clean = np.array(is_clean)
            logger.info(f"Total clean data points: {np.sum(is_clean)}")

            logger.info("Filtering out detected poisoned samples")
            indices_to_keep = is_clean == 1
            x_train_final = x_train_all[indices_to_keep]
            y_train_final = y_train_all[indices_to_keep]
        else:
            logger.info(
                "Defense does not require filtering. Model fitting will use all data."
            )
            x_train_final = x_train_all
            y_train_final = y_train_all
        if len(x_train_final):
            logger.info(
                f"Fitting model of {model_config['module']}.{model_config['name']}..."
            )
            classifier.fit(
                x_train_final,
                y_train_final,
                batch_size=fit_batch_size,
                nb_epochs=train_epochs,
                verbose=False,
                shuffle=True,
            )
        else:
            logger.warning(
                "All data points filtered by defense. Skipping training")

        logger.info("Validating on clean test data")
        test_data = load_dataset(
            config["dataset"],
            epochs=1,
            split=config["dataset"].get("eval_split", "test"),
            preprocessing_fn=poison_scenario_preprocessing,
            shuffle_files=False,
        )
        benign_validation_metric = metrics.MetricList("categorical_accuracy")
        target_class_benign_metric = metrics.MetricList("categorical_accuracy")
        for x, y in tqdm(test_data, desc="Testing"):
            # Ensure that input sample isn't overwritten by classifier
            x.flags.writeable = False
            y_pred = classifier.predict(x)
            benign_validation_metric.append(y, y_pred)
            y_pred_tgt_class = y_pred[y == src_class]
            if len(y_pred_tgt_class):
                target_class_benign_metric.append(
                    [src_class] * len(y_pred_tgt_class), y_pred_tgt_class)
        logger.info(
            f"Unpoisoned validation accuracy: {benign_validation_metric.mean():.2%}"
        )
        logger.info(
            f"Unpoisoned validation accuracy on targeted class: {target_class_benign_metric.mean():.2%}"
        )
        results = {
            "benign_validation_accuracy":
            benign_validation_metric.mean(),
            "benign_validation_accuracy_targeted_class":
            target_class_benign_metric.mean(),
        }

        poisoned_test_metric = metrics.MetricList("categorical_accuracy")
        poisoned_targeted_test_metric = metrics.MetricList(
            "categorical_accuracy")

        if poison_dataset_flag:
            logger.info("Testing on poisoned test data")
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split=config["dataset"].get("eval_split", "test"),
                preprocessing_fn=poison_scenario_preprocessing,
                shuffle_files=False,
            )
            for x_test, y_test in tqdm(test_data, desc="Testing"):
                src_indices = np.where(y_test == src_class)[0]
                poisoned_indices = src_indices  # Poison entire class
                x_test, _ = poison_dataset(
                    x_test,
                    y_test,
                    src_class,
                    tgt_class,
                    len(y_test),
                    backdoor,
                    poisoned_indices,
                )
                y_pred = classifier.predict(x_test)
                poisoned_test_metric.append(y_test, y_pred)

                y_pred_targeted = y_pred[y_test == src_class]
                if len(y_pred_targeted):
                    poisoned_targeted_test_metric.append(
                        [tgt_class] * len(y_pred_targeted), y_pred_targeted)
            results["poisoned_test_accuracy"] = poisoned_test_metric.mean()
            results[
                "poisoned_targeted_misclassification_accuracy"] = poisoned_targeted_test_metric.mean(
                )
            logger.info(f"Test accuracy: {poisoned_test_metric.mean():.2%}")
            logger.info(
                f"Test targeted misclassification accuracy: {poisoned_targeted_test_metric.mean():.2%}"
            )

        return results
Пример #11
0
    def _evaluate(
        self,
        config: dict,
        num_eval_batches: Optional[int],
        skip_benign: Optional[bool],
        skip_attack: Optional[bool],
        skip_misclassified: Optional[bool],
    ) -> dict:
        """
        Evaluate the config and return a results dict
        """
        if skip_misclassified:
            raise ValueError(
                "skip_misclassified shouldn't be set for D-APRICOT scenario")
        if skip_attack:
            raise ValueError(
                "--skip-attack should not be set for D-APRICOT scenario.")
        if skip_benign:
            logger.warning("--skip-benign is being ignored since the D-APRICOT"
                           " scenario doesn't include benign evaluation.")
        attack_config = config["attack"]
        attack_type = attack_config.get("type")
        if attack_type == "preloaded":
            raise ValueError(
                "D-APRICOT scenario should not have preloaded set to True in attack config"
            )
        elif "targeted_labels" not in attack_config:
            raise ValueError(
                "Attack config must have 'targeted_labels' key, as the "
                "D-APRICOT threat model is targeted.")
        elif attack_config.get("use_label"):
            raise ValueError(
                "The D-APRICOT scenario threat model is targeted, and"
                " thus 'use_label' should be set to false.")

        if config["dataset"].get("batch_size") != 1:
            raise ValueError(
                "batch_size of 1 is required for D-APRICOT scenario")

        model_config = config["model"]
        estimator, _ = load_model(model_config)

        defense_config = config.get("defense") or {}
        defense_type = defense_config.get("type")

        label_targeter = load_label_targeter(attack_config["targeted_labels"])

        if defense_type in ["Preprocessor", "Postprocessor"]:
            logger.info(
                f"Applying internal {defense_type} defense to estimator")
            estimator = load_defense_internal(config["defense"], estimator)

        if model_config["fit"]:
            try:
                logger.info(
                    f"Fitting model {model_config['module']}.{model_config['name']}..."
                )
                fit_kwargs = model_config["fit_kwargs"]

                logger.info(
                    f"Loading train dataset {config['dataset']['name']}...")
                train_data = load_dataset(
                    config["dataset"],
                    epochs=fit_kwargs["nb_epochs"],
                    split=config["dataset"].get("train_split", "train"),
                    shuffle_files=True,
                )
                if defense_type == "Trainer":
                    logger.info(f"Training with {defense_type} defense...")
                    defense = load_defense_wrapper(config["defense"],
                                                   estimator)
                    defense.fit_generator(train_data, **fit_kwargs)
                else:
                    logger.info("Fitting estimator on clean train dataset...")
                    estimator.fit_generator(train_data, **fit_kwargs)
            except NotImplementedError:
                raise NotImplementedError(
                    "Training has not yet been implemented for object detectors"
                )

        if defense_type == "Transform":
            # NOTE: Transform currently not supported
            logger.info(
                f"Transforming estimator with {defense_type} defense...")
            defense = load_defense_wrapper(config["defense"], estimator)
            estimator = defense()

        metrics_logger = metrics.MetricsLogger.from_config(
            config["metric"],
            skip_benign=True,
            skip_attack=False,
            targeted=True,
        )

        eval_split = config["dataset"].get("eval_split", "test")

        # Evaluate the ART estimator on adversarial test examples
        logger.info("Generating or loading / testing adversarial examples...")

        attack = load_attack(attack_config, estimator)
        test_data = load_dataset(
            config["dataset"],
            epochs=1,
            split=eval_split,
            num_batches=num_eval_batches,
            shuffle_files=False,
        )

        export_samples = config["scenario"].get("export_samples")
        if export_samples is not None and export_samples > 0:
            sample_exporter = SampleExporter(self.scenario_output_dir,
                                             test_data.context, export_samples)
        else:
            sample_exporter = None

        for x, y in tqdm(test_data, desc="Attack"):
            with metrics.resource_context(
                    name="Attack",
                    profiler=config["metric"].get("profiler_type"),
                    computational_resource_dict=metrics_logger.
                    computational_resource_dict,
            ):

                if x.shape[0] != 1:
                    raise ValueError("D-APRICOT batch size must be set to 1")
                # (nb=1, num_cameras, h, w, c) --> (num_cameras, h, w, c)
                x = x[0]
                y_object, y_patch_metadata = y

                generate_kwargs = deepcopy(
                    attack_config.get("generate_kwargs", {}))
                generate_kwargs["y_patch_metadata"] = y_patch_metadata
                y_target = label_targeter.generate(y_object)
                generate_kwargs["y_object"] = y_target

                x_adv = attack.generate(x=x, **generate_kwargs)

            # Ensure that input sample isn't overwritten by estimator
            x_adv.flags.writeable = False
            y_pred_adv = estimator.predict(x_adv)
            for img_idx in range(len(y_object)):
                y_i_target = y_target[img_idx]
                y_i_pred = y_pred_adv[img_idx]
                metrics_logger.update_task([y_i_target], [y_i_pred],
                                           adversarial=True,
                                           targeted=True)

            metrics_logger.update_perturbation(x, x_adv)
            if sample_exporter is not None:
                sample_exporter.export(x, x_adv, y_object, y_pred_adv)

        metrics_logger.log_task(adversarial=True, targeted=True)
        return metrics_logger.results()
Пример #12
0
    def _evaluate(self, config: dict) -> dict:
        """
        Evaluate the config and return a results dict
        """

        model_config = config["model"]
        classifier, preprocessing_fn = load_model(model_config)

        defense_config = config.get("defense") or {}
        defense_type = defense_config.get("type")

        if defense_type in ["Preprocessor", "Postprocessor"]:
            logger.info(
                f"Applying internal {defense_type} defense to classifier")
            classifier = load_defense_internal(config["defense"], classifier)

        if model_config["fit"]:
            classifier.set_learning_phase(True)
            logger.info(
                f"Fitting model {model_config['module']}.{model_config['name']}..."
            )
            fit_kwargs = model_config["fit_kwargs"]

            logger.info(
                f"Loading train dataset {config['dataset']['name']}...")
            train_data = load_dataset(
                config["dataset"],
                epochs=fit_kwargs["nb_epochs"],
                split_type="train",
                preprocessing_fn=preprocessing_fn,
            )
            if defense_type == "Trainer":
                logger.info(f"Training with {defense_type} defense...")
                defense = load_defense_wrapper(config["defense"], classifier)
                defense.fit_generator(train_data, **fit_kwargs)
            else:
                logger.info("Fitting classifier on clean train dataset...")
                classifier.fit_generator(train_data, **fit_kwargs)

            ################################################################
            #### Save weights at the end of training
            ################################################################
            ckpt_name = model_config['module'].replace('.', '_')
            ckpt_name += '_pretrained' if model_config['model_kwargs'][
                'pretrained'] else ''
            ckpt_name += '_epochs%d.pth' % model_config['fit_kwargs'][
                'nb_epochs']
            classifier.save(
                osp.join(paths.runtime_paths().saved_model_dir, ckpt_name))
            logger.info(f"Saved classifier {ckpt_name} ...")

        if defense_type == "Transform":
            # NOTE: Transform currently not supported
            logger.info(
                f"Transforming classifier with {defense_type} defense...")
            defense = load_defense_wrapper(config["defense"], classifier)
            classifier = defense()

        classifier.set_learning_phase(False)

        # Evaluate the ART classifier on benign test examples
        logger.info(f"Loading test dataset {config['dataset']['name']}...")
        test_data = load_dataset(
            config["dataset"],
            epochs=1,
            split_type="test",
            preprocessing_fn=preprocessing_fn,
        )
        logger.info("Running inference on benign examples...")
        metrics_logger = metrics.MetricsLogger.from_config(config["metric"])

        for x, y in tqdm(test_data, desc="Benign"):
            y_pred = classifier.predict(x)
            metrics_logger.update_task(y, y_pred)
        metrics_logger.log_task()

        # Evaluate the ART classifier on adversarial test examples
        logger.info("Generating or loading / testing adversarial examples...")

        attack_config = config["attack"]
        attack_type = attack_config.get("type")
        targeted = bool(attack_config.get("kwargs", {}).get("targeted"))
        if targeted and attack_config.get("use_label"):
            raise ValueError("Targeted attacks cannot have 'use_label'")
        if attack_type == "preloaded":
            test_data = load_adversarial_dataset(
                attack_config,
                epochs=1,
                split_type="adversarial",
                preprocessing_fn=preprocessing_fn,
            )
        else:
            attack = load_attack(attack_config, classifier)
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split_type="test",
                preprocessing_fn=preprocessing_fn,
            )
        for x, y in tqdm(test_data, desc="Attack"):
            if attack_type == "preloaded":
                x, x_adv = x
                if targeted:
                    y, y_target = y
            elif attack_config.get("use_label"):
                x_adv = attack.generate(x=x, y=y)
            elif targeted:
                raise NotImplementedError(
                    "Requires generation of target labels")
                # x_adv = attack.generate(x=x, y=y_target)
            else:
                x_adv = attack.generate(x=x)

            y_pred_adv = classifier.predict(x_adv)
            if targeted:
                # NOTE: does not remove data points where y == y_target
                metrics_logger.update_task(y_target,
                                           y_pred_adv,
                                           adversarial=True)
            else:
                metrics_logger.update_task(y, y_pred_adv, adversarial=True)
            metrics_logger.update_perturbation(x, x_adv)
        metrics_logger.log_task(adversarial=True, targeted=targeted)
        return metrics_logger.results()
Пример #13
0
    def _evaluate(self, config: dict) -> dict:
        """
        Evaluate a config file for classification robustness against attack.
        """

        model_config = config["model"]
        classifier, preprocessing_fn = load_model(model_config)
        classifier_for_defense, _ = load_model(model_config)

        train_epochs = config["adhoc"]["train_epochs"]
        src_class = config["adhoc"]["source_class"]
        tgt_class = config["adhoc"]["target_class"]

        # Set random seed due to large variance in attack and defense success
        np.random.seed(config["adhoc"]["np_seed"])
        set_random_seed(config["adhoc"]["tf_seed"])

        logger.info(f"Loading dataset {config['dataset']['name']}...")
        batch_size = config["dataset"]["batch_size"]
        train_data = load_dataset(
            config["dataset"],
            epochs=1,
            split_type="train",
            preprocessing_fn=preprocessing_fn,
        )

        logger.info(
            "Building in-memory dataset for poisoning detection and training")
        attack_config = config["attack"]
        attack = load(attack_config)
        fraction_poisoned = config["adhoc"]["fraction_poisoned"]
        poison_dataset_flag = config["adhoc"]["poison_dataset"]
        # detect_poison does not currently support data generators
        #     therefore, make in memory dataset
        x_train_all, y_train_all = [], []
        for x_train, y_train in train_data:
            if poison_dataset_flag and np.random.rand() < fraction_poisoned:
                x_train, y_train = poison_batch(x_train, y_train,
                                                src_class, tgt_class,
                                                len(y_train), attack)
            x_train_all.append(x_train)
            y_train_all.append(y_train)
        x_train_all = np.concatenate(x_train_all, axis=0)
        y_train_all = np.concatenate(y_train_all, axis=0)
        y_train_all_categorical = to_categorical(y_train_all)

        defense_config = config["defense"]
        logger.info(
            f"Fitting model {model_config['module']}.{model_config['name']} "
            f"for defense {defense_config['name']}...")
        classifier_for_defense.fit(
            x_train_all,
            y_train_all_categorical,
            batch_size=batch_size,
            nb_epochs=train_epochs,
            verbose=False,
        )
        defense_fn = load_fn(defense_config)
        defense = defense_fn(classifier_for_defense, x_train_all,
                             y_train_all_categorical)
        _, is_clean = defense.detect_poison(nb_clusters=2,
                                            nb_dims=43,
                                            reduce="PCA")
        is_clean = np.array(is_clean)
        logger.info(f"Total clean data points: {np.sum(is_clean)}")

        logger.info("Filtering out detected poisoned samples")
        indices_to_keep = is_clean == 1
        x_train_filter = x_train_all[indices_to_keep]
        y_train_filter = y_train_all_categorical[indices_to_keep]
        if len(x_train_filter):
            logger.info(
                f"Fitting model of {model_config['module']}.{model_config['name']}..."
            )
            classifier.fit(
                x_train_filter,
                y_train_filter,
                batch_size=batch_size,
                nb_epochs=train_epochs,
                verbose=False,
            )
        else:
            logger.warning(
                "All data points filtered by defense. Skipping training")

        logger.info(f"Validating on clean test data")
        test_data = load_dataset(
            config["dataset"],
            epochs=1,
            split_type="test",
            preprocessing_fn=preprocessing_fn,
        )
        validation_metric = metrics.MetricList("categorical_accuracy")
        for x, y in tqdm(test_data, desc="Testing"):
            y_pred = classifier.predict(x)
            validation_metric.append(y, y_pred)
        logger.info(
            f"Unpoisoned validation accuracy: {validation_metric.mean():.2%}")
        results = {"validation_accuracy": validation_metric.mean()}

        if poison_dataset_flag:
            logger.info(f"Testing on poisoned test data")
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split_type="test",
                preprocessing_fn=preprocessing_fn,
            )
            test_metric = metrics.MetricList("categorical_accuracy")
            targeted_test_metric = metrics.MetricList("categorical_accuracy")
            for x_test, y_test in tqdm(test_data, desc="Testing"):
                x_test, _ = poison_batch(x_test, y_test, src_class, tgt_class,
                                         len(y_test), attack)
                y_pred = classifier.predict(x_test)
                test_metric.append(y_test, y_pred)

                y_pred_targeted = y_pred[y_test == src_class]
                if not len(y_pred_targeted):
                    continue
                targeted_test_metric.append([tgt_class] * len(y_pred_targeted),
                                            y_pred_targeted)
            results["test_accuracy"] = test_metric.mean()
            results[
                "targeted_misclassification_accuracy"] = targeted_test_metric.mean(
                )
            logger.info(f"Test accuracy: {test_metric.mean():.2%}")
            logger.info(
                f"Test targeted misclassification accuracy: {targeted_test_metric.mean():.2%}"
            )
        return results
Пример #14
0
    def _evaluate(self, config: dict, num_eval_batches: Optional[int],
                  skip_benign: Optional[bool]) -> dict:
        """
        Evaluate the config and return a results dict
        """

        model_config = config["model"]
        classifier, preprocessing_fn = load_model(model_config)

        defense_config = config.get("defense") or {}
        defense_type = defense_config.get("type")

        if defense_type in ["Preprocessor", "Postprocessor"]:
            logger.info(
                f"Applying internal {defense_type} defense to classifier")
            classifier = load_defense_internal(config["defense"], classifier)

        if model_config["fit"]:
            classifier.set_learning_phase(True)
            logger.info(
                f"Fitting model {model_config['module']}.{model_config['name']}..."
            )
            fit_kwargs = model_config["fit_kwargs"]

            logger.info(
                f"Loading train dataset {config['dataset']['name']}...")
            train_data = load_dataset(
                config["dataset"],
                epochs=fit_kwargs["nb_epochs"],
                split_type="train",
                preprocessing_fn=preprocessing_fn,
                shuffle_files=True,
            )
            if defense_type == "Trainer":
                logger.info(f"Training with {defense_type} defense...")
                defense = load_defense_wrapper(config["defense"], classifier)
                defense.fit_generator(train_data, **fit_kwargs)
            else:
                logger.info("Fitting classifier on clean train dataset...")
                classifier.fit_generator(train_data, **fit_kwargs)

        if defense_type == "Transform":
            # NOTE: Transform currently not supported
            logger.info(
                f"Transforming classifier with {defense_type} defense...")
            defense = load_defense_wrapper(config["defense"], classifier)
            classifier = defense()

        classifier.set_learning_phase(False)

        metrics_logger = metrics.MetricsLogger.from_config(
            config["metric"], skip_benign=skip_benign)
        if skip_benign:
            logger.info("Skipping benign classification...")
        else:
            # Evaluate the ART classifier on benign test examples
            logger.info(f"Loading test dataset {config['dataset']['name']}...")
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split_type="test",
                preprocessing_fn=preprocessing_fn,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )

            logger.info("Running inference on benign examples...")
            for x, y in tqdm(test_data, desc="Benign"):
                with metrics.resource_context(
                        name="Inference",
                        profiler=config["metric"].get("profiler_type"),
                        computational_resource_dict=metrics_logger.
                        computational_resource_dict,
                ):
                    y_pred = classifier.predict(x)
                metrics_logger.update_task(y, y_pred)
            metrics_logger.log_task()

        # Evaluate the ART classifier on adversarial test examples
        logger.info("Generating or loading / testing adversarial examples...")

        attack_config = config["attack"]
        attack_type = attack_config.get("type")
        targeted = bool(attack_config.get("kwargs", {}).get("targeted"))
        if targeted and attack_config.get("use_label"):
            raise ValueError("Targeted attacks cannot have 'use_label'")
        if attack_type == "preloaded":
            test_data = load_adversarial_dataset(
                attack_config,
                epochs=1,
                split_type="adversarial",
                preprocessing_fn=preprocessing_fn,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )
        else:
            attack = load_attack(attack_config, classifier)
            if targeted != getattr(attack, "targeted", False):
                logger.warning(
                    f"targeted config {targeted} != attack field {getattr(attack, 'targeted', False)}"
                )
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split_type="test",
                preprocessing_fn=preprocessing_fn,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )
            if targeted:
                label_targeter = load_label_targeter(
                    attack_config["targeted_labels"])
        for x, y in tqdm(test_data, desc="Attack"):
            with metrics.resource_context(
                    name="Attack",
                    profiler=config["metric"].get("profiler_type"),
                    computational_resource_dict=metrics_logger.
                    computational_resource_dict,
            ):
                if attack_type == "preloaded":
                    x, x_adv = x
                    if targeted:
                        y, y_target = y
                elif attack_config.get("use_label"):
                    x_adv = attack.generate(x=x, y=y)
                elif targeted:
                    y_target = label_targeter.generate(y)
                    x_adv = attack.generate(x=x, y=y_target)
                else:
                    x_adv = attack.generate(x=x)

            y_pred_adv = classifier.predict(x_adv)
            if targeted:
                metrics_logger.update_task(y_target,
                                           y_pred_adv,
                                           adversarial=True)
            else:
                metrics_logger.update_task(y, y_pred_adv, adversarial=True)
            metrics_logger.update_perturbation(x, x_adv)
        metrics_logger.log_task(adversarial=True, targeted=targeted)
        return metrics_logger.results()
Пример #15
0
    def _evaluate(
        self,
        config: dict,
        num_eval_batches: Optional[int],
        skip_benign: Optional[bool],
        skip_attack: Optional[bool],
        skip_misclassified: Optional[bool],
    ) -> dict:
        """
        Evaluate the config and return a results dict
        """
        model_config = config["model"]
        estimator, _ = load_model(model_config)

        defense_config = config.get("defense") or {}
        defense_type = defense_config.get("type")

        if defense_type in ["Preprocessor", "Postprocessor"]:
            logger.info(
                f"Applying internal {defense_type} defense to estimator")
            estimator = load_defense_internal(config["defense"], estimator)

        if model_config["fit"]:
            try:
                logger.info(
                    f"Fitting model {model_config['module']}.{model_config['name']}..."
                )
                fit_kwargs = model_config["fit_kwargs"]

                logger.info(
                    f"Loading train dataset {config['dataset']['name']}...")
                train_data = load_dataset(
                    config["dataset"],
                    epochs=fit_kwargs["nb_epochs"],
                    split=config["dataset"].get("train_split", "train"),
                    shuffle_files=True,
                )
                if defense_type == "Trainer":
                    logger.info(f"Training with {defense_type} defense...")
                    defense = load_defense_wrapper(config["defense"],
                                                   estimator)
                    defense.fit_generator(train_data, **fit_kwargs)
                else:
                    logger.info("Fitting estimator on clean train dataset...")
                    estimator.fit_generator(train_data, **fit_kwargs)
            except NotImplementedError:
                raise NotImplementedError(
                    "Training has not yet been implemented for object detectors"
                )

        if defense_type == "Transform":
            # NOTE: Transform currently not supported
            logger.info(
                f"Transforming estimator with {defense_type} defense...")
            defense = load_defense_wrapper(config["defense"], estimator)
            estimator = defense()

        attack_config = config["attack"]
        attack_type = attack_config.get("type")

        targeted = bool(attack_config.get("kwargs", {}).get("targeted"))
        metrics_logger = metrics.MetricsLogger.from_config(
            config["metric"],
            skip_benign=skip_benign,
            skip_attack=skip_attack,
            targeted=targeted,
        )

        eval_split = config["dataset"].get("eval_split", "test")
        if skip_benign:
            logger.info("Skipping benign classification...")
        else:
            # Evaluate the ART estimator on benign test examples
            logger.info(f"Loading test dataset {config['dataset']['name']}...")
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split=eval_split,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )

            logger.info("Running inference on benign examples...")
            for x, y in tqdm(test_data, desc="Benign"):
                # Ensure that input sample isn't overwritten by estimator
                x.flags.writeable = False
                with metrics.resource_context(
                        name="Inference",
                        profiler=config["metric"].get("profiler_type"),
                        computational_resource_dict=metrics_logger.
                        computational_resource_dict,
                ):
                    y_pred = estimator.predict(x)
                metrics_logger.update_task(y, y_pred)
            metrics_logger.log_task()

        if skip_attack:
            logger.info("Skipping attack generation...")
            return metrics_logger.results()

        # Evaluate the ART estimator on adversarial test examples
        logger.info("Generating or loading / testing adversarial examples...")

        if skip_misclassified:
            acc_task_idx = [i.name for i in metrics_logger.tasks
                            ].index("categorical_accuracy")
            benign_acc = metrics_logger.tasks[acc_task_idx].values()

        if targeted and attack_config.get("use_label"):
            raise ValueError("Targeted attacks cannot have 'use_label'")
        if attack_type == "preloaded":
            preloaded_split = attack_config.get("kwargs", {}).get(
                "split", "adversarial")
            test_data = load_adversarial_dataset(
                attack_config,
                epochs=1,
                split=preloaded_split,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )
        else:
            attack = load_attack(attack_config, estimator)
            if targeted != getattr(attack, "targeted", False):
                logger.warning(
                    f"targeted config {targeted} != attack field {getattr(attack, 'targeted', False)}"
                )
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split=eval_split,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )
            if targeted:
                label_targeter = load_label_targeter(
                    attack_config["targeted_labels"])

        export_samples = config["scenario"].get("export_samples")
        if export_samples is not None and export_samples > 0:
            sample_exporter = SampleExporter(self.scenario_output_dir,
                                             test_data.context, export_samples)
        else:
            sample_exporter = None

        for batch_idx, (x, y) in enumerate(tqdm(test_data, desc="Attack")):
            with metrics.resource_context(
                    name="Attack",
                    profiler=config["metric"].get("profiler_type"),
                    computational_resource_dict=metrics_logger.
                    computational_resource_dict,
            ):
                if attack_type == "preloaded":
                    if len(x) == 2:
                        x, x_adv = x
                    else:
                        x_adv = x
                    if targeted:
                        y, y_target = y
                else:
                    generate_kwargs = deepcopy(
                        attack_config.get("generate_kwargs", {}))
                    # Temporary workaround for ART code requirement of ndarray mask
                    if "mask" in generate_kwargs:
                        generate_kwargs["mask"] = np.array(
                            generate_kwargs["mask"])
                    if attack_config.get("use_label"):
                        generate_kwargs["y"] = y
                    elif targeted:
                        y_target = label_targeter.generate(y)
                        generate_kwargs["y"] = y_target

                    if skip_misclassified and benign_acc[batch_idx] == 0:
                        x_adv = x
                    else:
                        x_adv = attack.generate(x=x, **generate_kwargs)

            # Ensure that input sample isn't overwritten by estimator
            x_adv.flags.writeable = False
            y_pred_adv = estimator.predict(x_adv)
            metrics_logger.update_task(y, y_pred_adv, adversarial=True)
            if targeted:
                metrics_logger.update_task(y_target,
                                           y_pred_adv,
                                           adversarial=True,
                                           targeted=True)
            metrics_logger.update_perturbation(x, x_adv)
            if sample_exporter is not None:
                sample_exporter.export(x, x_adv, y, y_pred_adv)
        metrics_logger.log_task(adversarial=True)
        if targeted:
            metrics_logger.log_task(adversarial=True, targeted=True)
        return metrics_logger.results()
    def _evaluate(self, config: dict, num_eval_batches: Optional[int],
                  skip_benign: Optional[bool]) -> dict:
        """
        Evaluate the config and return a results dict
        """
        model_config = config["model"]
        estimator, fit_preprocessing_fn = load_model(model_config)

        defense_config = config.get("defense") or {}
        defense_type = defense_config.get("type")

        if defense_type in ["Preprocessor", "Postprocessor"]:
            logger.info(
                f"Applying internal {defense_type} defense to estimator")
            estimator = load_defense_internal(config["defense"], estimator)

        if model_config["fit"]:
            try:
                estimator.set_learning_phase(True)
            except NotImplementedError:
                logger.exception(
                    "set_learning_phase error; training may not work.")

            logger.info(
                f"Fitting model {model_config['module']}.{model_config['name']}..."
            )
            fit_kwargs = model_config["fit_kwargs"]

            logger.info(
                f"Loading train dataset {config['dataset']['name']}...")
            batch_size = config["dataset"].pop("batch_size")
            config["dataset"]["batch_size"] = fit_kwargs.get(
                "fit_batch_size", batch_size)
            train_data = load_dataset(
                config["dataset"],
                epochs=fit_kwargs["nb_epochs"],
                split_type=config["dataset"].get("train_split",
                                                 "train_clean100"),
                preprocessing_fn=fit_preprocessing_fn,
                shuffle_files=True,
            )
            config["dataset"]["batch_size"] = batch_size
            if defense_type == "Trainer":
                logger.info(f"Training with {defense_type} defense...")
                defense = load_defense_wrapper(config["defense"], estimator)
                defense.fit_generator(train_data, **fit_kwargs)
            else:
                logger.info("Fitting estimator on clean train dataset...")
                estimator.fit_generator(train_data, **fit_kwargs)

        if defense_type == "Transform":
            # NOTE: Transform currently not supported
            logger.info(
                f"Transforming estimator with {defense_type} defense...")
            defense = load_defense_wrapper(config["defense"], estimator)
            estimator = defense()

        try:
            estimator.set_learning_phase(False)
        except NotImplementedError:
            logger.warning(
                "Unable to set estimator's learning phase. As of ART 1.4.1, "
                "this is not yet supported for speech recognition models.")

        metrics_logger = metrics.MetricsLogger.from_config(
            config["metric"], skip_benign=skip_benign)
        if config["dataset"]["batch_size"] != 1:
            logger.warning("Evaluation batch_size != 1 may not be supported.")

        predict_kwargs = config["model"].get("predict_kwargs", {})
        eval_split = config["dataset"].get("eval_split", "test_clean")
        if skip_benign:
            logger.info("Skipping benign classification...")
        else:
            # Evaluate the ART estimator on benign test examples
            logger.info(f"Loading test dataset {config['dataset']['name']}...")
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split_type=eval_split,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )
            logger.info("Running inference on benign examples...")
            for x, y in tqdm(test_data, desc="Benign"):
                # Ensure that input sample isn't overwritten by estimator
                x.flags.writeable = False
                with metrics.resource_context(
                        name="Inference",
                        profiler=config["metric"].get("profiler_type"),
                        computational_resource_dict=metrics_logger.
                        computational_resource_dict,
                ):
                    y_pred = estimator.predict(x, **predict_kwargs)
                metrics_logger.update_task(y, y_pred)
            metrics_logger.log_task()

        # Imperceptible attack still WIP
        if (config.get("adhoc") or {}).get("skip_adversarial"):
            logger.info("Skipping adversarial classification...")
            return metrics_logger.results()

        # Evaluate the ART estimator on adversarial test examples
        logger.info("Generating or loading / testing adversarial examples...")

        attack_config = config["attack"]
        attack_type = attack_config.get("type")

        targeted = bool(attack_config.get("targeted"))
        if attack_type == "preloaded":
            test_data = load_adversarial_dataset(
                attack_config,
                epochs=1,
                split_type="adversarial",
                num_batches=num_eval_batches,
                shuffle_files=False,
            )
        else:
            attack = load_attack(attack_config, estimator)
            if targeted != attack.targeted:
                logger.warning(
                    f"targeted config {targeted} != attack field {attack.targeted}"
                )
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split_type=eval_split,
                num_batches=num_eval_batches,
                shuffle_files=False,
            )
            if targeted:
                label_targeter = load_label_targeter(
                    attack_config["targeted_labels"])
        for x, y in tqdm(test_data, desc="Attack"):
            with metrics.resource_context(
                    name="Attack",
                    profiler=config["metric"].get("profiler_type"),
                    computational_resource_dict=metrics_logger.
                    computational_resource_dict,
            ):
                if attack_type == "preloaded":
                    x, x_adv = x
                    if targeted:
                        y, y_target = y
                elif attack_config.get("use_label"):
                    x_adv = attack.generate(x=x, y=y)
                elif targeted:
                    y_target = label_targeter.generate(y)
                    x_adv = attack.generate(x=x, y=y_target)
                else:
                    x_adv = attack.generate(x=x)

            # Ensure that input sample isn't overwritten by estimator
            x_adv.flags.writeable = False
            y_pred_adv = estimator.predict(x_adv, **predict_kwargs)
            metrics_logger.update_task(y, y_pred_adv, adversarial=True)
            metrics_logger.update_perturbation(x, x_adv)
        metrics_logger.log_task(adversarial=True, targeted=True)
        return metrics_logger.results()
Пример #17
0
    def _evaluate(self, config: dict) -> dict:
        """
        Evaluate the config and return a results dict
        """

        model_config = config["model"]
        classifier, preprocessing_fn = load_model(model_config)

        defense_config = config.get("defense") or {}
        defense_type = defense_config.get("type")

        if defense_type in ["Preprocessor", "Postprocessor"]:
            logger.info(f"Applying internal {defense_type} defense to classifier")
            classifier = load_defense_internal(config["defense"], classifier)

        if model_config["fit"]:
            classifier.set_learning_phase(True)
            logger.info(
                f"Fitting model {model_config['module']}.{model_config['name']}..."
            )
            train_epochs = config["model"]["fit_kwargs"]["nb_epochs"]
            batch_size = config["dataset"]["batch_size"]

            logger.info(f"Loading train dataset {config['dataset']['name']}...")
            train_data = load_dataset(
                config["dataset"],
                epochs=train_epochs,
                split_type="train",
                preprocessing_fn=preprocessing_fn,
            )

            if defense_type == "Trainer":
                logger.info(f"Training with {defense_type} defense...")
                defense = load_defense_wrapper(config["defense"], classifier)
            else:
                logger.info(f"Fitting classifier on clean train dataset...")

            for epoch in range(train_epochs):
                classifier.set_learning_phase(True)

                for _ in tqdm(
                    range(train_data.batches_per_epoch),
                    desc=f"Epoch: {epoch}/{train_epochs}",
                ):
                    x, y = train_data.get_batch()
                    # x_trains consists of one or more videos, each represented as an
                    # ndarray of shape (n_stacks, 3, 16, 112, 112).
                    # To train, randomly sample a batch of stacks
                    x = np.stack([x_i[np.random.randint(x_i.shape[0])] for x_i in x])
                    if defense_type == "Trainer":
                        defense.fit(x, y, batch_size=batch_size, nb_epochs=1)
                    else:
                        classifier.fit(x, y, batch_size=batch_size, nb_epochs=1)

        if defense_type == "Transform":
            # NOTE: Transform currently not supported
            logger.info(f"Transforming classifier with {defense_type} defense...")
            defense = load_defense_wrapper(config["defense"], classifier)
            classifier = defense()

        classifier.set_learning_phase(False)

        # Evaluate the ART classifier on benign test examples
        logger.info(f"Loading test dataset {config['dataset']['name']}...")
        test_data_generator = load_dataset(
            config["dataset"],
            epochs=1,
            split_type="test",
            preprocessing_fn=preprocessing_fn,
        )

        logger.info("Running inference on benign examples...")
        metrics_logger = metrics.MetricsLogger.from_config(config["metric"])

        for x_batch, y_batch in tqdm(test_data_generator, desc="Benign"):
            for x, y in zip(x_batch, y_batch):
                # combine predictions across all stacks
                y_pred = np.mean(classifier.predict(x), axis=0)
                metrics_logger.update_task(y, y_pred)
        metrics_logger.log_task()

        # Evaluate the ART classifier on adversarial test examples
        logger.info("Generating / testing adversarial examples...")

        attack = load_attack(config["attack"], classifier)
        test_data_generator = load_dataset(
            config["dataset"],
            epochs=1,
            split_type="test",
            preprocessing_fn=preprocessing_fn,
        )
        for x_batch, y_batch in tqdm(test_data_generator, desc="Attack"):
            for x, y in zip(x_batch, y_batch):
                # each x is of shape (n_stack, 3, 16, 112, 112)
                #    n_stack varies
                attack.set_params(batch_size=x.shape[0])
                x_adv = attack.generate(x=x)
                # combine predictions across all stacks
                y_pred = np.mean(classifier.predict(x), axis=0)
                metrics_logger.update_task(y, y_pred, adversarial=True)
                metrics_logger.update_perturbation([x], [x_adv])
        metrics_logger.log_task(adversarial=True)
        return metrics_logger.results()
Пример #18
0
    def _evaluate(self, config: dict) -> dict:
        """
        Evaluate the config and return a results dict
        """

        model_config = config["model"]
        classifier, preprocessing_fn = load_model(model_config)
        if isinstance(preprocessing_fn, tuple):
            fit_preprocessing_fn, predict_preprocessing_fn = preprocessing_fn
        else:
            fit_preprocessing_fn = predict_preprocessing_fn = preprocessing_fn

        defense_config = config.get("defense") or {}
        defense_type = defense_config.get("type")

        if defense_type in ["Preprocessor", "Postprocessor"]:
            logger.info(
                f"Applying internal {defense_type} defense to classifier")
            classifier = load_defense_internal(config["defense"], classifier)

        if model_config["fit"]:
            classifier.set_learning_phase(True)
            logger.info(
                f"Fitting model {model_config['module']}.{model_config['name']}..."
            )
            fit_kwargs = model_config["fit_kwargs"]

            logger.info(
                f"Loading train dataset {config['dataset']['name']}...")
            batch_size = config["dataset"].pop("batch_size")
            config["dataset"]["batch_size"] = config.get("adhoc", {}).get(
                "fit_batch_size", batch_size)
            train_data = load_dataset(
                config["dataset"],
                epochs=fit_kwargs["nb_epochs"],
                split_type="train",
                preprocessing_fn=fit_preprocessing_fn,
            )
            config["dataset"]["batch_size"] = batch_size
            if defense_type == "Trainer":
                logger.info(f"Training with {defense_type} defense...")
                defense = load_defense_wrapper(config["defense"], classifier)
                defense.fit_generator(train_data, **fit_kwargs)
            else:
                logger.info("Fitting classifier on clean train dataset...")
                classifier.fit_generator(train_data, **fit_kwargs)

        if defense_type == "Transform":
            # NOTE: Transform currently not supported
            logger.info(
                f"Transforming classifier with {defense_type} defense...")
            defense = load_defense_wrapper(config["defense"], classifier)
            classifier = defense()

        #HACK: to save model -- currently commenting it out
        #SAIL-JATI ----------------------------------
        #ts = time.time()
        #st = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d-%H-%M-%S')
        #model_save_dir_ = os.path.join("/nas/home/ajati/work/codes/SAIL_ALR_models/", st+"/")
        #os.system("mkdir -p "+model_save_dir_)
        #torch.save(classifier._model._model.state_dict(), model_save_dir_+"/sail_alr_model_state_dict.pt")
        #torch.save(classifier._model._model, model_save_dir_+"/sail_alr_model.pt")
        #torch.save(classifier._optimizer.state_dict(), model_save_dir_+"/sail_alr_optim_state_dict.pt")
        #torch.save(classifier._optimizer, model_save_dir_+"/sail_alr_optim.pt")
        ##-------------------------------------------

        classifier.set_learning_phase(False)

        # Evaluate the ART classifier on benign test examples
        logger.info(f"Loading test dataset {config['dataset']['name']}...")
        test_data = load_dataset(
            config["dataset"],
            epochs=1,
            split_type="test",
            preprocessing_fn=predict_preprocessing_fn,
        )
        logger.info("Running inference on benign examples...")
        metrics_logger = metrics.MetricsLogger.from_config(config["metric"])

        for x, y in tqdm(test_data, desc="Benign"):
            y_pred = classifier.predict(x)
            metrics_logger.update_task(y, y_pred)
        metrics_logger.log_task()

        # Evaluate the ART classifier on adversarial test examples
        logger.info("Generating or loading / testing adversarial examples...")
        attack_config = config["attack"]
        attack_type = attack_config.get("type")
        targeted = bool(attack_config.get("kwargs", {}).get("targeted"))
        if targeted and attack_config.get("use_label"):
            raise ValueError("Targeted attacks cannot have 'use_label'")
        if attack_type == "preloaded":
            test_data = load_adversarial_dataset(
                attack_config,
                epochs=1,
                split_type="adversarial",
                preprocessing_fn=predict_preprocessing_fn,
            )
        else:
            attack = load_attack(attack_config, classifier)
            test_data = load_dataset(
                config["dataset"],
                epochs=1,
                split_type="test",
                preprocessing_fn=predict_preprocessing_fn,
            )

        #JATI -- snr
        snrs = []
        for x, y in tqdm(test_data, desc="Attack"):
            if attack_type == "preloaded":
                x, x_adv = x
                if targeted:
                    y, y_target = y
            elif attack_config.get("use_label"):
                x_adv = attack.generate(x=x, y=y)
            elif targeted:
                raise NotImplementedError(
                    "Requires generation of target labels")
                # x_adv = attack.generate(x=x, y=y_target)
            else:
                x_adv = attack.generate(x=x)

            #JATI - snr
            noise = x_adv - x
            snr = 10 * np.log10(np.mean(x**2) / np.mean(noise**2))
            snrs.append(snr)

            y_pred_adv = classifier.predict(x_adv)
            if targeted:
                # NOTE: does not remove data points where y == y_target
                metrics_logger.update_task(y_target,
                                           y_pred_adv,
                                           adversarial=True)
            else:
                metrics_logger.update_task(y, y_pred_adv, adversarial=True)
            metrics_logger.update_perturbation(x, x_adv)
        metrics_logger.log_task(adversarial=True, targeted=targeted)

        mean_snr = np.mean(snrs)
        logging.info(f"MEAN SNR of adversarial samples = {mean_snr}")

        return metrics_logger.results()