Exemple #1
0
def test_conv2d_output_shape(input_dim, kernel_size, stride, padding,
                             dilation):
    h, w = conv2d_output_shape(
        input_dim,
        kernel_size=kernel_size,
        stride=stride,
        pad=padding,
        dilation=dilation,
    )
    conv = torch.nn.Conv2d(1,
                           1,
                           kernel_size,
                           stride=stride,
                           padding=padding,
                           dilation=dilation)
    x = torch.rand(1, 1, *input_dim)
    assert conv(x).shape[2:] == (h, w)
Exemple #2
0
    def __init__(
        self,
        input_dim,
        input_channels: List = [2, 16, 32, 64, 128, 256],
        enc_hid_channels=8,
        enc_kernel_size=(1, 3),
        enc_padding=(0, 1),
        enc_last_kernel_size=(1, 4),
        enc_last_stride=(1, 2),
        enc_last_padding=(0, 1),
        enc_layers=5,
        skip_last_kernel_size=(1, 3),
        skip_last_stride=(1, 1),
        skip_last_padding=(0, 1),
        glstm_groups=2,
        glstm_layers=2,
        glstm_bidirectional=False,
        glstm_rearrange=False,
        output_channels=2,
    ):
        """Densely-Connected Convolutional Recurrent Network (DC-CRN).

        Reference: Fig. 3 and Section III-B in [1]

        Args:
            input_dim (int): input feature dimension
            input_channels (list): number of input channels for the stacked
                DenselyConnectedBlock layers
                Its length should be (`number of DenselyConnectedBlock layers`).
                It is recommended to use even number of channels to avoid AssertError
                when `glstm_bidirectional=True`.
            enc_hid_channels (int): common number of intermediate channels for all
                DenselyConnectedBlock of the encoder
            enc_kernel_size (tuple): common kernel size for all DenselyConnectedBlock
                of the encoder
            enc_padding (tuple): common padding for all DenselyConnectedBlock
                of the encoder
            enc_last_kernel_size (tuple): common kernel size for the last Conv layer
                in all DenselyConnectedBlock of the encoder
            enc_last_stride (tuple): common stride for the last Conv layer in all
                DenselyConnectedBlock of the encoder
            enc_last_padding (tuple): common padding for the last Conv layer in all
                DenselyConnectedBlock of the encoder
            enc_layers (int): common total number of Conv layers for all
                DenselyConnectedBlock layers of the encoder
            skip_last_kernel_size (tuple): common kernel size for the last Conv layer
                in all DenselyConnectedBlock of the skip pathways
            skip_last_stride (tuple): common stride for the last Conv layer in all
                DenselyConnectedBlock of the skip pathways
            skip_last_padding (tuple): common padding for the last Conv layer in all
                DenselyConnectedBlock of the skip pathways
            glstm_groups (int): number of groups in each Grouped LSTM layer
            glstm_layers (int): number of Grouped LSTM layers
            glstm_bidirectional (bool): whether to use BLSTM or unidirectional LSTM
                in Grouped LSTM layers
            glstm_rearrange (bool): whether to apply the rearrange operation after each
                grouped LSTM layer
            output_channels (int): number of output channels (must be an even number to
                recover both real and imaginary parts)
        """
        super().__init__()

        assert output_channels % 2 == 0, output_channels
        self.conv_enc = nn.ModuleList()
        # here T=42 is a random integer that should not be changed after Conv
        T = 42
        hidden_sizes = [input_dim]
        hdim = input_dim
        for i in range(1, len(input_channels)):
            self.conv_enc.append(
                DenselyConnectedBlock(
                    in_channels=input_channels[i - 1],
                    out_channels=input_channels[i],
                    hid_channels=enc_hid_channels,
                    kernel_size=enc_kernel_size,
                    padding=enc_padding,
                    last_kernel_size=enc_last_kernel_size,
                    last_stride=enc_last_stride,
                    last_padding=enc_last_padding,
                    layers=enc_layers,
                    transposed=False,
                ))
            tdim, hdim = conv2d_output_shape(
                (T, hdim),
                kernel_size=enc_last_kernel_size,
                stride=enc_last_stride,
                pad=enc_last_padding,
            )
            hidden_sizes.append(hdim)
            assert tdim == T, (tdim, hdim)

        hs = hdim * input_channels[-1]
        assert hs >= glstm_groups, (hs, glstm_groups)
        self.glstm = GLSTM(
            hidden_size=hs,
            groups=glstm_groups,
            layers=glstm_layers,
            bidirectional=glstm_bidirectional,
            rearrange=glstm_rearrange,
        )

        self.skip_pathway = nn.ModuleList()
        self.deconv_dec = nn.ModuleList()
        for i in range(len(input_channels) - 1, 0, -1):
            self.skip_pathway.append(
                DenselyConnectedBlock(
                    in_channels=input_channels[i],
                    out_channels=input_channels[i],
                    hid_channels=enc_hid_channels,
                    kernel_size=enc_kernel_size,
                    padding=enc_padding,
                    last_kernel_size=skip_last_kernel_size,
                    last_stride=skip_last_stride,
                    last_padding=skip_last_padding,
                    layers=enc_layers,
                    transposed=False,
                ))
            # make sure the last two dimensions will not be changed after this layer
            enc_hdim = hidden_sizes[i]
            tdim, hdim = conv2d_output_shape(
                (T, enc_hdim),
                kernel_size=skip_last_kernel_size,
                stride=skip_last_stride,
                pad=skip_last_padding,
            )
            assert tdim == T and hdim == enc_hdim, (tdim, hdim, T, enc_hdim)

            if i != 1:
                out_ch = input_channels[i - 1]
            else:
                out_ch = output_channels
            # make sure the last but one dimension will not be changed after this layer
            tdim, hdim = convtransp2d_output_shape(
                (T, enc_hdim),
                kernel_size=enc_last_kernel_size,
                stride=enc_last_stride,
                pad=enc_last_padding,
            )
            assert tdim == T, (tdim, hdim)
            hpadding = hidden_sizes[i - 1] - hdim
            assert hpadding >= 0, (hidden_sizes[i - 1], hdim)
            self.deconv_dec.append(
                DenselyConnectedBlock(
                    in_channels=input_channels[i] * 2,
                    out_channels=out_ch,
                    hid_channels=enc_hid_channels,
                    kernel_size=enc_kernel_size,
                    padding=enc_padding,
                    last_kernel_size=enc_last_kernel_size,
                    last_stride=enc_last_stride,
                    last_padding=enc_last_padding,
                    last_output_padding=(0, hpadding),
                    layers=enc_layers,
                    transposed=True,
                ))

        self.fc_real = nn.Linear(in_features=input_dim, out_features=input_dim)
        self.fc_imag = nn.Linear(in_features=input_dim, out_features=input_dim)
