Ejemplo n.º 1
0
def build_steer_cnn_decoder(
    context_rep_ids,
    hidden_reps_ids,
    kernel_sizes,
    mean_rep_ids,
    covariance_activation="quadratic",
    N=4,
    flip=True,
    max_frequency=30,
    activation="relu",
    padding_mode="zeros",
):
    if flip:
        gspace = (gspaces.FlipRot2dOnR2(N=N) if N != -1 else
                  gspaces.FlipRot2dOnR2(N=N, maximum_frequency=max_frequency))
    else:
        gspace = (gspaces.Rot2dOnR2(N=N) if N != -1 else gspaces.Rot2dOnR2(
            N=N, maximum_frequency=max_frequency))

    in_field_type = gnn.FieldType(
        gspace, [gspace.trivial_repr, *reps_from_ids(gspace, context_rep_ids)])

    hidden_field_types = [
        gnn.FieldType(gspace, reps_from_ids(gspace, ids))
        for ids in hidden_reps_ids
    ]

    mean_field_type = gnn.FieldType(gspace,
                                    reps_from_ids(gspace, mean_rep_ids))

    pre_covariance_field_type = get_pre_covariance_field_type(
        gspace, mean_field_type, covariance_activation)

    out_field_type = mean_field_type + pre_covariance_field_type

    init_modify = (1.0
                   if not (mean_rep_ids == [[0]] and context_rep_ids == [[0]])
                   else 0.833, )
    if N == -1:
        init_modify = 1.0

    return build_steer_cnn_2d(
        in_field_type,
        hidden_field_types,
        kernel_sizes,
        out_field_type,
        gspace,
        activation,
        padding_mode,
    )
Ejemplo n.º 2
0
 def __init__(self, input_channels, output_channels, kernel_size, stride, N, activation = True, deconv = False, last_deconv = False):
     super(rot_conv2d, self).__init__()       
     r2_act = gspaces.Rot2dOnR2(N = N)
     
     feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
     feat_type_hid = nn.FieldType(r2_act, output_channels*[r2_act.regular_repr])
     if not deconv:
         if activation:
             self.layer = nn.SequentialModule(
                 nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = (kernel_size - 1)//2),
                 nn.InnerBatchNorm(feat_type_hid),
                 nn.ReLU(feat_type_hid)
             ) 
         else:
             self.layer = nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride,padding = (kernel_size - 1)//2)
     else:
         if last_deconv:
             feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
             feat_type_hid = nn.FieldType(r2_act, output_channels*[r2_act.irrep(1)])
             self.layer = nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = 0)
         else:
             self.layer = nn.SequentialModule(
                     nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = 0),
                     nn.InnerBatchNorm(feat_type_hid),
                     nn.ReLU(feat_type_hid)
                 ) 
Ejemplo n.º 3
0
    def restrict(self, id: int) -> Tuple[gspaces.GSpace, Callable, Callable]:
        r"""

        Build the :class:`~e2cnn.group.GSpace` associated with the subgroup of the current fiber group identified by
        the input ``id``.
        
        ``id`` is a positive integer :math:`M` indicating the number of rotations in the subgroup.
        If the current fiber group is :math:`C_N` (:class:`~e2cnn.group.CyclicGroup`), then :math:`M` needs to divide
        :math:`N`. Otherwise, :math:`M` can be any positive integer.
        
        Args:
            id (int): the number :math:`M` of rotations in the subgroup

        Returns:
            a tuple containing

                - **gspace**: the restricted gspace

                - **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space

                - **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup)


        """
        subgroup, mapping, child = self.fibergroup.subgroup(id)

        if id > 1:
            return gspaces.Rot2dOnR2(fibergroup=subgroup), mapping, child
        elif id == 1:
            return gspaces.TrivialOnR2(fibergroup=subgroup), mapping, child
        else:
            raise ValueError(f"id {id} not recognized!")
Ejemplo n.º 4
0
 def __init__(self, 
              input_channels,
              hidden_dim, 
              kernel_size, 
              N # Group size 
             ): 
     super(rot_resblock, self).__init__()
     
     # Specify symmetry transformation
     r2_act = gspaces.Rot2dOnR2(N = N)
     feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
     feat_type_hid = nn.FieldType(r2_act, hidden_dim*[r2_act.regular_repr])
     
     self.layer1 = nn.SequentialModule(
         nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(feat_type_hid),
         nn.ReLU(feat_type_hid)
     ) 
     
     self.layer2 = nn.SequentialModule(
         nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(feat_type_hid),
         nn.ReLU(feat_type_hid)
     )    
     
     self.upscale = nn.SequentialModule(
         nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(feat_type_hid),
         nn.ReLU(feat_type_hid)
     )    
     
     self.input_channels = input_channels
     self.hidden_dim = hidden_dim
Ejemplo n.º 5
0
    def __init__(self, input_frames, output_frames, kernel_size, N):
        super(Unet_Rot, self).__init__()
        r2_act = gspaces.Rot2dOnR2(N = N)
        self.feat_type_in = nn.FieldType(r2_act, input_frames*[r2_act.irrep(1)])
        self.feat_type_in_hid = nn.FieldType(r2_act, 32*[r2_act.regular_repr])
        self.feat_type_hid_out = nn.FieldType(r2_act, (16 + input_frames)*[r2_act.irrep(1)])
        self.feat_type_out = nn.FieldType(r2_act, output_frames*[r2_act.irrep(1)])
        
        self.conv1 = nn.SequentialModule(
            nn.R2Conv(self.feat_type_in, self.feat_type_in_hid, kernel_size = kernel_size, stride = 2, padding = (kernel_size - 1)//2),
            nn.InnerBatchNorm(self.feat_type_in_hid),
            nn.ReLU(self.feat_type_in_hid)
        )

        self.conv2 = rot_conv2d(32, 64, kernel_size = kernel_size, stride = 1, N = N)
        self.conv2_1 = rot_conv2d(64, 64, kernel_size = kernel_size, stride = 1, N = N)
        self.conv3 = rot_conv2d(64, 128, kernel_size = kernel_size, stride = 2, N = N)
        self.conv3_1 = rot_conv2d(128, 128, kernel_size = kernel_size, stride = 1, N = N)
        self.conv4 = rot_conv2d(128, 256, kernel_size = kernel_size, stride = 2, N = N)
        self.conv4_1 = rot_conv2d(256, 256, kernel_size = kernel_size, stride = 1, N = N)

        self.deconv3 = rot_deconv2d(256, 64, N)
        self.deconv2 = rot_deconv2d(192, 32, N)
        self.deconv1 = rot_deconv2d(96, 16, N, last_deconv = True)

    
        self.output_layer = nn.R2Conv(self.feat_type_hid_out, self.feat_type_out, kernel_size = kernel_size, padding = (kernel_size - 1)//2)
Ejemplo n.º 6
0
def get_model(hparams):
    if hparams['model'] == 'discrete':
        return C4UNet(hparams['in_channels'], hparams['out_channels'],
                        group_type=hparams['group'],
                        N=hparams['N'],
                        n_features=hparams['n_features'],
                        loss_type=hparams['loss_func'],
                        lr=hparams['lr'])
    elif hparams['model'] == 'standard':
        return UNet(hparams['in_channels'], hparams['out_channels'], n_features=hparams['n_features'], loss_type=hparams['loss_func'], lr=hparams['lr'])
    elif hparams['model'] == 'harmonic':
        return HarmonicUNet(hparams['in_channels'],
                            hparams['out_channels'],
                            n_features=hparams['n_features'],
                            group_type=hparams['group'],
                            max_freq=hparams['max_freq'],
                            loss_type=hparams['loss_func'],
                            lr=hparams['lr'])
    elif hparams['model'] == 'steerable':
        gspace = gspaces.Rot2dOnR2(-1, maximum_frequency=hparams['max_freq'])
        return SteerableCNN(gspace,
                        hparams['in_channels'],
                        hparams['out_channels'],
                        n_blocks=hparams['n_blocks'],
                        n_features=hparams['n_features'],
                        irrep_type=hparams['irrep_type'],
                        loss_type=hparams['loss_func'],
                        lr=hparams['lr'])
    else:
        raise ValueError(f'Unsupported model type: {hparams["model"]}')
Ejemplo n.º 7
0
    def __init__(self, channel_in=3, n_classes=4, rot_n=4):
        super(SmallE2, self).__init__()

        r2_act = gspaces.Rot2dOnR2(N=rot_n)

        self.feat_type_in = nn.FieldType(r2_act,
                                         channel_in * [r2_act.trivial_repr])
        feat_type_hid = nn.FieldType(r2_act, 8 * [r2_act.regular_repr])
        feat_type_out = nn.FieldType(r2_act, 2 * [r2_act.regular_repr])

        self.bn = nn.InnerBatchNorm(feat_type_hid)
        self.relu = nn.ReLU(feat_type_hid)

        self.convin = nn.R2Conv(self.feat_type_in,
                                feat_type_hid,
                                kernel_size=3)
        self.convhid = nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size=3)
        self.convout = nn.R2Conv(feat_type_hid, feat_type_out, kernel_size=3)

        self.avgpool = nn.PointwiseAvgPool(feat_type_out, 3)
        self.invariant_map = nn.GroupPooling(feat_type_out)

        c = self.invariant_map.out_type.size

        self.lin_in = torch.nn.Linear(c, 64)
        self.elu = torch.nn.ELU()
        self.lin_out = torch.nn.Linear(64, n_classes)
Ejemplo n.º 8
0
    def __init__(self):
        super(DenseFeatureExtractionModuleE2Inv, self).__init__()

        filters = np.array([32,32, 64,64, 128,128,128, 256,256,256, 512,512,512], dtype=np.int32)*2
        
        # number of rotations to consider for rotation invariance
        N = 8
        
        self.gspace = gspaces.Rot2dOnR2(N)
        self.input_type = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3)
        ip_op_types = [
            self.input_type,
        ]
        
        self.num_channels = 64

        for filter_ in filters[:10]:
            ip_op_types.append(FIELD_TYPE['regular'](self.gspace, filter_, fixparams=False))

        self.model = enn.SequentialModule(*[
            conv3x3(ip_op_types[0], ip_op_types[1]),
            enn.ReLU(ip_op_types[1], inplace=True),
            conv3x3(ip_op_types[1], ip_op_types[2]),
            enn.ReLU(ip_op_types[2], inplace=True),
            enn.PointwiseMaxPool(ip_op_types[2], 2),

            conv3x3(ip_op_types[2], ip_op_types[3]),
            enn.ReLU(ip_op_types[3], inplace=True),
            conv3x3(ip_op_types[3], ip_op_types[4]),
            enn.ReLU(ip_op_types[4], inplace=True),
            enn.PointwiseMaxPool(ip_op_types[4], 2),

            conv3x3(ip_op_types[4], ip_op_types[5]),
            enn.ReLU(ip_op_types[5], inplace=True),
            conv3x3(ip_op_types[5], ip_op_types[6]),
            enn.ReLU(ip_op_types[6], inplace=True),
            conv3x3(ip_op_types[6], ip_op_types[7]),
            enn.ReLU(ip_op_types[7], inplace=True),
            enn.PointwiseAvgPool(ip_op_types[7], kernel_size=2, stride=1),

            conv5x5(ip_op_types[7], ip_op_types[8]),
            enn.ReLU(ip_op_types[8], inplace=True),
            conv5x5(ip_op_types[8], ip_op_types[9]),
            enn.ReLU(ip_op_types[9], inplace=True),
            conv5x5(ip_op_types[9], ip_op_types[10]),
            enn.ReLU(ip_op_types[10], inplace=True),
            
            # enn.PointwiseMaxPool(ip_op_types[7], 2),

            # conv3x3(ip_op_types[7], ip_op_types[8]),
            # enn.ReLU(ip_op_types[8], inplace=True),
            # conv3x3(ip_op_types[8], ip_op_types[9]),
            # enn.ReLU(ip_op_types[9], inplace=True),
            # conv3x3(ip_op_types[9], ip_op_types[10]),
            # enn.ReLU(ip_op_types[10], inplace=True),
            enn.GroupPooling(ip_op_types[10])
        ])
