示例#1
0
    def __init__(self, input_frames, output_frames, kernel_size, N):
        super(Unet_Rot, self).__init__()
        r2_act = gspaces.Rot2dOnR2(N = N)
        self.feat_type_in = nn.FieldType(r2_act, input_frames*[r2_act.irrep(1)])
        self.feat_type_in_hid = nn.FieldType(r2_act, 32*[r2_act.regular_repr])
        self.feat_type_hid_out = nn.FieldType(r2_act, (16 + input_frames)*[r2_act.irrep(1)])
        self.feat_type_out = nn.FieldType(r2_act, output_frames*[r2_act.irrep(1)])
        
        self.conv1 = nn.SequentialModule(
            nn.R2Conv(self.feat_type_in, self.feat_type_in_hid, kernel_size = kernel_size, stride = 2, padding = (kernel_size - 1)//2),
            nn.InnerBatchNorm(self.feat_type_in_hid),
            nn.ReLU(self.feat_type_in_hid)
        )

        self.conv2 = rot_conv2d(32, 64, kernel_size = kernel_size, stride = 1, N = N)
        self.conv2_1 = rot_conv2d(64, 64, kernel_size = kernel_size, stride = 1, N = N)
        self.conv3 = rot_conv2d(64, 128, kernel_size = kernel_size, stride = 2, N = N)
        self.conv3_1 = rot_conv2d(128, 128, kernel_size = kernel_size, stride = 1, N = N)
        self.conv4 = rot_conv2d(128, 256, kernel_size = kernel_size, stride = 2, N = N)
        self.conv4_1 = rot_conv2d(256, 256, kernel_size = kernel_size, stride = 1, N = N)

        self.deconv3 = rot_deconv2d(256, 64, N)
        self.deconv2 = rot_deconv2d(192, 32, N)
        self.deconv1 = rot_deconv2d(96, 16, N, last_deconv = True)

    
        self.output_layer = nn.R2Conv(self.feat_type_hid_out, self.feat_type_out, kernel_size = kernel_size, padding = (kernel_size - 1)//2)
示例#2
0
    def __init__(self, channel_in=3, n_classes=4, rot_n=4):
        super(SmallE2, self).__init__()

        r2_act = gspaces.Rot2dOnR2(N=rot_n)

        self.feat_type_in = nn.FieldType(r2_act,
                                         channel_in * [r2_act.trivial_repr])
        feat_type_hid = nn.FieldType(r2_act, 8 * [r2_act.regular_repr])
        feat_type_out = nn.FieldType(r2_act, 2 * [r2_act.regular_repr])

        self.bn = nn.InnerBatchNorm(feat_type_hid)
        self.relu = nn.ReLU(feat_type_hid)

        self.convin = nn.R2Conv(self.feat_type_in,
                                feat_type_hid,
                                kernel_size=3)
        self.convhid = nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size=3)
        self.convout = nn.R2Conv(feat_type_hid, feat_type_out, kernel_size=3)

        self.avgpool = nn.PointwiseAvgPool(feat_type_out, 3)
        self.invariant_map = nn.GroupPooling(feat_type_out)

        c = self.invariant_map.out_type.size

        self.lin_in = torch.nn.Linear(c, 64)
        self.elu = torch.nn.ELU()
        self.lin_out = torch.nn.Linear(64, n_classes)
示例#3
0
 def __init__(self, 
              input_channels,
              hidden_dim, 
              kernel_size, 
              N # Group size 
             ): 
     super(rot_resblock, self).__init__()
     
     # Specify symmetry transformation
     r2_act = gspaces.Rot2dOnR2(N = N)
     feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
     feat_type_hid = nn.FieldType(r2_act, hidden_dim*[r2_act.regular_repr])
     
     self.layer1 = nn.SequentialModule(
         nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(feat_type_hid),
         nn.ReLU(feat_type_hid)
     ) 
     
     self.layer2 = nn.SequentialModule(
         nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(feat_type_hid),
         nn.ReLU(feat_type_hid)
     )    
     
     self.upscale = nn.SequentialModule(
         nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(feat_type_hid),
         nn.ReLU(feat_type_hid)
     )    
     
     self.input_channels = input_channels
     self.hidden_dim = hidden_dim
示例#4
0
    def __init__(self, in_type, num_classes=10):
        super(ClassificationHead, self).__init__()
        gspace = in_type.gspace
        self.add_module('gpool', nn.GroupPooling(in_type))

        # number of output channels
        # Fully Connected
        in_type = self.gpool.out_type
        out_type = nn.FieldType(gspace, 64 * [gspace.trivial_repr])
        self.add_module(
            'linear1',
            sscnn.e2cnn.PlainConv(in_type,
                                  out_type,
                                  kernel_size=1,
                                  padding=0,
                                  bias=False))
        self.add_module('relu1', nn.ReLU(out_type, inplace=True))
        in_type = out_type
        out_type = nn.FieldType(gspace, num_classes * [gspace.trivial_repr])
        self.add_module(
            'linear2',
            sscnn.e2cnn.PlainConv(in_type,
                                  out_type,
                                  kernel_size=1,
                                  padding=0,
                                  bias=False))
示例#5
0
    def give_feat_types(self):
        '''
        Output: feat_types - list of features types (see class ) 
                           - self.fib_reps[i]=[k_1,...,k_l] gives a list of integers where
                             k_i stands for irrep(k_i) of the rotation group or if k_i=-1 for the regular representation
                             the sume of rep(k_1),...,rep(k_l) determines the ith element of "feat_types"
        '''
        #Feat type of embedding consist of sums of trivial and context fiber representation:
        feat_types = [
            G_CNN.FieldType(self.G_act,
                            [self.G_act.trivial_repr, self.context_rep])
        ]
        #Go over all hidden fiber reps:
        for ids in self.hidden_reps_ids:
            #New layer collects the sum of individual representations to one list:
            new_layer = self.give_reps_from_ids(ids)

            #Append a new feature type given by the new layer:
            feat_types.append(G_CNN.FieldType(self.G_act, new_layer))

        #Get the fiber representation for the pre-covariance tensor:
        pre_cov_rep = cov_activ_func.get_pre_cov_rep(self.G_act,
                                                     self.dim_cov_est)

        #The final fiber representation is given by the sum of the identity (=rotation) representation and
        #the covariance matrix:
        feat_types.append(
            G_CNN.FieldType(self.G_act, [self.target_rep, pre_cov_rep]))
        return (feat_types)
示例#6
0
    def __init__(self, in_chan, out_chan, imsize, kernel_size=5, N=8):
        super(DNRestrictedLeNet, self).__init__()

        z = imsize // 2 // 2

        self.r2_act = gspaces.FlipRot2dOnR2(N)

        in_type = e2nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
        self.input_type = in_type

        out_type = e2nn.FieldType(self.r2_act, 6 * [self.r2_act.regular_repr])
        self.mask = e2nn.MaskModule(in_type, imsize, margin=1)
        self.conv1 = e2nn.R2Conv(in_type,
                                 out_type,
                                 kernel_size=kernel_size,
                                 padding=kernel_size // 2,
                                 bias=False)
        self.relu1 = e2nn.ReLU(out_type, inplace=True)
        self.pool1 = e2nn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2)

        self.gpool = e2nn.GroupPooling(out_type)

        self.conv2 = nn.Conv2d(6, 16, kernel_size, padding=kernel_size // 2)

        self.fc1 = nn.Linear(16 * z * z, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, out_chan)

        self.drop = nn.Dropout(p=0.5)

        # dummy parameter for tracking device
        self.dummy = nn.Parameter(torch.empty(0))
示例#7
0
    def __init__(self, in_type, conv_func):
        super(RegressionHead, self).__init__()
        gspace = in_type.gspace
        self.add_module('gpool', nn.PointwiseAdaptiveMaxPool(in_type, (1, 1)))

        if isinstance(in_type.gspace, e2cnn.gspaces.Rot2dOnR2):
            base = 8
        elif isinstance(in_type.gspace, e2cnn.gspaces.FlipRot2dOnR2):
            base = 4

        # number of output channels
        # Fully Connected
        in_type = in_type
        out_type = nn.FieldType(gspace, 2 * base * [gspace.regular_repr])
        self.add_module(
            'block1',
            nn.SequentialModule(
                conv_func(in_type,
                          out_type,
                          kernel_size=1,
                          padding=0,
                          bias=False), nn.ReLU(out_type, inplace=True)))
        in_type = out_type
        if isinstance(gspace, gspaces.Rot2dOnR2):
            out_type = nn.FieldType(gspace, [gspace.irrep(1)])
        elif isinstance(gspace, gspaces.FlipRot2dOnR2):
            out_type = nn.FieldType(gspace, [gspace.irrep(1, 1)])
        else:
            raise NotImplementedError

        self.add_module(
            'block2',
            conv_func(in_type, out_type, kernel_size=1, padding=0, bias=False))
示例#8
0
 def __init__(self, input_channels, output_channels, kernel_size, stride, N, activation = True, deconv = False, last_deconv = False):
     super(rot_conv2d, self).__init__()       
     r2_act = gspaces.Rot2dOnR2(N = N)
     
     feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
     feat_type_hid = nn.FieldType(r2_act, output_channels*[r2_act.regular_repr])
     if not deconv:
         if activation:
             self.layer = nn.SequentialModule(
                 nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = (kernel_size - 1)//2),
                 nn.InnerBatchNorm(feat_type_hid),
                 nn.ReLU(feat_type_hid)
             ) 
         else:
             self.layer = nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride,padding = (kernel_size - 1)//2)
     else:
         if last_deconv:
             feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr])
             feat_type_hid = nn.FieldType(r2_act, output_channels*[r2_act.irrep(1)])
             self.layer = nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = 0)
         else:
             self.layer = nn.SequentialModule(
                     nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = 0),
                     nn.InnerBatchNorm(feat_type_hid),
                     nn.ReLU(feat_type_hid)
                 ) 
