Example #1
0
 def __init__(self, input_channels, output_channels, kernel_size, stride, N, activation = True, deconv = False, last_deconv = False):
     super(rot_conv2d, self).__init__()       
     r2_act = gspaces.Rot2dOnR2(N = N)
     
     feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
     feat_type_hid = nn.FieldType(r2_act, output_channels*[r2_act.regular_repr])
     if not deconv:
         if activation:
             self.layer = nn.SequentialModule(
                 nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = (kernel_size - 1)//2),
                 nn.InnerBatchNorm(feat_type_hid),
                 nn.ReLU(feat_type_hid)
             ) 
         else:
             self.layer = nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride,padding = (kernel_size - 1)//2)
     else:
         if last_deconv:
             feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
             feat_type_hid = nn.FieldType(r2_act, output_channels*[r2_act.irrep(1)])
             self.layer = nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = 0)
         else:
             self.layer = nn.SequentialModule(
                     nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = 0),
                     nn.InnerBatchNorm(feat_type_hid),
                     nn.ReLU(feat_type_hid)
                 ) 
Example #2
0
 def __init__(self, 
              input_channels,
              hidden_dim, 
              kernel_size, 
              N # Group size 
             ): 
     super(rot_resblock, self).__init__()
     
     # Specify symmetry transformation
     r2_act = gspaces.Rot2dOnR2(N = N)
     feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
     feat_type_hid = nn.FieldType(r2_act, hidden_dim*[r2_act.regular_repr])
     
     self.layer1 = nn.SequentialModule(
         nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(feat_type_hid),
         nn.ReLU(feat_type_hid)
     ) 
     
     self.layer2 = nn.SequentialModule(
         nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(feat_type_hid),
         nn.ReLU(feat_type_hid)
     )    
     
     self.upscale = nn.SequentialModule(
         nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(feat_type_hid),
         nn.ReLU(feat_type_hid)
     )    
     
     self.input_channels = input_channels
     self.hidden_dim = hidden_dim
    def __init__(self, input_shape, num_actions, dueling_DQN):
        super(D4_steerable_DQN_Snake, self).__init__()
        self.input_shape = input_shape
        self.num_actions = num_actions
        self.dueling_DQN = dueling_DQN
        self.r2_act = gspaces.FlipRot2dOnR2(N=4)
        self.input_type = nn.FieldType(
            self.r2_act, input_shape[0] * [self.r2_act.trivial_repr])
        feature1_type = nn.FieldType(self.r2_act,
                                     8 * [self.r2_act.regular_repr])
        feature2_type = nn.FieldType(self.r2_act,
                                     12 * [self.r2_act.regular_repr])
        feature3_type = nn.FieldType(self.r2_act,
                                     12 * [self.r2_act.regular_repr])
        feature4_type = nn.FieldType(self.r2_act,
                                     32 * [self.r2_act.regular_repr])

        self.feature_field1 = nn.SequentialModule(
            nn.R2Conv(self.input_type,
                      feature1_type,
                      kernel_size=7,
                      padding=2,
                      stride=2,
                      bias=False), nn.ReLU(feature1_type, inplace=True))
        self.feature_field2 = nn.SequentialModule(
            nn.R2Conv(feature1_type,
                      feature2_type,
                      kernel_size=5,
                      padding=1,
                      stride=2,
                      bias=False), nn.ReLU(feature2_type, inplace=True))
        self.feature_field3 = nn.SequentialModule(
            nn.R2Conv(feature2_type,
                      feature3_type,
                      kernel_size=5,
                      padding=1,
                      stride=1,
                      bias=False), nn.ReLU(feature3_type, inplace=True))

        self.equivariant_features = nn.SequentialModule(
            nn.R2Conv(feature3_type,
                      feature4_type,
                      kernel_size=5,
                      stride=1,
                      bias=False), nn.ReLU(feature4_type, inplace=True))
        self.gpool = nn.GroupPooling(feature4_type)
        self.feature_shape()
        if self.dueling_DQN:
            print("You are using Dueling DQN")
            self.advantage = torch.nn.Linear(
                self.equivariant_features.out_type.size, self.num_actions)
            #self.value = torch.nn.Linear(self.gpool.out_type.size, 1)
            self.value = torch.nn.Linear(
                self.equivariant_features.out_type.size, 1)
        else:
            self.actionvalue = torch.nn.Linear(
                self.equivariant_features.out_type.size, self.num_actions)
