def save_trace( class_id, image_id, select_fn: Callable[[np.ndarray], np.ndarray], class_dir, graph_dir, create_model, graph, per_node: bool = False, images_per_class: int = 1, num_gpus: float = 1, model_dir=resnet18_dir, transforms=None, transform_name="noop", save_dir="result/test", dataset_mode=dataset_mode, **kwargs, ): input_fn = lambda: (input_fn_for_adversarial_examples( is_training=(dataset_mode == "train"), data_dir=data_dir, num_parallel_batches=1, is_shuffle=False, transform_fn=None, ).filter(lambda image, label: tf.equal( tf.convert_to_tensor(class_id, dtype=tf.int32), label)).skip( image_id).take(1).batch(1).make_one_shot_iterator().get_next()[0]) mode.check(False) data_dir = abspath(CIFAR10_PATH) model_dir = abspath(model_dir) ckpt_dir = f"{model_dir}/ckpts" start = time.clock() predicted_label = predict_batch( create_model=create_model, input_fn=lambda: (input_fn_for_adversarial_examples( is_training=(dataset_mode == "train"), data_dir=data_dir, num_parallel_batches=1, is_shuffle=False, transform_fn=None, ).filter(lambda image, label: tf.equal( tf.convert_to_tensor(class_id, dtype=tf.int32), label)).skip( image_id).take(1).batch(1).make_one_shot_iterator().get_next()[ 0]), model_dir=ckpt_dir, ) predicted_label = predicted_label[0] if predicted_label != class_id: return [{}] if per_node else {} print(f"prediction {time.clock() - start}s") original_graph_dir = os.path.join(graph_dir, f"original_{transform_name}", f"{class_id}") original_graph_saver = IOBatchAction( dir=original_graph_dir, root_index=image_id, ) original_model_fn = partial( model_fn_with_fetch_hook, create_model=create_model, graph=graph, graph_saver=original_graph_saver, batch_valid=[1], ) trace = reconstruct_trace_from_tf( class_id=class_id, model_fn=original_model_fn, input_fn=lambda: (input_fn_for_adversarial_examples( is_training=(dataset_mode == "train"), data_dir=data_dir, num_parallel_batches=1, is_shuffle=False, transform_fn=None, ).filter(lambda image, label: tf.equal( tf.convert_to_tensor(class_id, dtype=tf.int32), label)).skip( image_id).take(1).batch(1).make_one_shot_iterator().get_next()[ 0]), select_fn=select_fn, model_dir=ckpt_dir, ) print(f"Saved graph") graph_path = os.path.join(graph_dir, f"original_{transform_name}", f"{class_id}", f"{image_id}.pkl") if not os.path.exists(graph_path): return single_graph = IOObjAction(graph_path).load() if use_class_trace: class_trace_avg = ClassTraceIOAction(class_dir, predicted_label).load() assert class_trace_avg is not None else: class_trace_avg = None start = time.clock() single_trace = reconstruct_trace_from_tf_to_trace( single_graph, class_id=(class_id if attack_name == "original" else None), select_fn=select_fn, class_trace=class_trace_avg, debug=False, ) print(f"compute original trace {time.clock() - start}s")
def save_graph( attack_name: str, attack_fn, generate_adversarial_fn, select_fn: Callable[[np.ndarray], np.ndarray], image_id_index, batch_size, class_id, # model_dir = "result/lenet/model_augmentation", model_dir=model_dir, transforms=None, transform_name="noop", graph_dir="result/test", dataset_mode=dataset_mode, images_per_class=1, compute_adversarial=True, **kwargs, ): # mode.check(False) data_dir = abspath(CIFAR10_PATH) model_dir = abspath(model_dir) ckpt_dir = f"{model_dir}/ckpts" create_model = lambda: partial( ResNet10Cifar10(), training=False, ) graph = ResNet10Cifar10.graph().load() batch_size = min(batch_size, images_per_class - image_id_index) predicted_label = predict_batch( create_model=create_model, input_fn=lambda: (input_fn_for_adversarial_examples( is_training=(dataset_mode == "train"), data_dir=data_dir, num_parallel_batches=1, is_shuffle=False, transform_fn=None, ).filter(lambda image, label: tf.equal( tf.convert_to_tensor(class_id, dtype=tf.int32), label)).skip( image_id_index).take(batch_size).batch(batch_size). make_one_shot_iterator().get_next()[0]), model_dir=ckpt_dir, ) prediction_valid = (predicted_label == class_id) if compute_adversarial: # The shape of each example is (1, 32, 32, 3) adversarial_examples = [ resnet10_cifar10_example( attack_name=attack_name, attack_fn=attack_fn, generate_adversarial_fn=generate_adversarial_fn, class_id=class_id, image_id=image_id, # model_dir not ckpt_dir model_dir=model_dir, transforms=transforms, transform_name=transform_name, dataset_mode=dataset_mode, ).load() for image_id in range(image_id_index, image_id_index + batch_size) ] adversarial_valid = np.array( [example is not None for example in adversarial_examples]) adversarial_examples = [ example if example is not None else np.zeros((1, 32, 32, 3)) for example in adversarial_examples ] adversarial_examples = np.squeeze( np.array(adversarial_examples).astype(np.float32), axis=1) # adversarial_example is [0, 1] of shape (1, 32, 32, 3) adversarial_predicted_label = predict_batch( create_model=create_model, input_fn=lambda: tf.data.Dataset.from_tensors(adversarial_examples ), model_dir=ckpt_dir, ) adversarial_prediction_valid = adversarial_predicted_label != class_id if compute_adversarial: batch_valid = (prediction_valid * adversarial_valid * adversarial_prediction_valid) else: batch_valid = (prediction_valid) original_graph_dir = os.path.join(graph_dir, f"original_{transform_name}", f"{class_id}") original_graph_saver = IOBatchAction( dir=original_graph_dir, root_index=image_id_index, ) original_model_fn = partial( model_fn_with_fetch_hook, create_model=create_model, graph=graph, graph_saver=original_graph_saver, batch_valid=batch_valid, ) trace = reconstruct_trace_from_tf( class_id=class_id, model_fn=original_model_fn, input_fn=lambda: (input_fn_for_adversarial_examples( is_training=(dataset_mode == "train"), data_dir=data_dir, num_parallel_batches=1, is_shuffle=False, transform_fn=None, ).filter(lambda image, label: tf.equal( tf.convert_to_tensor(class_id, dtype=tf.int32), label)).skip( image_id_index).take(batch_size).batch(batch_size). make_one_shot_iterator().get_next()[0]), select_fn=select_fn, model_dir=ckpt_dir, ) if compute_adversarial: adversarial_graph_dir = os.path.join( graph_dir, f"{attack_name}_{transform_name}", f"{class_id}") adversarial_graph_saver = IOBatchAction( dir=adversarial_graph_dir, root_index=image_id_index, ) adversarial_model_fn = partial( model_fn_with_fetch_hook, create_model=create_model, graph=graph, graph_saver=adversarial_graph_saver, batch_valid=batch_valid, ) adversarial_trace = reconstruct_trace_from_tf( model_fn=adversarial_model_fn, input_fn=lambda: tf.data.Dataset.from_tensors(adversarial_examples ), select_fn=select_fn, model_dir=ckpt_dir, )
def save_graph( attack_name: str, attack_fn, generate_adversarial_fn, select_fn: Callable[[np.ndarray], np.ndarray], image_id_index, batch_size, class_id, # model_dir = "result/lenet/model_augmentation", model_dir=model_dir, transforms=None, transform_name="noop", graph_dir="result/test", dataset_mode=dataset_mode, images_per_class=1, compute_adversarial=True, **kwargs, ): data_dir = abspath(MNIST_PATH) model_dir = abspath(model_dir) ckpt_dir = f"{model_dir}/ckpts" create_model = lambda: LeNet(data_format="channels_first") graph = LeNet.graph().load() batch_size = min(batch_size, images_per_class - image_id_index) if dataset_mode == "test": dataset = mnist.test elif dataset_mode == "train": dataset = mnist.train else: raise RuntimeError("Dataset invalid") predicted_label = predict_batch( create_model=create_model, input_fn=lambda: (dataset(data_dir).filter(lambda image, label: tf.equal( tf.convert_to_tensor(class_id, dtype=tf.int32), label)).skip( image_id_index).take(batch_size).batch(batch_size). make_one_shot_iterator().get_next()[0]), model_dir=ckpt_dir, ) prediction_valid = (predicted_label == class_id) # The shape of each example is (1, 32, 32, 3) adversarial_examples = [ lenet_mnist_example( attack_name=attack_name, attack_fn=attack_fn, generate_adversarial_fn=generate_adversarial_fn, class_id=class_id, image_id=image_id, # model_dir not ckpt_dir model_dir=model_dir, transforms=transforms, transform_name=transform_name, mode=dataset_mode, ).load() for image_id in range(image_id_index, image_id_index + batch_size) ] adversarial_valid = np.array( [example is not None for example in adversarial_examples]) adversarial_examples = [ example if example is not None else np.zeros((1, 1, 28, 28)) for example in adversarial_examples ] adversarial_examples = np.squeeze(np.array(adversarial_examples).astype( np.float32), axis=1) # adversarial_example is [0, 1] of shape (1, 28, 28) adversarial_predicted_label = predict_batch( create_model=create_model, input_fn=lambda: tf.data.Dataset.from_tensors( mnist.normalize(adversarial_examples)), model_dir=ckpt_dir, ) adversarial_prediction_valid = adversarial_predicted_label != class_id batch_valid = (prediction_valid * adversarial_valid * adversarial_prediction_valid) original_graph_dir = os.path.join(graph_dir, f"original_{transform_name}", f"{class_id}") original_graph_saver = IOBatchAction( dir=original_graph_dir, root_index=image_id_index, ) original_model_fn = partial( model_fn_with_fetch_hook, create_model=create_model, graph=graph, graph_saver=original_graph_saver, batch_valid=batch_valid, ) trace = reconstruct_trace_from_tf( class_id=class_id, model_fn=original_model_fn, input_fn=lambda: (dataset( data_dir, transforms=transforms, ).filter(lambda image, label: tf.equal( tf.convert_to_tensor(class_id, dtype=tf.int32), label)).skip( image_id_index).take(batch_size).batch(batch_size). make_one_shot_iterator().get_next()[0]), select_fn=select_fn, model_dir=ckpt_dir, ) if compute_adversarial: adversarial_graph_dir = os.path.join( graph_dir, f"{attack_name}_{transform_name}", f"{class_id}") adversarial_graph_saver = IOBatchAction( dir=adversarial_graph_dir, root_index=image_id_index, ) adversarial_model_fn = partial( model_fn_with_fetch_hook, create_model=create_model, graph=graph, graph_saver=adversarial_graph_saver, batch_valid=batch_valid, ) adversarial_trace = reconstruct_trace_from_tf( model_fn=adversarial_model_fn, input_fn=lambda: tf.data.Dataset.from_tensors( mnist.normalize(adversarial_examples)), select_fn=select_fn, model_dir=ckpt_dir, )
def predict_original_adversarial( attack_name: str, attack_fn, generate_adversarial_fn, image_id_index, batch_size, class_id, model_dir=model_dir, transforms=None, transform_name="noop", graph_dir="result/test", dataset_mode=dataset_mode, images_per_class=1, **kwargs, ): # mode.check(False) data_dir = abspath(CIFAR10_PATH) model_dir = abspath(model_dir) ckpt_dir = f"{model_dir}/ckpts" create_model = lambda: partial( ResNet10Cifar10(), training=False, ) graph = ResNet10Cifar10.graph().load() batch_size = min(batch_size, images_per_class - image_id_index) predicted_label = predict_batch( create_model=create_model, input_fn=lambda: (input_fn_for_adversarial_examples( is_training=(dataset_mode == "train"), data_dir=data_dir, num_parallel_batches=1, is_shuffle=False, transform_fn=None, ).filter(lambda image, label: tf.equal( tf.convert_to_tensor(class_id, dtype=tf.int32), label)).skip( image_id_index).take(batch_size).batch(batch_size). make_one_shot_iterator().get_next()[0]), model_dir=ckpt_dir, ) # The shape of each example is (1, 32, 32, 3) adversarial_examples = [ resnet10_cifar10_example( attack_name=attack_name, attack_fn=attack_fn, generate_adversarial_fn=generate_adversarial_fn, class_id=class_id, image_id=image_id, # model_dir not ckpt_dir model_dir=model_dir, transforms=transforms, transform_name=transform_name, dataset_mode=dataset_mode, ).load() for image_id in range(image_id_index, image_id_index + batch_size) ] adversarial_valid = np.array( [example is not None for example in adversarial_examples]) adversarial_examples = [ example if example is not None else np.zeros((1, 32, 32, 3)) for example in adversarial_examples ] adversarial_examples = np.squeeze(np.array(adversarial_examples).astype( np.float32), axis=1) # adversarial_example is [0, 1] of shape (1, 32, 32, 3) adversarial_predicted_label = predict_batch( create_model=create_model, input_fn=lambda: tf.data.Dataset.from_tensors(adversarial_examples), model_dir=ckpt_dir, ) assert predicted_label.shape == adversarial_predicted_label.shape original_correct = (predicted_label[adversarial_valid] == class_id).sum() adversarial_correct = ( adversarial_predicted_label[adversarial_valid] == class_num).sum() valid_count = adversarial_valid.sum() return original_correct, adversarial_correct, valid_count