예제 #1
0
 def __init__(self, input_channels, output_channels, kernel_size, stride, N, activation = True, deconv = False, last_deconv = False):
     super(rot_conv2d, self).__init__()       
     r2_act = gspaces.Rot2dOnR2(N = N)
     
     feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
     feat_type_hid = nn.FieldType(r2_act, output_channels*[r2_act.regular_repr])
     if not deconv:
         if activation:
             self.layer = nn.SequentialModule(
                 nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = (kernel_size - 1)//2),
                 nn.InnerBatchNorm(feat_type_hid),
                 nn.ReLU(feat_type_hid)
             ) 
         else:
             self.layer = nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride,padding = (kernel_size - 1)//2)
     else:
         if last_deconv:
             feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
             feat_type_hid = nn.FieldType(r2_act, output_channels*[r2_act.irrep(1)])
             self.layer = nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = 0)
         else:
             self.layer = nn.SequentialModule(
                     nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = 0),
                     nn.InnerBatchNorm(feat_type_hid),
                     nn.ReLU(feat_type_hid)
                 ) 
예제 #2
0
 def __init__(self, 
              input_channels,
              hidden_dim, 
              kernel_size, 
              N # Group size 
             ): 
     super(rot_resblock, self).__init__()
     
     # Specify symmetry transformation
     r2_act = gspaces.Rot2dOnR2(N = N)
     feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
     feat_type_hid = nn.FieldType(r2_act, hidden_dim*[r2_act.regular_repr])
     
     self.layer1 = nn.SequentialModule(
         nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(feat_type_hid),
         nn.ReLU(feat_type_hid)
     ) 
     
     self.layer2 = nn.SequentialModule(
         nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(feat_type_hid),
         nn.ReLU(feat_type_hid)
     )    
     
     self.upscale = nn.SequentialModule(
         nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(feat_type_hid),
         nn.ReLU(feat_type_hid)
     )    
     
     self.input_channels = input_channels
     self.hidden_dim = hidden_dim
예제 #3
0
 def __init__(self, in_type, inner_type, out_type, stride=1):
     super(BottleneckBlock, self).__init__()
     
     self.bn1 = enn.InnerBatchNorm(in_type)
     self.relu1 = enn.ReLU(in_type,inplace=True)
     self.conv1 = conv1x1(in_type, inner_type)
     self.bn2 = enn.InnerBatchNorm(inner_type)
     self.relu2 = enn.ReLU(inner_type,inplace=True)
     self.conv2 = conv3x3(inner_type, out_type)
예제 #4
0
    def __init__(self, n_classes=6):
        super(SteerCNN, self).__init__()

        # the model is equivariant under rotations by 45 degrees, modelled by C8
        self.r2_act = gspaces.Rot2dOnR2(N=4)

        # the input image is a scalar field, corresponding to the trivial representation
        input_type = nn_e2.FieldType(self.r2_act,
                                     3 * [self.r2_act.trivial_repr])

        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = input_type
        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 24 feature fields, each transforming under the regular representation of C8
        out_type = nn_e2.FieldType(self.r2_act,
                                   24 * [self.r2_act.regular_repr])
        self.block1 = nn_e2.SequentialModule(
            nn_e2.R2Conv(input_type,
                         out_type,
                         kernel_size=7,
                         padding=3,
                         bias=False), nn_e2.InnerBatchNorm(out_type),
            nn_e2.ReLU(out_type, inplace=True))

        self.pool1 = nn_e2.PointwiseAvgPool(out_type, 4)

        # convolution 2
        # the old output type is the input type to the next layer
        in_type = self.block1.out_type
        # the output type of the second convolution layer are 48 regular feature fields of C8
        #out_type = nn_e2.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr])
        self.block2 = nn_e2.SequentialModule(
            nn_e2.R2Conv(in_type,
                         out_type,
                         kernel_size=7,
                         padding=3,
                         bias=False), nn_e2.InnerBatchNorm(out_type),
            nn_e2.ReLU(out_type, inplace=True))
        self.pool2 = nn_e2.SequentialModule(
            nn_e2.PointwiseAvgPoolAntialiased(out_type,
                                              sigma=0.66,
                                              stride=1,
                                              padding=0),
            nn_e2.PointwiseAvgPool(out_type, 4), nn_e2.GroupPooling(out_type))
        # PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=7)

        # number of output channels
        c = 24 * 13 * 13  #self.gpool.out_type.size

        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(c, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )
