def forward(self, features_1, features_2=None): ''' :param features_1: [..., in1] or [..., in1, in2] if features_2 is None :param features_2: [..., in2] :return: [..., out] ''' d_out = dim(self.Rs_out) d_in1 = dim(self.Rs_in1) d_in2 = dim(self.Rs_in2) if self._complete == 'out' or features_2 is None: if features_2 is None: features = features_1 else: features = features_1[..., :, None] * features_2[..., None, :] size = features.shape[:-2] features = features.reshape(-1, d_in1, d_in2) # [in1, in2, batch] mixing_matrix = get_sparse_buffer( self, "mixing_matrix") # [out, in1 * in2] features = torch.einsum('zij->ijz', features) # [in1, in2, batch] features = features.reshape(d_in1 * d_in2, features.shape[2]) features = mixing_matrix @ features # [out, batch] return features.T.reshape(*size, d_out) if self._complete == 'in1': k = self.left(features_1) # [..., out, in2] return torch.einsum('...ij,...j->...i', k, features_2) if self._complete == 'in2': k = self.right(features_2) # [..., out, in1] return torch.einsum('...ij,...j->...i', k, features_1)
def forward(self, features): *size, n = features.size() features = features.reshape(-1, n) mixing_matrix = get_sparse_buffer(self, 'mixing_matrix') # features = torch.einsum('ij,zj->zi', self.mixing_matrix, features) features = (mixing_matrix @ features.T).T return features.reshape(*size, self.mul, -1)
def to_dense(self): """ :return: tensor of shape [dim(Rs_out), dim(Rs_in1), dim(Rs_in2)] """ mixing_matrix = get_sparse_buffer(self, "mixing_matrix") # [out, in1 * in2] mixing_matrix = mixing_matrix.to_dense() mixing_matrix = mixing_matrix.reshape(dim(self.Rs_out), dim(self.Rs_in1), dim(self.Rs_in2)) return mixing_matrix
def forward(self, features): ''' :param features: [..., channels] ''' *size, n = features.size() features = features.reshape(-1, n) mixing_matrix = get_sparse_buffer(self, "mixing_matrix") features = torch.einsum('zi,zj->ijz', features, features) features = mixing_matrix @ features.reshape(-1, features.shape[2]) return features.T.reshape(*size, -1)
def forward(self, features_1, features_2): """ :return: tensor [..., channel] """ *size, n = features_1.size() features_1 = features_1.reshape(-1, n) assert n == rs.dim(self.Rs_in1) *size2, n = features_2.size() features_2 = features_2.reshape(-1, n) assert size == size2 T = get_sparse_buffer(self, 'T') # [out, in1 * in2] kernel = (T.t() @ self.kernel().T).T.reshape(rs.dim(self.Rs_out), rs.dim(self.Rs_in1), rs.dim(self.Rs_in2)) # [out, in1, in2] features = torch.einsum('kij,zi,zj->zk', kernel, features_1, features_2) return features.reshape(*size, -1)
def right(self, features_2): ''' :param features_2: [..., in2] :return: [..., out, in1] ''' d_out = dim(self.Rs_out) d_in1 = dim(self.Rs_in1) d_in2 = dim(self.Rs_in2) size_2 = features_2.shape[:-1] features_2 = features_2.reshape(-1, d_in2) mixing_matrix = get_sparse_buffer(self, "mixing_matrix") # [out, in1 * in2] mixing_matrix = mixing_matrix.sparse_reshape(d_out * d_in1, d_in2) output = mixing_matrix @ features_2.T # [out * in1, batch] return output.T.reshape(*size_2, d_out, d_in1)
def forward(self, features): ''' :param features: [..., channels] ''' *size, n = features.size() features = features.reshape(-1, n) assert n == rs.dim(self.Rs_in) if self.linear: features = torch.cat([features.new_ones(features.shape[0], 1), features], dim=1) n += 1 T = get_sparse_buffer(self, 'T') # [out, in1 * in2] kernel = (T.t() @ self.kernel().T).T.reshape(rs.dim(self.Rs_out), n, n) # [out, in1, in2] features = torch.einsum('zi,zj->zij', features, features) features = torch.einsum('kij,zij->zk', kernel, features) return features.reshape(*size, -1)
def left(self, features_1): ''' :param features_1: [..., in1] :return: [..., out, in2] ''' d_out = dim(self.Rs_out) d_in1 = dim(self.Rs_in1) d_in2 = dim(self.Rs_in2) size_1 = features_1.shape[:-1] features_1 = features_1.reshape(-1, d_in1) mixing_matrix = get_sparse_buffer(self, "mixing_matrix") # [out, in1 * in2] mixing_matrix = mixing_matrix.sparse_reshape(d_out * d_in1, d_in2).t() # [in2, out * in1] mixing_matrix = mixing_matrix.sparse_reshape(d_in2 * d_out, d_in1) # [in2 * out, in1] output = mixing_matrix @ features_1.T # [in2 * out, batch] output = output.reshape(d_in2, d_out, features_1.shape[0]) output = torch.einsum('j*z->zij', output) return output.reshape(*size_1, d_out, d_in2)
def forward(self, features_1, features_2): ''' :param features_1: [..., in1] :param features_2: [..., in2] :return: [..., out] ''' d_out = dim(self.Rs_out) d_in1 = dim(self.Rs_in1) d_in2 = dim(self.Rs_in2) features = features_1[..., :, None] * features_2[..., None, :] size = features.shape[:-2] features = features.reshape(-1, d_in1, d_in2) # [in1, in2, batch] mixing_matrix = get_sparse_buffer(self, "mixing_matrix") # [out, in1 * in2] features = torch.einsum('zij->ijz', features) # [in1, in2, batch] features = features.reshape(d_in1 * d_in2, features.shape[2]) features = mixing_matrix @ features # [out, batch] return features.T.reshape(*size, d_out)