def __init__( self, in_channels=1, out_channels=1, n_residual=15, residual_channels=128, head_channels=32, sample_fn=None, ): """Initializes a new PixelCNN instance. Args: in_channels: The number of input channels. out_channels: The number of output channels. n_residual: The number of residual blocks. residual_channels: The number of channels to use in the residual layers. head_channels: The number of channels to use in the two 1x1 convolutional layers at the head of the network. sample_fn: See the base class. """ super().__init__(sample_fn) self._input = pg_nn.MaskedConv2d( is_causal=True, in_channels=in_channels, out_channels=2 * residual_channels, kernel_size=7, padding=3, ) self._masked_layers = nn.ModuleList([ MaskedResidualBlock(n_channels=2 * residual_channels) for _ in range(n_residual) ]) self._head = nn.Sequential( nn.ReLU(), # nn.Conv2d( # in_channels=2 * residual_channels, # out_channels=head_channels, # kernel_size=1, # ), ARMA2d(2 * residual_channels, head_channels, w_stride=1, w_kernel_size=w_ksz2, w_padding=w_ksz2 // 2, a_init=init, a_kernel_size=a_ksz2, a_padding=a_ksz2 // 2), nn.ReLU(), # nn.Conv2d( # in_channels=head_channels, out_channels=out_channels, kernel_size=1 # ), ARMA2d(head_channels, out_channels, w_stride=1, w_kernel_size=w_ksz2, w_padding=w_ksz2 // 2, a_init=init, a_kernel_size=a_ksz2, a_padding=a_ksz2 // 2), )
def __init__(self, n_channels): """Initializes a new MaskedResidualBlock instance. Args: n_channels: The number of input (and output) channels. """ super().__init__() self._net = nn.Sequential( nn.ReLU(), nn.Conv2d( in_channels=n_channels, out_channels=n_channels // 2, kernel_size=1 ), nn.ReLU(), pg_nn.MaskedConv2d( is_causal=False, in_channels=n_channels // 2, out_channels=n_channels // 2, kernel_size=3, padding=1, ), nn.ReLU(), nn.Conv2d( in_channels=n_channels // 2, out_channels=n_channels, kernel_size=1 ), )
def __init__(self, in_channels=1, out_channels=1, in_size=28, n_transformer_blocks=8, n_attention_heads=4, n_embedding_channels=16, sample_fn=None): """Initializes a new ImageGPT instance. Args: in_channels: The number of input channels. out_channels: The number of output channels. in_size: Size of the input images. Used to create positional encodings. probs_fn: See the base class. n_transformer_blocks: Number of TransformerBlocks to use. n_attention_heads: Number of attention heads to use. n_embedding_channels: Number of attention embedding channels to use. sample_fn: See the base class. """ super().__init__(sample_fn) self._pos = nn.Parameter(torch.zeros(1, in_channels, in_size, in_size)) self._input = pg_nn.MaskedConv2d(is_causal=True, in_channels=in_channels, out_channels=n_embedding_channels, kernel_size=3, padding=1) self._transformer = nn.ModuleList( TransformerBlock(n_channels=n_embedding_channels, n_attention_heads=n_attention_heads) for _ in range(n_transformer_blocks)) self._ln = pg_nn.NCHWLayerNorm(n_embedding_channels) self._out = nn.Conv2d(in_channels=n_embedding_channels, out_channels=out_channels, kernel_size=1)
def __init__(self, in_channels=3, out_dim=1, probs_fn=torch.sigmoid, sample_fn=lambda x: distributions.Bernoulli(probs=x).sample(), n_channels=64, n_pixel_snail_blocks=8, n_residual_blocks=2, attention_key_channels=4, attention_value_channels=32, head_channels=1): """Initializes a new PixelSNAIL instance. Args: in_channels: The number of channels in the input image (typically either 1 or 3 for black and white or color images respectively). out_dim: The dimension of the output. Given input of the form NCHW, the output from the model will be N out_dim CHW. probs_fn: See the base class. sample_fn: See the base class. n_channels: The number of channels to use for convolutions. n_pixel_snail_blocks: The number of PixelSNAILBlocks. n_residual_blocks: The number of ResidualBlock to use in each PixelSnailBlock. attention_key_channels: Number of channels (dimension) for the attention key. attention_value_channels: Number of channels (dimension) for the attention value. """ super().__init__(probs_fn, sample_fn) self._out_dim = out_dim self._input = pg_nn.MaskedConv2d(is_causal=True, in_channels=in_channels, out_channels=n_channels, kernel_size=3, padding=1) self._pixel_snail_blocks = nn.ModuleList([ PixelSNAILBlock(n_channels=n_channels, n_residual_blocks=n_residual_blocks, attention_key_channels=attention_key_channels, attention_value_channels=attention_value_channels) for _ in range(n_pixel_snail_blocks) ]) self._output = nn.Sequential( nn.Conv2d(in_channels=n_channels, out_channels=head_channels, kernel_size=1), nn.Conv2d(in_channels=head_channels, out_channels=self._out_dim * in_channels, kernel_size=1))
def __init__(self, in_channels=1, out_channels=1, sample_fn=None): """Initializes a new TinyCNN instance. Args: in_channels: Number of input channels. out_channels: Number of output channels. sample_fn: See the base class. """ super().__init__(sample_fn) self._conv = pg_nn.MaskedConv2d(is_causal=True, in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
def __init__( self, in_channels=1, out_channels=1, n_channels=64, n_pixel_snail_blocks=8, n_residual_blocks=2, attention_key_channels=4, attention_value_channels=32, sample_fn=None, ): """Initializes a new PixelSNAIL instance. Args: in_channels: Number of input channels. out_channels: Number of output_channels. n_channels: Number of channels to use for convolutions. n_pixel_snail_blocks: Number of PixelSNAILBlocks. n_residual_blocks: Number of ResidualBlock to use in each PixelSnailBlock. attention_key_channels: Number of channels (dims) for the attention key. attention_value_channels: Number of channels (dims) for the attention value. sample_fn: See the base class. """ super().__init__(sample_fn) self._input = pg_nn.MaskedConv2d( is_causal=True, in_channels=in_channels, out_channels=n_channels, kernel_size=3, padding=1, ) self._pixel_snail_blocks = nn.ModuleList([ PixelSNAILBlock( n_channels=n_channels, input_img_channels=in_channels, n_residual_blocks=n_residual_blocks, attention_key_channels=attention_key_channels, attention_value_channels=attention_value_channels, ) for _ in range(n_pixel_snail_blocks) ]) self._output = nn.Sequential( nn.Conv2d(in_channels=n_channels, out_channels=n_channels // 2, kernel_size=1), nn.Conv2d(in_channels=n_channels // 2, out_channels=out_channels, kernel_size=1), ) self._logsoftmax = nn.LogSoftmax(dim=1)
def __init__(self, in_channels=1, out_dim=1, probs_fn=torch.sigmoid, sample_fn=lambda x: distributions.Bernoulli(probs=x).sample(), n_residual=15, residual_channels=128, head_channels=32): """Initializes a new PixelCNN instance. Args: in_channels: The number of channels in the input image (typically either 1 or 3 for black and white or color images respectively). out_dim: The dimension of the output. Given input of the form NCHW, the output from the model will be N out_dim CHW. probs_fn: See the base class. sample_fn: See the base class. n_residual: The number of residual blocks. residual_channels: The number of channels to use in the residual layers. head_channels: The number of channels to use in the two 1x1 convolutional layers at the head of the network. """ super().__init__(probs_fn, sample_fn) self._out_dim = out_dim self._input = pg_nn.MaskedConv2d(is_causal=True, in_channels=in_channels, out_channels=2 * residual_channels, kernel_size=7, padding=3) self._masked_layers = nn.ModuleList([ MaskedResidualBlock(n_channels=2 * residual_channels) for _ in range(n_residual) ]) self._head = nn.Sequential( nn.ReLU(), nn.Conv2d(in_channels=2 * residual_channels, out_channels=head_channels, kernel_size=1), nn.ReLU(), nn.Conv2d(in_channels=head_channels, out_channels=self._out_dim * in_channels, kernel_size=1))
def __init__( self, in_channels=1, out_dim=1, probs_fn=torch.sigmoid, sample_fn=lambda x: distributions.Bernoulli(probs=x).sample()): """Initializes a new TinyCNN instance. Args: in_channels: Number of input channels. out_dim: Dimension of the output per channel. probs_fn: See the base class. sample_fn: See the base class. """ super().__init__(probs_fn, sample_fn) self._out_dim = out_dim self._conv = pg_nn.MaskedConv2d(is_causal=True, in_channels=in_channels, out_channels=out_dim * in_channels, kernel_size=3, padding=1)
def __init__(self, in_channels, in_size, out_dim=1, probs_fn=torch.sigmoid, sample_fn=lambda x: distributions.Bernoulli(probs=x).sample(), n_transformer_blocks=8, n_attention_heads=4, n_embedding_channels=16): """Initializes a new ImageGPT instance. Args: in_channels: The number of input channels. in_size: Size of the input images. Used to create positional encodings. out_dim: The dimension of the output. Given input of the form NCHW, the output from the model will be N out_dim CHW. probs_fn: See the base class. sample_fn: See the base class. n_transformer_blocks: Number of TransformerBlocks to use. n_attention_heads: Number of attention heads to use. n_embedding_channels: Number of attention embedding channels to use. """ super().__init__(probs_fn, sample_fn) self._out_dim = out_dim self._pos = nn.Parameter(torch.zeros(1, in_channels, in_size, in_size)) self._input = pg_nn.MaskedConv2d( is_causal=True, in_channels=in_channels, out_channels=n_embedding_channels, kernel_size=3, padding=1) self._transformer = nn.ModuleList( TransformerBlock(n_channels=n_embedding_channels, n_attention_heads=n_attention_heads) for _ in range(n_transformer_blocks)) self._out = nn.Conv2d(in_channels=n_embedding_channels, out_channels=self._out_dim * in_channels, kernel_size=1)
def __init__(self, n_channels): """Initializes a new MaskedResidualBlock instance. Args: n_channels: The number of input (and output) channels. """ super().__init__() self._net = nn.Sequential( # NOTE(eugenhotaj): The PixelCNN paper users ReLU->Conv2d since they do # not use a ReLU in the first layer. nn.Conv2d(in_channels=n_channels, out_channels=n_channels // 2, kernel_size=1), nn.ReLU(), pg_nn.MaskedConv2d(is_causal=False, in_channels=n_channels // 2, out_channels=n_channels // 2, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(in_channels=n_channels // 2, out_channels=n_channels, kernel_size=1), nn.ReLU())