Esempio n. 1
0
    def forward_one_image(class_id: int, image_id: int) -> Dict[str, Any]:
        nonlocal model_dir
        mode.check(False)
        data_dir = abspath(MNIST_PATH)
        model_dir = abspath(model_dir)
        ckpt_dir = f"{model_dir}/ckpts"
        create_model = lambda: LeNet(data_format="channels_first")
        graph = LeNet.graph().load()
        # model_fn = partial(
        #     model_fn_with_fetch_hook,
        #     create_model=create_model, graph=graph
        # )

        for transform, name in transforms:
            predicted_label = predict(
                create_model=create_model,
                input_fn=lambda: mnist.test(
                                            data_dir,
                                            transforms = transform
                                            )
                .filter(
                    lambda image, label: tf.equal(
                        tf.convert_to_tensor(class_id, dtype=tf.int32), label
                    )
                )
                .skip(image_id)
                .take(1)
                .batch(1),
                model_dir=ckpt_dir,
                tag = name,
            )
Esempio n. 2
0
        def get_row(class_id: int, image_id: int) -> Dict[str, Any]:
            nonlocal model_dir
            mode.check(False)
            data_dir = abspath(MNIST_PATH)
            model_dir = abspath(model_dir)
            ckpt_dir = f"{model_dir}/ckpts"
            create_model = lambda: LeNet(data_format="channels_first")
            graph = LeNet.graph().load()
            model_fn = partial(model_fn_with_fetch_hook,
                               create_model=create_model,
                               graph=graph)

            predicted_label = predict(
                create_model=create_model,
                input_fn=lambda: mnist.test(data_dir).filter(
                    lambda image, label: tf.equal(
                        tf.convert_to_tensor(class_id, dtype=tf.int32), label))
                .skip(image_id).take(1).batch(1),
                model_dir=ckpt_dir,
            )
            if predicted_label != class_id:
                return [{}] if per_node else {}

            adversarial_example = lenet_mnist_example(
                attack_name=attack_name,
                attack_fn=attack_fn,
                generate_adversarial_fn=generate_adversarial_fn,
                class_id=class_id,
                image_id=image_id,
                # model_dir not ckpt_dir
                model_dir=model_dir,
                transforms=transforms,
                transform_name=transform_name,
            ).load()

            if adversarial_example is None:
                return [{}] if per_node else {}

            adversarial_predicted_label = predict(
                create_model=create_model,
                input_fn=lambda: tf.data.Dataset.from_tensors(
                    mnist.normalize(adversarial_example)),
                model_dir=ckpt_dir,
            )

            if predicted_label == adversarial_predicted_label:
                return [{}] if per_node else {}

            trace = reconstruct_trace_from_tf(
                class_id=class_id,
                model_fn=model_fn,
                input_fn=lambda: mnist.test(data_dir, transforms=transforms).
                filter(lambda image, label: tf.equal(
                    tf.convert_to_tensor(class_id, dtype=tf.int32), label)
                       ).skip(image_id).take(1).batch(1),
                select_fn=select_fn,
                model_dir=ckpt_dir,
                per_channel=per_channel,
            )[0]

            if trace is None:
                return [{}] if per_node else {}

            adversarial_trace = reconstruct_trace_from_tf_brute_force(
                model_fn=model_fn,
                input_fn=lambda: tf.data.Dataset.from_tensors(
                    mnist.normalize(adversarial_example)),
                select_fn=select_fn,
                model_dir=ckpt_dir,
                per_channel=per_channel,
            )[0]

            adversarial_label = adversarial_trace.attrs[GraphAttrKey.PREDICT]

            # adversarial_pred, violation = \
            #     detect_by_reduced_edge_count_violation(
            #         class_trace_fn(adversarial_label).load(),
            #                         adversarial_trace,
            #                         reduce_mode,
            #     )
            # row = {
            #     "image_id": image_id,
            #     "class_id": class_id,
            #     "original.prediction":
            #         detect_by_reduced_edge(class_trace_fn(class_id).load(),
            #                                 trace,
            #                                 reduce_mode,
            #                                 ),
            #     "adversarial.prediction":
            #         adversarial_pred,
            #     "violation": violation,
            # }

            original_pred, violation = \
                detect_by_reduced_edge_count_violation(
                    class_trace_fn(class_id).load(),
                                    trace,
                                    reduce_mode,
                )
            row = {
                "image_id":
                image_id,
                "class_id":
                class_id,
                "original.prediction":
                original_pred,
                "adversarial.prediction":
                detect_by_reduced_edge(
                    class_trace_fn(adversarial_label).load(),
                    adversarial_trace,
                    reduce_mode,
                ),
                "violation":
                violation,
            }
            return row
Esempio n. 3
0
 def eval_input_fn():
     return (mnist.test(data_dir, ).batch(
         batch_size).make_one_shot_iterator().get_next())