def build_steer_cnn_decoder( context_rep_ids, hidden_reps_ids, kernel_sizes, mean_rep_ids, covariance_activation="quadratic", N=4, flip=True, max_frequency=30, activation="relu", padding_mode="zeros", ): if flip: gspace = (gspaces.FlipRot2dOnR2(N=N) if N != -1 else gspaces.FlipRot2dOnR2(N=N, maximum_frequency=max_frequency)) else: gspace = (gspaces.Rot2dOnR2(N=N) if N != -1 else gspaces.Rot2dOnR2( N=N, maximum_frequency=max_frequency)) in_field_type = gnn.FieldType( gspace, [gspace.trivial_repr, *reps_from_ids(gspace, context_rep_ids)]) hidden_field_types = [ gnn.FieldType(gspace, reps_from_ids(gspace, ids)) for ids in hidden_reps_ids ] mean_field_type = gnn.FieldType(gspace, reps_from_ids(gspace, mean_rep_ids)) pre_covariance_field_type = get_pre_covariance_field_type( gspace, mean_field_type, covariance_activation) out_field_type = mean_field_type + pre_covariance_field_type init_modify = (1.0 if not (mean_rep_ids == [[0]] and context_rep_ids == [[0]]) else 0.833, ) if N == -1: init_modify = 1.0 return build_steer_cnn_2d( in_field_type, hidden_field_types, kernel_sizes, out_field_type, gspace, activation, padding_mode, )
def __init__(self, input_channels, output_channels, kernel_size, stride, N, activation = True, deconv = False, last_deconv = False): super(rot_conv2d, self).__init__() r2_act = gspaces.Rot2dOnR2(N = N) feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr]) feat_type_hid = nn.FieldType(r2_act, output_channels*[r2_act.regular_repr]) if not deconv: if activation: self.layer = nn.SequentialModule( nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = (kernel_size - 1)//2), nn.InnerBatchNorm(feat_type_hid), nn.ReLU(feat_type_hid) ) else: self.layer = nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride,padding = (kernel_size - 1)//2) else: if last_deconv: feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr]) feat_type_hid = nn.FieldType(r2_act, output_channels*[r2_act.irrep(1)]) self.layer = nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = 0) else: self.layer = nn.SequentialModule( nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = 0), nn.InnerBatchNorm(feat_type_hid), nn.ReLU(feat_type_hid) )
def restrict(self, id: int) -> Tuple[gspaces.GSpace, Callable, Callable]: r""" Build the :class:`~e2cnn.group.GSpace` associated with the subgroup of the current fiber group identified by the input ``id``. ``id`` is a positive integer :math:`M` indicating the number of rotations in the subgroup. If the current fiber group is :math:`C_N` (:class:`~e2cnn.group.CyclicGroup`), then :math:`M` needs to divide :math:`N`. Otherwise, :math:`M` can be any positive integer. Args: id (int): the number :math:`M` of rotations in the subgroup Returns: a tuple containing - **gspace**: the restricted gspace - **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space - **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup) """ subgroup, mapping, child = self.fibergroup.subgroup(id) if id > 1: return gspaces.Rot2dOnR2(fibergroup=subgroup), mapping, child elif id == 1: return gspaces.TrivialOnR2(fibergroup=subgroup), mapping, child else: raise ValueError(f"id {id} not recognized!")
def __init__(self, input_channels, hidden_dim, kernel_size, N # Group size ): super(rot_resblock, self).__init__() # Specify symmetry transformation r2_act = gspaces.Rot2dOnR2(N = N) feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr]) feat_type_hid = nn.FieldType(r2_act, hidden_dim*[r2_act.regular_repr]) self.layer1 = nn.SequentialModule( nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2), nn.InnerBatchNorm(feat_type_hid), nn.ReLU(feat_type_hid) ) self.layer2 = nn.SequentialModule( nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2), nn.InnerBatchNorm(feat_type_hid), nn.ReLU(feat_type_hid) ) self.upscale = nn.SequentialModule( nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2), nn.InnerBatchNorm(feat_type_hid), nn.ReLU(feat_type_hid) ) self.input_channels = input_channels self.hidden_dim = hidden_dim
def __init__(self, input_frames, output_frames, kernel_size, N): super(Unet_Rot, self).__init__() r2_act = gspaces.Rot2dOnR2(N = N) self.feat_type_in = nn.FieldType(r2_act, input_frames*[r2_act.irrep(1)]) self.feat_type_in_hid = nn.FieldType(r2_act, 32*[r2_act.regular_repr]) self.feat_type_hid_out = nn.FieldType(r2_act, (16 + input_frames)*[r2_act.irrep(1)]) self.feat_type_out = nn.FieldType(r2_act, output_frames*[r2_act.irrep(1)]) self.conv1 = nn.SequentialModule( nn.R2Conv(self.feat_type_in, self.feat_type_in_hid, kernel_size = kernel_size, stride = 2, padding = (kernel_size - 1)//2), nn.InnerBatchNorm(self.feat_type_in_hid), nn.ReLU(self.feat_type_in_hid) ) self.conv2 = rot_conv2d(32, 64, kernel_size = kernel_size, stride = 1, N = N) self.conv2_1 = rot_conv2d(64, 64, kernel_size = kernel_size, stride = 1, N = N) self.conv3 = rot_conv2d(64, 128, kernel_size = kernel_size, stride = 2, N = N) self.conv3_1 = rot_conv2d(128, 128, kernel_size = kernel_size, stride = 1, N = N) self.conv4 = rot_conv2d(128, 256, kernel_size = kernel_size, stride = 2, N = N) self.conv4_1 = rot_conv2d(256, 256, kernel_size = kernel_size, stride = 1, N = N) self.deconv3 = rot_deconv2d(256, 64, N) self.deconv2 = rot_deconv2d(192, 32, N) self.deconv1 = rot_deconv2d(96, 16, N, last_deconv = True) self.output_layer = nn.R2Conv(self.feat_type_hid_out, self.feat_type_out, kernel_size = kernel_size, padding = (kernel_size - 1)//2)
def get_model(hparams): if hparams['model'] == 'discrete': return C4UNet(hparams['in_channels'], hparams['out_channels'], group_type=hparams['group'], N=hparams['N'], n_features=hparams['n_features'], loss_type=hparams['loss_func'], lr=hparams['lr']) elif hparams['model'] == 'standard': return UNet(hparams['in_channels'], hparams['out_channels'], n_features=hparams['n_features'], loss_type=hparams['loss_func'], lr=hparams['lr']) elif hparams['model'] == 'harmonic': return HarmonicUNet(hparams['in_channels'], hparams['out_channels'], n_features=hparams['n_features'], group_type=hparams['group'], max_freq=hparams['max_freq'], loss_type=hparams['loss_func'], lr=hparams['lr']) elif hparams['model'] == 'steerable': gspace = gspaces.Rot2dOnR2(-1, maximum_frequency=hparams['max_freq']) return SteerableCNN(gspace, hparams['in_channels'], hparams['out_channels'], n_blocks=hparams['n_blocks'], n_features=hparams['n_features'], irrep_type=hparams['irrep_type'], loss_type=hparams['loss_func'], lr=hparams['lr']) else: raise ValueError(f'Unsupported model type: {hparams["model"]}')
def __init__(self, channel_in=3, n_classes=4, rot_n=4): super(SmallE2, self).__init__() r2_act = gspaces.Rot2dOnR2(N=rot_n) self.feat_type_in = nn.FieldType(r2_act, channel_in * [r2_act.trivial_repr]) feat_type_hid = nn.FieldType(r2_act, 8 * [r2_act.regular_repr]) feat_type_out = nn.FieldType(r2_act, 2 * [r2_act.regular_repr]) self.bn = nn.InnerBatchNorm(feat_type_hid) self.relu = nn.ReLU(feat_type_hid) self.convin = nn.R2Conv(self.feat_type_in, feat_type_hid, kernel_size=3) self.convhid = nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size=3) self.convout = nn.R2Conv(feat_type_hid, feat_type_out, kernel_size=3) self.avgpool = nn.PointwiseAvgPool(feat_type_out, 3) self.invariant_map = nn.GroupPooling(feat_type_out) c = self.invariant_map.out_type.size self.lin_in = torch.nn.Linear(c, 64) self.elu = torch.nn.ELU() self.lin_out = torch.nn.Linear(64, n_classes)
def __init__(self): super(DenseFeatureExtractionModuleE2Inv, self).__init__() filters = np.array([32,32, 64,64, 128,128,128, 256,256,256, 512,512,512], dtype=np.int32)*2 # number of rotations to consider for rotation invariance N = 8 self.gspace = gspaces.Rot2dOnR2(N) self.input_type = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3) ip_op_types = [ self.input_type, ] self.num_channels = 64 for filter_ in filters[:10]: ip_op_types.append(FIELD_TYPE['regular'](self.gspace, filter_, fixparams=False)) self.model = enn.SequentialModule(*[ conv3x3(ip_op_types[0], ip_op_types[1]), enn.ReLU(ip_op_types[1], inplace=True), conv3x3(ip_op_types[1], ip_op_types[2]), enn.ReLU(ip_op_types[2], inplace=True), enn.PointwiseMaxPool(ip_op_types[2], 2), conv3x3(ip_op_types[2], ip_op_types[3]), enn.ReLU(ip_op_types[3], inplace=True), conv3x3(ip_op_types[3], ip_op_types[4]), enn.ReLU(ip_op_types[4], inplace=True), enn.PointwiseMaxPool(ip_op_types[4], 2), conv3x3(ip_op_types[4], ip_op_types[5]), enn.ReLU(ip_op_types[5], inplace=True), conv3x3(ip_op_types[5], ip_op_types[6]), enn.ReLU(ip_op_types[6], inplace=True), conv3x3(ip_op_types[6], ip_op_types[7]), enn.ReLU(ip_op_types[7], inplace=True), enn.PointwiseAvgPool(ip_op_types[7], kernel_size=2, stride=1), conv5x5(ip_op_types[7], ip_op_types[8]), enn.ReLU(ip_op_types[8], inplace=True), conv5x5(ip_op_types[8], ip_op_types[9]), enn.ReLU(ip_op_types[9], inplace=True), conv5x5(ip_op_types[9], ip_op_types[10]), enn.ReLU(ip_op_types[10], inplace=True), # enn.PointwiseMaxPool(ip_op_types[7], 2), # conv3x3(ip_op_types[7], ip_op_types[8]), # enn.ReLU(ip_op_types[8], inplace=True), # conv3x3(ip_op_types[8], ip_op_types[9]), # enn.ReLU(ip_op_types[9], inplace=True), # conv3x3(ip_op_types[9], ip_op_types[10]), # enn.ReLU(ip_op_types[10], inplace=True), enn.GroupPooling(ip_op_types[10]) ])
def __init__(self, n_classes=6): super(SteerCNN, self).__init__() # the model is equivariant under rotations by 45 degrees, modelled by C8 self.r2_act = gspaces.Rot2dOnR2(N=4) # the input image is a scalar field, corresponding to the trivial representation input_type = nn_e2.FieldType(self.r2_act, 3 * [self.r2_act.trivial_repr]) # we store the input type for wrapping the images into a geometric tensor during the forward pass self.input_type = input_type # convolution 1 # first specify the output type of the convolutional layer # we choose 24 feature fields, each transforming under the regular representation of C8 out_type = nn_e2.FieldType(self.r2_act, 24 * [self.r2_act.regular_repr]) self.block1 = nn_e2.SequentialModule( nn_e2.R2Conv(input_type, out_type, kernel_size=7, padding=3, bias=False), nn_e2.InnerBatchNorm(out_type), nn_e2.ReLU(out_type, inplace=True)) self.pool1 = nn_e2.PointwiseAvgPool(out_type, 4) # convolution 2 # the old output type is the input type to the next layer in_type = self.block1.out_type # the output type of the second convolution layer are 48 regular feature fields of C8 #out_type = nn_e2.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr]) self.block2 = nn_e2.SequentialModule( nn_e2.R2Conv(in_type, out_type, kernel_size=7, padding=3, bias=False), nn_e2.InnerBatchNorm(out_type), nn_e2.ReLU(out_type, inplace=True)) self.pool2 = nn_e2.SequentialModule( nn_e2.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=1, padding=0), nn_e2.PointwiseAvgPool(out_type, 4), nn_e2.GroupPooling(out_type)) # PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=7) # number of output channels c = 24 * 13 * 13 #self.gpool.out_type.size # Fully Connected self.fully_net = torch.nn.Sequential( torch.nn.Linear(c, 64), torch.nn.BatchNorm1d(64), torch.nn.ELU(inplace=True), torch.nn.Linear(64, n_classes), )
def get_gspace(kwargs): group_type, N = kwargs['group_type'], kwargs['N'] del kwargs['group_type'] del kwargs['N'] if group_type == 'circle': return gspaces.Rot2dOnR2(N) elif group_type == 'dihedral': return gspaces.FlipRot2dOnR2(N) else: raise ValueError( 'Discrete group argument "group" must be one of ["circle", "dihedral"]' )
def __init__(self, nclasses=1): super(ResNet50, self).__init__() self.gspace = gspaces.Rot2dOnR2(N=8) reg_field64 = FIELD_TYPE["regular"](self.gspace, 64, fixparams=False) reg_field256 = FIELD_TYPE["regular"](self.gspace, 256, fixparams=False) reg_field128 = FIELD_TYPE["regular"](self.gspace, 128, fixparams=False) reg_field512 = FIELD_TYPE["regular"](self.gspace, 512, fixparams=False) reg_field1024 = FIELD_TYPE["regular"](self.gspace, 1024, fixparams=False) reg_field2048 = FIELD_TYPE["regular"](self.gspace, 2048, fixparams=False) self.conv1 = enn.R2Conv(FIELD_TYPE["trivial"](self.gspace, 3, fixparams=False), reg_field64, kernel_size=7, stride=2, padding=3) self.bn1 = enn.InnerBatchNorm(reg_field64) self.relu1 = enn.ELU(reg_field64) self.maxpool1 = enn.PointwiseMaxPoolAntialiased(reg_field64, kernel_size=2) layer1 = [] layer1.append(ResBlock(stride=2, in_type = reg_field64, inner_type = reg_field64, out_type = reg_field256)) layer1.append(ResBlock(stride=1, in_type = reg_field256, inner_type = reg_field64, out_type = reg_field256)) layer1.append(ResBlock(stride=1, in_type = reg_field256, inner_type = reg_field64, out_type = reg_field256)) self.layer1 = torch.nn.Sequential(*layer1) layer2 = [] layer2.append(ResBlock(stride=2, in_type = reg_field256, inner_type = reg_field128, out_type = reg_field512)) layer2.append(ResBlock(stride=1, in_type = reg_field512, inner_type = reg_field128, out_type = reg_field512)) layer2.append(ResBlock(stride=1, in_type = reg_field512, inner_type = reg_field128, out_type = reg_field512)) layer2.append(ResBlock(stride=1, in_type = reg_field512, inner_type = reg_field128, out_type = reg_field512)) self.layer2 = torch.nn.Sequential(*layer2) layer3 = [] layer3.append(ResBlock(stride=2, in_type = reg_field512, inner_type = reg_field256, out_type = reg_field1024)) layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024)) layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024)) layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024)) layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024)) layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024)) self.layer3 = torch.nn.Sequential(*layer3) layer4 = [] layer4.append(ResBlock(stride=2, in_type = reg_field1024, inner_type = reg_field512, out_type = reg_field2048)) layer4.append(ResBlock(stride=1, in_type = reg_field2048, inner_type = reg_field512, out_type = reg_field2048)) layer4.append(ResBlock(stride=1, in_type = reg_field2048, inner_type = reg_field512, out_type = reg_field2048)) self.layer4 = torch.nn.Sequential(*layer4) self.pool = torch.nn.AdaptiveAvgPool2d((1, 1)) self.fc = torch.nn.Linear(2048, nclasses)
def generate_2d_rot8(out_path): r2_act = gspaces.Rot2dOnR2(N=8) feat_type_in = gnn.FieldType(r2_act, [r2_act.trivial_repr]) feat_type_out = gnn.FieldType(r2_act, 3 * [r2_act.regular_repr]) conv = gnn.R2Conv(feat_type_in, feat_type_out, kernel_size=3, bias=False) xs, ys, ws = [], [], [] for task_idx in range(10000): gnn.init.generalized_he_init(conv.weights, conv.basisexpansion) inp = gnn.GeometricTensor(torch.randn(20, 1, 32, 32), feat_type_in) result = conv(inp).tensor.detach().cpu().numpy() xs.append(inp.tensor.detach().cpu().numpy()) ys.append(result) ws.append(conv.weights.detach().cpu().numpy()) if task_idx % 100 == 0: print(f"Finished generating task {task_idx}") xs, ys, ws = np.stack(xs), np.stack(ys), np.stack(ws) np.savez(out_path, x=xs, y=ys, w=ws)
def __init__(self): super(ModelDilated, self).__init__() N = 8 self.gspace = gspaces.Rot2dOnR2(N) self.in_type = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3) self.out_type = enn.FieldType(self.gspace, [self.gspace.regular_repr] * 16) self.layer = enn.R2Conv( self.in_type, self.out_type, 3, stride=1, padding=2, dilation=2, bias=True, ) self.invariant = enn.GroupPooling(self.out_type)
def __init__(self, in_chan, out_chan, imsize, kernel_size=5, N=8): super(CNSteerableLeNet, self).__init__() kernel_size = int(kernel_size) z = imsize // 2 // 2 self.r2_act = gspaces.Rot2dOnR2(N) in_type = e2nn.FieldType(self.r2_act, [self.r2_act.trivial_repr]) self.input_type = in_type out_type = e2nn.FieldType(self.r2_act, 6 * [self.r2_act.regular_repr]) self.mask = e2nn.MaskModule(in_type, imsize, margin=1) self.conv1 = e2nn.R2Conv(in_type, out_type, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) self.relu1 = e2nn.ReLU(out_type, inplace=True) self.pool1 = e2nn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2) in_type = self.pool1.out_type out_type = e2nn.FieldType(self.r2_act, 16 * [self.r2_act.regular_repr]) self.conv2 = e2nn.R2Conv(in_type, out_type, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) self.relu2 = e2nn.ReLU(out_type, inplace=True) self.pool2 = e2nn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2) self.gpool = e2nn.GroupPooling(out_type) self.fc1 = nn.Linear(16 * z * z, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, out_chan) self.drop = nn.Dropout(p=0.5) # dummy parameter for tracking device self.dummy = nn.Parameter(torch.empty(0))
def __init__(self, growth_rate, list_layer, nclasses): super(DenseNet161, self).__init__() self.gspace = gspaces.Rot2dOnR2(N=8) in_type = 2*growth_rate self.conv1 = conv7x7(FIELD_TYPE["trivial"](self.gspace, 3, fixparams=False), FIELD_TYPE["regular"](self.gspace, in_type, fixparams=False)) self.pool1 = enn.PointwiseMaxPool(FIELD_TYPE["regular"](self.gspace, in_type, fixparams=False), kernel_size=2, stride=2) #1st block self.block1 = DenseBlock(in_type, growth_rate, self.gspace, list_layer[0]) in_type = in_type +list_layer[0]*growth_rate self.trans1 = TransitionBlock(in_type, int(in_type/2), self.gspace) in_type = int(in_type/2) #2nd block self.block2 = DenseBlock(in_type, growth_rate, self.gspace, list_layer[1]) in_type = in_type +list_layer[1]*growth_rate self.trans2 = TransitionBlock(in_type, int(in_type/2), self.gspace) in_type = int(in_type/2) #3rd block self.block3 = DenseBlock(in_type, growth_rate, self.gspace, list_layer[2]) in_type = in_type +list_layer[2]*growth_rate self.trans3 = TransitionBlock(in_type, int(in_type/2), self.gspace) in_type = int(in_type/2) #4th block self.block4 = DenseBlock(in_type, growth_rate, self.gspace, list_layer[3]) in_type = in_type +list_layer[3]*growth_rate self.bn = enn.InnerBatchNorm(FIELD_TYPE["regular"](self.gspace, in_type, fixparams=False)) self.relu = enn.ReLU(FIELD_TYPE["regular"](self.gspace, in_type, fixparams=False),inplace=True) self.pool2 = torch.nn.AdaptiveAvgPool2d((1, 1)) self.classifier = torch.nn.Linear(in_type, nclasses)
def __init__(self, input_frames, output_frames, kernel_size, N): super(ResNet_Rot, self).__init__() r2_act = gspaces.Rot2dOnR2(N = N) # we use rho_1 representation since the input is velocity fields self.feat_type_in = nn.FieldType(r2_act, input_frames*[r2_act.irrep(1)]) # we use regular representation for middle layers self.feat_type_in_hid = nn.FieldType(r2_act, 16*[r2_act.regular_repr]) self.feat_type_hid_out = nn.FieldType(r2_act, 192*[r2_act.regular_repr]) self.feat_type_out = nn.FieldType(r2_act, output_frames*[r2_act.irrep(1)]) self.input_layer = nn.SequentialModule( nn.R2Conv(self.feat_type_in, self.feat_type_in_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2), nn.InnerBatchNorm(self.feat_type_in_hid), nn.ReLU(self.feat_type_in_hid) ) layers = [self.input_layer] layers += [rot_resblock(16, 32, kernel_size, N), rot_resblock(32, 32, kernel_size, N)] layers += [rot_resblock(32, 64, kernel_size, N), rot_resblock(64, 64, kernel_size, N)] layers += [rot_resblock(64, 128, kernel_size, N), rot_resblock(128, 128, kernel_size, N)] layers += [rot_resblock(128, 192, kernel_size, N), rot_resblock(192, 192, kernel_size, N)] layers += [nn.R2Conv(self.feat_type_hid_out, self.feat_type_out, kernel_size = kernel_size, padding = (kernel_size - 1)//2)] self.model = torch.nn.Sequential(*layers)
# pool over the group x = self.gpool(x) # unwrap the output GeometricTensor # (take the Pytorch tensor and discard the associated representation) x = x.tensor # classify with the final fully connected layers) x = self.fully_net(x.reshape(x.shape[0], -1)) return x if __name__ == "__main__": r2_act = gspaces.Rot2dOnR2(N=8) # the input image is a scalar field, corresponding to the trivial representation in_type = nn.FieldType(r2_act, 3 * [r2_act.trivial_repr]) device = 'cuda' if torch.cuda.is_available() else 'cpu' dummy = torch.rand(64, 3, 32, 32).to(device) y = nn.GeometricTensor(dummy, in_type) # TEST 2 r2_act = gspaces.Rot2dOnR2(N=16) # 5 print(r2_act.trivial_repr) feat_type_in = nn.FieldType(r2_act, 3 * [r2_act.trivial_repr]) # 6 feat_type_out = nn.FieldType(r2_act, 10 * [r2_act.regular_repr]) # 7 # 8 conv = nn.R2Conv(feat_type_in, feat_type_out, kernel_size=5) # 9
def __init__(self, base = 'DNSteerableAGRadGalNet', attention_module='SelfAttention', attention_gates=3, attention_aggregation='ft', n_classes=2, attention_normalisation='sigmoid', quiet=True, number_rotations=8, imsize=150, kernel_size=3, group="D" ): super(DNSteerableAGRadGalNet, self).__init__() aggregation_mode = attention_aggregation normalisation = attention_normalisation AG = int(attention_gates) N = int(number_rotations) kernel_size = int(kernel_size) imsize = int(imsize) n_classes = int(n_classes) assert aggregation_mode in ['concat', 'mean', 'deep_sup', 'ft'], 'Aggregation mode not recognised. Valid inputs include concat, mean, deep_sup or ft.' assert normalisation in ['sigmoid','range_norm','std_mean_norm','tanh','softmax'], f'Nomralisation not implemented. Can be any of: sigmoid, range_norm, std_mean_norm, tanh, softmax' assert AG in [0,1,2,3], f'Number of Attention Gates applied (AG) must be an integer in range [0,3]. Currently AG={AG}' assert group.lower() in ["d","c"], f"group parameter must either be 'D' for DN, or 'C' for CN, steerable networks. (currently {group})." filters = [6,16,32,64,128] self.attention_out_sizes = [] self.ag = AG self.n_classes = n_classes self.filters = filters self.aggregation_mode = aggregation_mode # Setting up e2 if group.lower() == "d": self.r2_act = gspaces.FlipRot2dOnR2(N=int(number_rotations)) else: self.r2_act = gspaces.Rot2dOnR2(N=int(number_rotations)) in_type = e2nn.FieldType(self.r2_act, [self.r2_act.trivial_repr]) out_type = e2nn.FieldType(self.r2_act, 6*[self.r2_act.regular_repr]) self.in_type = in_type self.mask = e2nn.MaskModule(in_type, imsize, margin=0) self.conv1a = e2nn.R2Conv(in_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu1a = e2nn.ReLU(out_type); self.bnorm1a= e2nn.InnerBatchNorm(out_type) self.conv1b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu1b = e2nn.ReLU(out_type); self.bnorm1b= e2nn.InnerBatchNorm(out_type) self.conv1c = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu1c = e2nn.ReLU(out_type); self.bnorm1c= e2nn.InnerBatchNorm(out_type) self.mpool1 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2) self.gpool1 = e2nn.GroupPooling(out_type) in_type = out_type out_type = e2nn.FieldType(self.r2_act, 16*[self.r2_act.regular_repr]) self.conv2a = e2nn.R2Conv(in_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu2a = e2nn.ReLU(out_type); self.bnorm2a= e2nn.InnerBatchNorm(out_type) self.conv2b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu2b = e2nn.ReLU(out_type); self.bnorm2b= e2nn.InnerBatchNorm(out_type) self.conv2c = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu2c = e2nn.ReLU(out_type); self.bnorm2c= e2nn.InnerBatchNorm(out_type) self.mpool2 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2) self.gpool2 = e2nn.GroupPooling(out_type) in_type = out_type out_type = e2nn.FieldType(self.r2_act, 32*[self.r2_act.regular_repr]) self.conv3a = e2nn.R2Conv(in_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu3a = e2nn.ReLU(out_type); self.bnorm3a= e2nn.InnerBatchNorm(out_type) self.conv3b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu3b = e2nn.ReLU(out_type); self.bnorm3b= e2nn.InnerBatchNorm(out_type) self.conv3c = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu3c = e2nn.ReLU(out_type); self.bnorm3c= e2nn.InnerBatchNorm(out_type) self.mpool3 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2) self.gpool3 = e2nn.GroupPooling(out_type) in_type = out_type out_type = e2nn.FieldType(self.r2_act, 64*[self.r2_act.regular_repr]) self.conv4a = e2nn.R2Conv(in_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu4a = e2nn.ReLU(out_type); self.bnorm4a= e2nn.InnerBatchNorm(out_type) self.conv4b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu4b = e2nn.ReLU(out_type); self.bnorm4b= e2nn.InnerBatchNorm(out_type) self.mpool4 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2) self.gpool4 = e2nn.GroupPooling(out_type) self.flatten = nn.Flatten(1) self.dropout = nn.Dropout(p=0.5) if self.ag == 0: pass if self.ag >= 1: self.attention1 = GridAttentionBlock2D(in_channels=32, gating_channels=64, inter_channels=64, input_size=[imsize//4,imsize//4], normalisation=normalisation) if self.ag >= 2: self.attention2 = GridAttentionBlock2D(in_channels=16, gating_channels=64, inter_channels=64, input_size=[imsize//2,imsize//2], normalisation=normalisation) if self.ag >= 3: self.attention3 = GridAttentionBlock2D(in_channels=6, gating_channels=64, inter_channels=64, input_size=[imsize,imsize], normalisation=normalisation) self.fc1 = nn.Linear(16*5*5,256) #channel_size * width * height self.fc2 = nn.Linear(256,256) self.fc3 = nn.Linear(256, self.n_classes) self.dummy = nn.Parameter(torch.empty(0)) self.module_order = ['conv1a', 'relu1a', 'bnorm1a', #1->6 'conv1b', 'relu1b', 'bnorm1b', #6->6 'conv1c', 'relu1c', 'bnorm1c', #6->6 'mpool1', 'conv2a', 'relu2a', 'bnorm2a', #6->16 'conv2b', 'relu2b', 'bnorm2b', #16->16 'conv2c', 'relu2c', 'bnorm2c', #16->16 'mpool2', 'conv3a', 'relu3a', 'bnorm3a', #16->32 'conv3b', 'relu3b', 'bnorm3b', #32->32 'conv3c', 'relu3c', 'bnorm3c', #32->32 'mpool3', 'conv4a', 'relu4a', 'bnorm4a', #32->64 'conv4b', 'relu4b', 'bnorm4b', #64->64 'compatibility_score1', 'compatibility_score2'] ######################### # Aggreagation Strategies if self.ag != 0: self.attention_filter_sizes = [32, 16, 6] concat_length = 0 for i in range(self.ag): concat_length += self.attention_filter_sizes[i] if aggregation_mode == 'concat': self.classifier = nn.Linear(concat_length, self.n_classes) self.aggregate = self.aggregation_concat else: # Not able to initialise in a loop as the modules will not change device with remaining model. self.classifiers = nn.ModuleList() if self.ag>=1: self.classifiers.append(nn.Linear(self.attention_filter_sizes[0], self.n_classes)) if self.ag>=2: self.classifiers.append(nn.Linear(self.attention_filter_sizes[1], self.n_classes)) if self.ag>=3: self.classifiers.append(nn.Linear(self.attention_filter_sizes[2], self.n_classes)) if aggregation_mode == 'mean': self.aggregate = self.aggregation_sep elif aggregation_mode == 'deep_sup': self.classifier = nn.Linear(concat_length, self.n_classes) self.aggregate = self.aggregation_ds elif aggregation_mode == 'ft': self.classifier = nn.Linear(self.n_classes*self.ag, self.n_classes) self.aggregate = self.aggregation_ft else: raise NotImplementedError else: self.classifier = nn.Linear((150//16)**2*64, self.n_classes) self.aggregate = lambda x: self.classifier(self.flatten(x))
from mnist_rot_dataset import RotatedMNISTDataset, TransformedDataset from models import C8Backbone3x3, Backbone5x5, ClassificationHead, RegressionHead from functools import partial import matplotlib.pyplot as plt torch.manual_seed(23) torch.cuda.manual_seed(23) np.random.seed(23) device = 'cuda' if torch.cuda.is_available() else 'cpu' batch_size = 128 group=gspaces.Rot2dOnR2(N=8) group=gspaces.FlipRot2dOnR2(N=8) conv_func = nn.R2Conv # conv_func = sscnn.e2cnn.SSConv # conv_func = sscnn.e2cnn.PlainConv backbone = Backbone5x5(out_channels=2, conv_func=conv_func, group=group) head = RegressionHead(backbone.out_type, conv_func) dataset_func = partial(torchvision.datasets.MNIST, transform=transforms.ToTensor()) # dataset_func = partial(torchvision.datasets.CIFAR10, # transform=transforms.Compose([ # transforms.Grayscale(num_output_channels=1), # transforms.ToTensor(), # transforms.Normalize(mean=[0.5], std=[0.5])
def train(args): torch.manual_seed(23) torch.cuda.manual_seed(23) np.random.seed(23) device = 'cuda' if torch.cuda.is_available() else 'cpu' batch_size = 128 if args.reflection == 1: group = gspaces.Rot2dOnR2(N=args.rotation) elif args.reflection == 2: group = gspaces.FlipRot2dOnR2(N=args.rotation) else: raise NotImplementedError if args.conv == 'R2Conv': conv_func = nn.R2Conv elif args.conv == 'SSConv': conv_func = sscnn.e2cnn.SSConv elif args.conv == 'PlainConv': conv_func = sscnn.e2cnn.PlainConv else: raise NotImplementedError if args.dataset == 'MNIST': train_dataset = TransformedDataset( torchvision.datasets.MNIST( '.', train=True, download=True, transform=transforms.ToTensor(), ), random_rotate=args.rotate_train, random_reflect=args.reflect_train, ) test_dataset = TransformedDataset( torchvision.datasets.MNIST( '.', train=False, download=True, transform=transforms.ToTensor(), ), random_rotate=args.rotate_test, random_reflect=args.reflect_test, ) in_channels = 1 num_classes = 10 elif args.dataset == 'KMNIST': train_dataset = TransformedDataset( torchvision.datasets.KMNIST( '.', train=True, download=True, transform=transforms.ToTensor(), ), random_rotate=args.rotate_train, random_reflect=args.reflect_train, ) test_dataset = TransformedDataset( torchvision.datasets.KMNIST( '.', train=False, download=True, transform=transforms.ToTensor(), ), random_rotate=args.rotate_test, random_reflect=args.reflect_test, ) in_channels = 1 num_classes = 10 elif args.dataset == 'EMNIST': train_dataset = TransformedDataset( torchvision.datasets.EMNIST( '.', split='balanced', train=True, download=True, transform=transforms.ToTensor(), ), random_rotate=args.rotate_train, random_reflect=args.reflect_train, ) test_dataset = TransformedDataset( torchvision.datasets.EMNIST( '.', split='balanced', train=False, download=True, transform=transforms.ToTensor(), ), random_rotate=args.rotate_test, random_reflect=args.reflect_test, ) in_channels = 1 num_classes = 50 elif args.dataset == 'FMNIST': train_dataset = TransformedDataset( torchvision.datasets.FashionMNIST( '.', train=True, download=True, transform=transforms.ToTensor(), ), random_rotate=args.rotate_train, random_reflect=args.reflect_train, ) test_dataset = TransformedDataset( torchvision.datasets.FashionMNIST( '.', train=False, download=True, transform=transforms.ToTensor(), ), random_rotate=args.rotate_test, random_reflect=args.reflect_test, ) in_channels = 1 num_classes = 10 elif args.dataset == 'CIFAR10': train_dataset = TransformedDataset( torchvision.datasets.CIFAR10('.', train=True, download=True, transform=transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])), random_rotate=args.rotate_train, random_reflect=args.reflect_train, disk_masked=(args.task == 'regression'), ) test_dataset = TransformedDataset( torchvision.datasets.CIFAR10('.', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])), random_rotate=args.rotate_test, random_reflect=args.reflect_test, disk_masked=(args.task == 'regression'), ) in_channels = 3 num_classes = 10 elif args.dataset == 'CIFAR100': train_dataset = TransformedDataset( torchvision.datasets.CIFAR100( '.', train=True, download=True, transform=transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), ])), random_rotate=args.rotate_train, random_reflect=args.reflect_train, disk_masked=(args.task == 'regression'), ) test_dataset = TransformedDataset( torchvision.datasets.CIFAR100('.', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize( (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), ])), random_rotate=args.rotate_test, random_reflect=args.reflect_test, disk_masked=(args.task == 'regression'), ) num_classes = 100 elif args.dataset == 'STL10': train_dataset = TransformedDataset( torchvision.datasets.STL10('.', split='train', download=True, transform=transforms.Compose([ transforms.RandomCrop(96, padding=12), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ])), random_rotate=args.rotate_train, random_reflect=args.reflect_train, disk_masked=(args.task == 'regression'), ) test_dataset = TransformedDataset( torchvision.datasets.STL10('.', split='test', download=True, transform=transforms.Compose([ transforms.ToTensor(), ])), random_rotate=args.rotate_test, random_reflect=args.reflect_test, disk_masked=(args.task == 'regression'), ) num_classes = 50 batch_size = 12 in_channels = 3 else: raise NotImplementedError if args.backbone == 'B5': backbone = Backbone5x5(conv_func=conv_func, group=group, in_channels=in_channels) elif args.backbone == 'WRN': backbone = Wide_ResNet(28, 10, 0.3, initial_stride=2, N=args.rotation, f=(args.reflection == 2), r=0, conv_func=conv_func, fixparams=False, in_channels=in_channels) else: raise NotImplementedError if args.task == 'classification': head = ClassificationHead(backbone.out_type, num_classes=num_classes) cross_entropy_loss = torch.nn.CrossEntropyLoss() loss_function = lambda y, l, v: cross_entropy_loss(y, l) eval_function = mask_of_success elif args.task == 'regression': head = RegressionHead(backbone.out_type, conv_func) mse_loss = torch.nn.MSELoss() loss_function = lambda y, l, v: mse_loss(y, v) eval_function = abs_included_angles else: raise NotImplementedError train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size) model = nn.SequentialModule( OrderedDict([('backbone', backbone), ('head', head)])) model = model.to(device) if args.backbone == 'B5': # args.dataset.endswith('MNIST'): optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-5) scheduler = None max_epochs = 60 elif args.backbone == 'WRN': # elif args.dataset == 'CIFAR10' or args.dataset == 'STL10': if args.dataset == 'STL10': base_lr = 2e-3 else: base_lr = 1e-2 optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=45, gamma=0.2) max_epochs = 260 else: raise NotImplementedError file_name = '_'.join([str(_) for _ in args.__dict__.values()]) ckpt_name = file_name + '.pth' log_name = file_name + '.txt' log_file = open(log_name, 'w') for epoch in range(max_epochs + 2): model.train() if device == 'cuda': start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() else: start = time.time() for i, (x, l, v) in enumerate(train_loader): optimizer.zero_grad() x = x.to(device) l = l.to(device) v = v.to(device) y = model(nn.GeometricTensor(x, backbone.input_type)) y = y.tensor.flatten(1, -1) loss = loss_function(y, l, v) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 10) # nv = [(n, v.abs().max().item()) for (n, v) in model.named_parameters()] # imax = max(range(len(nv)), key=lambda x: nv[x][1]) # print(loss.item(), nv[imax]) optimizer.step() # if i > 50: # break if scheduler: scheduler.step() for param_group in optimizer.param_groups: lr = param_group['lr'] print('lr = ', param_group['lr']) if device == 'cuda': end.record() torch.cuda.synchronize() print('Epoch', start.elapsed_time(end)) else: end = time.time() print('Epoch', end - start) if epoch % 5 == 0: errors = [] with torch.no_grad(): model.eval() for i, (x, l, v) in enumerate(test_loader): x = x.to(device) l = l.to(device) v = v.to(device) y = model(nn.GeometricTensor(x, backbone.input_type)) y = y.tensor.flatten(1, -1) res = eval_function(y, l, v) errors.extend(res) error = np.array(errors).mean() print(f"epoch {epoch} | tes : {error}") log_file.write(f"epoch {epoch} | acc : {error} | lr : {lr}\n") log_file.flush()
def __init__(self): super(UNet, self).__init__() self.r2_act = gspaces.Rot2dOnR2(N=8) self.field_type_1 = enn.FieldType(self.r2_act, 1 * [self.r2_act.regular_repr]) self.field_type_3 = enn.FieldType(self.r2_act, 3 * [self.r2_act.trivial_repr]) self.field_type_8 = enn.FieldType(self.r2_act, 8 * [self.r2_act.regular_repr]) self.field_type_16 = enn.FieldType(self.r2_act, 16 * [self.r2_act.regular_repr]) self.field_type_32 = enn.FieldType(self.r2_act, 32 * [self.r2_act.regular_repr]) self.field_type_64 = enn.FieldType(self.r2_act, 64 * [self.r2_act.regular_repr]) self.field_type_128 = enn.FieldType(self.r2_act, 128 * [self.r2_act.regular_repr]) self.conv1 = Conv(in_type=self.field_type_3, out_type=self.field_type_8) self.conv2 = Conv(in_type=self.field_type_8, out_type=self.field_type_8) self.down1 = DownSample(in_type=self.field_type_8, out_type=self.field_type_16) self.conv12 = Conv(in_type=self.field_type_16, out_type=self.field_type_16) self.down2 = DownSample(in_type=self.field_type_16, out_type=self.field_type_32) self.conv22 = Conv(in_type=self.field_type_32, out_type=self.field_type_32) self.down3 = DownSample(in_type=self.field_type_32, out_type=self.field_type_32) self.conv32 = Conv(in_type=self.field_type_32, out_type=self.field_type_32) self.up41 = UpSample(in_type=self.field_type_64, out_type=self.field_type_32, mid_type=self.field_type_32) self.up31 = UpSample(in_type=self.field_type_64, out_type=self.field_type_16, mid_type=self.field_type_32) self.up21 = UpSample(in_type=self.field_type_32, out_type=self.field_type_8, mid_type=self.field_type_16) self.up11 = LastConcat(in_type=self.field_type_16, out_type=self.field_type_1, mid_type=self.field_type_8) self.up42 = UpSample(in_type=self.field_type_64, out_type=self.field_type_32, mid_type=self.field_type_32) self.up32 = UpSample(in_type=self.field_type_64, out_type=self.field_type_16, mid_type=self.field_type_32) self.up22 = UpSample(in_type=self.field_type_32, out_type=self.field_type_8, mid_type=self.field_type_16) self.up12 = LastConcat(in_type=self.field_type_16, out_type=self.field_type_1, mid_type=self.field_type_8) self.up43 = UpSample(in_type=self.field_type_64, out_type=self.field_type_32, mid_type=self.field_type_32) self.up33 = UpSample(in_type=self.field_type_64, out_type=self.field_type_16, mid_type=self.field_type_32) self.up23 = UpSample(in_type=self.field_type_32, out_type=self.field_type_8, mid_type=self.field_type_16) self.up13 = LastConcat(in_type=self.field_type_16, out_type=self.field_type_1, mid_type=self.field_type_8) self.gpool1 = enn.GroupPooling(self.field_type_1) self.gpool2 = enn.GroupPooling(self.field_type_1) self.gpool3 = enn.GroupPooling(self.field_type_1)
# Set default Orientation=8, .i.e, the group C8 # One can change it by passing the env Orientation=xx Orientation = 8 # keep similar computation or similar params # One can change it by passing the env fixparams=True fixparams = False if 'Orientation' in os.environ: Orientation = int(os.environ['Orientation']) if 'fixparams' in os.environ: fixparams = True print('ReResNet Orientation: {}\tFix Params: {}'.format( Orientation, fixparams)) # define the equivariant group. We use C8 group by default. gspace = gspaces.Rot2dOnR2(N=Orientation) def regular_feature_type(gspace: gspaces.GSpace, planes: int): """ build a regular feature map with the specified number of channels""" assert gspace.fibergroup.order() > 0 N = gspace.fibergroup.order() if fixparams: planes *= math.sqrt(N) planes = planes / N planes = int(planes) return enn.FieldType(gspace, [gspace.regular_repr] * planes) def trivial_feature_type(gspace: gspaces.GSpace, planes: int,
else: sys.exit("Unknown architecture type.") if ARGS['DIV_FREE']: print("Used div free kernel in the output") CNP=steercnp.SteerCNP(encoder,decoder,ARGS['DIM_COV_EST'],dim_context_feat=2,l_scale=ARGS['LENGTH_SCALE_OUT'], kernel_dict_out={'kernel_type':"div_free"},normalize_output=False) else: CNP=steercnp.SteerCNP(encoder,decoder,ARGS['DIM_COV_EST'],dim_context_feat=2,l_scale=ARGS['LENGTH_SCALE_OUT']) #If equivariance is wanted, create the group and the fieldtype for the equivariance: if ARGS['TESTING_GROUP']=='D4': G_act=gspaces.FlipRot2dOnR2(N=4) feature_in=G_CNN.FieldType(G_act,[G_act.irrep(1,1)]) elif ARGS['TESTING_GROUP']=='C16': G_act=gspaces.Rot2dOnR2(N=16) feature_in=G_CNN.FieldType(G_act,[G_act.irrep(1)]) else: G_act=None feature_in=None print("Number of parameters: ", my_utils.count_parameters(CNP,print_table=False)) CNP,_,_=training.train_cnp(CNP, train_dataset=train_dataset, val_dataset=val_dataset, data_identifier=data_IDENTIFIER, device=DEVICE, minibatch_size=ARGS['BATCH_SIZE'], n_epochs=ARGS['N_EPOCHS'],
ACT_FNS = { 'relu': enn.ReLU, 'elu': enn.ELU, 'gated': enn.GatedNonLinearity1, 'swish': base_layers.GeomSwish, } GROUPS = { 'fliprot16': gspaces.FlipRot2dOnR2(N=16), 'fliprot12': gspaces.FlipRot2dOnR2(N=12), 'fliprot8': gspaces.FlipRot2dOnR2(N=8), 'fliprot4': gspaces.FlipRot2dOnR2(N=4), 'fliprot2': gspaces.FlipRot2dOnR2(N=2), 'flip': gspaces.Flip2dOnR2(), 'rot16': gspaces.Rot2dOnR2(N=16), 'rot12': gspaces.Rot2dOnR2(N=12), 'rot8': gspaces.Rot2dOnR2(N=8), 'rot4': gspaces.Rot2dOnR2(N=4), 'rot2': gspaces.Rot2dOnR2(N=2), 'so2': gspaces.Rot2dOnR2(N=-1, maximum_frequency=10), 'o2': gspaces.FlipRot2dOnR2(N=-1, maximum_frequency=10), } FIBERS = { "trivial": trivial_fiber, "quotient": quotient_fiber, "regular": regular_fiber, "irrep": irrep_fiber, "mixed1": mixed1_fiber, "mixed2": mixed2_fiber,
def __init__(self, input_channels, output_channels, N, last_deconv = False): super(rot_deconv2d, self).__init__() self.conv2d = rot_conv2d(input_channels = input_channels, output_channels = output_channels, kernel_size = 4, activation = True, stride = 1, N = N, deconv = True, last_deconv = last_deconv) r2_act = gspaces.Rot2dOnR2(N = N) self.feat_type = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
def __init__(self, depth, widen_factor, dropout_rate, num_classes=100, N: int = 8, r: int = 1, f: bool = True, deltaorth: bool = False, fixparams: bool = True, initial_stride: int = 1, ): r""" Build and equivariant Wide ResNet. The parameter ``N`` controls rotation equivariance and the parameter ``f`` reflection equivariance. More precisely, ``N`` is the number of discrete rotations the model is initially equivariant to. ``N = 1`` means the model is only reflection equivariant from the beginning. ``f`` is a boolean flag specifying whether the model should be reflection equivariant or not. If it is ``False``, the model is not reflection equivariant. ``r`` is the restriction level: - ``0``: no restriction. The model is equivariant to ``N`` rotations from the input to the output - ``1``: restriction before the last block. The model is equivariant to ``N`` rotations before the last block (i.e. in the first 2 blocks). Then it is restricted to ``N/2`` rotations until the output. - ``2``: restriction after the first block. The model is equivariant to ``N`` rotations in the first block. Then it is restricted to ``N/2`` rotations until the output (i.e. in the last 3 blocks). - ``3``: restriction after the first and the second block. The model is equivariant to ``N`` rotations in the first block. It is restricted to ``N/2`` rotations before the second block and to ``1`` rotations before the last block. NOTICE: if restriction to ``N/2`` is performed, ``N`` needs to be even! """ super(Wide_ResNet, self).__init__() assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' n = int((depth - 4) / 6) k = widen_factor print(f'| Wide-Resnet {depth}x{k}') nStages = [16, 16 * k, 32 * k, 64 * k] self._fixparams = fixparams self._layer = 0 # number of discrete rotations to be equivariant to self._N = N # if the model is [F]lip equivariant self._f = f if self._f: if N != 1: self.gspace = gspaces.FlipRot2dOnR2(N) else: self.gspace = gspaces.Flip2dOnR2() else: if N != 1: self.gspace = gspaces.Rot2dOnR2(N) else: self.gspace = gspaces.TrivialOnR2() # level of [R]estriction: # r = 0: never do restriction, i.e. initial group (either DN or CN) preserved for the whole network # r = 1: restrict before the last block, i.e. initial group (either DN or CN) preserved for the first # 2 blocks, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the last block # r = 2: restrict after the first block, i.e. initial group (either DN or CN) preserved for the first # block, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the last 2 blocks # r = 3: restrict after each block. Initial group (either DN or CN) preserved for the first # block, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the second block and to 1 rotation # in the last one (D1 or C1) assert r in [0, 1, 2, 3] self._r = r # the input has 3 color channels (RGB). # Color channels are trivial fields and don't transform when the input is rotated or flipped r1 = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3) # input field type of the model self.in_type = r1 # in the first layer we always scale up the output channels to allow for enough independent filters r2 = FIELD_TYPE["regular"](self.gspace, nStages[0], fixparams=True) # dummy attribute keeping track of the output field type of the last submodule built, i.e. the input field type of # the next submodule to build self._in_type = r2 self.conv1 = conv5x5(r1, r2) self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=initial_stride) if self._r >= 2: N_new = N//2 id = (0, N_new) if self._f else N_new self.restrict1 = self._restrict_layer(id) else: self.restrict1 = lambda x: x self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2) if self._r == 3: id = (0, 1) if self._f else 1 self.restrict2 = self._restrict_layer(id) elif self._r == 1: N_new = N // 2 id = (0, N_new) if self._f else N_new self.restrict2 = self._restrict_layer(id) else: self.restrict2 = lambda x: x # last layer maps to a trivial (invariant) feature map self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2, totrivial=True) self.bn = enn.InnerBatchNorm(self.layer3.out_type, momentum=0.9) self.relu = enn.ReLU(self.bn.out_type, inplace=True) self.linear = torch.nn.Linear(self.bn.out_type.size, num_classes) for name, module in self.named_modules(): if isinstance(module, enn.R2Conv): if deltaorth: init.deltaorthonormal_init(module.weights, module.basisexpansion) elif isinstance(module, torch.nn.BatchNorm2d): module.weight.data.fill_(1) module.bias.data.zero_() elif isinstance(module, torch.nn.Linear): module.bias.data.zero_() print("MODEL TOPOLOGY:") for i, (name, mod) in enumerate(self.named_modules()): print(f"\t{i} - {name}")
def __init__(self, n_classes=10): super(C8SteerableCNN, self).__init__() # the model is equivariant under rotations by 45 degrees, modelled by C8 self.r2_act = gspaces.Rot2dOnR2(N=8) # the input image is a scalar field, corresponding to the trivial representation in_type = nn.FieldType(self.r2_act, [self.r2_act.trivial_repr]) # we store the input type for wrapping the images into a geometric tensor during the forward pass self.input_type = in_type # convolution 1 # first specify the output type of the convolutional layer # we choose 16 feature fields, each transforming under the regular representation of C8 out_type = nn.FieldType(self.r2_act, 24 * [self.r2_act.regular_repr]) self.block1 = nn.SequentialModule( # nn.MaskModule(in_type, 29, margin=1), nn.R2Conv(in_type, out_type, kernel_size=7, padding=1, bias=False), nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True)) # convolution 2 # the old output type is the input type to the next layer in_type = self.block1.out_type # the output type of the second convolution layer are 32 regular feature fields of C8 out_type = nn.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr]) self.block2 = nn.SequentialModule( nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False), nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True)) self.pool1 = nn.SequentialModule( nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)) # convolution 3 # the old output type is the input type to the next layer in_type = self.block2.out_type # the output type of the third convolution layer are 32 regular feature fields of C8 out_type = nn.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr]) self.block3 = nn.SequentialModule( nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False), nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True)) # convolution 4 # the old output type is the input type to the next layer in_type = self.block3.out_type # the output type of the fourth convolution layer are 64 regular feature fields of C8 out_type = nn.FieldType(self.r2_act, 96 * [self.r2_act.regular_repr]) self.block4 = nn.SequentialModule( nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False), nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True)) self.pool2 = nn.SequentialModule( nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)) # convolution 5 # the old output type is the input type to the next layer in_type = self.block4.out_type # the output type of the fifth convolution layer are 64 regular feature fields of C8 out_type = nn.FieldType(self.r2_act, 96 * [self.r2_act.regular_repr]) self.block5 = nn.SequentialModule( nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False), nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True)) # convolution 6 # the old output type is the input type to the next layer in_type = self.block5.out_type # the output type of the sixth convolution layer are 64 regular feature fields of C8 out_type = nn.FieldType(self.r2_act, 64 * [self.r2_act.regular_repr]) self.block6 = nn.SequentialModule( nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False), nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True)) self.pool3 = nn.PointwiseAvgPool(out_type, kernel_size=4) self.gpool = nn.GroupPooling(out_type) # number of output channels c = self.gpool.out_type.size # Fully Connected self.fully_net = torch.nn.Sequential( torch.nn.Linear(c, 64), torch.nn.BatchNorm1d(64), torch.nn.ELU(inplace=True), torch.nn.Linear(64, n_classes), )
def restrict(self, id: Tuple[Union[None, float, int], int]) -> Tuple[gspaces.GSpace, Callable, Callable]: r""" Build the :class:`~e2cnn.group.GSpace` associated with the subgroup of the current fiber group identified by the input ``id``, which is a tuple :math:`(k, M)`. Here, :math:`M` is a positive integer indicating the number of discrete rotations in the subgroup while :math:`k` is either ``None`` (no reflections) or an angle indicating the axis of reflection. If the current fiber group is :math:`D_N` (:class:`~e2cnn.group.DihedralGroup`), then :math:`M` needs to divide :math:`N` and :math:`k` needs to be an integer in :math:`\{0, \dots, \frac{N}{M}-1\}`. Otherwise, :math:`M` can be any positive integer while :math:`k` needs to be a real number in :math:`[0, \frac{2\pi}{M}]`. Valid combinations are: - (``None``, :math:`1`): restrict to no reflection and rotation symmetries - (``None``, :math:`M`): restrict to only the :math:`M` rotations generated by :math:`r_{2\pi/M}`. - (:math:`0`, :math:`1`): restrict to only reflections :math:`\langle f \rangle` around the same axis as in the current group - (:math:`0`, :math:`M`): restrict to reflections and :math:`M` rotations generated by :math:`r_{2\pi/M}` and :math:`f` If the current fiber group is :math:`D_N` (an instance of :class:`~e2cnn.group.DihedralGroup`): - (:math:`k`, :math:`M`): restrict to reflections :math:`\langle r_{k\frac{2\pi}{N}} f \rangle` around the axis of the current G-space rotated by :math:`k\frac{\pi}{N}` and :math:`M` rotations generated by :math:`r_{2\pi/M}` If the current fiber group is :math:`O(2)` (an instance of :class:`~e2cnn.group.O2`): - (:math:`\theta`, :math:`M`): restrict to reflections :math:`\langle r_{\theta} f \rangle` around the axis of the current G-space rotated by :math:`\frac{\theta}{2}` and :math:`M` rotations generated by :math:`r_{2\pi/M}` - (``None``, :math:`-1`): restrict to all (continuous) rotations Args: id (tuple): the id of the subgroup Returns: a tuple containing - **gspace**: the restricted gspace - **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space - **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup) """ subgroup, mapping, child = self.fibergroup.subgroup(id) if id[0] is not None: # the new flip axis is the previous one rotated by the new chosen axis for the flip # notice that the actual group element used to generate the subgroup does not correspond to the flip axis # but to 2 times that angle if self.fibergroup.order() > 1: n = self.fibergroup.rotation_order rotation = id[0] * 2.0 * np.pi / n else: rotation = id[0] new_axis = divmod(self.axis + 0.5*rotation, 2*np.pi)[1] if id[0] is None and id[1] == 1: return gspaces.TrivialOnR2(fibergroup=subgroup), mapping, child elif id[0] is None and (id[1] > 1 or id[1] == -1): return gspaces.Rot2dOnR2(fibergroup=subgroup), mapping, child elif id[0] is not None and id[1] == 1: return gspaces.Flip2dOnR2(fibergroup=subgroup, axis=new_axis), mapping, child elif id[0] is not None: return gspaces.FlipRot2dOnR2(fibergroup=subgroup, axis=new_axis), mapping, child else: raise ValueError(f"id {id} not recognized!")
def __init__(self, hidden_reps_ids, kernel_sizes, dim_cov_est, context_rep_ids=[1], N=4, flip=False, non_linearity=["NormReLU"], max_frequency=30): ''' Input: hidden_reps_ids - list: encoding the hidden fiber representation (see give_fib_reps_from_ids) kernel_sizes - list of ints - sizes of kernels for convolutional layers dim_cov_est - dimension of covariance estimation, either 1,2,3 or 4 context_rep_ids - list: gives the input fiber representation (see give_fib_reps_from_ids) non_linearity - list of strings - gives names of non-linearity to be used Either length 1 (then same non-linearity for all) or length is the number of layers (giving a custom non-linearity for every layer) N - int - gives the group order, -1 is infinite flip - Bool - indicates whether we have a flip in the rotation group (i.e.O(2) vs SO(2), D_N vs C_N) max_frequency - int - maximum irrep frequency to computed, only relevant if N=-1 ''' super(SteerDecoder, self).__init__() #Save the rotation group, if flip is true, then include all corresponding reflections: self.flip = flip self.max_frequency = max_frequency if self.flip: self.G_act = gspaces.FlipRot2dOnR2( N=N) if N != -1 else gspaces.FlipRot2dOnR2( N=N, maximum_frequency=self.max_frequency) #The output fiber representation is the identity: self.target_rep = self.G_act.irrep(1, 1) else: self.G_act = gspaces.Rot2dOnR2( N=N) if N != -1 else gspaces.Rot2dOnR2( N=N, maximum_frequency=self.max_frequency) #The output fiber representation is the identity: self.target_rep = self.G_act.irrep(1) #Save the N defining D_N or C_N (if N=-1 it is infinity): self.polygon_corners = N #Save the id's for the context representation and extract the context fiber representation: self.context_rep_ids = context_rep_ids self.context_rep = group.directsum( self.give_reps_from_ids(self.context_rep_ids)) #Save the parameters: self.kernel_sizes = kernel_sizes self.n_layers = len(hidden_reps_ids) + 2 self.hidden_reps_ids = hidden_reps_ids self.dim_cov_est = dim_cov_est #-----CREATE LIST OF NON-LINEARITIES---- if len(non_linearity) == 1: self.non_linearity = (self.n_layers - 2) * non_linearity elif len(non_linearity) != (self.n_layers - 2): sys.exit( "List of non-linearities invalid: must have either length 1 or n_layers-2" ) else: self.non_linearity = non_linearity #-----ENDE LIST OF NON-LINEARITIES---- #-----------CREATE DECODER----------------- ''' Create a list of layers based on the kernel sizes. Compute the padding such that the height h and width w of a tensor with shape (batch_size,n_channels,h,w) does not change while being passed through the decoder ''' #Create list of feature types: feat_types = self.give_feat_types() self.feature_emb = feat_types[0] self.feature_out = feat_types[-1] #Create layers list and append it: layers_list = [ G_CNN.R2Conv(feat_types[0], feat_types[1], kernel_size=kernel_sizes[0], padding=(kernel_sizes[0] - 1) // 2) ] for it in range(self.n_layers - 2): if self.non_linearity[it] == "ReLU": layers_list.append(G_CNN.ReLU(feat_types[it + 1], inplace=True)) elif self.non_linearity[it] == "NormReLU": layers_list.append(G_CNN.NormNonLinearity(feat_types[it + 1])) else: sys.exit("Unknown non-linearity.") layers_list.append( G_CNN.R2Conv(feat_types[it + 1], feat_types[it + 2], kernel_size=kernel_sizes[it], padding=(kernel_sizes[it] - 1) // 2)) #Create a steerable decoder out of the layers list: self.decoder = G_CNN.SequentialModule(*layers_list) #-----------END CREATE DECODER--------------- #-----------CONTROL INPUTS------------------ #Control that all kernel sizes are odd (otherwise output shape is not correct): if any([j % 2 - 1 for j in kernel_sizes]): sys.exit("All kernels need to have odd sizes") if len(kernel_sizes) != (self.n_layers - 1): sys.exit("Number of layers and number kernels do not match.") if len(self.non_linearity) != (self.n_layers - 2): sys.exit( "Number of layers and number of non-linearities do not match.")