def LinearBNact(chin, chout, act='swish', bn=True): """assumes that the inputs to the net are shape (bs,n,mc_samples,c)""" assert act in ('relu', 'swish'), f"unknown activation type {act}" normlayer = MaskBatchNormNd(chout) return nn.Sequential(Pass(nn.Linear(chin, chout), dim=1), normlayer if bn else nn.Sequential(), Pass(Swish() if act == 'swish' else nn.ReLU(), dim=1))
def MultiheadLinearBNact(c_in, c_out, n_heads, act="swish", bn=True): """??(from LieConv - not sure it does assume) assumes that the inputs to the net are shape (bs,n,mc_samples,c)""" assert act in ("relu", "swish", "softplus"), f"unknown activation type {act}" normlayer = MaskBatchNormNd(c_out) return nn.Sequential( OrderedDict( [ ("linear", Pass(MultiheadLinear(n_heads, c_in, c_out), dim=1)), ("norm", normlayer if bn else nn.Sequential()), ("activation", Pass(activation_fn[act](), dim=1)), ] ) )
def LinearBNact(chin, chout, act="swish", bn=True): """assumes that the inputs to the net are shape (bs,n,mc_samples,c)""" assert act in ("relu", "swish", "softplus"), f"unknown activation type {act}" normlayer = MaskBatchNormNd(chout) return nn.Sequential( OrderedDict( [ ("linear", Pass(nn.Linear(chin, chout), dim=1)), ("norm", normlayer if bn else nn.Sequential()), ("activation", Pass(activation_fn[act](), dim=1)), ] ) )
def build_net(self, dim_input, dim_output, ds_frac=1, k=1536, nbhd=np.inf, act="swish", bn=True, num_layers=6, pool=True, liftsamples=1, fill=1 / 4, group=SE3, knn=False, cache=False, dropout=0, **kwargs): """ Arguments: dim_input: number of input channels: 1 for MNIST, 3 for RGB images, other for non images ds_frac: total downsampling to perform throughout the layers of the net. In (0,1) k: channel width for the network. Can be int (same for all) or array to specify individually. nbhd: number of samples to use for Monte Carlo estimation (p) act: bn: whether or not to use batch normalization. Recommended in al cases except dynamical systems. num_layers: number of BottleNeck Block layers in the network pool: liftsamples: number of samples to use in lifting. 1 for all groups with trivial stabilizer. Otherwise 2+ fill: specifies the fraction of the input which is included in local neighborhood. (can be array to specify a different value for each layer) group: group to be equivariant to knn: cache: dropout: dropout probability for fully connected layers """ if isinstance(fill, (float, int)): fill = [fill] * num_layers if isinstance(k, int): k = [k] * (num_layers + 1) conv = lambda ki, ko, fill: LieConv( ki, ko, mc_samples=nbhd, ds_frac=ds_frac, bn=bn, act=act, mean=True, group=group, fill=fill, cache=cache, knn=knn) layers = nn.ModuleList([ Pass(nn.Linear(dim_input, k[0]), dim=1), *[LieConvBottleBlock(k[i], k[i + 1], conv, bn=bn, act=act, fill=fill[i], dropout=dropout) for i in range(num_layers)], Pass(nn.ReLU(), dim=1), MaskBatchNormNd(k[-1]) if bn else nn.Sequential(), Pass(nn.Dropout(p=dropout), dim=1) if dropout else nn.Sequential(), GlobalPool(mean=True) if pool else Expression(lambda x: x[1]), nn.Linear(k[-1], dim_output) ]) self.group = group self.liftsamples = liftsamples return layers
def __init__(self, chin, ds_frac=1, num_outputs=1, k=1536, nbhd=np.inf, act="swish", bn=True, num_layers=6, mean=True, per_point=True, pool=True, liftsamples=1, fill=1 / 32, group=SE3, knn=False, cache=False, **kwargs): super().__init__() if isinstance(fill, (float, int)): fill = [fill] * num_layers if isinstance(k, int): k = [k] * (num_layers + 1) conv = lambda ki, ko, fill: LieConv(ki, ko, nbhd=nbhd, ds_frac=ds_frac, bn=bn, act=act, mean=mean, group=group, fill=fill, cache=cache, knn=knn, **kwargs) self.net = nn.Sequential( Pass(nn.Linear(chin, k[0]), dim=1), #embedding layer *[ BottleBlock(k[i], k[i + 1], conv, bn=bn, act=act, fill=fill[i]) for i in range(num_layers) ], #Pass(nn.Linear(k[-1],k[-1]//2),dim=1), MaskBatchNormNd(k[-1]) if bn else nn.Sequential(), Pass(Swish() if act == 'swish' else nn.ReLU(), dim=1), Pass(nn.Linear(k[-1], num_outputs), dim=1), GlobalPool(mean=mean) if pool else Expression(lambda x: x[1]), ) self.liftsamples = liftsamples self.per_point = per_point self.group = group
def build_net(self, dim_input, dim_output=1, k=12, nbhd=0, dropout=0.0, num_layers=6, fourier_features=16, norm_coords=True, norm_feats=False, thin_mlps=False, **kwargs): m_dim = 12 layer_class = ThinEGNNLayer if thin_mlps else EGNNLayer egnn = lambda: layer_class(dim=k, m_dim=m_dim, norm_coors=norm_coords, norm_feats=norm_feats, dropout=dropout, fourier_features=fourier_features, num_nearest_neighbors=nbhd, init_eps=1e-2) return nn.Sequential(Pass(nn.Linear(dim_input, k), dim=1), *[EGNNPass(egnn()) for _ in range(num_layers)], GlobalPool(mean=True), nn.Linear(k, dim_output))
def __init__(self, chin, chout, conv, bn=False, act='swish', fill=None, dropout=0): super().__init__() assert chin <= chout, f"unsupported channels chin{chin}, " \ f"chout{chout}. No upsampling atm." nonlinearity = Swish if act == 'swish' else nn.ReLU if fill is not None: self.conv = conv(chin // 4, chout // 4, fill=fill) else: self.conv = conv(chin // 4, chout // 4) self.net = nn.Sequential( Pass(nonlinearity(), dim=1), MaskBatchNormNd(chin) if bn else nn.Sequential(), Pass(nn.Dropout(dropout) if dropout else nn.Sequential()), Pass(nn.Linear(chin, chin // 4), dim=1), Pass(nonlinearity(), dim=1), MaskBatchNormNd(chin // 4) if bn else nn.Sequential(), self.conv, Pass(nonlinearity(), dim=1), MaskBatchNormNd(chout // 4) if bn else nn.Sequential(), Pass(nn.Dropout(dropout) if dropout else nn.Sequential()), Pass(nn.Linear(chout // 4, chout), dim=1), ) self.chin = chin
def __init__(self, chin, chout, conv, bn=False, act='swish', fill=None): super().__init__() assert chin <= chout, f"unsupported channels chin{chin}, chout{chout}" nonlinearity = Swish if act == 'swish' else nn.ReLU self.conv = conv(chin // 4, chout // 4, fill=fill) if fill is not None else conv( chin // 4, chout // 4) self.net = nn.Sequential( MaskBatchNormNd(chin) if bn else nn.Sequential(), Pass(nonlinearity(), dim=1), Pass(nn.Linear(chin, chin // 4), dim=1), MaskBatchNormNd(chin // 4) if bn else nn.Sequential(), Pass(nonlinearity(), dim=1), self.conv, MaskBatchNormNd(chout // 4) if bn else nn.Sequential(), Pass(nonlinearity(), dim=1), Pass(nn.Linear(chout // 4, chout), dim=1), ) self.chin = chin
def LieConvBNrelu(in_channels, out_channels, bn=True, act='swish', **kwargs): return nn.Sequential( LieConv(in_channels, out_channels, bn=bn, **kwargs), MaskBatchNormNd(out_channels) if bn else nn.Sequential(), Pass(Swish() if act == 'swish' else nn.ReLU(), dim=1))
def __init__( self, chin, ds_frac=1, num_outputs=1, k=1536, nbhd=np.inf, act="swish", bn=True, num_layers=6, mean=True, per_point=True, pool=True, liftsamples=1, fill=1 / 4, group=SE3, knn=False, cache=False, lie_algebra_nonlinearity=None, **kwargs, ): super().__init__() if isinstance(fill, (float, int)): fill = [fill] * num_layers if isinstance(k, int): k = [k] * (num_layers + 1) conv = lambda ki, ko, fill: LieConv( ki, ko, mc_samples=nbhd, ds_frac=ds_frac, bn=bn, act=act, mean=mean, group=group, fill=fill, cache=cache, knn=knn, **kwargs, ) self.net = nn.Sequential( Pass(nn.Linear(chin, k[0]), dim=1), # embedding layer *[ BottleBlock(k[i], k[i + 1], conv, bn=bn, act=act, fill=fill[i]) for i in range(num_layers) ], MaskBatchNormNd(k[-1]) if bn else nn.Sequential(), Pass(Swish() if act == "swish" else nn.ReLU(), dim=1), Pass(nn.Linear(k[-1], num_outputs), dim=1), GlobalPool(mean=mean) if pool else Expression(lambda x: x[1]), ) self.liftsamples = liftsamples self.per_point = per_point self.group = group self.lie_algebra_nonlinearity = lie_algebra_nonlinearity if lie_algebra_nonlinearity is not None: if lie_algebra_nonlinearity == "tanh": self.lie_algebra_nonlinearity = nn.Tanh() else: raise ValueError( f"{lie_algebra_nonlinearity} is not a supported nonlinearity" )
def pConvBNrelu(in_channels, out_channels, bn=True, act="swish", **kwargs): return nn.Sequential( PointConv(in_channels, out_channels, bn=bn, **kwargs), MaskBatchNormNd(out_channels) if bn else nn.Sequential(), Pass(Swish() if act == "swish" else nn.ReLU(), dim=1), )
def __init__( self, dim_input, dim_output, dim_hidden, num_layers, num_heads, global_pool=True, global_pool_mean=True, group=SE3(0.2), liftsamples=1, block_norm="layer_pre", output_norm="none", kernel_norm="none", kernel_type="mlp", kernel_dim=16, kernel_act="swish", mc_samples=0, fill=1.0, architecture="model_1", attention_fn="softmax", # softmax or dot product? XXX: TODO: "dot product" is used to describe both the attention weights being non-softmax (non-local attention paper) and the feature kernel. should fix terminology feature_embed_dim=None, max_sample_norm=None, lie_algebra_nonlinearity=None, ): super().__init__() if isinstance(dim_hidden, int): dim_hidden = [dim_hidden] * (num_layers + 1) if isinstance(num_heads, int): num_heads = [num_heads] * num_layers attention_block = lambda dim, n_head: EquivariantTransformerBlock( dim, n_head, group, block_norm=block_norm, kernel_norm=kernel_norm, kernel_type=kernel_type, kernel_dim=kernel_dim, kernel_act=kernel_act, mc_samples=mc_samples, fill=fill, attention_fn=attention_fn, feature_embed_dim=feature_embed_dim, ) activation_fn = { "swish": Swish, "relu": nn.ReLU, "softplus": nn.Softplus, } if architecture == "model_1": if output_norm == "batch": norm1 = nn.BatchNorm1d(dim_hidden[-1]) norm2 = nn.BatchNorm1d(dim_hidden[-1]) norm3 = nn.BatchNorm1d(dim_hidden[-1]) elif output_norm == "layer": norm1 = nn.LayerNorm(dim_hidden[-1]) norm2 = nn.LayerNorm(dim_hidden[-1]) norm3 = nn.LayerNorm(dim_hidden[-1]) elif output_norm == "none": norm1 = nn.Sequential() norm2 = nn.Sequential() norm3 = nn.Sequential() else: raise ValueError(f"{output_norm} is not a valid norm type.") self.net = nn.Sequential( Pass(nn.Linear(dim_input, dim_hidden[0]), dim=1), *[ attention_block(dim_hidden[i], num_heads[i]) for i in range(num_layers) ], GlobalPool(mean=global_pool_mean) if global_pool else Expression(lambda x: x[1]), nn.Sequential( norm1, activation_fn[kernel_act](), nn.Linear(dim_hidden[-1], dim_hidden[-1]), norm2, activation_fn[kernel_act](), nn.Linear(dim_hidden[-1], dim_hidden[-1]), norm3, activation_fn[kernel_act](), nn.Linear(dim_hidden[-1], dim_output), ), ) elif architecture == "lieconv": if output_norm == "batch": norm = nn.BatchNorm1d(dim_hidden[-1]) elif output_norm == "none": norm = nn.Sequential() else: raise ValueError(f"{output_norm} is not a valid norm type.") self.net = nn.Sequential( Pass(nn.Linear(dim_input, dim_hidden[0]), dim=1), *[ attention_block(dim_hidden[i], num_heads[i]) for i in range(num_layers) ], nn.Sequential( OrderedDict([ # ("norm", Pass(norm, dim=1)), ( "activation", Pass( activation_fn[kernel_act](), dim=1, ), ), ( "linear", Pass(nn.Linear(dim_hidden[-1], dim_output), dim=1), ), ])), GlobalPool(mean=global_pool_mean) if global_pool else Expression(lambda x: x[1]), ) else: raise ValueError(f"{architecture} is not a valid architecture.") self.group = group self.liftsamples = liftsamples self.max_sample_norm = max_sample_norm self.lie_algebra_nonlinearity = lie_algebra_nonlinearity if lie_algebra_nonlinearity is not None: if lie_algebra_nonlinearity == "tanh": self.lie_algebra_nonlinearity = nn.Tanh() else: raise ValueError( f"{lie_algebra_nonlinearity} is not a supported nonlinearity" )
def build_net(self, dim_input, dim_output, dim_hidden, num_layers, num_heads, global_pool_mean=True, group=SE3(0.2), liftsamples=1, block_norm="layer_pre", kernel_norm="none", kernel_type="mlp", kernel_dim=16, kernel_act="swish", nbhd=0, fill=1.0, attention_fn="norm_exp", feature_embed_dim=None, max_sample_norm=None, lie_algebra_nonlinearity=None, **kwargs): if isinstance(dim_hidden, int): dim_hidden = [dim_hidden] * (num_layers + 1) if isinstance(num_heads, int): num_heads = [num_heads] * num_layers attention_block = lambda dim, n_head: EquivariantTransformerBlock( dim, n_head, group, block_norm=block_norm, kernel_norm=kernel_norm, kernel_type=kernel_type, kernel_dim=kernel_dim, kernel_act=kernel_act, mc_samples=nbhd, fill=fill, attention_fn=attention_fn, feature_embed_dim=feature_embed_dim, ) layers = nn.Sequential( Pass(nn.Linear(dim_input, dim_hidden[0]), dim=1), *[ attention_block(dim_hidden[i], num_heads[i]) for i in range(num_layers) ], GlobalPool(mean=global_pool_mean), nn.Linear(dim_hidden[-1], dim_output)) self.group = group self.liftsamples = liftsamples self.max_sample_norm = max_sample_norm self.lie_algebra_nonlinearity = lie_algebra_nonlinearity if lie_algebra_nonlinearity is not None: if lie_algebra_nonlinearity == 'tanh': self.lie_algebra_nonlinearity = nn.Tanh() else: raise ValueError('{} is not a supported nonlinearity'.format( lie_algebra_nonlinearity)) return layers