Beispiel #1
0
 def ray_compute(
     dataset_split="test",
     save_adversarial=True,
     images_per_class=test_images_per_class,
 ):
     save_fn = partial(
         save_logic_per_image,
         dataset_split=dataset_split,
         save_adversarial=save_adversarial,
     )
     params = [
         (
             class_id,
             image_id,
         )
         for class_id in range(10)
         for image_id in range(images_per_class)
     ]
     results = ray_iter(
         save_fn,
         params,
         chunksize=1,
         out_of_order=True,
         huge_task=True,
     )
     results = [result for result in results]
Beispiel #2
0
def test_data_forward(
    class_trace_fn: Callable[[int], IOAction[AttrMap]],
    select_fn: Callable[[np.ndarray], np.ndarray],
    overlap_fn: Callable[[AttrMap, AttrMap, str], float],
    path: str,
    per_channel: bool = False,
    per_node: bool = False,
    images_per_class: int = 1,
    num_gpus: float = 0.2,
    model_dir = "result/lenet/model_augmentation",
    transforms = None,
    **kwargs,
):
    def forward_one_image(class_id: int, image_id: int) -> Dict[str, Any]:
        nonlocal model_dir
        mode.check(False)
        data_dir = abspath(MNIST_PATH)
        model_dir = abspath(model_dir)
        ckpt_dir = f"{model_dir}/ckpts"
        create_model = lambda: LeNet(data_format="channels_first")
        graph = LeNet.graph().load()
        # model_fn = partial(
        #     model_fn_with_fetch_hook,
        #     create_model=create_model, graph=graph
        # )

        for transform, name in transforms:
            predicted_label = predict(
                create_model=create_model,
                input_fn=lambda: mnist.test(
                                            data_dir,
                                            transforms = transform
                                            )
                .filter(
                    lambda image, label: tf.equal(
                        tf.convert_to_tensor(class_id, dtype=tf.int32), label
                    )
                )
                .skip(image_id)
                .take(1)
                .batch(1),
                model_dir=ckpt_dir,
                tag = name,
            )

    results = ray_iter(
        forward_one_image,
        (
            (class_id, image_id)
            for image_id in range(0, images_per_class)
            for class_id in range(0, 1)
        ),
        chunksize=1,
        out_of_order=True,
        num_gpus=num_gpus,
    )

    results = [result for result in results]
