def __init__(self, n_channels):
        """Initializes a new ResidualBlock.

    Args:
      n_channels: The number of input and output channels.
    """
        super().__init__()
        self._input_conv = nn.Conv2d(in_channels=n_channels,
                                     out_channels=n_channels,
                                     kernel_size=2,
                                     padding=1)
        self._output_conv = nn.Conv2d(in_channels=n_channels,
                                      out_channels=2 * n_channels,
                                      kernel_size=2,
                                      padding=1)
        self._activation = pg_nn.GatedActivation(activation_fn=nn.Identity())
Exemple #2
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 mask_center=False):
        """Initializes a new GatedPixelCNNLayer instance.

        Args:
            in_channels: The number of channels in the input.
            out_channels: The number of output channels.
            kernel_size: The size of the (causal) convolutional kernel to use.
            mask_center: Whether the 'GatedPixelCNNLayer' is causal. If 'True', the
                center pixel is masked out so the computation only depends on pixels to
                the left and above. The residual connection in the horizontal stack is
                also removed.
        """
        super().__init__()

        assert kernel_size % 2 == 1, "kernel_size cannot be even"

        self._in_channels = in_channels
        self._out_channels = out_channels
        self._activation = pg_nn.GatedActivation()
        self._kernel_size = kernel_size
        self._padding = (kernel_size - 1) // 2  # (kernel_size - stride) / 2
        self._mask_center = mask_center

        # Vertical stack convolutions.
        self._vstack_1xN = nn.Conv2d(
            in_channels=self._in_channels,
            out_channels=self._out_channels,
            kernel_size=(1, self._kernel_size),
            padding=(0, self._padding),
        )
        # TODO(eugenhotaj): Is it better to shift down the the vstack_Nx1 output
        # instead of adding extra padding to the convolution? When we add extra
        # padding, the cropped output rows will no longer line up with the rows of
        # the vstack_1x1 output.
        self._vstack_Nx1 = nn.Conv2d(
            in_channels=self._out_channels,
            out_channels=2 * self._out_channels,
            kernel_size=(self._kernel_size // 2 + 1, 1),
            padding=(self._padding + 1, 0),
        )
        self._vstack_1x1 = nn.Conv2d(in_channels=in_channels,
                                     out_channels=2 * out_channels,
                                     kernel_size=1)

        self._link = nn.Conv2d(in_channels=2 * out_channels,
                               out_channels=2 * out_channels,
                               kernel_size=1)

        # Horizontal stack convolutions.
        self._hstack_1xN = nn.Conv2d(
            in_channels=self._in_channels,
            out_channels=2 * self._out_channels,
            kernel_size=(1, self._kernel_size // 2 + 1),
            padding=(0, self._padding + int(self._mask_center)),
        )
        self._hstack_residual = nn.Conv2d(in_channels=out_channels,
                                          out_channels=out_channels,
                                          kernel_size=1)
        self._hstack_skip = nn.Conv2d(in_channels=out_channels,
                                      out_channels=out_channels,
                                      kernel_size=1)