def create_decoder_block(in_channels, out_channels, kernel_size, wn=True, bn=True, activation=nn.LeakyReLU, layers=2, final_layer=False): decoder = [] for i in range(layers): _in = in_channels _out = in_channels _bn = bn _activation = activation if i == 0: _in = in_channels * 2 if i == layers - 1: _out = out_channels if final_layer: _bn = False _activation = None decoder.append( create_layer(_in, _out, kernel_size, wn, _bn, _activation, nn.ConvTranspose2d)) return nn.Sequential(*decoder)
def create_encoder_block(in_channels, out_channels, kernel_size, wn, bn, layers, activation=nn.ReLU): encoder = [] for i in range(layers): _in = out_channels _out = out_channels if i == 0: _in = in_channels encoder.append(create_layer(_in, _out, kernel_size, wn, bn, activation, nn.Conv2d)) return nn.Sequential(*encoder)
def __init__(self, in_channels, out_channels, kernel_size=3, filters=[8, 16, 16, 32], weight_norm=True, batch_norm=True, activation=nn.ReLU, final_activation=None): super().__init__() assert len(filters) > 0 encoder = [] decoder = [[] for _ in range(out_channels)] for i in range(len(filters)): if i == 0: encoder_layer = create_layer(in_channels, filters[i], kernel_size, weight_norm, batch_norm, activation, nn.Conv2d) decoder_layer = [ create_layer(filters[i], 1, kernel_size, weight_norm, False, final_activation, nn.ConvTranspose2d) for _ in range(out_channels) ] else: encoder_layer = create_layer(filters[i - 1], filters[i], kernel_size, weight_norm, batch_norm, activation, nn.Conv2d) decoder_layer = [ create_layer(filters[i], filters[i - 1], kernel_size, weight_norm, batch_norm, activation, nn.ConvTranspose2d) for _ in range(out_channels) ] encoder = encoder + [encoder_layer] for c in range(out_channels): decoder[c] = [decoder_layer[c]] + decoder[c] self.encoder = nn.Sequential(*encoder) for c in range(out_channels): decoder[c] = nn.Sequential(*decoder[c]) self.decoder = nn.Sequential(*decoder)