def __init__(
        self,
        in_channels=1,
        out_channels=1,
        n_residual=15,
        residual_channels=128,
        head_channels=32,
        sample_fn=None,
    ):
        """Initializes a new PixelCNN instance.

        Args:
            in_channels: The number of input channels.
            out_channels: The number of output channels.
            n_residual: The number of residual blocks.
            residual_channels: The number of channels to use in the residual layers.
            head_channels: The number of channels to use in the two 1x1 convolutional
                layers at the head of the network.
            sample_fn: See the base class.
        """
        super().__init__(sample_fn)
        self._input = pg_nn.MaskedConv2d(
            is_causal=True,
            in_channels=in_channels,
            out_channels=2 * residual_channels,
            kernel_size=7,
            padding=3,
        )
        self._masked_layers = nn.ModuleList([
            MaskedResidualBlock(n_channels=2 * residual_channels)
            for _ in range(n_residual)
        ])
        self._head = nn.Sequential(
            nn.ReLU(),
            # nn.Conv2d(
            #     in_channels=2 * residual_channels,
            #     out_channels=head_channels,
            #     kernel_size=1,
            # ),
            ARMA2d(2 * residual_channels,
                   head_channels,
                   w_stride=1,
                   w_kernel_size=w_ksz2,
                   w_padding=w_ksz2 // 2,
                   a_init=init,
                   a_kernel_size=a_ksz2,
                   a_padding=a_ksz2 // 2),
            nn.ReLU(),
            # nn.Conv2d(
            #     in_channels=head_channels, out_channels=out_channels, kernel_size=1
            # ),
            ARMA2d(head_channels,
                   out_channels,
                   w_stride=1,
                   w_kernel_size=w_ksz2,
                   w_padding=w_ksz2 // 2,
                   a_init=init,
                   a_kernel_size=a_ksz2,
                   a_padding=a_ksz2 // 2),
        )
Exemple #2
0
    def __init__(self, n_channels):
        """Initializes a new MaskedResidualBlock instance.

        Args:
            n_channels: The number of input (and output) channels.
        """
        super().__init__()
        self._net = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(
                in_channels=n_channels, out_channels=n_channels // 2, kernel_size=1
            ),
            nn.ReLU(),
            pg_nn.MaskedConv2d(
                is_causal=False,
                in_channels=n_channels // 2,
                out_channels=n_channels // 2,
                kernel_size=3,
                padding=1,
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=n_channels // 2, out_channels=n_channels, kernel_size=1
            ),
        )
Exemple #3
0
 def __init__(self,
              in_channels=1,
              out_channels=1,
              in_size=28,
              n_transformer_blocks=8,
              n_attention_heads=4,
              n_embedding_channels=16,
              sample_fn=None):
     """Initializes a new ImageGPT instance.
 
 Args:
   in_channels: The number of input channels.
   out_channels: The number of output channels.
   in_size: Size of the input images. Used to create positional encodings.
   probs_fn: See the base class.
   n_transformer_blocks: Number of TransformerBlocks to use.
   n_attention_heads: Number of attention heads to use.
   n_embedding_channels: Number of attention embedding channels to use.
   sample_fn: See the base class.
 """
     super().__init__(sample_fn)
     self._pos = nn.Parameter(torch.zeros(1, in_channels, in_size, in_size))
     self._input = pg_nn.MaskedConv2d(is_causal=True,
                                      in_channels=in_channels,
                                      out_channels=n_embedding_channels,
                                      kernel_size=3,
                                      padding=1)
     self._transformer = nn.ModuleList(
         TransformerBlock(n_channels=n_embedding_channels,
                          n_attention_heads=n_attention_heads)
         for _ in range(n_transformer_blocks))
     self._ln = pg_nn.NCHWLayerNorm(n_embedding_channels)
     self._out = nn.Conv2d(in_channels=n_embedding_channels,
                           out_channels=out_channels,
                           kernel_size=1)
    def __init__(self,
                 in_channels=3,
                 out_dim=1,
                 probs_fn=torch.sigmoid,
                 sample_fn=lambda x: distributions.Bernoulli(probs=x).sample(),
                 n_channels=64,
                 n_pixel_snail_blocks=8,
                 n_residual_blocks=2,
                 attention_key_channels=4,
                 attention_value_channels=32,
                 head_channels=1):
        """Initializes a new PixelSNAIL instance.

    Args:
      in_channels: The number of channels in the input image (typically either 
        1 or 3 for black and white or color images respectively).
      out_dim: The dimension of the output. Given input of the form NCHW, the 
        output from the model will be N out_dim CHW.
      probs_fn: See the base class.
      sample_fn: See the base class.
      n_channels: The number of channels to use for convolutions.
      n_pixel_snail_blocks: The number of PixelSNAILBlocks. 
      n_residual_blocks: The number of ResidualBlock to use in each 
        PixelSnailBlock.
      attention_key_channels: Number of channels (dimension) for the attention 
        key.
      attention_value_channels: Number of channels (dimension) for the attention 
        value.

    """
        super().__init__(probs_fn, sample_fn)
        self._out_dim = out_dim

        self._input = pg_nn.MaskedConv2d(is_causal=True,
                                         in_channels=in_channels,
                                         out_channels=n_channels,
                                         kernel_size=3,
                                         padding=1)
        self._pixel_snail_blocks = nn.ModuleList([
            PixelSNAILBlock(n_channels=n_channels,
                            n_residual_blocks=n_residual_blocks,
                            attention_key_channels=attention_key_channels,
                            attention_value_channels=attention_value_channels)
            for _ in range(n_pixel_snail_blocks)
        ])
        self._output = nn.Sequential(
            nn.Conv2d(in_channels=n_channels,
                      out_channels=head_channels,
                      kernel_size=1),
            nn.Conv2d(in_channels=head_channels,
                      out_channels=self._out_dim * in_channels,
                      kernel_size=1))
Exemple #5
0
    def __init__(self, in_channels=1, out_channels=1, sample_fn=None):
        """Initializes a new TinyCNN instance.

    Args:
      in_channels: Number of input channels.
      out_channels: Number of output channels.
      sample_fn: See the base class.
    """
        super().__init__(sample_fn)
        self._conv = pg_nn.MaskedConv2d(is_causal=True,
                                        in_channels=in_channels,
                                        out_channels=out_channels,
                                        kernel_size=3,
                                        padding=1)
    def __init__(
        self,
        in_channels=1,
        out_channels=1,
        n_channels=64,
        n_pixel_snail_blocks=8,
        n_residual_blocks=2,
        attention_key_channels=4,
        attention_value_channels=32,
        sample_fn=None,
    ):
        """Initializes a new PixelSNAIL instance.

        Args:
            in_channels: Number of input channels.
            out_channels: Number of output_channels.
            n_channels: Number of channels to use for convolutions.
            n_pixel_snail_blocks: Number of PixelSNAILBlocks.
            n_residual_blocks: Number of ResidualBlock to use in each PixelSnailBlock.
            attention_key_channels: Number of channels (dims) for the attention key.
            attention_value_channels: Number of channels (dims) for the attention value.
            sample_fn: See the base class.
        """
        super().__init__(sample_fn)
        self._input = pg_nn.MaskedConv2d(
            is_causal=True,
            in_channels=in_channels,
            out_channels=n_channels,
            kernel_size=3,
            padding=1,
        )
        self._pixel_snail_blocks = nn.ModuleList([
            PixelSNAILBlock(
                n_channels=n_channels,
                input_img_channels=in_channels,
                n_residual_blocks=n_residual_blocks,
                attention_key_channels=attention_key_channels,
                attention_value_channels=attention_value_channels,
            ) for _ in range(n_pixel_snail_blocks)
        ])
        self._output = nn.Sequential(
            nn.Conv2d(in_channels=n_channels,
                      out_channels=n_channels // 2,
                      kernel_size=1),
            nn.Conv2d(in_channels=n_channels // 2,
                      out_channels=out_channels,
                      kernel_size=1),
        )

        self._logsoftmax = nn.LogSoftmax(dim=1)
Exemple #7
0
    def __init__(self,
                 in_channels=1,
                 out_dim=1,
                 probs_fn=torch.sigmoid,
                 sample_fn=lambda x: distributions.Bernoulli(probs=x).sample(),
                 n_residual=15,
                 residual_channels=128,
                 head_channels=32):
        """Initializes a new PixelCNN instance.
    
    Args:
      in_channels: The number of channels in the input image (typically either 
        1 or 3 for black and white or color images respectively).
      out_dim: The dimension of the output. Given input of the form NCHW, the 
        output from the model will be N out_dim CHW.
      probs_fn: See the base class.
      sample_fn: See the base class.
      n_residual: The number of residual blocks.
      residual_channels: The number of channels to use in the residual layers.
      head_channels: The number of channels to use in the two 1x1 convolutional
        layers at the head of the network.
    """
        super().__init__(probs_fn, sample_fn)
        self._out_dim = out_dim

        self._input = pg_nn.MaskedConv2d(is_causal=True,
                                         in_channels=in_channels,
                                         out_channels=2 * residual_channels,
                                         kernel_size=7,
                                         padding=3)
        self._masked_layers = nn.ModuleList([
            MaskedResidualBlock(n_channels=2 * residual_channels)
            for _ in range(n_residual)
        ])
        self._head = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(in_channels=2 * residual_channels,
                      out_channels=head_channels,
                      kernel_size=1), nn.ReLU(),
            nn.Conv2d(in_channels=head_channels,
                      out_channels=self._out_dim * in_channels,
                      kernel_size=1))
