def __init__(self, in_size, out_size, is_batchnorm):
        super(UnetGatingSignal3, self).__init__()
        self.fmap_size = (4, 4, 4)

        if is_batchnorm:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, in_size // 2, (1, 1, 1), (1, 1, 1),
                          (0, 0, 0)),
                nn.BatchNorm3d(in_size // 2),
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool3d(output_size=self.fmap_size),
            )
            self.fc1 = nn.Linear(in_features=(in_size // 2) *
                                 self.fmap_size[0] * self.fmap_size[1] *
                                 self.fmap_size[2],
                                 out_features=out_size,
                                 bias=True)
        else:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, in_size // 2, (1, 1, 1), (1, 1, 1),
                          (0, 0, 0)),
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool3d(output_size=self.fmap_size),
            )
            self.fc1 = nn.Linear(in_features=(in_size // 2) *
                                 self.fmap_size[0] * self.fmap_size[1] *
                                 self.fmap_size[2],
                                 out_features=out_size,
                                 bias=True)

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
    def __init__(self,
                 in_size,
                 out_size,
                 is_batchnorm,
                 kernel_size=(3, 3, 1),
                 padding_size=(1, 1, 0),
                 init_stride=(1, 1, 1)):
        super(UnetConv3, self).__init__()

        if is_batchnorm:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, out_size, kernel_size, init_stride,
                          padding_size),
                nn.BatchNorm3d(out_size),
                nn.ReLU(inplace=True),
            )
            self.conv2 = nn.Sequential(
                nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size),
                nn.BatchNorm3d(out_size),
                nn.ReLU(inplace=True),
            )
        else:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, out_size, kernel_size, init_stride,
                          padding_size),
                nn.ReLU(inplace=True),
            )
            self.conv2 = nn.Sequential(
                nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size),
                nn.ReLU(inplace=True),
            )

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
    def __init__(self, in_size, out_size, is_batchnorm=True):
        super(UnetUp3_CT, self).__init__()
        self.conv = UnetConv3(in_size + out_size,
                              out_size,
                              is_batchnorm,
                              kernel_size=(3, 3, 3),
                              padding_size=(1, 1, 1))
        self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear')

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('UnetConv3') != -1: continue
            init_weights(m, init_type='kaiming')
    def __init__(self, in_size, out_size, is_deconv):
        super(unetUp, self).__init__()
        self.conv = unetConv2(in_size, out_size, False)
        if is_deconv:
            self.up = nn.ConvTranspose2d(in_size,
                                         out_size,
                                         kernel_size=4,
                                         stride=2,
                                         padding=1)
        else:
            self.up = nn.UpsamplingBilinear2d(scale_factor=2)

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('unetConv2') != -1: continue
            init_weights(m, init_type='kaiming')
    def __init__(self, in_size, out_size, is_deconv, is_batchnorm=True):
        super(UnetUp3, self).__init__()
        if is_deconv:
            self.conv = UnetConv3(in_size, out_size, is_batchnorm)
            self.up = nn.ConvTranspose3d(in_size,
                                         out_size,
                                         kernel_size=(4, 4, 1),
                                         stride=(2, 2, 1),
                                         padding=(1, 1, 0))
        else:
            self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm)
            self.up = nn.Upsample(scale_factor=(2, 2, 1), mode='trilinear')

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('UnetConv3') != -1: continue
            init_weights(m, init_type='kaiming')
    def __init__(self,
                 in_size,
                 out_size,
                 is_batchnorm,
                 n=2,
                 ks=3,
                 stride=1,
                 padding=1):
        super(unetConv2, self).__init__()
        self.n = n
        self.ks = ks
        self.stride = stride
        self.padding = padding
        s = stride
        p = padding
        if is_batchnorm:
            for i in range(1, n + 1):
                conv = nn.Sequential(
                    nn.Conv2d(in_size, out_size, ks, s, p),
                    nn.BatchNorm2d(out_size),
                    nn.ReLU(inplace=True),
                )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size

        else:
            for i in range(1, n + 1):
                conv = nn.Sequential(
                    nn.Conv2d(in_size, out_size, ks, s, p),
                    nn.ReLU(inplace=True),
                )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
    def __init__(self,
                 in_size,
                 out_size,
                 kernel_size=(1, 1, 1),
                 is_batchnorm=True):
        super(UnetGridGatingSignal3, self).__init__()

        if is_batchnorm:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, out_size, kernel_size, (1, 1, 1),
                          (0, 0, 0)),
                nn.BatchNorm3d(out_size),
                nn.ReLU(inplace=True),
            )
        else:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, out_size, kernel_size, (1, 1, 1),
                          (0, 0, 0)),
                nn.ReLU(inplace=True),
            )

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
    def __init__(self,
                 in_channels,
                 gating_channels,
                 inter_channels=None,
                 dimension=3,
                 mode='concatenation',
                 sub_sample_factor=(1, 1, 1),
                 bn_layer=True,
                 use_W=True,
                 use_phi=True,
                 use_theta=True,
                 use_psi=True,
                 nonlinearity1='relu'):
        super(_GridAttentionBlockND_TORR, self).__init__()

        assert dimension in [2, 3]
        assert mode in [
            'concatenation', 'concatenation_softmax', 'concatenation_sigmoid',
            'concatenation_mean', 'concatenation_range_normalise',
            'concatenation_mean_flow'
        ]

        # Default parameter set
        self.mode = mode
        self.dimension = dimension
        self.sub_sample_factor = sub_sample_factor if isinstance(
            sub_sample_factor,
            tuple) else tuple([sub_sample_factor]) * dimension
        self.sub_sample_kernel_size = self.sub_sample_factor

        # Number of channels (pixel dimensions)
        self.in_channels = in_channels
        self.gating_channels = gating_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            bn = nn.BatchNorm3d
            self.upsample_mode = 'trilinear'
        elif dimension == 2:
            conv_nd = nn.Conv2d
            bn = nn.BatchNorm2d
            self.upsample_mode = 'bilinear'
        else:
            raise NotImplemented

        # initialise id functions
        # Theta^T * x_ij + Phi^T * gating_signal + bias
        self.W = lambda x: x
        self.theta = lambda x: x
        self.psi = lambda x: x
        self.phi = lambda x: x
        self.nl1 = lambda x: x

        if use_W:
            if bn_layer:
                self.W = nn.Sequential(
                    conv_nd(in_channels=self.in_channels,
                            out_channels=self.in_channels,
                            kernel_size=1,
                            stride=1,
                            padding=0),
                    bn(self.in_channels),
                )
            else:
                self.W = conv_nd(in_channels=self.in_channels,
                                 out_channels=self.in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)

        if use_theta:
            self.theta = conv_nd(in_channels=self.in_channels,
                                 out_channels=self.inter_channels,
                                 kernel_size=self.sub_sample_kernel_size,
                                 stride=self.sub_sample_factor,
                                 padding=0,
                                 bias=False)

        if use_phi:
            self.phi = conv_nd(in_channels=self.gating_channels,
                               out_channels=self.inter_channels,
                               kernel_size=self.sub_sample_kernel_size,
                               stride=self.sub_sample_factor,
                               padding=0,
                               bias=False)

        if use_psi:
            self.psi = conv_nd(in_channels=self.inter_channels,
                               out_channels=1,
                               kernel_size=1,
                               stride=1,
                               padding=0,
                               bias=True)

        if nonlinearity1:
            if nonlinearity1 == 'relu':
                self.nl1 = lambda x: F.relu(x, inplace=True)

        if 'concatenation' in mode:
            # self.operation_function = self._concatenation
            self.operation_function = Concatenation(
                W=self.W,
                phi=self.phi,
                psi=self.psi,
                theta=self.theta,
                nl1=self.nl1,
                mode=mode,
                upsample_mode=self.upsample_mode)
        else:
            raise NotImplementedError('Unknown operation function.')

        # Initialise weights
        for m in self.children():
            init_weights(m, init_type='kaiming')

        if use_psi and self.mode == 'concatenation_sigmoid':
            nn.init.constant(self.psi.bias.data, 3.0)

        if use_psi and self.mode == 'concatenation_softmax':
            nn.init.constant(self.psi.bias.data, 10.0)

        # if use_psi and self.mode == 'concatenation_mean':
        #     nn.init.constant(self.psi.bias.data, 3.0)

        # if use_psi and self.mode == 'concatenation_range_normalise':
        #     nn.init.constant(self.psi.bias.data, 3.0)

        parallel = False
        if parallel:
            if use_W: self.W = nn.DataParallel(self.W)
            if use_phi: self.phi = nn.DataParallel(self.phi)
            if use_psi: self.psi = nn.DataParallel(self.psi)
            if use_theta: self.theta = nn.DataParallel(self.theta)
    def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian',
                 sub_sample_factor=4, bn_layer=True):
        super(_NonLocalBlockND, self).__init__()

        assert dimension in [1, 2, 3]
        assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']

        # print('Dimension: %d, mode: %s' % (dimension, mode))

        self.mode = mode
        self.dimension = dimension
        self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, list) else [sub_sample_factor]

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool = nn.MaxPool3d
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool = nn.MaxPool2d
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool = nn.MaxPool1d
            bn = nn.BatchNorm1d

        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0)

        if bn_layer:
            self.W = nn.Sequential(
                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0),
                bn(self.in_channels)
            )
            nn.init.constant(self.W[1].weight, 0)
            nn.init.constant(self.W[1].bias, 0)
        else:
            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0)
            nn.init.constant(self.W.weight, 0)
            nn.init.constant(self.W.bias, 0)

        self.theta = None
        self.phi = None

        if mode in ['embedded_gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']:
            self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                                 kernel_size=1, stride=1, padding=0)
            self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                               kernel_size=1, stride=1, padding=0)

            if mode in ['concatenation']:
                self.wf_phi = nn.Linear(self.inter_channels, 1, bias=False)
                self.wf_theta = nn.Linear(self.inter_channels, 1, bias=False)
            elif mode in ['concat_proper', 'concat_proper_down']:
                self.psi = nn.Conv2d(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1,
                                     padding=0, bias=True)

        if mode == 'embedded_gaussian':
            self.operation_function = self._embedded_gaussian
        elif mode == 'dot_product':
            self.operation_function = self._dot_product
        elif mode == 'gaussian':
            self.operation_function = self._gaussian
        elif mode == 'concatenation':
            self.operation_function = self._concatenation
        elif mode == 'concat_proper':
            self.operation_function = self._concatenation_proper
        elif mode == 'concat_proper_down':
            self.operation_function = self._concatenation_proper_down
        else:
            raise NotImplementedError('Unknown operation function.')

        if any(ss > 1 for ss in self.sub_sample_factor):
            self.g = nn.Sequential(self.g, max_pool(kernel_size=sub_sample_factor))
            if self.phi is None:
                self.phi = max_pool(kernel_size=sub_sample_factor)
            else:
                self.phi = nn.Sequential(self.phi, max_pool(kernel_size=sub_sample_factor))
            if mode == 'concat_proper_down':
                self.theta = nn.Sequential(self.theta, max_pool(kernel_size=sub_sample_factor))

        # Initialise weights
        for m in self.children():
            init_weights(m, init_type='kaiming')
    def __init__(self, feature_scale=4, n_classes=21, in_channels=3, is_batchnorm=True, n_convs=None,
                 nonlocal_mode='concatenation', aggregation_mode='concat'):
        super(resnet_grid_attention, self).__init__()
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale
        self.n_classes= n_classes
        self.aggregation_mode = aggregation_mode
        self.deep_supervised = True

        # # Resnet Pretrained Define
        self.resnet = torchvision.models.resnet18(pretrained=True)
        conv1_weight = self.resnet.conv1.weight.data
        self.resnet.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3,bias=False)
        self.resnet.conv1.weight = torch.nn.Parameter(torch.cat((conv1_weight,conv1_weight[:,:1,:,:]),dim=1))
        filters = [64, 64, 128, 256, 512]
        # filters = [int(x / self.feature_scale) for x in filters]
            # # Feature Extraction
        self.conv1 = self.resnet.conv1
        self.conv2 = self.resnet.layer1
        self.conv3 = self.resnet.layer2
        self.conv4 = self.resnet.layer3
        self.conv5 = self.resnet.layer4


        # if n_convs is None:
        #     n_convs = [3, 3, 3, 2, 2]

        # filters = [64, 128, 256, 512]
        # filters = [int(x / self.feature_scale) for x in filters]

        ####################
        # # Feature Extraction
        # self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm, n=n_convs[0])
        # self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        # self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm, n=n_convs[1])
        # self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        # self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm, n=n_convs[2])
        # self.maxpool3 = nn.MaxPool2d(kernel_size=2)

        # self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm, n=n_convs[3])
        # self.maxpool4 = nn.MaxPool2d(kernel_size=2)

        # self.conv5 = unetConv2(filters[3], filters[3], self.is_batchnorm, n=n_convs[4])

        ################
        # Attention Maps
        self.compatibility_score1 = AttentionBlock2D(in_channels=filters[2], gating_channels=filters[4],
                                                     inter_channels=filters[4], sub_sample_factor=(1,1),
                                                     mode=nonlocal_mode, use_W=False, use_phi=True,
                                                     use_theta=True, use_psi=True, nonlinearity1='relu')

        self.compatibility_score2 = AttentionBlock2D(in_channels=filters[3], gating_channels=filters[4],
                                                     inter_channels=filters[4], sub_sample_factor=(1,1),
                                                     mode=nonlocal_mode, use_W=False, use_phi=True,
                                                     use_theta=True, use_psi=True, nonlinearity1='relu')

        #########################
        # Aggreagation Strategies
        self.attention_filter_sizes = [filters[2], filters[3]]

        if aggregation_mode == 'concat':
            self.classifier = nn.Linear(filters[2]+filters[3]+filters[4], n_classes)
            self.aggregate = self.aggreagation_concat

        else:
            self.classifier1 = nn.Linear(filters[2], n_classes)
            self.classifier2 = nn.Linear(filters[3], n_classes)
            self.classifier3 = nn.Linear(filters[4], n_classes)
            self.classifiers = [self.classifier1, self.classifier2, self.classifier3]

            if aggregation_mode == 'mean':
                # self.aggregate = self.aggregation_sep
                self.aggregate = Aggregation_sep(self.classifiers)

            elif aggregation_mode == 'deep_sup':
                self.classifier = nn.Linear(filters[2] + filters[3] + filters[4], n_classes)
                # self.aggregate = self.aggregation_ds
                self.aggregate = Aggregation_ds(self.classifiers, self.classifier)

            elif aggregation_mode == 'ft':
                self.classifier = nn.Linear(n_classes*3, n_classes)
                # self.aggregate = self.aggregation_ft
                self.aggregate = Aggregation_ft(self.classifiers, self.classifier)
            else:
                raise NotImplementedError

        ####################
        # initialise weights
        # Freezing BatchNorm2D
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm2d):
                init_weights(m, init_type='kaiming')