示例#9
0
    def __init__(self, n_classes=6):
        super(SteerCNN, self).__init__()

        # the model is equivariant under rotations by 45 degrees, modelled by C8
        self.r2_act = gspaces.Rot2dOnR2(N=4)

        # the input image is a scalar field, corresponding to the trivial representation
        input_type = nn_e2.FieldType(self.r2_act,
                                     3 * [self.r2_act.trivial_repr])

        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = input_type
        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 24 feature fields, each transforming under the regular representation of C8
        out_type = nn_e2.FieldType(self.r2_act,
                                   24 * [self.r2_act.regular_repr])
        self.block1 = nn_e2.SequentialModule(
            nn_e2.R2Conv(input_type,
                         out_type,
                         kernel_size=7,
                         padding=3,
                         bias=False), nn_e2.InnerBatchNorm(out_type),
            nn_e2.ReLU(out_type, inplace=True))

        self.pool1 = nn_e2.PointwiseAvgPool(out_type, 4)

        # convolution 2
        # the old output type is the input type to the next layer
        in_type = self.block1.out_type
        # the output type of the second convolution layer are 48 regular feature fields of C8
        #out_type = nn_e2.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr])
        self.block2 = nn_e2.SequentialModule(
            nn_e2.R2Conv(in_type,
                         out_type,
                         kernel_size=7,
                         padding=3,
                         bias=False), nn_e2.InnerBatchNorm(out_type),
            nn_e2.ReLU(out_type, inplace=True))
        self.pool2 = nn_e2.SequentialModule(
            nn_e2.PointwiseAvgPoolAntialiased(out_type,
                                              sigma=0.66,
                                              stride=1,
                                              padding=0),
            nn_e2.PointwiseAvgPool(out_type, 4), nn_e2.GroupPooling(out_type))
        # PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=7)

        # number of output channels
        c = 24 * 13 * 13  #self.gpool.out_type.size

        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(c, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )
    def __init__(self, input_shape, num_actions, dueling_DQN):
        super(D4_steerable_DQN_Snake, self).__init__()
        self.input_shape = input_shape
        self.num_actions = num_actions
        self.dueling_DQN = dueling_DQN
        self.r2_act = gspaces.FlipRot2dOnR2(N=4)
        self.input_type = nn.FieldType(
            self.r2_act, input_shape[0] * [self.r2_act.trivial_repr])
        feature1_type = nn.FieldType(self.r2_act,
                                     8 * [self.r2_act.regular_repr])
        feature2_type = nn.FieldType(self.r2_act,
                                     12 * [self.r2_act.regular_repr])
        feature3_type = nn.FieldType(self.r2_act,
                                     12 * [self.r2_act.regular_repr])
        feature4_type = nn.FieldType(self.r2_act,
                                     32 * [self.r2_act.regular_repr])

        self.feature_field1 = nn.SequentialModule(
            nn.R2Conv(self.input_type,
                      feature1_type,
                      kernel_size=7,
                      padding=2,
                      stride=2,
                      bias=False), nn.ReLU(feature1_type, inplace=True))
        self.feature_field2 = nn.SequentialModule(
            nn.R2Conv(feature1_type,
                      feature2_type,
                      kernel_size=5,
                      padding=1,
                      stride=2,
                      bias=False), nn.ReLU(feature2_type, inplace=True))
        self.feature_field3 = nn.SequentialModule(
            nn.R2Conv(feature2_type,
                      feature3_type,
                      kernel_size=5,
                      padding=1,
                      stride=1,
                      bias=False), nn.ReLU(feature3_type, inplace=True))

        self.equivariant_features = nn.SequentialModule(
            nn.R2Conv(feature3_type,
                      feature4_type,
                      kernel_size=5,
                      stride=1,
                      bias=False), nn.ReLU(feature4_type, inplace=True))
        self.gpool = nn.GroupPooling(feature4_type)
        self.feature_shape()
        if self.dueling_DQN:
            print("You are using Dueling DQN")
            self.advantage = torch.nn.Linear(
                self.equivariant_features.out_type.size, self.num_actions)
            #self.value = torch.nn.Linear(self.gpool.out_type.size, 1)
            self.value = torch.nn.Linear(
                self.equivariant_features.out_type.size, 1)
        else:
            self.actionvalue = torch.nn.Linear(
                self.equivariant_features.out_type.size, self.num_actions)