Ejemplo n.º 9
0
    def __init__(self, n_classes=6):
        super(SteerCNN, self).__init__()

        # the model is equivariant under rotations by 45 degrees, modelled by C8
        self.r2_act = gspaces.Rot2dOnR2(N=4)

        # the input image is a scalar field, corresponding to the trivial representation
        input_type = nn_e2.FieldType(self.r2_act,
                                     3 * [self.r2_act.trivial_repr])

        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = input_type
        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 24 feature fields, each transforming under the regular representation of C8
        out_type = nn_e2.FieldType(self.r2_act,
                                   24 * [self.r2_act.regular_repr])
        self.block1 = nn_e2.SequentialModule(
            nn_e2.R2Conv(input_type,
                         out_type,
                         kernel_size=7,
                         padding=3,
                         bias=False), nn_e2.InnerBatchNorm(out_type),
            nn_e2.ReLU(out_type, inplace=True))

        self.pool1 = nn_e2.PointwiseAvgPool(out_type, 4)

        # convolution 2
        # the old output type is the input type to the next layer
        in_type = self.block1.out_type
        # the output type of the second convolution layer are 48 regular feature fields of C8
        #out_type = nn_e2.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr])
        self.block2 = nn_e2.SequentialModule(
            nn_e2.R2Conv(in_type,
                         out_type,
                         kernel_size=7,
                         padding=3,
                         bias=False), nn_e2.InnerBatchNorm(out_type),
            nn_e2.ReLU(out_type, inplace=True))
        self.pool2 = nn_e2.SequentialModule(
            nn_e2.PointwiseAvgPoolAntialiased(out_type,
                                              sigma=0.66,
                                              stride=1,
                                              padding=0),
            nn_e2.PointwiseAvgPool(out_type, 4), nn_e2.GroupPooling(out_type))
        # PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=7)

        # number of output channels
        c = 24 * 13 * 13  #self.gpool.out_type.size

        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(c, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )
Ejemplo n.º 10
0
 def get_gspace(kwargs):
     group_type, N = kwargs['group_type'], kwargs['N']
     del kwargs['group_type']
     del kwargs['N']
     if group_type == 'circle':
         return gspaces.Rot2dOnR2(N)
     elif group_type == 'dihedral':
         return gspaces.FlipRot2dOnR2(N)
     else:
         raise ValueError(
             'Discrete group argument "group" must be one of ["circle", "dihedral"]'
         )
Ejemplo n.º 11
0
 def __init__(self, nclasses=1):
     super(ResNet50, self).__init__()
     self.gspace = gspaces.Rot2dOnR2(N=8)
     
     reg_field64 = FIELD_TYPE["regular"](self.gspace, 64, fixparams=False)
     reg_field256 = FIELD_TYPE["regular"](self.gspace, 256, fixparams=False)
     reg_field128 = FIELD_TYPE["regular"](self.gspace, 128, fixparams=False)
     reg_field512 = FIELD_TYPE["regular"](self.gspace, 512, fixparams=False)
     reg_field1024 = FIELD_TYPE["regular"](self.gspace, 1024, fixparams=False)
     reg_field2048 = FIELD_TYPE["regular"](self.gspace, 2048, fixparams=False)
     
     self.conv1 = enn.R2Conv(FIELD_TYPE["trivial"](self.gspace, 3, fixparams=False),
                             reg_field64, kernel_size=7, stride=2, padding=3)
     self.bn1 = enn.InnerBatchNorm(reg_field64)
     self.relu1 = enn.ELU(reg_field64)
     self.maxpool1 = enn.PointwiseMaxPoolAntialiased(reg_field64, kernel_size=2)
     
     layer1 = []
     layer1.append(ResBlock(stride=2, in_type = reg_field64, inner_type = reg_field64, out_type = reg_field256))
     layer1.append(ResBlock(stride=1, in_type = reg_field256, inner_type = reg_field64, out_type = reg_field256))
     layer1.append(ResBlock(stride=1, in_type = reg_field256, inner_type = reg_field64, out_type = reg_field256))
     self.layer1 = torch.nn.Sequential(*layer1)
     
     layer2 = []
     layer2.append(ResBlock(stride=2, in_type = reg_field256, inner_type = reg_field128, out_type = reg_field512))
     layer2.append(ResBlock(stride=1, in_type = reg_field512, inner_type = reg_field128, out_type = reg_field512))
     layer2.append(ResBlock(stride=1, in_type = reg_field512, inner_type = reg_field128, out_type = reg_field512))
     layer2.append(ResBlock(stride=1, in_type = reg_field512, inner_type = reg_field128, out_type = reg_field512))
     self.layer2 = torch.nn.Sequential(*layer2)
     
     layer3 = []
     layer3.append(ResBlock(stride=2, in_type = reg_field512, inner_type = reg_field256, out_type = reg_field1024))
     layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024))
     layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024))
     layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024))
     layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024))
     layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024))
     self.layer3 = torch.nn.Sequential(*layer3)
     
     layer4 = []
     layer4.append(ResBlock(stride=2, in_type = reg_field1024, inner_type = reg_field512, out_type = reg_field2048))
     layer4.append(ResBlock(stride=1, in_type = reg_field2048, inner_type = reg_field512, out_type = reg_field2048))
     layer4.append(ResBlock(stride=1, in_type = reg_field2048, inner_type = reg_field512, out_type = reg_field2048))
     self.layer4 = torch.nn.Sequential(*layer4)
     
     self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
     self.fc = torch.nn.Linear(2048, nclasses)