Exemple #3
0
    def __init__(
            self,
            in_channels,
            out_channels,
            hid_channels=8,
            kernel_size=(1, 3),
            padding=(0, 1),
            last_kernel_size=(
                1, 4),  # use (1, 4) to alleviate the checkerboard artifacts
            last_stride=(1, 2),
            last_padding=(0, 1),
            last_output_padding=(0, 0),
            layers=5,
            transposed=False,
    ):
        """Densely-Connected Convolutional Block.

        Args:
            in_channels (int): number of input channels
            out_channels (int): number of output channels
            hid_channels (int): number of output channels in intermediate Conv layers
            kernel_size (tuple): kernel size for all but the last Conv layers
            padding (tuple): padding for all but the last Conv layers
            last_kernel_size (tuple): kernel size for the last GluConv layer
            last_stride (tuple): stride for the last GluConv layer
            last_padding (tuple): padding for the last GluConv layer
            last_output_padding (tuple): output padding for the last GluConvTranspose2d
                 (only used when `transposed=True`)
            layers (int): total number of Conv layers
            transposed (bool): True to use GluConvTranspose2d in the last layer
                               False to use GluConv2d in the last layer
        """
        super().__init__()

        assert layers > 1, layers
        self.conv = nn.ModuleList()
        in_channel = in_channels
        # here T=42 and D=127 are random integers that should not be changed after Conv
        T, D = 42, 127
        hidden_sizes = [127]
        for _ in range(layers - 1):
            self.conv.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channel,
                        hid_channels,
                        kernel_size=kernel_size,
                        stride=(1, 1),
                        padding=padding,
                    ),
                    nn.BatchNorm2d(hid_channels),
                    nn.ELU(inplace=True),
                ))
            in_channel = in_channel + hid_channels
            # make sure the last two dimensions will not be changed after this layer
            tdim, hdim = conv2d_output_shape(
                (T, D),
                kernel_size=kernel_size,
                stride=(1, 1),
                pad=padding,
            )
            hidden_sizes.append(hdim)
            assert tdim == T and hdim == D, (tdim, hdim, T, D)

        if transposed:
            self.conv.append(
                GluConvTranspose2d(
                    in_channel,
                    out_channels,
                    kernel_size=last_kernel_size,
                    stride=last_stride,
                    padding=last_padding,
                    output_padding=last_output_padding,
                ))
        else:
            self.conv.append(
                GluConv2d(
                    in_channel,
                    out_channels,
                    kernel_size=last_kernel_size,
                    stride=last_stride,
                    padding=last_padding,
                ))