def build_net(self, dim_input, dim_output, ds_frac=1, k=1536, nbhd=np.inf, act="swish", bn=True, num_layers=6, pool=True, liftsamples=1, fill=1 / 4, group=SE3, knn=False, cache=False, dropout=0, **kwargs): """ Arguments: dim_input: number of input channels: 1 for MNIST, 3 for RGB images, other for non images ds_frac: total downsampling to perform throughout the layers of the net. In (0,1) k: channel width for the network. Can be int (same for all) or array to specify individually. nbhd: number of samples to use for Monte Carlo estimation (p) act: bn: whether or not to use batch normalization. Recommended in al cases except dynamical systems. num_layers: number of BottleNeck Block layers in the network pool: liftsamples: number of samples to use in lifting. 1 for all groups with trivial stabilizer. Otherwise 2+ fill: specifies the fraction of the input which is included in local neighborhood. (can be array to specify a different value for each layer) group: group to be equivariant to knn: cache: dropout: dropout probability for fully connected layers """ if isinstance(fill, (float, int)): fill = [fill] * num_layers if isinstance(k, int): k = [k] * (num_layers + 1) conv = lambda ki, ko, fill: LieConv( ki, ko, mc_samples=nbhd, ds_frac=ds_frac, bn=bn, act=act, mean=True, group=group, fill=fill, cache=cache, knn=knn) layers = nn.ModuleList([ Pass(nn.Linear(dim_input, k[0]), dim=1), *[LieConvBottleBlock(k[i], k[i + 1], conv, bn=bn, act=act, fill=fill[i], dropout=dropout) for i in range(num_layers)], Pass(nn.ReLU(), dim=1), MaskBatchNormNd(k[-1]) if bn else nn.Sequential(), Pass(nn.Dropout(p=dropout), dim=1) if dropout else nn.Sequential(), GlobalPool(mean=True) if pool else Expression(lambda x: x[1]), nn.Linear(k[-1], dim_output) ]) self.group = group self.liftsamples = liftsamples return layers
def __init__(self, chin, ds_frac=1, num_outputs=1, k=1536, nbhd=np.inf, act="swish", bn=True, num_layers=6, mean=True, per_point=True, pool=True, liftsamples=1, fill=1 / 32, group=SE3, knn=False, cache=False, **kwargs): super().__init__() if isinstance(fill, (float, int)): fill = [fill] * num_layers if isinstance(k, int): k = [k] * (num_layers + 1) conv = lambda ki, ko, fill: LieConv(ki, ko, nbhd=nbhd, ds_frac=ds_frac, bn=bn, act=act, mean=mean, group=group, fill=fill, cache=cache, knn=knn, **kwargs) self.net = nn.Sequential( Pass(nn.Linear(chin, k[0]), dim=1), #embedding layer *[ BottleBlock(k[i], k[i + 1], conv, bn=bn, act=act, fill=fill[i]) for i in range(num_layers) ], #Pass(nn.Linear(k[-1],k[-1]//2),dim=1), MaskBatchNormNd(k[-1]) if bn else nn.Sequential(), Pass(Swish() if act == 'swish' else nn.ReLU(), dim=1), Pass(nn.Linear(k[-1], num_outputs), dim=1), GlobalPool(mean=mean) if pool else Expression(lambda x: x[1]), ) self.liftsamples = liftsamples self.per_point = per_point self.group = group
def Swish(): return Expression(lambda x: x * torch.sigmoid(x))
def __init__( self, chin, ds_frac=1, num_outputs=1, k=1536, nbhd=np.inf, act="swish", bn=True, num_layers=6, mean=True, per_point=True, pool=True, liftsamples=1, fill=1 / 4, group=SE3, knn=False, cache=False, lie_algebra_nonlinearity=None, **kwargs, ): super().__init__() if isinstance(fill, (float, int)): fill = [fill] * num_layers if isinstance(k, int): k = [k] * (num_layers + 1) conv = lambda ki, ko, fill: LieConv( ki, ko, mc_samples=nbhd, ds_frac=ds_frac, bn=bn, act=act, mean=mean, group=group, fill=fill, cache=cache, knn=knn, **kwargs, ) self.net = nn.Sequential( Pass(nn.Linear(chin, k[0]), dim=1), # embedding layer *[ BottleBlock(k[i], k[i + 1], conv, bn=bn, act=act, fill=fill[i]) for i in range(num_layers) ], MaskBatchNormNd(k[-1]) if bn else nn.Sequential(), Pass(Swish() if act == "swish" else nn.ReLU(), dim=1), Pass(nn.Linear(k[-1], num_outputs), dim=1), GlobalPool(mean=mean) if pool else Expression(lambda x: x[1]), ) self.liftsamples = liftsamples self.per_point = per_point self.group = group self.lie_algebra_nonlinearity = lie_algebra_nonlinearity if lie_algebra_nonlinearity is not None: if lie_algebra_nonlinearity == "tanh": self.lie_algebra_nonlinearity = nn.Tanh() else: raise ValueError( f"{lie_algebra_nonlinearity} is not a supported nonlinearity" )
def __init__( self, dim_input, dim_output, dim_hidden, num_layers, num_heads, global_pool=True, global_pool_mean=True, group=SE3(0.2), liftsamples=1, block_norm="layer_pre", output_norm="none", kernel_norm="none", kernel_type="mlp", kernel_dim=16, kernel_act="swish", mc_samples=0, fill=1.0, architecture="model_1", attention_fn="softmax", # softmax or dot product? XXX: TODO: "dot product" is used to describe both the attention weights being non-softmax (non-local attention paper) and the feature kernel. should fix terminology feature_embed_dim=None, max_sample_norm=None, lie_algebra_nonlinearity=None, ): super().__init__() if isinstance(dim_hidden, int): dim_hidden = [dim_hidden] * (num_layers + 1) if isinstance(num_heads, int): num_heads = [num_heads] * num_layers attention_block = lambda dim, n_head: EquivariantTransformerBlock( dim, n_head, group, block_norm=block_norm, kernel_norm=kernel_norm, kernel_type=kernel_type, kernel_dim=kernel_dim, kernel_act=kernel_act, mc_samples=mc_samples, fill=fill, attention_fn=attention_fn, feature_embed_dim=feature_embed_dim, ) activation_fn = { "swish": Swish, "relu": nn.ReLU, "softplus": nn.Softplus, } if architecture == "model_1": if output_norm == "batch": norm1 = nn.BatchNorm1d(dim_hidden[-1]) norm2 = nn.BatchNorm1d(dim_hidden[-1]) norm3 = nn.BatchNorm1d(dim_hidden[-1]) elif output_norm == "layer": norm1 = nn.LayerNorm(dim_hidden[-1]) norm2 = nn.LayerNorm(dim_hidden[-1]) norm3 = nn.LayerNorm(dim_hidden[-1]) elif output_norm == "none": norm1 = nn.Sequential() norm2 = nn.Sequential() norm3 = nn.Sequential() else: raise ValueError(f"{output_norm} is not a valid norm type.") self.net = nn.Sequential( Pass(nn.Linear(dim_input, dim_hidden[0]), dim=1), *[ attention_block(dim_hidden[i], num_heads[i]) for i in range(num_layers) ], GlobalPool(mean=global_pool_mean) if global_pool else Expression(lambda x: x[1]), nn.Sequential( norm1, activation_fn[kernel_act](), nn.Linear(dim_hidden[-1], dim_hidden[-1]), norm2, activation_fn[kernel_act](), nn.Linear(dim_hidden[-1], dim_hidden[-1]), norm3, activation_fn[kernel_act](), nn.Linear(dim_hidden[-1], dim_output), ), ) elif architecture == "lieconv": if output_norm == "batch": norm = nn.BatchNorm1d(dim_hidden[-1]) elif output_norm == "none": norm = nn.Sequential() else: raise ValueError(f"{output_norm} is not a valid norm type.") self.net = nn.Sequential( Pass(nn.Linear(dim_input, dim_hidden[0]), dim=1), *[ attention_block(dim_hidden[i], num_heads[i]) for i in range(num_layers) ], nn.Sequential( OrderedDict([ # ("norm", Pass(norm, dim=1)), ( "activation", Pass( activation_fn[kernel_act](), dim=1, ), ), ( "linear", Pass(nn.Linear(dim_hidden[-1], dim_output), dim=1), ), ])), GlobalPool(mean=global_pool_mean) if global_pool else Expression(lambda x: x[1]), ) else: raise ValueError(f"{architecture} is not a valid architecture.") self.group = group self.liftsamples = liftsamples self.max_sample_norm = max_sample_norm self.lie_algebra_nonlinearity = lie_algebra_nonlinearity if lie_algebra_nonlinearity is not None: if lie_algebra_nonlinearity == "tanh": self.lie_algebra_nonlinearity = nn.Tanh() else: raise ValueError( f"{lie_algebra_nonlinearity} is not a supported nonlinearity" )
def __init__( self, feature_dim, location_dim, n_heads, feature_featurisation="dot_product", location_featurisation="mlp", location_feature_combination="sum", normalisation="none", hidden_dim=16, feature_embed_dim=None, activation="swish", ): super().__init__() if feature_embed_dim is None: feature_embed_dim = int(feature_dim / (4 * n_heads)) print(feature_embed_dim) if feature_featurisation == "dot_product": self.feature_featurisation = DotProductKernel( feature_dim, feature_dim, feature_dim, n_heads) featurised_feature_dim = 1 elif feature_featurisation == "linear_concat": self.feature_featurisation = LinearConcatEmbedding( int(feature_embed_dim * n_heads / 2), feature_dim, feature_dim, n_heads) featurised_feature_dim = feature_embed_dim elif feature_featurisation == "linear_concat_linear": featurised_feature_dim = feature_embed_dim self.feature_featurisation = LinearConcatLinearEmbedding( feature_embed_dim * n_heads, feature_dim, feature_dim, n_heads) else: raise ValueError( f"{feature_featurisation} is not a valid feature featurisation" ) if location_featurisation == "mlp": self.location_featurisation = nn.Sequential( Expression(lambda x: ( x[0], x[1].unsqueeze(-2).repeat(1, 1, 1, n_heads, 1), x[2], )), MultiheadWeightNet( location_dim, 1, n_heads, hid_dim=hidden_dim, act=activation, bn=False, ), Expression(lambda x: x[1].squeeze(-1)), ) featurised_location_dim = 1 elif location_featurisation == "none": self.location_featurisation = Expression( lambda x: x[1].unsqueeze(-2).repeat(1, 1, 1, n_heads, 1)) featurised_location_dim = location_dim else: raise ValueError( f"{location_featurisation} is not a valid location featurisation" ) if location_feature_combination == "sum": self.location_feature_combination = Expression( lambda x: x[0] + x[1]) elif location_feature_combination == "mlp": self.location_feature_combination = nn.Sequential( Expression(lambda x: (None, torch.cat(x, dim=-1), None)), MultiheadMLP( featurised_feature_dim + featurised_location_dim, hidden_dim, 1, n_heads, 3, activation, False, ), Expression(lambda x: x[1].squeeze(-1)), ) elif location_feature_combination == "multiply": self.location_feature_combination = Expression( lambda x: x[0] * x[1]) else: raise ValueError( f"{location_feature_combination} is not a valid combination method" ) if normalisation == "none": self.normalisation = lambda attention_coeffs, mask: attention_coeffs elif normalisation == "softmax": def attention_func(attention_coeffs, mask): attention_coeffs = torch.where( mask.unsqueeze(-1), attention_coeffs, torch.tensor( -1e38, dtype=attention_coeffs.dtype, device=attention_coeffs.device, ) * torch.ones_like(attention_coeffs), ) return F.softmax(attention_coeffs, dim=2) self.normalisation = attention_func elif normalisation == "dot_product": def attention_func(attention_coeffs, mask): attention_coeffs = torch.where( mask.unsqueeze(-1), attention_coeffs, torch.tensor( 0.0, dtype=attention_coeffs.dtype, device=attention_coeffs.device, ) * torch.ones_like(attention_coeffs), ) normalization = mask.unsqueeze(-1).sum(-2, keepdim=True) normalization = torch.clamp(normalization, min=1) return attention_coeffs / normalization self.normalisation = attention_func