def __init__(self, input_size: int, depth: int = 32, act: ActFunc = None, shape: Tuple[int] = (3, 64, 64)): """Initializes a ConvDecoder instance. Args: input_size (int): Input size, usually feature size output from RSSM. depth (int): Number of channels in the first conv layer act (Any): Activation for Encoder, default ReLU shape (List): Shape of observation input """ super().__init__() self.act = act if not act: self.act = nn.ReLU self.depth = depth self.shape = shape self.layers = [ Linear(input_size, 32 * self.depth), Reshape([-1, 32 * self.depth, 1, 1]), ConvTranspose2d(32 * self.depth, 4 * self.depth, 5, stride=2), self.act(), ConvTranspose2d(4 * self.depth, 2 * self.depth, 5, stride=2), self.act(), ConvTranspose2d(2 * self.depth, self.depth, 6, stride=2), self.act(), ConvTranspose2d(self.depth, self.shape[0], 6, stride=2), ] self.model = nn.Sequential(*self.layers)
def __init__( self, *, input_size: int, filters: Tuple[Tuple[int]] = ( (1024, 5, 2), (128, 5, 2), (64, 6, 2), (32, 6, 2), ), initializer="default", bias_init=0, activation_fn: str = "relu", output_shape: Tuple[int] = (3, 64, 64) ): """Initializes a TransposedConv2DStack instance. Args: input_size: The size of the 1D input vector, from which to generate the image distribution. filters (Tuple[Tuple[int]]): Tuple of filter setups (1 for each ConvTranspose2D layer): [in_channels, kernel, stride]. initializer (Union[str]): bias_init: The initial bias values to use. activation_fn: Activation function descriptor (str). output_shape (Tuple[int]): Shape of the final output image. """ super().__init__() self.activation = get_activation_fn(activation_fn, framework="torch") self.output_shape = output_shape initializer = get_initializer(initializer, framework="torch") in_channels = filters[0][0] self.layers = [ # Map from 1D-input vector to correct initial size for the # Conv2DTransposed stack. nn.Linear(input_size, in_channels), # Reshape from the incoming 1D vector (input_size) to 1x1 image # format (channels first). Reshape([-1, in_channels, 1, 1]), ] for i, (_, kernel, stride) in enumerate(filters): out_channels = ( filters[i + 1][0] if i < len(filters) - 1 else output_shape[0] ) conv_transp = nn.ConvTranspose2d(in_channels, out_channels, kernel, stride) # Apply initializer. initializer(conv_transp.weight) nn.init.constant_(conv_transp.bias, bias_init) self.layers.append(conv_transp) # Apply activation function, if provided and if not last layer. if self.activation is not None and i < len(filters) - 1: self.layers.append(self.activation()) # num-outputs == num-inputs for next layer. in_channels = out_channels self._model = nn.Sequential(*self.layers)