コード例 #1
0
    def __init__(self, tier: int, n_layers: int, hidden_size: int,
                 gmm_size: int, freq: int):
        """
        Args:
            tier (int): the tier that this module represents.
            n_layers (int): number of layers this tier is composed of.
            hidden_size (int): parameter for the hidden_state of the Delayed Stack Layers
            gmm_size (int): number of mixture components of the GMM
            freq (int): size of the frequency axis of the spectrogram to generate. See note in the
                        documentation of the file.
        """
        super(Tier, self).__init__()

        self.tier = tier
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.k = gmm_size
        self.freq = freq

        # Only the initial tier uses a centralized stack according to MelNet paper (Table 1)
        self.has_central_stack = False

        # Define layers of the tier
        self.layers = nn.ModuleList([
            ModuleWrapperDelayedStackLayer0(
                DelayedStackLayer0(hidden_size=hidden_size,
                                   has_central_stack=self.has_central_stack,
                                   freq=freq,
                                   is_conditioned=True,
                                   hidden_size_condition=hidden_size * 4))
        ] + [
            ModuleWrapperDelayedStackLayer(
                DelayedStackLayer(layer=layer_idx,
                                  hidden_size=hidden_size,
                                  has_central_stack=self.has_central_stack,
                                  freq=freq))
            for layer_idx in range(1, n_layers)
        ])
        # The Layer 0 of this tier (greater than first tier) is conditioned on the output of the
        # feature extraction network. These conditioning features are the concatenation of the
        # hidden state of 4 one dimensional RNN, that's why the hidden size of the condition in
        # DelayedStackLayer0 has * 4

        # Define feature extraction network
        self.feature_extraction = ModuleWrapperFeatureExtraction(
            FeatureExtractionLayer(hidden_size))

        # Define dummy tensor to trick checkpointing
        self.dummy_tensor = torch.ones(1,
                                       dtype=torch.float32,
                                       requires_grad=True)

        # Linear transformation from final layer of the frequency-delayed stack to produce
        # unconstrained parameters
        self.W_theta = nn.Linear(in_features=hidden_size,
                                 out_features=3 * self.k)
コード例 #2
0
    def __init__(self, tier: int, n_layers: int, hidden_size: int,
                 gmm_size: int, freq: int):
        """
        Args:
            tier (int): the tier that this module represents.
            n_layers (int): number of layers this tier is composed of.
            hidden_size (int): parameter for the hidden_state of the Delayed Stack Layers
            gmm_size (int): number of mixture components of the GMM
            freq (int): size of the frequency axis of the spectrogram to generate. See note in the
                        documentation of the file.
        """
        super(Tier1, self).__init__()

        self.tier = tier
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.k = gmm_size

        # Only the initial tier uses a centralized stack according to MelNet paper (Table 1)
        self.has_central_stack = True

        # Define layers of the tier
        self.layers = nn.ModuleList([
            ModuleWrapperDelayedStackLayer0(
                DelayedStackLayer0(hidden_size=hidden_size,
                                   has_central_stack=self.has_central_stack,
                                   freq=freq))
        ] + [
            ModuleWrapperDelayedStackLayer(
                DelayedStackLayer(layer=layer_idx,
                                  hidden_size=hidden_size,
                                  has_central_stack=self.has_central_stack,
                                  freq=freq))
            for layer_idx in range(1, n_layers)
        ])

        # Define dummy tensor to trick checkpointing
        self.dummy_tensor = torch.ones(1,
                                       dtype=torch.float32,
                                       requires_grad=True)

        # Linear transformation from final layer of the frequency-delayed stack to produce
        # unconstrained parameters
        self.W_theta = nn.Linear(in_features=hidden_size,
                                 out_features=3 * self.k)