def __init__(self, n_channels): """Initializes a new ResidualBlock. Args: n_channels: The number of input and output channels. """ super().__init__() self._input_conv = nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=2, padding=1) self._output_conv = nn.Conv2d(in_channels=n_channels, out_channels=2 * n_channels, kernel_size=2, padding=1) self._activation = pg_nn.GatedActivation(activation_fn=nn.Identity())
def __init__(self, in_channels, out_channels, kernel_size=3, mask_center=False): """Initializes a new GatedPixelCNNLayer instance. Args: in_channels: The number of channels in the input. out_channels: The number of output channels. kernel_size: The size of the (causal) convolutional kernel to use. mask_center: Whether the 'GatedPixelCNNLayer' is causal. If 'True', the center pixel is masked out so the computation only depends on pixels to the left and above. The residual connection in the horizontal stack is also removed. """ super().__init__() assert kernel_size % 2 == 1, "kernel_size cannot be even" self._in_channels = in_channels self._out_channels = out_channels self._activation = pg_nn.GatedActivation() self._kernel_size = kernel_size self._padding = (kernel_size - 1) // 2 # (kernel_size - stride) / 2 self._mask_center = mask_center # Vertical stack convolutions. self._vstack_1xN = nn.Conv2d( in_channels=self._in_channels, out_channels=self._out_channels, kernel_size=(1, self._kernel_size), padding=(0, self._padding), ) # TODO(eugenhotaj): Is it better to shift down the the vstack_Nx1 output # instead of adding extra padding to the convolution? When we add extra # padding, the cropped output rows will no longer line up with the rows of # the vstack_1x1 output. self._vstack_Nx1 = nn.Conv2d( in_channels=self._out_channels, out_channels=2 * self._out_channels, kernel_size=(self._kernel_size // 2 + 1, 1), padding=(self._padding + 1, 0), ) self._vstack_1x1 = nn.Conv2d(in_channels=in_channels, out_channels=2 * out_channels, kernel_size=1) self._link = nn.Conv2d(in_channels=2 * out_channels, out_channels=2 * out_channels, kernel_size=1) # Horizontal stack convolutions. self._hstack_1xN = nn.Conv2d( in_channels=self._in_channels, out_channels=2 * self._out_channels, kernel_size=(1, self._kernel_size // 2 + 1), padding=(0, self._padding + int(self._mask_center)), ) self._hstack_residual = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1) self._hstack_skip = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1)