예제 #5
0
    def __init__(self, channel_in=3, n_classes=4, rot_n=4):
        super(SmallE2, self).__init__()

        r2_act = gspaces.Rot2dOnR2(N=rot_n)

        self.feat_type_in = nn.FieldType(r2_act,
                                         channel_in * [r2_act.trivial_repr])
        feat_type_hid = nn.FieldType(r2_act, 8 * [r2_act.regular_repr])
        feat_type_out = nn.FieldType(r2_act, 2 * [r2_act.regular_repr])

        self.bn = nn.InnerBatchNorm(feat_type_hid)
        self.relu = nn.ReLU(feat_type_hid)

        self.convin = nn.R2Conv(self.feat_type_in,
                                feat_type_hid,
                                kernel_size=3)
        self.convhid = nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size=3)
        self.convout = nn.R2Conv(feat_type_hid, feat_type_out, kernel_size=3)

        self.avgpool = nn.PointwiseAvgPool(feat_type_out, 3)
        self.invariant_map = nn.GroupPooling(feat_type_out)

        c = self.invariant_map.out_type.size

        self.lin_in = torch.nn.Linear(c, 64)
        self.elu = torch.nn.ELU()
        self.lin_out = torch.nn.Linear(64, n_classes)
예제 #6
0
    def __init__(self, input_frames, output_frames, kernel_size, N):
        super(Unet_Rot, self).__init__()
        r2_act = gspaces.Rot2dOnR2(N = N)
        self.feat_type_in = nn.FieldType(r2_act, input_frames*[r2_act.irrep(1)])
        self.feat_type_in_hid = nn.FieldType(r2_act, 32*[r2_act.regular_repr])
        self.feat_type_hid_out = nn.FieldType(r2_act, (16 + input_frames)*[r2_act.irrep(1)])
        self.feat_type_out = nn.FieldType(r2_act, output_frames*[r2_act.irrep(1)])
        
        self.conv1 = nn.SequentialModule(
            nn.R2Conv(self.feat_type_in, self.feat_type_in_hid, kernel_size = kernel_size, stride = 2, padding = (kernel_size - 1)//2),
            nn.InnerBatchNorm(self.feat_type_in_hid),
            nn.ReLU(self.feat_type_in_hid)
        )

        self.conv2 = rot_conv2d(32, 64, kernel_size = kernel_size, stride = 1, N = N)
        self.conv2_1 = rot_conv2d(64, 64, kernel_size = kernel_size, stride = 1, N = N)
        self.conv3 = rot_conv2d(64, 128, kernel_size = kernel_size, stride = 2, N = N)
        self.conv3_1 = rot_conv2d(128, 128, kernel_size = kernel_size, stride = 1, N = N)
        self.conv4 = rot_conv2d(128, 256, kernel_size = kernel_size, stride = 2, N = N)
        self.conv4_1 = rot_conv2d(256, 256, kernel_size = kernel_size, stride = 1, N = N)

        self.deconv3 = rot_deconv2d(256, 64, N)
        self.deconv2 = rot_deconv2d(192, 32, N)
        self.deconv1 = rot_deconv2d(96, 16, N, last_deconv = True)

    
        self.output_layer = nn.R2Conv(self.feat_type_hid_out, self.feat_type_out, kernel_size = kernel_size, padding = (kernel_size - 1)//2)
예제 #7
0
    def __init__(
        self,
        in_type: enn.FieldType,
        inner_type: enn.FieldType,
        dropout_rate: float,
        stride: int = 1,
        out_type: enn.FieldType = None,
    ):
        super(WideBasic, self).__init__()

        if out_type is None:
            out_type = in_type

        self.in_type = in_type
        inner_type = inner_type
        self.out_type = out_type

        if isinstance(in_type.gspace, gspaces.FlipRot2dOnR2):
            rotations = in_type.gspace.fibergroup.rotation_order
        elif isinstance(in_type.gspace, gspaces.Rot2dOnR2):
            rotations = in_type.gspace.fibergroup.order()
        else:
            rotations = 0

        if rotations in [0, 2, 4]:
            conv = conv3x3
        else:
            conv = conv5x5

        self.bn1 = enn.InnerBatchNorm(self.in_type)
        self.relu1 = enn.ReLU(self.in_type, inplace=True)
        self.conv1 = conv(self.in_type, inner_type)

        self.bn2 = enn.InnerBatchNorm(inner_type)
        self.relu2 = enn.ReLU(inner_type, inplace=True)

        self.dropout = enn.PointwiseDropout(inner_type, p=dropout_rate)

        self.conv2 = conv(inner_type, self.out_type, stride=stride)

        self.shortcut = None
        if stride != 1 or self.in_type != self.out_type:
            self.shortcut = conv1x1(self.in_type,
                                    self.out_type,
                                    stride=stride,
                                    bias=False)
예제 #8
0
 def __init__(self, in_type, out_type):
     super(BasicBlock, self).__init__()
     self.in_type = in_type
     self.out_type = out_type
     
     self.bn1 = enn.InnerBatchNorm(self.in_type)
     self.relu1 = enn.ReLU(self.in_type, inplace=True)
     self.conv1 = conv3x3(self.in_type, self.out_type)
예제 #9
0
 def __init__(self, in_type, out_type, gspace):
     super(TransitionBlock, self).__init__()
     self.gspace = gspace
     self.in_type = FIELD_TYPE["regular"](self.gspace, in_type, fixparams=False)
     self.out_type = FIELD_TYPE["regular"](self.gspace, out_type, fixparams=False)
     
     self.bn1 = enn.InnerBatchNorm(self.in_type)
     self.relu1 = enn.ReLU(self.in_type,inplace=True)
     self.conv1 = conv1x1(self.in_type,self.out_type)
     self.avgpool = enn.PointwiseAvgPool(self.out_type, kernel_size=2)
예제 #10
0
 def __init__(self, in_type,inner_type, out_type, stride=1):
     super(ResBlock, self).__init__()
     
     self.in_type = in_type
     self.inner_type = inner_type
     self.out_type = out_type
     
     self.conv1 = conv1x1(self.in_type, self.inner_type, stride = 1, bias = False)
     self.bn1 = enn.InnerBatchNorm(self.inner_type)
     self.relu1 = enn.ReLU(self.inner_type)
     
     self.conv2 = conv3x3(self.inner_type, self.inner_type, padding=1, stride = stride, bias = False)
     self.bn2 = enn.InnerBatchNorm(self.inner_type)
     self.relu2 = enn.ReLU(self.inner_type, inplace=True)
     
     self.conv3 = conv1x1(self.inner_type, self.out_type, stride = 1, bias = False)
     self.bn3 = enn.InnerBatchNorm(self.out_type)
     self.relu3 = enn.ReLU(self.out_type, inplace=True)
     
     self.shortcut = None
     if stride != 1 or self.in_type != self.out_type:
         self.shortcut = enn.R2Conv(self.in_type, self.out_type, kernel_size=1, stride=stride, bias=False)
예제 #11
0
 def __init__(self, nclasses=1):
     super(ResNet50, self).__init__()
     self.gspace = gspaces.Rot2dOnR2(N=8)
     
     reg_field64 = FIELD_TYPE["regular"](self.gspace, 64, fixparams=False)
     reg_field256 = FIELD_TYPE["regular"](self.gspace, 256, fixparams=False)
     reg_field128 = FIELD_TYPE["regular"](self.gspace, 128, fixparams=False)
     reg_field512 = FIELD_TYPE["regular"](self.gspace, 512, fixparams=False)
     reg_field1024 = FIELD_TYPE["regular"](self.gspace, 1024, fixparams=False)
     reg_field2048 = FIELD_TYPE["regular"](self.gspace, 2048, fixparams=False)
     
     self.conv1 = enn.R2Conv(FIELD_TYPE["trivial"](self.gspace, 3, fixparams=False),
                             reg_field64, kernel_size=7, stride=2, padding=3)
     self.bn1 = enn.InnerBatchNorm(reg_field64)
     self.relu1 = enn.ELU(reg_field64)
     self.maxpool1 = enn.PointwiseMaxPoolAntialiased(reg_field64, kernel_size=2)
     
     layer1 = []
     layer1.append(ResBlock(stride=2, in_type = reg_field64, inner_type = reg_field64, out_type = reg_field256))
     layer1.append(ResBlock(stride=1, in_type = reg_field256, inner_type = reg_field64, out_type = reg_field256))
     layer1.append(ResBlock(stride=1, in_type = reg_field256, inner_type = reg_field64, out_type = reg_field256))
     self.layer1 = torch.nn.Sequential(*layer1)
     
     layer2 = []
     layer2.append(ResBlock(stride=2, in_type = reg_field256, inner_type = reg_field128, out_type = reg_field512))
     layer2.append(ResBlock(stride=1, in_type = reg_field512, inner_type = reg_field128, out_type = reg_field512))
     layer2.append(ResBlock(stride=1, in_type = reg_field512, inner_type = reg_field128, out_type = reg_field512))
     layer2.append(ResBlock(stride=1, in_type = reg_field512, inner_type = reg_field128, out_type = reg_field512))
     self.layer2 = torch.nn.Sequential(*layer2)
     
     layer3 = []
     layer3.append(ResBlock(stride=2, in_type = reg_field512, inner_type = reg_field256, out_type = reg_field1024))
     layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024))
     layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024))
     layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024))
     layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024))
     layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024))
     self.layer3 = torch.nn.Sequential(*layer3)
     
     layer4 = []
     layer4.append(ResBlock(stride=2, in_type = reg_field1024, inner_type = reg_field512, out_type = reg_field2048))
     layer4.append(ResBlock(stride=1, in_type = reg_field2048, inner_type = reg_field512, out_type = reg_field2048))
     layer4.append(ResBlock(stride=1, in_type = reg_field2048, inner_type = reg_field512, out_type = reg_field2048))
     self.layer4 = torch.nn.Sequential(*layer4)
     
     self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
     self.fc = torch.nn.Linear(2048, nclasses)
