Exemplo n.º 1
0
    def __init__(self,
                 num_layers: int,
                 atom_feature_size: int,
                 num_channels: int,
                 num_nlayers: int = 1,
                 num_degrees: int = 4,
                 dim_output: int = 3,
                 edge_dim: int = 4,
                 div: float = 4,
                 pooling: str = 'avg',
                 n_heads: int = 1,
                 **kwargs):
        super().__init__()
        # Build the network
        self.num_layers = num_layers
        self.num_nlayers = num_nlayers
        self.num_channels = num_channels
        self.num_degrees = num_degrees
        self.edge_dim = edge_dim
        self.div = div
        self.pooling = pooling
        self.n_heads = n_heads
        self.dim_output = dim_output

        self.fibers = {
            'in': Fiber(1, atom_feature_size),
            'mid': Fiber(num_degrees, self.num_channels),
            'out': Fiber(1, num_degrees * self.num_channels)
        }
        blocks = self._build_gcn(self.fibers, dim_output)
        self.Gblock, self.FCblock = blocks
        print(self.Gblock)
        print(self.FCblock)
Exemplo n.º 2
0
    def __init__(self,
                 num_layers: int,
                 atom_feature_size: int,
                 num_channels: int,
                 num_nlayers: int = 1,
                 num_degrees: int = 4,
                 edge_dim: int = 4,
                 **kwargs):
        super().__init__()
        # Build the network
        self.num_layers = num_layers
        self.num_nlayers = num_nlayers
        self.num_channels = num_channels
        self.num_degrees = num_degrees
        self.num_channels_out = num_channels * num_degrees
        self.edge_dim = edge_dim

        self.fibers = {
            'in': Fiber(1, atom_feature_size),
            'mid': Fiber(num_degrees, self.num_channels),
            'out': Fiber(1, self.num_channels_out)
        }

        blocks = self._build_gcn(self.fibers, 1)
        self.block0, self.block1, self.block2 = blocks
        print(self.block0)
        print(self.block1)
        print(self.block2)
