Exemplo n.º 1
0
    def _build_gcn(self, fibers, out_dim):

        block0 = []
        fin = fibers['in']
        for i in range(self.num_layers - 1):
            block0.append(
                GConvSE3(fin,
                         fibers['mid'],
                         self_interaction=True,
                         edge_dim=self.edge_dim))
            block0.append(GNormSE3(fibers['mid'], num_layers=self.num_nlayers))
            fin = fibers['mid']
        block0.append(
            GConvSE3(fibers['mid'],
                     fibers['out'],
                     self_interaction=True,
                     edge_dim=self.edge_dim))

        block1 = [GMaxPooling()]

        block2 = []
        block2.append(nn.Linear(self.num_channels_out, self.num_channels_out))
        block2.append(nn.ReLU(inplace=True))
        block2.append(nn.Linear(self.num_channels_out, out_dim))

        return nn.ModuleList(block0), nn.ModuleList(block1), nn.ModuleList(
            block2)
Exemplo n.º 2
0
    def _build_gcn(self, fibers):
        # Equivariant layers
        Gblock = []
        fin = fibers['in']

        for i in range(self.num_layers-1):
            Gblock.append(GConvSE3(fin, fibers['mid'], self_interaction=True, flavor='TFN', edge_dim=self.edge_dim))
            Gblock.append(GNormBias(fibers['mid']))
            fin = fibers['mid']
        Gblock.append(
            GConvSE3(fibers['mid'], fibers['out'], self_interaction=True, flavor='TFN', edge_dim=self.edge_dim))

        return nn.ModuleList(Gblock), nn.ModuleList([])
Exemplo n.º 3
0
    def _build_gcn(self, fibers, out_dim):
        # Equivariant layers
        Gblock = []
        fin = fibers['in']
        for i in range(self.num_layers):
            Gblock.append(
                GSE3Res(fin,
                        fibers['mid'],
                        edge_dim=self.edge_dim,
                        div=self.div,
                        n_heads=self.n_heads))
            Gblock.append(GNormSE3(fibers['mid']))
            fin = fibers['mid']
        Gblock.append(
            GConvSE3(fibers['mid'],
                     fibers['out'],
                     self_interaction=True,
                     edge_dim=self.edge_dim))

        # Pooling
        if self.pooling == 'avg':
            Gblock.append(GAvgPooling())
        elif self.pooling == 'max':
            Gblock.append(GMaxPooling())

        # FC layers
        FCblock = []
        FCblock.append(
            nn.Linear(self.fibers['out'].n_features,
                      self.fibers['out'].n_features))
        FCblock.append(nn.ReLU(inplace=True))
        FCblock.append(nn.Linear(self.fibers['out'].n_features, out_dim))

        return nn.ModuleList(Gblock), nn.ModuleList(FCblock)
Exemplo n.º 4
0
    def _build_gcn(self, fibers):

        block0 = []
        fin = fibers['in']
        for i in range(self.num_layers - 1):
            block0.append(
                GConvSE3(fin,
                         fibers['mid'],
                         self_interaction=self.use_self,
                         edge_dim=self.edge_dim))
            block0.append(GNormSE3(fibers['mid'], num_layers=self.num_nlayers))
            fin = fibers['mid']
        block0.append(
            GConvSE3(fibers['mid'],
                     fibers['out'],
                     self_interaction=self.use_self,
                     edge_dim=self.edge_dim))
        return nn.ModuleList(block0)
Exemplo n.º 5
0
    def _build_gcn(self, fibers, out_dim):
        # Equivariant layers
        Gblock = []
        fin = fibers['in']
        for i in range(self.num_layers):
            Gblock.append(
                GSE3Res(fin,
                        fibers['mid'],
                        edge_dim=self.edge_dim,
                        div=self.div,
                        n_heads=self.n_heads))
            Gblock.append(GNormSE3(fibers['mid']))
            fin = fibers['mid']
        Gblock.append(
            GConvSE3(fibers['mid'],
                     fibers['out'],
                     self_interaction=True,
                     edge_dim=self.edge_dim))

        Finblock = []
        Finblock.append(OutEncoder())

        return nn.ModuleList(Gblock), nn.ModuleList(Finblock)