Beispiel #3
0
def count_logics(
    logic_name="unary",
    logic_fn=channel_activation,
    trace_key=TraceKey.POINT,
):
    def load_logics(
        class_id,
        image_id,
        transform_name="noop",
    ):
        dataset_split = "train"
        path = logic_save_path(
            logic_dir,
            trace_key,
            logic_name,
            dataset_split,
            "original",
            transform_name,
            class_id,
            image_id,
        )
        if not os.path.exists(path):
            return {}
        with open(path, "rb") as f:
            logics = pickle.load(f)
        return logics

    def logic_plot_hist_save(
        logics,
        filter_count,
    ):
        thred_filter = (logics > filter_count).astype(np.uint8)
        sparse_thred_filter = SparseMask(thred_filter)

        nonzero_filter = (logics > 0).astype(np.uint8)
        sparse_nonzero_filter = SparseMask(nonzero_filter)

        return sparse_thred_filter, sparse_nonzero_filter

    node_to_logics = {}
    for class_id in range(10):
        params = [(
            class_id,
            image_id,
        ) for image_id in range(train_images_per_class)]
        results = ray_iter(
            load_logics,
            params,
            chunksize=1,
            out_of_order=True,
            huge_task=True,
        )
        results = [result for result in results if len(result) > 0]

        thred_filter_per_class = {}
        for node_name in results[0].keys():
            shape = results[0][node_name].shape
            logics_acc = np.zeros(shape)
            for result in results:
                logic = result[node_name].to_tensor()
                logics_acc += abs(logic)
            if class_id == 0:
                node_to_logics[node_name] = logics_acc.copy()
            else:
                node_to_logics[node_name] += logics_acc

            name = f"{class_id}/{node_name.split(':')[0].split('/')[0]}"

            filter_count = (logic_filter_thred[logic_name] *
                            train_images_per_class)

            sparse_thred_filter, _ = logic_plot_hist_save(
                logics_acc,
                filter_count,
            )
            thred_filter_per_class[node_name] = sparse_thred_filter
        print(f"{class_id}")
        print_logic_per_class(thred_filter_per_class)

    thred_filter_all = {}
    nonzero_filter_all = {}
    for node_name in results[0].keys():
        filter_count = (logic_filter_thred[logic_name] * 10 *
                        train_images_per_class)
        dataset_logics = node_to_logics[node_name]
        sparse_thred_filter, sparse_nonzero_filter = logic_plot_hist_save(
            dataset_logics,
            filter_count,
        )
        thred_filter_all[node_name] = sparse_thred_filter
        nonzero_filter_all[node_name] = sparse_nonzero_filter
    print(f"all")
    print_logic_per_class(thred_filter_all)
    def get_overlap_ratio() -> pd.DataFrame:
        def get_row(class_id: int, image_id: int) -> Dict[str, Any]:
            nonlocal model_dir
            mode.check(False)
            data_dir = abspath(MNIST_PATH)
            model_dir = abspath(model_dir)
            ckpt_dir = f"{model_dir}/ckpts"
            create_model = lambda: LeNet(data_format="channels_first")
            graph = LeNet.graph().load()
            model_fn = partial(model_fn_with_fetch_hook,
                               create_model=create_model,
                               graph=graph)

            predicted_label = predict(
                create_model=create_model,
                input_fn=lambda: mnist.test(data_dir).filter(
                    lambda image, label: tf.equal(
                        tf.convert_to_tensor(class_id, dtype=tf.int32), label))
                .skip(image_id).take(1).batch(1),
                model_dir=ckpt_dir,
            )
            if predicted_label != class_id:
                return [{}] if per_node else {}

            adversarial_example = lenet_mnist_example(
                attack_name=attack_name,
                attack_fn=attack_fn,
                generate_adversarial_fn=generate_adversarial_fn,
                class_id=class_id,
                image_id=image_id,
                # model_dir not ckpt_dir
                model_dir=model_dir,
                transforms=transforms,
                transform_name=transform_name,
            ).load()

            if adversarial_example is None:
                return [{}] if per_node else {}

            adversarial_predicted_label = predict(
                create_model=create_model,
                input_fn=lambda: tf.data.Dataset.from_tensors(
                    mnist.normalize(adversarial_example)),
                model_dir=ckpt_dir,
            )

            if predicted_label == adversarial_predicted_label:
                return [{}] if per_node else {}

            trace = reconstruct_trace_from_tf(
                class_id=class_id,
                model_fn=model_fn,
                input_fn=lambda: mnist.test(data_dir, transforms=transforms).
                filter(lambda image, label: tf.equal(
                    tf.convert_to_tensor(class_id, dtype=tf.int32), label)
                       ).skip(image_id).take(1).batch(1),
                select_fn=select_fn,
                model_dir=ckpt_dir,
                per_channel=per_channel,
            )[0]

            if trace is None:
                return [{}] if per_node else {}

            adversarial_trace = reconstruct_trace_from_tf_brute_force(
                model_fn=model_fn,
                input_fn=lambda: tf.data.Dataset.from_tensors(
                    mnist.normalize(adversarial_example)),
                select_fn=select_fn,
                model_dir=ckpt_dir,
                per_channel=per_channel,
            )[0]

            adversarial_label = adversarial_trace.attrs[GraphAttrKey.PREDICT]

            # adversarial_pred, violation = \
            #     detect_by_reduced_edge_count_violation(
            #         class_trace_fn(adversarial_label).load(),
            #                         adversarial_trace,
            #                         reduce_mode,
            #     )
            # row = {
            #     "image_id": image_id,
            #     "class_id": class_id,
            #     "original.prediction":
            #         detect_by_reduced_edge(class_trace_fn(class_id).load(),
            #                                 trace,
            #                                 reduce_mode,
            #                                 ),
            #     "adversarial.prediction":
            #         adversarial_pred,
            #     "violation": violation,
            # }

            original_pred, violation = \
                detect_by_reduced_edge_count_violation(
                    class_trace_fn(class_id).load(),
                                    trace,
                                    reduce_mode,
                )
            row = {
                "image_id":
                image_id,
                "class_id":
                class_id,
                "original.prediction":
                original_pred,
                "adversarial.prediction":
                detect_by_reduced_edge(
                    class_trace_fn(adversarial_label).load(),
                    adversarial_trace,
                    reduce_mode,
                ),
                "violation":
                violation,
            }
            return row

        results = ray_iter(
            get_row,
            ((class_id, image_id) for image_id in range(0, images_per_class)
             for class_id in range(0, 10)),
            # ((-1, image_id) for image_id in range(mnist_info.test().size)),
            chunksize=1,
            out_of_order=True,
            num_gpus=num_gpus,
        )
        results = [result for result in results if len(result) != 0]
        violations = [result['violation'] for result in results]
        for result in results:
            result.pop('violation')
        traces = results
        mode_to_violation[reduce_mode] = violations
        return pd.DataFrame(traces)