Exemplo n.º 3
0
    def __init__(self, f_in: Fiber, f_out: Fiber, edge_dim: int=0, div: float=4,
                 n_heads: int=1):
        super().__init__()
        self.f_in = f_in
        self.f_out = f_out
        self.div = div
        self.n_heads = n_heads

        f_mid_out = {k: int(v // div) for k, v in self.f_out.structure_dict.items()}
        self.f_mid_out = Fiber(dictionary=f_mid_out)

        f_mid_in = {d: m for d, m in f_mid_out.items() if d in self.f_in.degrees}
        self.f_mid_in = Fiber(dictionary=f_mid_in)

        self.edge_dim = edge_dim

        self.GMAB = nn.ModuleDict()

        # Projections
        self.GMAB['v'] = GConvSE3Partial(f_in, self.f_mid_out, edge_dim=edge_dim)
        self.GMAB['k'] = GConvSE3Partial(f_in, self.f_mid_in, edge_dim=edge_dim)
        self.GMAB['q'] = G1x1SE3(f_in, self.f_mid_in)

        # Attention
        self.GMAB['attn'] = GMABSE3(self.f_mid_out, self.f_mid_in, n_heads=n_heads)

        # Skip connections
        self.project = G1x1SE3(self.f_mid_out, f_out)
        self.add = GSum(f_out, f_in)
Exemplo n.º 4
0
    def __init__(self, num_layers: int, num_channels: int, num_degrees: int = 4, div: float = 4,
                 n_heads: int = 1, si_m='1x1', si_e='att', x_ij='add'):
        """
        Args:
            num_layers: number of attention layers
            num_channels: number of channels per degree
            num_degrees: number of degrees (aka types) in hidden layer, count start from type-0
            div: (int >= 1) keys, queries and values will have (num_channels/div) channels
            n_heads: (int >= 1) for multi-headed attention
            si_m: ['1x1', 'att'] type of self-interaction in hidden layers
            si_e: ['1x1', 'att'] type of self-interaction in final layer
            x_ij: ['add', 'cat'] use relative position as edge feature
        """
        super().__init__()
        # Build the network
        self.num_layers = num_layers
        self.num_channels = num_channels
        self.num_degrees = num_degrees
        self.edge_dim = 1
        self.div = div
        self.n_heads = n_heads
        self.si_m, self.si_e = si_m, si_e
        self.x_ij = x_ij

        self.fibers = {'in': Fiber(dictionary={1: 1}),
                       'mid': Fiber(self.num_degrees, self.num_channels),
                       'out': Fiber(dictionary={1: 2})}

        self.Gblock = self._build_gcn(self.fibers)
        print(self.Gblock)
Exemplo n.º 5
0
    def __init__(self, f_in, f_out, edge_dim: int=0, x_ij=None):
        """SE(3)-equivariant partial convolution.

        A partial convolution computes the inner product between a kernel and
        each input channel, without summing over the result from each input
        channel. This unfolded structure makes it amenable to be used for
        computing the value-embeddings of the attention mechanism.

        Args:
            f_in: list of tuples [(multiplicities, type),...]
            f_out: list of tuples [(multiplicities, type),...]
        """
        super().__init__()
        self.f_out = f_out
        self.edge_dim = edge_dim

        # adding/concatinating relative position to feature vectors
        # 'cat' concatenates relative position & existing feature vector
        # 'add' adds it, but only if multiplicity > 1
        assert x_ij in [None, 'cat', 'add']
        self.x_ij = x_ij
        if x_ij == 'cat':
            self.f_in = Fiber.combine(f_in, Fiber(structure=[(1,1)]))
        else:
            self.f_in = f_in

        # Node -> edge weights
        self.kernel_unary = nn.ModuleDict()
        for (mi, di) in self.f_in.structure:
            for (mo, do) in self.f_out.structure:
                self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim=edge_dim)
Exemplo n.º 6
0
    def __init__(self,
                 *,
                 num_layers: int,
                 num_degrees: int = 4,
                 num_channels: int,
                 div: float = 4,
                 n_heads: int = 1,
                 num_iter: int = 3,
                 si_m='1x1',
                 si_e='1x1',
                 x_ij=None,
                 compute_gradients=True,
                 k_neighbors=None,
                 **kwargs):
        """Iterative SE(3) equivariant GCN with attention

        Args:
            num_layers: number of layers per SE3-Transformer block
            num_degrees: number of degrees (aka types) in hidden layer, count start from type-0
            num_channels: number of channels per degree
            div: (int >= 1) keys, queries and values will have (num_channels/div) channels
            n_heads: (int >= 1) for multi-headed attention
            num_iter: number of SE3-Transformer blocks with individual coordinate outputs
            si_m: ['1x1', 'att'] type of self-interaction in hidden layers
            si_e: ['1x1', 'att'] type of self-interaction in final layer
            x_ij: ['add', 'cat'] use relative position as edge feature
            compute_gradients: [True, False] backpropagate through spherical harmonics computation
            k_neighbors: attend to K neighbours with strongest interaction
            kwargs: catch arguments that are not used in this method
        """
        super().__init__()

        self.num_layers = num_layers
        self.num_channels = num_channels
        self.num_degrees = num_degrees
        self.div = div
        self.n_heads = n_heads
        self.si_m = si_m
        self.si_e = si_e
        self.x_ij = x_ij
        self.num_iter = num_iter
        self.compute_gradients = compute_gradients
        self.k_neighbors = k_neighbors

        self.edge_dim = 1
        self.fibers = {
            'in': Fiber(dictionary={0: 1}),
            'mid': Fiber(self.num_degrees, self.num_channels),
            'out': Fiber(dictionary={1: 1})
        }

        self._graph_attention_common_params = {
            'edge_dim': self.edge_dim,
            'learnable_skip': True,
            'skip': 'cat',
            'x_ij': self.x_ij
        }

        self.blocks = self._build_gcn(self.fibers)
Exemplo n.º 7
0
    def __init__(self,
                 f_in: Fiber,
                 f_out: Fiber,
                 edge_dim: int = 0,
                 div: float = 4,
                 n_heads: int = 1,
                 learnable_skip=True):
        super().__init__()
        self.f_in = f_in
        self.f_out = f_out
        self.div = div
        self.n_heads = n_heads

        # f_mid_out has same structure as 'f_out' but #channels divided by 'div'
        # this will be used for the values
        f_mid_out = {
            k: int(v // div)
            for k, v in self.f_out.structure_dict.items()
        }
        self.f_mid_out = Fiber(dictionary=f_mid_out)

        # f_mid_in has same structure as f_mid_out, but only degrees which are in f_in
        # this will be used for keys and queries
        # (queries are merely projected, hence degrees have to match input)
        f_mid_in = {
            d: m
            for d, m in f_mid_out.items() if d in self.f_in.degrees
        }
        self.f_mid_in = Fiber(dictionary=f_mid_in)

        self.edge_dim = edge_dim

        self.GMAB = nn.ModuleDict()

        # Projections
        self.GMAB['v'] = GConvSE3Partial(f_in,
                                         self.f_mid_out,
                                         edge_dim=edge_dim)
        self.GMAB['k'] = GConvSE3Partial(f_in,
                                         self.f_mid_in,
                                         edge_dim=edge_dim)
        self.GMAB['q'] = G1x1SE3(f_in, self.f_mid_in)

        # Attention
        self.GMAB['attn'] = GMABSE3(self.f_mid_out,
                                    self.f_mid_in,
                                    n_heads=n_heads)

        # Skip connections
        self.project = G1x1SE3(self.f_mid_out, f_out, learnable=learnable_skip)
        self.add = GSum(f_out, f_in)
        # the following checks whether the skip connection would change
        # the output fibre strucure; the reason can be that the input has
        # more channels than the ouput (for at least one degree); this would
        # then cause a (hard to debug) error in the next layer
        assert self.add.f_out.structure_dict == f_out.structure_dict, \
            'skip connection would change output structure'
Exemplo n.º 8
0
    def __init__(self, num_layers: int, num_channels: int, num_degrees: int = 4, **kwargs):
        super().__init__()
        # Build the network
        self.num_layers = num_layers
        self.num_channels = num_channels
        self.num_degrees = num_degrees
        self.edge_dim = 1

        self.fibers = {'in': Fiber(dictionary={0: 1, 1: 1}),
                       'mid': Fiber(self.num_degrees, self.num_channels),
                       'out': Fiber(dictionary={1: 2})}

        blocks = self._build_gcn(self.fibers)
        self.Gblock, self.FCblock = blocks
        print(self.Gblock)
        print(self.FCblock)
        # purely for counting paramters in utils_logging.py
        self.enc, self.dec = self.Gblock, self.FCblock
Exemplo n.º 9
0
 def __init__(self, f_x: Fiber, f_y: Fiber):
     super().__init__()
     self.f_x = f_x
     self.f_y = f_y
     f_out = {}
     for k in f_x.degrees:
         f_out[k] = f_x.dict[k]
         if k in f_y.degrees:
             f_out[k] += f_y.dict[k]
     self.f_out = Fiber(dictionary=f_out)
Exemplo n.º 10
0
    def __init__(self, f_x: Fiber, f_y: Fiber):
        """SE(3)-equvariant graph residual sum function.

        Args:
            f_x: Fiber() object for fiber of summands
            f_y: Fiber() object for fiber of summands
        """
        super().__init__()
        self.f_x = f_x
        self.f_y = f_y
        self.f_out = Fiber.combine_max(f_x, f_y)
Exemplo n.º 11
0
    def __init__(self,
                 num_layers=2,
                 num_channels=32,
                 num_degrees=3,
                 n_heads=4,
                 div=4,
                 si_m='1x1',
                 si_e='att',
                 l0_in_features=32,
                 l0_out_features=32,
                 l1_in_features=3,
                 l1_out_features=3,
                 num_edge_features=32,
                 x_ij=None):
        super().__init__()
        # Build the network
        self.num_layers = num_layers
        self.num_channels = num_channels
        self.num_degrees = num_degrees
        self.edge_dim = num_edge_features
        self.div = div
        self.n_heads = n_heads
        self.si_m, self.si_e = si_m, si_e
        self.x_ij = x_ij

        if l1_out_features > 0:
            fibers = {
                'in': Fiber(dictionary={
                    0: l0_in_features,
                    1: l1_in_features
                }),
                'mid': Fiber(self.num_degrees, self.num_channels),
                'out': Fiber(dictionary={
                    0: l0_out_features,
                    1: l1_out_features
                })
            }
        else:
            fibers = {
                'in': Fiber(dictionary={
                    0: l0_in_features,
                    1: l1_in_features
                }),
                'mid': Fiber(self.num_degrees, self.num_channels),
                'out': Fiber(dictionary={0: l0_out_features})
            }

        blocks = self._build_gcn(fibers)
        self.Gblock = blocks
Exemplo n.º 12
0
    def __init__(self,
                 num_layers=2,
                 num_channels=32,
                 num_nonlin_layers=1,
                 num_degrees=3,
                 l0_in_features=32,
                 l0_out_features=32,
                 l1_in_features=3,
                 l1_out_features=3,
                 num_edge_features=32,
                 use_self=True):
        super().__init__()
        # Build the network
        self.num_layers = num_layers
        self.num_nlayers = num_nonlin_layers
        self.num_channels = num_channels
        self.num_degrees = num_degrees
        self.edge_dim = num_edge_features
        self.use_self = use_self

        if l1_out_features > 0:
            fibers = {
                'in': Fiber(dictionary={
                    0: l0_in_features,
                    1: l1_in_features
                }),
                'mid': Fiber(self.num_degrees, self.num_channels),
                'out': Fiber(dictionary={
                    0: l0_out_features,
                    1: l1_out_features
                })
            }
        else:
            fibers = {
                'in': Fiber(dictionary={
                    0: l0_in_features,
                    1: l1_in_features
                }),
                'mid': Fiber(self.num_degrees, self.num_channels),
                'out': Fiber(dictionary={0: l0_out_features})
            }
        blocks = self._build_gcn(fibers)
        self.block0 = blocks