def __init__(
     self,
     sample_rate: int,
     n_mel_channels: int = 80,
     n_flows: int = 12,
     n_group: int = 8,
     n_early_every: int = 4,
     n_early_size: int = 2,
     n_wn_layers: int = 8,
     n_wn_channels: int = 512,
     wn_kernel_size: int = 3,
 ):
     self.sample_rate = sample_rate
     super().__init__()
     wavenet_config = {
         "n_layers": n_wn_layers,
         "n_channels": n_wn_channels,
         "kernel_size": wn_kernel_size,
     }
     self.waveglow = WaveGlow(
         n_mel_channels=n_mel_channels,
         n_flows=n_flows,
         n_group=n_group,
         n_early_every=n_early_every,
         n_early_size=n_early_size,
         WN_config=wavenet_config,
     )
     self.to(self._device)
class WaveGlowNM(TrainableNM):
    """
    WaveGlowNM implements the Waveglow model in whole. This NM is meant to
    be used during training

    Args:
        n_mel_channels (int): Size of input mel spectrogram
            Defaults to 80.
        n_flows (int): Number of normalizing flows/layers of waveglow.
            Defaults to 12
        n_group (int): Each audio/spec pair is split in n_group number of
            groups. It must be divisible by 2 as halves are split this way.
            Defaults to 8
        n_early_every (int): After n_early_every layers, n_early_size number of
            groups are skipped to the output of the Neural Module.
            Defaults to 4
        n_early_size (int): The number of groups to skip to the output at every
            n_early_every layers.
            Defaults to 2
        n_wn_layers (int): The number of layers of the wavenet submodule.
            Defaults to 8
        n_wn_channels (int): The number of channels of the wavenet submodule.
            Defaults to 512
        wn_kernel_size (int): The kernel size of the wavenet submodule.
            Defaults to 3
    """

    @property
    @add_port_docs()
    def input_ports(self):
        """Returns definitions of module input ports.
        """
        return {
            # "mel_spectrogram": NeuralType(
            #     {0: AxisType(BatchTag), 1: AxisType(MelSpectrogramSignalTag), 2: AxisType(TimeTag),}
            # ),
            # "audio": NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
            "mel_spectrogram": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
            "audio": NeuralType(('B', 'T'), AudioSignal(self.sample_rate)),
        }

    @property
    @add_port_docs()
    def output_ports(self):
        """Returns definitions of module output ports.
        """
        # TODO @blisc: please take a look at those definitions
        return {
            # "audio": NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
            # "log_s_list": NeuralType(),
            # "log_det_W_list": NeuralType(),
            "audio": NeuralType(('B', 'T'), AudioSignal(self.sample_rate)),
            "log_s_list": NeuralType(elements_type=ChannelType()),
            "log_det_W_list": NeuralType(elements_type=ChannelType()),
        }

    def __init__(
        self,
        sample_rate: int,
        n_mel_channels: int = 80,
        n_flows: int = 12,
        n_group: int = 8,
        n_early_every: int = 4,
        n_early_size: int = 2,
        n_wn_layers: int = 8,
        n_wn_channels: int = 512,
        wn_kernel_size: int = 3,
    ):
        self.sample_rate = sample_rate
        super().__init__()
        wavenet_config = {
            "n_layers": n_wn_layers,
            "n_channels": n_wn_channels,
            "kernel_size": wn_kernel_size,
        }
        self.waveglow = WaveGlow(
            n_mel_channels=n_mel_channels,
            n_flows=n_flows,
            n_group=n_group,
            n_early_every=n_early_every,
            n_early_size=n_early_size,
            WN_config=wavenet_config,
        )
        self.to(self._device)

    def forward(self, mel_spectrogram, audio):
        # This function should probably be split
        # If training, it returns the predicted normal distribution
        # Else it returns the predicted audio
        if self.training:
            audio, log_s_list, log_det_W_list = self.waveglow((mel_spectrogram, audio))
        else:
            audio = self.waveglow.infer(mel_spectrogram)
            log_s_list = log_det_W_list = []
        return audio, log_s_list, log_det_W_list