예제 #1
0
    def __init__(self, channel_in=3, n_classes=4, rot_n=4):
        super(SmallE2, self).__init__()

        r2_act = gspaces.Rot2dOnR2(N=rot_n)

        self.feat_type_in = nn.FieldType(r2_act,
                                         channel_in * [r2_act.trivial_repr])
        feat_type_hid = nn.FieldType(r2_act, 8 * [r2_act.regular_repr])
        feat_type_out = nn.FieldType(r2_act, 2 * [r2_act.regular_repr])

        self.bn = nn.InnerBatchNorm(feat_type_hid)
        self.relu = nn.ReLU(feat_type_hid)

        self.convin = nn.R2Conv(self.feat_type_in,
                                feat_type_hid,
                                kernel_size=3)
        self.convhid = nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size=3)
        self.convout = nn.R2Conv(feat_type_hid, feat_type_out, kernel_size=3)

        self.avgpool = nn.PointwiseAvgPool(feat_type_out, 3)
        self.invariant_map = nn.GroupPooling(feat_type_out)

        c = self.invariant_map.out_type.size

        self.lin_in = torch.nn.Linear(c, 64)
        self.elu = torch.nn.ELU()
        self.lin_out = torch.nn.Linear(64, n_classes)
예제 #2
0
    def __init__(self, input_frames, output_frames, kernel_size, N):
        super(Unet_Rot, self).__init__()
        r2_act = gspaces.Rot2dOnR2(N = N)
        self.feat_type_in = nn.FieldType(r2_act, input_frames*[r2_act.irrep(1)])
        self.feat_type_in_hid = nn.FieldType(r2_act, 32*[r2_act.regular_repr])
        self.feat_type_hid_out = nn.FieldType(r2_act, (16 + input_frames)*[r2_act.irrep(1)])
        self.feat_type_out = nn.FieldType(r2_act, output_frames*[r2_act.irrep(1)])
        
        self.conv1 = nn.SequentialModule(
            nn.R2Conv(self.feat_type_in, self.feat_type_in_hid, kernel_size = kernel_size, stride = 2, padding = (kernel_size - 1)//2),
            nn.InnerBatchNorm(self.feat_type_in_hid),
            nn.ReLU(self.feat_type_in_hid)
        )

        self.conv2 = rot_conv2d(32, 64, kernel_size = kernel_size, stride = 1, N = N)
        self.conv2_1 = rot_conv2d(64, 64, kernel_size = kernel_size, stride = 1, N = N)
        self.conv3 = rot_conv2d(64, 128, kernel_size = kernel_size, stride = 2, N = N)
        self.conv3_1 = rot_conv2d(128, 128, kernel_size = kernel_size, stride = 1, N = N)
        self.conv4 = rot_conv2d(128, 256, kernel_size = kernel_size, stride = 2, N = N)
        self.conv4_1 = rot_conv2d(256, 256, kernel_size = kernel_size, stride = 1, N = N)

        self.deconv3 = rot_deconv2d(256, 64, N)
        self.deconv2 = rot_deconv2d(192, 32, N)
        self.deconv1 = rot_deconv2d(96, 16, N, last_deconv = True)

    
        self.output_layer = nn.R2Conv(self.feat_type_hid_out, self.feat_type_out, kernel_size = kernel_size, padding = (kernel_size - 1)//2)
예제 #3
0
 def __init__(self, input_channels, output_channels, kernel_size, stride, N, activation = True, deconv = False, last_deconv = False):
     super(rot_conv2d, self).__init__()       
     r2_act = gspaces.Rot2dOnR2(N = N)
     
     feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
     feat_type_hid = nn.FieldType(r2_act, output_channels*[r2_act.regular_repr])
     if not deconv:
         if activation:
             self.layer = nn.SequentialModule(
                 nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = (kernel_size - 1)//2),
                 nn.InnerBatchNorm(feat_type_hid),
                 nn.ReLU(feat_type_hid)
             ) 
         else:
             self.layer = nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride,padding = (kernel_size - 1)//2)
     else:
         if last_deconv:
             feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
             feat_type_hid = nn.FieldType(r2_act, output_channels*[r2_act.irrep(1)])
             self.layer = nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = 0)
         else:
             self.layer = nn.SequentialModule(
                     nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = 0),
                     nn.InnerBatchNorm(feat_type_hid),
                     nn.ReLU(feat_type_hid)
                 ) 
예제 #4
0
 def __init__(self, 
              input_channels,
              hidden_dim, 
              kernel_size, 
              N # Group size 
             ): 
     super(rot_resblock, self).__init__()
     
     # Specify symmetry transformation
     r2_act = gspaces.Rot2dOnR2(N = N)
     feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
     feat_type_hid = nn.FieldType(r2_act, hidden_dim*[r2_act.regular_repr])
     
     self.layer1 = nn.SequentialModule(
         nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(feat_type_hid),
         nn.ReLU(feat_type_hid)
     ) 
     
     self.layer2 = nn.SequentialModule(
         nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(feat_type_hid),
         nn.ReLU(feat_type_hid)
     )    
     
     self.upscale = nn.SequentialModule(
         nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(feat_type_hid),
         nn.ReLU(feat_type_hid)
     )    
     
     self.input_channels = input_channels
     self.hidden_dim = hidden_dim
