コード例 #1
0
def save_graph(
    attack_name: str,
    attack_fn,
    generate_adversarial_fn,
    select_fn: Callable[[np.ndarray], np.ndarray],
    image_id_index,
    batch_size,
    class_id,
    # model_dir = "result/lenet/model_augmentation",
    model_dir=model_dir,
    transforms=None,
    transform_name="noop",
    graph_dir="result/test",
    dataset_mode=dataset_mode,
    images_per_class=1,
    compute_adversarial=True,
    **kwargs,
):

    data_dir = abspath(MNIST_PATH)
    model_dir = abspath(model_dir)
    ckpt_dir = f"{model_dir}/ckpts"
    create_model = lambda: LeNet(data_format="channels_first")
    graph = LeNet.graph().load()

    batch_size = min(batch_size, images_per_class - image_id_index)
    if dataset_mode == "test":
        dataset = mnist.test
    elif dataset_mode == "train":
        dataset = mnist.train
    else:
        raise RuntimeError("Dataset invalid")

    predicted_label = predict_batch(
        create_model=create_model,
        input_fn=lambda:
        (dataset(data_dir).filter(lambda image, label: tf.equal(
            tf.convert_to_tensor(class_id, dtype=tf.int32), label)).skip(
                image_id_index).take(batch_size).batch(batch_size).
         make_one_shot_iterator().get_next()[0]),
        model_dir=ckpt_dir,
    )
    prediction_valid = (predicted_label == class_id)

    # The shape of each example is (1, 32, 32, 3)
    adversarial_examples = [
        lenet_mnist_example(
            attack_name=attack_name,
            attack_fn=attack_fn,
            generate_adversarial_fn=generate_adversarial_fn,
            class_id=class_id,
            image_id=image_id,
            # model_dir not ckpt_dir
            model_dir=model_dir,
            transforms=transforms,
            transform_name=transform_name,
            mode=dataset_mode,
        ).load()
        for image_id in range(image_id_index, image_id_index + batch_size)
    ]

    adversarial_valid = np.array(
        [example is not None for example in adversarial_examples])

    adversarial_examples = [
        example if example is not None else np.zeros((1, 1, 28, 28))
        for example in adversarial_examples
    ]
    adversarial_examples = np.squeeze(np.array(adversarial_examples).astype(
        np.float32),
                                      axis=1)

    # adversarial_example is [0, 1] of shape (1, 28, 28)
    adversarial_predicted_label = predict_batch(
        create_model=create_model,
        input_fn=lambda: tf.data.Dataset.from_tensors(
            mnist.normalize(adversarial_examples)),
        model_dir=ckpt_dir,
    )
    adversarial_prediction_valid = adversarial_predicted_label != class_id

    batch_valid = (prediction_valid * adversarial_valid *
                   adversarial_prediction_valid)
    original_graph_dir = os.path.join(graph_dir, f"original_{transform_name}",
                                      f"{class_id}")
    original_graph_saver = IOBatchAction(
        dir=original_graph_dir,
        root_index=image_id_index,
    )
    original_model_fn = partial(
        model_fn_with_fetch_hook,
        create_model=create_model,
        graph=graph,
        graph_saver=original_graph_saver,
        batch_valid=batch_valid,
    )
    trace = reconstruct_trace_from_tf(
        class_id=class_id,
        model_fn=original_model_fn,
        input_fn=lambda: (dataset(
            data_dir,
            transforms=transforms,
        ).filter(lambda image, label: tf.equal(
            tf.convert_to_tensor(class_id, dtype=tf.int32), label)).skip(
                image_id_index).take(batch_size).batch(batch_size).
                          make_one_shot_iterator().get_next()[0]),
        select_fn=select_fn,
        model_dir=ckpt_dir,
    )

    if compute_adversarial:
        adversarial_graph_dir = os.path.join(
            graph_dir, f"{attack_name}_{transform_name}", f"{class_id}")
        adversarial_graph_saver = IOBatchAction(
            dir=adversarial_graph_dir,
            root_index=image_id_index,
        )
        adversarial_model_fn = partial(
            model_fn_with_fetch_hook,
            create_model=create_model,
            graph=graph,
            graph_saver=adversarial_graph_saver,
            batch_valid=batch_valid,
        )

        adversarial_trace = reconstruct_trace_from_tf(
            model_fn=adversarial_model_fn,
            input_fn=lambda: tf.data.Dataset.from_tensors(
                mnist.normalize(adversarial_examples)),
            select_fn=select_fn,
            model_dir=ckpt_dir,
        )
