def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.gspace = Rot2dOnR2(4) self.input_type = FieldType(self.gspace, 3*[self.gspace.trivial_repr]) self.small_type = FieldType(self.gspace, 4*[self.gspace.regular_repr]) self.mid_type = FieldType(self.gspace, 16*[self.gspace.regular_repr]) self.model = nn.Sequential( R2Conv(self.input_type, self.small_type, kernel_size=3, padding=1, bias=False), InnerBatchNorm(self.small_type), ReLU(self.small_type), R2Conv(self.small_type, self.small_type, kernel_size=3, padding=1, bias=False), InnerBatchNorm(self.small_type), ReLU(self.small_type), R2Conv(self.small_type, self.small_type, kernel_size=3, padding=1, bias=False), InnerBatchNorm(self.small_type), ReLU(self.small_type), R2Conv(self.small_type, self.mid_type, kernel_size=3, padding=1, bias=False), InnerBatchNorm(self.mid_type), ReLU(self.mid_type), R2Conv(self.mid_type, self.small_type, kernel_size=3, padding=1, bias=False), InnerBatchNorm(self.small_type), ReLU(self.small_type), ) self.pool = GroupPooling(self.small_type) pool_out = self.pool.out_type.size self.final = nn.Conv2d(pool_out, 1, kernel_size=3, padding=1)
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.gspace = Rot2dOnR2(-1, maximum_frequency=2) self.input_type = FieldType(self.gspace, 3*[self.gspace.trivial_repr]) self.small_type = FieldType(self.gspace, 4*list(self.gspace.irreps.values())) self.mid_type = FieldType(self.gspace, 16*list(self.gspace.irreps.values())) self.model = nn.Sequential( R2Conv(self.input_type, self.small_type, kernel_size=3, padding=1, bias=False), GNormBatchNorm(self.small_type), NormNonLinearity(self.small_type), R2Conv(self.small_type, self.small_type, kernel_size=3, padding=1, bias=False), GNormBatchNorm(self.small_type), NormNonLinearity(self.small_type), R2Conv(self.small_type, self.small_type, kernel_size=3, padding=1, bias=False), GNormBatchNorm(self.small_type), NormNonLinearity(self.small_type), R2Conv(self.small_type, self.mid_type, kernel_size=3, padding=1, bias=False), GNormBatchNorm(self.mid_type), NormNonLinearity(self.mid_type), R2Conv(self.mid_type, self.small_type, kernel_size=3, padding=1, bias=False), GNormBatchNorm(self.small_type), NormNonLinearity(self.small_type), ) self.pool = NormPool(self.small_type) pool_out = self.pool.out_type.size self.final = nn.Conv2d(pool_out, 1, kernel_size=3, padding=1)
def get_bottleneck(self, name='bottleneck'): feat_type, _ = self.features[name] return nn.Sequential( OrderedDict({ f'{name}-conv1': R2Conv(feat_type, feat_type, kernel_size=3, padding=1, bias=False), f'{name}-bn1': GNormBatchNorm(feat_type), f'{name}-relu1': NormNonLinearity(feat_type, bias=False), f'{name}-conv2': R2Conv(feat_type, feat_type, kernel_size=3, padding=1, bias=False), f'{name}-bn2': GNormBatchNorm(feat_type), f'{name}-relu2': NormNonLinearity(feat_type, bias=False), }))
def get_encoder(self, name): feat_type_in, feat_type_out = self.features[name] return nn.Sequential( OrderedDict({ f'{name}-conv1': R2Conv(feat_type_in, feat_type_out, kernel_size=3, stride=2, padding=1, bias=False), f'{name}-bn1': GNormBatchNorm(feat_type_out), f'{name}-relu1': NormNonLinearity(feat_type_out, bias=False), f'{name}-conv2': R2Conv(feat_type_out, feat_type_out, kernel_size=3, padding=1, bias=False), f'{name}-bn2': GNormBatchNorm(feat_type_out), f'{name}-relu2': NormNonLinearity(feat_type_out, bias=False) }))
def get_encoder(self, name): feat_type_in, feat_type_out = self.features[name] return nn.Sequential( OrderedDict({ f'{name}-conv1': R2Conv(feat_type_in, feat_type_out, kernel_size=3, stride=2, padding=1, bias=False), f'{name}-bn1': InnerBatchNorm(feat_type_out), f'{name}-relu1': ReLU(feat_type_out, inplace=True), # f'{name}-maxpool': PointwiseMaxPool(feat_type_out, kernel_size=3, stride=2, padding=1), f'{name}-conv2': R2Conv(feat_type_out, feat_type_out, kernel_size=3, padding=1, bias=False), f'{name}-bn2': InnerBatchNorm(feat_type_out), f'{name}-relu2': ReLU(feat_type_out, inplace=True), }))
def get_bottleneck(self, name='bottleneck'): feat_type, _ = self.features[name] return nn.Sequential( OrderedDict({ f'{name}-conv1': R2Conv(feat_type, feat_type, kernel_size=3, padding=1, bias=False), f'{name}-bn1': InnerBatchNorm(feat_type), f'{name}-relu1': ReLU(feat_type, inplace=True), f'{name}-conv2': R2Conv(feat_type, feat_type, kernel_size=3, padding=1, bias=False), f'{name}-bn2': InnerBatchNorm(feat_type), f'{name}-relu2': ReLU(feat_type, inplace=True), }))
def get_decoder(self, name): feat_type_in, feat_type_out = self.features[name] return nn.Sequential( OrderedDict({ f'{name}-deconv1': R2ConvTransposed(feat_type_in, feat_type_out, kernel_size=3, stride=2, output_padding=1, bias=False), # f'{name}-upsample': R2Upsampling(feat_type_in, scale_factor=2), # f'{name}-conv1': R2Conv(feat_type_in, feat_type_out, kernel_size=3, padding=1, bias=False), f'{name}-bn1': InnerBatchNorm(feat_type_out), f'{name}-relu1': ReLU(feat_type_out, inplace=True), f'{name}-conv2': R2Conv(feat_type_out, feat_type_out, kernel_size=3, bias=False), # f'{name}-conv2': R2Conv(feat_type_out, feat_type_out, kernel_size=3, padding=1, bias=False), f'{name}-bn2': InnerBatchNorm(feat_type_out), f'{name}-relu2': ReLU(feat_type_out, inplace=True), }))
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.gspace = Rot2dOnR2(-1, maximum_frequency=2) self.input_type = FieldType(self.gspace, 3*[self.gspace.trivial_repr]) layers = [] irreps = [ v for k, v in self.gspace.irreps.items() if k != self.gspace.trivial_repr.name] trivials = FieldType(self.gspace, [self.gspace.trivial_repr]*10) gates = FieldType(self.gspace, len(irreps) * [self.gspace.trivial_repr]*10) gated = FieldType(self.gspace, irreps*10).sorted() gate = gates + gated self.small_type = trivials + gate layers.append( R2Conv(self.input_type, self.small_type, kernel_size=3, padding=1, bias=False) ) layers.append( MultipleModule(layers[-1].out_type, labels=[ *(["trivial"] * (len(trivials) + len(gates)) + ["gated"] * len(gated)) ], modules=[ (InnerBatchNorm(trivials + gates), 'trivial'), (NormBatchNorm(gated), 'gated') ]) ) layers.append( MultipleModule(layers[-1].out_type, labels=[ *(["trivial"] * len(trivials) + ["gate"] * len(gate)) ], modules=[ (ReLU(trivials), 'trivial'), (GatedNonLinearity1(gate), 'gate') ]) ) self.model = nn.Sequential(*layers) self.pool = NormPool(layers[-1].out_type) pool_out = self.pool.out_type.size self.final = nn.Conv2d(pool_out, 1, kernel_size=3, padding=1)