Esempio n. 1
0
    def build_net(self, dim_input, dim_output, ds_frac=1, k=1536, nbhd=np.inf,
                  act="swish", bn=True, num_layers=6, pool=True, liftsamples=1,
                  fill=1 / 4, group=SE3, knn=False, cache=False, dropout=0,
                  **kwargs):
        """
        Arguments:
            dim_input: number of input channels: 1 for MNIST, 3 for RGB images, other
                for non images
            ds_frac: total downsampling to perform throughout the layers of the
                net. In (0,1)
            k: channel width for the network. Can be int (same for all) or array
                to specify individually.
            nbhd: number of samples to use for Monte Carlo estimation (p)
            act:
            bn: whether or not to use batch normalization. Recommended in al
                cases except dynamical systems.
            num_layers: number of BottleNeck Block layers in the network
            pool:
            liftsamples: number of samples to use in lifting. 1 for all groups
                with trivial stabilizer. Otherwise 2+
            fill: specifies the fraction of the input which is included in local
                neighborhood. (can be array to specify a different value for
                each layer)
            group: group to be equivariant to
            knn:
            cache:
            dropout: dropout probability for fully connected layers
        """
        if isinstance(fill, (float, int)):
            fill = [fill] * num_layers
        if isinstance(k, int):
            k = [k] * (num_layers + 1)
        conv = lambda ki, ko, fill: LieConv(
            ki, ko, mc_samples=nbhd, ds_frac=ds_frac, bn=bn, act=act, mean=True,
            group=group, fill=fill, cache=cache, knn=knn)
        layers = nn.ModuleList([
            Pass(nn.Linear(dim_input, k[0]), dim=1),
            *[LieConvBottleBlock(k[i], k[i + 1], conv, bn=bn, act=act,
                                 fill=fill[i],
                                 dropout=dropout) for i in range(num_layers)],
            Pass(nn.ReLU(), dim=1),
            MaskBatchNormNd(k[-1]) if bn else nn.Sequential(),
            Pass(nn.Dropout(p=dropout), dim=1) if dropout else nn.Sequential(),
            GlobalPool(mean=True) if pool else Expression(lambda x: x[1]),
            nn.Linear(k[-1], dim_output)
        ])
        self.group = group
        self.liftsamples = liftsamples

        return layers
Esempio n. 2
0
 def __init__(self,
              chin,
              ds_frac=1,
              num_outputs=1,
              k=1536,
              nbhd=np.inf,
              act="swish",
              bn=True,
              num_layers=6,
              mean=True,
              per_point=True,
              pool=True,
              liftsamples=1,
              fill=1 / 32,
              group=SE3,
              knn=False,
              cache=False,
              **kwargs):
     super().__init__()
     if isinstance(fill, (float, int)):
         fill = [fill] * num_layers
     if isinstance(k, int):
         k = [k] * (num_layers + 1)
     conv = lambda ki, ko, fill: LieConv(ki,
                                         ko,
                                         nbhd=nbhd,
                                         ds_frac=ds_frac,
                                         bn=bn,
                                         act=act,
                                         mean=mean,
                                         group=group,
                                         fill=fill,
                                         cache=cache,
                                         knn=knn,
                                         **kwargs)
     self.net = nn.Sequential(
         Pass(nn.Linear(chin, k[0]), dim=1),  #embedding layer
         *[
             BottleBlock(k[i], k[i + 1], conv, bn=bn, act=act, fill=fill[i])
             for i in range(num_layers)
         ],
         #Pass(nn.Linear(k[-1],k[-1]//2),dim=1),
         MaskBatchNormNd(k[-1]) if bn else nn.Sequential(),
         Pass(Swish() if act == 'swish' else nn.ReLU(), dim=1),
         Pass(nn.Linear(k[-1], num_outputs), dim=1),
         GlobalPool(mean=mean) if pool else Expression(lambda x: x[1]),
     )
     self.liftsamples = liftsamples
     self.per_point = per_point
     self.group = group
Esempio n. 3
0
def Swish():
    return Expression(lambda x: x * torch.sigmoid(x))
