예제 #1
0
    def __init__(self, n_channels):
        """Initializes a new CausalResidualBlock instance.

        Args:
            n_channels: The number of input (and output) channels.
        """
        super().__init__()
        self._net = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(in_channels=n_channels,
                      out_channels=n_channels // 2,
                      kernel_size=1),
            nn.ReLU(),
            pg_nn.CausalConv2d(
                mask_center=False,
                in_channels=n_channels // 2,
                out_channels=n_channels // 2,
                kernel_size=3,
                padding=1,
            ),
            nn.ReLU(),
            nn.Conv2d(in_channels=n_channels // 2,
                      out_channels=n_channels,
                      kernel_size=1),
        )
예제 #2
0
    def __init__(
        self,
        in_channels=1,
        out_channels=1,
        n_channels=64,
        n_pixel_snail_blocks=8,
        n_residual_blocks=2,
        attention_key_channels=4,
        attention_value_channels=32,
        sample_fn=None,
    ):
        """Initializes a new PixelSNAIL instance.

        Args:
            in_channels: Number of input channels.
            out_channels: Number of output_channels.
            n_channels: Number of channels to use for convolutions.
            n_pixel_snail_blocks: Number of PixelSNAILBlocks.
            n_residual_blocks: Number of ResidualBlock to use in each PixelSnailBlock.
            attention_key_channels: Number of channels (dims) for the attention key.
            attention_value_channels: Number of channels (dims) for the attention value.
            sample_fn: See the base class.
        """
        super().__init__(sample_fn)
        self._input = pg_nn.CausalConv2d(
            mask_center=True,
            in_channels=in_channels,
            out_channels=n_channels,
            kernel_size=3,
            padding=1,
        )
        self._pixel_snail_blocks = nn.ModuleList([
            PixelSNAILBlock(
                n_channels=n_channels,
                input_img_channels=in_channels,
                n_residual_blocks=n_residual_blocks,
                attention_key_channels=attention_key_channels,
                attention_value_channels=attention_value_channels,
            ) for _ in range(n_pixel_snail_blocks)
        ])
        self._output = nn.Sequential(
            nn.Conv2d(in_channels=n_channels,
                      out_channels=n_channels // 2,
                      kernel_size=1),
            nn.Conv2d(in_channels=n_channels // 2,
                      out_channels=out_channels,
                      kernel_size=1),
        )
예제 #3
0
    def __init__(self, in_channels=1, out_channels=1, sample_fn=None):
        """Initializes a new TinyCNN instance.

        Args:
            in_channels: Number of input channels.
            out_channels: Number of output channels.
            sample_fn: See the base class.
        """
        super().__init__(sample_fn)
        self._conv = pg_nn.CausalConv2d(
            mask_center=True,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            padding=1,
        )
예제 #4
0
    def __init__(
        self,
        in_channels=1,
        out_channels=1,
        n_residual=15,
        residual_channels=128,
        head_channels=32,
        sample_fn=None,
    ):
        """Initializes a new PixelCNN instance.

        Args:
            in_channels: The number of input channels.
            out_channels: The number of output channels.
            n_residual: The number of residual blocks.
            residual_channels: The number of channels to use in the residual layers.
            head_channels: The number of channels to use in the two 1x1 convolutional
                layers at the head of the network.
            sample_fn: See the base class.
        """
        super().__init__(sample_fn)
        self._input = pg_nn.CausalConv2d(
            mask_center=True,
            in_channels=in_channels,
            out_channels=2 * residual_channels,
            kernel_size=7,
            padding=3,
        )
        self._causal_layers = nn.ModuleList(
            [
                CausalResidualBlock(n_channels=2 * residual_channels)
                for _ in range(n_residual)
            ]
        )
        self._head = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(
                in_channels=2 * residual_channels,
                out_channels=head_channels,
                kernel_size=1,
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=head_channels, out_channels=out_channels, kernel_size=1
            ),
        )
예제 #5
0
    def __init__(
        self,
        in_channels=1,
        out_channels=1,
        in_size=28,
        n_transformer_blocks=8,
        n_attention_heads=4,
        n_embedding_channels=16,
        sample_fn=None,
    ):
        """Initializes a new ImageGPT instance.

        Args:
            in_channels: The number of input channels.
            out_channels: The number of output channels.
            in_size: Size of the input images. Used to create positional encodings.
            n_transformer_blocks: Number of TransformerBlocks to use.
            n_attention_heads: Number of attention heads to use.
            n_embedding_channels: Number of attention embedding channels to use.
            sample_fn: See the base class.
        """
        super().__init__(sample_fn)
        self._pos = nn.Parameter(torch.zeros(1, in_channels, in_size, in_size))
        self._input = pg_nn.CausalConv2d(
            mask_center=True,
            in_channels=in_channels,
            out_channels=n_embedding_channels,
            kernel_size=3,
            padding=1,
        )
        self._transformer = nn.ModuleList(
            TransformerBlock(n_channels=n_embedding_channels,
                             n_attention_heads=n_attention_heads)
            for _ in range(n_transformer_blocks))
        self._ln = pg_nn.NCHWLayerNorm(n_embedding_channels)
        self._out = nn.Conv2d(in_channels=n_embedding_channels,
                              out_channels=out_channels,
                              kernel_size=1)