def generate_2d_rot8(out_path):
    r2_act = gspaces.Rot2dOnR2(N=8)
    feat_type_in = gnn.FieldType(r2_act, [r2_act.trivial_repr])
    feat_type_out = gnn.FieldType(r2_act, 3 * [r2_act.regular_repr])
    conv = gnn.R2Conv(feat_type_in, feat_type_out, kernel_size=3, bias=False)
    xs, ys, ws = [], [], []
    for task_idx in range(10000):
        gnn.init.generalized_he_init(conv.weights, conv.basisexpansion)
        inp = gnn.GeometricTensor(torch.randn(20, 1, 32, 32), feat_type_in)
        result = conv(inp).tensor.detach().cpu().numpy()
        xs.append(inp.tensor.detach().cpu().numpy())
        ys.append(result)
        ws.append(conv.weights.detach().cpu().numpy())
        if task_idx % 100 == 0:
            print(f"Finished generating task {task_idx}")
    xs, ys, ws = np.stack(xs), np.stack(ys), np.stack(ws)
    np.savez(out_path, x=xs, y=ys, w=ws)
Ejemplo n.º 13
0
 def __init__(self):
     super(ModelDilated, self).__init__()
     N = 8
     self.gspace = gspaces.Rot2dOnR2(N)
     self.in_type = enn.FieldType(self.gspace,
                                  [self.gspace.trivial_repr] * 3)
     self.out_type = enn.FieldType(self.gspace,
                                   [self.gspace.regular_repr] * 16)
     self.layer = enn.R2Conv(
         self.in_type,
         self.out_type,
         3,
         stride=1,
         padding=2,
         dilation=2,
         bias=True,
     )
     self.invariant = enn.GroupPooling(self.out_type)
Ejemplo n.º 14
0
    def __init__(self, in_chan, out_chan, imsize, kernel_size=5, N=8):
        super(CNSteerableLeNet, self).__init__()
        kernel_size = int(kernel_size)

        z = imsize // 2 // 2

        self.r2_act = gspaces.Rot2dOnR2(N)

        in_type = e2nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
        self.input_type = in_type

        out_type = e2nn.FieldType(self.r2_act, 6 * [self.r2_act.regular_repr])
        self.mask = e2nn.MaskModule(in_type, imsize, margin=1)
        self.conv1 = e2nn.R2Conv(in_type,
                                 out_type,
                                 kernel_size=kernel_size,
                                 padding=kernel_size // 2,
                                 bias=False)
        self.relu1 = e2nn.ReLU(out_type, inplace=True)
        self.pool1 = e2nn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2)

        in_type = self.pool1.out_type
        out_type = e2nn.FieldType(self.r2_act, 16 * [self.r2_act.regular_repr])
        self.conv2 = e2nn.R2Conv(in_type,
                                 out_type,
                                 kernel_size=kernel_size,
                                 padding=kernel_size // 2,
                                 bias=False)
        self.relu2 = e2nn.ReLU(out_type, inplace=True)
        self.pool2 = e2nn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2)

        self.gpool = e2nn.GroupPooling(out_type)

        self.fc1 = nn.Linear(16 * z * z, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, out_chan)

        self.drop = nn.Dropout(p=0.5)

        # dummy parameter for tracking device
        self.dummy = nn.Parameter(torch.empty(0))
Ejemplo n.º 15
0
 def __init__(self, growth_rate, list_layer, nclasses):
     super(DenseNet161, self).__init__()
     
     self.gspace = gspaces.Rot2dOnR2(N=8)
     
     in_type = 2*growth_rate
     
     self.conv1 = conv7x7(FIELD_TYPE["trivial"](self.gspace, 3, fixparams=False), 
                          FIELD_TYPE["regular"](self.gspace, in_type, fixparams=False))
     
     self.pool1 = enn.PointwiseMaxPool(FIELD_TYPE["regular"](self.gspace, in_type, fixparams=False),
                                       kernel_size=2, stride=2)
     
     
     #1st block
     self.block1 = DenseBlock(in_type, growth_rate, self.gspace, list_layer[0])
     in_type = in_type +list_layer[0]*growth_rate
     self.trans1 = TransitionBlock(in_type, int(in_type/2), self.gspace)
     in_type = int(in_type/2)
     
     #2nd block
     self.block2 = DenseBlock(in_type, growth_rate, self.gspace, list_layer[1])
     in_type = in_type +list_layer[1]*growth_rate
     self.trans2 = TransitionBlock(in_type, int(in_type/2), self.gspace)
     in_type = int(in_type/2)
     
     #3rd block
     self.block3 = DenseBlock(in_type, growth_rate, self.gspace, list_layer[2])
     in_type = in_type +list_layer[2]*growth_rate
     self.trans3 = TransitionBlock(in_type, int(in_type/2), self.gspace)
     in_type = int(in_type/2)
     
     #4th block
     self.block4 = DenseBlock(in_type, growth_rate, self.gspace, list_layer[3])
     in_type = in_type +list_layer[3]*growth_rate
     
     
     self.bn = enn.InnerBatchNorm(FIELD_TYPE["regular"](self.gspace, in_type, fixparams=False))
     self.relu = enn.ReLU(FIELD_TYPE["regular"](self.gspace, in_type, fixparams=False),inplace=True)
     self.pool2 = torch.nn.AdaptiveAvgPool2d((1, 1))
     self.classifier = torch.nn.Linear(in_type, nclasses)