示例#11
0
    def __init__(self,
                 base='DNSteerableLeNet',
                 in_chan=1,
                 n_classes=2,
                 imsize=150,
                 kernel_size=5,
                 N=8,
                 quiet=True,
                 number_rotations=None):
        super(DNSteerableLeNet, self).__init__()
        kernel_size = int(kernel_size)
        out_chan = int(n_classes)

        if number_rotations != None:
            N = int(number_rotations)

        z = imsize // 2 // 2

        self.r2_act = gspaces.FlipRot2dOnR2(N)

        in_type = e2nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
        self.input_type = in_type

        out_type = e2nn.FieldType(self.r2_act, 6 * [self.r2_act.regular_repr])
        self.mask = e2nn.MaskModule(in_type, imsize, margin=1)
        self.conv1 = e2nn.R2Conv(in_type,
                                 out_type,
                                 kernel_size=kernel_size,
                                 padding=kernel_size // 2,
                                 bias=False)
        self.relu1 = e2nn.ReLU(out_type, inplace=True)
        self.pool1 = e2nn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2)

        in_type = self.pool1.out_type
        out_type = e2nn.FieldType(self.r2_act, 16 * [self.r2_act.regular_repr])
        self.conv2 = e2nn.R2Conv(in_type,
                                 out_type,
                                 kernel_size=kernel_size,
                                 padding=kernel_size // 2,
                                 bias=False)
        self.relu2 = e2nn.ReLU(out_type, inplace=True)
        self.pool2 = e2nn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2)

        self.gpool = e2nn.GroupPooling(out_type)

        self.fc1 = nn.Linear(16 * z * z, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, out_chan)

        self.drop = nn.Dropout(p=0.5)

        # dummy parameter for tracking device
        self.dummy = nn.Parameter(torch.empty(0))
