示例#1
0
def LinearBNact(chin, chout, act='swish', bn=True):
    """assumes that the inputs to the net are shape (bs,n,mc_samples,c)"""
    assert act in ('relu', 'swish'), f"unknown activation type {act}"
    normlayer = MaskBatchNormNd(chout)
    return nn.Sequential(Pass(nn.Linear(chin, chout), dim=1),
                         normlayer if bn else nn.Sequential(),
                         Pass(Swish() if act == 'swish' else nn.ReLU(), dim=1))
示例#2
0
def MultiheadLinearBNact(c_in, c_out, n_heads, act="swish", bn=True):
    """??(from LieConv - not sure it does assume) assumes that the inputs to the net are shape (bs,n,mc_samples,c)"""
    assert act in ("relu", "swish", "softplus"), f"unknown activation type {act}"
    normlayer = MaskBatchNormNd(c_out)
    return nn.Sequential(
        OrderedDict(
            [
                ("linear", Pass(MultiheadLinear(n_heads, c_in, c_out), dim=1)),
                ("norm", normlayer if bn else nn.Sequential()),
                ("activation", Pass(activation_fn[act](), dim=1)),
            ]
        )
    )
示例#3
0
def LinearBNact(chin, chout, act="swish", bn=True):
    """assumes that the inputs to the net are shape (bs,n,mc_samples,c)"""
    assert act in ("relu", "swish", "softplus"), f"unknown activation type {act}"
    normlayer = MaskBatchNormNd(chout)
    return nn.Sequential(
        OrderedDict(
            [
                ("linear", Pass(nn.Linear(chin, chout), dim=1)),
                ("norm", normlayer if bn else nn.Sequential()),
                ("activation", Pass(activation_fn[act](), dim=1)),
            ]
        )
    )
示例#4
0
    def build_net(self, dim_input, dim_output, ds_frac=1, k=1536, nbhd=np.inf,
                  act="swish", bn=True, num_layers=6, pool=True, liftsamples=1,
                  fill=1 / 4, group=SE3, knn=False, cache=False, dropout=0,
                  **kwargs):
        """
        Arguments:
            dim_input: number of input channels: 1 for MNIST, 3 for RGB images, other
                for non images
            ds_frac: total downsampling to perform throughout the layers of the
                net. In (0,1)
            k: channel width for the network. Can be int (same for all) or array
                to specify individually.
            nbhd: number of samples to use for Monte Carlo estimation (p)
            act:
            bn: whether or not to use batch normalization. Recommended in al
                cases except dynamical systems.
            num_layers: number of BottleNeck Block layers in the network
            pool:
            liftsamples: number of samples to use in lifting. 1 for all groups
                with trivial stabilizer. Otherwise 2+
            fill: specifies the fraction of the input which is included in local
                neighborhood. (can be array to specify a different value for
                each layer)
            group: group to be equivariant to
            knn:
            cache:
            dropout: dropout probability for fully connected layers
        """
        if isinstance(fill, (float, int)):
            fill = [fill] * num_layers
        if isinstance(k, int):
            k = [k] * (num_layers + 1)
        conv = lambda ki, ko, fill: LieConv(
            ki, ko, mc_samples=nbhd, ds_frac=ds_frac, bn=bn, act=act, mean=True,
            group=group, fill=fill, cache=cache, knn=knn)
        layers = nn.ModuleList([
            Pass(nn.Linear(dim_input, k[0]), dim=1),
            *[LieConvBottleBlock(k[i], k[i + 1], conv, bn=bn, act=act,
                                 fill=fill[i],
                                 dropout=dropout) for i in range(num_layers)],
            Pass(nn.ReLU(), dim=1),
            MaskBatchNormNd(k[-1]) if bn else nn.Sequential(),
            Pass(nn.Dropout(p=dropout), dim=1) if dropout else nn.Sequential(),
            GlobalPool(mean=True) if pool else Expression(lambda x: x[1]),
            nn.Linear(k[-1], dim_output)
        ])
        self.group = group
        self.liftsamples = liftsamples

        return layers
