示例#1
0
    def __init__(self, config):
        super().__init__()

        # self.len_in_out_multiplier = hparams.len_in_out_multiplier

        # Use the wavenet_vocoder builder to create the model.
        self.model = WaveNet(
            out_channels=config.out_channels,
            layers=config.layers,
            stacks=config.stacks,
            residual_channels=config.residual_channels,
            gate_channels=config.gate_channels,
            skip_out_channels=config.skip_out_channels,
            kernel_size=config.kernel_size,
            dropout=config.dropout,
            weight_normalization=config.weight_normalization,
            cin_channels=config.cin_channels,
            gin_channels=config.gin_channels,
            n_speakers=config.n_speakers,
            upsample_conditional_features=config.upsample_conditional_features,
            upsample_scales=config.upsample_scales,
            freq_axis_kernel_size=config.freq_axis_kernel_size,
            scalar_input=config.scalar_input,
            use_speaker_embedding=config.use_speaker_embedding,
            legacy=config.legacy
        )

        self.has_weight_norm = True
示例#2
0
class WaveNetWrapper(nn.Module):
    """A wrapper around r9y9's WaveNet implementation to integrate it seamlessly into the framework."""
    IDENTIFIER = "r9y9WaveNet"

    def __init__(self, dim_in, dim_out, hparams):
        super().__init__()

        self.len_in_out_multiplier = hparams.len_in_out_multiplier

        # Use the wavenet_vocoder builder to create the model.
        self.model = WaveNet(out_channels=hparams.out_channels,
                             layers=hparams.layers,
                             stacks=hparams.stacks,
                             residual_channels=hparams.residual_channels,
                             gate_channels=hparams.gate_channels,
                             skip_out_channels=hparams.skip_out_channels,
                             kernel_size=hparams.kernel_size,
                             dropout=hparams.dropout,
                             weight_normalization=hparams.weight_normalization,
                             cin_channels=hparams.cin_channels,
                             gin_channels=hparams.gin_channels,
                             n_speakers=hparams.n_speakers,
                             upsample_conditional_features=hparams.upsample_conditional_features,
                             upsample_scales=hparams.upsample_scales,
                             freq_axis_kernel_size=hparams.freq_axis_kernel_size,
                             scalar_input=is_scalar_input(hparams.input_type),
                             use_speaker_embedding=hparams.use_speaker_embedding,
                             )

    def forward(self, inputs, hidden, seq_lengths_inputs, max_length_inputs, target=None, seq_lengths_target=None):

        if target is not None:  # During training and testing with teacher forcing.
            output = self.model(target, c=inputs, g=None, softmax=False)
            # output = self.model(target, c=inputs[:, :, :target.shape[2]], g=None, softmax=False)
            # Output shape is B x C x T. Don't permute here because CrossEntropyLoss requires the same shape.
        else:  # During inference.
            with torch.no_grad():
                self.model.make_generation_fast_()
                assert(len(seq_lengths_inputs) == 1), "Batch synthesis is not supported yet."
                num_frames_to_gen = seq_lengths_inputs[0] * self.len_in_out_multiplier
                output = self.model.incremental_forward(c=inputs, T=num_frames_to_gen, softmax=True, quantize=True)
                # Output shape is B x C x T.

        return output, None

    def set_gpu_flag(self, use_gpu):
        self.use_gpu = use_gpu

    def init_hidden(self, batch_size=1):
        return None

    def parameters(self):
        return self.model.parameters()
示例#3
0
def wavenet(
    layers=20,
    stacks=2,
    residual_channels=512,
    gate_channels=512,
    skip_out_channels=512,
    cin_channels=-1,
    gin_channels=-1,
    weight_normalization=True,
    dropout=1 - 0.95,
    kernel_size=3,
    n_speakers=None,
    upsample_conditional_features=False,
    upsample_scales=[16, 16],
):
    from wavenet_vocoder import WaveNet

    model = WaveNet(
        layers=layers,
        stacks=stacks,
        residual_channels=residual_channels,
        gate_channels=gate_channels,
        skip_out_channels=skip_out_channels,
        kernel_size=kernel_size,
        dropout=dropout,
        weight_normalization=weight_normalization,
        cin_channels=cin_channels,
        gin_channels=gin_channels,
        n_speakers=n_speakers,
        upsample_conditional_features=upsample_conditional_features,
        upsample_scales=upsample_scales)

    return model
