def __init__(self,
                 x_channels=128,
                 num_classes=23,
                 depthwise_separable_convolution=True,
                 squeeze_excitation=True):
        super(StageN, self).__init__()

        kernel_size = 11

        if depthwise_separable_convolution:
            first_convs = [
                DepthwiseSeparableConvolution(in_channels=32 + num_classes,
                                              out_channels=x_channels,
                                              kernel_size=kernel_size),
                DepthwiseSeparableConvolution(in_channels=x_channels,
                                              out_channels=x_channels,
                                              kernel_size=kernel_size),
                DepthwiseSeparableConvolution(in_channels=x_channels,
                                              out_channels=x_channels,
                                              kernel_size=kernel_size)
            ]
        else:
            first_convs = [
                nn.Conv2d(in_channels=32 + num_classes,
                          out_channels=x_channels,
                          kernel_size=11,
                          padding=same_padding(kernel_size)),
                nn.Conv2d(in_channels=x_channels,
                          out_channels=x_channels,
                          kernel_size=11,
                          padding=same_padding(kernel_size)),
                nn.Conv2d(in_channels=x_channels,
                          out_channels=x_channels,
                          kernel_size=11,
                          padding=same_padding(kernel_size)),
            ]

        if squeeze_excitation:
            first_convs.insert(
                3, SqueezeExcitation(channels=x_channels, ratio=16))

        self.convs = nn.ModuleList([
            *first_convs,
            nn.Conv2d(in_channels=x_channels,
                      out_channels=x_channels,
                      kernel_size=1),
            nn.Conv2d(in_channels=x_channels,
                      out_channels=num_classes,
                      kernel_size=1)
        ])

        self.relu = nn.ReLU()
    def __init__(self,
                 x_channels=128,
                 stage_channels=512,
                 num_classes=23,
                 depthwise_separable_convolution=True,
                 squeeze_excitation=True,
                 dilation=1):
        super(Stage1, self).__init__()
        self.X = X(x_channels, depthwise_separable_convolution,
                   squeeze_excitation, dilation)

        if depthwise_separable_convolution:
            first_conv = DepthwiseSeparableConvolution(
                in_channels=32, out_channels=stage_channels, kernel_size=9)
        else:
            first_conv = nn.Conv2d(in_channels=32,
                                   out_channels=stage_channels,
                                   kernel_size=9,
                                   padding=same_padding(9))

        self.convs = nn.ModuleList([
            first_conv,
            nn.Conv2d(in_channels=stage_channels,
                      out_channels=stage_channels,
                      kernel_size=1),
            nn.Conv2d(in_channels=stage_channels,
                      out_channels=num_classes,
                      kernel_size=1)
        ])

        self.relu = nn.ReLU()
    def __init__(self,
                 in_channels,
                 kernel_size,
                 dilation=1,
                 depthwise_separable_convolution=True):
        super(Bottleneck, self).__init__()

        if depthwise_separable_convolution:
            self.conv1 = DepthwiseSeparableConvolution(
                in_channels=in_channels,
                out_channels=in_channels // 2,
                kernel_size=1)
            self.conv2 = DepthwiseSeparableConvolution(
                in_channels=in_channels // 2,
                out_channels=in_channels,
                kernel_size=kernel_size,
                dilation=dilation)
            self.conv3 = DepthwiseSeparableConvolution(
                in_channels=in_channels,
                out_channels=in_channels,
                kernel_size=1)
        else:
            self.conv1 = nn.Conv2d(in_channels, in_channels // 2, 1)
            self.conv2 = nn.Conv2d(in_channels // 2,
                                   in_channels,
                                   kernel_size,
                                   padding=same_padding(kernel_size, dilation),
                                   dilation=dilation)
            self.conv3 = nn.Conv2d(in_channels, in_channels, 1)

        self.relu = nn.ReLU()
    def __init__(self,
                 num_stacks=3,
                 num_blocks=1,
                 num_channels=32,
                 num_classes=23,
                 kernel_size=3,
                 dilation=1,
                 depthwise_separable_convolution=True):
        super(StackedHourglassNet, self).__init__()

        assert (1 <= num_blocks <= 7, "invalid number of blocks [1, 7]")

        self.num_stacks = num_stacks
        self.init_channels = num_channels
        self.channels = num_channels
        self.conv1 = (DepthwiseSeparableConvolution(
            in_channels=1, out_channels=self.channels, kernel_size=kernel_size)
                      if depthwise_separable_convolution else nn.Conv2d(
                          in_channels=1,
                          out_channels=self.channels,
                          kernel_size=kernel_size,
                          padding=same_padding(kernel_size)))
        self.relu = nn.ReLU()

        hgs, intermediate_conv1, intermediate_conv2, loss_conv, intermediate_conv3 = [], [], [], [], []
        for i in range(self.num_stacks):
            hgs.append(
                Hourglass(num_blocks, self.channels, kernel_size, dilation,
                          depthwise_separable_convolution))

            intermediate_conv1.append(
                Bottleneck(self.channels, kernel_size, dilation,
                           depthwise_separable_convolution))

            loss_conv.append(
                DepthwiseSeparableConvolution(in_channels=self.channels,
                                              out_channels=num_classes,
                                              kernel_size=1)
                if depthwise_separable_convolution else nn.
                Conv2d(self.channels, num_classes, 1))

            if i < self.num_stacks - 1:
                intermediate_conv2.append(
                    Bottleneck(self.channels, kernel_size, dilation,
                               depthwise_separable_convolution))

                intermediate_conv3.append(
                    DepthwiseSeparableConvolution(in_channels=num_classes,
                                                  out_channels=self.channels,
                                                  kernel_size=1)
                    if depthwise_separable_convolution else nn.
                    Conv2d(num_classes, self.channels, 1))

        self.hgs = nn.ModuleList(hgs)
        self.intermediate_conv1 = nn.ModuleList(intermediate_conv1)
        self.intermediate_conv2 = nn.ModuleList(intermediate_conv2)
        self.loss_conv = nn.ModuleList(loss_conv)
        self.intermediate_conv3 = nn.ModuleList(intermediate_conv3)
    def __init__(self,
                 x_channels=128,
                 depthwise_separable_convolution=True,
                 squeeze_excitation=True,
                 dilation=1):
        super(X, self).__init__()

        kernel_size = calculate_kernel_size(9, dilation)

        convs = [
            nn.Conv2d(in_channels=1,
                      out_channels=x_channels,
                      kernel_size=9,
                      padding=same_padding(9))
        ]

        if depthwise_separable_convolution:
            convs += [
                DepthwiseSeparableConvolution(in_channels=x_channels,
                                              out_channels=x_channels,
                                              kernel_size=9),
                DepthwiseSeparableConvolution(in_channels=x_channels,
                                              out_channels=x_channels,
                                              kernel_size=9),
                DepthwiseSeparableConvolution(in_channels=x_channels,
                                              out_channels=32,
                                              kernel_size=5)
            ]
        else:
            convs += [
                nn.Conv2d(in_channels=x_channels,
                          out_channels=x_channels,
                          kernel_size=kernel_size,
                          padding=same_padding(kernel_size, dilation),
                          dilation=dilation),
                nn.Conv2d(in_channels=x_channels,
                          out_channels=x_channels,
                          kernel_size=kernel_size,
                          padding=same_padding(kernel_size, dilation),
                          dilation=dilation),
                nn.Conv2d(in_channels=x_channels,
                          out_channels=x_channels,
                          kernel_size=kernel_size,
                          padding=same_padding(kernel_size, dilation),
                          dilation=dilation),
                nn.Conv2d(in_channels=x_channels,
                          out_channels=x_channels,
                          kernel_size=kernel_size,
                          padding=same_padding(kernel_size, dilation),
                          dilation=dilation),
                nn.Conv2d(in_channels=x_channels,
                          out_channels=32,
                          kernel_size=5,
                          padding=same_padding(5))
            ]

        if squeeze_excitation:
            convs.insert(3, SqueezeExcitation(channels=x_channels, ratio=16))

        self.convs = nn.ModuleList(convs)

        self.max_pool = nn.MaxPool2d(3, 2, 1)

        self.relu = nn.ReLU()