Example #1
0
    def __init__(self, n_channels, n_attention_heads):
        """Initializes a new TransformerBlock instance.

        Args:
            n_channels: The number of input and output channels.
            n_attention_heads: The number of attention heads to use.
        """
        super().__init__()
        self._ln1 = pg_nn.NCHWLayerNorm(n_channels)
        self._ln2 = pg_nn.NCHWLayerNorm(n_channels)
        self._attn = pg_nn.CausalAttention(
            in_channels=n_channels,
            n_heads=n_attention_heads,
            embed_channels=n_channels,
            out_channels=n_channels,
        )
        self._out = nn.Sequential(
            nn.Conv2d(in_channels=n_channels,
                      out_channels=4 * n_channels,
                      kernel_size=1),
            nn.GELU(),
            nn.Conv2d(in_channels=4 * n_channels,
                      out_channels=n_channels,
                      kernel_size=1),
        )
Example #2
0
 def __init__(self,
              in_channels=1,
              out_channels=1,
              in_size=28,
              n_transformer_blocks=8,
              n_attention_heads=4,
              n_embedding_channels=16,
              sample_fn=None):
     """Initializes a new ImageGPT instance.
 
 Args:
   in_channels: The number of input channels.
   out_channels: The number of output channels.
   in_size: Size of the input images. Used to create positional encodings.
   probs_fn: See the base class.
   n_transformer_blocks: Number of TransformerBlocks to use.
   n_attention_heads: Number of attention heads to use.
   n_embedding_channels: Number of attention embedding channels to use.
   sample_fn: See the base class.
 """
     super().__init__(sample_fn)
     self._pos = nn.Parameter(torch.zeros(1, in_channels, in_size, in_size))
     self._input = pg_nn.MaskedConv2d(is_causal=True,
                                      in_channels=in_channels,
                                      out_channels=n_embedding_channels,
                                      kernel_size=3,
                                      padding=1)
     self._transformer = nn.ModuleList(
         TransformerBlock(n_channels=n_embedding_channels,
                          n_attention_heads=n_attention_heads)
         for _ in range(n_transformer_blocks))
     self._ln = pg_nn.NCHWLayerNorm(n_embedding_channels)
     self._out = nn.Conv2d(in_channels=n_embedding_channels,
                           out_channels=out_channels,
                           kernel_size=1)