示例#4
0
def get_model():
    global hparams
    upsample_params = hparams.upsample_params
    upsample_params["cin_channels"] = hparams.cin_channels
    upsample_params["cin_pad"] = hparams.cin_pad
    model = WaveNet(
        out_channels=hparams.out_channels,
        layers=hparams.layers,
        stacks=hparams.stacks,
        residual_channels=hparams.residual_channels,
        gate_channels=hparams.gate_channels,
        skip_out_channels=hparams.skip_out_channels,
        cin_channels=hparams.cin_channels,
        gin_channels=hparams.gin_channels,
        n_speakers=hparams.n_speakers,
        dropout=hparams.dropout,
        kernel_size=hparams.kernel_size,
        cin_pad=hparams.cin_pad,
        upsample_conditional_features=hparams.upsample_conditional_features,
        upsample_params=upsample_params,
        scalar_input=is_scalar_input(hparams.input_type),
        output_distribution=hparams.output_distribution,
    )
    # print(model)
    return model
示例#5
0
def build_model():
    if is_mulaw_quantize(hparams.input_type):
        if hparams.out_channels != hparams.quantize_channels:
            raise RuntimeError(
                "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'")
    if hparams.upsample_conditional_features and hparams.cin_channels < 0:
        s = "Upsample conv layers were specified while local conditioning disabled. "
        s += "Notice that upsample conv layers will never be used."
        warn(s)

    upsample_params = hparams.upsample_params
    upsample_params["cin_channels"] = hparams.cin_channels
    upsample_params["cin_pad"] = hparams.cin_pad
    model = WaveNet(
        out_channels=hparams.out_channels,
        layers=hparams.layers,
        stacks=hparams.stacks,
        residual_channels=hparams.residual_channels,
        gate_channels=hparams.gate_channels,
        skip_out_channels=hparams.skip_out_channels,
        cin_channels=hparams.cin_channels,
        gin_channels=hparams.gin_channels,
        n_speakers=hparams.n_speakers,
        dropout=hparams.dropout,
        kernel_size=hparams.kernel_size,
        cin_pad=hparams.cin_pad,
        upsample_conditional_features=hparams.upsample_conditional_features,
        upsample_params=upsample_params,
        scalar_input=is_scalar_input(hparams.input_type),
        output_distribution=hparams.output_distribution,
    )
    return model
示例#6
0
def build_vqvae_model():
    if is_mulaw_quantize(hparams.input_type):
        if hparams.out_channels != hparams.quantize_channels:
            raise RuntimeError(
                "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'"
            )
    if hparams.upsample_conditional_features and hparams.cin_channels < 0:
        s = "Upsample conv layers were specified while local conditioning disabled. "
        s += "Notice that upsample conv layers will never be used."
        warn(s)

    upsample_params = hparams.upsample_params
    upsample_params["cin_channels"] = hparams.cin_channels
    upsample_params["cin_pad"] = hparams.cin_pad
    wavenet = WaveNet(
        out_channels=hparams.out_channels,
        layers=hparams.layers,
        stacks=hparams.stacks,
        residual_channels=hparams.residual_channels,
        gate_channels=hparams.gate_channels,
        skip_out_channels=hparams.skip_out_channels,
        cin_channels=hparams.cin_channels,
        gin_channels=hparams.gin_channels,
        n_speakers=hparams.n_speakers,
        dropout=hparams.dropout,
        kernel_size=hparams.kernel_size,
        cin_pad=hparams.cin_pad,
        upsample_conditional_features=hparams.upsample_conditional_features,
        upsample_params=upsample_params,
        scalar_input=is_scalar_input(hparams.input_type),
        output_distribution=hparams.output_distribution,
        use_speaker_embedding=True,
    )
    if hparams.use_K1 and hparams.K1 != hparams.K:
        K1 = hparams.K1
    else:
        K1 = None

    if hparams.post_conv:
        hid = 64
    else:
        hid = hparams.cin_channels

    model = VQVAE(wavenet=wavenet,
                  c_in=39,
                  hid=hid,
                  frame_rate=hparams.frame_rate,
                  use_time_jitter=hparams.time_jitter,
                  K=hparams.K,
                  ema=hparams.ema,
                  sliced=hparams.sliced,
                  ins_norm=hparams.ins_norm,
                  post_conv=hparams.post_conv,
                  adain=hparams.adain,
                  dropout=hparams.vq_drop,
                  drop_dim=hparams.drop_dim,
                  K1=K1,
                  num_slices=hparams.num_slices)
    return model