예제 #5
0
    def __init__(self, n_classes=6):
        super(SteerCNN, self).__init__()

        # the model is equivariant under rotations by 45 degrees, modelled by C8
        self.r2_act = gspaces.Rot2dOnR2(N=4)

        # the input image is a scalar field, corresponding to the trivial representation
        input_type = nn_e2.FieldType(self.r2_act,
                                     3 * [self.r2_act.trivial_repr])

        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = input_type
        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 24 feature fields, each transforming under the regular representation of C8
        out_type = nn_e2.FieldType(self.r2_act,
                                   24 * [self.r2_act.regular_repr])
        self.block1 = nn_e2.SequentialModule(
            nn_e2.R2Conv(input_type,
                         out_type,
                         kernel_size=7,
                         padding=3,
                         bias=False), nn_e2.InnerBatchNorm(out_type),
            nn_e2.ReLU(out_type, inplace=True))

        self.pool1 = nn_e2.PointwiseAvgPool(out_type, 4)

        # convolution 2
        # the old output type is the input type to the next layer
        in_type = self.block1.out_type
        # the output type of the second convolution layer are 48 regular feature fields of C8
        #out_type = nn_e2.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr])
        self.block2 = nn_e2.SequentialModule(
            nn_e2.R2Conv(in_type,
                         out_type,
                         kernel_size=7,
                         padding=3,
                         bias=False), nn_e2.InnerBatchNorm(out_type),
            nn_e2.ReLU(out_type, inplace=True))
        self.pool2 = nn_e2.SequentialModule(
            nn_e2.PointwiseAvgPoolAntialiased(out_type,
                                              sigma=0.66,
                                              stride=1,
                                              padding=0),
            nn_e2.PointwiseAvgPool(out_type, 4), nn_e2.GroupPooling(out_type))
        # PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=7)

        # number of output channels
        c = 24 * 13 * 13  #self.gpool.out_type.size

        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(c, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )
    def __init__(self, input_shape, num_actions, dueling_DQN):
        super(D4_steerable_DQN_Snake, self).__init__()
        self.input_shape = input_shape
        self.num_actions = num_actions
        self.dueling_DQN = dueling_DQN
        self.r2_act = gspaces.FlipRot2dOnR2(N=4)
        self.input_type = nn.FieldType(
            self.r2_act, input_shape[0] * [self.r2_act.trivial_repr])
        feature1_type = nn.FieldType(self.r2_act,
                                     8 * [self.r2_act.regular_repr])
        feature2_type = nn.FieldType(self.r2_act,
                                     12 * [self.r2_act.regular_repr])
        feature3_type = nn.FieldType(self.r2_act,
                                     12 * [self.r2_act.regular_repr])
        feature4_type = nn.FieldType(self.r2_act,
                                     32 * [self.r2_act.regular_repr])

        self.feature_field1 = nn.SequentialModule(
            nn.R2Conv(self.input_type,
                      feature1_type,
                      kernel_size=7,
                      padding=2,
                      stride=2,
                      bias=False), nn.ReLU(feature1_type, inplace=True))
        self.feature_field2 = nn.SequentialModule(
            nn.R2Conv(feature1_type,
                      feature2_type,
                      kernel_size=5,
                      padding=1,
                      stride=2,
                      bias=False), nn.ReLU(feature2_type, inplace=True))
        self.feature_field3 = nn.SequentialModule(
            nn.R2Conv(feature2_type,
                      feature3_type,
                      kernel_size=5,
                      padding=1,
                      stride=1,
                      bias=False), nn.ReLU(feature3_type, inplace=True))

        self.equivariant_features = nn.SequentialModule(
            nn.R2Conv(feature3_type,
                      feature4_type,
                      kernel_size=5,
                      stride=1,
                      bias=False), nn.ReLU(feature4_type, inplace=True))
        self.gpool = nn.GroupPooling(feature4_type)
        self.feature_shape()
        if self.dueling_DQN:
            print("You are using Dueling DQN")
            self.advantage = torch.nn.Linear(
                self.equivariant_features.out_type.size, self.num_actions)
            #self.value = torch.nn.Linear(self.gpool.out_type.size, 1)
            self.value = torch.nn.Linear(
                self.equivariant_features.out_type.size, 1)
        else:
            self.actionvalue = torch.nn.Linear(
                self.equivariant_features.out_type.size, self.num_actions)
