def __init__(self, num_layers: int, atom_feature_size: int, num_channels: int, num_nlayers: int = 1, num_degrees: int = 4, dim_output: int = 3, edge_dim: int = 4, div: float = 4, pooling: str = 'avg', n_heads: int = 1, **kwargs): super().__init__() # Build the network self.num_layers = num_layers self.num_nlayers = num_nlayers self.num_channels = num_channels self.num_degrees = num_degrees self.edge_dim = edge_dim self.div = div self.pooling = pooling self.n_heads = n_heads self.dim_output = dim_output self.fibers = { 'in': Fiber(1, atom_feature_size), 'mid': Fiber(num_degrees, self.num_channels), 'out': Fiber(1, num_degrees * self.num_channels) } blocks = self._build_gcn(self.fibers, dim_output) self.Gblock, self.FCblock = blocks print(self.Gblock) print(self.FCblock)
def __init__(self, num_layers: int, atom_feature_size: int, num_channels: int, num_nlayers: int = 1, num_degrees: int = 4, edge_dim: int = 4, **kwargs): super().__init__() # Build the network self.num_layers = num_layers self.num_nlayers = num_nlayers self.num_channels = num_channels self.num_degrees = num_degrees self.num_channels_out = num_channels * num_degrees self.edge_dim = edge_dim self.fibers = { 'in': Fiber(1, atom_feature_size), 'mid': Fiber(num_degrees, self.num_channels), 'out': Fiber(1, self.num_channels_out) } blocks = self._build_gcn(self.fibers, 1) self.block0, self.block1, self.block2 = blocks print(self.block0) print(self.block1) print(self.block2)
def __init__(self, f_in: Fiber, f_out: Fiber, edge_dim: int=0, div: float=4, n_heads: int=1): super().__init__() self.f_in = f_in self.f_out = f_out self.div = div self.n_heads = n_heads f_mid_out = {k: int(v // div) for k, v in self.f_out.structure_dict.items()} self.f_mid_out = Fiber(dictionary=f_mid_out) f_mid_in = {d: m for d, m in f_mid_out.items() if d in self.f_in.degrees} self.f_mid_in = Fiber(dictionary=f_mid_in) self.edge_dim = edge_dim self.GMAB = nn.ModuleDict() # Projections self.GMAB['v'] = GConvSE3Partial(f_in, self.f_mid_out, edge_dim=edge_dim) self.GMAB['k'] = GConvSE3Partial(f_in, self.f_mid_in, edge_dim=edge_dim) self.GMAB['q'] = G1x1SE3(f_in, self.f_mid_in) # Attention self.GMAB['attn'] = GMABSE3(self.f_mid_out, self.f_mid_in, n_heads=n_heads) # Skip connections self.project = G1x1SE3(self.f_mid_out, f_out) self.add = GSum(f_out, f_in)
def __init__(self, num_layers: int, num_channels: int, num_degrees: int = 4, div: float = 4, n_heads: int = 1, si_m='1x1', si_e='att', x_ij='add'): """ Args: num_layers: number of attention layers num_channels: number of channels per degree num_degrees: number of degrees (aka types) in hidden layer, count start from type-0 div: (int >= 1) keys, queries and values will have (num_channels/div) channels n_heads: (int >= 1) for multi-headed attention si_m: ['1x1', 'att'] type of self-interaction in hidden layers si_e: ['1x1', 'att'] type of self-interaction in final layer x_ij: ['add', 'cat'] use relative position as edge feature """ super().__init__() # Build the network self.num_layers = num_layers self.num_channels = num_channels self.num_degrees = num_degrees self.edge_dim = 1 self.div = div self.n_heads = n_heads self.si_m, self.si_e = si_m, si_e self.x_ij = x_ij self.fibers = {'in': Fiber(dictionary={1: 1}), 'mid': Fiber(self.num_degrees, self.num_channels), 'out': Fiber(dictionary={1: 2})} self.Gblock = self._build_gcn(self.fibers) print(self.Gblock)
def __init__(self, f_in, f_out, edge_dim: int=0, x_ij=None): """SE(3)-equivariant partial convolution. A partial convolution computes the inner product between a kernel and each input channel, without summing over the result from each input channel. This unfolded structure makes it amenable to be used for computing the value-embeddings of the attention mechanism. Args: f_in: list of tuples [(multiplicities, type),...] f_out: list of tuples [(multiplicities, type),...] """ super().__init__() self.f_out = f_out self.edge_dim = edge_dim # adding/concatinating relative position to feature vectors # 'cat' concatenates relative position & existing feature vector # 'add' adds it, but only if multiplicity > 1 assert x_ij in [None, 'cat', 'add'] self.x_ij = x_ij if x_ij == 'cat': self.f_in = Fiber.combine(f_in, Fiber(structure=[(1,1)])) else: self.f_in = f_in # Node -> edge weights self.kernel_unary = nn.ModuleDict() for (mi, di) in self.f_in.structure: for (mo, do) in self.f_out.structure: self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim=edge_dim)
def __init__(self, *, num_layers: int, num_degrees: int = 4, num_channels: int, div: float = 4, n_heads: int = 1, num_iter: int = 3, si_m='1x1', si_e='1x1', x_ij=None, compute_gradients=True, k_neighbors=None, **kwargs): """Iterative SE(3) equivariant GCN with attention Args: num_layers: number of layers per SE3-Transformer block num_degrees: number of degrees (aka types) in hidden layer, count start from type-0 num_channels: number of channels per degree div: (int >= 1) keys, queries and values will have (num_channels/div) channels n_heads: (int >= 1) for multi-headed attention num_iter: number of SE3-Transformer blocks with individual coordinate outputs si_m: ['1x1', 'att'] type of self-interaction in hidden layers si_e: ['1x1', 'att'] type of self-interaction in final layer x_ij: ['add', 'cat'] use relative position as edge feature compute_gradients: [True, False] backpropagate through spherical harmonics computation k_neighbors: attend to K neighbours with strongest interaction kwargs: catch arguments that are not used in this method """ super().__init__() self.num_layers = num_layers self.num_channels = num_channels self.num_degrees = num_degrees self.div = div self.n_heads = n_heads self.si_m = si_m self.si_e = si_e self.x_ij = x_ij self.num_iter = num_iter self.compute_gradients = compute_gradients self.k_neighbors = k_neighbors self.edge_dim = 1 self.fibers = { 'in': Fiber(dictionary={0: 1}), 'mid': Fiber(self.num_degrees, self.num_channels), 'out': Fiber(dictionary={1: 1}) } self._graph_attention_common_params = { 'edge_dim': self.edge_dim, 'learnable_skip': True, 'skip': 'cat', 'x_ij': self.x_ij } self.blocks = self._build_gcn(self.fibers)
def __init__(self, f_in: Fiber, f_out: Fiber, edge_dim: int = 0, div: float = 4, n_heads: int = 1, learnable_skip=True): super().__init__() self.f_in = f_in self.f_out = f_out self.div = div self.n_heads = n_heads # f_mid_out has same structure as 'f_out' but #channels divided by 'div' # this will be used for the values f_mid_out = { k: int(v // div) for k, v in self.f_out.structure_dict.items() } self.f_mid_out = Fiber(dictionary=f_mid_out) # f_mid_in has same structure as f_mid_out, but only degrees which are in f_in # this will be used for keys and queries # (queries are merely projected, hence degrees have to match input) f_mid_in = { d: m for d, m in f_mid_out.items() if d in self.f_in.degrees } self.f_mid_in = Fiber(dictionary=f_mid_in) self.edge_dim = edge_dim self.GMAB = nn.ModuleDict() # Projections self.GMAB['v'] = GConvSE3Partial(f_in, self.f_mid_out, edge_dim=edge_dim) self.GMAB['k'] = GConvSE3Partial(f_in, self.f_mid_in, edge_dim=edge_dim) self.GMAB['q'] = G1x1SE3(f_in, self.f_mid_in) # Attention self.GMAB['attn'] = GMABSE3(self.f_mid_out, self.f_mid_in, n_heads=n_heads) # Skip connections self.project = G1x1SE3(self.f_mid_out, f_out, learnable=learnable_skip) self.add = GSum(f_out, f_in) # the following checks whether the skip connection would change # the output fibre strucure; the reason can be that the input has # more channels than the ouput (for at least one degree); this would # then cause a (hard to debug) error in the next layer assert self.add.f_out.structure_dict == f_out.structure_dict, \ 'skip connection would change output structure'
def __init__(self, num_layers: int, num_channels: int, num_degrees: int = 4, **kwargs): super().__init__() # Build the network self.num_layers = num_layers self.num_channels = num_channels self.num_degrees = num_degrees self.edge_dim = 1 self.fibers = {'in': Fiber(dictionary={0: 1, 1: 1}), 'mid': Fiber(self.num_degrees, self.num_channels), 'out': Fiber(dictionary={1: 2})} blocks = self._build_gcn(self.fibers) self.Gblock, self.FCblock = blocks print(self.Gblock) print(self.FCblock) # purely for counting paramters in utils_logging.py self.enc, self.dec = self.Gblock, self.FCblock
def __init__(self, f_x: Fiber, f_y: Fiber): super().__init__() self.f_x = f_x self.f_y = f_y f_out = {} for k in f_x.degrees: f_out[k] = f_x.dict[k] if k in f_y.degrees: f_out[k] += f_y.dict[k] self.f_out = Fiber(dictionary=f_out)
def __init__(self, f_x: Fiber, f_y: Fiber): """SE(3)-equvariant graph residual sum function. Args: f_x: Fiber() object for fiber of summands f_y: Fiber() object for fiber of summands """ super().__init__() self.f_x = f_x self.f_y = f_y self.f_out = Fiber.combine_max(f_x, f_y)
def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4, si_m='1x1', si_e='att', l0_in_features=32, l0_out_features=32, l1_in_features=3, l1_out_features=3, num_edge_features=32, x_ij=None): super().__init__() # Build the network self.num_layers = num_layers self.num_channels = num_channels self.num_degrees = num_degrees self.edge_dim = num_edge_features self.div = div self.n_heads = n_heads self.si_m, self.si_e = si_m, si_e self.x_ij = x_ij if l1_out_features > 0: fibers = { 'in': Fiber(dictionary={ 0: l0_in_features, 1: l1_in_features }), 'mid': Fiber(self.num_degrees, self.num_channels), 'out': Fiber(dictionary={ 0: l0_out_features, 1: l1_out_features }) } else: fibers = { 'in': Fiber(dictionary={ 0: l0_in_features, 1: l1_in_features }), 'mid': Fiber(self.num_degrees, self.num_channels), 'out': Fiber(dictionary={0: l0_out_features}) } blocks = self._build_gcn(fibers) self.Gblock = blocks
def __init__(self, num_layers=2, num_channels=32, num_nonlin_layers=1, num_degrees=3, l0_in_features=32, l0_out_features=32, l1_in_features=3, l1_out_features=3, num_edge_features=32, use_self=True): super().__init__() # Build the network self.num_layers = num_layers self.num_nlayers = num_nonlin_layers self.num_channels = num_channels self.num_degrees = num_degrees self.edge_dim = num_edge_features self.use_self = use_self if l1_out_features > 0: fibers = { 'in': Fiber(dictionary={ 0: l0_in_features, 1: l1_in_features }), 'mid': Fiber(self.num_degrees, self.num_channels), 'out': Fiber(dictionary={ 0: l0_out_features, 1: l1_out_features }) } else: fibers = { 'in': Fiber(dictionary={ 0: l0_in_features, 1: l1_in_features }), 'mid': Fiber(self.num_degrees, self.num_channels), 'out': Fiber(dictionary={0: l0_out_features}) } blocks = self._build_gcn(fibers) self.block0 = blocks