示例#7
0
def build_model():
    if is_mulaw_quantize(hparams.input_type):
        if hparams.out_channels != hparams.quantize_channels:
            raise RuntimeError(
                "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'"
            )
    if hparams.upsample_conditional_features and hparams.cin_channels < 0:
        s = "Upsample conv layers were specified while local conditioning disabled. "
        s += "Notice that upsample conv layers will never be used."
        warn(s)

    upsample_params = hparams.upsample_params
    upsample_params["cin_channels"] = hparams.cin_channels
    upsample_params["cin_pad"] = hparams.cin_pad
    if hparams.name == 'new_inae':
        use_speaker_embedding = False
    else:
        use_speaker_embedding = True

    wavenet = WaveNet(
        out_channels=hparams.out_channels,
        layers=hparams.layers,
        stacks=hparams.stacks,
        residual_channels=hparams.residual_channels,
        gate_channels=hparams.gate_channels,
        skip_out_channels=hparams.skip_out_channels,
        cin_channels=hparams.cin_channels,
        gin_channels=hparams.gin_channels,
        n_speakers=hparams.n_speakers,
        dropout=hparams.dropout,
        kernel_size=hparams.kernel_size,
        cin_pad=hparams.cin_pad,
        upsample_conditional_features=hparams.upsample_conditional_features,
        upsample_params=upsample_params,
        scalar_input=is_scalar_input(hparams.input_type),
        output_distribution=hparams.output_distribution,
        use_speaker_embedding=use_speaker_embedding,
    )
    if hparams.name == 'inae':
        model = INAE(wavenet=wavenet,
                     c_in=39,
                     hid=64,
                     frame_rate=hparams.frame_rate,
                     adain=hparams.adain)
    elif hparams.name == 'inae1':

        model = INAE1(wavenet=wavenet,
                      c_in=39,
                      hid=64,
                      frame_rate=hparams.frame_rate,
                      adain=hparams.adain)
    elif hparams.name == 'new_inae':
        model = NewINAE(wavenet=wavenet,
                        c_in=39,
                        hid=64,
                        frame_rate=hparams.frame_rate)
    return model
示例#8
0
class BetterWaveNetDecoder(nn.Module):
    def __init__(self, wavenet_args):
        super().__init__()
        self.wavenet = WaveNet(**wavenet_args)

    def forward(self, one_hot_z, x):
        output = self.wavenet.forward(x=one_hot(x, self.wavenet.out_channels),
                                      c=one_hot_z)
        p_x = torch.cat(
            [torch.ones(output.size(0), output.size(1), 1), output[:, :, :-1]],
            dim=-1)
        return p_x
示例#9
0
def save_checkpoint(device,
                    model,
                    global_step,
                    global_test_step,
                    checkpoint_dir,
                    epoch,
                    ema=None):
    checkpoint_path = join(
        checkpoint_dir,
        hparams.name + "_checkpoint_step{:09d}.pth.tar".format(global_step))
    optimizer_state = model.optimizer.state_dict(
    ) if hparams.save_optimizer_state else None
    torch.save(
        {
            "model": model.decode_model.state_dict(),
            "optimizer": optimizer_state,
            "global_step": global_step,
            "global_epoch": epoch,
            "global_test_step": global_test_step,
        }, checkpoint_path)
    print("Saved checkpoint:", checkpoint_path)

    if ema is not None:
        averaged_model = WaveNet(
            scalar_input=is_scalar_input(hparams.input_type))
        averaged_model = torch.nn.DataParallel(averaged_model).to(device)
        averaged_model = clone_as_averaged_model(averaged_model, model, ema)
        checkpoint_path = join(
            checkpoint_dir,
            "checkpoint_step{:09d}_ema.pth".format(global_step))
        torch.save(
            {
                "model": averaged_model.state_dict(),
                "optimizer": optimizer_state,
                "global_step": global_step,
                "global_epoch": epoch,
                "global_test_step": global_test_step,
            }, checkpoint_path)
        print("Saved averaged checkpoint:", checkpoint_path)
