def __init__(self, input_size, classes, routings):
        super(CapsuleNet, self).__init__()
        self.input_size = input_size
        self.classes = classes
        self.routings = routings

        # Layer 1: Just a conventional Conv2D layer
        self.conv1 = nn.Conv2d(input_size[0], 256, kernel_size=(6,24), stride=2, padding=0)

        # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_caps, dim_caps]
        self.primarycaps = PrimaryCapsule(256, 256, 32, kernel_size=(6,12), stride=(2,1), padding=0)

        # Layer 3: Capsule layer. Routing algorithm works here.
        self.digitcaps = DenseCapsule(in_num_caps=8*21*8, in_dim_caps=32,
                                      out_num_caps=classes, out_dim_caps=8, routings=routings)

        # Decoder network.
        self.decoder = nn.Sequential(
            nn.Linear(8*classes, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, input_size[0] * input_size[1] * input_size[2]),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU()
    def __init__(self, input_size, classes, routings):
        super().__init__()
        self.input_size = input_size
        self.classes = classes
        self.routings = routings
        n_labels = 12
        n_maps = 45
        n_layers = 13
        n_pricap = 30
        self.conv0 = nn.Conv2d(1, n_maps, (3, 3), padding=(1, 1), bias=False)


        self.n_layers = n_layers
        dilation = True
        if dilation:
            self.convs = [nn.Conv2d(n_maps, n_maps, (3, 3), padding=int(2**(i // 3)), dilation=int(2**(i // 3)),
                bias=False) for i in range(n_layers)]
        else:
            self.convs = [nn.Conv2d(n_maps, n_maps, (3, 3), padding=1, dilation=1,
                bias=False) for _ in range(n_layers)]
        for i, conv in enumerate(self.convs):
            self.add_module("bn{}".format(i + 1), nn.BatchNorm2d(n_maps, affine=False))
            self.add_module("conv{}".format(i + 1), conv)
        # self.output = nn.Linear(n_maps, n_labels)


        self.convF = nn.Conv2d(n_maps, n_maps, kernel_size=(28,28), stride=2)
        self.batchF = nn.BatchNorm2d(n_maps,affine=False)
        self.primarycaps = PrimaryCapsule(n_maps)
        self.digitcaps = DenseCapsule(in_num_caps=36*17, in_dim_caps=n_maps,
                                      out_num_caps=classes, out_dim_caps=16, routings=routings)

        # self.convF = nn.Conv2d(n_maps, n_pricap, kernel_size=(3,3), stride=1, padding=0,dilation=16)
        # self.batchF = nn.BatchNorm2d(n_pricap,affine=False)
        # self.primarycaps = PrimaryCapsule(n_pricap)
        # self.capbatch = nn.BatchNorm2d(n_maps,affine=False)
        # self.digitcaps = DenseCapsule(in_num_caps=1848, in_dim_caps=n_pricap,
        #                               out_num_caps=classes, out_dim_caps=16, routings=routings)
        self.decoder = nn.Sequential(
            nn.Linear(16*classes, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 2048),
            nn.ReLU(inplace=True),
            # nn.Dropout(0.6),
            nn.Linear(2048, input_size[0] * input_size[1] * input_size[2]),
            nn.Sigmoid()
        )
        # self.dropout1 = nn.Dropout(0.6)

        # self.dropout2 = nn.Dropout2d(0.1)
        self.relu = nn.ReLU()