def __init__(self, n_channels): """Initializes a new CausalResidualBlock instance. Args: n_channels: The number of input (and output) channels. """ super().__init__() self._net = nn.Sequential( nn.ReLU(), nn.Conv2d(in_channels=n_channels, out_channels=n_channels // 2, kernel_size=1), nn.ReLU(), pg_nn.CausalConv2d( mask_center=False, in_channels=n_channels // 2, out_channels=n_channels // 2, kernel_size=3, padding=1, ), nn.ReLU(), nn.Conv2d(in_channels=n_channels // 2, out_channels=n_channels, kernel_size=1), )
def __init__( self, in_channels=1, out_channels=1, n_channels=64, n_pixel_snail_blocks=8, n_residual_blocks=2, attention_key_channels=4, attention_value_channels=32, sample_fn=None, ): """Initializes a new PixelSNAIL instance. Args: in_channels: Number of input channels. out_channels: Number of output_channels. n_channels: Number of channels to use for convolutions. n_pixel_snail_blocks: Number of PixelSNAILBlocks. n_residual_blocks: Number of ResidualBlock to use in each PixelSnailBlock. attention_key_channels: Number of channels (dims) for the attention key. attention_value_channels: Number of channels (dims) for the attention value. sample_fn: See the base class. """ super().__init__(sample_fn) self._input = pg_nn.CausalConv2d( mask_center=True, in_channels=in_channels, out_channels=n_channels, kernel_size=3, padding=1, ) self._pixel_snail_blocks = nn.ModuleList([ PixelSNAILBlock( n_channels=n_channels, input_img_channels=in_channels, n_residual_blocks=n_residual_blocks, attention_key_channels=attention_key_channels, attention_value_channels=attention_value_channels, ) for _ in range(n_pixel_snail_blocks) ]) self._output = nn.Sequential( nn.Conv2d(in_channels=n_channels, out_channels=n_channels // 2, kernel_size=1), nn.Conv2d(in_channels=n_channels // 2, out_channels=out_channels, kernel_size=1), )
def __init__(self, in_channels=1, out_channels=1, sample_fn=None): """Initializes a new TinyCNN instance. Args: in_channels: Number of input channels. out_channels: Number of output channels. sample_fn: See the base class. """ super().__init__(sample_fn) self._conv = pg_nn.CausalConv2d( mask_center=True, in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, )
def __init__( self, in_channels=1, out_channels=1, n_residual=15, residual_channels=128, head_channels=32, sample_fn=None, ): """Initializes a new PixelCNN instance. Args: in_channels: The number of input channels. out_channels: The number of output channels. n_residual: The number of residual blocks. residual_channels: The number of channels to use in the residual layers. head_channels: The number of channels to use in the two 1x1 convolutional layers at the head of the network. sample_fn: See the base class. """ super().__init__(sample_fn) self._input = pg_nn.CausalConv2d( mask_center=True, in_channels=in_channels, out_channels=2 * residual_channels, kernel_size=7, padding=3, ) self._causal_layers = nn.ModuleList( [ CausalResidualBlock(n_channels=2 * residual_channels) for _ in range(n_residual) ] ) self._head = nn.Sequential( nn.ReLU(), nn.Conv2d( in_channels=2 * residual_channels, out_channels=head_channels, kernel_size=1, ), nn.ReLU(), nn.Conv2d( in_channels=head_channels, out_channels=out_channels, kernel_size=1 ), )
def __init__( self, in_channels=1, out_channels=1, in_size=28, n_transformer_blocks=8, n_attention_heads=4, n_embedding_channels=16, sample_fn=None, ): """Initializes a new ImageGPT instance. Args: in_channels: The number of input channels. out_channels: The number of output channels. in_size: Size of the input images. Used to create positional encodings. n_transformer_blocks: Number of TransformerBlocks to use. n_attention_heads: Number of attention heads to use. n_embedding_channels: Number of attention embedding channels to use. sample_fn: See the base class. """ super().__init__(sample_fn) self._pos = nn.Parameter(torch.zeros(1, in_channels, in_size, in_size)) self._input = pg_nn.CausalConv2d( mask_center=True, in_channels=in_channels, out_channels=n_embedding_channels, kernel_size=3, padding=1, ) self._transformer = nn.ModuleList( TransformerBlock(n_channels=n_embedding_channels, n_attention_heads=n_attention_heads) for _ in range(n_transformer_blocks)) self._ln = pg_nn.NCHWLayerNorm(n_embedding_channels) self._out = nn.Conv2d(in_channels=n_embedding_channels, out_channels=out_channels, kernel_size=1)