Example #4
0
    def __init__(self, input_frames, output_frames, kernel_size, N):
        super(Unet_Rot, self).__init__()
        r2_act = gspaces.Rot2dOnR2(N = N)
        self.feat_type_in = nn.FieldType(r2_act, input_frames*[r2_act.irrep(1)])
        self.feat_type_in_hid = nn.FieldType(r2_act, 32*[r2_act.regular_repr])
        self.feat_type_hid_out = nn.FieldType(r2_act, (16 + input_frames)*[r2_act.irrep(1)])
        self.feat_type_out = nn.FieldType(r2_act, output_frames*[r2_act.irrep(1)])
        
        self.conv1 = nn.SequentialModule(
            nn.R2Conv(self.feat_type_in, self.feat_type_in_hid, kernel_size = kernel_size, stride = 2, padding = (kernel_size - 1)//2),
            nn.InnerBatchNorm(self.feat_type_in_hid),
            nn.ReLU(self.feat_type_in_hid)
        )

        self.conv2 = rot_conv2d(32, 64, kernel_size = kernel_size, stride = 1, N = N)
        self.conv2_1 = rot_conv2d(64, 64, kernel_size = kernel_size, stride = 1, N = N)
        self.conv3 = rot_conv2d(64, 128, kernel_size = kernel_size, stride = 2, N = N)
        self.conv3_1 = rot_conv2d(128, 128, kernel_size = kernel_size, stride = 1, N = N)
        self.conv4 = rot_conv2d(128, 256, kernel_size = kernel_size, stride = 2, N = N)
        self.conv4_1 = rot_conv2d(256, 256, kernel_size = kernel_size, stride = 1, N = N)

        self.deconv3 = rot_deconv2d(256, 64, N)
        self.deconv2 = rot_deconv2d(192, 32, N)
        self.deconv1 = rot_deconv2d(96, 16, N, last_deconv = True)

    
        self.output_layer = nn.R2Conv(self.feat_type_hid_out, self.feat_type_out, kernel_size = kernel_size, padding = (kernel_size - 1)//2)
 def restrict_layer(self, subgroup_id, feature_type):
     layers = list()
     layers.append(nn.RestrictionModule(feature_type, subgroup_id))
     layers.append(nn.DisentangleModule(layers[-1].out_type))
     self.input_feature_type = layers[-1].out_type
     self.r2_act = self.input_feature_type.gspace
     return nn.SequentialModule(*layers)
Example #6
0
    def __init__(self,
                 block,
                 num_blocks,
                 in_channels,
                 out_channels,
                 expansion=None,
                 stride=1,
                 avg_down=False,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 **kwargs):
        self.block = block
        self.expansion = get_expansion(block, expansion)

        downsample = None
        if stride != 1 or in_channels != out_channels:
            downsample = []
            conv_stride = stride
            if avg_down and stride != 1:
                conv_stride = 1
                downsample.append(
                    ennAvgPool(in_channels,
                               kernel_size=stride,
                               stride=stride,
                               ceil_mode=True))
            downsample.extend([
                build_conv_layer(conv_cfg,
                                 in_channels,
                                 out_channels,
                                 kernel_size=1,
                                 stride=conv_stride,
                                 bias=False),
                build_norm_layer(norm_cfg, out_channels)[1]
            ])
            downsample = enn.SequentialModule(*downsample)

        layers = []
        layers.append(
            block(in_channels=in_channels,
                  out_channels=out_channels,
                  expansion=self.expansion,
                  stride=stride,
                  downsample=downsample,
                  conv_cfg=conv_cfg,
                  norm_cfg=norm_cfg,
                  **kwargs))
        in_channels = out_channels
        for i in range(1, num_blocks):
            layers.append(
                block(in_channels=in_channels,
                      out_channels=out_channels,
                      expansion=self.expansion,
                      stride=1,
                      conv_cfg=conv_cfg,
                      norm_cfg=norm_cfg,
                      **kwargs))
        super(ResLayer, self).__init__(*layers)
Example #7
0
    def __init__(self):
        super(DenseFeatureExtractionModuleE2Inv, self).__init__()

        filters = np.array([32,32, 64,64, 128,128,128, 256,256,256, 512,512,512], dtype=np.int32)*2
        
        # number of rotations to consider for rotation invariance
        N = 8
        
        self.gspace = gspaces.Rot2dOnR2(N)
        self.input_type = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3)
        ip_op_types = [
            self.input_type,
        ]
        
        self.num_channels = 64

        for filter_ in filters[:10]:
            ip_op_types.append(FIELD_TYPE['regular'](self.gspace, filter_, fixparams=False))

        self.model = enn.SequentialModule(*[
            conv3x3(ip_op_types[0], ip_op_types[1]),
            enn.ReLU(ip_op_types[1], inplace=True),
            conv3x3(ip_op_types[1], ip_op_types[2]),
            enn.ReLU(ip_op_types[2], inplace=True),
            enn.PointwiseMaxPool(ip_op_types[2], 2),

            conv3x3(ip_op_types[2], ip_op_types[3]),
            enn.ReLU(ip_op_types[3], inplace=True),
            conv3x3(ip_op_types[3], ip_op_types[4]),
            enn.ReLU(ip_op_types[4], inplace=True),
            enn.PointwiseMaxPool(ip_op_types[4], 2),

            conv3x3(ip_op_types[4], ip_op_types[5]),
            enn.ReLU(ip_op_types[5], inplace=True),
            conv3x3(ip_op_types[5], ip_op_types[6]),
            enn.ReLU(ip_op_types[6], inplace=True),
            conv3x3(ip_op_types[6], ip_op_types[7]),
            enn.ReLU(ip_op_types[7], inplace=True),
            enn.PointwiseAvgPool(ip_op_types[7], kernel_size=2, stride=1),

            conv5x5(ip_op_types[7], ip_op_types[8]),
            enn.ReLU(ip_op_types[8], inplace=True),
            conv5x5(ip_op_types[8], ip_op_types[9]),
            enn.ReLU(ip_op_types[9], inplace=True),
            conv5x5(ip_op_types[9], ip_op_types[10]),
            enn.ReLU(ip_op_types[10], inplace=True),
            
            # enn.PointwiseMaxPool(ip_op_types[7], 2),

            # conv3x3(ip_op_types[7], ip_op_types[8]),
            # enn.ReLU(ip_op_types[8], inplace=True),
            # conv3x3(ip_op_types[8], ip_op_types[9]),
            # enn.ReLU(ip_op_types[9], inplace=True),
            # conv3x3(ip_op_types[9], ip_op_types[10]),
            # enn.ReLU(ip_op_types[10], inplace=True),
            enn.GroupPooling(ip_op_types[10])
        ])
Example #8
0
def build_steer_cnn_2d(
    in_field_type,
    hidden_field_types,
    kernel_sizes,
    out_field_type,
    gspace,
    activation="relu",
    padding_mode="zeros",
    modify_init=1.0,
):
    """
    Input:
        in_rep - rep of representation of the input data
        hidden_reps - the reps to use in the hidden layers
        kernel sizes - the size of the kernel used in each layer
        out_rep - the rep to use in the ouput layer
        activation - the activation to use between layers
        gspace - the gsapce that data lives in
    """

    if isinstance(kernel_sizes, int):
        kernel_sizes = [kernel_sizes] * (len(hidden_reps) + 1)

    layer_field_types = [in_field_type, *hidden_field_types, out_field_type]

    layers = []

    for i in range(len(layer_field_types) - 1):
        layers.append(
            gnn.R2Conv(
                layer_field_types[i],
                layer_field_types[i + 1],
                kernel_sizes[i],
                padding=int((kernel_sizes[i] - 1) / 2),
                padding_mode=padding_mode,
                initialize=True,
            ))
        if i != len(layer_field_types) - 2:
            layers.append(activations[activation](layer_field_types[i + 1]))

    cnn = gnn.SequentialModule(*layers)

    # TODO: dirty fix to alleviate weird initialisations
    for p in cnn.parameters():
        if p.dim() == 0:
            p.data = p.data * modify_init
        else:
            p.data[:] = p.data * modify_init

    return nn.Sequential(
        Expression(lambda X: gnn.GeometricTensor(X, in_field_type)),
        cnn,
        Expression(lambda X: X.tensor),
    )
Example #9
0
 def __init__(self, input_frames, output_frames, kernel_size, N):
     super(ResNet_Rot, self).__init__()
     r2_act = gspaces.Rot2dOnR2(N = N)
     # we use rho_1 representation since the input is velocity fields 
     self.feat_type_in = nn.FieldType(r2_act, input_frames*[r2_act.irrep(1)])
     # we use regular representation for middle layers
     self.feat_type_in_hid = nn.FieldType(r2_act, 16*[r2_act.regular_repr])
     self.feat_type_hid_out = nn.FieldType(r2_act, 192*[r2_act.regular_repr])
     self.feat_type_out = nn.FieldType(r2_act, output_frames*[r2_act.irrep(1)])
     
     self.input_layer = nn.SequentialModule(
         nn.R2Conv(self.feat_type_in, self.feat_type_in_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(self.feat_type_in_hid),
         nn.ReLU(self.feat_type_in_hid)
     )
     layers = [self.input_layer]
     layers += [rot_resblock(16, 32, kernel_size, N), rot_resblock(32, 32, kernel_size, N)]
     layers += [rot_resblock(32, 64, kernel_size, N), rot_resblock(64, 64, kernel_size, N)]
     layers += [rot_resblock(64, 128, kernel_size, N), rot_resblock(128, 128, kernel_size, N)]
     layers += [rot_resblock(128, 192, kernel_size, N), rot_resblock(192, 192, kernel_size, N)]
     layers += [nn.R2Conv(self.feat_type_hid_out, self.feat_type_out, kernel_size = kernel_size, padding = (kernel_size - 1)//2)]
     self.model = torch.nn.Sequential(*layers)
Example #10
0
    def __init__(self,
                 hidden_reps_ids,
                 kernel_sizes,
                 dim_cov_est,
                 context_rep_ids=[1],
                 N=4,
                 flip=False,
                 non_linearity=["NormReLU"],
                 max_frequency=30):
        '''
        Input:  hidden_reps_ids - list: encoding the hidden fiber representation (see give_fib_reps_from_ids)
                kernel_sizes - list of ints - sizes of kernels for convolutional layers
                dim_cov_est - dimension of covariance estimation, either 1,2,3 or 4                
                context_rep_ids - list: gives the input fiber representation (see give_fib_reps_from_ids)
                non_linearity - list of strings - gives names of non-linearity to be used
                                    Either length 1 (then same non-linearity for all)
                                    or length is the number of layers (giving a custom non-linearity for every
                                    layer)   
                N - int - gives the group order, -1 is infinite
                flip - Bool - indicates whether we have a flip in the rotation group (i.e.O(2) vs SO(2), D_N vs C_N)
                max_frequency - int - maximum irrep frequency to computed, only relevant if N=-1
        '''

        super(SteerDecoder, self).__init__()
        #Save the rotation group, if flip is true, then include all corresponding reflections:
        self.flip = flip
        self.max_frequency = max_frequency

        if self.flip:
            self.G_act = gspaces.FlipRot2dOnR2(
                N=N) if N != -1 else gspaces.FlipRot2dOnR2(
                    N=N, maximum_frequency=self.max_frequency)
            #The output fiber representation is the identity:
            self.target_rep = self.G_act.irrep(1, 1)
        else:
            self.G_act = gspaces.Rot2dOnR2(
                N=N) if N != -1 else gspaces.Rot2dOnR2(
                    N=N, maximum_frequency=self.max_frequency)
            #The output fiber representation is the identity:
            self.target_rep = self.G_act.irrep(1)

        #Save the N defining D_N or C_N (if N=-1 it is infinity):
        self.polygon_corners = N

        #Save the id's for the context representation and extract the context fiber representation:
        self.context_rep_ids = context_rep_ids
        self.context_rep = group.directsum(
            self.give_reps_from_ids(self.context_rep_ids))

        #Save the parameters:
        self.kernel_sizes = kernel_sizes
        self.n_layers = len(hidden_reps_ids) + 2
        self.hidden_reps_ids = hidden_reps_ids
        self.dim_cov_est = dim_cov_est

        #-----CREATE LIST OF NON-LINEARITIES----
        if len(non_linearity) == 1:
            self.non_linearity = (self.n_layers - 2) * non_linearity
        elif len(non_linearity) != (self.n_layers - 2):
            sys.exit(
                "List of non-linearities invalid: must have either length 1 or n_layers-2"
            )
        else:
            self.non_linearity = non_linearity
        #-----ENDE LIST OF NON-LINEARITIES----

        #-----------CREATE DECODER-----------------
        '''
        Create a list of layers based on the kernel sizes. Compute the padding such
        that the height h and width w of a tensor with shape (batch_size,n_channels,h,w) does not change
        while being passed through the decoder
        '''
        #Create list of feature types:
        feat_types = self.give_feat_types()
        self.feature_emb = feat_types[0]
        self.feature_out = feat_types[-1]
        #Create layers list and append it:
        layers_list = [
            G_CNN.R2Conv(feat_types[0],
                         feat_types[1],
                         kernel_size=kernel_sizes[0],
                         padding=(kernel_sizes[0] - 1) // 2)
        ]
        for it in range(self.n_layers - 2):
            if self.non_linearity[it] == "ReLU":
                layers_list.append(G_CNN.ReLU(feat_types[it + 1],
                                              inplace=True))
            elif self.non_linearity[it] == "NormReLU":
                layers_list.append(G_CNN.NormNonLinearity(feat_types[it + 1]))
            else:
                sys.exit("Unknown non-linearity.")
            layers_list.append(
                G_CNN.R2Conv(feat_types[it + 1],
                             feat_types[it + 2],
                             kernel_size=kernel_sizes[it],
                             padding=(kernel_sizes[it] - 1) // 2))
        #Create a steerable decoder out of the layers list:
        self.decoder = G_CNN.SequentialModule(*layers_list)
        #-----------END CREATE DECODER---------------

        #-----------CONTROL INPUTS------------------
        #Control that all kernel sizes are odd (otherwise output shape is not correct):
        if any([j % 2 - 1 for j in kernel_sizes]):
            sys.exit("All kernels need to have odd sizes")
        if len(kernel_sizes) != (self.n_layers - 1):
            sys.exit("Number of layers and number kernels do not match.")
        if len(self.non_linearity) != (self.n_layers - 2):
            sys.exit(
                "Number of layers and number of non-linearities do not match.")
    def __init__(self, input_size, in_type, field_type, out_fiber,
                 activation_fn, hidden_size, group_action_type):

        super(InvariantCNNBlock, self).__init__()
        _, self.c, self.h, self.w = input_size
        ngf = 16
        self.group_action_type = group_action_type
        feat_type_in = enn.FieldType(
            self.group_action_type,
            self.c * [self.group_action_type.trivial_repr])
        feat_type_hid = FIBERS[out_fiber](group_action_type,
                                          hidden_size,
                                          field_type,
                                          fixparams=True)
        feat_type_out = enn.FieldType(
            self.group_action_type,
            128 * [self.group_action_type.regular_repr])

        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = feat_type_in

        self.block1 = enn.SequentialModule(
            enn.R2Conv(feat_type_in, feat_type_hid, kernel_size=5, padding=0),
            enn.InnerBatchNorm(feat_type_hid),
            activation_fn(feat_type_hid, inplace=True),
        )

        self.pool1 = enn.SequentialModule(
            enn.PointwiseAvgPoolAntialiased(feat_type_hid,
                                            sigma=0.66,
                                            stride=2))

        self.block2 = enn.SequentialModule(
            enn.R2Conv(feat_type_hid, feat_type_hid, kernel_size=5),
            enn.InnerBatchNorm(feat_type_hid),
            activation_fn(feat_type_hid, inplace=True),
        )

        self.pool2 = enn.SequentialModule(
            enn.PointwiseAvgPoolAntialiased(feat_type_hid,
                                            sigma=0.66,
                                            stride=2))

        self.block3 = enn.SequentialModule(
            enn.R2Conv(feat_type_hid, feat_type_out, kernel_size=3, padding=1),
            enn.InnerBatchNorm(feat_type_out),
            activation_fn(feat_type_out, inplace=True),
        )

        self.pool3 = enn.PointwiseAvgPoolAntialiased(feat_type_out,
                                                     sigma=0.66,
                                                     stride=1,
                                                     padding=0)

        self.gpool = enn.GroupPooling(feat_type_out)

        self.gc = self.gpool.out_type.size
        self.gen = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(self.gc,
                                     ngf,
                                     kernel_size=4,
                                     stride=1,
                                     padding=0), torch.nn.BatchNorm2d(ngf),
            torch.nn.LeakyReLU(0.2),
            torch.nn.ConvTranspose2d(ngf,
                                     ngf,
                                     kernel_size=4,
                                     stride=2,
                                     padding=1), torch.nn.BatchNorm2d(ngf),
            torch.nn.LeakyReLU(0.2),
            torch.nn.ConvTranspose2d(ngf,
                                     int(ngf / 2),
                                     kernel_size=4,
                                     stride=2,
                                     padding=1),
            torch.nn.BatchNorm2d(int(ngf / 2)), torch.nn.LeakyReLU(0.2),
            torch.nn.ConvTranspose2d(int(ngf / 2),
                                     self.c,
                                     kernel_size=4,
                                     stride=2,
                                     padding=1), torch.nn.Tanh())
    def __init__(self, input_shape, num_actions, dueling_DQN, rest_lev):
        super(steerable_DQN_Pacman, self).__init__()
        self.input_shape = input_shape
        self.num_actions = num_actions
        self.dueling_DQN = dueling_DQN
        self.rest_lev = rest_lev
        self.feature_factor = 1
        # Scales up the num of fields
        self.num_feature_fields = [8, 16, 16, 128]
        # Symmetry group
        self.r2_act = gspaces.FlipRot2dOnR2(N=4)

        # 1st E-Conv
        self.input_type = nn.FieldType(
            self.r2_act, input_shape[0] * [self.r2_act.trivial_repr])
        feature1_type = nn.FieldType(
            self.r2_act,
            self.num_feature_fields[0] * [self.r2_act.regular_repr])
        self.feature_field1 = nn.SequentialModule(
            nn.R2Conv(self.input_type,
                      feature1_type,
                      kernel_size=7,
                      padding=2,
                      stride=2,
                      bias=False), nn.ReLU(feature1_type, inplace=True),
            nn.PointwiseAvgPoolAntialiased(feature1_type, sigma=0.66,
                                           stride=2))

        # 2nd E-Conv
        self.input_feature_type = feature1_type
        if self.rest_lev == 3:
            self.restrict1 = self.restrict_layer((0, 2), feature1_type)
            self.feature_factor = 2
        else:
            self.restrict1 = lambda x: x
        feature2_type = nn.FieldType(
            self.r2_act, self.num_feature_fields[1] * self.feature_factor *
            [self.r2_act.regular_repr])
        self.feature_field2 = nn.SequentialModule(
            nn.R2Conv(self.input_feature_type,
                      feature2_type,
                      kernel_size=5,
                      padding=2,
                      stride=2,
                      bias=False), nn.ReLU(feature2_type, inplace=True))

        # 3rd E-Conv
        self.input_feature_type = feature2_type
        if self.rest_lev == 2:
            self.restrict2 = self.restrict_layer((0, 1), feature2_type)
            self.feature_factor = 4
        else:
            self.restrict2 = lambda x: x
        feature3_type = nn.FieldType(
            self.r2_act, self.num_feature_fields[2] * self.feature_factor *
            [self.r2_act.regular_repr])
        self.feature_field3 = nn.SequentialModule(
            nn.R2Conv(self.input_feature_type,
                      feature3_type,
                      kernel_size=5,
                      padding=1,
                      stride=2,
                      bias=False), nn.ReLU(feature3_type, inplace=True))

        # 4th E-Conv
        self.input_feature_type = feature3_type
        if rest_lev == 1:
            self.restrict_extra = self.restrict_layer((0, 1), feature3_type)
            self.feature_factor = 4
        else:
            self.restrict_extra = lambda x: x
        feature4_type = nn.FieldType(
            self.r2_act, self.num_feature_fields[3] * self.feature_factor *
            [self.r2_act.regular_repr])
        self.feature_field_extra = nn.SequentialModule(
            nn.R2Conv(self.input_feature_type,
                      feature4_type,
                      kernel_size=5,
                      padding=0,
                      bias=True), nn.ReLU(feature4_type, inplace=True))
        _, _ = self.feature_shape()
        self.out_size = self.feature_field_extra.out_type.size
        # Final linear layer
        if self.dueling_DQN:
            print("You are using Dueling DQN")
            self.advantage = torch.nn.Linear(self.out_size, self.num_actions)
            self.value = torch.nn.Linear(self.out_size, 1)
        else:
            self.actionvalue = torch.nn.Linear(self.out_size, self.num_actions)