def __init__(self, input_frames, output_frames, kernel_size, N): super(Unet_Rot, self).__init__() r2_act = gspaces.Rot2dOnR2(N = N) self.feat_type_in = nn.FieldType(r2_act, input_frames*[r2_act.irrep(1)]) self.feat_type_in_hid = nn.FieldType(r2_act, 32*[r2_act.regular_repr]) self.feat_type_hid_out = nn.FieldType(r2_act, (16 + input_frames)*[r2_act.irrep(1)]) self.feat_type_out = nn.FieldType(r2_act, output_frames*[r2_act.irrep(1)]) self.conv1 = nn.SequentialModule( nn.R2Conv(self.feat_type_in, self.feat_type_in_hid, kernel_size = kernel_size, stride = 2, padding = (kernel_size - 1)//2), nn.InnerBatchNorm(self.feat_type_in_hid), nn.ReLU(self.feat_type_in_hid) ) self.conv2 = rot_conv2d(32, 64, kernel_size = kernel_size, stride = 1, N = N) self.conv2_1 = rot_conv2d(64, 64, kernel_size = kernel_size, stride = 1, N = N) self.conv3 = rot_conv2d(64, 128, kernel_size = kernel_size, stride = 2, N = N) self.conv3_1 = rot_conv2d(128, 128, kernel_size = kernel_size, stride = 1, N = N) self.conv4 = rot_conv2d(128, 256, kernel_size = kernel_size, stride = 2, N = N) self.conv4_1 = rot_conv2d(256, 256, kernel_size = kernel_size, stride = 1, N = N) self.deconv3 = rot_deconv2d(256, 64, N) self.deconv2 = rot_deconv2d(192, 32, N) self.deconv1 = rot_deconv2d(96, 16, N, last_deconv = True) self.output_layer = nn.R2Conv(self.feat_type_hid_out, self.feat_type_out, kernel_size = kernel_size, padding = (kernel_size - 1)//2)
def __init__(self, channel_in=3, n_classes=4, rot_n=4): super(SmallE2, self).__init__() r2_act = gspaces.Rot2dOnR2(N=rot_n) self.feat_type_in = nn.FieldType(r2_act, channel_in * [r2_act.trivial_repr]) feat_type_hid = nn.FieldType(r2_act, 8 * [r2_act.regular_repr]) feat_type_out = nn.FieldType(r2_act, 2 * [r2_act.regular_repr]) self.bn = nn.InnerBatchNorm(feat_type_hid) self.relu = nn.ReLU(feat_type_hid) self.convin = nn.R2Conv(self.feat_type_in, feat_type_hid, kernel_size=3) self.convhid = nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size=3) self.convout = nn.R2Conv(feat_type_hid, feat_type_out, kernel_size=3) self.avgpool = nn.PointwiseAvgPool(feat_type_out, 3) self.invariant_map = nn.GroupPooling(feat_type_out) c = self.invariant_map.out_type.size self.lin_in = torch.nn.Linear(c, 64) self.elu = torch.nn.ELU() self.lin_out = torch.nn.Linear(64, n_classes)
def __init__(self, input_channels, hidden_dim, kernel_size, N # Group size ): super(rot_resblock, self).__init__() # Specify symmetry transformation r2_act = gspaces.Rot2dOnR2(N = N) feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr]) feat_type_hid = nn.FieldType(r2_act, hidden_dim*[r2_act.regular_repr]) self.layer1 = nn.SequentialModule( nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2), nn.InnerBatchNorm(feat_type_hid), nn.ReLU(feat_type_hid) ) self.layer2 = nn.SequentialModule( nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2), nn.InnerBatchNorm(feat_type_hid), nn.ReLU(feat_type_hid) ) self.upscale = nn.SequentialModule( nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2), nn.InnerBatchNorm(feat_type_hid), nn.ReLU(feat_type_hid) ) self.input_channels = input_channels self.hidden_dim = hidden_dim
def __init__(self, in_type, num_classes=10): super(ClassificationHead, self).__init__() gspace = in_type.gspace self.add_module('gpool', nn.GroupPooling(in_type)) # number of output channels # Fully Connected in_type = self.gpool.out_type out_type = nn.FieldType(gspace, 64 * [gspace.trivial_repr]) self.add_module( 'linear1', sscnn.e2cnn.PlainConv(in_type, out_type, kernel_size=1, padding=0, bias=False)) self.add_module('relu1', nn.ReLU(out_type, inplace=True)) in_type = out_type out_type = nn.FieldType(gspace, num_classes * [gspace.trivial_repr]) self.add_module( 'linear2', sscnn.e2cnn.PlainConv(in_type, out_type, kernel_size=1, padding=0, bias=False))
def give_feat_types(self): ''' Output: feat_types - list of features types (see class ) - self.fib_reps[i]=[k_1,...,k_l] gives a list of integers where k_i stands for irrep(k_i) of the rotation group or if k_i=-1 for the regular representation the sume of rep(k_1),...,rep(k_l) determines the ith element of "feat_types" ''' #Feat type of embedding consist of sums of trivial and context fiber representation: feat_types = [ G_CNN.FieldType(self.G_act, [self.G_act.trivial_repr, self.context_rep]) ] #Go over all hidden fiber reps: for ids in self.hidden_reps_ids: #New layer collects the sum of individual representations to one list: new_layer = self.give_reps_from_ids(ids) #Append a new feature type given by the new layer: feat_types.append(G_CNN.FieldType(self.G_act, new_layer)) #Get the fiber representation for the pre-covariance tensor: pre_cov_rep = cov_activ_func.get_pre_cov_rep(self.G_act, self.dim_cov_est) #The final fiber representation is given by the sum of the identity (=rotation) representation and #the covariance matrix: feat_types.append( G_CNN.FieldType(self.G_act, [self.target_rep, pre_cov_rep])) return (feat_types)
def __init__(self, in_chan, out_chan, imsize, kernel_size=5, N=8): super(DNRestrictedLeNet, self).__init__() z = imsize // 2 // 2 self.r2_act = gspaces.FlipRot2dOnR2(N) in_type = e2nn.FieldType(self.r2_act, [self.r2_act.trivial_repr]) self.input_type = in_type out_type = e2nn.FieldType(self.r2_act, 6 * [self.r2_act.regular_repr]) self.mask = e2nn.MaskModule(in_type, imsize, margin=1) self.conv1 = e2nn.R2Conv(in_type, out_type, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) self.relu1 = e2nn.ReLU(out_type, inplace=True) self.pool1 = e2nn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2) self.gpool = e2nn.GroupPooling(out_type) self.conv2 = nn.Conv2d(6, 16, kernel_size, padding=kernel_size // 2) self.fc1 = nn.Linear(16 * z * z, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, out_chan) self.drop = nn.Dropout(p=0.5) # dummy parameter for tracking device self.dummy = nn.Parameter(torch.empty(0))
def __init__(self, in_type, conv_func): super(RegressionHead, self).__init__() gspace = in_type.gspace self.add_module('gpool', nn.PointwiseAdaptiveMaxPool(in_type, (1, 1))) if isinstance(in_type.gspace, e2cnn.gspaces.Rot2dOnR2): base = 8 elif isinstance(in_type.gspace, e2cnn.gspaces.FlipRot2dOnR2): base = 4 # number of output channels # Fully Connected in_type = in_type out_type = nn.FieldType(gspace, 2 * base * [gspace.regular_repr]) self.add_module( 'block1', nn.SequentialModule( conv_func(in_type, out_type, kernel_size=1, padding=0, bias=False), nn.ReLU(out_type, inplace=True))) in_type = out_type if isinstance(gspace, gspaces.Rot2dOnR2): out_type = nn.FieldType(gspace, [gspace.irrep(1)]) elif isinstance(gspace, gspaces.FlipRot2dOnR2): out_type = nn.FieldType(gspace, [gspace.irrep(1, 1)]) else: raise NotImplementedError self.add_module( 'block2', conv_func(in_type, out_type, kernel_size=1, padding=0, bias=False))
def __init__(self, input_channels, output_channels, kernel_size, stride, N, activation = True, deconv = False, last_deconv = False): super(rot_conv2d, self).__init__() r2_act = gspaces.Rot2dOnR2(N = N) feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr]) feat_type_hid = nn.FieldType(r2_act, output_channels*[r2_act.regular_repr]) if not deconv: if activation: self.layer = nn.SequentialModule( nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = (kernel_size - 1)//2), nn.InnerBatchNorm(feat_type_hid), nn.ReLU(feat_type_hid) ) else: self.layer = nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride,padding = (kernel_size - 1)//2) else: if last_deconv: feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr]) feat_type_hid = nn.FieldType(r2_act, output_channels*[r2_act.irrep(1)]) self.layer = nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = 0) else: self.layer = nn.SequentialModule( nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = 0), nn.InnerBatchNorm(feat_type_hid), nn.ReLU(feat_type_hid) )
def __init__(self, n_classes=6): super(SteerCNN, self).__init__() # the model is equivariant under rotations by 45 degrees, modelled by C8 self.r2_act = gspaces.Rot2dOnR2(N=4) # the input image is a scalar field, corresponding to the trivial representation input_type = nn_e2.FieldType(self.r2_act, 3 * [self.r2_act.trivial_repr]) # we store the input type for wrapping the images into a geometric tensor during the forward pass self.input_type = input_type # convolution 1 # first specify the output type of the convolutional layer # we choose 24 feature fields, each transforming under the regular representation of C8 out_type = nn_e2.FieldType(self.r2_act, 24 * [self.r2_act.regular_repr]) self.block1 = nn_e2.SequentialModule( nn_e2.R2Conv(input_type, out_type, kernel_size=7, padding=3, bias=False), nn_e2.InnerBatchNorm(out_type), nn_e2.ReLU(out_type, inplace=True)) self.pool1 = nn_e2.PointwiseAvgPool(out_type, 4) # convolution 2 # the old output type is the input type to the next layer in_type = self.block1.out_type # the output type of the second convolution layer are 48 regular feature fields of C8 #out_type = nn_e2.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr]) self.block2 = nn_e2.SequentialModule( nn_e2.R2Conv(in_type, out_type, kernel_size=7, padding=3, bias=False), nn_e2.InnerBatchNorm(out_type), nn_e2.ReLU(out_type, inplace=True)) self.pool2 = nn_e2.SequentialModule( nn_e2.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=1, padding=0), nn_e2.PointwiseAvgPool(out_type, 4), nn_e2.GroupPooling(out_type)) # PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=7) # number of output channels c = 24 * 13 * 13 #self.gpool.out_type.size # Fully Connected self.fully_net = torch.nn.Sequential( torch.nn.Linear(c, 64), torch.nn.BatchNorm1d(64), torch.nn.ELU(inplace=True), torch.nn.Linear(64, n_classes), )
def __init__(self, input_shape, num_actions, dueling_DQN): super(D4_steerable_DQN_Snake, self).__init__() self.input_shape = input_shape self.num_actions = num_actions self.dueling_DQN = dueling_DQN self.r2_act = gspaces.FlipRot2dOnR2(N=4) self.input_type = nn.FieldType( self.r2_act, input_shape[0] * [self.r2_act.trivial_repr]) feature1_type = nn.FieldType(self.r2_act, 8 * [self.r2_act.regular_repr]) feature2_type = nn.FieldType(self.r2_act, 12 * [self.r2_act.regular_repr]) feature3_type = nn.FieldType(self.r2_act, 12 * [self.r2_act.regular_repr]) feature4_type = nn.FieldType(self.r2_act, 32 * [self.r2_act.regular_repr]) self.feature_field1 = nn.SequentialModule( nn.R2Conv(self.input_type, feature1_type, kernel_size=7, padding=2, stride=2, bias=False), nn.ReLU(feature1_type, inplace=True)) self.feature_field2 = nn.SequentialModule( nn.R2Conv(feature1_type, feature2_type, kernel_size=5, padding=1, stride=2, bias=False), nn.ReLU(feature2_type, inplace=True)) self.feature_field3 = nn.SequentialModule( nn.R2Conv(feature2_type, feature3_type, kernel_size=5, padding=1, stride=1, bias=False), nn.ReLU(feature3_type, inplace=True)) self.equivariant_features = nn.SequentialModule( nn.R2Conv(feature3_type, feature4_type, kernel_size=5, stride=1, bias=False), nn.ReLU(feature4_type, inplace=True)) self.gpool = nn.GroupPooling(feature4_type) self.feature_shape() if self.dueling_DQN: print("You are using Dueling DQN") self.advantage = torch.nn.Linear( self.equivariant_features.out_type.size, self.num_actions) #self.value = torch.nn.Linear(self.gpool.out_type.size, 1) self.value = torch.nn.Linear( self.equivariant_features.out_type.size, 1) else: self.actionvalue = torch.nn.Linear( self.equivariant_features.out_type.size, self.num_actions)
def __init__(self, base='DNSteerableLeNet', in_chan=1, n_classes=2, imsize=150, kernel_size=5, N=8, quiet=True, number_rotations=None): super(DNSteerableLeNet, self).__init__() kernel_size = int(kernel_size) out_chan = int(n_classes) if number_rotations != None: N = int(number_rotations) z = imsize // 2 // 2 self.r2_act = gspaces.FlipRot2dOnR2(N) in_type = e2nn.FieldType(self.r2_act, [self.r2_act.trivial_repr]) self.input_type = in_type out_type = e2nn.FieldType(self.r2_act, 6 * [self.r2_act.regular_repr]) self.mask = e2nn.MaskModule(in_type, imsize, margin=1) self.conv1 = e2nn.R2Conv(in_type, out_type, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) self.relu1 = e2nn.ReLU(out_type, inplace=True) self.pool1 = e2nn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2) in_type = self.pool1.out_type out_type = e2nn.FieldType(self.r2_act, 16 * [self.r2_act.regular_repr]) self.conv2 = e2nn.R2Conv(in_type, out_type, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) self.relu2 = e2nn.ReLU(out_type, inplace=True) self.pool2 = e2nn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2) self.gpool = e2nn.GroupPooling(out_type) self.fc1 = nn.Linear(16 * z * z, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, out_chan) self.drop = nn.Dropout(p=0.5) # dummy parameter for tracking device self.dummy = nn.Parameter(torch.empty(0))
def build_steer_cnn_decoder( context_rep_ids, hidden_reps_ids, kernel_sizes, mean_rep_ids, covariance_activation="quadratic", N=4, flip=True, max_frequency=30, activation="relu", padding_mode="zeros", ): if flip: gspace = (gspaces.FlipRot2dOnR2(N=N) if N != -1 else gspaces.FlipRot2dOnR2(N=N, maximum_frequency=max_frequency)) else: gspace = (gspaces.Rot2dOnR2(N=N) if N != -1 else gspaces.Rot2dOnR2( N=N, maximum_frequency=max_frequency)) in_field_type = gnn.FieldType( gspace, [gspace.trivial_repr, *reps_from_ids(gspace, context_rep_ids)]) hidden_field_types = [ gnn.FieldType(gspace, reps_from_ids(gspace, ids)) for ids in hidden_reps_ids ] mean_field_type = gnn.FieldType(gspace, reps_from_ids(gspace, mean_rep_ids)) pre_covariance_field_type = get_pre_covariance_field_type( gspace, mean_field_type, covariance_activation) out_field_type = mean_field_type + pre_covariance_field_type init_modify = (1.0 if not (mean_rep_ids == [[0]] and context_rep_ids == [[0]]) else 0.833, ) if N == -1: init_modify = 1.0 return build_steer_cnn_2d( in_field_type, hidden_field_types, kernel_sizes, out_field_type, gspace, activation, padding_mode, )
def small_wrn(N=4): import torch.nn as nn from collections import OrderedDict gspace = gspaces.FlipRot2dOnR2(N) r1 = enn.FieldType(gspace, [gspace.trivial_repr] * 3) r2 = enn.FieldType(gspace, [gspace.regular_repr] * 3) rout = enn.FieldType(gspace, [gspace.trivial_repr] * 256) wrn = Small_Standalone(in_type=r1, out_type=rout, inner_type=r2, dropout_rate=0.3) model = nn.Sequential(OrderedDict([('wrn', wrn), ('fc', nn.ReLU())])) return model
def build_nnet(self, dims, activation_fn=enn.ReLU): nnet = [] domains, codomains = self.parse_vnorms() if self.args.learn_p: if self.args.mixed: domains = [ torch.nn.Parameter(torch.tensor(0.)) for _ in domains ] else: domains = [torch.nn.Parameter(torch.tensor(0.))] * len(domains) codomains = domains[1:] + [domains[0]] in_type = enn.FieldType(self.group_action_type, [self.group_action_type.trivial_repr]) out_dims = int(dims[1:][0] / self.group_card) out_type = enn.FieldType( self.group_action_type, out_dims * [self.group_action_type.regular_repr]) total_layers = len(domains) for i, (in_dim, out_dim, domain, codomain) in enumerate( zip(dims[:-1], dims[1:], domains, codomains)): nnet.append( base_layers.get_equivar_conv2d( in_type, out_type, self.group_action_type, kernel_size=self.args.kernel_size, stride=1, padding=1, coeff=self.args.coeff, n_iterations=self.args.n_lipschitz_iters, atol=self.args.atol, rtol=self.args.rtol, domain=domain, codomain=codomain, zero_init=(out_dim == 2), )) nnet.append(activation_fn(nnet[-1].out_type, inplace=True)) in_type = nnet[-1].out_type if i == total_layers - 2: out_type = enn.FieldType(self.group_action_type, [self.group_action_type.trivial_repr]) else: out_type = enn.FieldType( self.group_action_type, out_dim * [self.group_action_type.regular_repr]) return torch.nn.Sequential(*nnet)
def __init__(self, args, n_blocks, input_size, hidden_size, n_hidden, group_action_type=None): super(FiberRealNVP, self).__init__() _, self.c, self.h, self.w = input_size[:] assert self.c > 1 mask = torch.arange(self.c).float() % 2 self.n_blocks = int(n_blocks) self.n_hidden = n_hidden self.group_action_type = GROUPS[args.group] self.out_fiber = args.out_fiber self.field_type = args.field_type self.group_card = len(list(self.group_action_type.testing_elements)) self.dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') i_mask = 1 - mask mask = torch.stack([mask, i_mask]).repeat(int(self.n_blocks / 2) + 1, 1) self.p_z = StandardNormal self.input_type = enn.FieldType( self.group_action_type, self.c * [self.group_action_type.trivial_repr]) self.activation_fn = ACT_FNS[args.act] self.s, self.t = create_equivariant_real_nvp_blocks( input_size, self.input_type, self.field_type, self.out_fiber, self.activation_fn, hidden_size, self.n_blocks, n_hidden, self.group_action_type, args.kernel_size, args.realnvp_padding) self.mask = nn.Parameter(mask, requires_grad=False)
def check_invariance(self, r2_act, out_type, data, func, data_type=None): _, c, h, w = data.shape input_type = enn.FieldType(r2_act, self.c * [r2_act.trivial_repr]) y = func(data) for g in r2_act.testing_elements: data = enn.GeometricTensor( data.tensor.view(-1, c, h, w).cpu(), input_type) x_transformed = enn.GeometricTensor( data.transform(g).tensor.view(-1, c, h, w).cuda(), input_type) y_from_x_transformed = func(x_transformed) y_transformed_from_x = y # Invariance Condition data = enn.GeometricTensor( data.tensor.squeeze().view(-1, c, h, w).cuda(), input_type) # assert torch.allclose(output_conv.squeeze(), output_rg_conv.squeeze(), atol=1e-5), g print_y = y_from_x_transformed.tensor.detach().to( 'cpu').numpy().squeeze() print("{:4d} : {}".format(g, print_y)) assert torch.allclose(y_from_x_transformed.tensor.squeeze(), y_transformed_from_x.tensor.squeeze(), atol=1e-5), g print("Passed Invariance Test")
def check_equivariance(r2_act, out_type, data, func, data_type=None): input_type = enn.FieldType(r2_act, [r2_act.trivial_repr]) if data_type == 'GeomTensor': data = enn.GeometricTensor(data.view(-1, 1, 1, 2), input_type) for g in r2_act.testing_elements: output = func(data) if data_type == 'GeomTensor': rg_output = enn.GeometricTensor( output.tensor.view(-1, 1, 1, 2).cpu(), out_type).transform(g) data = enn.GeometricTensor( data.tensor.view(-1, 1, 1, 2).cpu(), input_type) x_transformed = enn.GeometricTensor( data.transform(g).tensor.view(-1, 1, 1, 2), input_type) else: rg_output = enn.GeometricTensor( output.view(-1, 1, 1, 2).cpu(), out_type).transform(g) data = enn.GeometricTensor( data.view(-1, 1, 1, 2).cpu(), input_type) x_transformed = data.transform(g).tensor.view(-1, 1, 1, 2) output_rg = func(x_transformed) # Equivariance Condition if data_type == 'GeomTensor': output_rg = enn.GeometricTensor(output_rg.tensor.cpu(), out_type) data = enn.GeometricTensor(data.tensor.squeeze().view(-1, 1, 1, 2), input_type) else: output_rg = enn.GeometricTensor( output_rg.view(-1, 1, 1, 2).cpu(), out_type) data = data.tensor.squeeze() assert torch.allclose(rg_output.tensor.cpu().squeeze(), output_rg.tensor.squeeze(), atol=1e-5), g
def __init__(self, args, n_blocks, input_size, hidden_size, n_hidden, group_action_type=None): super(EquivariantToyResFlow, self).__init__() self.args = args self.beta = args.beta self.n_blocks = n_blocks self.activation_fn = ACT_FNS[args.act] self.group_action_type = GROUPS[args.group] # self.group_action_type = gspaces.FlipRot2dOnR2(N=4) self.group_card = len(list(self.group_action_type.testing_elements)) self.input_type = enn.FieldType(self.group_action_type, [self.group_action_type.trivial_repr]) dims = [2] + list(map(int, args.dims.split('-'))) + [2] blocks = [] if self.args.actnorm: blocks.append(layers.EquivariantActNorm1d(2)) for _ in range(n_blocks): blocks.append( layers.Equivar_iResBlock( self.build_nnet(dims, self.activation_fn), n_dist=self.args.n_dist, n_power_series=self.args.n_power_series, exact_trace=self.args.exact_trace, brute_force=self.args.brute_force, n_samples=self.args.batch_size, neumann_grad=True, grad_in_forward=True, )) if self.args.actnorm: blocks.append(layers.EquivariantActNorm1d(2)) if self.args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2)) self.flow_model = layers.SequentialFlow(blocks)
def mixed_fiber(gspace: gspaces.GeneralOnR2, planes: int, ratio: float, field_type: int = 0, fixparams: bool = True): N = gspace.fibergroup.order() assert N > 0 if isinstance(gspace, gspaces.FlipRot2dOnR2): subgroup = (0, 1) elif isinstance(gspace, gspaces.Flip2dOnR2): subgroup = 1 else: raise ValueError(f"Space {gspace} not supported") qr = gspace.quotient_repr(subgroup) rr = gspace.regular_repr planes = planes / rr.size if fixparams: planes *= math.sqrt(N * CHANNELS_CONSTANT) r_planes = int(planes * ratio) q_planes = int(2 * planes * (1 - ratio)) return enn.FieldType(gspace, [rr] * r_planes + [qr] * q_planes).sorted()
def quotient_fiber(gspace: gspaces.GeneralOnR2, planes: int, field_type: int = 0, fixparams: bool = True): """ build a quotient fiber with the specified number of channels""" N = gspace.fibergroup.order() assert N > 0 if isinstance(gspace, gspaces.FlipRot2dOnR2): n = N / 2 subgroups = [] for axis in [0, round(n / 4), round(n / 2)]: subgroups.append((int(axis), 1)) elif isinstance(gspace, gspaces.Rot2dOnR2): assert N % 4 == 0 # subgroups = [int(round(N/2)), int(round(N/4))] subgroups = [2, 4] elif isinstance(gspace, gspaces.Flip2dOnR2): subgroups = [2] else: raise ValueError(f"Space {gspace} not supported") rs = [gspace.quotient_repr(subgroup) for subgroup in subgroups] size = sum([r.size for r in rs]) planes = planes / size if fixparams: planes *= math.sqrt(N * CHANNELS_CONSTANT) planes = int(planes) return enn.FieldType(gspace, rs * planes).sorted()
def create_equivariant_convexp_blocks(input_size, in_type, field_type, out_fiber, activation_fn, hidden_size, n_blocks, n_hidden, group_action_type, kernel_size=3, padding=1): nets = [] _, c, h, w = input_size input_type = in_type _, c, h, w = input_size out_type = enn.FieldType(group_action_type, c * [group_action_type.trivial_repr]) for i in range(n_blocks): s_block = [ enn.R2Conv(in_type, out_type, kernel_size=kernel_size, padding=padding, bias=True), # enn.InnerBatchNorm(out_type), # activation_fn(out_type, inplace=True) ] nets += [MultiInputSequential(*s_block)] s = nets = MultiInputSequential(*nets) return s
def trivial_feature_type(gspace: gspaces.GSpace, planes: int, fixparams: bool = True): """ build a trivial feature map with the specified number of channels""" if fixparams: planes *= math.sqrt(gspace.fibergroup.order()) planes = int(planes) return enn.FieldType(gspace, [gspace.trivial_repr] * planes)
def generate_2d_rot8(out_path): r2_act = gspaces.Rot2dOnR2(N=8) feat_type_in = gnn.FieldType(r2_act, [r2_act.trivial_repr]) feat_type_out = gnn.FieldType(r2_act, 3 * [r2_act.regular_repr]) conv = gnn.R2Conv(feat_type_in, feat_type_out, kernel_size=3, bias=False) xs, ys, ws = [], [], [] for task_idx in range(10000): gnn.init.generalized_he_init(conv.weights, conv.basisexpansion) inp = gnn.GeometricTensor(torch.randn(20, 1, 32, 32), feat_type_in) result = conv(inp).tensor.detach().cpu().numpy() xs.append(inp.tensor.detach().cpu().numpy()) ys.append(result) ws.append(conv.weights.detach().cpu().numpy()) if task_idx % 100 == 0: print(f"Finished generating task {task_idx}") xs, ys, ws = np.stack(xs), np.stack(ys), np.stack(ws) np.savez(out_path, x=xs, y=ys, w=ws)
def __init__(self): super(DenseFeatureExtractionModuleE2Inv, self).__init__() filters = np.array([32,32, 64,64, 128,128,128, 256,256,256, 512,512,512], dtype=np.int32)*2 # number of rotations to consider for rotation invariance N = 8 self.gspace = gspaces.Rot2dOnR2(N) self.input_type = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3) ip_op_types = [ self.input_type, ] self.num_channels = 64 for filter_ in filters[:10]: ip_op_types.append(FIELD_TYPE['regular'](self.gspace, filter_, fixparams=False)) self.model = enn.SequentialModule(*[ conv3x3(ip_op_types[0], ip_op_types[1]), enn.ReLU(ip_op_types[1], inplace=True), conv3x3(ip_op_types[1], ip_op_types[2]), enn.ReLU(ip_op_types[2], inplace=True), enn.PointwiseMaxPool(ip_op_types[2], 2), conv3x3(ip_op_types[2], ip_op_types[3]), enn.ReLU(ip_op_types[3], inplace=True), conv3x3(ip_op_types[3], ip_op_types[4]), enn.ReLU(ip_op_types[4], inplace=True), enn.PointwiseMaxPool(ip_op_types[4], 2), conv3x3(ip_op_types[4], ip_op_types[5]), enn.ReLU(ip_op_types[5], inplace=True), conv3x3(ip_op_types[5], ip_op_types[6]), enn.ReLU(ip_op_types[6], inplace=True), conv3x3(ip_op_types[6], ip_op_types[7]), enn.ReLU(ip_op_types[7], inplace=True), enn.PointwiseAvgPool(ip_op_types[7], kernel_size=2, stride=1), conv5x5(ip_op_types[7], ip_op_types[8]), enn.ReLU(ip_op_types[8], inplace=True), conv5x5(ip_op_types[8], ip_op_types[9]), enn.ReLU(ip_op_types[9], inplace=True), conv5x5(ip_op_types[9], ip_op_types[10]), enn.ReLU(ip_op_types[10], inplace=True), # enn.PointwiseMaxPool(ip_op_types[7], 2), # conv3x3(ip_op_types[7], ip_op_types[8]), # enn.ReLU(ip_op_types[8], inplace=True), # conv3x3(ip_op_types[8], ip_op_types[9]), # enn.ReLU(ip_op_types[9], inplace=True), # conv3x3(ip_op_types[9], ip_op_types[10]), # enn.ReLU(ip_op_types[10], inplace=True), enn.GroupPooling(ip_op_types[10]) ])
def regular_feature_type(gspace: gspaces.GSpace, planes: int): """ build a regular feature map with the specified number of channels""" assert gspace.fibergroup.order() > 0 N = gspace.fibergroup.order() if fixparams: planes *= math.sqrt(N) planes = planes / N planes = int(planes) return enn.FieldType(gspace, [gspace.regular_repr] * planes)
def __init__(self): super(ModelDilated, self).__init__() N = 8 self.gspace = gspaces.Rot2dOnR2(N) self.in_type = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3) self.out_type = enn.FieldType(self.gspace, [self.gspace.regular_repr] * 16) self.layer = enn.R2Conv( self.in_type, self.out_type, 3, stride=1, padding=2, dilation=2, bias=True, ) self.invariant = enn.GroupPooling(self.out_type)
def trivial_fiber(gspace: gspaces.GeneralOnR2, planes: int, field_type: int = 0, fixparams: bool = True): """ build a trivial fiber with the specified number of channels""" if fixparams: planes *= math.sqrt(gspace.fibergroup.order() * CHANNELS_CONSTANT) planes = int(planes) return enn.FieldType(gspace, [gspace.trivial_repr] * planes)
def irrep_fiber(gspace: gspaces.GeneralOnR2, planes: int, field_type: int = 0, fixparams: bool = True): """ build a irrep fiber with the specified number of channels""" assert gspace.fibergroup.order() < 0 N = gspace.fibergroup.order() planes = int(planes) if planes % 2 != 0: planes += 1 return enn.FieldType(gspace, [gspace.irrep(0)] * planes)
def _build_net(self, input_size): _, c, h, w = input_size transforms = [] _stacked_blocks = StackediResBlocks in_type = self.input_type my_i_dims = self.intermediate_dim out_type = FIBERS[self.out_fiber](self.group_action_type, my_i_dims, self.field_type, fixparams=True) for i in range(self.n_scale): transforms.append( _stacked_blocks( in_type, out_type, self.group_action_type, initial_size=(c, h, w), idim=my_i_dims, squeeze=False, #Can't change channels/fibers init_layer=self.init_layer if i == 0 else None, n_blocks=self.n_blocks[i], quadratic=self.quadratic, actnorm=self.actnorm, fc_actnorm=self.fc_actnorm, batchnorm=self.batchnorm, dropout=self.dropout, fc=self.fc, coeff=self.coeff, vnorms=self.vnorms, n_lipschitz_iters=self.n_lipschitz_iters, sn_atol=self.sn_atol, sn_rtol=self.sn_rtol, n_power_series=self.n_power_series, n_dist=self.n_dist, n_samples=self.n_samples, kernels=self.kernels, activation_fn=self.activation_fn, fc_end=self.fc_end, fc_idim=self.fc_idim, n_exact_terms=self.n_exact_terms, preact=self.preact, neumann_grad=self.neumann_grad, grad_in_forward=self.grad_in_forward, first_resblock=self.first_resblock and (i == 0), learn_p=self.learn_p, )) c, h, w = c * 2 if self.factor_out else c * 4, h // 2, w // 2 print("C: %d H: %d W: %d" % (c, h, w)) if i == self.n_scale - 1: out_type = enn.FieldType( self.group_action_type, self.c * [self.group_action_type.trivial_repr]) return nn.ModuleList(transforms)
def __init__(self, input_frames, output_frames, kernel_size, N): super(ResNet_Rot, self).__init__() r2_act = gspaces.Rot2dOnR2(N = N) # we use rho_1 representation since the input is velocity fields self.feat_type_in = nn.FieldType(r2_act, input_frames*[r2_act.irrep(1)]) # we use regular representation for middle layers self.feat_type_in_hid = nn.FieldType(r2_act, 16*[r2_act.regular_repr]) self.feat_type_hid_out = nn.FieldType(r2_act, 192*[r2_act.regular_repr]) self.feat_type_out = nn.FieldType(r2_act, output_frames*[r2_act.irrep(1)]) self.input_layer = nn.SequentialModule( nn.R2Conv(self.feat_type_in, self.feat_type_in_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2), nn.InnerBatchNorm(self.feat_type_in_hid), nn.ReLU(self.feat_type_in_hid) ) layers = [self.input_layer] layers += [rot_resblock(16, 32, kernel_size, N), rot_resblock(32, 32, kernel_size, N)] layers += [rot_resblock(32, 64, kernel_size, N), rot_resblock(64, 64, kernel_size, N)] layers += [rot_resblock(64, 128, kernel_size, N), rot_resblock(128, 128, kernel_size, N)] layers += [rot_resblock(128, 192, kernel_size, N), rot_resblock(192, 192, kernel_size, N)] layers += [nn.R2Conv(self.feat_type_hid_out, self.feat_type_out, kernel_size = kernel_size, padding = (kernel_size - 1)//2)] self.model = torch.nn.Sequential(*layers)