예제 #12
0
 def __init__(self, growth_rate, list_layer, nclasses):
     super(DenseNet161, self).__init__()
     
     self.gspace = gspaces.Rot2dOnR2(N=8)
     
     in_type = 2*growth_rate
     
     self.conv1 = conv7x7(FIELD_TYPE["trivial"](self.gspace, 3, fixparams=False), 
                          FIELD_TYPE["regular"](self.gspace, in_type, fixparams=False))
     
     self.pool1 = enn.PointwiseMaxPool(FIELD_TYPE["regular"](self.gspace, in_type, fixparams=False),
                                       kernel_size=2, stride=2)
     
     
     #1st block
     self.block1 = DenseBlock(in_type, growth_rate, self.gspace, list_layer[0])
     in_type = in_type +list_layer[0]*growth_rate
     self.trans1 = TransitionBlock(in_type, int(in_type/2), self.gspace)
     in_type = int(in_type/2)
     
     #2nd block
     self.block2 = DenseBlock(in_type, growth_rate, self.gspace, list_layer[1])
     in_type = in_type +list_layer[1]*growth_rate
     self.trans2 = TransitionBlock(in_type, int(in_type/2), self.gspace)
     in_type = int(in_type/2)
     
     #3rd block
     self.block3 = DenseBlock(in_type, growth_rate, self.gspace, list_layer[2])
     in_type = in_type +list_layer[2]*growth_rate
     self.trans3 = TransitionBlock(in_type, int(in_type/2), self.gspace)
     in_type = int(in_type/2)
     
     #4th block
     self.block4 = DenseBlock(in_type, growth_rate, self.gspace, list_layer[3])
     in_type = in_type +list_layer[3]*growth_rate
     
     
     self.bn = enn.InnerBatchNorm(FIELD_TYPE["regular"](self.gspace, in_type, fixparams=False))
     self.relu = enn.ReLU(FIELD_TYPE["regular"](self.gspace, in_type, fixparams=False),inplace=True)
     self.pool2 = torch.nn.AdaptiveAvgPool2d((1, 1))
     self.classifier = torch.nn.Linear(in_type, nclasses)