Ejemplo n.º 16
0
 def __init__(self, input_frames, output_frames, kernel_size, N):
     super(ResNet_Rot, self).__init__()
     r2_act = gspaces.Rot2dOnR2(N = N)
     # we use rho_1 representation since the input is velocity fields 
     self.feat_type_in = nn.FieldType(r2_act, input_frames*[r2_act.irrep(1)])
     # we use regular representation for middle layers
     self.feat_type_in_hid = nn.FieldType(r2_act, 16*[r2_act.regular_repr])
     self.feat_type_hid_out = nn.FieldType(r2_act, 192*[r2_act.regular_repr])
     self.feat_type_out = nn.FieldType(r2_act, output_frames*[r2_act.irrep(1)])
     
     self.input_layer = nn.SequentialModule(
         nn.R2Conv(self.feat_type_in, self.feat_type_in_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(self.feat_type_in_hid),
         nn.ReLU(self.feat_type_in_hid)
     )
     layers = [self.input_layer]
     layers += [rot_resblock(16, 32, kernel_size, N), rot_resblock(32, 32, kernel_size, N)]
     layers += [rot_resblock(32, 64, kernel_size, N), rot_resblock(64, 64, kernel_size, N)]
     layers += [rot_resblock(64, 128, kernel_size, N), rot_resblock(128, 128, kernel_size, N)]
     layers += [rot_resblock(128, 192, kernel_size, N), rot_resblock(192, 192, kernel_size, N)]
     layers += [nn.R2Conv(self.feat_type_hid_out, self.feat_type_out, kernel_size = kernel_size, padding = (kernel_size - 1)//2)]
     self.model = torch.nn.Sequential(*layers)
Ejemplo n.º 17
0
        # pool over the group
        x = self.gpool(x)

        # unwrap the output GeometricTensor
        # (take the Pytorch tensor and discard the associated representation)
        x = x.tensor

        # classify with the final fully connected layers)
        x = self.fully_net(x.reshape(x.shape[0], -1))

        return x


if __name__ == "__main__":
    r2_act = gspaces.Rot2dOnR2(N=8)
    # the input image is a scalar field, corresponding to the trivial representation
    in_type = nn.FieldType(r2_act, 3 * [r2_act.trivial_repr])

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dummy = torch.rand(64, 3, 32, 32).to(device)

    y = nn.GeometricTensor(dummy, in_type)

    # TEST 2
    r2_act = gspaces.Rot2dOnR2(N=16)  # 5
    print(r2_act.trivial_repr)
    feat_type_in = nn.FieldType(r2_act, 3 * [r2_act.trivial_repr])  # 6
    feat_type_out = nn.FieldType(r2_act, 10 * [r2_act.regular_repr])  # 7
    #  8
    conv = nn.R2Conv(feat_type_in, feat_type_out, kernel_size=5)  # 9
Ejemplo n.º 18
0
    def __init__(self,
                 base = 'DNSteerableAGRadGalNet',
                 attention_module='SelfAttention',
                 attention_gates=3,
                 attention_aggregation='ft',
                 n_classes=2,
                 attention_normalisation='sigmoid',
                 quiet=True,
                 number_rotations=8,
                 imsize=150,
                 kernel_size=3,
                 group="D"
                ):
        super(DNSteerableAGRadGalNet, self).__init__()
        aggregation_mode = attention_aggregation
        normalisation = attention_normalisation
        AG = int(attention_gates)
        N = int(number_rotations)
        kernel_size = int(kernel_size)
        imsize = int(imsize)
        n_classes = int(n_classes)
        assert aggregation_mode in ['concat', 'mean', 'deep_sup', 'ft'], 'Aggregation mode not recognised. Valid inputs include concat, mean, deep_sup or ft.'
        assert normalisation in ['sigmoid','range_norm','std_mean_norm','tanh','softmax'], f'Nomralisation not implemented. Can be any of: sigmoid, range_norm, std_mean_norm, tanh, softmax'
        assert AG in [0,1,2,3], f'Number of Attention Gates applied (AG) must be an integer in range [0,3]. Currently AG={AG}'
        assert group.lower() in ["d","c"], f"group parameter must either be 'D' for DN, or 'C' for CN, steerable networks. (currently {group})."
        filters = [6,16,32,64,128]

        self.attention_out_sizes = []
        self.ag = AG
        self.n_classes = n_classes
        self.filters = filters
        self.aggregation_mode = aggregation_mode

        # Setting up e2
        if group.lower() == "d":
            self.r2_act = gspaces.FlipRot2dOnR2(N=int(number_rotations))
        else:
            self.r2_act = gspaces.Rot2dOnR2(N=int(number_rotations))
        in_type = e2nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
        out_type = e2nn.FieldType(self.r2_act, 6*[self.r2_act.regular_repr])
        self.in_type = in_type

        self.mask = e2nn.MaskModule(in_type, imsize, margin=0)
        self.conv1a = e2nn.R2Conv(in_type,  out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu1a = e2nn.ReLU(out_type); self.bnorm1a= e2nn.InnerBatchNorm(out_type)
        self.conv1b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu1b = e2nn.ReLU(out_type); self.bnorm1b= e2nn.InnerBatchNorm(out_type)
        self.conv1c = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu1c = e2nn.ReLU(out_type); self.bnorm1c= e2nn.InnerBatchNorm(out_type)
        self.mpool1 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2)
        self.gpool1 = e2nn.GroupPooling(out_type)


        in_type = out_type
        out_type = e2nn.FieldType(self.r2_act, 16*[self.r2_act.regular_repr])
        self.conv2a = e2nn.R2Conv(in_type,  out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu2a = e2nn.ReLU(out_type); self.bnorm2a= e2nn.InnerBatchNorm(out_type)
        self.conv2b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu2b = e2nn.ReLU(out_type); self.bnorm2b= e2nn.InnerBatchNorm(out_type)
        self.conv2c = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu2c = e2nn.ReLU(out_type); self.bnorm2c= e2nn.InnerBatchNorm(out_type)
        self.mpool2 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2)
        self.gpool2 = e2nn.GroupPooling(out_type)

        in_type = out_type
        out_type = e2nn.FieldType(self.r2_act, 32*[self.r2_act.regular_repr])
        self.conv3a = e2nn.R2Conv(in_type,  out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu3a = e2nn.ReLU(out_type); self.bnorm3a= e2nn.InnerBatchNorm(out_type)
        self.conv3b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu3b = e2nn.ReLU(out_type); self.bnorm3b= e2nn.InnerBatchNorm(out_type)
        self.conv3c = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu3c = e2nn.ReLU(out_type); self.bnorm3c= e2nn.InnerBatchNorm(out_type)
        self.mpool3 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2)
        self.gpool3 = e2nn.GroupPooling(out_type)

        in_type = out_type
        out_type = e2nn.FieldType(self.r2_act, 64*[self.r2_act.regular_repr])
        self.conv4a = e2nn.R2Conv(in_type,  out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu4a = e2nn.ReLU(out_type); self.bnorm4a= e2nn.InnerBatchNorm(out_type)
        self.conv4b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu4b = e2nn.ReLU(out_type); self.bnorm4b= e2nn.InnerBatchNorm(out_type)
        self.mpool4 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2)
        self.gpool4 = e2nn.GroupPooling(out_type)

        self.flatten = nn.Flatten(1)
        self.dropout = nn.Dropout(p=0.5)

        if self.ag == 0:
            pass
        if self.ag >= 1:
            self.attention1 = GridAttentionBlock2D(in_channels=32, gating_channels=64, inter_channels=64, input_size=[imsize//4,imsize//4], normalisation=normalisation)
        if self.ag >= 2:
            self.attention2 = GridAttentionBlock2D(in_channels=16, gating_channels=64, inter_channels=64, input_size=[imsize//2,imsize//2], normalisation=normalisation)
        if self.ag >= 3:
            self.attention3 = GridAttentionBlock2D(in_channels=6, gating_channels=64, inter_channels=64, input_size=[imsize,imsize], normalisation=normalisation)

        self.fc1 = nn.Linear(16*5*5,256) #channel_size * width * height
        self.fc2 = nn.Linear(256,256)
        self.fc3 = nn.Linear(256, self.n_classes)
        self.dummy = nn.Parameter(torch.empty(0))

        self.module_order = ['conv1a', 'relu1a', 'bnorm1a', #1->6
                             'conv1b', 'relu1b', 'bnorm1b', #6->6
                             'conv1c', 'relu1c', 'bnorm1c', #6->6
                             'mpool1',
                             'conv2a', 'relu2a', 'bnorm2a', #6->16
                             'conv2b', 'relu2b', 'bnorm2b', #16->16
                             'conv2c', 'relu2c', 'bnorm2c', #16->16
                             'mpool2',
                             'conv3a', 'relu3a', 'bnorm3a', #16->32
                             'conv3b', 'relu3b', 'bnorm3b', #32->32
                             'conv3c', 'relu3c', 'bnorm3c', #32->32
                             'mpool3',
                             'conv4a', 'relu4a', 'bnorm4a', #32->64
                             'conv4b', 'relu4b', 'bnorm4b', #64->64
                             'compatibility_score1',
                             'compatibility_score2']


        #########################
        # Aggreagation Strategies
        if self.ag != 0:
            self.attention_filter_sizes = [32, 16, 6]
            concat_length = 0
            for i in range(self.ag):
                concat_length += self.attention_filter_sizes[i]
            if aggregation_mode == 'concat':
                self.classifier = nn.Linear(concat_length, self.n_classes)
                self.aggregate = self.aggregation_concat
            else:
                # Not able to initialise in a loop as the modules will not change device with remaining model.
                self.classifiers = nn.ModuleList()
                if self.ag>=1:
                    self.classifiers.append(nn.Linear(self.attention_filter_sizes[0], self.n_classes))
                if self.ag>=2:
                    self.classifiers.append(nn.Linear(self.attention_filter_sizes[1], self.n_classes))
                if self.ag>=3:
                    self.classifiers.append(nn.Linear(self.attention_filter_sizes[2], self.n_classes))
                if aggregation_mode == 'mean':
                    self.aggregate = self.aggregation_sep
                elif aggregation_mode == 'deep_sup':
                    self.classifier = nn.Linear(concat_length, self.n_classes)
                    self.aggregate = self.aggregation_ds
                elif aggregation_mode == 'ft':
                    self.classifier = nn.Linear(self.n_classes*self.ag, self.n_classes)
                    self.aggregate = self.aggregation_ft
                else:
                    raise NotImplementedError
        else:
            self.classifier = nn.Linear((150//16)**2*64, self.n_classes)
            self.aggregate = lambda x: self.classifier(self.flatten(x))
Ejemplo n.º 19
0
from mnist_rot_dataset import RotatedMNISTDataset, TransformedDataset
from models import C8Backbone3x3, Backbone5x5, ClassificationHead, RegressionHead

from functools import partial

import matplotlib.pyplot as plt

torch.manual_seed(23)
torch.cuda.manual_seed(23)
np.random.seed(23)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 128

group=gspaces.Rot2dOnR2(N=8)
group=gspaces.FlipRot2dOnR2(N=8)

conv_func = nn.R2Conv
# conv_func = sscnn.e2cnn.SSConv
# conv_func = sscnn.e2cnn.PlainConv
backbone = Backbone5x5(out_channels=2, conv_func=conv_func, group=group)
head = RegressionHead(backbone.out_type, conv_func)

dataset_func = partial(torchvision.datasets.MNIST,
        transform=transforms.ToTensor())
# dataset_func = partial(torchvision.datasets.CIFAR10,
#         transform=transforms.Compose([
#             transforms.Grayscale(num_output_channels=1),
#             transforms.ToTensor(),
#             transforms.Normalize(mean=[0.5], std=[0.5])
Ejemplo n.º 20
0
def train(args):
    torch.manual_seed(23)
    torch.cuda.manual_seed(23)
    np.random.seed(23)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    batch_size = 128

    if args.reflection == 1:
        group = gspaces.Rot2dOnR2(N=args.rotation)
    elif args.reflection == 2:
        group = gspaces.FlipRot2dOnR2(N=args.rotation)
    else:
        raise NotImplementedError

    if args.conv == 'R2Conv':
        conv_func = nn.R2Conv
    elif args.conv == 'SSConv':
        conv_func = sscnn.e2cnn.SSConv
    elif args.conv == 'PlainConv':
        conv_func = sscnn.e2cnn.PlainConv
    else:
        raise NotImplementedError

    if args.dataset == 'MNIST':
        train_dataset = TransformedDataset(
            torchvision.datasets.MNIST(
                '.',
                train=True,
                download=True,
                transform=transforms.ToTensor(),
            ),
            random_rotate=args.rotate_train,
            random_reflect=args.reflect_train,
        )
        test_dataset = TransformedDataset(
            torchvision.datasets.MNIST(
                '.',
                train=False,
                download=True,
                transform=transforms.ToTensor(),
            ),
            random_rotate=args.rotate_test,
            random_reflect=args.reflect_test,
        )
        in_channels = 1
        num_classes = 10
    elif args.dataset == 'KMNIST':
        train_dataset = TransformedDataset(
            torchvision.datasets.KMNIST(
                '.',
                train=True,
                download=True,
                transform=transforms.ToTensor(),
            ),
            random_rotate=args.rotate_train,
            random_reflect=args.reflect_train,
        )
        test_dataset = TransformedDataset(
            torchvision.datasets.KMNIST(
                '.',
                train=False,
                download=True,
                transform=transforms.ToTensor(),
            ),
            random_rotate=args.rotate_test,
            random_reflect=args.reflect_test,
        )
        in_channels = 1
        num_classes = 10
    elif args.dataset == 'EMNIST':
        train_dataset = TransformedDataset(
            torchvision.datasets.EMNIST(
                '.',
                split='balanced',
                train=True,
                download=True,
                transform=transforms.ToTensor(),
            ),
            random_rotate=args.rotate_train,
            random_reflect=args.reflect_train,
        )
        test_dataset = TransformedDataset(
            torchvision.datasets.EMNIST(
                '.',
                split='balanced',
                train=False,
                download=True,
                transform=transforms.ToTensor(),
            ),
            random_rotate=args.rotate_test,
            random_reflect=args.reflect_test,
        )
        in_channels = 1
        num_classes = 50
    elif args.dataset == 'FMNIST':
        train_dataset = TransformedDataset(
            torchvision.datasets.FashionMNIST(
                '.',
                train=True,
                download=True,
                transform=transforms.ToTensor(),
            ),
            random_rotate=args.rotate_train,
            random_reflect=args.reflect_train,
        )
        test_dataset = TransformedDataset(
            torchvision.datasets.FashionMNIST(
                '.',
                train=False,
                download=True,
                transform=transforms.ToTensor(),
            ),
            random_rotate=args.rotate_test,
            random_reflect=args.reflect_test,
        )
        in_channels = 1
        num_classes = 10
    elif args.dataset == 'CIFAR10':
        train_dataset = TransformedDataset(
            torchvision.datasets.CIFAR10('.',
                                         train=True,
                                         download=True,
                                         transform=transforms.Compose([
                                             transforms.RandomCrop(32,
                                                                   padding=4),
                                             transforms.RandomHorizontalFlip(),
                                             transforms.ToTensor(),
                                             transforms.Normalize(
                                                 (0.4914, 0.4822, 0.4465),
                                                 (0.2023, 0.1994, 0.2010)),
                                         ])),
            random_rotate=args.rotate_train,
            random_reflect=args.reflect_train,
            disk_masked=(args.task == 'regression'),
        )
        test_dataset = TransformedDataset(
            torchvision.datasets.CIFAR10('.',
                                         train=False,
                                         download=True,
                                         transform=transforms.Compose([
                                             transforms.ToTensor(),
                                             transforms.Normalize(
                                                 (0.4914, 0.4822, 0.4465),
                                                 (0.2023, 0.1994, 0.2010)),
                                         ])),
            random_rotate=args.rotate_test,
            random_reflect=args.reflect_test,
            disk_masked=(args.task == 'regression'),
        )
        in_channels = 3
        num_classes = 10
    elif args.dataset == 'CIFAR100':
        train_dataset = TransformedDataset(
            torchvision.datasets.CIFAR100(
                '.',
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5071, 0.4867, 0.4408),
                                         (0.2675, 0.2565, 0.2761)),
                ])),
            random_rotate=args.rotate_train,
            random_reflect=args.reflect_train,
            disk_masked=(args.task == 'regression'),
        )
        test_dataset = TransformedDataset(
            torchvision.datasets.CIFAR100('.',
                                          train=False,
                                          download=True,
                                          transform=transforms.Compose([
                                              transforms.ToTensor(),
                                              transforms.Normalize(
                                                  (0.5071, 0.4867, 0.4408),
                                                  (0.2675, 0.2565, 0.2761)),
                                          ])),
            random_rotate=args.rotate_test,
            random_reflect=args.reflect_test,
            disk_masked=(args.task == 'regression'),
        )
        num_classes = 100
    elif args.dataset == 'STL10':
        train_dataset = TransformedDataset(
            torchvision.datasets.STL10('.',
                                       split='train',
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.RandomCrop(96,
                                                                 padding=12),
                                           transforms.RandomHorizontalFlip(),
                                           transforms.ToTensor(),
                                       ])),
            random_rotate=args.rotate_train,
            random_reflect=args.reflect_train,
            disk_masked=(args.task == 'regression'),
        )
        test_dataset = TransformedDataset(
            torchvision.datasets.STL10('.',
                                       split='test',
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                       ])),
            random_rotate=args.rotate_test,
            random_reflect=args.reflect_test,
            disk_masked=(args.task == 'regression'),
        )
        num_classes = 50
        batch_size = 12
        in_channels = 3
    else:
        raise NotImplementedError

    if args.backbone == 'B5':
        backbone = Backbone5x5(conv_func=conv_func,
                               group=group,
                               in_channels=in_channels)
    elif args.backbone == 'WRN':
        backbone = Wide_ResNet(28,
                               10,
                               0.3,
                               initial_stride=2,
                               N=args.rotation,
                               f=(args.reflection == 2),
                               r=0,
                               conv_func=conv_func,
                               fixparams=False,
                               in_channels=in_channels)
    else:
        raise NotImplementedError

    if args.task == 'classification':
        head = ClassificationHead(backbone.out_type, num_classes=num_classes)
        cross_entropy_loss = torch.nn.CrossEntropyLoss()
        loss_function = lambda y, l, v: cross_entropy_loss(y, l)
        eval_function = mask_of_success
    elif args.task == 'regression':
        head = RegressionHead(backbone.out_type, conv_func)
        mse_loss = torch.nn.MSELoss()
        loss_function = lambda y, l, v: mse_loss(y, v)
        eval_function = abs_included_angles
    else:
        raise NotImplementedError

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size)

    model = nn.SequentialModule(
        OrderedDict([('backbone', backbone), ('head', head)]))
    model = model.to(device)

    if args.backbone == 'B5':  # args.dataset.endswith('MNIST'):
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=5e-5,
                                     weight_decay=1e-5)
        scheduler = None
        max_epochs = 60
    elif args.backbone == 'WRN':  # elif args.dataset == 'CIFAR10' or args.dataset == 'STL10':
        if args.dataset == 'STL10':
            base_lr = 2e-3
        else:
            base_lr = 1e-2
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=base_lr,
                                    momentum=0.9,
                                    weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=45,
                                                    gamma=0.2)
        max_epochs = 260
    else:
        raise NotImplementedError

    file_name = '_'.join([str(_) for _ in args.__dict__.values()])
    ckpt_name = file_name + '.pth'
    log_name = file_name + '.txt'
    log_file = open(log_name, 'w')

    for epoch in range(max_epochs + 2):
        model.train()

        if device == 'cuda':
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
        else:
            start = time.time()

        for i, (x, l, v) in enumerate(train_loader):

            optimizer.zero_grad()

            x = x.to(device)
            l = l.to(device)
            v = v.to(device)

            y = model(nn.GeometricTensor(x, backbone.input_type))
            y = y.tensor.flatten(1, -1)

            loss = loss_function(y, l, v)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 10)

            # nv =  [(n, v.abs().max().item()) for (n, v) in model.named_parameters()]
            # imax = max(range(len(nv)), key=lambda x: nv[x][1])
            # print(loss.item(), nv[imax])

            optimizer.step()

            # if i > 50:
            #     break

        if scheduler:
            scheduler.step()
        for param_group in optimizer.param_groups:
            lr = param_group['lr']
            print('lr = ', param_group['lr'])

        if device == 'cuda':
            end.record()
            torch.cuda.synchronize()
            print('Epoch', start.elapsed_time(end))
        else:
            end = time.time()
            print('Epoch', end - start)

        if epoch % 5 == 0:
            errors = []
            with torch.no_grad():
                model.eval()
                for i, (x, l, v) in enumerate(test_loader):

                    x = x.to(device)
                    l = l.to(device)
                    v = v.to(device)

                    y = model(nn.GeometricTensor(x, backbone.input_type))
                    y = y.tensor.flatten(1, -1)

                    res = eval_function(y, l, v)
                    errors.extend(res)

            error = np.array(errors).mean()
            print(f"epoch {epoch} | tes : {error}")
            log_file.write(f"epoch {epoch} | acc : {error} | lr : {lr}\n")
            log_file.flush()
    def __init__(self):
        super(UNet, self).__init__()

        self.r2_act = gspaces.Rot2dOnR2(N=8)

        self.field_type_1 = enn.FieldType(self.r2_act,
                                          1 * [self.r2_act.regular_repr])
        self.field_type_3 = enn.FieldType(self.r2_act,
                                          3 * [self.r2_act.trivial_repr])
        self.field_type_8 = enn.FieldType(self.r2_act,
                                          8 * [self.r2_act.regular_repr])
        self.field_type_16 = enn.FieldType(self.r2_act,
                                           16 * [self.r2_act.regular_repr])
        self.field_type_32 = enn.FieldType(self.r2_act,
                                           32 * [self.r2_act.regular_repr])
        self.field_type_64 = enn.FieldType(self.r2_act,
                                           64 * [self.r2_act.regular_repr])
        self.field_type_128 = enn.FieldType(self.r2_act,
                                            128 * [self.r2_act.regular_repr])

        self.conv1 = Conv(in_type=self.field_type_3,
                          out_type=self.field_type_8)
        self.conv2 = Conv(in_type=self.field_type_8,
                          out_type=self.field_type_8)

        self.down1 = DownSample(in_type=self.field_type_8,
                                out_type=self.field_type_16)
        self.conv12 = Conv(in_type=self.field_type_16,
                           out_type=self.field_type_16)
        self.down2 = DownSample(in_type=self.field_type_16,
                                out_type=self.field_type_32)
        self.conv22 = Conv(in_type=self.field_type_32,
                           out_type=self.field_type_32)
        self.down3 = DownSample(in_type=self.field_type_32,
                                out_type=self.field_type_32)
        self.conv32 = Conv(in_type=self.field_type_32,
                           out_type=self.field_type_32)

        self.up41 = UpSample(in_type=self.field_type_64,
                             out_type=self.field_type_32,
                             mid_type=self.field_type_32)
        self.up31 = UpSample(in_type=self.field_type_64,
                             out_type=self.field_type_16,
                             mid_type=self.field_type_32)
        self.up21 = UpSample(in_type=self.field_type_32,
                             out_type=self.field_type_8,
                             mid_type=self.field_type_16)
        self.up11 = LastConcat(in_type=self.field_type_16,
                               out_type=self.field_type_1,
                               mid_type=self.field_type_8)

        self.up42 = UpSample(in_type=self.field_type_64,
                             out_type=self.field_type_32,
                             mid_type=self.field_type_32)
        self.up32 = UpSample(in_type=self.field_type_64,
                             out_type=self.field_type_16,
                             mid_type=self.field_type_32)
        self.up22 = UpSample(in_type=self.field_type_32,
                             out_type=self.field_type_8,
                             mid_type=self.field_type_16)
        self.up12 = LastConcat(in_type=self.field_type_16,
                               out_type=self.field_type_1,
                               mid_type=self.field_type_8)

        self.up43 = UpSample(in_type=self.field_type_64,
                             out_type=self.field_type_32,
                             mid_type=self.field_type_32)
        self.up33 = UpSample(in_type=self.field_type_64,
                             out_type=self.field_type_16,
                             mid_type=self.field_type_32)
        self.up23 = UpSample(in_type=self.field_type_32,
                             out_type=self.field_type_8,
                             mid_type=self.field_type_16)
        self.up13 = LastConcat(in_type=self.field_type_16,
                               out_type=self.field_type_1,
                               mid_type=self.field_type_8)

        self.gpool1 = enn.GroupPooling(self.field_type_1)
        self.gpool2 = enn.GroupPooling(self.field_type_1)
        self.gpool3 = enn.GroupPooling(self.field_type_1)