예제 #7
0
    def __init__(self,
                 base='DNSteerableLeNet',
                 in_chan=1,
                 n_classes=2,
                 imsize=150,
                 kernel_size=5,
                 N=8,
                 quiet=True,
                 number_rotations=None):
        super(DNSteerableLeNet, self).__init__()
        kernel_size = int(kernel_size)
        out_chan = int(n_classes)

        if number_rotations != None:
            N = int(number_rotations)

        z = imsize // 2 // 2

        self.r2_act = gspaces.FlipRot2dOnR2(N)

        in_type = e2nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
        self.input_type = in_type

        out_type = e2nn.FieldType(self.r2_act, 6 * [self.r2_act.regular_repr])
        self.mask = e2nn.MaskModule(in_type, imsize, margin=1)
        self.conv1 = e2nn.R2Conv(in_type,
                                 out_type,
                                 kernel_size=kernel_size,
                                 padding=kernel_size // 2,
                                 bias=False)
        self.relu1 = e2nn.ReLU(out_type, inplace=True)
        self.pool1 = e2nn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2)

        in_type = self.pool1.out_type
        out_type = e2nn.FieldType(self.r2_act, 16 * [self.r2_act.regular_repr])
        self.conv2 = e2nn.R2Conv(in_type,
                                 out_type,
                                 kernel_size=kernel_size,
                                 padding=kernel_size // 2,
                                 bias=False)
        self.relu2 = e2nn.ReLU(out_type, inplace=True)
        self.pool2 = e2nn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2)

        self.gpool = e2nn.GroupPooling(out_type)

        self.fc1 = nn.Linear(16 * z * z, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, out_chan)

        self.drop = nn.Dropout(p=0.5)

        # dummy parameter for tracking device
        self.dummy = nn.Parameter(torch.empty(0))