예제 #13
0
 def __init__(self, input_frames, output_frames, kernel_size, N):
     super(ResNet_Rot, self).__init__()
     r2_act = gspaces.Rot2dOnR2(N = N)
     # we use rho_1 representation since the input is velocity fields 
     self.feat_type_in = nn.FieldType(r2_act, input_frames*[r2_act.irrep(1)])
     # we use regular representation for middle layers
     self.feat_type_in_hid = nn.FieldType(r2_act, 16*[r2_act.regular_repr])
     self.feat_type_hid_out = nn.FieldType(r2_act, 192*[r2_act.regular_repr])
     self.feat_type_out = nn.FieldType(r2_act, output_frames*[r2_act.irrep(1)])
     
     self.input_layer = nn.SequentialModule(
         nn.R2Conv(self.feat_type_in, self.feat_type_in_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(self.feat_type_in_hid),
         nn.ReLU(self.feat_type_in_hid)
     )
     layers = [self.input_layer]
     layers += [rot_resblock(16, 32, kernel_size, N), rot_resblock(32, 32, kernel_size, N)]
     layers += [rot_resblock(32, 64, kernel_size, N), rot_resblock(64, 64, kernel_size, N)]
     layers += [rot_resblock(64, 128, kernel_size, N), rot_resblock(128, 128, kernel_size, N)]
     layers += [rot_resblock(128, 192, kernel_size, N), rot_resblock(192, 192, kernel_size, N)]
     layers += [nn.R2Conv(self.feat_type_hid_out, self.feat_type_out, kernel_size = kernel_size, padding = (kernel_size - 1)//2)]
     self.model = torch.nn.Sequential(*layers)
예제 #14
0
    def __init__(self, depth, widen_factor, dropout_rate, num_classes=100,
                 N: int = 8,
                 r: int = 1,
                 f: bool = True,
                 deltaorth: bool = False,
                 fixparams: bool = True,
                 initial_stride: int = 1,
                 ):
        r"""
        
        Build and equivariant Wide ResNet.
        
        The parameter ``N`` controls rotation equivariance and the parameter ``f`` reflection equivariance.
        
        More precisely, ``N`` is the number of discrete rotations the model is initially equivariant to.
        ``N = 1`` means the model is only reflection equivariant from the beginning.
        
        ``f`` is a boolean flag specifying whether the model should be reflection equivariant or not.
        If it is ``False``, the model is not reflection equivariant.
        
        ``r`` is the restriction level:
        
        - ``0``: no restriction. The model is equivariant to ``N`` rotations from the input to the output

        - ``1``: restriction before the last block. The model is equivariant to ``N`` rotations before the last block
               (i.e. in the first 2 blocks). Then it is restricted to ``N/2`` rotations until the output.
        
        - ``2``: restriction after the first block. The model is equivariant to ``N`` rotations in the first block.
               Then it is restricted to ``N/2`` rotations until the output (i.e. in the last 3 blocks).
               
        - ``3``: restriction after the first and the second block. The model is equivariant to ``N`` rotations in the first
               block. It is restricted to ``N/2`` rotations before the second block and to ``1`` rotations before the last
               block.
        
        NOTICE: if restriction to ``N/2`` is performed, ``N`` needs to be even!
        
        """
        super(Wide_ResNet, self).__init__()
        
        assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
        n = int((depth - 4) / 6)
        k = widen_factor
        
        print(f'| Wide-Resnet {depth}x{k}')
        
        nStages = [16, 16 * k, 32 * k, 64 * k]
        
        self._fixparams = fixparams
        
        self._layer = 0
        
        # number of discrete rotations to be equivariant to
        self._N = N
        
        # if the model is [F]lip equivariant
        self._f = f
        if self._f:
            if N != 1:
                self.gspace = gspaces.FlipRot2dOnR2(N)
            else:
                self.gspace = gspaces.Flip2dOnR2()
        else:
            if N != 1:
                self.gspace = gspaces.Rot2dOnR2(N)
            else:
                self.gspace = gspaces.TrivialOnR2()

        # level of [R]estriction:
        #   r = 0: never do restriction, i.e. initial group (either DN or CN) preserved for the whole network
        #   r = 1: restrict before the last block, i.e. initial group (either DN or CN) preserved for the first
        #          2 blocks, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the last block
        #   r = 2: restrict after the first block, i.e. initial group (either DN or CN) preserved for the first
        #          block, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the last 2 blocks
        #   r = 3: restrict after each block. Initial group (either DN or CN) preserved for the first
        #          block, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the second block and to 1 rotation
        #          in the last one (D1 or C1)
        assert r in [0, 1, 2, 3]
        self._r = r
        
        # the input has 3 color channels (RGB).
        # Color channels are trivial fields and don't transform when the input is rotated or flipped
        r1 = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3)
        
        # input field type of the model
        self.in_type = r1
        
        # in the first layer we always scale up the output channels to allow for enough independent filters
        r2 = FIELD_TYPE["regular"](self.gspace, nStages[0], fixparams=True)
        
        # dummy attribute keeping track of the output field type of the last submodule built, i.e. the input field type of
        # the next submodule to build
        self._in_type = r2
        
        self.conv1 = conv5x5(r1, r2)
        self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=initial_stride)
        if self._r >= 2:
            N_new = N//2
            id = (0, N_new) if self._f else N_new
            self.restrict1 = self._restrict_layer(id)
        else:
            self.restrict1 = lambda x: x
        
        self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2)
        if self._r == 3:
            id = (0, 1) if self._f else 1
            self.restrict2 = self._restrict_layer(id)
        elif self._r == 1:
            N_new = N // 2
            id = (0, N_new) if self._f else N_new
            self.restrict2 = self._restrict_layer(id)
        else:
            self.restrict2 = lambda x: x
        
        # last layer maps to a trivial (invariant) feature map
        self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2, totrivial=True)
        
        self.bn = enn.InnerBatchNorm(self.layer3.out_type, momentum=0.9)
        self.relu = enn.ReLU(self.bn.out_type, inplace=True)
        self.linear = torch.nn.Linear(self.bn.out_type.size, num_classes)
        
        for name, module in self.named_modules():
            if isinstance(module, enn.R2Conv):
                if deltaorth:
                    init.deltaorthonormal_init(module.weights, module.basisexpansion)
            elif isinstance(module, torch.nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()
            elif isinstance(module, torch.nn.Linear):
                module.bias.data.zero_()
        
        print("MODEL TOPOLOGY:")
        for i, (name, mod) in enumerate(self.named_modules()):
            print(f"\t{i} - {name}")
예제 #15
0
    def __init__(self, n_classes=10):

        super(C8SteerableCNN, self).__init__()

        # the model is equivariant under rotations by 45 degrees, modelled by C8
        self.r2_act = gspaces.Rot2dOnR2(N=8)

        # the input image is a scalar field, corresponding to the trivial representation
        in_type = nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])

        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = in_type

        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 16 feature fields, each transforming under the regular representation of C8
        out_type = nn.FieldType(self.r2_act, 24 * [self.r2_act.regular_repr])
        self.block1 = nn.SequentialModule(
            # nn.MaskModule(in_type, 29, margin=1),
            nn.R2Conv(in_type, out_type, kernel_size=7, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True))

        # convolution 2
        # the old output type is the input type to the next layer
        in_type = self.block1.out_type
        # the output type of the second convolution layer are 32 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr])
        self.block2 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True))
        self.pool1 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2))

        # convolution 3
        # the old output type is the input type to the next layer
        in_type = self.block2.out_type
        # the output type of the third convolution layer are 32 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr])
        self.block3 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True))

        # convolution 4
        # the old output type is the input type to the next layer
        in_type = self.block3.out_type
        # the output type of the fourth convolution layer are 64 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 96 * [self.r2_act.regular_repr])
        self.block4 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True))
        self.pool2 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2))

        # convolution 5
        # the old output type is the input type to the next layer
        in_type = self.block4.out_type
        # the output type of the fifth convolution layer are 64 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 96 * [self.r2_act.regular_repr])
        self.block5 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True))

        # convolution 6
        # the old output type is the input type to the next layer
        in_type = self.block5.out_type
        # the output type of the sixth convolution layer are 64 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 64 * [self.r2_act.regular_repr])
        self.block6 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            nn.InnerBatchNorm(out_type), nn.ReLU(out_type, inplace=True))
        self.pool3 = nn.PointwiseAvgPool(out_type, kernel_size=4)

        self.gpool = nn.GroupPooling(out_type)

        # number of output channels
        c = self.gpool.out_type.size

        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(c, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )
