Example #1
0
    def __init__(self, chin, chout, conv, bn=False, act='swish', fill=None,
                 dropout=0):
        super().__init__()
        assert chin <= chout, f"unsupported channels chin{chin}, " \
                              f"chout{chout}. No upsampling atm."
        nonlinearity = Swish if act == 'swish' else nn.ReLU
        if fill is not None:
            self.conv = conv(chin // 4, chout // 4, fill=fill)
        else:
            self.conv = conv(chin // 4, chout // 4)

        self.net = nn.Sequential(
            Pass(nonlinearity(), dim=1),
            MaskBatchNormNd(chin) if bn else nn.Sequential(),
            Pass(nn.Dropout(dropout) if dropout else nn.Sequential()),
            Pass(nn.Linear(chin, chin // 4), dim=1),
            Pass(nonlinearity(), dim=1),
            MaskBatchNormNd(chin // 4) if bn else nn.Sequential(),
            self.conv,
            Pass(nonlinearity(), dim=1),
            MaskBatchNormNd(chout // 4) if bn else nn.Sequential(),
            Pass(nn.Dropout(dropout) if dropout else nn.Sequential()),
            Pass(nn.Linear(chout // 4, chout), dim=1),
        )
        self.chin = chin
Example #2
0
def LinearBNact(chin, chout, act='swish', bn=True):
    """assumes that the inputs to the net are shape (bs,n,mc_samples,c)"""
    assert act in ('relu', 'swish'), f"unknown activation type {act}"
    normlayer = MaskBatchNormNd(chout)
    return nn.Sequential(Pass(nn.Linear(chin, chout), dim=1),
                         normlayer if bn else nn.Sequential(),
                         Pass(Swish() if act == 'swish' else nn.ReLU(), dim=1))
Example #3
0
 def __init__(self, chin, chout, conv, bn=False, act='swish', fill=None):
     super().__init__()
     assert chin <= chout, f"unsupported channels chin{chin}, chout{chout}"
     nonlinearity = Swish if act == 'swish' else nn.ReLU
     self.conv = conv(chin // 4, chout //
                      4, fill=fill) if fill is not None else conv(
                          chin // 4, chout // 4)
     self.net = nn.Sequential(
         MaskBatchNormNd(chin) if bn else nn.Sequential(),
         Pass(nonlinearity(), dim=1),
         Pass(nn.Linear(chin, chin // 4), dim=1),
         MaskBatchNormNd(chin // 4) if bn else nn.Sequential(),
         Pass(nonlinearity(), dim=1),
         self.conv,
         MaskBatchNormNd(chout // 4) if bn else nn.Sequential(),
         Pass(nonlinearity(), dim=1),
         Pass(nn.Linear(chout // 4, chout), dim=1),
     )
     self.chin = chin
Example #4
0
def MultiheadLinearBNact(c_in, c_out, n_heads, act="swish", bn=True):
    """??(from LieConv - not sure it does assume) assumes that the inputs to the net are shape (bs,n,mc_samples,c)"""
    assert act in ("relu", "swish", "softplus"), f"unknown activation type {act}"
    normlayer = MaskBatchNormNd(c_out)
    return nn.Sequential(
        OrderedDict(
            [
                ("linear", Pass(MultiheadLinear(n_heads, c_in, c_out), dim=1)),
                ("norm", normlayer if bn else nn.Sequential()),
                ("activation", Pass(activation_fn[act](), dim=1)),
            ]
        )
    )
Example #5
0
def LinearBNact(chin, chout, act="swish", bn=True):
    """assumes that the inputs to the net are shape (bs,n,mc_samples,c)"""
    assert act in ("relu", "swish", "softplus"), f"unknown activation type {act}"
    normlayer = MaskBatchNormNd(chout)
    return nn.Sequential(
        OrderedDict(
            [
                ("linear", Pass(nn.Linear(chin, chout), dim=1)),
                ("norm", normlayer if bn else nn.Sequential()),
                ("activation", Pass(activation_fn[act](), dim=1)),
            ]
        )
    )
Example #6
0
    def build_net(self, dim_input, dim_output, ds_frac=1, k=1536, nbhd=np.inf,
                  act="swish", bn=True, num_layers=6, pool=True, liftsamples=1,
                  fill=1 / 4, group=SE3, knn=False, cache=False, dropout=0,
                  **kwargs):
        """
        Arguments:
            dim_input: number of input channels: 1 for MNIST, 3 for RGB images, other
                for non images
            ds_frac: total downsampling to perform throughout the layers of the
                net. In (0,1)
            k: channel width for the network. Can be int (same for all) or array
                to specify individually.
            nbhd: number of samples to use for Monte Carlo estimation (p)
            act:
            bn: whether or not to use batch normalization. Recommended in al
                cases except dynamical systems.
            num_layers: number of BottleNeck Block layers in the network
            pool:
            liftsamples: number of samples to use in lifting. 1 for all groups
                with trivial stabilizer. Otherwise 2+
            fill: specifies the fraction of the input which is included in local
                neighborhood. (can be array to specify a different value for
                each layer)
            group: group to be equivariant to
            knn:
            cache:
            dropout: dropout probability for fully connected layers
        """
        if isinstance(fill, (float, int)):
            fill = [fill] * num_layers
        if isinstance(k, int):
            k = [k] * (num_layers + 1)
        conv = lambda ki, ko, fill: LieConv(
            ki, ko, mc_samples=nbhd, ds_frac=ds_frac, bn=bn, act=act, mean=True,
            group=group, fill=fill, cache=cache, knn=knn)
        layers = nn.ModuleList([
            Pass(nn.Linear(dim_input, k[0]), dim=1),
            *[LieConvBottleBlock(k[i], k[i + 1], conv, bn=bn, act=act,
                                 fill=fill[i],
                                 dropout=dropout) for i in range(num_layers)],
            Pass(nn.ReLU(), dim=1),
            MaskBatchNormNd(k[-1]) if bn else nn.Sequential(),
            Pass(nn.Dropout(p=dropout), dim=1) if dropout else nn.Sequential(),
            GlobalPool(mean=True) if pool else Expression(lambda x: x[1]),
            nn.Linear(k[-1], dim_output)
        ])
        self.group = group
        self.liftsamples = liftsamples

        return layers
Example #7
0
 def __init__(self,
              chin,
              ds_frac=1,
              num_outputs=1,
              k=1536,
              nbhd=np.inf,
              act="swish",
              bn=True,
              num_layers=6,
              mean=True,
              per_point=True,
              pool=True,
              liftsamples=1,
              fill=1 / 32,
              group=SE3,
              knn=False,
              cache=False,
              **kwargs):
     super().__init__()
     if isinstance(fill, (float, int)):
         fill = [fill] * num_layers
     if isinstance(k, int):
         k = [k] * (num_layers + 1)
     conv = lambda ki, ko, fill: LieConv(ki,
                                         ko,
                                         nbhd=nbhd,
                                         ds_frac=ds_frac,
                                         bn=bn,
                                         act=act,
                                         mean=mean,
                                         group=group,
                                         fill=fill,
                                         cache=cache,
                                         knn=knn,
                                         **kwargs)
     self.net = nn.Sequential(
         Pass(nn.Linear(chin, k[0]), dim=1),  #embedding layer
         *[
             BottleBlock(k[i], k[i + 1], conv, bn=bn, act=act, fill=fill[i])
             for i in range(num_layers)
         ],
         #Pass(nn.Linear(k[-1],k[-1]//2),dim=1),
         MaskBatchNormNd(k[-1]) if bn else nn.Sequential(),
         Pass(Swish() if act == 'swish' else nn.ReLU(), dim=1),
         Pass(nn.Linear(k[-1], num_outputs), dim=1),
         GlobalPool(mean=mean) if pool else Expression(lambda x: x[1]),
     )
     self.liftsamples = liftsamples
     self.per_point = per_point
     self.group = group
Example #8
0
def LieConvBNrelu(in_channels, out_channels, bn=True, act='swish', **kwargs):
    return nn.Sequential(
        LieConv(in_channels, out_channels, bn=bn, **kwargs),
        MaskBatchNormNd(out_channels) if bn else nn.Sequential(),
        Pass(Swish() if act == 'swish' else nn.ReLU(), dim=1))
Example #9
0
    def __init__(
        self,
        chin,
        ds_frac=1,
        num_outputs=1,
        k=1536,
        nbhd=np.inf,
        act="swish",
        bn=True,
        num_layers=6,
        mean=True,
        per_point=True,
        pool=True,
        liftsamples=1,
        fill=1 / 4,
        group=SE3,
        knn=False,
        cache=False,
        lie_algebra_nonlinearity=None,
        **kwargs,
    ):
        super().__init__()
        if isinstance(fill, (float, int)):
            fill = [fill] * num_layers
        if isinstance(k, int):
            k = [k] * (num_layers + 1)
        conv = lambda ki, ko, fill: LieConv(
            ki,
            ko,
            mc_samples=nbhd,
            ds_frac=ds_frac,
            bn=bn,
            act=act,
            mean=mean,
            group=group,
            fill=fill,
            cache=cache,
            knn=knn,
            **kwargs,
        )
        self.net = nn.Sequential(
            Pass(nn.Linear(chin, k[0]), dim=1),  # embedding layer
            *[
                BottleBlock(k[i], k[i + 1], conv, bn=bn, act=act, fill=fill[i])
                for i in range(num_layers)
            ],
            MaskBatchNormNd(k[-1]) if bn else nn.Sequential(),
            Pass(Swish() if act == "swish" else nn.ReLU(), dim=1),
            Pass(nn.Linear(k[-1], num_outputs), dim=1),
            GlobalPool(mean=mean) if pool else Expression(lambda x: x[1]),
        )
        self.liftsamples = liftsamples
        self.per_point = per_point
        self.group = group

        self.lie_algebra_nonlinearity = lie_algebra_nonlinearity
        if lie_algebra_nonlinearity is not None:
            if lie_algebra_nonlinearity == "tanh":
                self.lie_algebra_nonlinearity = nn.Tanh()
            else:
                raise ValueError(
                    f"{lie_algebra_nonlinearity} is not a supported nonlinearity"
                )
Example #10
0
def pConvBNrelu(in_channels, out_channels, bn=True, act="swish", **kwargs):
    return nn.Sequential(
        PointConv(in_channels, out_channels, bn=bn, **kwargs),
        MaskBatchNormNd(out_channels) if bn else nn.Sequential(),
        Pass(Swish() if act == "swish" else nn.ReLU(), dim=1),
    )
Example #11
0
    def __init__(
        self,
        dim,
        n_heads,
        group,
        block_norm="layer_pre",
        kernel_norm="none",
        kernel_type="mlp",
        kernel_dim=16,
        kernel_act="swish",
        hidden_dim_factor=1,
        mc_samples=0,
        fill=1.0,
        attention_fn="softmax",
        feature_embed_dim=None,
    ):
        super().__init__()
        self.ema = EquivairantMultiheadAttention(
            dim,
            dim,
            n_heads,
            group,
            kernel_type=kernel_type,
            kernel_dim=kernel_dim,
            act=kernel_act,
            bn=kernel_norm == "batch",
            mc_samples=mc_samples,
            fill=fill,
            attention_fn=attention_fn,
            feature_embed_dim=feature_embed_dim,
        )

        self.mlp = MLP(dim, dim, dim, 2, kernel_act, kernel_norm == "batch")

        if block_norm == "none":
            self.attention_function = lambda inpt: inpt[1] + self.ema(inpt)[1]
            self.mlp_function = lambda inpt: inpt[1] + self.mlp(inpt)[1]
        elif block_norm == "layer_pre":
            self.ln_ema = nn.LayerNorm(dim)
            self.ln_mlp = nn.LayerNorm(dim)

            self.attention_function = (lambda inpt: inpt[1] + self.ema(
                (inpt[0], self.ln_ema(inpt[1]), inpt[2]))[1])
            self.mlp_function = (lambda inpt: inpt[1] + self.mlp(
                (inpt[0], self.ln_mlp(inpt[1]), inpt[2]))[1])
        elif block_norm == "layer_post":
            self.ln_ema = nn.LayerNorm(dim)
            self.ln_mlp = nn.LayerNorm(dim)

            self.attention_function = lambda inpt: inpt[1] + self.ln_ema(
                self.ema(inpt)[1])
            self.mlp_function = lambda inpt: inpt[1] + self.ln_mlp(
                self.mlp(inpt)[1])
        elif block_norm == "batch_pre":
            self.bn_ema = MaskBatchNormNd(dim)
            self.bn_mlp = MaskBatchNormNd(dim)

            self.attention_function = (
                lambda inpt: inpt[1] + self.ema(self.bn_ema(inpt))[1])
            self.mlp_function = lambda inpt: inpt[1] + self.mlp(
                self.bn_mlp(inpt))[1]
        elif block_norm == "batch_post":
            self.bn_ema = MaskBatchNormNd(dim)
            self.bn_mlp = MaskBatchNormNd(dim)

            self.attention_function = (
                lambda inpt: inpt[1] + self.bn_ema(self.ema(inpt))[1])
            self.mlp_function = lambda inpt: inpt[1] + self.bn_mlp(
                self.mlp(inpt))[1]
        else:
            raise ValueError(f"{block_norm} is invalid block norm type.")