示例#12
0
def build_steer_cnn_decoder(
    context_rep_ids,
    hidden_reps_ids,
    kernel_sizes,
    mean_rep_ids,
    covariance_activation="quadratic",
    N=4,
    flip=True,
    max_frequency=30,
    activation="relu",
    padding_mode="zeros",
):
    if flip:
        gspace = (gspaces.FlipRot2dOnR2(N=N) if N != -1 else
                  gspaces.FlipRot2dOnR2(N=N, maximum_frequency=max_frequency))
    else:
        gspace = (gspaces.Rot2dOnR2(N=N) if N != -1 else gspaces.Rot2dOnR2(
            N=N, maximum_frequency=max_frequency))

    in_field_type = gnn.FieldType(
        gspace, [gspace.trivial_repr, *reps_from_ids(gspace, context_rep_ids)])

    hidden_field_types = [
        gnn.FieldType(gspace, reps_from_ids(gspace, ids))
        for ids in hidden_reps_ids
    ]

    mean_field_type = gnn.FieldType(gspace,
                                    reps_from_ids(gspace, mean_rep_ids))

    pre_covariance_field_type = get_pre_covariance_field_type(
        gspace, mean_field_type, covariance_activation)

    out_field_type = mean_field_type + pre_covariance_field_type

    init_modify = (1.0
                   if not (mean_rep_ids == [[0]] and context_rep_ids == [[0]])
                   else 0.833, )
    if N == -1:
        init_modify = 1.0

    return build_steer_cnn_2d(
        in_field_type,
        hidden_field_types,
        kernel_sizes,
        out_field_type,
        gspace,
        activation,
        padding_mode,
    )