def graph_to_trace(
    attack_name,
    image_id_index,
    batch_size,
    class_id,
    save_dir="result/test",
    graph_dir="result/test",
    chunksize=1,
    transform_name="noop",
    images_per_class=10,
    compute_adversarial=True,
    use_class_trace=False,
    class_dir=None,
):
    def compute_trace(
        class_id,
        image_id,
    ):
        def compute_trace_per_attack(attack_name):
            graph_path = os.path.join(graph_dir,
                                      f"{attack_name}_{transform_name}",
                                      f"{class_id}", f"{image_id}.pkl")
            if not os.path.exists(graph_path):
                return None

            single_graph = IOObjAction(graph_path).load()
            if use_class_trace:
                logits = single_graph.tensor(single_graph.outputs[0]).value
                predicted_label = np.argmax(logits)
                class_trace_avg = ClassTraceIOAction(class_dir,
                                                     predicted_label).load()
                assert class_trace_avg is not None
            else:
                class_trace_avg = None
            global debug
            single_trace = reconstruct_trace_from_tf_to_trace(
                single_graph,
                class_id=(class_id if attack_name == "original" else None),
                select_fn=select_fn,
                class_trace=class_trace_avg,
                debug=debug,
            )
            trace_path = os.path.join(save_dir,
                                      f"{attack_name}_{transform_name}",
                                      f"{class_id}", f"{image_id}.pkl")
            IOObjAction(trace_path).save(single_trace)
            return single_trace

        TraceKey.BUG_IMAGEID = image_id
        trace = compute_trace_per_attack("original")
        if trace is None:
            return {}
        result = {"class_id": class_id, "image_id": image_id, "trace": trace}
        if compute_adversarial:
            adversarial_trace = compute_trace_per_attack(attack_name)
            result["adversarial_trace"] = adversarial_trace
        return result

    batch_param = [(class_id, image_id) for image_id in range(
        image_id_index, min(image_id_index + batch_size, images_per_class))]
    results = ray_iter(
        compute_trace,
        batch_param,
        chunksize=chunksize,
        out_of_order=True,
    )
    results = [result for result in results if len(result) > 0]
    return results
Beispiel #6
0
def load_logic_data(
    class_thred_logics,
    dataset_split,
    images_per_class,
    data_config,
    attack_name="original",
):
    def filter_logic_by_class_thred(
        target_logic,
        class_logic,
        layer_names=point_keys,
    ):
        filtered_logic = {}
        for key in layer_names:
            target_mask = target_logic[key].to_tensor()
            class_mask = class_logic[key].to_tensor()
            class_mask = class_mask.astype(np.bool)
            filtered_target = target_mask[class_mask]
            # filtered_target = target_mask.flatten()
            filtered_logic[key] = filtered_target
            # print(key, target_mask.shape, filtered_target.shape)
        return filtered_logic

    def load_logic_data_per_image_config(
        config,
        class_id,
        image_id,
    ):
        trace_key, logic_name, layer_names = config
        path = logic_save_path(
            logic_dir,
            trace_key,
            logic_name,
            dataset_split,
            attack_name,
            "noop",
            class_id,
            image_id,
        )
        if not os.path.exists(path):
            return {}
        with open(path, "rb") as f:
            logic = pickle.load(f)

        # filtered_logic = filter_logic_by_class_thred(
        #     logic,
        #     class_thred_logics[trace_key]["all"][logic_name],
        #     layer_names,
        # )

        logic_data = [data.to_tensor() for data in logic.values()]
        # for d in logic_data:
        #     print(d.shape)

        logic_data = np.concatenate(
            logic_data,
            axis=0,
        )
        return logic_data

    def load_logic_data_per_image(
        data_config,
        class_id,
        image_id,
    ):
        image_logic = []
        for config in data_config:
            image_logic.append(
                load_logic_data_per_image_config(
                    config,
                    class_id,
                    image_id,
                ))
            if len(image_logic[-1]) == 0:
                return [], -1, -1

        logic_data = np.concatenate(
            [data for data in image_logic],
            axis=0,
        )
        # logic_data = np.expand_dims(
        #     logic_data, 0
        # )

        raw_prediction = load_raw_prediction(
            class_id,
            image_id,
            dataset_split,
            attack_name,
        )
        return logic_data, class_id, raw_prediction

    results = ray_iter(
        load_logic_data_per_image,
        [(data_config, class_id, image_id) for class_id in range(10)
         for image_id in range(images_per_class)],
        chunksize=1,
        out_of_order=True,
        huge_task=True,
    )
    results = [result for result in results if len(result[0]) > 0]

    logic_data = np.array([result[0] for result in results])
    label = np.array([result[1] for result in results])
    raw_prediction = np.array([result[2] for result in results])

    return logic_data, label, raw_prediction