예제 #8
0
    def __init__(self, in_chan, out_chan, imsize, kernel_size=5, N=8):
        super(DNRestrictedLeNet, self).__init__()

        z = imsize // 2 // 2

        self.r2_act = gspaces.FlipRot2dOnR2(N)

        in_type = e2nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
        self.input_type = in_type

        out_type = e2nn.FieldType(self.r2_act, 6 * [self.r2_act.regular_repr])
        self.mask = e2nn.MaskModule(in_type, imsize, margin=1)
        self.conv1 = e2nn.R2Conv(in_type,
                                 out_type,
                                 kernel_size=kernel_size,
                                 padding=kernel_size // 2,
                                 bias=False)
        self.relu1 = e2nn.ReLU(out_type, inplace=True)
        self.pool1 = e2nn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2)

        self.gpool = e2nn.GroupPooling(out_type)

        self.conv2 = nn.Conv2d(6, 16, kernel_size, padding=kernel_size // 2)

        self.fc1 = nn.Linear(16 * z * z, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, out_chan)

        self.drop = nn.Dropout(p=0.5)

        # dummy parameter for tracking device
        self.dummy = nn.Parameter(torch.empty(0))
 def __init__(self, in_type, out_type, mid_type, padding=0):
     super(UpSample, self).__init__()
     self.mid_type = mid_type
     self.upsample = enn.R2Upsampling(mid_type, 2)
     self.conv1 = Conv(in_type, mid_type)
     self.conv2 = enn.R2Conv(mid_type, out_type, kernel_size=1)
     self.conv3 = Conv(out_type, out_type)
예제 #10
0
def create_equivariant_convexp_blocks(input_size,
                                      in_type,
                                      field_type,
                                      out_fiber,
                                      activation_fn,
                                      hidden_size,
                                      n_blocks,
                                      n_hidden,
                                      group_action_type,
                                      kernel_size=3,
                                      padding=1):
    nets = []
    _, c, h, w = input_size
    input_type = in_type
    _, c, h, w = input_size
    out_type = enn.FieldType(group_action_type,
                             c * [group_action_type.trivial_repr])
    for i in range(n_blocks):
        s_block = [
            enn.R2Conv(in_type,
                       out_type,
                       kernel_size=kernel_size,
                       padding=padding,
                       bias=True),
            # enn.InnerBatchNorm(out_type),
            # activation_fn(out_type, inplace=True)
        ]
        nets += [MultiInputSequential(*s_block)]
    s = nets = MultiInputSequential(*nets)
    return s
 def __init__(self, in_type, out_type, mid_type, padding=0):
     super(LastConcat, self).__init__()
     self.mid_type = mid_type
     self.conv1 = Conv(in_type, mid_type)
     self.conv2 = enn.R2Conv(mid_type,
                             out_type,
                             kernel_size=1,
                             stride=1,
                             padding=0)
    def build_multiscale_classifier(self, input_size):
        n, c, h, w = input_size
        hidden_shapes = []
        for i in range(self.n_scale):
            if i < self.n_scale - 1:
                c *= 2 if self.factor_out else 4
                h //= 2
                w //= 2
            hidden_shapes.append((n, c, h, w))

        classification_heads = []
        feat_type_out = FIBERS['regular'](self.group_action_type,
                                          self.classification_hdim,
                                          self.field_type, fixparams=True)
        feat_type_mid = FIBERS['regular'](self.group_action_type,
                                          int(self.classification_hdim // 2),
                                          self.field_type, fixparams=True)
        feat_type_last = FIBERS['regular'](self.group_action_type,
                                          int(self.classification_hdim // 4),
                                          self.field_type, fixparams=True)
        # feat_type_out = enn.FieldType(self.group_action_type,
                                      # self.classification_hdim*[self.group_action_type.regular_repr])
        for i, hshape in enumerate(hidden_shapes):
            classification_heads.append(
                nn.Sequential(
                    enn.R2Conv(self.input_type, feat_type_out, 5, stride=2),
                    layers.EquivariantActNorm2d(feat_type_out.size),
                    enn.ReLU(feat_type_out, inplace=True),
                    enn.PointwiseAvgPoolAntialiased(feat_type_out, sigma=0.66, stride=2),
                    enn.R2Conv(feat_type_out, feat_type_mid, kernel_size=3),
                    layers.EquivariantActNorm2d(feat_type_mid.size),
                    enn.ReLU(feat_type_mid, inplace=True),
                    enn.PointwiseAvgPoolAntialiased(feat_type_mid, sigma=0.66, stride=1),
                    enn.R2Conv(feat_type_mid, feat_type_last, kernel_size=3),
                    layers.EquivariantActNorm2d(feat_type_last.size),
                    enn.ReLU(feat_type_last, inplace=True),
                    enn.PointwiseAvgPoolAntialiased(feat_type_last, sigma=0.66, stride=2),
                    enn.GroupPooling(feat_type_last),
                )
            )
        self.classification_heads = nn.ModuleList(classification_heads)
        self.logit_layer = nn.Linear(classification_heads[-1][-1].out_type.size, self.n_classes)
예제 #13
0
def build_steer_cnn_2d(
    in_field_type,
    hidden_field_types,
    kernel_sizes,
    out_field_type,
    gspace,
    activation="relu",
    padding_mode="zeros",
    modify_init=1.0,
):
    """
    Input:
        in_rep - rep of representation of the input data
        hidden_reps - the reps to use in the hidden layers
        kernel sizes - the size of the kernel used in each layer
        out_rep - the rep to use in the ouput layer
        activation - the activation to use between layers
        gspace - the gsapce that data lives in
    """

    if isinstance(kernel_sizes, int):
        kernel_sizes = [kernel_sizes] * (len(hidden_reps) + 1)

    layer_field_types = [in_field_type, *hidden_field_types, out_field_type]

    layers = []

    for i in range(len(layer_field_types) - 1):
        layers.append(
            gnn.R2Conv(
                layer_field_types[i],
                layer_field_types[i + 1],
                kernel_sizes[i],
                padding=int((kernel_sizes[i] - 1) / 2),
                padding_mode=padding_mode,
                initialize=True,
            ))
        if i != len(layer_field_types) - 2:
            layers.append(activations[activation](layer_field_types[i + 1]))

    cnn = gnn.SequentialModule(*layers)

    # TODO: dirty fix to alleviate weird initialisations
    for p in cnn.parameters():
        if p.dim() == 0:
            p.data = p.data * modify_init
        else:
            p.data[:] = p.data * modify_init

    return nn.Sequential(
        Expression(lambda X: gnn.GeometricTensor(X, in_field_type)),
        cnn,
        Expression(lambda X: X.tensor),
    )
예제 #14
0
def conv1x1(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=0,
            dilation=1, bias=False):
    """1x1 convolution with padding"""
    return enn.R2Conv(in_type, out_type, 1,
                      stride=stride,
                      padding=padding,
                      dilation=dilation,
                      bias=bias,
                      sigma=None,
                      frequencies_cutoff=lambda r: 3*r,
                      )
예제 #15
0
def conv3x3(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=1,
            dilation=1, bias=True):
    """3x3 convolution with padding"""
    return enn.R2Conv(in_type, out_type, 3,
                      stride=stride,
                      padding=padding,
                      dilation=dilation,
                      bias=bias,
                      sigma=None,
                    #   frequencies_cutoff=lambda r: 3*r,
                      )
예제 #16
0
def conv7x7(in_type: enn.FieldType, out_type: enn.FieldType, stride=2, padding=1,
            dilation=1, bias=False):
    """3x3 convolution with padding"""
    return enn.R2Conv(in_type, out_type, 3,
                      stride=stride,
                      padding=padding,
                      dilation=dilation,
                      bias=bias,
                      sigma=None,
                      frequencies_cutoff=lambda r: 3*r,
                      #initialize = False,
                      )
예제 #17
0
def conv1x1(inplanes, out_planes, stride=1):
    """1x1 convolution"""
    in_type = FIELD_TYPE['regular'](gspace, inplanes)
    out_type = FIELD_TYPE['regular'](gspace, out_planes)
    return enn.R2Conv(in_type,
                      out_type,
                      1,
                      stride=stride,
                      bias=False,
                      sigma=None,
                      frequencies_cutoff=lambda r: 3 * r,
                      initialize=False)
예제 #18
0
 def __init__(self, input_frames, output_frames, kernel_size, N):
     super(ResNet_Rot, self).__init__()
     r2_act = gspaces.Rot2dOnR2(N = N)
     # we use rho_1 representation since the input is velocity fields 
     self.feat_type_in = nn.FieldType(r2_act, input_frames*[r2_act.irrep(1)])
     # we use regular representation for middle layers
     self.feat_type_in_hid = nn.FieldType(r2_act, 16*[r2_act.regular_repr])
     self.feat_type_hid_out = nn.FieldType(r2_act, 192*[r2_act.regular_repr])
     self.feat_type_out = nn.FieldType(r2_act, output_frames*[r2_act.irrep(1)])
     
     self.input_layer = nn.SequentialModule(
         nn.R2Conv(self.feat_type_in, self.feat_type_in_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(self.feat_type_in_hid),
         nn.ReLU(self.feat_type_in_hid)
     )
     layers = [self.input_layer]
     layers += [rot_resblock(16, 32, kernel_size, N), rot_resblock(32, 32, kernel_size, N)]
     layers += [rot_resblock(32, 64, kernel_size, N), rot_resblock(64, 64, kernel_size, N)]
     layers += [rot_resblock(64, 128, kernel_size, N), rot_resblock(128, 128, kernel_size, N)]
     layers += [rot_resblock(128, 192, kernel_size, N), rot_resblock(192, 192, kernel_size, N)]
     layers += [nn.R2Conv(self.feat_type_hid_out, self.feat_type_out, kernel_size = kernel_size, padding = (kernel_size - 1)//2)]
     self.model = torch.nn.Sequential(*layers)
예제 #19
0
def conv7x7(inplanes, out_planes, stride=2, padding=3, bias=False):
    """7x7 convolution with padding"""
    in_type = enn.FieldType(gspace, inplanes * [gspace.trivial_repr])
    out_type = FIELD_TYPE['regular'](gspace, out_planes)
    return enn.R2Conv(
        in_type,
        out_type,
        7,
        stride=stride,
        padding=padding,
        bias=bias,
        sigma=None,
        frequencies_cutoff=lambda r: 3 * r,
    )
예제 #20
0
def conv3x3(inplanes, out_planes, stride=1, padding=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    in_type = FIELD_TYPE['regular'](gspace, inplanes)
    out_type = FIELD_TYPE['regular'](gspace, out_planes)
    return enn.R2Conv(in_type,
                      out_type,
                      3,
                      stride=stride,
                      padding=padding,
                      groups=groups,
                      bias=False,
                      dilation=dilation,
                      sigma=None,
                      frequencies_cutoff=lambda r: 3 * r,
                      initialize=False)
예제 #21
0
 def __init__(self,
              in_type,
              out_type,
              group_action_type,
              kernel_size,
              stride,
              padding,
              bias=True,
              coeff=0.97,
              domain=2,
              codomain=2,
              n_iterations=None,
              atol=None,
              rtol=None,
              **unused_kwargs):
     del unused_kwargs
     super(InducedNormEquivarConv2d, self).__init__()
     self.in_channels = in_type.size
     self.out_channels = out_type.size
     self.group_action_type = group_action_type
     self.kernel_size = _pair(kernel_size)
     self.stride = _pair(stride)
     self.padding = _pair(padding)
     self.coeff = coeff
     self.n_iterations = n_iterations
     self.domain = domain
     self.codomain = codomain
     self.atol = atol
     self.rtol = rtol
     self.equivar_conv = enn.R2Conv(in_type,
                                    out_type,
                                    kernel_size=kernel_size,
                                    stride=stride,
                                    padding=padding,
                                    bias=bias)
     self.weight, expanded_bias = self.equivar_conv.expand_parameters()
     # the input image is a scalar field, corresponding to the trivial representation
     self.in_type = in_type
     self.out_type = out_type
     if bias:
         self.bias = expanded_bias
     else:
         self.register_parameter('bias', None)
     self.register_buffer('initialized', torch.tensor(0))
     self.register_buffer('spatial_dims', torch.tensor([1., 1.]))
     self.register_buffer('scale', torch.tensor(0.))
     self.register_buffer('u', self.weight.new_empty(self.out_channels))
     self.register_buffer('v', self.weight.new_empty(self.in_channels))
예제 #22
0
 def __init__(self, nclasses=1):
     super(ResNet50, self).__init__()
     self.gspace = gspaces.Rot2dOnR2(N=8)
     
     reg_field64 = FIELD_TYPE["regular"](self.gspace, 64, fixparams=False)
     reg_field256 = FIELD_TYPE["regular"](self.gspace, 256, fixparams=False)
     reg_field128 = FIELD_TYPE["regular"](self.gspace, 128, fixparams=False)
     reg_field512 = FIELD_TYPE["regular"](self.gspace, 512, fixparams=False)
     reg_field1024 = FIELD_TYPE["regular"](self.gspace, 1024, fixparams=False)
     reg_field2048 = FIELD_TYPE["regular"](self.gspace, 2048, fixparams=False)
     
     self.conv1 = enn.R2Conv(FIELD_TYPE["trivial"](self.gspace, 3, fixparams=False),
                             reg_field64, kernel_size=7, stride=2, padding=3)
     self.bn1 = enn.InnerBatchNorm(reg_field64)
     self.relu1 = enn.ELU(reg_field64)
     self.maxpool1 = enn.PointwiseMaxPoolAntialiased(reg_field64, kernel_size=2)
     
     layer1 = []
     layer1.append(ResBlock(stride=2, in_type = reg_field64, inner_type = reg_field64, out_type = reg_field256))
     layer1.append(ResBlock(stride=1, in_type = reg_field256, inner_type = reg_field64, out_type = reg_field256))
     layer1.append(ResBlock(stride=1, in_type = reg_field256, inner_type = reg_field64, out_type = reg_field256))
     self.layer1 = torch.nn.Sequential(*layer1)
     
     layer2 = []
     layer2.append(ResBlock(stride=2, in_type = reg_field256, inner_type = reg_field128, out_type = reg_field512))
     layer2.append(ResBlock(stride=1, in_type = reg_field512, inner_type = reg_field128, out_type = reg_field512))
     layer2.append(ResBlock(stride=1, in_type = reg_field512, inner_type = reg_field128, out_type = reg_field512))
     layer2.append(ResBlock(stride=1, in_type = reg_field512, inner_type = reg_field128, out_type = reg_field512))
     self.layer2 = torch.nn.Sequential(*layer2)
     
     layer3 = []
     layer3.append(ResBlock(stride=2, in_type = reg_field512, inner_type = reg_field256, out_type = reg_field1024))
     layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024))
     layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024))
     layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024))
     layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024))
     layer3.append(ResBlock(stride=1, in_type = reg_field1024, inner_type = reg_field256, out_type = reg_field1024))
     self.layer3 = torch.nn.Sequential(*layer3)
     
     layer4 = []
     layer4.append(ResBlock(stride=2, in_type = reg_field1024, inner_type = reg_field512, out_type = reg_field2048))
     layer4.append(ResBlock(stride=1, in_type = reg_field2048, inner_type = reg_field512, out_type = reg_field2048))
     layer4.append(ResBlock(stride=1, in_type = reg_field2048, inner_type = reg_field512, out_type = reg_field2048))
     self.layer4 = torch.nn.Sequential(*layer4)
     
     self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
     self.fc = torch.nn.Linear(2048, nclasses)
