示例#1
0
    def __init__(self, channel_in=3, n_classes=4, rot_n=4):
        super(SmallE2, self).__init__()

        r2_act = gspaces.Rot2dOnR2(N=rot_n)

        self.feat_type_in = nn.FieldType(r2_act,
                                         channel_in * [r2_act.trivial_repr])
        feat_type_hid = nn.FieldType(r2_act, 8 * [r2_act.regular_repr])
        feat_type_out = nn.FieldType(r2_act, 2 * [r2_act.regular_repr])

        self.bn = nn.InnerBatchNorm(feat_type_hid)
        self.relu = nn.ReLU(feat_type_hid)

        self.convin = nn.R2Conv(self.feat_type_in,
                                feat_type_hid,
                                kernel_size=3)
        self.convhid = nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size=3)
        self.convout = nn.R2Conv(feat_type_hid, feat_type_out, kernel_size=3)

        self.avgpool = nn.PointwiseAvgPool(feat_type_out, 3)
        self.invariant_map = nn.GroupPooling(feat_type_out)

        c = self.invariant_map.out_type.size

        self.lin_in = torch.nn.Linear(c, 64)
        self.elu = torch.nn.ELU()
        self.lin_out = torch.nn.Linear(64, n_classes)
示例#2
0
    def __init__(self, in_chan, out_chan, imsize, kernel_size=5, N=8):
        super(DNRestrictedLeNet, self).__init__()

        z = imsize // 2 // 2

        self.r2_act = gspaces.FlipRot2dOnR2(N)

        in_type = e2nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
        self.input_type = in_type

        out_type = e2nn.FieldType(self.r2_act, 6 * [self.r2_act.regular_repr])
        self.mask = e2nn.MaskModule(in_type, imsize, margin=1)
        self.conv1 = e2nn.R2Conv(in_type,
                                 out_type,
                                 kernel_size=kernel_size,
                                 padding=kernel_size // 2,
                                 bias=False)
        self.relu1 = e2nn.ReLU(out_type, inplace=True)
        self.pool1 = e2nn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2)

        self.gpool = e2nn.GroupPooling(out_type)

        self.conv2 = nn.Conv2d(6, 16, kernel_size, padding=kernel_size // 2)

        self.fc1 = nn.Linear(16 * z * z, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, out_chan)

        self.drop = nn.Dropout(p=0.5)

        # dummy parameter for tracking device
        self.dummy = nn.Parameter(torch.empty(0))
示例#3
0
    def __init__(self, in_type, num_classes=10):
        super(ClassificationHead, self).__init__()
        gspace = in_type.gspace
        self.add_module('gpool', nn.GroupPooling(in_type))

        # number of output channels
        # Fully Connected
        in_type = self.gpool.out_type
        out_type = nn.FieldType(gspace, 64 * [gspace.trivial_repr])
        self.add_module(
            'linear1',
            sscnn.e2cnn.PlainConv(in_type,
                                  out_type,
                                  kernel_size=1,
                                  padding=0,
                                  bias=False))
        self.add_module('relu1', nn.ReLU(out_type, inplace=True))
        in_type = out_type
        out_type = nn.FieldType(gspace, num_classes * [gspace.trivial_repr])
        self.add_module(
            'linear2',
            sscnn.e2cnn.PlainConv(in_type,
                                  out_type,
                                  kernel_size=1,
                                  padding=0,
                                  bias=False))
示例#4
0
    def __init__(self, n_classes=6):
        super(SteerCNN, self).__init__()

        # the model is equivariant under rotations by 45 degrees, modelled by C8
        self.r2_act = gspaces.Rot2dOnR2(N=4)

        # the input image is a scalar field, corresponding to the trivial representation
        input_type = nn_e2.FieldType(self.r2_act,
                                     3 * [self.r2_act.trivial_repr])

        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = input_type
        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 24 feature fields, each transforming under the regular representation of C8
        out_type = nn_e2.FieldType(self.r2_act,
                                   24 * [self.r2_act.regular_repr])
        self.block1 = nn_e2.SequentialModule(
            nn_e2.R2Conv(input_type,
                         out_type,
                         kernel_size=7,
                         padding=3,
                         bias=False), nn_e2.InnerBatchNorm(out_type),
            nn_e2.ReLU(out_type, inplace=True))

        self.pool1 = nn_e2.PointwiseAvgPool(out_type, 4)

        # convolution 2
        # the old output type is the input type to the next layer
        in_type = self.block1.out_type
        # the output type of the second convolution layer are 48 regular feature fields of C8
        #out_type = nn_e2.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr])
        self.block2 = nn_e2.SequentialModule(
            nn_e2.R2Conv(in_type,
                         out_type,
                         kernel_size=7,
                         padding=3,
                         bias=False), nn_e2.InnerBatchNorm(out_type),
            nn_e2.ReLU(out_type, inplace=True))
        self.pool2 = nn_e2.SequentialModule(
            nn_e2.PointwiseAvgPoolAntialiased(out_type,
                                              sigma=0.66,
                                              stride=1,
                                              padding=0),
            nn_e2.PointwiseAvgPool(out_type, 4), nn_e2.GroupPooling(out_type))
        # PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=7)

        # number of output channels
        c = 24 * 13 * 13  #self.gpool.out_type.size

        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(c, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )
示例#5
0
    def __init__(self):
        super(DenseFeatureExtractionModuleE2Inv, self).__init__()

        filters = np.array([32,32, 64,64, 128,128,128, 256,256,256, 512,512,512], dtype=np.int32)*2
        
        # number of rotations to consider for rotation invariance
        N = 8
        
        self.gspace = gspaces.Rot2dOnR2(N)
        self.input_type = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3)
        ip_op_types = [
            self.input_type,
        ]
        
        self.num_channels = 64

        for filter_ in filters[:10]:
            ip_op_types.append(FIELD_TYPE['regular'](self.gspace, filter_, fixparams=False))

        self.model = enn.SequentialModule(*[
            conv3x3(ip_op_types[0], ip_op_types[1]),
            enn.ReLU(ip_op_types[1], inplace=True),
            conv3x3(ip_op_types[1], ip_op_types[2]),
            enn.ReLU(ip_op_types[2], inplace=True),
            enn.PointwiseMaxPool(ip_op_types[2], 2),

            conv3x3(ip_op_types[2], ip_op_types[3]),
            enn.ReLU(ip_op_types[3], inplace=True),
            conv3x3(ip_op_types[3], ip_op_types[4]),
            enn.ReLU(ip_op_types[4], inplace=True),
            enn.PointwiseMaxPool(ip_op_types[4], 2),

            conv3x3(ip_op_types[4], ip_op_types[5]),
            enn.ReLU(ip_op_types[5], inplace=True),
            conv3x3(ip_op_types[5], ip_op_types[6]),
            enn.ReLU(ip_op_types[6], inplace=True),
            conv3x3(ip_op_types[6], ip_op_types[7]),
            enn.ReLU(ip_op_types[7], inplace=True),
            enn.PointwiseAvgPool(ip_op_types[7], kernel_size=2, stride=1),

            conv5x5(ip_op_types[7], ip_op_types[8]),
            enn.ReLU(ip_op_types[8], inplace=True),
            conv5x5(ip_op_types[8], ip_op_types[9]),
            enn.ReLU(ip_op_types[9], inplace=True),
            conv5x5(ip_op_types[9], ip_op_types[10]),
            enn.ReLU(ip_op_types[10], inplace=True),
            
            # enn.PointwiseMaxPool(ip_op_types[7], 2),

            # conv3x3(ip_op_types[7], ip_op_types[8]),
            # enn.ReLU(ip_op_types[8], inplace=True),
            # conv3x3(ip_op_types[8], ip_op_types[9]),
            # enn.ReLU(ip_op_types[9], inplace=True),
            # conv3x3(ip_op_types[9], ip_op_types[10]),
            # enn.ReLU(ip_op_types[10], inplace=True),
            enn.GroupPooling(ip_op_types[10])
        ])
    def __init__(self, input_shape, num_actions, dueling_DQN):
        super(D4_steerable_DQN_Snake, self).__init__()
        self.input_shape = input_shape
        self.num_actions = num_actions
        self.dueling_DQN = dueling_DQN
        self.r2_act = gspaces.FlipRot2dOnR2(N=4)
        self.input_type = nn.FieldType(
            self.r2_act, input_shape[0] * [self.r2_act.trivial_repr])
        feature1_type = nn.FieldType(self.r2_act,
                                     8 * [self.r2_act.regular_repr])
        feature2_type = nn.FieldType(self.r2_act,
                                     12 * [self.r2_act.regular_repr])
        feature3_type = nn.FieldType(self.r2_act,
                                     12 * [self.r2_act.regular_repr])
        feature4_type = nn.FieldType(self.r2_act,
                                     32 * [self.r2_act.regular_repr])

        self.feature_field1 = nn.SequentialModule(
            nn.R2Conv(self.input_type,
                      feature1_type,
                      kernel_size=7,
                      padding=2,
                      stride=2,
                      bias=False), nn.ReLU(feature1_type, inplace=True))
        self.feature_field2 = nn.SequentialModule(
            nn.R2Conv(feature1_type,
                      feature2_type,
                      kernel_size=5,
                      padding=1,
                      stride=2,
                      bias=False), nn.ReLU(feature2_type, inplace=True))
        self.feature_field3 = nn.SequentialModule(
            nn.R2Conv(feature2_type,
                      feature3_type,
                      kernel_size=5,
                      padding=1,
                      stride=1,
                      bias=False), nn.ReLU(feature3_type, inplace=True))

        self.equivariant_features = nn.SequentialModule(
            nn.R2Conv(feature3_type,
                      feature4_type,
                      kernel_size=5,
                      stride=1,
                      bias=False), nn.ReLU(feature4_type, inplace=True))
        self.gpool = nn.GroupPooling(feature4_type)
        self.feature_shape()
        if self.dueling_DQN:
            print("You are using Dueling DQN")
            self.advantage = torch.nn.Linear(
                self.equivariant_features.out_type.size, self.num_actions)
            #self.value = torch.nn.Linear(self.gpool.out_type.size, 1)
            self.value = torch.nn.Linear(
                self.equivariant_features.out_type.size, 1)
        else:
            self.actionvalue = torch.nn.Linear(
                self.equivariant_features.out_type.size, self.num_actions)