Beispiel #7
0
    def get_overlap_ratio() -> pd.DataFrame:
        def get_row(class_id: int, image_id: int) -> Dict[str, Any]:
            nonlocal model_dir
            mode.check(False)
            data_dir = abspath(MNIST_PATH)
            model_dir = abspath(model_dir)
            ckpt_dir = f"{model_dir}/ckpts"
            create_model = lambda: LeNet(data_format="channels_first")
            graph = LeNet.graph().load()

            model_fn = partial(model_fn_with_fetch_hook,
                               create_model=create_model,
                               graph=graph)
            if mnist_dataset_mode == "test":
                dataset = mnist.test
            elif mnist_dataset_mode == "train":
                dataset = mnist.train
            else:
                raise RuntimeError("Dataset invalid")

            predicted_label = predict(
                create_model=create_model,
                # dataset may be train or test, should be consistent with lenet_mnist_example
                input_fn=lambda: dataset(data_dir, ).filter(
                    lambda image, label: tf.equal(
                        tf.convert_to_tensor(class_id, dtype=tf.int32), label))
                .skip(image_id).take(1).batch(1),
                model_dir=ckpt_dir,
            )
            if predicted_label != class_id:
                return [{}] if per_node else {}

            adversarial_example = lenet_mnist_example(
                attack_name=attack_name,
                attack_fn=attack_fn,
                generate_adversarial_fn=generate_adversarial_fn,
                class_id=class_id,
                image_id=image_id,
                # model_dir not ckpt_dir
                model_dir=model_dir,
                transforms=transforms,
                transform_name=transform_name,
                mode=mnist_dataset_mode,
            ).load()

            if adversarial_example is None:
                return [{}] if per_node else {}

            adversarial_predicted_label = predict(
                create_model=create_model,
                input_fn=lambda: tf.data.Dataset.from_tensors(
                    mnist.normalize(adversarial_example)),
                model_dir=ckpt_dir,
            )

            if predicted_label == adversarial_predicted_label:
                return [{}] if per_node else {}

            if use_class_trace:
                class_trace_avg = ClassTraceIOAction(predicted_label).load()
            else:
                class_trace_avg = None

            trace = reconstruct_trace_from_tf(
                class_id=class_id,
                model_fn=model_fn,
                input_fn=lambda: dataset(
                    data_dir,
                    transforms=transforms,
                ).filter(lambda image, label: tf.equal(
                    tf.convert_to_tensor(class_id, dtype=tf.int32), label)).
                skip(image_id).take(1).batch(1),
                select_fn=select_fn,
                model_dir=ckpt_dir,
                per_channel=per_channel,
                class_trace=class_trace_avg,
            )[0]

            if trace is None:
                return [{}] if per_node else {}
            path = os.path.join(save_dir, f"original_{transform_name}",
                                f"{class_id}", f"{image_id}.pkl")
            ensure_dir(path)
            with open(path, "wb") as f:
                pickle.dump(trace, f)

            adversarial_trace = reconstruct_trace_from_tf(
                model_fn=model_fn,
                input_fn=lambda: tf.data.Dataset.from_tensors(
                    mnist.normalize(adversarial_example)),
                select_fn=select_fn,
                model_dir=ckpt_dir,
                per_channel=per_channel,
                class_trace=class_trace_avg,
            )[0]

            path = os.path.join(save_dir, f"{attack_name}_{transform_name}",
                                f"{class_id}", f"{image_id}.pkl")
            ensure_dir(path)
            with open(path, "wb") as f:
                pickle.dump(adversarial_trace, f)

            row = {
                "class_id": class_id,
                "image_id": image_id,
                "trace": trace,
                "adversarial_trace": adversarial_trace,
            }
            # row = calc_all_overlap(
            #     class_trace_fn(class_id).load(), adversarial_trace, overlap_fn
            # )
            return row

        traces = ray_iter(
            get_row,
            ((class_id, image_id) for image_id in range(images_per_class)
             for class_id in range(0, 10)),
            # ((-1, image_id) for image_id in range(mnist_info.test().size)),
            chunksize=1,
            out_of_order=True,
            num_gpus=num_gpus,
        )
        traces = [trace for trace in traces if len(trace) != 0]

        return traces