Ejemplo n.º 22
0
# Set default Orientation=8, .i.e, the group C8
# One can change it by passing the env Orientation=xx
Orientation = 8
# keep similar computation or similar params
# One can change it by passing the env fixparams=True
fixparams = False
if 'Orientation' in os.environ:
    Orientation = int(os.environ['Orientation'])
if 'fixparams' in os.environ:
    fixparams = True
print('ReResNet Orientation: {}\tFix Params: {}'.format(
    Orientation, fixparams))

# define the equivariant group. We use C8 group by default.
gspace = gspaces.Rot2dOnR2(N=Orientation)


def regular_feature_type(gspace: gspaces.GSpace, planes: int):
    """ build a regular feature map with the specified number of channels"""
    assert gspace.fibergroup.order() > 0
    N = gspace.fibergroup.order()
    if fixparams:
        planes *= math.sqrt(N)
    planes = planes / N
    planes = int(planes)
    return enn.FieldType(gspace, [gspace.regular_repr] * planes)


def trivial_feature_type(gspace: gspaces.GSpace,
                         planes: int,
    else:
        sys.exit("Unknown architecture type.")
    
    if ARGS['DIV_FREE']:
        print("Used div free kernel in the output")
        CNP=steercnp.SteerCNP(encoder,decoder,ARGS['DIM_COV_EST'],dim_context_feat=2,l_scale=ARGS['LENGTH_SCALE_OUT'],
                            kernel_dict_out={'kernel_type':"div_free"},normalize_output=False)
    else:
        CNP=steercnp.SteerCNP(encoder,decoder,ARGS['DIM_COV_EST'],dim_context_feat=2,l_scale=ARGS['LENGTH_SCALE_OUT'])

#If equivariance is wanted, create the group and the fieldtype for the equivariance:
if ARGS['TESTING_GROUP']=='D4':
    G_act=gspaces.FlipRot2dOnR2(N=4)
    feature_in=G_CNN.FieldType(G_act,[G_act.irrep(1,1)])
elif ARGS['TESTING_GROUP']=='C16':
    G_act=gspaces.Rot2dOnR2(N=16)
    feature_in=G_CNN.FieldType(G_act,[G_act.irrep(1)])
else:
    G_act=None
    feature_in=None


print("Number of parameters: ", my_utils.count_parameters(CNP,print_table=False))

CNP,_,_=training.train_cnp(CNP,
                           train_dataset=train_dataset,
                           val_dataset=val_dataset,
                           data_identifier=data_IDENTIFIER,
                           device=DEVICE,
                           minibatch_size=ARGS['BATCH_SIZE'],
                           n_epochs=ARGS['N_EPOCHS'],
ACT_FNS = {
    'relu': enn.ReLU,
    'elu': enn.ELU,
    'gated': enn.GatedNonLinearity1,
    'swish': base_layers.GeomSwish,
}

GROUPS = {
    'fliprot16': gspaces.FlipRot2dOnR2(N=16),
    'fliprot12': gspaces.FlipRot2dOnR2(N=12),
    'fliprot8': gspaces.FlipRot2dOnR2(N=8),
    'fliprot4': gspaces.FlipRot2dOnR2(N=4),
    'fliprot2': gspaces.FlipRot2dOnR2(N=2),
    'flip': gspaces.Flip2dOnR2(),
    'rot16': gspaces.Rot2dOnR2(N=16),
    'rot12': gspaces.Rot2dOnR2(N=12),
    'rot8': gspaces.Rot2dOnR2(N=8),
    'rot4': gspaces.Rot2dOnR2(N=4),
    'rot2': gspaces.Rot2dOnR2(N=2),
    'so2': gspaces.Rot2dOnR2(N=-1, maximum_frequency=10),
    'o2': gspaces.FlipRot2dOnR2(N=-1, maximum_frequency=10),
}

FIBERS = {
    "trivial": trivial_fiber,
    "quotient": quotient_fiber,
    "regular": regular_fiber,
    "irrep": irrep_fiber,
    "mixed1": mixed1_fiber,
    "mixed2": mixed2_fiber,
Ejemplo n.º 25
0
 def __init__(self, input_channels, output_channels, N, last_deconv = False):
     super(rot_deconv2d, self).__init__()
     self.conv2d = rot_conv2d(input_channels = input_channels, output_channels = output_channels, kernel_size = 4, 
                          activation = True, stride = 1, N = N, deconv = True, last_deconv = last_deconv)
     r2_act = gspaces.Rot2dOnR2(N = N)
     self.feat_type = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
Ejemplo n.º 26
0
    def __init__(self, depth, widen_factor, dropout_rate, num_classes=100,
                 N: int = 8,
                 r: int = 1,
                 f: bool = True,
                 deltaorth: bool = False,
                 fixparams: bool = True,
                 initial_stride: int = 1,
                 ):
        r"""
        
        Build and equivariant Wide ResNet.
        
        The parameter ``N`` controls rotation equivariance and the parameter ``f`` reflection equivariance.
        
        More precisely, ``N`` is the number of discrete rotations the model is initially equivariant to.
        ``N = 1`` means the model is only reflection equivariant from the beginning.
        
        ``f`` is a boolean flag specifying whether the model should be reflection equivariant or not.
        If it is ``False``, the model is not reflection equivariant.
        
        ``r`` is the restriction level:
        
        - ``0``: no restriction. The model is equivariant to ``N`` rotations from the input to the output

        - ``1``: restriction before the last block. The model is equivariant to ``N`` rotations before the last block
               (i.e. in the first 2 blocks). Then it is restricted to ``N/2`` rotations until the output.
        
        - ``2``: restriction after the first block. The model is equivariant to ``N`` rotations in the first block.
               Then it is restricted to ``N/2`` rotations until the output (i.e. in the last 3 blocks).
               
        - ``3``: restriction after the first and the second block. The model is equivariant to ``N`` rotations in the first
               block. It is restricted to ``N/2`` rotations before the second block and to ``1`` rotations before the last
               block.
        
        NOTICE: if restriction to ``N/2`` is performed, ``N`` needs to be even!
        
        """
        super(Wide_ResNet, self).__init__()
        
        assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
        n = int((depth - 4) / 6)
        k = widen_factor
        
        print(f'| Wide-Resnet {depth}x{k}')
        
        nStages = [16, 16 * k, 32 * k, 64 * k]
        
        self._fixparams = fixparams
        
        self._layer = 0
        
        # number of discrete rotations to be equivariant to
        self._N = N
        
        # if the model is [F]lip equivariant
        self._f = f
        if self._f:
            if N != 1:
                self.gspace = gspaces.FlipRot2dOnR2(N)
            else:
                self.gspace = gspaces.Flip2dOnR2()
        else:
            if N != 1:
                self.gspace = gspaces.Rot2dOnR2(N)
            else:
                self.gspace = gspaces.TrivialOnR2()

        # level of [R]estriction:
        #   r = 0: never do restriction, i.e. initial group (either DN or CN) preserved for the whole network
        #   r = 1: restrict before the last block, i.e. initial group (either DN or CN) preserved for the first
        #          2 blocks, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the last block
        #   r = 2: restrict after the first block, i.e. initial group (either DN or CN) preserved for the first
        #          block, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the last 2 blocks
        #   r = 3: restrict after each block. Initial group (either DN or CN) preserved for the first
        #          block, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the second block and to 1 rotation
        #          in the last one (D1 or C1)
        assert r in [0, 1, 2, 3]
        self._r = r
        
        # the input has 3 color channels (RGB).
        # Color channels are trivial fields and don't transform when the input is rotated or flipped
        r1 = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3)
        
        # input field type of the model
        self.in_type = r1
        
        # in the first layer we always scale up the output channels to allow for enough independent filters
        r2 = FIELD_TYPE["regular"](self.gspace, nStages[0], fixparams=True)
        
        # dummy attribute keeping track of the output field type of the last submodule built, i.e. the input field type of
        # the next submodule to build
        self._in_type = r2
        
        self.conv1 = conv5x5(r1, r2)
        self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=initial_stride)
        if self._r >= 2:
            N_new = N//2
            id = (0, N_new) if self._f else N_new
            self.restrict1 = self._restrict_layer(id)
        else:
            self.restrict1 = lambda x: x
        
        self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2)
        if self._r == 3:
            id = (0, 1) if self._f else 1
            self.restrict2 = self._restrict_layer(id)
        elif self._r == 1:
            N_new = N // 2
            id = (0, N_new) if self._f else N_new
            self.restrict2 = self._restrict_layer(id)
        else:
            self.restrict2 = lambda x: x
        
        # last layer maps to a trivial (invariant) feature map
        self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2, totrivial=True)
        
        self.bn = enn.InnerBatchNorm(self.layer3.out_type, momentum=0.9)
        self.relu = enn.ReLU(self.bn.out_type, inplace=True)
        self.linear = torch.nn.Linear(self.bn.out_type.size, num_classes)
        
        for name, module in self.named_modules():
            if isinstance(module, enn.R2Conv):
                if deltaorth:
                    init.deltaorthonormal_init(module.weights, module.basisexpansion)
            elif isinstance(module, torch.nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()
            elif isinstance(module, torch.nn.Linear):
                module.bias.data.zero_()
        
        print("MODEL TOPOLOGY:")
        for i, (name, mod) in enumerate(self.named_modules()):
            print(f"\t{i} - {name}")
Ejemplo n.º 27
0
    def __init__(self, n_classes=10):

        super(C8SteerableCNN, self).__init__()

        # the model is equivariant under rotations by 45 degrees, modelled by C8
        self.r2_act = gspaces.Rot2dOnR2(N=8)

        # the input image is a scalar field, corresponding to the trivial representation
        in_type = nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])

        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = in_type

        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 16 feature fields, each transforming under the regular representation of C8
        out_type = nn.FieldType(self.r2_act, 24 * [self.r2_act.regular_repr])
        self.block1 = nn.SequentialModule(
            # nn.MaskModule(in_type, 29, margin=1),
            nn.R2Conv(in_type, out_type, kernel_size=7, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True))

        # convolution 2
        # the old output type is the input type to the next layer
        in_type = self.block1.out_type
        # the output type of the second convolution layer are 32 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr])
        self.block2 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True))
        self.pool1 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2))

        # convolution 3
        # the old output type is the input type to the next layer
        in_type = self.block2.out_type
        # the output type of the third convolution layer are 32 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr])
        self.block3 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True))

        # convolution 4
        # the old output type is the input type to the next layer
        in_type = self.block3.out_type
        # the output type of the fourth convolution layer are 64 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 96 * [self.r2_act.regular_repr])
        self.block4 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True))
        self.pool2 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2))

        # convolution 5
        # the old output type is the input type to the next layer
        in_type = self.block4.out_type
        # the output type of the fifth convolution layer are 64 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 96 * [self.r2_act.regular_repr])
        self.block5 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True))

        # convolution 6
        # the old output type is the input type to the next layer
        in_type = self.block5.out_type
        # the output type of the sixth convolution layer are 64 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 64 * [self.r2_act.regular_repr])
        self.block6 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True))
        self.pool3 = nn.PointwiseAvgPool(out_type, kernel_size=4)

        self.gpool = nn.GroupPooling(out_type)

        # number of output channels
        c = self.gpool.out_type.size

        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(c, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )
Ejemplo n.º 28
0
    def restrict(self, id: Tuple[Union[None, float, int], int]) -> Tuple[gspaces.GSpace, Callable, Callable]:
        r"""

        Build the :class:`~e2cnn.group.GSpace` associated with the subgroup of the current fiber group identified by
        the input ``id``, which is a tuple :math:`(k, M)`.
        
        Here, :math:`M` is a positive integer indicating the number of discrete rotations in the subgroup while
        :math:`k` is either ``None`` (no reflections) or an angle indicating the axis of reflection.
        If the current fiber group is :math:`D_N` (:class:`~e2cnn.group.DihedralGroup`), then :math:`M` needs to divide
        :math:`N` and :math:`k` needs to be an integer in :math:`\{0, \dots, \frac{N}{M}-1\}`.
        Otherwise, :math:`M` can be any positive integer while :math:`k` needs to be a real number in
        :math:`[0, \frac{2\pi}{M}]`.
        
        Valid combinations are:
        
        - (``None``, :math:`1`): restrict to no reflection and rotation symmetries
        
        - (``None``, :math:`M`): restrict to only the :math:`M` rotations generated by :math:`r_{2\pi/M}`.
        
        - (:math:`0`, :math:`1`): restrict to only reflections :math:`\langle f \rangle` around the same axis as in the current group
        
        - (:math:`0`, :math:`M`): restrict to reflections and :math:`M` rotations generated by :math:`r_{2\pi/M}` and :math:`f`
        
        If the current fiber group is :math:`D_N` (an instance of :class:`~e2cnn.group.DihedralGroup`):
        
        - (:math:`k`, :math:`M`): restrict to reflections :math:`\langle r_{k\frac{2\pi}{N}} f \rangle` around the axis of the current G-space rotated by :math:`k\frac{\pi}{N}` and :math:`M` rotations generated by :math:`r_{2\pi/M}`
        
        If the current fiber group is :math:`O(2)` (an instance of :class:`~e2cnn.group.O2`):
        
        - (:math:`\theta`, :math:`M`): restrict to reflections :math:`\langle r_{\theta} f \rangle` around the axis of the current G-space rotated by :math:`\frac{\theta}{2}` and :math:`M` rotations generated by :math:`r_{2\pi/M}`
        
        - (``None``, :math:`-1`): restrict to all (continuous) rotations
        
        Args:
            id (tuple): the id of the subgroup

        Returns:
            a tuple containing

                - **gspace**: the restricted gspace

                - **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space

                - **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup)


        """
    
        subgroup, mapping, child = self.fibergroup.subgroup(id)
        
        if id[0] is not None:
            # the new flip axis is the previous one rotated by the new chosen axis for the flip
            # notice that the actual group element used to generate the subgroup does not correspond to the flip axis
            # but to 2 times that angle
            
            if self.fibergroup.order() > 1:
                n = self.fibergroup.rotation_order
                rotation = id[0] * 2.0 * np.pi / n
            else:
                rotation = id[0]
                
            new_axis = divmod(self.axis + 0.5*rotation, 2*np.pi)[1]

        if id[0] is None and id[1] == 1:
            return gspaces.TrivialOnR2(fibergroup=subgroup), mapping, child
        elif id[0] is None and (id[1] > 1 or id[1] == -1):
            return gspaces.Rot2dOnR2(fibergroup=subgroup), mapping, child
        elif id[0] is not None and id[1] == 1:
            return gspaces.Flip2dOnR2(fibergroup=subgroup, axis=new_axis), mapping, child
        elif id[0] is not None:
            return gspaces.FlipRot2dOnR2(fibergroup=subgroup, axis=new_axis), mapping, child
        else:
            raise ValueError(f"id {id} not recognized!")