示例#13
0
def small_wrn(N=4):
    import torch.nn as nn
    from collections import OrderedDict

    gspace = gspaces.FlipRot2dOnR2(N)
    r1 = enn.FieldType(gspace, [gspace.trivial_repr] * 3)
    r2 = enn.FieldType(gspace, [gspace.regular_repr] * 3)
    rout = enn.FieldType(gspace, [gspace.trivial_repr] * 256)
    wrn = Small_Standalone(in_type=r1,
                           out_type=rout,
                           inner_type=r2,
                           dropout_rate=0.3)
    model = nn.Sequential(OrderedDict([('wrn', wrn), ('fc', nn.ReLU())]))
    return model
    def build_nnet(self, dims, activation_fn=enn.ReLU):
        nnet = []
        domains, codomains = self.parse_vnorms()
        if self.args.learn_p:
            if self.args.mixed:
                domains = [
                    torch.nn.Parameter(torch.tensor(0.)) for _ in domains
                ]
            else:
                domains = [torch.nn.Parameter(torch.tensor(0.))] * len(domains)
            codomains = domains[1:] + [domains[0]]

        in_type = enn.FieldType(self.group_action_type,
                                [self.group_action_type.trivial_repr])
        out_dims = int(dims[1:][0] / self.group_card)
        out_type = enn.FieldType(
            self.group_action_type,
            out_dims * [self.group_action_type.regular_repr])
        total_layers = len(domains)
        for i, (in_dim, out_dim, domain, codomain) in enumerate(
                zip(dims[:-1], dims[1:], domains, codomains)):
            nnet.append(
                base_layers.get_equivar_conv2d(
                    in_type,
                    out_type,
                    self.group_action_type,
                    kernel_size=self.args.kernel_size,
                    stride=1,
                    padding=1,
                    coeff=self.args.coeff,
                    n_iterations=self.args.n_lipschitz_iters,
                    atol=self.args.atol,
                    rtol=self.args.rtol,
                    domain=domain,
                    codomain=codomain,
                    zero_init=(out_dim == 2),
                ))
            nnet.append(activation_fn(nnet[-1].out_type, inplace=True))
            in_type = nnet[-1].out_type
            if i == total_layers - 2:
                out_type = enn.FieldType(self.group_action_type,
                                         [self.group_action_type.trivial_repr])
            else:
                out_type = enn.FieldType(
                    self.group_action_type,
                    out_dim * [self.group_action_type.regular_repr])

        return torch.nn.Sequential(*nnet)
 def __init__(self,
              args,
              n_blocks,
              input_size,
              hidden_size,
              n_hidden,
              group_action_type=None):
     super(FiberRealNVP, self).__init__()
     _, self.c, self.h, self.w = input_size[:]
     assert self.c > 1
     mask = torch.arange(self.c).float() % 2
     self.n_blocks = int(n_blocks)
     self.n_hidden = n_hidden
     self.group_action_type = GROUPS[args.group]
     self.out_fiber = args.out_fiber
     self.field_type = args.field_type
     self.group_card = len(list(self.group_action_type.testing_elements))
     self.dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     i_mask = 1 - mask
     mask = torch.stack([mask,
                         i_mask]).repeat(int(self.n_blocks / 2) + 1, 1)
     self.p_z = StandardNormal
     self.input_type = enn.FieldType(
         self.group_action_type,
         self.c * [self.group_action_type.trivial_repr])
     self.activation_fn = ACT_FNS[args.act]
     self.s, self.t = create_equivariant_real_nvp_blocks(
         input_size, self.input_type, self.field_type, self.out_fiber,
         self.activation_fn, hidden_size, self.n_blocks, n_hidden,
         self.group_action_type, args.kernel_size, args.realnvp_padding)
     self.mask = nn.Parameter(mask, requires_grad=False)
    def check_invariance(self, r2_act, out_type, data, func, data_type=None):
        _, c, h, w = data.shape
        input_type = enn.FieldType(r2_act, self.c * [r2_act.trivial_repr])
        y = func(data)
        for g in r2_act.testing_elements:
            data = enn.GeometricTensor(
                data.tensor.view(-1, c, h, w).cpu(), input_type)
            x_transformed = enn.GeometricTensor(
                data.transform(g).tensor.view(-1, c, h, w).cuda(), input_type)

            y_from_x_transformed = func(x_transformed)

            y_transformed_from_x = y
            # Invariance Condition
            data = enn.GeometricTensor(
                data.tensor.squeeze().view(-1, c, h, w).cuda(), input_type)
            # assert torch.allclose(output_conv.squeeze(), output_rg_conv.squeeze(), atol=1e-5), g
            print_y = y_from_x_transformed.tensor.detach().to(
                'cpu').numpy().squeeze()

            print("{:4d} : {}".format(g, print_y))

            assert torch.allclose(y_from_x_transformed.tensor.squeeze(),
                                  y_transformed_from_x.tensor.squeeze(),
                                  atol=1e-5), g
        print("Passed Invariance Test")