示例#5
0
 def __init__(self,
              chin,
              ds_frac=1,
              num_outputs=1,
              k=1536,
              nbhd=np.inf,
              act="swish",
              bn=True,
              num_layers=6,
              mean=True,
              per_point=True,
              pool=True,
              liftsamples=1,
              fill=1 / 32,
              group=SE3,
              knn=False,
              cache=False,
              **kwargs):
     super().__init__()
     if isinstance(fill, (float, int)):
         fill = [fill] * num_layers
     if isinstance(k, int):
         k = [k] * (num_layers + 1)
     conv = lambda ki, ko, fill: LieConv(ki,
                                         ko,
                                         nbhd=nbhd,
                                         ds_frac=ds_frac,
                                         bn=bn,
                                         act=act,
                                         mean=mean,
                                         group=group,
                                         fill=fill,
                                         cache=cache,
                                         knn=knn,
                                         **kwargs)
     self.net = nn.Sequential(
         Pass(nn.Linear(chin, k[0]), dim=1),  #embedding layer
         *[
             BottleBlock(k[i], k[i + 1], conv, bn=bn, act=act, fill=fill[i])
             for i in range(num_layers)
         ],
         #Pass(nn.Linear(k[-1],k[-1]//2),dim=1),
         MaskBatchNormNd(k[-1]) if bn else nn.Sequential(),
         Pass(Swish() if act == 'swish' else nn.ReLU(), dim=1),
         Pass(nn.Linear(k[-1], num_outputs), dim=1),
         GlobalPool(mean=mean) if pool else Expression(lambda x: x[1]),
     )
     self.liftsamples = liftsamples
     self.per_point = per_point
     self.group = group
示例#6
0
    def build_net(self,
                  dim_input,
                  dim_output=1,
                  k=12,
                  nbhd=0,
                  dropout=0.0,
                  num_layers=6,
                  fourier_features=16,
                  norm_coords=True,
                  norm_feats=False,
                  thin_mlps=False,
                  **kwargs):
        m_dim = 12
        layer_class = ThinEGNNLayer if thin_mlps else EGNNLayer
        egnn = lambda: layer_class(dim=k,
                                   m_dim=m_dim,
                                   norm_coors=norm_coords,
                                   norm_feats=norm_feats,
                                   dropout=dropout,
                                   fourier_features=fourier_features,
                                   num_nearest_neighbors=nbhd,
                                   init_eps=1e-2)

        return nn.Sequential(Pass(nn.Linear(dim_input, k), dim=1),
                             *[EGNNPass(egnn()) for _ in range(num_layers)],
                             GlobalPool(mean=True), nn.Linear(k, dim_output))