示例#7
0
    def __init__(self,
                 base='DNSteerableLeNet',
                 in_chan=1,
                 n_classes=2,
                 imsize=150,
                 kernel_size=5,
                 N=8,
                 quiet=True,
                 number_rotations=None):
        super(DNSteerableLeNet, self).__init__()
        kernel_size = int(kernel_size)
        out_chan = int(n_classes)

        if number_rotations != None:
            N = int(number_rotations)

        z = imsize // 2 // 2

        self.r2_act = gspaces.FlipRot2dOnR2(N)

        in_type = e2nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
        self.input_type = in_type

        out_type = e2nn.FieldType(self.r2_act, 6 * [self.r2_act.regular_repr])
        self.mask = e2nn.MaskModule(in_type, imsize, margin=1)
        self.conv1 = e2nn.R2Conv(in_type,
                                 out_type,
                                 kernel_size=kernel_size,
                                 padding=kernel_size // 2,
                                 bias=False)
        self.relu1 = e2nn.ReLU(out_type, inplace=True)
        self.pool1 = e2nn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2)

        in_type = self.pool1.out_type
        out_type = e2nn.FieldType(self.r2_act, 16 * [self.r2_act.regular_repr])
        self.conv2 = e2nn.R2Conv(in_type,
                                 out_type,
                                 kernel_size=kernel_size,
                                 padding=kernel_size // 2,
                                 bias=False)
        self.relu2 = e2nn.ReLU(out_type, inplace=True)
        self.pool2 = e2nn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2)

        self.gpool = e2nn.GroupPooling(out_type)

        self.fc1 = nn.Linear(16 * z * z, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, out_chan)

        self.drop = nn.Dropout(p=0.5)

        # dummy parameter for tracking device
        self.dummy = nn.Parameter(torch.empty(0))