示例#17
0
def check_equivariance(r2_act, out_type, data, func, data_type=None):
    input_type = enn.FieldType(r2_act, [r2_act.trivial_repr])
    if data_type == 'GeomTensor':
        data = enn.GeometricTensor(data.view(-1, 1, 1, 2), input_type)
    for g in r2_act.testing_elements:
        output = func(data)
        if data_type == 'GeomTensor':
            rg_output = enn.GeometricTensor(
                output.tensor.view(-1, 1, 1, 2).cpu(), out_type).transform(g)
            data = enn.GeometricTensor(
                data.tensor.view(-1, 1, 1, 2).cpu(), input_type)
            x_transformed = enn.GeometricTensor(
                data.transform(g).tensor.view(-1, 1, 1, 2), input_type)
        else:
            rg_output = enn.GeometricTensor(
                output.view(-1, 1, 1, 2).cpu(), out_type).transform(g)
            data = enn.GeometricTensor(
                data.view(-1, 1, 1, 2).cpu(), input_type)
            x_transformed = data.transform(g).tensor.view(-1, 1, 1, 2)

        output_rg = func(x_transformed)
        # Equivariance Condition
        if data_type == 'GeomTensor':
            output_rg = enn.GeometricTensor(output_rg.tensor.cpu(), out_type)
            data = enn.GeometricTensor(data.tensor.squeeze().view(-1, 1, 1, 2),
                                       input_type)
        else:
            output_rg = enn.GeometricTensor(
                output_rg.view(-1, 1, 1, 2).cpu(), out_type)
            data = data.tensor.squeeze()
        assert torch.allclose(rg_output.tensor.cpu().squeeze(),
                              output_rg.tensor.squeeze(),
                              atol=1e-5), g
 def __init__(self,
              args,
              n_blocks,
              input_size,
              hidden_size,
              n_hidden,
              group_action_type=None):
     super(EquivariantToyResFlow, self).__init__()
     self.args = args
     self.beta = args.beta
     self.n_blocks = n_blocks
     self.activation_fn = ACT_FNS[args.act]
     self.group_action_type = GROUPS[args.group]
     # self.group_action_type = gspaces.FlipRot2dOnR2(N=4)
     self.group_card = len(list(self.group_action_type.testing_elements))
     self.input_type = enn.FieldType(self.group_action_type,
                                     [self.group_action_type.trivial_repr])
     dims = [2] + list(map(int, args.dims.split('-'))) + [2]
     blocks = []
     if self.args.actnorm: blocks.append(layers.EquivariantActNorm1d(2))
     for _ in range(n_blocks):
         blocks.append(
             layers.Equivar_iResBlock(
                 self.build_nnet(dims, self.activation_fn),
                 n_dist=self.args.n_dist,
                 n_power_series=self.args.n_power_series,
                 exact_trace=self.args.exact_trace,
                 brute_force=self.args.brute_force,
                 n_samples=self.args.batch_size,
                 neumann_grad=True,
                 grad_in_forward=True,
             ))
         if self.args.actnorm: blocks.append(layers.EquivariantActNorm1d(2))
         if self.args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2))
     self.flow_model = layers.SequentialFlow(blocks)
示例#19
0
def mixed_fiber(gspace: gspaces.GeneralOnR2,
                planes: int,
                ratio: float,
                field_type: int = 0,
                fixparams: bool = True):

    N = gspace.fibergroup.order()
    assert N > 0
    if isinstance(gspace, gspaces.FlipRot2dOnR2):
        subgroup = (0, 1)
    elif isinstance(gspace, gspaces.Flip2dOnR2):
        subgroup = 1
    else:
        raise ValueError(f"Space {gspace} not supported")

    qr = gspace.quotient_repr(subgroup)
    rr = gspace.regular_repr

    planes = planes / rr.size

    if fixparams:
        planes *= math.sqrt(N * CHANNELS_CONSTANT)

    r_planes = int(planes * ratio)
    q_planes = int(2 * planes * (1 - ratio))

    return enn.FieldType(gspace, [rr] * r_planes + [qr] * q_planes).sorted()
示例#20
0
def quotient_fiber(gspace: gspaces.GeneralOnR2,
                   planes: int,
                   field_type: int = 0,
                   fixparams: bool = True):
    """ build a quotient fiber with the specified number of channels"""
    N = gspace.fibergroup.order()
    assert N > 0
    if isinstance(gspace, gspaces.FlipRot2dOnR2):
        n = N / 2
        subgroups = []
        for axis in [0, round(n / 4), round(n / 2)]:
            subgroups.append((int(axis), 1))
    elif isinstance(gspace, gspaces.Rot2dOnR2):
        assert N % 4 == 0
        # subgroups = [int(round(N/2)), int(round(N/4))]
        subgroups = [2, 4]
    elif isinstance(gspace, gspaces.Flip2dOnR2):
        subgroups = [2]
    else:
        raise ValueError(f"Space {gspace} not supported")

    rs = [gspace.quotient_repr(subgroup) for subgroup in subgroups]
    size = sum([r.size for r in rs])
    planes = planes / size
    if fixparams:
        planes *= math.sqrt(N * CHANNELS_CONSTANT)
    planes = int(planes)
    return enn.FieldType(gspace, rs * planes).sorted()
示例#21
0
def create_equivariant_convexp_blocks(input_size,
                                      in_type,
                                      field_type,
                                      out_fiber,
                                      activation_fn,
                                      hidden_size,
                                      n_blocks,
                                      n_hidden,
                                      group_action_type,
                                      kernel_size=3,
                                      padding=1):
    nets = []
    _, c, h, w = input_size
    input_type = in_type
    _, c, h, w = input_size
    out_type = enn.FieldType(group_action_type,
                             c * [group_action_type.trivial_repr])
    for i in range(n_blocks):
        s_block = [
            enn.R2Conv(in_type,
                       out_type,
                       kernel_size=kernel_size,
                       padding=padding,
                       bias=True),
            # enn.InnerBatchNorm(out_type),
            # activation_fn(out_type, inplace=True)
        ]
        nets += [MultiInputSequential(*s_block)]
    s = nets = MultiInputSequential(*nets)
    return s
示例#22
0
def trivial_feature_type(gspace: gspaces.GSpace, planes: int, fixparams: bool = True):
    """ build a trivial feature map with the specified number of channels"""
    
    if fixparams:
        planes *= math.sqrt(gspace.fibergroup.order())
        
    planes = int(planes)
    return enn.FieldType(gspace, [gspace.trivial_repr] * planes)