Esempio n. 4
0
    def __init__(
        self,
        chin,
        ds_frac=1,
        num_outputs=1,
        k=1536,
        nbhd=np.inf,
        act="swish",
        bn=True,
        num_layers=6,
        mean=True,
        per_point=True,
        pool=True,
        liftsamples=1,
        fill=1 / 4,
        group=SE3,
        knn=False,
        cache=False,
        lie_algebra_nonlinearity=None,
        **kwargs,
    ):
        super().__init__()
        if isinstance(fill, (float, int)):
            fill = [fill] * num_layers
        if isinstance(k, int):
            k = [k] * (num_layers + 1)
        conv = lambda ki, ko, fill: LieConv(
            ki,
            ko,
            mc_samples=nbhd,
            ds_frac=ds_frac,
            bn=bn,
            act=act,
            mean=mean,
            group=group,
            fill=fill,
            cache=cache,
            knn=knn,
            **kwargs,
        )
        self.net = nn.Sequential(
            Pass(nn.Linear(chin, k[0]), dim=1),  # embedding layer
            *[
                BottleBlock(k[i], k[i + 1], conv, bn=bn, act=act, fill=fill[i])
                for i in range(num_layers)
            ],
            MaskBatchNormNd(k[-1]) if bn else nn.Sequential(),
            Pass(Swish() if act == "swish" else nn.ReLU(), dim=1),
            Pass(nn.Linear(k[-1], num_outputs), dim=1),
            GlobalPool(mean=mean) if pool else Expression(lambda x: x[1]),
        )
        self.liftsamples = liftsamples
        self.per_point = per_point
        self.group = group

        self.lie_algebra_nonlinearity = lie_algebra_nonlinearity
        if lie_algebra_nonlinearity is not None:
            if lie_algebra_nonlinearity == "tanh":
                self.lie_algebra_nonlinearity = nn.Tanh()
            else:
                raise ValueError(
                    f"{lie_algebra_nonlinearity} is not a supported nonlinearity"
                )