示例#8
0
 def __init__(self):
     super(ModelDilated, self).__init__()
     N = 8
     self.gspace = gspaces.Rot2dOnR2(N)
     self.in_type = enn.FieldType(self.gspace,
                                  [self.gspace.trivial_repr] * 3)
     self.out_type = enn.FieldType(self.gspace,
                                   [self.gspace.regular_repr] * 16)
     self.layer = enn.R2Conv(
         self.in_type,
         self.out_type,
         3,
         stride=1,
         padding=2,
         dilation=2,
         bias=True,
     )
     self.invariant = enn.GroupPooling(self.out_type)
    def build_multiscale_classifier(self, input_size):
        n, c, h, w = input_size
        hidden_shapes = []
        for i in range(self.n_scale):
            if i < self.n_scale - 1:
                c *= 2 if self.factor_out else 4
                h //= 2
                w //= 2
            hidden_shapes.append((n, c, h, w))

        classification_heads = []
        feat_type_out = FIBERS['regular'](self.group_action_type,
                                          self.classification_hdim,
                                          self.field_type, fixparams=True)
        feat_type_mid = FIBERS['regular'](self.group_action_type,
                                          int(self.classification_hdim // 2),
                                          self.field_type, fixparams=True)
        feat_type_last = FIBERS['regular'](self.group_action_type,
                                          int(self.classification_hdim // 4),
                                          self.field_type, fixparams=True)
        # feat_type_out = enn.FieldType(self.group_action_type,
                                      # self.classification_hdim*[self.group_action_type.regular_repr])
        for i, hshape in enumerate(hidden_shapes):
            classification_heads.append(
                nn.Sequential(
                    enn.R2Conv(self.input_type, feat_type_out, 5, stride=2),
                    layers.EquivariantActNorm2d(feat_type_out.size),
                    enn.ReLU(feat_type_out, inplace=True),
                    enn.PointwiseAvgPoolAntialiased(feat_type_out, sigma=0.66, stride=2),
                    enn.R2Conv(feat_type_out, feat_type_mid, kernel_size=3),
                    layers.EquivariantActNorm2d(feat_type_mid.size),
                    enn.ReLU(feat_type_mid, inplace=True),
                    enn.PointwiseAvgPoolAntialiased(feat_type_mid, sigma=0.66, stride=1),
                    enn.R2Conv(feat_type_mid, feat_type_last, kernel_size=3),
                    layers.EquivariantActNorm2d(feat_type_last.size),
                    enn.ReLU(feat_type_last, inplace=True),
                    enn.PointwiseAvgPoolAntialiased(feat_type_last, sigma=0.66, stride=2),
                    enn.GroupPooling(feat_type_last),
                )
            )
        self.classification_heads = nn.ModuleList(classification_heads)
        self.logit_layer = nn.Linear(classification_heads[-1][-1].out_type.size, self.n_classes)
示例#10
0
    def __init__(self,
                 base = 'DNSteerableAGRadGalNet',
                 attention_module='SelfAttention',
                 attention_gates=3,
                 attention_aggregation='ft',
                 n_classes=2,
                 attention_normalisation='sigmoid',
                 quiet=True,
                 number_rotations=8,
                 imsize=150,
                 kernel_size=3,
                 group="D"
                ):
        super(DNSteerableAGRadGalNet, self).__init__()
        aggregation_mode = attention_aggregation
        normalisation = attention_normalisation
        AG = int(attention_gates)
        N = int(number_rotations)
        kernel_size = int(kernel_size)
        imsize = int(imsize)
        n_classes = int(n_classes)
        assert aggregation_mode in ['concat', 'mean', 'deep_sup', 'ft'], 'Aggregation mode not recognised. Valid inputs include concat, mean, deep_sup or ft.'
        assert normalisation in ['sigmoid','range_norm','std_mean_norm','tanh','softmax'], f'Nomralisation not implemented. Can be any of: sigmoid, range_norm, std_mean_norm, tanh, softmax'
        assert AG in [0,1,2,3], f'Number of Attention Gates applied (AG) must be an integer in range [0,3]. Currently AG={AG}'
        assert group.lower() in ["d","c"], f"group parameter must either be 'D' for DN, or 'C' for CN, steerable networks. (currently {group})."
        filters = [6,16,32,64,128]

        self.attention_out_sizes = []
        self.ag = AG
        self.n_classes = n_classes
        self.filters = filters
        self.aggregation_mode = aggregation_mode

        # Setting up e2
        if group.lower() == "d":
            self.r2_act = gspaces.FlipRot2dOnR2(N=int(number_rotations))
        else:
            self.r2_act = gspaces.Rot2dOnR2(N=int(number_rotations))
        in_type = e2nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
        out_type = e2nn.FieldType(self.r2_act, 6*[self.r2_act.regular_repr])
        self.in_type = in_type

        self.mask = e2nn.MaskModule(in_type, imsize, margin=0)
        self.conv1a = e2nn.R2Conv(in_type,  out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu1a = e2nn.ReLU(out_type); self.bnorm1a= e2nn.InnerBatchNorm(out_type)
        self.conv1b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu1b = e2nn.ReLU(out_type); self.bnorm1b= e2nn.InnerBatchNorm(out_type)
        self.conv1c = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu1c = e2nn.ReLU(out_type); self.bnorm1c= e2nn.InnerBatchNorm(out_type)
        self.mpool1 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2)
        self.gpool1 = e2nn.GroupPooling(out_type)


        in_type = out_type
        out_type = e2nn.FieldType(self.r2_act, 16*[self.r2_act.regular_repr])
        self.conv2a = e2nn.R2Conv(in_type,  out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu2a = e2nn.ReLU(out_type); self.bnorm2a= e2nn.InnerBatchNorm(out_type)
        self.conv2b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu2b = e2nn.ReLU(out_type); self.bnorm2b= e2nn.InnerBatchNorm(out_type)
        self.conv2c = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu2c = e2nn.ReLU(out_type); self.bnorm2c= e2nn.InnerBatchNorm(out_type)
        self.mpool2 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2)
        self.gpool2 = e2nn.GroupPooling(out_type)

        in_type = out_type
        out_type = e2nn.FieldType(self.r2_act, 32*[self.r2_act.regular_repr])
        self.conv3a = e2nn.R2Conv(in_type,  out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu3a = e2nn.ReLU(out_type); self.bnorm3a= e2nn.InnerBatchNorm(out_type)
        self.conv3b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu3b = e2nn.ReLU(out_type); self.bnorm3b= e2nn.InnerBatchNorm(out_type)
        self.conv3c = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu3c = e2nn.ReLU(out_type); self.bnorm3c= e2nn.InnerBatchNorm(out_type)
        self.mpool3 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2)
        self.gpool3 = e2nn.GroupPooling(out_type)

        in_type = out_type
        out_type = e2nn.FieldType(self.r2_act, 64*[self.r2_act.regular_repr])
        self.conv4a = e2nn.R2Conv(in_type,  out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu4a = e2nn.ReLU(out_type); self.bnorm4a= e2nn.InnerBatchNorm(out_type)
        self.conv4b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu4b = e2nn.ReLU(out_type); self.bnorm4b= e2nn.InnerBatchNorm(out_type)
        self.mpool4 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2)
        self.gpool4 = e2nn.GroupPooling(out_type)

        self.flatten = nn.Flatten(1)
        self.dropout = nn.Dropout(p=0.5)

        if self.ag == 0:
            pass
        if self.ag >= 1:
            self.attention1 = GridAttentionBlock2D(in_channels=32, gating_channels=64, inter_channels=64, input_size=[imsize//4,imsize//4], normalisation=normalisation)
        if self.ag >= 2:
            self.attention2 = GridAttentionBlock2D(in_channels=16, gating_channels=64, inter_channels=64, input_size=[imsize//2,imsize//2], normalisation=normalisation)
        if self.ag >= 3:
            self.attention3 = GridAttentionBlock2D(in_channels=6, gating_channels=64, inter_channels=64, input_size=[imsize,imsize], normalisation=normalisation)

        self.fc1 = nn.Linear(16*5*5,256) #channel_size * width * height
        self.fc2 = nn.Linear(256,256)
        self.fc3 = nn.Linear(256, self.n_classes)
        self.dummy = nn.Parameter(torch.empty(0))

        self.module_order = ['conv1a', 'relu1a', 'bnorm1a', #1->6
                             'conv1b', 'relu1b', 'bnorm1b', #6->6
                             'conv1c', 'relu1c', 'bnorm1c', #6->6
                             'mpool1',
                             'conv2a', 'relu2a', 'bnorm2a', #6->16
                             'conv2b', 'relu2b', 'bnorm2b', #16->16
                             'conv2c', 'relu2c', 'bnorm2c', #16->16
                             'mpool2',
                             'conv3a', 'relu3a', 'bnorm3a', #16->32
                             'conv3b', 'relu3b', 'bnorm3b', #32->32
                             'conv3c', 'relu3c', 'bnorm3c', #32->32
                             'mpool3',
                             'conv4a', 'relu4a', 'bnorm4a', #32->64
                             'conv4b', 'relu4b', 'bnorm4b', #64->64
                             'compatibility_score1',
                             'compatibility_score2']


        #########################
        # Aggreagation Strategies
        if self.ag != 0:
            self.attention_filter_sizes = [32, 16, 6]
            concat_length = 0
            for i in range(self.ag):
                concat_length += self.attention_filter_sizes[i]
            if aggregation_mode == 'concat':
                self.classifier = nn.Linear(concat_length, self.n_classes)
                self.aggregate = self.aggregation_concat
            else:
                # Not able to initialise in a loop as the modules will not change device with remaining model.
                self.classifiers = nn.ModuleList()
                if self.ag>=1:
                    self.classifiers.append(nn.Linear(self.attention_filter_sizes[0], self.n_classes))
                if self.ag>=2:
                    self.classifiers.append(nn.Linear(self.attention_filter_sizes[1], self.n_classes))
                if self.ag>=3:
                    self.classifiers.append(nn.Linear(self.attention_filter_sizes[2], self.n_classes))
                if aggregation_mode == 'mean':
                    self.aggregate = self.aggregation_sep
                elif aggregation_mode == 'deep_sup':
                    self.classifier = nn.Linear(concat_length, self.n_classes)
                    self.aggregate = self.aggregation_ds
                elif aggregation_mode == 'ft':
                    self.classifier = nn.Linear(self.n_classes*self.ag, self.n_classes)
                    self.aggregate = self.aggregation_ft
                else:
                    raise NotImplementedError
        else:
            self.classifier = nn.Linear((150//16)**2*64, self.n_classes)
            self.aggregate = lambda x: self.classifier(self.flatten(x))
    def __init__(self):
        super(UNet, self).__init__()

        self.r2_act = gspaces.Rot2dOnR2(N=8)

        self.field_type_1 = enn.FieldType(self.r2_act,
                                          1 * [self.r2_act.regular_repr])
        self.field_type_3 = enn.FieldType(self.r2_act,
                                          3 * [self.r2_act.trivial_repr])
        self.field_type_8 = enn.FieldType(self.r2_act,
                                          8 * [self.r2_act.regular_repr])
        self.field_type_16 = enn.FieldType(self.r2_act,
                                           16 * [self.r2_act.regular_repr])
        self.field_type_32 = enn.FieldType(self.r2_act,
                                           32 * [self.r2_act.regular_repr])
        self.field_type_64 = enn.FieldType(self.r2_act,
                                           64 * [self.r2_act.regular_repr])
        self.field_type_128 = enn.FieldType(self.r2_act,
                                            128 * [self.r2_act.regular_repr])

        self.conv1 = Conv(in_type=self.field_type_3,
                          out_type=self.field_type_8)
        self.conv2 = Conv(in_type=self.field_type_8,
                          out_type=self.field_type_8)

        self.down1 = DownSample(in_type=self.field_type_8,
                                out_type=self.field_type_16)
        self.conv12 = Conv(in_type=self.field_type_16,
                           out_type=self.field_type_16)
        self.down2 = DownSample(in_type=self.field_type_16,
                                out_type=self.field_type_32)
        self.conv22 = Conv(in_type=self.field_type_32,
                           out_type=self.field_type_32)
        self.down3 = DownSample(in_type=self.field_type_32,
                                out_type=self.field_type_32)
        self.conv32 = Conv(in_type=self.field_type_32,
                           out_type=self.field_type_32)

        self.up41 = UpSample(in_type=self.field_type_64,
                             out_type=self.field_type_32,
                             mid_type=self.field_type_32)
        self.up31 = UpSample(in_type=self.field_type_64,
                             out_type=self.field_type_16,
                             mid_type=self.field_type_32)
        self.up21 = UpSample(in_type=self.field_type_32,
                             out_type=self.field_type_8,
                             mid_type=self.field_type_16)
        self.up11 = LastConcat(in_type=self.field_type_16,
                               out_type=self.field_type_1,
                               mid_type=self.field_type_8)

        self.up42 = UpSample(in_type=self.field_type_64,
                             out_type=self.field_type_32,
                             mid_type=self.field_type_32)
        self.up32 = UpSample(in_type=self.field_type_64,
                             out_type=self.field_type_16,
                             mid_type=self.field_type_32)
        self.up22 = UpSample(in_type=self.field_type_32,
                             out_type=self.field_type_8,
                             mid_type=self.field_type_16)
        self.up12 = LastConcat(in_type=self.field_type_16,
                               out_type=self.field_type_1,
                               mid_type=self.field_type_8)

        self.up43 = UpSample(in_type=self.field_type_64,
                             out_type=self.field_type_32,
                             mid_type=self.field_type_32)
        self.up33 = UpSample(in_type=self.field_type_64,
                             out_type=self.field_type_16,
                             mid_type=self.field_type_32)
        self.up23 = UpSample(in_type=self.field_type_32,
                             out_type=self.field_type_8,
                             mid_type=self.field_type_16)
        self.up13 = LastConcat(in_type=self.field_type_16,
                               out_type=self.field_type_1,
                               mid_type=self.field_type_8)

        self.gpool1 = enn.GroupPooling(self.field_type_1)
        self.gpool2 = enn.GroupPooling(self.field_type_1)
        self.gpool3 = enn.GroupPooling(self.field_type_1)
示例#12
0
    def __init__(self, n_classes=10):

        super(C8SteerableCNN, self).__init__()

        # the model is equivariant under rotations by 45 degrees, modelled by C8
        self.r2_act = gspaces.Rot2dOnR2(N=8)

        # the input image is a scalar field, corresponding to the trivial representation
        in_type = nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])

        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = in_type

        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 16 feature fields, each transforming under the regular representation of C8
        out_type = nn.FieldType(self.r2_act, 24 * [self.r2_act.regular_repr])
        self.block1 = nn.SequentialModule(
            # nn.MaskModule(in_type, 29, margin=1),
            nn.R2Conv(in_type, out_type, kernel_size=7, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True))

        # convolution 2
        # the old output type is the input type to the next layer
        in_type = self.block1.out_type
        # the output type of the second convolution layer are 32 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr])
        self.block2 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True))
        self.pool1 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2))

        # convolution 3
        # the old output type is the input type to the next layer
        in_type = self.block2.out_type
        # the output type of the third convolution layer are 32 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr])
        self.block3 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True))

        # convolution 4
        # the old output type is the input type to the next layer
        in_type = self.block3.out_type
        # the output type of the fourth convolution layer are 64 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 96 * [self.r2_act.regular_repr])
        self.block4 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True))
        self.pool2 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2))

        # convolution 5
        # the old output type is the input type to the next layer
        in_type = self.block4.out_type
        # the output type of the fifth convolution layer are 64 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 96 * [self.r2_act.regular_repr])
        self.block5 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True))

        # convolution 6
        # the old output type is the input type to the next layer
        in_type = self.block5.out_type
        # the output type of the sixth convolution layer are 64 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 64 * [self.r2_act.regular_repr])
        self.block6 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True))
        self.pool3 = nn.PointwiseAvgPool(out_type, kernel_size=4)

        self.gpool = nn.GroupPooling(out_type)

        # number of output channels
        c = self.gpool.out_type.size

        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(c, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )
    def __init__(self, input_size, in_type, field_type, out_fiber,
                 activation_fn, hidden_size, group_action_type):

        super(InvariantCNNBlock, self).__init__()
        _, self.c, self.h, self.w = input_size
        ngf = 16
        self.group_action_type = group_action_type
        feat_type_in = enn.FieldType(
            self.group_action_type,
            self.c * [self.group_action_type.trivial_repr])
        feat_type_hid = FIBERS[out_fiber](group_action_type,
                                          hidden_size,
                                          field_type,
                                          fixparams=True)
        feat_type_out = enn.FieldType(
            self.group_action_type,
            128 * [self.group_action_type.regular_repr])

        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = feat_type_in

        self.block1 = enn.SequentialModule(
            enn.R2Conv(feat_type_in, feat_type_hid, kernel_size=5, padding=0),
            enn.InnerBatchNorm(feat_type_hid),
            activation_fn(feat_type_hid, inplace=True),
        )

        self.pool1 = enn.SequentialModule(
            enn.PointwiseAvgPoolAntialiased(feat_type_hid,
                                            sigma=0.66,
                                            stride=2))

        self.block2 = enn.SequentialModule(
            enn.R2Conv(feat_type_hid, feat_type_hid, kernel_size=5),
            enn.InnerBatchNorm(feat_type_hid),
            activation_fn(feat_type_hid, inplace=True),
        )

        self.pool2 = enn.SequentialModule(
            enn.PointwiseAvgPoolAntialiased(feat_type_hid,
                                            sigma=0.66,
                                            stride=2))

        self.block3 = enn.SequentialModule(
            enn.R2Conv(feat_type_hid, feat_type_out, kernel_size=3, padding=1),
            enn.InnerBatchNorm(feat_type_out),
            activation_fn(feat_type_out, inplace=True),
        )

        self.pool3 = enn.PointwiseAvgPoolAntialiased(feat_type_out,
                                                     sigma=0.66,
                                                     stride=1,
                                                     padding=0)

        self.gpool = enn.GroupPooling(feat_type_out)

        self.gc = self.gpool.out_type.size
        self.gen = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(self.gc,
                                     ngf,
                                     kernel_size=4,
                                     stride=1,
                                     padding=0), torch.nn.BatchNorm2d(ngf),
            torch.nn.LeakyReLU(0.2),
            torch.nn.ConvTranspose2d(ngf,
                                     ngf,
                                     kernel_size=4,
                                     stride=2,
                                     padding=1), torch.nn.BatchNorm2d(ngf),
            torch.nn.LeakyReLU(0.2),
            torch.nn.ConvTranspose2d(ngf,
                                     int(ngf / 2),
                                     kernel_size=4,
                                     stride=2,
                                     padding=1),
            torch.nn.BatchNorm2d(int(ngf / 2)), torch.nn.LeakyReLU(0.2),
            torch.nn.ConvTranspose2d(int(ngf / 2),
                                     self.c,
                                     kernel_size=4,
                                     stride=2,
                                     padding=1), torch.nn.Tanh())