def save_graph( attack_name: str, attack_fn, generate_adversarial_fn, select_fn: Callable[[np.ndarray], np.ndarray], image_id_index, batch_size, class_id, # model_dir = "result/lenet/model_augmentation", model_dir=model_dir, transforms=None, transform_name="noop", graph_dir="result/test", dataset_mode=dataset_mode, images_per_class=1, compute_adversarial=True, **kwargs, ): data_dir = abspath(MNIST_PATH) model_dir = abspath(model_dir) ckpt_dir = f"{model_dir}/ckpts" create_model = lambda: LeNet(data_format="channels_first") graph = LeNet.graph().load() batch_size = min(batch_size, images_per_class - image_id_index) if dataset_mode == "test": dataset = mnist.test elif dataset_mode == "train": dataset = mnist.train else: raise RuntimeError("Dataset invalid") predicted_label = predict_batch( create_model=create_model, input_fn=lambda: (dataset(data_dir).filter(lambda image, label: tf.equal( tf.convert_to_tensor(class_id, dtype=tf.int32), label)).skip( image_id_index).take(batch_size).batch(batch_size). make_one_shot_iterator().get_next()[0]), model_dir=ckpt_dir, ) prediction_valid = (predicted_label == class_id) # The shape of each example is (1, 32, 32, 3) adversarial_examples = [ lenet_mnist_example( attack_name=attack_name, attack_fn=attack_fn, generate_adversarial_fn=generate_adversarial_fn, class_id=class_id, image_id=image_id, # model_dir not ckpt_dir model_dir=model_dir, transforms=transforms, transform_name=transform_name, mode=dataset_mode, ).load() for image_id in range(image_id_index, image_id_index + batch_size) ] adversarial_valid = np.array( [example is not None for example in adversarial_examples]) adversarial_examples = [ example if example is not None else np.zeros((1, 1, 28, 28)) for example in adversarial_examples ] adversarial_examples = np.squeeze(np.array(adversarial_examples).astype( np.float32), axis=1) # adversarial_example is [0, 1] of shape (1, 28, 28) adversarial_predicted_label = predict_batch( create_model=create_model, input_fn=lambda: tf.data.Dataset.from_tensors( mnist.normalize(adversarial_examples)), model_dir=ckpt_dir, ) adversarial_prediction_valid = adversarial_predicted_label != class_id batch_valid = (prediction_valid * adversarial_valid * adversarial_prediction_valid) original_graph_dir = os.path.join(graph_dir, f"original_{transform_name}", f"{class_id}") original_graph_saver = IOBatchAction( dir=original_graph_dir, root_index=image_id_index, ) original_model_fn = partial( model_fn_with_fetch_hook, create_model=create_model, graph=graph, graph_saver=original_graph_saver, batch_valid=batch_valid, ) trace = reconstruct_trace_from_tf( class_id=class_id, model_fn=original_model_fn, input_fn=lambda: (dataset( data_dir, transforms=transforms, ).filter(lambda image, label: tf.equal( tf.convert_to_tensor(class_id, dtype=tf.int32), label)).skip( image_id_index).take(batch_size).batch(batch_size). make_one_shot_iterator().get_next()[0]), select_fn=select_fn, model_dir=ckpt_dir, ) if compute_adversarial: adversarial_graph_dir = os.path.join( graph_dir, f"{attack_name}_{transform_name}", f"{class_id}") adversarial_graph_saver = IOBatchAction( dir=adversarial_graph_dir, root_index=image_id_index, ) adversarial_model_fn = partial( model_fn_with_fetch_hook, create_model=create_model, graph=graph, graph_saver=adversarial_graph_saver, batch_valid=batch_valid, ) adversarial_trace = reconstruct_trace_from_tf( model_fn=adversarial_model_fn, input_fn=lambda: tf.data.Dataset.from_tensors( mnist.normalize(adversarial_examples)), select_fn=select_fn, model_dir=ckpt_dir, )
def get_row(class_id: int, image_id: int) -> Dict[str, Any]: nonlocal model_dir mode.check(False) data_dir = abspath(MNIST_PATH) model_dir = abspath(model_dir) ckpt_dir = f"{model_dir}/ckpts" create_model = lambda: LeNet(data_format="channels_first") graph = LeNet.graph().load() model_fn = partial(model_fn_with_fetch_hook, create_model=create_model, graph=graph) predicted_label = predict( create_model=create_model, input_fn=lambda: mnist.test(data_dir).filter( lambda image, label: tf.equal( tf.convert_to_tensor(class_id, dtype=tf.int32), label)) .skip(image_id).take(1).batch(1), model_dir=ckpt_dir, ) if predicted_label != class_id: return [{}] if per_node else {} adversarial_example = lenet_mnist_example( attack_name=attack_name, attack_fn=attack_fn, generate_adversarial_fn=generate_adversarial_fn, class_id=class_id, image_id=image_id, # model_dir not ckpt_dir model_dir=model_dir, transforms=transforms, transform_name=transform_name, ).load() if adversarial_example is None: return [{}] if per_node else {} adversarial_predicted_label = predict( create_model=create_model, input_fn=lambda: tf.data.Dataset.from_tensors( mnist.normalize(adversarial_example)), model_dir=ckpt_dir, ) if predicted_label == adversarial_predicted_label: return [{}] if per_node else {} trace = reconstruct_trace_from_tf( class_id=class_id, model_fn=model_fn, input_fn=lambda: mnist.test(data_dir, transforms=transforms). filter(lambda image, label: tf.equal( tf.convert_to_tensor(class_id, dtype=tf.int32), label) ).skip(image_id).take(1).batch(1), select_fn=select_fn, model_dir=ckpt_dir, per_channel=per_channel, )[0] if trace is None: return [{}] if per_node else {} adversarial_trace = reconstruct_trace_from_tf_brute_force( model_fn=model_fn, input_fn=lambda: tf.data.Dataset.from_tensors( mnist.normalize(adversarial_example)), select_fn=select_fn, model_dir=ckpt_dir, per_channel=per_channel, )[0] adversarial_label = adversarial_trace.attrs[GraphAttrKey.PREDICT] # adversarial_pred, violation = \ # detect_by_reduced_edge_count_violation( # class_trace_fn(adversarial_label).load(), # adversarial_trace, # reduce_mode, # ) # row = { # "image_id": image_id, # "class_id": class_id, # "original.prediction": # detect_by_reduced_edge(class_trace_fn(class_id).load(), # trace, # reduce_mode, # ), # "adversarial.prediction": # adversarial_pred, # "violation": violation, # } original_pred, violation = \ detect_by_reduced_edge_count_violation( class_trace_fn(class_id).load(), trace, reduce_mode, ) row = { "image_id": image_id, "class_id": class_id, "original.prediction": original_pred, "adversarial.prediction": detect_by_reduced_edge( class_trace_fn(adversarial_label).load(), adversarial_trace, reduce_mode, ), "violation": violation, } return row
def get_row(class_id: int, image_id: int) -> Dict[str, Any]: nonlocal model_dir mode.check(False) data_dir = abspath(MNIST_PATH) model_dir = abspath(model_dir) ckpt_dir = f"{model_dir}/ckpts" create_model = lambda: LeNet(data_format="channels_first") graph = LeNet.graph().load() model_fn = partial(model_fn_with_fetch_hook, create_model=create_model, graph=graph) if mnist_dataset_mode == "test": dataset = mnist.test elif mnist_dataset_mode == "train": dataset = mnist.train else: raise RuntimeError("Dataset invalid") predicted_label = predict( create_model=create_model, # dataset may be train or test, should be consistent with lenet_mnist_example input_fn=lambda: dataset(data_dir, ).filter( lambda image, label: tf.equal( tf.convert_to_tensor(class_id, dtype=tf.int32), label)) .skip(image_id).take(1).batch(1), model_dir=ckpt_dir, ) if predicted_label != class_id: return [{}] if per_node else {} adversarial_example = lenet_mnist_example( attack_name=attack_name, attack_fn=attack_fn, generate_adversarial_fn=generate_adversarial_fn, class_id=class_id, image_id=image_id, # model_dir not ckpt_dir model_dir=model_dir, transforms=transforms, transform_name=transform_name, mode=mnist_dataset_mode, ).load() if adversarial_example is None: return [{}] if per_node else {} adversarial_predicted_label = predict( create_model=create_model, input_fn=lambda: tf.data.Dataset.from_tensors( mnist.normalize(adversarial_example)), model_dir=ckpt_dir, ) if predicted_label == adversarial_predicted_label: return [{}] if per_node else {} if use_class_trace: class_trace_avg = ClassTraceIOAction(predicted_label).load() else: class_trace_avg = None trace = reconstruct_trace_from_tf( class_id=class_id, model_fn=model_fn, input_fn=lambda: dataset( data_dir, transforms=transforms, ).filter(lambda image, label: tf.equal( tf.convert_to_tensor(class_id, dtype=tf.int32), label)). skip(image_id).take(1).batch(1), select_fn=select_fn, model_dir=ckpt_dir, per_channel=per_channel, class_trace=class_trace_avg, )[0] if trace is None: return [{}] if per_node else {} path = os.path.join(save_dir, f"original_{transform_name}", f"{class_id}", f"{image_id}.pkl") ensure_dir(path) with open(path, "wb") as f: pickle.dump(trace, f) adversarial_trace = reconstruct_trace_from_tf( model_fn=model_fn, input_fn=lambda: tf.data.Dataset.from_tensors( mnist.normalize(adversarial_example)), select_fn=select_fn, model_dir=ckpt_dir, per_channel=per_channel, class_trace=class_trace_avg, )[0] path = os.path.join(save_dir, f"{attack_name}_{transform_name}", f"{class_id}", f"{image_id}.pkl") ensure_dir(path) with open(path, "wb") as f: pickle.dump(adversarial_trace, f) row = { "class_id": class_id, "image_id": image_id, "trace": trace, "adversarial_trace": adversarial_trace, } # row = calc_all_overlap( # class_trace_fn(class_id).load(), adversarial_trace, overlap_fn # ) return row