Esempio n. 5
0
    def __init__(
        self,
        dim_input,
        dim_output,
        dim_hidden,
        num_layers,
        num_heads,
        global_pool=True,
        global_pool_mean=True,
        group=SE3(0.2),
        liftsamples=1,
        block_norm="layer_pre",
        output_norm="none",
        kernel_norm="none",
        kernel_type="mlp",
        kernel_dim=16,
        kernel_act="swish",
        mc_samples=0,
        fill=1.0,
        architecture="model_1",
        attention_fn="softmax",  # softmax or dot product? XXX: TODO: "dot product" is used to describe both the attention weights being non-softmax (non-local attention paper) and the feature kernel. should fix terminology
        feature_embed_dim=None,
        max_sample_norm=None,
        lie_algebra_nonlinearity=None,
    ):
        super().__init__()

        if isinstance(dim_hidden, int):
            dim_hidden = [dim_hidden] * (num_layers + 1)

        if isinstance(num_heads, int):
            num_heads = [num_heads] * num_layers

        attention_block = lambda dim, n_head: EquivariantTransformerBlock(
            dim,
            n_head,
            group,
            block_norm=block_norm,
            kernel_norm=kernel_norm,
            kernel_type=kernel_type,
            kernel_dim=kernel_dim,
            kernel_act=kernel_act,
            mc_samples=mc_samples,
            fill=fill,
            attention_fn=attention_fn,
            feature_embed_dim=feature_embed_dim,
        )

        activation_fn = {
            "swish": Swish,
            "relu": nn.ReLU,
            "softplus": nn.Softplus,
        }

        if architecture == "model_1":
            if output_norm == "batch":
                norm1 = nn.BatchNorm1d(dim_hidden[-1])
                norm2 = nn.BatchNorm1d(dim_hidden[-1])
                norm3 = nn.BatchNorm1d(dim_hidden[-1])
            elif output_norm == "layer":
                norm1 = nn.LayerNorm(dim_hidden[-1])
                norm2 = nn.LayerNorm(dim_hidden[-1])
                norm3 = nn.LayerNorm(dim_hidden[-1])
            elif output_norm == "none":
                norm1 = nn.Sequential()
                norm2 = nn.Sequential()
                norm3 = nn.Sequential()
            else:
                raise ValueError(f"{output_norm} is not a valid norm type.")

            self.net = nn.Sequential(
                Pass(nn.Linear(dim_input, dim_hidden[0]), dim=1),
                *[
                    attention_block(dim_hidden[i], num_heads[i])
                    for i in range(num_layers)
                ],
                GlobalPool(mean=global_pool_mean)
                if global_pool else Expression(lambda x: x[1]),
                nn.Sequential(
                    norm1,
                    activation_fn[kernel_act](),
                    nn.Linear(dim_hidden[-1], dim_hidden[-1]),
                    norm2,
                    activation_fn[kernel_act](),
                    nn.Linear(dim_hidden[-1], dim_hidden[-1]),
                    norm3,
                    activation_fn[kernel_act](),
                    nn.Linear(dim_hidden[-1], dim_output),
                ),
            )
        elif architecture == "lieconv":
            if output_norm == "batch":
                norm = nn.BatchNorm1d(dim_hidden[-1])
            elif output_norm == "none":
                norm = nn.Sequential()
            else:
                raise ValueError(f"{output_norm} is not a valid norm type.")

            self.net = nn.Sequential(
                Pass(nn.Linear(dim_input, dim_hidden[0]), dim=1),
                *[
                    attention_block(dim_hidden[i], num_heads[i])
                    for i in range(num_layers)
                ],
                nn.Sequential(
                    OrderedDict([
                        # ("norm", Pass(norm, dim=1)),
                        (
                            "activation",
                            Pass(
                                activation_fn[kernel_act](),
                                dim=1,
                            ),
                        ),
                        (
                            "linear",
                            Pass(nn.Linear(dim_hidden[-1], dim_output), dim=1),
                        ),
                    ])),
                GlobalPool(mean=global_pool_mean)
                if global_pool else Expression(lambda x: x[1]),
            )
        else:
            raise ValueError(f"{architecture} is not a valid architecture.")

        self.group = group
        self.liftsamples = liftsamples
        self.max_sample_norm = max_sample_norm

        self.lie_algebra_nonlinearity = lie_algebra_nonlinearity
        if lie_algebra_nonlinearity is not None:
            if lie_algebra_nonlinearity == "tanh":
                self.lie_algebra_nonlinearity = nn.Tanh()
            else:
                raise ValueError(
                    f"{lie_algebra_nonlinearity} is not a supported nonlinearity"
                )
    def __init__(
        self,
        feature_dim,
        location_dim,
        n_heads,
        feature_featurisation="dot_product",
        location_featurisation="mlp",
        location_feature_combination="sum",
        normalisation="none",
        hidden_dim=16,
        feature_embed_dim=None,
        activation="swish",
    ):

        super().__init__()

        if feature_embed_dim is None:
            feature_embed_dim = int(feature_dim / (4 * n_heads))
            print(feature_embed_dim)

        if feature_featurisation == "dot_product":
            self.feature_featurisation = DotProductKernel(
                feature_dim, feature_dim, feature_dim, n_heads)
            featurised_feature_dim = 1
        elif feature_featurisation == "linear_concat":
            self.feature_featurisation = LinearConcatEmbedding(
                int(feature_embed_dim * n_heads / 2), feature_dim, feature_dim,
                n_heads)
            featurised_feature_dim = feature_embed_dim
        elif feature_featurisation == "linear_concat_linear":
            featurised_feature_dim = feature_embed_dim
            self.feature_featurisation = LinearConcatLinearEmbedding(
                feature_embed_dim * n_heads, feature_dim, feature_dim, n_heads)
        else:
            raise ValueError(
                f"{feature_featurisation} is not a valid feature featurisation"
            )

        if location_featurisation == "mlp":
            self.location_featurisation = nn.Sequential(
                Expression(lambda x: (
                    x[0],
                    x[1].unsqueeze(-2).repeat(1, 1, 1, n_heads, 1),
                    x[2],
                )),
                MultiheadWeightNet(
                    location_dim,
                    1,
                    n_heads,
                    hid_dim=hidden_dim,
                    act=activation,
                    bn=False,
                ),
                Expression(lambda x: x[1].squeeze(-1)),
            )
            featurised_location_dim = 1
        elif location_featurisation == "none":
            self.location_featurisation = Expression(
                lambda x: x[1].unsqueeze(-2).repeat(1, 1, 1, n_heads, 1))
            featurised_location_dim = location_dim
        else:
            raise ValueError(
                f"{location_featurisation} is not a valid location featurisation"
            )

        if location_feature_combination == "sum":
            self.location_feature_combination = Expression(
                lambda x: x[0] + x[1])
        elif location_feature_combination == "mlp":
            self.location_feature_combination = nn.Sequential(
                Expression(lambda x: (None, torch.cat(x, dim=-1), None)),
                MultiheadMLP(
                    featurised_feature_dim + featurised_location_dim,
                    hidden_dim,
                    1,
                    n_heads,
                    3,
                    activation,
                    False,
                ),
                Expression(lambda x: x[1].squeeze(-1)),
            )
        elif location_feature_combination == "multiply":
            self.location_feature_combination = Expression(
                lambda x: x[0] * x[1])
        else:
            raise ValueError(
                f"{location_feature_combination} is not a valid combination method"
            )

        if normalisation == "none":
            self.normalisation = lambda attention_coeffs, mask: attention_coeffs
        elif normalisation == "softmax":

            def attention_func(attention_coeffs, mask):
                attention_coeffs = torch.where(
                    mask.unsqueeze(-1),
                    attention_coeffs,
                    torch.tensor(
                        -1e38,
                        dtype=attention_coeffs.dtype,
                        device=attention_coeffs.device,
                    ) * torch.ones_like(attention_coeffs),
                )
                return F.softmax(attention_coeffs, dim=2)

            self.normalisation = attention_func
        elif normalisation == "dot_product":

            def attention_func(attention_coeffs, mask):
                attention_coeffs = torch.where(
                    mask.unsqueeze(-1),
                    attention_coeffs,
                    torch.tensor(
                        0.0,
                        dtype=attention_coeffs.dtype,
                        device=attention_coeffs.device,
                    ) * torch.ones_like(attention_coeffs),
                )

                normalization = mask.unsqueeze(-1).sum(-2, keepdim=True)
                normalization = torch.clamp(normalization, min=1)
                return attention_coeffs / normalization

            self.normalisation = attention_func