def generate_2d_rot8(out_path):
    r2_act = gspaces.Rot2dOnR2(N=8)
    feat_type_in = gnn.FieldType(r2_act, [r2_act.trivial_repr])
    feat_type_out = gnn.FieldType(r2_act, 3 * [r2_act.regular_repr])
    conv = gnn.R2Conv(feat_type_in, feat_type_out, kernel_size=3, bias=False)
    xs, ys, ws = [], [], []
    for task_idx in range(10000):
        gnn.init.generalized_he_init(conv.weights, conv.basisexpansion)
        inp = gnn.GeometricTensor(torch.randn(20, 1, 32, 32), feat_type_in)
        result = conv(inp).tensor.detach().cpu().numpy()
        xs.append(inp.tensor.detach().cpu().numpy())
        ys.append(result)
        ws.append(conv.weights.detach().cpu().numpy())
        if task_idx % 100 == 0:
            print(f"Finished generating task {task_idx}")
    xs, ys, ws = np.stack(xs), np.stack(ys), np.stack(ws)
    np.savez(out_path, x=xs, y=ys, w=ws)
示例#24
0
    def __init__(self):
        super(DenseFeatureExtractionModuleE2Inv, self).__init__()

        filters = np.array([32,32, 64,64, 128,128,128, 256,256,256, 512,512,512], dtype=np.int32)*2
        
        # number of rotations to consider for rotation invariance
        N = 8
        
        self.gspace = gspaces.Rot2dOnR2(N)
        self.input_type = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3)
        ip_op_types = [
            self.input_type,
        ]
        
        self.num_channels = 64

        for filter_ in filters[:10]:
            ip_op_types.append(FIELD_TYPE['regular'](self.gspace, filter_, fixparams=False))

        self.model = enn.SequentialModule(*[
            conv3x3(ip_op_types[0], ip_op_types[1]),
            enn.ReLU(ip_op_types[1], inplace=True),
            conv3x3(ip_op_types[1], ip_op_types[2]),
            enn.ReLU(ip_op_types[2], inplace=True),
            enn.PointwiseMaxPool(ip_op_types[2], 2),

            conv3x3(ip_op_types[2], ip_op_types[3]),
            enn.ReLU(ip_op_types[3], inplace=True),
            conv3x3(ip_op_types[3], ip_op_types[4]),
            enn.ReLU(ip_op_types[4], inplace=True),
            enn.PointwiseMaxPool(ip_op_types[4], 2),

            conv3x3(ip_op_types[4], ip_op_types[5]),
            enn.ReLU(ip_op_types[5], inplace=True),
            conv3x3(ip_op_types[5], ip_op_types[6]),
            enn.ReLU(ip_op_types[6], inplace=True),
            conv3x3(ip_op_types[6], ip_op_types[7]),
            enn.ReLU(ip_op_types[7], inplace=True),
            enn.PointwiseAvgPool(ip_op_types[7], kernel_size=2, stride=1),

            conv5x5(ip_op_types[7], ip_op_types[8]),
            enn.ReLU(ip_op_types[8], inplace=True),
            conv5x5(ip_op_types[8], ip_op_types[9]),
            enn.ReLU(ip_op_types[9], inplace=True),
            conv5x5(ip_op_types[9], ip_op_types[10]),
            enn.ReLU(ip_op_types[10], inplace=True),
            
            # enn.PointwiseMaxPool(ip_op_types[7], 2),

            # conv3x3(ip_op_types[7], ip_op_types[8]),
            # enn.ReLU(ip_op_types[8], inplace=True),
            # conv3x3(ip_op_types[8], ip_op_types[9]),
            # enn.ReLU(ip_op_types[9], inplace=True),
            # conv3x3(ip_op_types[9], ip_op_types[10]),
            # enn.ReLU(ip_op_types[10], inplace=True),
            enn.GroupPooling(ip_op_types[10])
        ])
示例#25
0
def regular_feature_type(gspace: gspaces.GSpace, planes: int):
    """ build a regular feature map with the specified number of channels"""
    assert gspace.fibergroup.order() > 0
    N = gspace.fibergroup.order()
    if fixparams:
        planes *= math.sqrt(N)
    planes = planes / N
    planes = int(planes)
    return enn.FieldType(gspace, [gspace.regular_repr] * planes)
示例#26
0
 def __init__(self):
     super(ModelDilated, self).__init__()
     N = 8
     self.gspace = gspaces.Rot2dOnR2(N)
     self.in_type = enn.FieldType(self.gspace,
                                  [self.gspace.trivial_repr] * 3)
     self.out_type = enn.FieldType(self.gspace,
                                   [self.gspace.regular_repr] * 16)
     self.layer = enn.R2Conv(
         self.in_type,
         self.out_type,
         3,
         stride=1,
         padding=2,
         dilation=2,
         bias=True,
     )
     self.invariant = enn.GroupPooling(self.out_type)