def generate_2d_rot8(out_path):
    r2_act = gspaces.Rot2dOnR2(N=8)
    feat_type_in = gnn.FieldType(r2_act, [r2_act.trivial_repr])
    feat_type_out = gnn.FieldType(r2_act, 3 * [r2_act.regular_repr])
    conv = gnn.R2Conv(feat_type_in, feat_type_out, kernel_size=3, bias=False)
    xs, ys, ws = [], [], []
    for task_idx in range(10000):
        gnn.init.generalized_he_init(conv.weights, conv.basisexpansion)
        inp = gnn.GeometricTensor(torch.randn(20, 1, 32, 32), feat_type_in)
        result = conv(inp).tensor.detach().cpu().numpy()
        xs.append(inp.tensor.detach().cpu().numpy())
        ys.append(result)
        ws.append(conv.weights.detach().cpu().numpy())
        if task_idx % 100 == 0:
            print(f"Finished generating task {task_idx}")
    xs, ys, ws = np.stack(xs), np.stack(ys), np.stack(ws)
    np.savez(out_path, x=xs, y=ys, w=ws)
예제 #24
0
 def __init__(self):
     super(ModelDilated, self).__init__()
     N = 8
     self.gspace = gspaces.Rot2dOnR2(N)
     self.in_type = enn.FieldType(self.gspace,
                                  [self.gspace.trivial_repr] * 3)
     self.out_type = enn.FieldType(self.gspace,
                                   [self.gspace.regular_repr] * 16)
     self.layer = enn.R2Conv(
         self.in_type,
         self.out_type,
         3,
         stride=1,
         padding=2,
         dilation=2,
         bias=True,
     )
     self.invariant = enn.GroupPooling(self.out_type)