示例#10
0
    def __init__(self, dim_in, dim_out, hparams):
        super().__init__()

        self.len_in_out_multiplier = hparams.len_in_out_multiplier

        # Use the wavenet_vocoder builder to create the model.
        self.model = WaveNet(out_channels=hparams.out_channels,
                             layers=hparams.layers,
                             stacks=hparams.stacks,
                             residual_channels=hparams.residual_channels,
                             gate_channels=hparams.gate_channels,
                             skip_out_channels=hparams.skip_out_channels,
                             kernel_size=hparams.kernel_size,
                             dropout=hparams.dropout,
                             weight_normalization=hparams.weight_normalization,
                             cin_channels=hparams.cin_channels,
                             gin_channels=hparams.gin_channels,
                             n_speakers=hparams.n_speakers,
                             upsample_conditional_features=hparams.upsample_conditional_features,
                             upsample_scales=hparams.upsample_scales,
                             freq_axis_kernel_size=hparams.freq_axis_kernel_size,
                             scalar_input=is_scalar_input(hparams.input_type),
                             use_speaker_embedding=hparams.use_speaker_embedding,
                             )
示例#11
0
def wavenet(
    out_channels=256,
    layers=20,
    stacks=2,
    residual_channels=512,
    gate_channels=512,
    skip_out_channels=512,
    cin_channels=-1,
    gin_channels=-1,
    weight_normalization=True,
    dropout=1 - 0.95,
    kernel_size=3,
    n_speakers=None,
    upsample_conditional_features=False,
    upsample_scales=[16, 16],
    freq_axis_kernel_size=3,
    scalar_input=False,
    use_speaker_embedding=True,
    legacy=True,
    use_gaussian=False,
):
    from wavenet_vocoder import WaveNet

    model = WaveNet(
        out_channels=out_channels,
        layers=layers,
        stacks=stacks,
        residual_channels=residual_channels,
        gate_channels=gate_channels,
        skip_out_channels=skip_out_channels,
        kernel_size=kernel_size,
        dropout=dropout,
        weight_normalization=weight_normalization,
        cin_channels=cin_channels,
        gin_channels=gin_channels,
        n_speakers=n_speakers,
        upsample_conditional_features=upsample_conditional_features,
        upsample_scales=upsample_scales,
        freq_axis_kernel_size=freq_axis_kernel_size,
        scalar_input=scalar_input,
        use_speaker_embedding=use_speaker_embedding,
        legacy=legacy,
        use_gaussian=use_gaussian,
    )

    return model
示例#12
0
def wavenet(out_channels=256,
            layers=20,
            stacks=2,
            residual_channels=512,
            gate_channels=512,
            skip_out_channels=512,
            cin_channels=-1,
            gin_channels=-1,
            weight_normalization=True,
            dropout=1 - 0.95,
            kernel_size=3,
            n_speakers=None,
            upsample_conditional_features=False,
            upsample_scales=[16, 16],
            freq_axis_kernel_size=3,
            scalar_input=False,
            modal="se",
            modal_N=8,
            modal_stride=0,
            body_hidden_size=64,
            body_out_channels=32,
            ):
    from wavenet_vocoder import WaveNet

    model = WaveNet(out_channels=out_channels, layers=layers, stacks=stacks,
                    residual_channels=residual_channels,
                    gate_channels=gate_channels,
                    skip_out_channels=skip_out_channels,
                    kernel_size=kernel_size, dropout=dropout,
                    weight_normalization=weight_normalization,
                    cin_channels=cin_channels, gin_channels=gin_channels,
                    n_speakers=n_speakers,
                    upsample_conditional_features=upsample_conditional_features,
                    upsample_scales=upsample_scales,
                    freq_axis_kernel_size=freq_axis_kernel_size,
                    scalar_input=scalar_input,
                    modal=modal,
                    modal_N=modal_N,
                    modal_stride=modal_stride,
                    body_hidden_size=body_hidden_size,
                    body_out_channels=body_out_channels,
                    )

    return model