예제 #16
0
    def __init__(self,
                 base = 'DNSteerableAGRadGalNet',
                 attention_module='SelfAttention',
                 attention_gates=3,
                 attention_aggregation='ft',
                 n_classes=2,
                 attention_normalisation='sigmoid',
                 quiet=True,
                 number_rotations=8,
                 imsize=150,
                 kernel_size=3,
                 group="D"
                ):
        super(DNSteerableAGRadGalNet, self).__init__()
        aggregation_mode = attention_aggregation
        normalisation = attention_normalisation
        AG = int(attention_gates)
        N = int(number_rotations)
        kernel_size = int(kernel_size)
        imsize = int(imsize)
        n_classes = int(n_classes)
        assert aggregation_mode in ['concat', 'mean', 'deep_sup', 'ft'], 'Aggregation mode not recognised. Valid inputs include concat, mean, deep_sup or ft.'
        assert normalisation in ['sigmoid','range_norm','std_mean_norm','tanh','softmax'], f'Nomralisation not implemented. Can be any of: sigmoid, range_norm, std_mean_norm, tanh, softmax'
        assert AG in [0,1,2,3], f'Number of Attention Gates applied (AG) must be an integer in range [0,3]. Currently AG={AG}'
        assert group.lower() in ["d","c"], f"group parameter must either be 'D' for DN, or 'C' for CN, steerable networks. (currently {group})."
        filters = [6,16,32,64,128]

        self.attention_out_sizes = []
        self.ag = AG
        self.n_classes = n_classes
        self.filters = filters
        self.aggregation_mode = aggregation_mode

        # Setting up e2
        if group.lower() == "d":
            self.r2_act = gspaces.FlipRot2dOnR2(N=int(number_rotations))
        else:
            self.r2_act = gspaces.Rot2dOnR2(N=int(number_rotations))
        in_type = e2nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
        out_type = e2nn.FieldType(self.r2_act, 6*[self.r2_act.regular_repr])
        self.in_type = in_type

        self.mask = e2nn.MaskModule(in_type, imsize, margin=0)
        self.conv1a = e2nn.R2Conv(in_type,  out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu1a = e2nn.ReLU(out_type); self.bnorm1a= e2nn.InnerBatchNorm(out_type)
        self.conv1b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu1b = e2nn.ReLU(out_type); self.bnorm1b= e2nn.InnerBatchNorm(out_type)
        self.conv1c = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu1c = e2nn.ReLU(out_type); self.bnorm1c= e2nn.InnerBatchNorm(out_type)
        self.mpool1 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2)
        self.gpool1 = e2nn.GroupPooling(out_type)


        in_type = out_type
        out_type = e2nn.FieldType(self.r2_act, 16*[self.r2_act.regular_repr])
        self.conv2a = e2nn.R2Conv(in_type,  out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu2a = e2nn.ReLU(out_type); self.bnorm2a= e2nn.InnerBatchNorm(out_type)
        self.conv2b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu2b = e2nn.ReLU(out_type); self.bnorm2b= e2nn.InnerBatchNorm(out_type)
        self.conv2c = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu2c = e2nn.ReLU(out_type); self.bnorm2c= e2nn.InnerBatchNorm(out_type)
        self.mpool2 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2)
        self.gpool2 = e2nn.GroupPooling(out_type)

        in_type = out_type
        out_type = e2nn.FieldType(self.r2_act, 32*[self.r2_act.regular_repr])
        self.conv3a = e2nn.R2Conv(in_type,  out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu3a = e2nn.ReLU(out_type); self.bnorm3a= e2nn.InnerBatchNorm(out_type)
        self.conv3b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu3b = e2nn.ReLU(out_type); self.bnorm3b= e2nn.InnerBatchNorm(out_type)
        self.conv3c = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu3c = e2nn.ReLU(out_type); self.bnorm3c= e2nn.InnerBatchNorm(out_type)
        self.mpool3 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2)
        self.gpool3 = e2nn.GroupPooling(out_type)

        in_type = out_type
        out_type = e2nn.FieldType(self.r2_act, 64*[self.r2_act.regular_repr])
        self.conv4a = e2nn.R2Conv(in_type,  out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu4a = e2nn.ReLU(out_type); self.bnorm4a= e2nn.InnerBatchNorm(out_type)
        self.conv4b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu4b = e2nn.ReLU(out_type); self.bnorm4b= e2nn.InnerBatchNorm(out_type)
        self.mpool4 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2)
        self.gpool4 = e2nn.GroupPooling(out_type)

        self.flatten = nn.Flatten(1)
        self.dropout = nn.Dropout(p=0.5)

        if self.ag == 0:
            pass
        if self.ag >= 1:
            self.attention1 = GridAttentionBlock2D(in_channels=32, gating_channels=64, inter_channels=64, input_size=[imsize//4,imsize//4], normalisation=normalisation)
        if self.ag >= 2:
            self.attention2 = GridAttentionBlock2D(in_channels=16, gating_channels=64, inter_channels=64, input_size=[imsize//2,imsize//2], normalisation=normalisation)
        if self.ag >= 3:
            self.attention3 = GridAttentionBlock2D(in_channels=6, gating_channels=64, inter_channels=64, input_size=[imsize,imsize], normalisation=normalisation)

        self.fc1 = nn.Linear(16*5*5,256) #channel_size * width * height
        self.fc2 = nn.Linear(256,256)
        self.fc3 = nn.Linear(256, self.n_classes)
        self.dummy = nn.Parameter(torch.empty(0))

        self.module_order = ['conv1a', 'relu1a', 'bnorm1a', #1->6
                             'conv1b', 'relu1b', 'bnorm1b', #6->6
                             'conv1c', 'relu1c', 'bnorm1c', #6->6
                             'mpool1',
                             'conv2a', 'relu2a', 'bnorm2a', #6->16
                             'conv2b', 'relu2b', 'bnorm2b', #16->16
                             'conv2c', 'relu2c', 'bnorm2c', #16->16
                             'mpool2',
                             'conv3a', 'relu3a', 'bnorm3a', #16->32
                             'conv3b', 'relu3b', 'bnorm3b', #32->32
                             'conv3c', 'relu3c', 'bnorm3c', #32->32
                             'mpool3',
                             'conv4a', 'relu4a', 'bnorm4a', #32->64
                             'conv4b', 'relu4b', 'bnorm4b', #64->64
                             'compatibility_score1',
                             'compatibility_score2']


        #########################
        # Aggreagation Strategies
        if self.ag != 0:
            self.attention_filter_sizes = [32, 16, 6]
            concat_length = 0
            for i in range(self.ag):
                concat_length += self.attention_filter_sizes[i]
            if aggregation_mode == 'concat':
                self.classifier = nn.Linear(concat_length, self.n_classes)
                self.aggregate = self.aggregation_concat
            else:
                # Not able to initialise in a loop as the modules will not change device with remaining model.
                self.classifiers = nn.ModuleList()
                if self.ag>=1:
                    self.classifiers.append(nn.Linear(self.attention_filter_sizes[0], self.n_classes))
                if self.ag>=2:
                    self.classifiers.append(nn.Linear(self.attention_filter_sizes[1], self.n_classes))
                if self.ag>=3:
                    self.classifiers.append(nn.Linear(self.attention_filter_sizes[2], self.n_classes))
                if aggregation_mode == 'mean':
                    self.aggregate = self.aggregation_sep
                elif aggregation_mode == 'deep_sup':
                    self.classifier = nn.Linear(concat_length, self.n_classes)
                    self.aggregate = self.aggregation_ds
                elif aggregation_mode == 'ft':
                    self.classifier = nn.Linear(self.n_classes*self.ag, self.n_classes)
                    self.aggregate = self.aggregation_ft
                else:
                    raise NotImplementedError
        else:
            self.classifier = nn.Linear((150//16)**2*64, self.n_classes)
            self.aggregate = lambda x: self.classifier(self.flatten(x))
    def __init__(self, in_type, out_type):
        super(Conv, self).__init__()

        self.conv = enn.SequentialModule(
            enn.R2Conv(in_type, out_type, kernel_size=3, stride=1, padding=0),
            enn.InnerBatchNorm(out_type), enn.ReLU(out_type))
예제 #18
0
def build_norm_layer(cfg, num_features, postfix=''):
    in_type = FIELD_TYPE['regular'](gspace, num_features)
    return 'bn' + str(postfix), enn.InnerBatchNorm(in_type)
예제 #19
0
def create_equivariant_real_nvp_blocks(input_size,
                                       in_type,
                                       field_type,
                                       out_fiber,
                                       activation_fn,
                                       hidden_size,
                                       n_blocks,
                                       n_hidden,
                                       group_action_type,
                                       kernel_size=3,
                                       padding=1,
                                       only_t=False):
    nets, nett = [], []

    # we store the input type for wrapping the images into a geometric tensor during the forward pass
    input_type = in_type
    _, c, h, w = input_size
    out_type = enn.FieldType(group_action_type,
                             c * [group_action_type.trivial_repr])
    inter_block_out_type = FIBERS[out_fiber](group_action_type,
                                             hidden_size,
                                             field_type,
                                             fixparams=True)
    for i in range(n_blocks):
        if not only_t:
            s_block = [
                enn.SequentialModule(
                    enn.R2Conv(in_type,
                               inter_block_out_type,
                               kernel_size=kernel_size,
                               padding=padding,
                               bias=True),
                    enn.InnerBatchNorm(inter_block_out_type),
                    activation_fn(inter_block_out_type, inplace=True))
            ]
        t_block = [
            enn.SequentialModule(
                enn.R2Conv(in_type,
                           inter_block_out_type,
                           kernel_size=kernel_size,
                           padding=padding,
                           bias=True),
                enn.InnerBatchNorm(inter_block_out_type),
                activation_fn(inter_block_out_type, inplace=True))
        ]
        for _ in range(n_hidden):
            if not only_t:
                s_block += [
                    enn.SequentialModule(
                        enn.R2Conv(s_block[-1].out_type,
                                   inter_block_out_type,
                                   kernel_size=kernel_size,
                                   padding=padding,
                                   bias=True),
                        enn.InnerBatchNorm(inter_block_out_type),
                        activation_fn(inter_block_out_type, inplace=True))
                ]
            t_block += [
                enn.SequentialModule(
                    enn.R2Conv(t_block[-1].out_type,
                               inter_block_out_type,
                               kernel_size=kernel_size,
                               padding=padding,
                               bias=True),
                    enn.InnerBatchNorm(inter_block_out_type),
                    activation_fn(inter_block_out_type, inplace=True))
            ]

        if not only_t:
            s_block += [
                enn.SequentialModule(
                    enn.R2Conv(s_block[-1].out_type,
                               in_type,
                               kernel_size=kernel_size,
                               padding=padding,
                               bias=True), enn.InnerBatchNorm(out_type),
                    activation_fn(out_type, inplace=True))
            ]
            nets += [MultiInputSequential(*s_block)]

        t_block += [
            enn.SequentialModule(
                enn.R2Conv(t_block[-1].out_type,
                           in_type,
                           kernel_size=kernel_size,
                           padding=padding,
                           bias=True), enn.InnerBatchNorm(out_type),
                activation_fn(out_type, inplace=True))
        ]
        nett += [MultiInputSequential(*t_block)]

    t = nett = MultiInputSequential(*nett)
    if not only_t:
        s = nets = MultiInputSequential(*nets)
        return s, t
    else:
        return t
    def __init__(self, input_size, in_type, field_type, out_fiber,
                 activation_fn, hidden_size, group_action_type):

        super(InvariantCNNBlock, self).__init__()
        _, self.c, self.h, self.w = input_size
        ngf = 16
        self.group_action_type = group_action_type
        feat_type_in = enn.FieldType(
            self.group_action_type,
            self.c * [self.group_action_type.trivial_repr])
        feat_type_hid = FIBERS[out_fiber](group_action_type,
                                          hidden_size,
                                          field_type,
                                          fixparams=True)
        feat_type_out = enn.FieldType(
            self.group_action_type,
            128 * [self.group_action_type.regular_repr])

        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = feat_type_in

        self.block1 = enn.SequentialModule(
            enn.R2Conv(feat_type_in, feat_type_hid, kernel_size=5, padding=0),
            enn.InnerBatchNorm(feat_type_hid),
            activation_fn(feat_type_hid, inplace=True),
        )

        self.pool1 = enn.SequentialModule(
            enn.PointwiseAvgPoolAntialiased(feat_type_hid,
                                            sigma=0.66,
                                            stride=2))

        self.block2 = enn.SequentialModule(
            enn.R2Conv(feat_type_hid, feat_type_hid, kernel_size=5),
            enn.InnerBatchNorm(feat_type_hid),
            activation_fn(feat_type_hid, inplace=True),
        )

        self.pool2 = enn.SequentialModule(
            enn.PointwiseAvgPoolAntialiased(feat_type_hid,
                                            sigma=0.66,
                                            stride=2))

        self.block3 = enn.SequentialModule(
            enn.R2Conv(feat_type_hid, feat_type_out, kernel_size=3, padding=1),
            enn.InnerBatchNorm(feat_type_out),
            activation_fn(feat_type_out, inplace=True),
        )

        self.pool3 = enn.PointwiseAvgPoolAntialiased(feat_type_out,
                                                     sigma=0.66,
                                                     stride=1,
                                                     padding=0)

        self.gpool = enn.GroupPooling(feat_type_out)

        self.gc = self.gpool.out_type.size
        self.gen = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(self.gc,
                                     ngf,
                                     kernel_size=4,
                                     stride=1,
                                     padding=0), torch.nn.BatchNorm2d(ngf),
            torch.nn.LeakyReLU(0.2),
            torch.nn.ConvTranspose2d(ngf,
                                     ngf,
                                     kernel_size=4,
                                     stride=2,
                                     padding=1), torch.nn.BatchNorm2d(ngf),
            torch.nn.LeakyReLU(0.2),
            torch.nn.ConvTranspose2d(ngf,
                                     int(ngf / 2),
                                     kernel_size=4,
                                     stride=2,
                                     padding=1),
            torch.nn.BatchNorm2d(int(ngf / 2)), torch.nn.LeakyReLU(0.2),
            torch.nn.ConvTranspose2d(int(ngf / 2),
                                     self.c,
                                     kernel_size=4,
                                     stride=2,
                                     padding=1), torch.nn.Tanh())