def attribute(self,
               x: GraphsTuple,
               model: TransparentModel,
               task_index: Optional[int] = None,
               batch_index: Optional[int] = None) -> List[GraphsTuple]:
     """Gets attribtutions."""
     n = self.num_steps
     ref = self.make_reference(x)
     n_nodes = tf.reduce_sum(x.n_node)
     n_edges = tf.reduce_sum(x.n_edge)
     interp, node_steps, edge_steps = graph_utils.interpolate_graphs_tuple(
         ref, x, n)
     nodes_grad, edges_grad = model.get_gradient(
         interp, task_index, batch_index)
     # Node shapes: [n_nodes * n, nodes.shape[-1]] -> [n_nodes*n].
     node_values = tf.einsum('ij,ij->i', nodes_grad, node_steps)
     edge_values = tf.einsum('ij,ij->i', edges_grad, edge_steps)
     # Node shapes: [n_nodes * n] -> [n_nodes, n].
     node_values = tf.transpose(tf.reshape(node_values, (n, n_nodes)))
     edge_values = tf.transpose(tf.reshape(edge_values, (n, n_edges)))
     # Node shapes: [n_nodes, n] -> [n_nodes].
     node_ig = tf.reduce_sum(node_values, axis=1)
     edge_ig = tf.reduce_sum(edge_values, axis=1)
     graphs = x.replace(nodes=node_ig, edges=edge_ig, globals=None)
     return list(graph_utils.split_graphs_tuple(graphs))
 def attribute(self,
               x: GraphsTuple,
               model: TransparentModel,
               task_index: Optional[int] = None,
               batch_index: Optional[int] = None) -> List[GraphsTuple]:
     """Gets attribtutions."""
     rand_nodes = np.random.uniform(size=(x.nodes.shape[0]))
     rand_edges = np.random.uniform(size=(x.edges.shape[0]))
     graphs = x.replace(nodes=rand_nodes, edges=rand_edges, globals=None)
     return list(graph_utils.split_graphs_tuple(graphs))
 def attribute(self,
               x: GraphsTuple,
               model: TransparentModel,
               task_index: Optional[int] = None,
               batch_index: Optional[int] = None) -> List[GraphsTuple]:
     """Gets attribtutions."""
     weights = model.get_attention_weights(x)
     weights = tf.stack(weights)  # [n_blocks, n_edges, n_heads]
     weights = self.head_reducer(weights, axis=2)  # [n_blocks, n_edges]
     weights = self.block_reducer(weights, axis=0)  # [n_edges]
     empty_nodes = tf.zeros(len(x.nodes))
     graphs = x.replace(nodes=empty_nodes, edges=weights, globals=None)
     return list(graph_utils.split_graphs_tuple(graphs))
 def attribute(self,
               x: GraphsTuple,
               model: TransparentModel,
               task_index: Optional[int] = None,
               batch_index: Optional[int] = None) -> List[GraphsTuple]:
     """Gets attribtutions."""
     node_grad, edge_grad = model.get_gradient(x, task_index, batch_index)
     node_weights = tf.einsum('ij,ij->i', x.nodes, node_grad)
     edge_weights = tf.einsum('ij,ij->i', x.edges, edge_grad)
     graphs = x.replace(
         nodes=node_weights,
         edges=edge_weights,
         globals=None)
     return list(graph_utils.split_graphs_tuple(graphs))
 def attribute(self,
               x,
               model: TransparentModel,
               task_index: Optional[int] = None,
               batch_index: Optional[int] = None) -> List[GraphsTuple]:
     """Gets attribtutions."""
     node_act, edge_act = model.get_gap_activations(x)
     weights = model.get_prediction_weights()
     node_weights = tf.einsum('ij,j', node_act, weights)
     edge_weights = tf.einsum('ij,j', edge_act, weights)
     graphs = x.replace(
         nodes=node_weights,
         edges=edge_weights,
         globals=None)
     return list(graph_utils.split_graphs_tuple(graphs))
 def attribute(self,
               x: GraphsTuple,
               model: TransparentModel,
               task_index: Optional[int] = None,
               batch_index: Optional[int] = None) -> List[GraphsTuple]:
     """Gets attribtutions."""
     n = self.num_samples
     n_nodes = int(tf.reduce_sum(x.n_node))
     n_edges = int(tf.reduce_sum(x.n_edge))
     noisy_x = graph_utils.perturb_graphs_tuple(x, n, self.sigma)
     atts = self.method.attribute(noisy_x, model, task_index, batch_index)
     atts = graph_nets.utils_tf.concat(atts, axis=0)
     many_nodes = tf.reshape(atts.nodes, (n, n_nodes))
     node_weights = tf.reduce_mean(many_nodes, axis=0)
     many_edges = tf.reshape(atts.edges, (n, n_edges))
     edge_weights = tf.reduce_mean(many_edges, axis=0)
     graphs = x.replace(
         nodes=node_weights,
         edges=edge_weights,
         globals=None)
     return list(graph_utils.split_graphs_tuple(graphs))
    def attribute(self,
                  x: GraphsTuple,
                  model: TransparentModel,
                  task_index: Optional[int] = None,
                  batch_index: Optional[int] = None) -> List[GraphsTuple]:
        """Gets attribtutions."""
        acts, grads, _ = model.get_intermediate_activations_gradients(
            x, task_index, batch_index)
        node_w, edge_w = [], []
        layer_indices = [-1] if self.last_layer_only else list(
            range(len(acts)))
        for index in layer_indices:
            node_act, edge_act = acts[index]
            node_grad, edge_grad = grads[index]
            node_w.append(tf.einsum('ij,ij->i', node_act, node_grad))
            edge_w.append(tf.einsum('ij,ij->i', edge_act, edge_grad))

        node_weights = self.reduce_fn(node_w, axis=0)
        edge_weights = self.reduce_fn(edge_w, axis=0)
        graphs = x.replace(
            nodes=node_weights,
            edges=edge_weights,
            globals=None)
        return list(graph_utils.split_graphs_tuple(graphs))