Ejemplo n.º 29
0
    def __init__(self,
                 hidden_reps_ids,
                 kernel_sizes,
                 dim_cov_est,
                 context_rep_ids=[1],
                 N=4,
                 flip=False,
                 non_linearity=["NormReLU"],
                 max_frequency=30):
        '''
        Input:  hidden_reps_ids - list: encoding the hidden fiber representation (see give_fib_reps_from_ids)
                kernel_sizes - list of ints - sizes of kernels for convolutional layers
                dim_cov_est - dimension of covariance estimation, either 1,2,3 or 4                
                context_rep_ids - list: gives the input fiber representation (see give_fib_reps_from_ids)
                non_linearity - list of strings - gives names of non-linearity to be used
                                    Either length 1 (then same non-linearity for all)
                                    or length is the number of layers (giving a custom non-linearity for every
                                    layer)   
                N - int - gives the group order, -1 is infinite
                flip - Bool - indicates whether we have a flip in the rotation group (i.e.O(2) vs SO(2), D_N vs C_N)
                max_frequency - int - maximum irrep frequency to computed, only relevant if N=-1
        '''

        super(SteerDecoder, self).__init__()
        #Save the rotation group, if flip is true, then include all corresponding reflections:
        self.flip = flip
        self.max_frequency = max_frequency

        if self.flip:
            self.G_act = gspaces.FlipRot2dOnR2(
                N=N) if N != -1 else gspaces.FlipRot2dOnR2(
                    N=N, maximum_frequency=self.max_frequency)
            #The output fiber representation is the identity:
            self.target_rep = self.G_act.irrep(1, 1)
        else:
            self.G_act = gspaces.Rot2dOnR2(
                N=N) if N != -1 else gspaces.Rot2dOnR2(
                    N=N, maximum_frequency=self.max_frequency)
            #The output fiber representation is the identity:
            self.target_rep = self.G_act.irrep(1)

        #Save the N defining D_N or C_N (if N=-1 it is infinity):
        self.polygon_corners = N

        #Save the id's for the context representation and extract the context fiber representation:
        self.context_rep_ids = context_rep_ids
        self.context_rep = group.directsum(
            self.give_reps_from_ids(self.context_rep_ids))

        #Save the parameters:
        self.kernel_sizes = kernel_sizes
        self.n_layers = len(hidden_reps_ids) + 2
        self.hidden_reps_ids = hidden_reps_ids
        self.dim_cov_est = dim_cov_est

        #-----CREATE LIST OF NON-LINEARITIES----
        if len(non_linearity) == 1:
            self.non_linearity = (self.n_layers - 2) * non_linearity
        elif len(non_linearity) != (self.n_layers - 2):
            sys.exit(
                "List of non-linearities invalid: must have either length 1 or n_layers-2"
            )
        else:
            self.non_linearity = non_linearity
        #-----ENDE LIST OF NON-LINEARITIES----

        #-----------CREATE DECODER-----------------
        '''
        Create a list of layers based on the kernel sizes. Compute the padding such
        that the height h and width w of a tensor with shape (batch_size,n_channels,h,w) does not change
        while being passed through the decoder
        '''
        #Create list of feature types:
        feat_types = self.give_feat_types()
        self.feature_emb = feat_types[0]
        self.feature_out = feat_types[-1]
        #Create layers list and append it:
        layers_list = [
            G_CNN.R2Conv(feat_types[0],
                         feat_types[1],
                         kernel_size=kernel_sizes[0],
                         padding=(kernel_sizes[0] - 1) // 2)
        ]
        for it in range(self.n_layers - 2):
            if self.non_linearity[it] == "ReLU":
                layers_list.append(G_CNN.ReLU(feat_types[it + 1],
                                              inplace=True))
            elif self.non_linearity[it] == "NormReLU":
                layers_list.append(G_CNN.NormNonLinearity(feat_types[it + 1]))
            else:
                sys.exit("Unknown non-linearity.")
            layers_list.append(
                G_CNN.R2Conv(feat_types[it + 1],
                             feat_types[it + 2],
                             kernel_size=kernel_sizes[it],
                             padding=(kernel_sizes[it] - 1) // 2))
        #Create a steerable decoder out of the layers list:
        self.decoder = G_CNN.SequentialModule(*layers_list)
        #-----------END CREATE DECODER---------------

        #-----------CONTROL INPUTS------------------
        #Control that all kernel sizes are odd (otherwise output shape is not correct):
        if any([j % 2 - 1 for j in kernel_sizes]):
            sys.exit("All kernels need to have odd sizes")
        if len(kernel_sizes) != (self.n_layers - 1):
            sys.exit("Number of layers and number kernels do not match.")
        if len(self.non_linearity) != (self.n_layers - 2):
            sys.exit(
                "Number of layers and number of non-linearities do not match.")