コード例 #2
0
        def get_row(class_id: int, image_id: int) -> Dict[str, Any]:
            nonlocal model_dir
            mode.check(False)
            data_dir = abspath(MNIST_PATH)
            model_dir = abspath(model_dir)
            ckpt_dir = f"{model_dir}/ckpts"
            create_model = lambda: LeNet(data_format="channels_first")
            graph = LeNet.graph().load()
            model_fn = partial(model_fn_with_fetch_hook,
                               create_model=create_model,
                               graph=graph)

            predicted_label = predict(
                create_model=create_model,
                input_fn=lambda: mnist.test(data_dir).filter(
                    lambda image, label: tf.equal(
                        tf.convert_to_tensor(class_id, dtype=tf.int32), label))
                .skip(image_id).take(1).batch(1),
                model_dir=ckpt_dir,
            )
            if predicted_label != class_id:
                return [{}] if per_node else {}

            adversarial_example = lenet_mnist_example(
                attack_name=attack_name,
                attack_fn=attack_fn,
                generate_adversarial_fn=generate_adversarial_fn,
                class_id=class_id,
                image_id=image_id,
                # model_dir not ckpt_dir
                model_dir=model_dir,
                transforms=transforms,
                transform_name=transform_name,
            ).load()

            if adversarial_example is None:
                return [{}] if per_node else {}

            adversarial_predicted_label = predict(
                create_model=create_model,
                input_fn=lambda: tf.data.Dataset.from_tensors(
                    mnist.normalize(adversarial_example)),
                model_dir=ckpt_dir,
            )

            if predicted_label == adversarial_predicted_label:
                return [{}] if per_node else {}

            trace = reconstruct_trace_from_tf(
                class_id=class_id,
                model_fn=model_fn,
                input_fn=lambda: mnist.test(data_dir, transforms=transforms).
                filter(lambda image, label: tf.equal(
                    tf.convert_to_tensor(class_id, dtype=tf.int32), label)
                       ).skip(image_id).take(1).batch(1),
                select_fn=select_fn,
                model_dir=ckpt_dir,
                per_channel=per_channel,
            )[0]

            if trace is None:
                return [{}] if per_node else {}

            adversarial_trace = reconstruct_trace_from_tf_brute_force(
                model_fn=model_fn,
                input_fn=lambda: tf.data.Dataset.from_tensors(
                    mnist.normalize(adversarial_example)),
                select_fn=select_fn,
                model_dir=ckpt_dir,
                per_channel=per_channel,
            )[0]

            adversarial_label = adversarial_trace.attrs[GraphAttrKey.PREDICT]

            # adversarial_pred, violation = \
            #     detect_by_reduced_edge_count_violation(
            #         class_trace_fn(adversarial_label).load(),
            #                         adversarial_trace,
            #                         reduce_mode,
            #     )
            # row = {
            #     "image_id": image_id,
            #     "class_id": class_id,
            #     "original.prediction":
            #         detect_by_reduced_edge(class_trace_fn(class_id).load(),
            #                                 trace,
            #                                 reduce_mode,
            #                                 ),
            #     "adversarial.prediction":
            #         adversarial_pred,
            #     "violation": violation,
            # }

            original_pred, violation = \
                detect_by_reduced_edge_count_violation(
                    class_trace_fn(class_id).load(),
                                    trace,
                                    reduce_mode,
                )
            row = {
                "image_id":
                image_id,
                "class_id":
                class_id,
                "original.prediction":
                original_pred,
                "adversarial.prediction":
                detect_by_reduced_edge(
                    class_trace_fn(adversarial_label).load(),
                    adversarial_trace,
                    reduce_mode,
                ),
                "violation":
                violation,
            }
            return row