def build_catae_model():
    if is_mulaw_quantize(hparams.input_type):
        if hparams.out_channels != hparams.quantize_channels:
            raise RuntimeError(
                "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'"
            )
    if hparams.upsample_conditional_features and hparams.cin_channels < 0:
        s = "Upsample conv layers were specified while local conditioning disabled. "
        s += "Notice that upsample conv layers will never be used."
        warn(s)

    upsample_params = hparams.upsample_params
    upsample_params["cin_channels"] = hparams.cin_channels
    upsample_params["cin_pad"] = hparams.cin_pad
    wavenet = WaveNet(
        out_channels=hparams.out_channels,
        layers=hparams.layers,
        stacks=hparams.stacks,
        residual_channels=hparams.residual_channels,
        gate_channels=hparams.gate_channels,
        skip_out_channels=hparams.skip_out_channels,
        cin_channels=hparams.cin_channels,
        gin_channels=hparams.gin_channels,
        n_speakers=hparams.n_speakers,
        dropout=hparams.dropout,
        kernel_size=hparams.kernel_size,
        cin_pad=hparams.cin_pad,
        upsample_conditional_features=hparams.upsample_conditional_features,
        upsample_params=upsample_params,
        scalar_input=is_scalar_input(hparams.input_type),
        output_distribution=hparams.output_distribution,
        use_speaker_embedding=True,
    )
    model = CatWavAE(wavenet=wavenet,
                     c_in=39,
                     hid=hparams.cin_channels,
                     tau=0.1,
                     k=hparams.K,
                     frame_rate=hparams.frame_rate,
                     hard=hparams.hard,
                     slices=hparams.num_slices)
    return model
示例#14
0
def build_model(hparams_json=None):
    if hparams_json is not None:
        with open(hparams_json, 'r') as jf:
            hparams = HParams(**json.load(jf))
    if is_mulaw_quantize(hparams.input_type):
        if hparams.out_channels != hparams.quantize_channels:
            raise RuntimeError(
                "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'")
    if hparams.upsample_conditional_features and hparams.cin_channels < 0:
        s = "Upsample conv layers were specified while local conditioning disabled. "
        s += "Notice that upsample conv layers will never be used."
        warn(s)

    upsample_params = hparams.upsample_params
    upsample_params["cin_channels"] = hparams.cin_channels
    upsample_params["cin_pad"] = hparams.cin_pad
    use_speaker_embedding = True if hparams.gin_channels > 0 else False
    model = WaveNet(
        out_channels=hparams.out_channels,
        layers=hparams.layers,
        stacks=hparams.stacks,
        residual_channels=hparams.residual_channels,
        gate_channels=hparams.gate_channels,
        skip_out_channels=hparams.skip_out_channels,
        cin_channels=hparams.cin_channels,
        gin_channels=hparams.gin_channels,
        n_speakers=hparams.n_speakers,
        dropout=hparams.dropout,
        kernel_size=hparams.kernel_size,
        cin_pad=hparams.cin_pad,
        upsample_conditional_features=hparams.upsample_conditional_features,
        upsample_net=hparams.upsample_net,
        upsample_params=upsample_params,
        scalar_input=is_scalar_input(hparams.input_type),
        use_speaker_embedding=use_speaker_embedding,
        output_distribution=hparams.output_distribution,
    )
    return model
示例#15
0
        s = "Upsample conv layers were specified while local conditioning disabled. "
        s += "Notice that upsample conv layers will never be used."
        warn(s)

    upsample_params = hparams.upsample_params
    upsample_params["cin_channels"] = hparams.cin_channels
    upsample_params["cin_pad"] = hparams.cin_pad
    model = WaveNet(
        out_channels=hparams.out_channels,
        layers=hparams.layers,
        stacks=hparams.stacks,
        residual_channels=hparams.residual_channels,
        gate_channels=hparams.gate_channels,
        skip_out_channels=hparams.skip_out_channels,
        cin_channels=hparams.cin_channels,
        gin_channels=hparams.gin_channels,
        n_speakers=hparams.n_speakers,
        dropout=hparams.dropout,
        kernel_size=hparams.kernel_size,
        cin_pad=hparams.cin_pad,
        upsample_conditional_features=hparams.upsample_conditional_features,
        upsample_params=upsample_params,
        scalar_input=is_scalar_input(hparams.input_type),
        output_distribution=hparams.output_distribution,
    )
    loss_net = NetWithLossClass(model, hparams)
    lr = get_lr(hparams.optimizer_params["lr"], hparams.nepochs,
                step_size_per_epoch)
    lr = Tensor(lr)

    if args.checkpoint != '':
        param_dict = load_checkpoint(args.pre_trained_model_path)
示例#16
0
 def __init__(self, wavenet_args):
     super().__init__()
     self.wavenet = WaveNet(**wavenet_args)
