def conv_block(self, x, conv_layers, norm_layer, emb, res=True): # first layer x_add = x + emb.view(emb.size(0), emb.size(1), 1) out = pad_layer(x_add, conv_layers[0]) out = F.leaky_relu(out, negative_slope=self.ns) # upsample by pixelshuffle out = pixel_shuffle_1d(out, upscale_factor=2) out = out + emb.view(emb.size(0), emb.size(1), 1) out = pad_layer(out, conv_layers[1]) out = F.leaky_relu(out, negative_slope=self.ns) out = norm_layer(out) if res: x_up = upsample(x, scale_factor=2) out = out + x_up return out
def forward(self, x): outs = [] for l in self.conv1s: out = pad_layer(x, l) outs.append(out) out = torch.cat(outs + [x], dim=1) out = F.leaky_relu(out, negative_slope=self.ns) out = self.conv_block(out, [self.conv2], [self.ins_norm1, self.drop1], res=False) out = self.conv_block(out, [self.conv3, self.conv4], [self.ins_norm2, self.drop2]) out = self.conv_block(out, [self.conv5, self.conv6], [self.ins_norm3, self.drop3]) out = self.conv_block(out, [self.conv7, self.conv8], [self.ins_norm4, self.drop4]) # dense layer out = self.dense_block(out, [self.dense1, self.dense2], [self.ins_norm5, self.drop5], res=True) out = self.dense_block(out, [self.dense3, self.dense4], [self.ins_norm6, self.drop6], res=True) out_rnn = RNN(out, self.RNN) out = torch.cat([out, out_rnn], dim=1) out = linear(out, self.linear) mean = RNN(out, self.mean) log_var = RNN(out, self.log_var) if self.one_hot: out = gumbel_softmax(out) else: out = F.leaky_relu(out, negative_slope=self.ns) return out, mean, log_var
def conv_block(self, x, conv_layers, after_layers, res=True): out = x for layer in conv_layers: out = pad_layer(out, layer) out = F.leaky_relu(out, negative_slope=self.ns) for layer in after_layers: out = layer(out) if res: out = out + x return out
def conv_block(self, x, conv_layers, norm_layers, res=True): out = x for layer in conv_layers: out = pad_layer(out, layer) out = F.leaky_relu(out, negative_slope=self.ns) for layer in norm_layers: out = layer(out) if res: x_pad = F.pad(x, pad=(0, x.size(2) % 2), mode='reflect') x_down = F.avg_pool1d(x_pad, kernel_size=2) out = x_down + out return out