コード例 #3
0
        def get_row(class_id: int, image_id: int) -> Dict[str, Any]:
            nonlocal model_dir
            mode.check(False)
            data_dir = abspath(MNIST_PATH)
            model_dir = abspath(model_dir)
            ckpt_dir = f"{model_dir}/ckpts"
            create_model = lambda: LeNet(data_format="channels_first")
            graph = LeNet.graph().load()

            model_fn = partial(model_fn_with_fetch_hook,
                               create_model=create_model,
                               graph=graph)
            if mnist_dataset_mode == "test":
                dataset = mnist.test
            elif mnist_dataset_mode == "train":
                dataset = mnist.train
            else:
                raise RuntimeError("Dataset invalid")

            predicted_label = predict(
                create_model=create_model,
                # dataset may be train or test, should be consistent with lenet_mnist_example
                input_fn=lambda: dataset(data_dir, ).filter(
                    lambda image, label: tf.equal(
                        tf.convert_to_tensor(class_id, dtype=tf.int32), label))
                .skip(image_id).take(1).batch(1),
                model_dir=ckpt_dir,
            )
            if predicted_label != class_id:
                return [{}] if per_node else {}

            adversarial_example = lenet_mnist_example(
                attack_name=attack_name,
                attack_fn=attack_fn,
                generate_adversarial_fn=generate_adversarial_fn,
                class_id=class_id,
                image_id=image_id,
                # model_dir not ckpt_dir
                model_dir=model_dir,
                transforms=transforms,
                transform_name=transform_name,
                mode=mnist_dataset_mode,
            ).load()

            if adversarial_example is None:
                return [{}] if per_node else {}

            adversarial_predicted_label = predict(
                create_model=create_model,
                input_fn=lambda: tf.data.Dataset.from_tensors(
                    mnist.normalize(adversarial_example)),
                model_dir=ckpt_dir,
            )

            if predicted_label == adversarial_predicted_label:
                return [{}] if per_node else {}

            if use_class_trace:
                class_trace_avg = ClassTraceIOAction(predicted_label).load()
            else:
                class_trace_avg = None

            trace = reconstruct_trace_from_tf(
                class_id=class_id,
                model_fn=model_fn,
                input_fn=lambda: dataset(
                    data_dir,
                    transforms=transforms,
                ).filter(lambda image, label: tf.equal(
                    tf.convert_to_tensor(class_id, dtype=tf.int32), label)).
                skip(image_id).take(1).batch(1),
                select_fn=select_fn,
                model_dir=ckpt_dir,
                per_channel=per_channel,
                class_trace=class_trace_avg,
            )[0]

            if trace is None:
                return [{}] if per_node else {}
            path = os.path.join(save_dir, f"original_{transform_name}",
                                f"{class_id}", f"{image_id}.pkl")
            ensure_dir(path)
            with open(path, "wb") as f:
                pickle.dump(trace, f)

            adversarial_trace = reconstruct_trace_from_tf(
                model_fn=model_fn,
                input_fn=lambda: tf.data.Dataset.from_tensors(
                    mnist.normalize(adversarial_example)),
                select_fn=select_fn,
                model_dir=ckpt_dir,
                per_channel=per_channel,
                class_trace=class_trace_avg,
            )[0]

            path = os.path.join(save_dir, f"{attack_name}_{transform_name}",
                                f"{class_id}", f"{image_id}.pkl")
            ensure_dir(path)
            with open(path, "wb") as f:
                pickle.dump(adversarial_trace, f)

            row = {
                "class_id": class_id,
                "image_id": image_id,
                "trace": trace,
                "adversarial_trace": adversarial_trace,
            }
            # row = calc_all_overlap(
            #     class_trace_fn(class_id).load(), adversarial_trace, overlap_fn
            # )
            return row