def __init__(self, n_channels, n_attention_heads): """Initializes a new TransformerBlock instance. Args: n_channels: The number of input and output channels. n_attention_heads: The number of attention heads to use. """ super().__init__() self._attn = pg_nn.MaskedAttention( in_channels=n_channels, embed_channels=n_channels, out_channels=n_channels, n_heads=n_attention_heads, is_causal=False) self._out = nn.Sequential( nn.Conv2d( in_channels=n_channels, out_channels=4*n_channels, kernel_size=1), nn.GELU(), nn.Conv2d ( in_channels=4*n_channels, out_channels=n_channels, kernel_size=1))
def __init__( self, n_channels, input_img_channels=1, n_residual_blocks=2, attention_key_channels=4, attention_value_channels=32, ): """Initializes a new PixelSnailBlock instance. Args: n_channels: Number of input and output channels. input_img_channels: The number of channels in the original input_img. Used for the positional encoding channels and the extra channels for the key and value convolutions in the attention block. n_residual_blocks: Number of residual blocks. attention_key_channels: Number of channels (dims) for the attention key. attention_value_channels: Number of channels (dims) for the attention value. """ super().__init__() def conv(in_channels): return nn.Conv2d(in_channels, out_channels=n_channels, kernel_size=1) self._residual = nn.Sequential( *[ResidualBlock(n_channels) for _ in range(n_residual_blocks)]) self._attention = pg_nn.MaskedAttention( in_channels=n_channels + 2 * input_img_channels, embed_channels=attention_key_channels, out_channels=attention_value_channels, is_causal=True, extra_input_channels=input_img_channels, ) self._residual_out = conv(n_channels) self._attention_out = conv(attention_value_channels) self._out = conv(n_channels)