示例#27
0
def trivial_fiber(gspace: gspaces.GeneralOnR2,
                  planes: int,
                  field_type: int = 0,
                  fixparams: bool = True):
    """ build a trivial fiber with the specified number of channels"""

    if fixparams:
        planes *= math.sqrt(gspace.fibergroup.order() * CHANNELS_CONSTANT)
    planes = int(planes)
    return enn.FieldType(gspace, [gspace.trivial_repr] * planes)
示例#28
0
def irrep_fiber(gspace: gspaces.GeneralOnR2,
                planes: int,
                field_type: int = 0,
                fixparams: bool = True):
    """ build a irrep fiber with the specified number of channels"""
    assert gspace.fibergroup.order() < 0
    N = gspace.fibergroup.order()
    planes = int(planes)

    if planes % 2 != 0:
        planes += 1
    return enn.FieldType(gspace, [gspace.irrep(0)] * planes)
 def _build_net(self, input_size):
     _, c, h, w = input_size
     transforms = []
     _stacked_blocks = StackediResBlocks
     in_type = self.input_type
     my_i_dims = self.intermediate_dim
     out_type = FIBERS[self.out_fiber](self.group_action_type,
                                       my_i_dims,
                                       self.field_type,
                                       fixparams=True)
     for i in range(self.n_scale):
         transforms.append(
             _stacked_blocks(
                 in_type,
                 out_type,
                 self.group_action_type,
                 initial_size=(c, h, w),
                 idim=my_i_dims,
                 squeeze=False,  #Can't change channels/fibers
                 init_layer=self.init_layer if i == 0 else None,
                 n_blocks=self.n_blocks[i],
                 quadratic=self.quadratic,
                 actnorm=self.actnorm,
                 fc_actnorm=self.fc_actnorm,
                 batchnorm=self.batchnorm,
                 dropout=self.dropout,
                 fc=self.fc,
                 coeff=self.coeff,
                 vnorms=self.vnorms,
                 n_lipschitz_iters=self.n_lipschitz_iters,
                 sn_atol=self.sn_atol,
                 sn_rtol=self.sn_rtol,
                 n_power_series=self.n_power_series,
                 n_dist=self.n_dist,
                 n_samples=self.n_samples,
                 kernels=self.kernels,
                 activation_fn=self.activation_fn,
                 fc_end=self.fc_end,
                 fc_idim=self.fc_idim,
                 n_exact_terms=self.n_exact_terms,
                 preact=self.preact,
                 neumann_grad=self.neumann_grad,
                 grad_in_forward=self.grad_in_forward,
                 first_resblock=self.first_resblock and (i == 0),
                 learn_p=self.learn_p,
             ))
         c, h, w = c * 2 if self.factor_out else c * 4, h // 2, w // 2
         print("C: %d H: %d W: %d" % (c, h, w))
         if i == self.n_scale - 1:
             out_type = enn.FieldType(
                 self.group_action_type,
                 self.c * [self.group_action_type.trivial_repr])
     return nn.ModuleList(transforms)
示例#30
0
 def __init__(self, input_frames, output_frames, kernel_size, N):
     super(ResNet_Rot, self).__init__()
     r2_act = gspaces.Rot2dOnR2(N = N)
     # we use rho_1 representation since the input is velocity fields 
     self.feat_type_in = nn.FieldType(r2_act, input_frames*[r2_act.irrep(1)])
     # we use regular representation for middle layers
     self.feat_type_in_hid = nn.FieldType(r2_act, 16*[r2_act.regular_repr])
     self.feat_type_hid_out = nn.FieldType(r2_act, 192*[r2_act.regular_repr])
     self.feat_type_out = nn.FieldType(r2_act, output_frames*[r2_act.irrep(1)])
     
     self.input_layer = nn.SequentialModule(
         nn.R2Conv(self.feat_type_in, self.feat_type_in_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2),
         nn.InnerBatchNorm(self.feat_type_in_hid),
         nn.ReLU(self.feat_type_in_hid)
     )
     layers = [self.input_layer]
     layers += [rot_resblock(16, 32, kernel_size, N), rot_resblock(32, 32, kernel_size, N)]
     layers += [rot_resblock(32, 64, kernel_size, N), rot_resblock(64, 64, kernel_size, N)]
     layers += [rot_resblock(64, 128, kernel_size, N), rot_resblock(128, 128, kernel_size, N)]
     layers += [rot_resblock(128, 192, kernel_size, N), rot_resblock(192, 192, kernel_size, N)]
     layers += [nn.R2Conv(self.feat_type_hid_out, self.feat_type_out, kernel_size = kernel_size, padding = (kernel_size - 1)//2)]
     self.model = torch.nn.Sequential(*layers)