Beispiel #8
0
def evaluate_by_NOT(attack_name, ):
    class_nonzero_logics = load_class_filter_logics("nonzero_filter", )

    def load_per_image_logics(
        class_id,
        image_id,
        attack_name,
    ):
        per_image_logics = {}
        for logic_name in available_logic_names:
            path = logic_save_path(
                logic_dir,
                logic_name,
                "test",
                attack_name,
                "noop",
                class_id,
                image_id,
            )
            if not os.path.exists(path):
                return {}
            with open(path, "rb") as f:
                per_image_logics[logic_name] = pickle.load(f)
        return per_image_logics

    def count_adversarial_logic_difference(
        original_logics,
        adversarial_logics,
    ):
        logic_diff = {}
        for logic_name in available_logic_names:
            original_per_logic = original_logics[logic_name]
            adversarial_per_logic = adversarial_logics[logic_name]
            for key in original_per_logic:
                original = original_per_logic[key].to_tensor()
                adversarial = adversarial_per_logic[key].to_tensor()
                diff = (original != adversarial).sum()
                logic_diff[f"{logic_name}.{key}"] = diff
        return logic_diff

    def eval_per_image(
        class_id,
        image_id,
    ):
        original_logics = load_per_image_logics(class_id, image_id, "original")
        original_pred = load_raw_prediction(
            class_id,
            image_id,
            "test",
            "original",
        )
        adversarial_logics = load_per_image_logics(
            class_id,
            image_id,
            attack_name,
        )
        adversarial_pred = load_raw_prediction(class_id, image_id, "test",
                                               attack_name)
        if (len(original_logics) == 0 or original_pred == -1
                or len(adversarial_logics) == 0 or adversarial_pred == -1):
            return {}

        logic_difference = count_adversarial_logic_difference(
            original_logics,
            adversarial_logics,
        )

        original_detection_label = predict_by_nonzero_filter(
            original_logics,
            class_nonzero_logics[original_pred],
        )
        adversarial_detection_label = predict_by_nonzero_filter(
            adversarial_logics,
            class_nonzero_logics[adversarial_pred],
        )
        info = {
            "class_id": class_id,
            "imageid": image_id,
            "original.prediction": original_pred,
            "adversarial.prediction": adversarial_pred,
            "original.detection": original_detection_label,
            "adversarial.detection": adversarial_detection_label,
        }
        info.update(logic_difference)
        return info

    results = ray_iter(
        eval_per_image,
        [(class_id, image_id) for class_id in range(10)
         for image_id in range(test_images_per_class)],
        chunksize=1,
        out_of_order=True,
        huge_task=True,
    )
    results = [result for result in results if len(result) > 0]
    results = pd.DataFrame(results)
    st()
