def forward(self, x, data):
        # Remove edges used for pooling from stack
        unpool_nodes, unpool_edges, unpool_connection = data.unpool_nodes[
            -1], data.unpool_edges[-1], data.unpool_connection[-1]
        if len(data.unpool_nodes) > 1:
            data.unpool_nodes, data.unpool_edges, data.unpool_connection = data.unpool_nodes[:
                                                                                             -1], data.unpool_edges[:
                                                                                                                    -1], data.unpool_connection[:

        # Create a mapping from edge indices to pooled node indices
        unpool_map = torch.zeros(data.num_nodes).long()
        unpool_map[unpool_nodes] = torch.arange(unpool_nodes.size(0))

        # Assign values of pooled nodes to nearest nodes
        x_sh = [data.num_nodes] + list(x.size()[1:])
        new_x = torch.zeros(x_sh).to(x.device)
        new_x[unpool_edges[1]] = x[unpool_map[unpool_edges[0]]]
        x = new_x

        # Apply parallel transport to correct rotation
        connection = unpool_connection[unpool_edges[1].argsort()]
        if x.size(1) > 1:
            x[:, 1, :, 0], x[:, 1, :,
                             1] = complex_product(x[:, 1, :, 0], x[:, 1, :, 1],
                                                  connection[:, None, 0],
                                                  -connection[:, None, 1])
        return x
    def message(self, x_j, precomp, connection):
        Locally aligns features with parallel transport (using connection) and
        applies the precomputed component of the circular harmonic filter to each neighbouring node (the target nodes).

        :param x_j: the feature vector of the target neighbours [n_edges, prev_order + 1, in_channels, 2]
        :param precomp: the precomputed part of harmonic networks [n_edges, max_order + 1, n_rings, 2].
        :param connection: the connection encoding parallel transport for each edge [n_edges, 2].
        :return: the message from each target to the source nodes [n_edges, n_rings, in_channels, prev_order + 1, max_order + 1, 2]

        (N, M, F, C), R = x_j.size(), self.n_rings

        # Set up result tensors
        res = torch.cuda.FloatTensor(N, R, F, M, self.max_order + 1,

        # Compute the convolutions per stream
        for input_order in range(M):
            # Fetch correct input order and reshape for matrix multiplications
            x_j_m = x_j[:, input_order, None, :, :]  # [N, 1, in_channels, 2]

            # First apply parallel transport
            if connection is not None and input_order > 0:
                rot_re = connection[:, None, None, 0]
                rot_im = connection[:, None, None, 1]
                      0], x_j_m[...,
                                1] = complex_product(x_j_m[..., 0], x_j_m[...,
                                                     rot_re, rot_im)

            # Next, apply precomputed component
            for output_order in range(self.max_order + 1):
                m = output_order - input_order
                sign = np.sign(m)
                m = np.abs(m)

                # Compute product with precomputed component
                res[:, :, :, input_order, output_order,
                    0], res[:, :, :, input_order, output_order,
                            1] = complex_product(
                                x_j_m[..., 0], x_j_m[..., 1], precomp[:, m, :,
                                                                      0, None],
                                sign * precomp[:, m, :, 1, None])

        return res
    def message(self, x_j, connection):
        Applies connection to each neighbour, before aggregating for pooling.

        # Apply parallel transport to features from stream 1
        if (x_j.size(1) > 1):
            x_j[:, 1, :, 0], x_j[:, 1, :, 1] = complex_product(x_j[:, 1, :, 0], x_j[:, 1, :, 1], connection[:, None, 0], connection[:, None, 1])

        return x_j
    def update(self, aggr_out):
        Updates node embeddings with circular harmonic filters.
        This is done separately for each rotation order stream.
        :param aggr_out: the result of the aggregation operation [n_nodes, n_rings, in_channels, prev_order + 1, max_order + 1, complex]
        :return: the new feature vector for x [n_nodes, max_order + 1, out_channels, complex]
        (N, _, F, M, _, C), O = aggr_out.size(), self.out_channels
        res = torch.cuda.FloatTensor(N, M, self.max_order + 1, O, 2).fill_(0)

        for input_order in range(M):
            for output_order in range(self.max_order + 1):
                m = np.abs(output_order - input_order)
                m_idx = input_order * (
                    self.max_order +
                    1) + output_order if self.separate_streams else m

                aggr_re = aggr_out[:, :, None, :, input_order, output_order,
                                   0]  # [N, n_rings, 1, in_channels]
                aggr_im = aggr_out[:, :, None, :, input_order, output_order,
                                   1]  # [N, n_rings, 1, in_channels]

                # Apply the radial profile
                aggr_re = (self.radial_profile[m_idx] * aggr_re).sum(
                    dim=1)  # [N, out_channels, in_channels]
                aggr_im = (self.radial_profile[m_idx] * aggr_im).sum(
                    dim=1)  # [N, out_channels, in_channels]

                # Apply phase offset
                if self.offset:
                    cos = torch.cos(self.phase_offset[m_idx]
                                    )  # [out_channels, in_channels]
                    sin = torch.sin(self.phase_offset[m_idx]
                                    )  # [out_channels, in_channels]
                    aggr_re, aggr_im = complex_product(aggr_re, aggr_im, cos,

                # Store per rotation stream
                res[:, input_order, output_order, :, 0] = aggr_re.sum(dim=-1)
                res[:, input_order, output_order, :, 1] = aggr_im.sum(dim=-1)

        # The input streams are summed together to retrieve one value per output stream
        return res.sum(dim=1)