Exemple #8
0
    def __init__(
        self,
        in_channels=1,
        out_dim=1,
        probs_fn=torch.sigmoid,
        sample_fn=lambda x: distributions.Bernoulli(probs=x).sample()):
        """Initializes a new TinyCNN instance.

    Args:
      in_channels: Number of input channels.
      out_dim: Dimension of the output per channel.
      probs_fn: See the base class.
      sample_fn: See the base class.
    """
        super().__init__(probs_fn, sample_fn)
        self._out_dim = out_dim
        self._conv = pg_nn.MaskedConv2d(is_causal=True,
                                        in_channels=in_channels,
                                        out_channels=out_dim * in_channels,
                                        kernel_size=3,
                                        padding=1)
Exemple #9
0
 def __init__(self,       
              in_channels,
              in_size,
              out_dim=1,
              probs_fn=torch.sigmoid,
              sample_fn=lambda x: distributions.Bernoulli(probs=x).sample(),
              n_transformer_blocks=8,
              n_attention_heads=4,
              n_embedding_channels=16):
   """Initializes a new ImageGPT instance.
   
   Args:
     in_channels: The number of input channels.
     in_size: Size of the input images. Used to create positional encodings.
     out_dim: The dimension of the output. Given input of the form NCHW, the 
       output from the model will be N out_dim CHW.
     probs_fn: See the base class.
     sample_fn: See the base class.
     n_transformer_blocks: Number of TransformerBlocks to use.
     n_attention_heads: Number of attention heads to use.
     n_embedding_channels: Number of attention embedding channels to use.
   """
   super().__init__(probs_fn, sample_fn)
   self._out_dim = out_dim
   self._pos = nn.Parameter(torch.zeros(1, in_channels, in_size, in_size))
   self._input = pg_nn.MaskedConv2d(
       is_causal=True, 
       in_channels=in_channels,
       out_channels=n_embedding_channels,
       kernel_size=3,
       padding=1)
   self._transformer = nn.ModuleList(
       TransformerBlock(n_channels=n_embedding_channels,
                        n_attention_heads=n_attention_heads)
       for _ in range(n_transformer_blocks))
   self._out = nn.Conv2d(in_channels=n_embedding_channels,
                         out_channels=self._out_dim * in_channels,
                         kernel_size=1)
Exemple #10
0
    def __init__(self, n_channels):
        """Initializes a new MaskedResidualBlock instance.

    Args:
      n_channels: The number of input (and output) channels.
    """
        super().__init__()
        self._net = nn.Sequential(
            # NOTE(eugenhotaj): The PixelCNN paper users ReLU->Conv2d since they do
            # not use a ReLU in the first layer.
            nn.Conv2d(in_channels=n_channels,
                      out_channels=n_channels // 2,
                      kernel_size=1),
            nn.ReLU(),
            pg_nn.MaskedConv2d(is_causal=False,
                               in_channels=n_channels // 2,
                               out_channels=n_channels // 2,
                               kernel_size=3,
                               padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=n_channels // 2,
                      out_channels=n_channels,
                      kernel_size=1),
            nn.ReLU())