def forward(self, x, data): # Remove edges used for pooling from stack unpool_nodes, unpool_edges, unpool_connection = data.unpool_nodes[ -1], data.unpool_edges[-1], data.unpool_connection[-1] if len(data.unpool_nodes) > 1: data.unpool_nodes, data.unpool_edges, data.unpool_connection = data.unpool_nodes[: -1], data.unpool_edges[: -1], data.unpool_connection[: -1] # Create a mapping from edge indices to pooled node indices unpool_map = torch.zeros(data.num_nodes).long() unpool_map[unpool_nodes] = torch.arange(unpool_nodes.size(0)) # Assign values of pooled nodes to nearest nodes x_sh = [data.num_nodes] + list(x.size()[1:]) new_x = torch.zeros(x_sh).to(x.device) new_x[unpool_edges[1]] = x[unpool_map[unpool_edges[0]]] x = new_x # Apply parallel transport to correct rotation connection = unpool_connection[unpool_edges[1].argsort()] if x.size(1) > 1: x[:, 1, :, 0], x[:, 1, :, 1] = complex_product(x[:, 1, :, 0], x[:, 1, :, 1], connection[:, None, 0], -connection[:, None, 1]) return x
def message(self, x_j, precomp, connection): """ Locally aligns features with parallel transport (using connection) and applies the precomputed component of the circular harmonic filter to each neighbouring node (the target nodes). :param x_j: the feature vector of the target neighbours [n_edges, prev_order + 1, in_channels, 2] :param precomp: the precomputed part of harmonic networks [n_edges, max_order + 1, n_rings, 2]. :param connection: the connection encoding parallel transport for each edge [n_edges, 2]. :return: the message from each target to the source nodes [n_edges, n_rings, in_channels, prev_order + 1, max_order + 1, 2] """ (N, M, F, C), R = x_j.size(), self.n_rings # Set up result tensors res = torch.cuda.FloatTensor(N, R, F, M, self.max_order + 1, C).fill_(0) # Compute the convolutions per stream for input_order in range(M): # Fetch correct input order and reshape for matrix multiplications x_j_m = x_j[:, input_order, None, :, :] # [N, 1, in_channels, 2] # First apply parallel transport if connection is not None and input_order > 0: rot_re = connection[:, None, None, 0] rot_im = connection[:, None, None, 1] x_j_m[..., 0], x_j_m[..., 1] = complex_product(x_j_m[..., 0], x_j_m[..., 1], rot_re, rot_im) # Next, apply precomputed component for output_order in range(self.max_order + 1): m = output_order - input_order sign = np.sign(m) m = np.abs(m) # Compute product with precomputed component res[:, :, :, input_order, output_order, 0], res[:, :, :, input_order, output_order, 1] = complex_product( x_j_m[..., 0], x_j_m[..., 1], precomp[:, m, :, 0, None], sign * precomp[:, m, :, 1, None]) return res
def message(self, x_j, connection): """ Applies connection to each neighbour, before aggregating for pooling. """ # Apply parallel transport to features from stream 1 if (x_j.size(1) > 1): x_j[:, 1, :, 0], x_j[:, 1, :, 1] = complex_product(x_j[:, 1, :, 0], x_j[:, 1, :, 1], connection[:, None, 0], connection[:, None, 1]) return x_j
def update(self, aggr_out): """ Updates node embeddings with circular harmonic filters. This is done separately for each rotation order stream. :param aggr_out: the result of the aggregation operation [n_nodes, n_rings, in_channels, prev_order + 1, max_order + 1, complex] :return: the new feature vector for x [n_nodes, max_order + 1, out_channels, complex] """ (N, _, F, M, _, C), O = aggr_out.size(), self.out_channels res = torch.cuda.FloatTensor(N, M, self.max_order + 1, O, 2).fill_(0) for input_order in range(M): for output_order in range(self.max_order + 1): m = np.abs(output_order - input_order) m_idx = input_order * ( self.max_order + 1) + output_order if self.separate_streams else m aggr_re = aggr_out[:, :, None, :, input_order, output_order, 0] # [N, n_rings, 1, in_channels] aggr_im = aggr_out[:, :, None, :, input_order, output_order, 1] # [N, n_rings, 1, in_channels] # Apply the radial profile aggr_re = (self.radial_profile[m_idx] * aggr_re).sum( dim=1) # [N, out_channels, in_channels] aggr_im = (self.radial_profile[m_idx] * aggr_im).sum( dim=1) # [N, out_channels, in_channels] # Apply phase offset if self.offset: cos = torch.cos(self.phase_offset[m_idx] ) # [out_channels, in_channels] sin = torch.sin(self.phase_offset[m_idx] ) # [out_channels, in_channels] aggr_re, aggr_im = complex_product(aggr_re, aggr_im, cos, sin) # Store per rotation stream res[:, input_order, output_order, :, 0] = aggr_re.sum(dim=-1) res[:, input_order, output_order, :, 1] = aggr_im.sum(dim=-1) # The input streams are summed together to retrieve one value per output stream return res.sum(dim=1)