예제 #25
0
def conv5x5(in_type: enn.FieldType,
            out_type: enn.FieldType,
            stride=1,
            padding=2,
            dilation=1,
            bias=True):
    """5x5 convolution with padding"""
    return enn.R2Conv(
        in_type,
        out_type,
        5,
        stride=stride,
        padding=padding,
        dilation=dilation,
        bias=bias,
        sigma=None,
        #   frequencies_cutoff=lambda r: 3*r,
        initialize=False,
    )
예제 #26
0
def conv3x3(in_type: enn.FieldType,
            out_type: enn.FieldType,
            stride=1,
            padding=1,
            dilation=1,
            bias=True,
            frequencies_cutoff=None):
    """3x3 convolution with padding"""
    return enn.R2Conv(
        in_type,
        out_type,
        3,
        stride=stride,
        padding=padding,
        dilation=dilation,
        bias=bias,
        sigma=None,
        frequencies_cutoff=frequencies_cutoff,
        initialize=False,
    )
예제 #27
0
 def __init__(self, in_type,inner_type, out_type, stride=1):
     super(ResBlock, self).__init__()
     
     self.in_type = in_type
     self.inner_type = inner_type
     self.out_type = out_type
     
     self.conv1 = conv1x1(self.in_type, self.inner_type, stride = 1, bias = False)
     self.bn1 = enn.InnerBatchNorm(self.inner_type)
     self.relu1 = enn.ReLU(self.inner_type)
     
     self.conv2 = conv3x3(self.inner_type, self.inner_type, padding=1, stride = stride, bias = False)
     self.bn2 = enn.InnerBatchNorm(self.inner_type)
     self.relu2 = enn.ReLU(self.inner_type, inplace=True)
     
     self.conv3 = conv1x1(self.inner_type, self.out_type, stride = 1, bias = False)
     self.bn3 = enn.InnerBatchNorm(self.out_type)
     self.relu3 = enn.ReLU(self.out_type, inplace=True)
     
     self.shortcut = None
     if stride != 1 or self.in_type != self.out_type:
         self.shortcut = enn.R2Conv(self.in_type, self.out_type, kernel_size=1, stride=stride, bias=False)
