Beispiel #1
0
  def __init__(self, 
               n_channels, 
               n_attention_heads):
    """Initializes a new TransformerBlock instance.

    Args:
      n_channels: The number of input and output channels.
      n_attention_heads: The number of attention heads to use.
    """
    super().__init__()
    self._attn = pg_nn.MaskedAttention(
        in_channels=n_channels,
        embed_channels=n_channels,
        out_channels=n_channels,
        n_heads=n_attention_heads,
        is_causal=False)
    self._out = nn.Sequential(
        nn.Conv2d(
            in_channels=n_channels, 
            out_channels=4*n_channels, 
            kernel_size=1),
        nn.GELU(),
        nn.Conv2d (
            in_channels=4*n_channels, 
            out_channels=n_channels, 
            kernel_size=1))
    def __init__(
        self,
        n_channels,
        input_img_channels=1,
        n_residual_blocks=2,
        attention_key_channels=4,
        attention_value_channels=32,
    ):
        """Initializes a new PixelSnailBlock instance.

        Args:
            n_channels: Number of input and output channels.
            input_img_channels: The number of channels in the original input_img. Used
                for the positional encoding channels and the extra channels for the key
                and value convolutions in the attention block.
            n_residual_blocks: Number of residual blocks.
            attention_key_channels: Number of channels (dims) for the attention key.
            attention_value_channels: Number of channels (dims) for the attention value.
        """
        super().__init__()

        def conv(in_channels):
            return nn.Conv2d(in_channels,
                             out_channels=n_channels,
                             kernel_size=1)

        self._residual = nn.Sequential(
            *[ResidualBlock(n_channels) for _ in range(n_residual_blocks)])
        self._attention = pg_nn.MaskedAttention(
            in_channels=n_channels + 2 * input_img_channels,
            embed_channels=attention_key_channels,
            out_channels=attention_value_channels,
            is_causal=True,
            extra_input_channels=input_img_channels,
        )
        self._residual_out = conv(n_channels)
        self._attention_out = conv(attention_value_channels)
        self._out = conv(n_channels)