示例#7
0
    def __init__(self, chin, chout, conv, bn=False, act='swish', fill=None,
                 dropout=0):
        super().__init__()
        assert chin <= chout, f"unsupported channels chin{chin}, " \
                              f"chout{chout}. No upsampling atm."
        nonlinearity = Swish if act == 'swish' else nn.ReLU
        if fill is not None:
            self.conv = conv(chin // 4, chout // 4, fill=fill)
        else:
            self.conv = conv(chin // 4, chout // 4)

        self.net = nn.Sequential(
            Pass(nonlinearity(), dim=1),
            MaskBatchNormNd(chin) if bn else nn.Sequential(),
            Pass(nn.Dropout(dropout) if dropout else nn.Sequential()),
            Pass(nn.Linear(chin, chin // 4), dim=1),
            Pass(nonlinearity(), dim=1),
            MaskBatchNormNd(chin // 4) if bn else nn.Sequential(),
            self.conv,
            Pass(nonlinearity(), dim=1),
            MaskBatchNormNd(chout // 4) if bn else nn.Sequential(),
            Pass(nn.Dropout(dropout) if dropout else nn.Sequential()),
            Pass(nn.Linear(chout // 4, chout), dim=1),
        )
        self.chin = chin
示例#8
0
 def __init__(self, chin, chout, conv, bn=False, act='swish', fill=None):
     super().__init__()
     assert chin <= chout, f"unsupported channels chin{chin}, chout{chout}"
     nonlinearity = Swish if act == 'swish' else nn.ReLU
     self.conv = conv(chin // 4, chout //
                      4, fill=fill) if fill is not None else conv(
                          chin // 4, chout // 4)
     self.net = nn.Sequential(
         MaskBatchNormNd(chin) if bn else nn.Sequential(),
         Pass(nonlinearity(), dim=1),
         Pass(nn.Linear(chin, chin // 4), dim=1),
         MaskBatchNormNd(chin // 4) if bn else nn.Sequential(),
         Pass(nonlinearity(), dim=1),
         self.conv,
         MaskBatchNormNd(chout // 4) if bn else nn.Sequential(),
         Pass(nonlinearity(), dim=1),
         Pass(nn.Linear(chout // 4, chout), dim=1),
     )
     self.chin = chin
示例#9
0
def LieConvBNrelu(in_channels, out_channels, bn=True, act='swish', **kwargs):
    return nn.Sequential(
        LieConv(in_channels, out_channels, bn=bn, **kwargs),
        MaskBatchNormNd(out_channels) if bn else nn.Sequential(),
        Pass(Swish() if act == 'swish' else nn.ReLU(), dim=1))
示例#10
0
    def __init__(
        self,
        chin,
        ds_frac=1,
        num_outputs=1,
        k=1536,
        nbhd=np.inf,
        act="swish",
        bn=True,
        num_layers=6,
        mean=True,
        per_point=True,
        pool=True,
        liftsamples=1,
        fill=1 / 4,
        group=SE3,
        knn=False,
        cache=False,
        lie_algebra_nonlinearity=None,
        **kwargs,
    ):
        super().__init__()
        if isinstance(fill, (float, int)):
            fill = [fill] * num_layers
        if isinstance(k, int):
            k = [k] * (num_layers + 1)
        conv = lambda ki, ko, fill: LieConv(
            ki,
            ko,
            mc_samples=nbhd,
            ds_frac=ds_frac,
            bn=bn,
            act=act,
            mean=mean,
            group=group,
            fill=fill,
            cache=cache,
            knn=knn,
            **kwargs,
        )
        self.net = nn.Sequential(
            Pass(nn.Linear(chin, k[0]), dim=1),  # embedding layer
            *[
                BottleBlock(k[i], k[i + 1], conv, bn=bn, act=act, fill=fill[i])
                for i in range(num_layers)
            ],
            MaskBatchNormNd(k[-1]) if bn else nn.Sequential(),
            Pass(Swish() if act == "swish" else nn.ReLU(), dim=1),
            Pass(nn.Linear(k[-1], num_outputs), dim=1),
            GlobalPool(mean=mean) if pool else Expression(lambda x: x[1]),
        )
        self.liftsamples = liftsamples
        self.per_point = per_point
        self.group = group

        self.lie_algebra_nonlinearity = lie_algebra_nonlinearity
        if lie_algebra_nonlinearity is not None:
            if lie_algebra_nonlinearity == "tanh":
                self.lie_algebra_nonlinearity = nn.Tanh()
            else:
                raise ValueError(
                    f"{lie_algebra_nonlinearity} is not a supported nonlinearity"
                )
示例#11
0
def pConvBNrelu(in_channels, out_channels, bn=True, act="swish", **kwargs):
    return nn.Sequential(
        PointConv(in_channels, out_channels, bn=bn, **kwargs),
        MaskBatchNormNd(out_channels) if bn else nn.Sequential(),
        Pass(Swish() if act == "swish" else nn.ReLU(), dim=1),
    )
示例#12
0
    def __init__(
        self,
        dim_input,
        dim_output,
        dim_hidden,
        num_layers,
        num_heads,
        global_pool=True,
        global_pool_mean=True,
        group=SE3(0.2),
        liftsamples=1,
        block_norm="layer_pre",
        output_norm="none",
        kernel_norm="none",
        kernel_type="mlp",
        kernel_dim=16,
        kernel_act="swish",
        mc_samples=0,
        fill=1.0,
        architecture="model_1",
        attention_fn="softmax",  # softmax or dot product? XXX: TODO: "dot product" is used to describe both the attention weights being non-softmax (non-local attention paper) and the feature kernel. should fix terminology
        feature_embed_dim=None,
        max_sample_norm=None,
        lie_algebra_nonlinearity=None,
    ):
        super().__init__()

        if isinstance(dim_hidden, int):
            dim_hidden = [dim_hidden] * (num_layers + 1)

        if isinstance(num_heads, int):
            num_heads = [num_heads] * num_layers

        attention_block = lambda dim, n_head: EquivariantTransformerBlock(
            dim,
            n_head,
            group,
            block_norm=block_norm,
            kernel_norm=kernel_norm,
            kernel_type=kernel_type,
            kernel_dim=kernel_dim,
            kernel_act=kernel_act,
            mc_samples=mc_samples,
            fill=fill,
            attention_fn=attention_fn,
            feature_embed_dim=feature_embed_dim,
        )

        activation_fn = {
            "swish": Swish,
            "relu": nn.ReLU,
            "softplus": nn.Softplus,
        }

        if architecture == "model_1":
            if output_norm == "batch":
                norm1 = nn.BatchNorm1d(dim_hidden[-1])
                norm2 = nn.BatchNorm1d(dim_hidden[-1])
                norm3 = nn.BatchNorm1d(dim_hidden[-1])
            elif output_norm == "layer":
                norm1 = nn.LayerNorm(dim_hidden[-1])
                norm2 = nn.LayerNorm(dim_hidden[-1])
                norm3 = nn.LayerNorm(dim_hidden[-1])
            elif output_norm == "none":
                norm1 = nn.Sequential()
                norm2 = nn.Sequential()
                norm3 = nn.Sequential()
            else:
                raise ValueError(f"{output_norm} is not a valid norm type.")

            self.net = nn.Sequential(
                Pass(nn.Linear(dim_input, dim_hidden[0]), dim=1),
                *[
                    attention_block(dim_hidden[i], num_heads[i])
                    for i in range(num_layers)
                ],
                GlobalPool(mean=global_pool_mean)
                if global_pool else Expression(lambda x: x[1]),
                nn.Sequential(
                    norm1,
                    activation_fn[kernel_act](),
                    nn.Linear(dim_hidden[-1], dim_hidden[-1]),
                    norm2,
                    activation_fn[kernel_act](),
                    nn.Linear(dim_hidden[-1], dim_hidden[-1]),
                    norm3,
                    activation_fn[kernel_act](),
                    nn.Linear(dim_hidden[-1], dim_output),
                ),
            )
        elif architecture == "lieconv":
            if output_norm == "batch":
                norm = nn.BatchNorm1d(dim_hidden[-1])
            elif output_norm == "none":
                norm = nn.Sequential()
            else:
                raise ValueError(f"{output_norm} is not a valid norm type.")

            self.net = nn.Sequential(
                Pass(nn.Linear(dim_input, dim_hidden[0]), dim=1),
                *[
                    attention_block(dim_hidden[i], num_heads[i])
                    for i in range(num_layers)
                ],
                nn.Sequential(
                    OrderedDict([
                        # ("norm", Pass(norm, dim=1)),
                        (
                            "activation",
                            Pass(
                                activation_fn[kernel_act](),
                                dim=1,
                            ),
                        ),
                        (
                            "linear",
                            Pass(nn.Linear(dim_hidden[-1], dim_output), dim=1),
                        ),
                    ])),
                GlobalPool(mean=global_pool_mean)
                if global_pool else Expression(lambda x: x[1]),
            )
        else:
            raise ValueError(f"{architecture} is not a valid architecture.")

        self.group = group
        self.liftsamples = liftsamples
        self.max_sample_norm = max_sample_norm

        self.lie_algebra_nonlinearity = lie_algebra_nonlinearity
        if lie_algebra_nonlinearity is not None:
            if lie_algebra_nonlinearity == "tanh":
                self.lie_algebra_nonlinearity = nn.Tanh()
            else:
                raise ValueError(
                    f"{lie_algebra_nonlinearity} is not a supported nonlinearity"
                )
示例#13
0
    def build_net(self,
                  dim_input,
                  dim_output,
                  dim_hidden,
                  num_layers,
                  num_heads,
                  global_pool_mean=True,
                  group=SE3(0.2),
                  liftsamples=1,
                  block_norm="layer_pre",
                  kernel_norm="none",
                  kernel_type="mlp",
                  kernel_dim=16,
                  kernel_act="swish",
                  nbhd=0,
                  fill=1.0,
                  attention_fn="norm_exp",
                  feature_embed_dim=None,
                  max_sample_norm=None,
                  lie_algebra_nonlinearity=None,
                  **kwargs):

        if isinstance(dim_hidden, int):
            dim_hidden = [dim_hidden] * (num_layers + 1)

        if isinstance(num_heads, int):
            num_heads = [num_heads] * num_layers

        attention_block = lambda dim, n_head: EquivariantTransformerBlock(
            dim,
            n_head,
            group,
            block_norm=block_norm,
            kernel_norm=kernel_norm,
            kernel_type=kernel_type,
            kernel_dim=kernel_dim,
            kernel_act=kernel_act,
            mc_samples=nbhd,
            fill=fill,
            attention_fn=attention_fn,
            feature_embed_dim=feature_embed_dim,
        )

        layers = nn.Sequential(
            Pass(nn.Linear(dim_input, dim_hidden[0]), dim=1), *[
                attention_block(dim_hidden[i], num_heads[i])
                for i in range(num_layers)
            ], GlobalPool(mean=global_pool_mean),
            nn.Linear(dim_hidden[-1], dim_output))

        self.group = group
        self.liftsamples = liftsamples
        self.max_sample_norm = max_sample_norm

        self.lie_algebra_nonlinearity = lie_algebra_nonlinearity
        if lie_algebra_nonlinearity is not None:
            if lie_algebra_nonlinearity == 'tanh':
                self.lie_algebra_nonlinearity = nn.Tanh()
            else:
                raise ValueError('{} is not a supported nonlinearity'.format(
                    lie_algebra_nonlinearity))

        return layers