예제 #28
0
def convnxn(inplanes,
            outplanes,
            kernel_size=3,
            stride=1,
            padding=0,
            groups=1,
            bias=False,
            dilation=1):
    in_type = FIELD_TYPE['regular'](gspace, inplanes)
    out_type = FIELD_TYPE['regular'](gspace, outplanes)
    return enn.R2Conv(
        in_type,
        out_type,
        kernel_size,
        stride=stride,
        padding=padding,
        groups=groups,
        bias=bias,
        dilation=dilation,
        sigma=None,
        frequencies_cutoff=lambda r: 3 * r,
    )
예제 #29
0
    def __init__(self,
                 base = 'DNSteerableAGRadGalNet',
                 attention_module='SelfAttention',
                 attention_gates=3,
                 attention_aggregation='ft',
                 n_classes=2,
                 attention_normalisation='sigmoid',
                 quiet=True,
                 number_rotations=8,
                 imsize=150,
                 kernel_size=3,
                 group="D"
                ):
        super(DNSteerableAGRadGalNet, self).__init__()
        aggregation_mode = attention_aggregation
        normalisation = attention_normalisation
        AG = int(attention_gates)
        N = int(number_rotations)
        kernel_size = int(kernel_size)
        imsize = int(imsize)
        n_classes = int(n_classes)
        assert aggregation_mode in ['concat', 'mean', 'deep_sup', 'ft'], 'Aggregation mode not recognised. Valid inputs include concat, mean, deep_sup or ft.'
        assert normalisation in ['sigmoid','range_norm','std_mean_norm','tanh','softmax'], f'Nomralisation not implemented. Can be any of: sigmoid, range_norm, std_mean_norm, tanh, softmax'
        assert AG in [0,1,2,3], f'Number of Attention Gates applied (AG) must be an integer in range [0,3]. Currently AG={AG}'
        assert group.lower() in ["d","c"], f"group parameter must either be 'D' for DN, or 'C' for CN, steerable networks. (currently {group})."
        filters = [6,16,32,64,128]

        self.attention_out_sizes = []
        self.ag = AG
        self.n_classes = n_classes
        self.filters = filters
        self.aggregation_mode = aggregation_mode

        # Setting up e2
        if group.lower() == "d":
            self.r2_act = gspaces.FlipRot2dOnR2(N=int(number_rotations))
        else:
            self.r2_act = gspaces.Rot2dOnR2(N=int(number_rotations))
        in_type = e2nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
        out_type = e2nn.FieldType(self.r2_act, 6*[self.r2_act.regular_repr])
        self.in_type = in_type

        self.mask = e2nn.MaskModule(in_type, imsize, margin=0)
        self.conv1a = e2nn.R2Conv(in_type,  out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu1a = e2nn.ReLU(out_type); self.bnorm1a= e2nn.InnerBatchNorm(out_type)
        self.conv1b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu1b = e2nn.ReLU(out_type); self.bnorm1b= e2nn.InnerBatchNorm(out_type)
        self.conv1c = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu1c = e2nn.ReLU(out_type); self.bnorm1c= e2nn.InnerBatchNorm(out_type)
        self.mpool1 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2)
        self.gpool1 = e2nn.GroupPooling(out_type)


        in_type = out_type
        out_type = e2nn.FieldType(self.r2_act, 16*[self.r2_act.regular_repr])
        self.conv2a = e2nn.R2Conv(in_type,  out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu2a = e2nn.ReLU(out_type); self.bnorm2a= e2nn.InnerBatchNorm(out_type)
        self.conv2b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu2b = e2nn.ReLU(out_type); self.bnorm2b= e2nn.InnerBatchNorm(out_type)
        self.conv2c = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu2c = e2nn.ReLU(out_type); self.bnorm2c= e2nn.InnerBatchNorm(out_type)
        self.mpool2 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2)
        self.gpool2 = e2nn.GroupPooling(out_type)

        in_type = out_type
        out_type = e2nn.FieldType(self.r2_act, 32*[self.r2_act.regular_repr])
        self.conv3a = e2nn.R2Conv(in_type,  out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu3a = e2nn.ReLU(out_type); self.bnorm3a= e2nn.InnerBatchNorm(out_type)
        self.conv3b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu3b = e2nn.ReLU(out_type); self.bnorm3b= e2nn.InnerBatchNorm(out_type)
        self.conv3c = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu3c = e2nn.ReLU(out_type); self.bnorm3c= e2nn.InnerBatchNorm(out_type)
        self.mpool3 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2)
        self.gpool3 = e2nn.GroupPooling(out_type)

        in_type = out_type
        out_type = e2nn.FieldType(self.r2_act, 64*[self.r2_act.regular_repr])
        self.conv4a = e2nn.R2Conv(in_type,  out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu4a = e2nn.ReLU(out_type); self.bnorm4a= e2nn.InnerBatchNorm(out_type)
        self.conv4b = e2nn.R2Conv(out_type, out_type, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False); self.relu4b = e2nn.ReLU(out_type); self.bnorm4b= e2nn.InnerBatchNorm(out_type)
        self.mpool4 = e2nn.PointwiseMaxPool(out_type, kernel_size=(2,2), stride=2)
        self.gpool4 = e2nn.GroupPooling(out_type)

        self.flatten = nn.Flatten(1)
        self.dropout = nn.Dropout(p=0.5)

        if self.ag == 0:
            pass
        if self.ag >= 1:
            self.attention1 = GridAttentionBlock2D(in_channels=32, gating_channels=64, inter_channels=64, input_size=[imsize//4,imsize//4], normalisation=normalisation)
        if self.ag >= 2:
            self.attention2 = GridAttentionBlock2D(in_channels=16, gating_channels=64, inter_channels=64, input_size=[imsize//2,imsize//2], normalisation=normalisation)
        if self.ag >= 3:
            self.attention3 = GridAttentionBlock2D(in_channels=6, gating_channels=64, inter_channels=64, input_size=[imsize,imsize], normalisation=normalisation)

        self.fc1 = nn.Linear(16*5*5,256) #channel_size * width * height
        self.fc2 = nn.Linear(256,256)
        self.fc3 = nn.Linear(256, self.n_classes)
        self.dummy = nn.Parameter(torch.empty(0))

        self.module_order = ['conv1a', 'relu1a', 'bnorm1a', #1->6
                             'conv1b', 'relu1b', 'bnorm1b', #6->6
                             'conv1c', 'relu1c', 'bnorm1c', #6->6
                             'mpool1',
                             'conv2a', 'relu2a', 'bnorm2a', #6->16
                             'conv2b', 'relu2b', 'bnorm2b', #16->16
                             'conv2c', 'relu2c', 'bnorm2c', #16->16
                             'mpool2',
                             'conv3a', 'relu3a', 'bnorm3a', #16->32
                             'conv3b', 'relu3b', 'bnorm3b', #32->32
                             'conv3c', 'relu3c', 'bnorm3c', #32->32
                             'mpool3',
                             'conv4a', 'relu4a', 'bnorm4a', #32->64
                             'conv4b', 'relu4b', 'bnorm4b', #64->64
                             'compatibility_score1',
                             'compatibility_score2']


        #########################
        # Aggreagation Strategies
        if self.ag != 0:
            self.attention_filter_sizes = [32, 16, 6]
            concat_length = 0
            for i in range(self.ag):
                concat_length += self.attention_filter_sizes[i]
            if aggregation_mode == 'concat':
                self.classifier = nn.Linear(concat_length, self.n_classes)
                self.aggregate = self.aggregation_concat
            else:
                # Not able to initialise in a loop as the modules will not change device with remaining model.
                self.classifiers = nn.ModuleList()
                if self.ag>=1:
                    self.classifiers.append(nn.Linear(self.attention_filter_sizes[0], self.n_classes))
                if self.ag>=2:
                    self.classifiers.append(nn.Linear(self.attention_filter_sizes[1], self.n_classes))
                if self.ag>=3:
                    self.classifiers.append(nn.Linear(self.attention_filter_sizes[2], self.n_classes))
                if aggregation_mode == 'mean':
                    self.aggregate = self.aggregation_sep
                elif aggregation_mode == 'deep_sup':
                    self.classifier = nn.Linear(concat_length, self.n_classes)
                    self.aggregate = self.aggregation_ds
                elif aggregation_mode == 'ft':
                    self.classifier = nn.Linear(self.n_classes*self.ag, self.n_classes)
                    self.aggregate = self.aggregation_ft
                else:
                    raise NotImplementedError
        else:
            self.classifier = nn.Linear((150//16)**2*64, self.n_classes)
            self.aggregate = lambda x: self.classifier(self.flatten(x))
    def __init__(self, in_type, out_type):
        super(Conv, self).__init__()

        self.conv = enn.SequentialModule(
            enn.R2Conv(in_type, out_type, kernel_size=3, stride=1, padding=0),
            enn.InnerBatchNorm(out_type), enn.ReLU(out_type))