示例#17
0
class WaveNetWrapper(nn.Module):
    """A wrapper around r9y9's WaveNet implementation to integrate it seamlessly into the framework."""
    IDENTIFIER = "r9y9WaveNet"

    class Config:
        INPUT_TYPE_MULAW = "mulaw-quantize"
        INPUT_TYPE_RAW = "raw"

        def __init__(
                self,
                cin_channels=80,
                dropout=0.05,
                freq_axis_kernel_size=3,
                gate_channels=512,
                gin_channels=-1,
                hinge_regularizer=True,  # Only used in MoL prediction (INPUT_TYPE_RAW).
                kernel_size=3,
                layers=24,
                log_scale_min=float(np.log(1e-14)),  # Only used in INPUT_TYPE_RAW.
                n_speakers=1,
                out_channels=256,  # Use num_mixtures * 3 (pi, mean, log_scale) for INPUT_TYPE_RAW.
                residual_channels=512,
                scalar_input=is_scalar_input(INPUT_TYPE_MULAW),
                skip_out_channels=256,
                stacks=4,
                upsample_conditional_features=False,
                upsample_scales=[5, 4, 2],
                use_speaker_embedding=False,
                weight_normalization=True,
                legacy=False):

            self.cin_channels = cin_channels
            self.dropout = dropout
            self.freq_axis_kernel_size = freq_axis_kernel_size
            self.gate_channels = gate_channels
            self.gin_channels = gin_channels
            self.hinge_regularizer = hinge_regularizer
            self.kernel_size = kernel_size
            self.layers = layers
            self.log_scale_min = log_scale_min
            self.n_speakers = n_speakers
            self.out_channels = out_channels
            self.residual_channels = residual_channels
            self.scalar_input = scalar_input
            self.skip_out_channels = skip_out_channels
            self.stacks = stacks
            self.upsample_conditional_features = upsample_conditional_features
            self.upsample_scales = upsample_scales
            self.use_speaker_embedding = use_speaker_embedding
            self.weight_normalization = weight_normalization
            self.legacy = legacy

        def create_model(self):
            return WaveNetWrapper(self)

    def __init__(self, config):
        super().__init__()

        # self.len_in_out_multiplier = hparams.len_in_out_multiplier

        # Use the wavenet_vocoder builder to create the model.
        self.model = WaveNet(
            out_channels=config.out_channels,
            layers=config.layers,
            stacks=config.stacks,
            residual_channels=config.residual_channels,
            gate_channels=config.gate_channels,
            skip_out_channels=config.skip_out_channels,
            kernel_size=config.kernel_size,
            dropout=config.dropout,
            weight_normalization=config.weight_normalization,
            cin_channels=config.cin_channels,
            gin_channels=config.gin_channels,
            n_speakers=config.n_speakers,
            upsample_conditional_features=config.upsample_conditional_features,
            upsample_scales=config.upsample_scales,
            freq_axis_kernel_size=config.freq_axis_kernel_size,
            scalar_input=config.scalar_input,
            use_speaker_embedding=config.use_speaker_embedding,
            legacy=config.legacy
        )

        self.has_weight_norm = True
        # self.__deepcopy__ = MethodType(__deepcopy__, self)

    def forward(self, input_, target, seq_lengths, *_):

        if target is not None:  # During training and testing with teacher forcing.
            assert self.has_weight_norm, "Model has been used for generation " \
                "and weight norm was removed, cannot continue training. Remove"\
                " the make_generation_fast_() call to continue training after" \
                " generation."
            output = self.model(target, c=input_, g=None, softmax=False)
            # output = self.model(target, c=inputs[:, :, :target.shape[2]], g=None, softmax=False)
            # Output shape is B x C x T. Don't permute here because CrossEntropyLoss requires the same shape.
        else:  # During inference.
            with torch.no_grad():
                self.model.make_generation_fast_()  # After calling this the training cannot be continued.
                self.has_weight_norm = False
                assert(len(seq_lengths) == 1), "Batch synth is not supported."
                num_frames_to_gen = seq_lengths[0] * self.len_in_out_multiplier
                output = self.model.incremental_forward(
                    c=input_, T=num_frames_to_gen, softmax=True, quantize=True)
                # output = self.model.incremental_forward(
                #   c=inputs[:, :, :1000], T=torch.tensor(1000), softmax=True, quantize=True)

        # Output shape is B x C x T.
        return output, None

    def set_gpu_flag(self, use_gpu):
        self.use_gpu = use_gpu

    def init_hidden(self, batch_size=1):
        return None

    def parameters(self):
        return self.model.parameters()