def __init__(self, in_channel, out_channel): super().__init__() self._seq = nn.Sequential( nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1), nn.Conv2d(out_channel, out_channel // 2, kernel_size=1), nn.BatchNorm2d(out_channel // 2), Swish(), nn.Conv2d(out_channel // 2, out_channel, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(out_channel), Swish() )
def __init__(self, z_dim): super().__init__() # Input channels = z_channels * 2 = x_channels + z_channels # Output channels = z_channels self.decoder_blocks = nn.ModuleList([ DecoderBlock([z_dim * 2, z_dim // 2]), # 2x upsample DecoderBlock([z_dim, z_dim // 4, z_dim // 8]), # 4x upsample DecoderBlock([z_dim // 4, z_dim // 16, z_dim // 32]) # 4x uplsampe ]) self.decoder_residual_blocks = nn.ModuleList([ DecoderResidualBlock(z_dim // 2, n_group=1), DecoderResidualBlock(z_dim // 8, n_group=2), DecoderResidualBlock(z_dim // 32, n_group=4) ]) # p(z_l | z_(l-1)) self.condition_z = nn.ModuleList([ nn.Sequential(ResidualBlock(z_dim // 2), nn.AdaptiveAvgPool2d(1), Swish(), nn.Conv2d(z_dim // 2, z_dim, kernel_size=1)), nn.Sequential(ResidualBlock(z_dim // 8), nn.AdaptiveAvgPool2d(1), Swish(), nn.Conv2d(z_dim // 8, z_dim // 4, kernel_size=1)) ]) # p(z_l | x, z_(l-1)) self.condition_xz = nn.ModuleList([ nn.Sequential(ResidualBlock(z_dim), nn.Conv2d(z_dim, z_dim // 2, kernel_size=1), nn.AdaptiveAvgPool2d(1), Swish(), nn.Conv2d(z_dim // 2, z_dim, kernel_size=1)), nn.Sequential(ResidualBlock(z_dim // 4), nn.Conv2d(z_dim // 4, z_dim // 8, kernel_size=1), nn.AdaptiveAvgPool2d(1), Swish(), nn.Conv2d(z_dim // 8, z_dim // 4, kernel_size=1)) ]) self.map_from_z = nn.ModuleList([ nn.Conv2d(z_dim + 2, z_dim, kernel_size=1), nn.Conv2d(z_dim // 2 + 2, z_dim // 2, kernel_size=1), nn.Conv2d(z_dim // 8 + 2, z_dim // 8, kernel_size=1) ]) self.recon = nn.Conv2d(z_dim // 32, 3, kernel_size=1) self.pos_map = nn.Linear(2, 2)
def __init__(self, in_channel, out_channel): super().__init__() self._seq = nn.Sequential( nn.ConvTranspose2d(in_channel, out_channel, kernel_size=3, stride=2, padding=1, output_padding=1), # nn.UpsamplingBilinear2d(scale_factor=2), # nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1), nn.BatchNorm2d(out_channel), Swish(), )
def __init__(self, z_dim): super().__init__() self.encoder_blocks = nn.ModuleList([ EncoderBlock([3, z_dim // 16, z_dim // 8]), # (16, 16) EncoderBlock([z_dim // 8, z_dim // 4, z_dim // 2]), # (4, 4) EncoderBlock([z_dim // 2, z_dim]), # (2, 2) ]) self.encoder_residual_blocks = nn.ModuleList([ EncoderResidualBlock(z_dim // 8), EncoderResidualBlock(z_dim // 2), EncoderResidualBlock(z_dim), ]) self.condition_x = nn.Sequential( Swish(), nn.Conv2d(z_dim, z_dim * 2, kernel_size=1))