Exemplo n.º 1
0
 def remove_weight_norm(self):
     remove_weight_norm(self.input_conv[0])
     remove_weight_norm(self.kernel_conv)
     remove_weight_norm(self.bias_conv)
     for block in self.residual_convs:
         remove_weight_norm(block[1])
         remove_weight_norm(block[3])
Exemplo n.º 2
0
 def remove_weight_norm(self):
     print('Removing weight norm...')
     remove_weight_norm(self.conv_pre)
     for layer in self.conv_post:
         if len(layer.state_dict()) != 0:
             remove_weight_norm(layer)
     for res_block in self.res_stack:
         res_block.remove_weight_norm()
Exemplo n.º 3
0
    def remove_wn(self):
        print("Removing weight norm...")

        for ups in self.upsampler:
            ups.remove_wn()

        remove_weight_norm(self.conv_pre)
        remove_weight_norm(self.conv_post)
Exemplo n.º 4
0
 def remove_weight_norm(self):
     for l in self.ups:
         remove_weight_norm(l)
     for l in self.resblocks:
         l.remove_weight_norm()
     remove_weight_norm(self.conv_pre)
     remove_weight_norm(self.conv_post)
Exemplo n.º 5
0
 def remove_weight_norm(self):
     print("Removing weight norm...")
     for layer in self.ups:
         remove_weight_norm(layer)
     for layer in self.resblocks:
         layer.remove_weight_norm()
     remove_weight_norm(self.conv_pre)
     remove_weight_norm(self.conv_post)
Exemplo n.º 6
0
 def remove_weight_norm(self):
     print('Removing weight norm...')
     for l in self.ups:
         remove_weight_norm(l)
     for l in self.resblocks:
         l.remove_weight_norm()
     remove_weight_norm(self.conv_pre)
     remove_weight_norm(self.conv_post)
Exemplo n.º 7
0
 def remove_weight_norm(self):
     for layer in self.convs1:
         remove_weight_norm(layer)
     for layer in self.convs2:
         remove_weight_norm(layer)
Exemplo n.º 8
0
    def __init__(
        self,
        in_channels,
        out_channels,
        resblock_type,
        resblock_dilation_sizes,
        resblock_kernel_sizes,
        upsample_kernel_sizes,
        upsample_initial_channel,
        upsample_factors,
        inference_padding=5,
        cond_channels=0,
        conv_pre_weight_norm=True,
        conv_post_weight_norm=True,
        conv_post_bias=True,
    ):
        r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)

        Network:
            x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
                                                 ..          -> zI ---|
                                              resblockN_kNx1 -> zN ---'

        Args:
            in_channels (int): number of input tensor channels.
            out_channels (int): number of output tensor channels.
            resblock_type (str): type of the `ResBlock`. '1' or '2'.
            resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
            resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
            upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
            upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
                for each consecutive upsampling layer.
            upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
            inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
        """
        super().__init__()
        self.inference_padding = inference_padding
        self.num_kernels = len(resblock_kernel_sizes)
        self.num_upsamples = len(upsample_factors)
        # initial upsampling layers
        self.conv_pre = weight_norm(
            Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
        resblock = ResBlock1 if resblock_type == "1" else ResBlock2
        # upsampling layers
        self.ups = nn.ModuleList()
        for i, (u, k) in enumerate(zip(upsample_factors,
                                       upsample_kernel_sizes)):
            self.ups.append(
                weight_norm(
                    ConvTranspose1d(
                        upsample_initial_channel // (2**i),
                        upsample_initial_channel // (2**(i + 1)),
                        k,
                        u,
                        padding=(k - u) // 2,
                    )))
        # MRF blocks
        self.resblocks = nn.ModuleList()
        for i in range(len(self.ups)):
            ch = upsample_initial_channel // (2**(i + 1))
            for _, (k, d) in enumerate(
                    zip(resblock_kernel_sizes, resblock_dilation_sizes)):
                self.resblocks.append(resblock(ch, k, d))
        # post convolution layer
        self.conv_post = weight_norm(
            Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
        if cond_channels > 0:
            self.cond_layer = nn.Conv1d(cond_channels,
                                        upsample_initial_channel, 1)

        if not conv_pre_weight_norm:
            remove_weight_norm(self.conv_pre)

        if not conv_post_weight_norm:
            remove_weight_norm(self.conv_post)
Exemplo n.º 9
0
 def remove_weight_norm(self):
     for l in self.convs:
         remove_weight_norm(l)
Exemplo n.º 10
0
 def remove_weight_norm(self):
     self.kernel_predictor.remove_weight_norm()
     remove_weight_norm(self.convt_pre[1])
     for block in self.conv_blocks:
         remove_weight_norm(block[1])
Exemplo n.º 11
0
from models.conv import GatedConv
import torch
import heapq
from torch.nn.utils import remove_weight_norm
from config import pretrained_model_path

torch.set_grad_enabled(False)

model = GatedConv.load(pretrained_model_path)
model.eval()

conv = model.cnn[10]
remove_weight_norm(conv)

w = conv.weight.squeeze().detach()
b = conv.bias.unsqueeze(1).detach()

embed = w

vocab = model.vocabulary
v = dict((vocab[i], i) for i in range(len(vocab)))


def cos(c1, c2):
    e1, e2 = embed[v[c1]], embed[v[c2]]
    return (e1 * e2).sum() / (e1.norm() * e2.norm())


def nearest(c, n=5):
    def gen():
        for c_ in v:
Exemplo n.º 12
0
 def remove_weight_norm(self):
     remove_weight_norm(self.conv)
Exemplo n.º 13
0
 def remove_wn(self):
     remove_weight_norm(self.up)
     self.res_0.remove_wn()
     self.res_1.remove_wn()
     self.res_2.remove_wn()
     
Exemplo n.º 14
0
 def remove_wn(self):
     for l in self.convs1:
         remove_weight_norm(l)
     for l in self.convs2:
         remove_weight_norm(l)