def compute_birelation_class(
    class_id,
    images_per_class=10,
    relation_op="and",  # and or or
    attack_name="FGSM",
    transform_name="noop",
    saved_trace_dir=f"{model_dir}/training_trace",
    result_dir="result/test",
):
    def compute_relation_per_pair(
        image_id,
        reconstruct_point_fn,
        relation_op,
        # op_to_collective_mask, # filter channels that are never actived for all samples
    ):
        original_dir = os.path.join(
            saved_trace_dir,
            f"original_{transform_name}",
            f"{class_id}",
        )
        adversarial_dir = os.path.join(
            saved_trace_dir,
            f"{attack_name}_{transform_name}",
            f"{class_id}",
        )

        original_path = os.path.join(
            original_dir,
            f"{image_id}.pkl",
        )
        adversarial_path = os.path.join(
            adversarial_dir,
            f"{image_id}.pkl",
        )
        if not os.path.exists(original_path) or \
            not os.path.exists(adversarial_path):
            return {}
        with open(original_path, "rb") as f:
            original_trace = pickle.load(f)
        with open(adversarial_path, "rb") as f:
            adversarial_trace = pickle.load(f)

        op_to_trace = {}
        for node_name in sorted(original_trace.nodes):
            if key in original_trace.nodes[node_name]:
                op_to_trace[node_name] = reconstruct_point_fn(
                    node_name=node_name, trace=original_trace)

        birelation_info = {}
        intro_conv_cmp = {}
        channel_info = {}
        for node_name in [
                "conv2d/Relu:0",
                "conv2d_1/Relu:0",
        ]:
            # collective_mask = op_to_collective_mask[node_name]
            mask = op_to_trace[node_name]
            mask = ndimage.binary_dilation(
                mask,
                # structure = dilation_structure,
                iterations=dilation_iter,
            ).astype(mask.dtype)
            filter_number = mask.shape[0]
            if relation_op == "and":
                cmp_result = [
                    (mask[i]==mask[j]).all() and \
                    (mask[i].sum()>0) and \
                    (mask[j].sum()>0)
                    for i in range(filter_number)
                    for j in range(i+1, filter_number)
                ]
            elif relation_op == "or":
                cmp_result = [(mask[i] * mask[j] > 0).any()
                              for i in range(filter_number)
                              for j in range(i + 1, filter_number)]

            cmp_result = np.array(cmp_result).astype(np.uint8)
            intro_conv_cmp[node_name] = cmp_result

            mask = op_to_trace[node_name]
            channel = mask.sum(-1).sum(-1)
            channel[channel > 0] = 1
            channel_info[node_name] = channel

        # Compute inter conv
        mask1 = op_to_trace["conv2d/Relu:0"]
        mask2 = op_to_trace["conv2d_1/Relu:0"]
        h, w = mask1.shape[-2:]
        c = mask2.shape[0]
        # use np.resize instead mask.resize
        mask2 = np.resize(mask2, (c, h, w))
        filter_number1 = mask1.shape[0]
        filter_number2 = mask2.shape[0]
        inter_conv_relation = [(mask1[i] * mask2[j] > 0).any()
                               for i in range(filter_number1)
                               for j in range(filter_number2)]
        inter_conv_relation = np.array(inter_conv_relation).astype(np.uint8)

        del op_to_trace
        birelation_info["Channel"] = channel_info
        birelation_info["IntroConv"] = intro_conv_cmp
        birelation_info["image_id"] = image_id
        birelation_info["InterConv"] = inter_conv_relation
        return birelation_info

    def filter_valid_filter(
        image_id,
        reconstruct_point_fn,
    ):
        original_dir = os.path.join(
            saved_trace_dir,
            f"original_{transform_name}",
            f"{class_id}",
        )
        adversarial_dir = os.path.join(
            saved_trace_dir,
            f"{attack_name}_{transform_name}",
            f"{class_id}",
        )

        original_path = os.path.join(
            original_dir,
            f"{image_id}.pkl",
        )
        adversarial_path = os.path.join(
            adversarial_dir,
            f"{image_id}.pkl",
        )
        if not os.path.exists(original_path) or \
            not os.path.exists(adversarial_path):
            return {}
        with open(original_path, "rb") as f:
            original_trace = pickle.load(f)
        with open(adversarial_path, "rb") as f:
            adversarial_trace = pickle.load(f)

        op_to_mask = {}
        for node_name in sorted(original_trace.nodes):
            if key in original_trace.nodes[node_name]:
                op_to_mask[node_name] = reconstruct_point_fn(
                    node_name=node_name, trace=original_trace)

        op_to_valid_mask = {}
        for node_name in [
                "conv2d/Relu:0",
                "conv2d_1/Relu:0",
        ]:
            mask = op_to_mask[node_name]
            squeezed_mask = mask.sum(-1).sum(-1)
            nonzero_mask = squeezed_mask > 0
            op_to_valid_mask[node_name] = nonzero_mask

        # for k in op_to_valid_mask:
        #     print(f"{k}: {op_to_valid_mask[k].shape}")
        return op_to_valid_mask

    ray_params = [
        (
            image_id,
            reconstruct_point_fn,
            relation_op,
            # op_to_reduced_mask,
        ) for image_id in range(images_per_class)
    ]

    results = ray_iter(
        compute_relation_per_pair,
        ray_params,
        chunksize=1,
        out_of_order=True,
        huge_task=True,
    )
    results = [result for result in results if len(result) > 0]
    print(
        f"Class {class_id} op {relation_op}: {len(results)}/{images_per_class} valid samples"
    )
    relation_dir = os.path.join(result_dir, relation_op)
    os.makedirs(relation_dir, exist_ok=True)
    relation_path = os.path.join(relation_dir, f"{class_id}.pkl")
    with open(relation_path, "wb") as f:
        pickle.dump(results, f)