class VideoAudioGenerator(nn.Module):
    def __init__(self,
                 dim_neck,
                 dim_emb,
                 dim_pre,
                 freq,
                 dim_spec=80,
                 is_train=False,
                 lr=0.001,
                 multigpu=False,
                 lambda_wavenet=0.001,
                 args=None,
                 residual=False,
                 attention_map=None,
                 use_256=False,
                 loss_content=False,
                 test_path=None):
        super(VideoAudioGenerator, self).__init__()

        self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec)

        self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec)
        self.postnet = Postnet(num_mel=dim_spec)
        if use_256:
            self.video_decoder = VideoGenerator(use_256=True)
        else:
            self.video_decoder = STAGE2_G(residual=residual)
        self.use_256 = use_256
        self.lambda_wavenet = lambda_wavenet
        self.loss_content = loss_content
        self.multigpu = multigpu
        self.test_path = test_path

        self.vocoder = WaveRNN(rnn_dims=hparams.voc_rnn_dims,
                               fc_dims=hparams.voc_fc_dims,
                               bits=hparams.bits,
                               pad=hparams.voc_pad,
                               upsample_factors=hparams.voc_upsample_factors,
                               feat_dims=hparams.num_mels,
                               compute_dims=hparams.voc_compute_dims,
                               res_out_dims=hparams.voc_res_out_dims,
                               res_blocks=hparams.voc_res_blocks,
                               hop_length=hparams.hop_size,
                               sample_rate=hparams.sample_rate,
                               mode=hparams.voc_mode)

        if is_train:
            self.criterionIdt = torch.nn.L1Loss(reduction='mean')
            self.opt_encoder = torch.optim.Adam(self.encoder.parameters(),
                                                lr=lr)
            self.opt_decoder = torch.optim.Adam(itertools.chain(
                self.decoder.parameters(), self.postnet.parameters()),
                                                lr=lr)
            self.opt_video_decoder = torch.optim.Adam(
                self.video_decoder.parameters(), lr=lr)

            self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(),
                                                lr=hparams.voc_lr)
            self.vocoder_loss_func = F.cross_entropy  # Only for RAW

        if multigpu:
            self.encoder = nn.DataParallel(self.encoder)
            self.decoder = nn.DataParallel(self.decoder)
            self.video_decoder = nn.DataParallel(self.video_decoder)
            self.postnet = nn.DataParallel(self.postnet)
            self.vocoder = nn.DataParallel(self.vocoder)

    def optimize_parameters_video(self,
                                  dataloader,
                                  epochs,
                                  device,
                                  display_freq=10,
                                  save_freq=1000,
                                  save_dir="./",
                                  experimentName="Train",
                                  initial_niter=0,
                                  load_model=None):
        writer = SummaryWriter(log_dir="logs/" + experimentName)
        if load_model is not None:
            print("Loading from %s..." % load_model)
            # self.load_state_dict(torch.load(load_model))
            d = torch.load(load_model)
            newdict = d.copy()
            for key, value in d.items():
                newkey = key
                if 'wavenet' in key:
                    newdict[key.replace('wavenet',
                                        'vocoder')] = newdict.pop(key)
                    newkey = key.replace('wavenet', 'vocoder')
                if self.multigpu and 'module' not in key:
                    newdict[newkey.replace('.', '.module.',
                                           1)] = newdict.pop(newkey)
                    newkey = newkey.replace('.', '.module.', 1)
                if newkey not in self.state_dict():
                    newdict.pop(newkey)
            print("Load " + str(len(newdict)) + " parameters!")
            self.load_state_dict(newdict, strict=False)
            print("AutoVC Model Loaded")
        niter = initial_niter
        for epoch in range(epochs):
            self.train()
            for i, data in enumerate(dataloader):
                # print("Processing ..." + str(name))
                speaker, mel, prev, wav, video, video_large = data
                speaker, mel, prev, wav, video, video_large = speaker.to(
                    device), mel.to(device), prev.to(device), wav.to(
                        device), video.to(device), video_large.to(device)
                codes, code_unsample = self.encoder(mel,
                                                    speaker,
                                                    return_unsample=True)

                tmp = []
                for code in codes:
                    tmp.append(
                        code.unsqueeze(1).expand(-1,
                                                 int(mel.size(1) / len(codes)),
                                                 -1))
                code_exp = torch.cat(tmp, dim=1)

                if not self.use_256:
                    v_stage1, v_stage2 = self.video_decoder(code_unsample,
                                                            train=True)
                else:
                    v_stage2 = self.video_decoder(code_unsample)
                mel_outputs = self.decoder(code_exp)
                mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1))
                mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(
                    2, 1)

                if self.loss_content:
                    _, recons_codes = self.encoder(mel_outputs_postnet,
                                                   speaker,
                                                   return_unsample=True)
                    loss_content = self.criterionIdt(code_unsample,
                                                     recons_codes)
                else:
                    loss_content = torch.from_numpy(np.array(0))

                if not self.use_256:
                    loss_video = self.criterionIdt(v_stage1,
                                                   video) + self.criterionIdt(
                                                       v_stage2, video_large)
                else:
                    loss_video = self.criterionIdt(v_stage2, video_large)

                loss_recon = self.criterionIdt(mel, mel_outputs)
                loss_recon0 = self.criterionIdt(mel, mel_outputs_postnet)
                loss_vocoder = 0

                if not self.multigpu:
                    y_hat = self.vocoder(
                        prev,
                        self.vocoder.pad_tensor(mel_outputs_postnet,
                                                hparams.voc_pad).transpose(
                                                    1, 2))
                else:
                    y_hat = self.vocoder(
                        prev,
                        self.vocoder.module.pad_tensor(
                            mel_outputs_postnet,
                            hparams.voc_pad).transpose(1, 2))
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
                # assert (0 <= wav < 2 ** 9).all()
                loss_vocoder = self.vocoder_loss_func(
                    y_hat,
                    wav.unsqueeze(-1).to(device))
                self.opt_vocoder.zero_grad()

                loss = loss_video + loss_recon + loss_recon0 + self.lambda_wavenet * loss_vocoder + loss_content

                self.opt_encoder.zero_grad()
                self.opt_decoder.zero_grad()
                self.opt_video_decoder.zero_grad()
                loss.backward()
                self.opt_encoder.step()
                self.opt_decoder.step()
                self.opt_video_decoder.step()
                self.opt_vocoder.step()

                if niter % display_freq == 0:
                    print("Epoch[%d] Iter[%d] Niter[%d] %s" %
                          (epoch, i, niter, loss.data.item()))
                    writer.add_scalars(
                        'data/Loss', {
                            'loss':
                            loss.data.item(),
                            'loss_video':
                            loss_video.data.item(),
                            'loss_audio':
                            loss_recon0.data.item() + loss_recon.data.item()
                        }, niter)

                if niter % save_freq == 0:
                    torch.cuda.empty_cache()  # Prevent Out of Memory
                    print("Saving and Testing...", end='\t')
                    torch.save(
                        self.state_dict(),
                        save_dir + '/Epoch' + str(epoch).zfill(3) + '_Iter' +
                        str(niter).zfill(8) + ".pkl")
                    # self.load_state_dict(torch.load('params.pkl'))
                    self.test_audiovideo(device, writer, niter)
                    print("Done")
                    self.train()
                torch.cuda.empty_cache()  # Prevent Out of Memory
                niter += 1

    def generate(self, mel, speaker, device='cuda:0'):
        mel, speaker = mel.to(device), speaker.to(device)
        if not self.multigpu:
            codes, code_unsample = self.encoder(mel,
                                                speaker,
                                                return_unsample=True)
        else:
            codes, code_unsample = self.encoder.module(mel,
                                                       speaker,
                                                       return_unsample=True)

        tmp = []
        for code in codes:
            tmp.append(
                code.unsqueeze(1).expand(-1, int(mel.size(1) / len(codes)),
                                         -1))
        code_exp = torch.cat(tmp, dim=1)

        if not self.multigpu:
            if not self.use_256:
                v_stage1, v_stage2 = self.video_decoder(code_unsample,
                                                        train=True)
            else:
                v_stage2 = self.video_decoder(code_unsample)
                v_stage1 = v_stage2
            mel_outputs = self.decoder(code_exp)
            mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1))
        else:
            if not self.use_256:
                v_stage1, v_stage2 = self.video_decoder.module(code_unsample,
                                                               train=True)
            else:
                v_stage2 = self.video_decoder.module(code_unsample)
                v_stage1 = v_stage2
            mel_outputs = self.decoder.module(code_exp)
            mel_outputs_postnet = self.postnet.module(
                mel_outputs.transpose(2, 1))

        mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1)

        return mel_outputs_postnet, v_stage1, v_stage2

    def test_video(self, device):
        wav, sr = librosa.load(
            "/mnt/lustre/dengkangle/cmu/datasets/video/obama_test.mp4",
            hparams.sample_rate)
        mel_basis = librosa.filters.mel(hparams.sample_rate,
                                        hparams.n_fft,
                                        n_mels=hparams.num_mels)
        linear_spec = np.abs(
            librosa.stft(wav,
                         n_fft=hparams.n_fft,
                         hop_length=hparams.hop_size,
                         win_length=hparams.win_size))
        mel_spec = mel_basis.dot(linear_spec)
        mel_db = 20 * np.log10(mel_spec)

        test_data = np.clip((mel_db + 120) / 125, 0, 1)
        test_data = torch.Tensor(pad_seq(test_data.T,
                                         hparams.freq)).unsqueeze(0).to(device)
        with torch.no_grad():
            codes, code_exp = self.encoder.module(test_data,
                                                  return_unsample=True)
            v_mid, v_hat = self.video_decoder.module(code_exp, train=True)

        reader = imageio.get_reader(
            "/mnt/lustre/dengkangle/cmu/datasets/video/obama_test.mp4",
            'ffmpeg',
            fps=20)
        frames = []
        for i, im in enumerate(reader):
            frames.append(np.array(im).transpose(2, 0, 1))
        frames = (np.array(frames) / 255 - 0.5) / 0.5
        return frames, v_mid[0:1], v_hat[0:1]

    def test_audiovideo(self, device, writer, niter):
        source_path = self.test_path

        mel_basis80 = librosa.filters.mel(hparams.sample_rate,
                                          hparams.n_fft,
                                          n_mels=80)

        wav, sr = librosa.load(source_path, hparams.sample_rate)
        wav = preemphasis(wav, hparams.preemphasis, hparams.preemphasize)

        linear_spec = np.abs(
            librosa.stft(wav,
                         n_fft=hparams.n_fft,
                         hop_length=hparams.hop_size,
                         win_length=hparams.win_size))
        mel_spec = mel_basis80.dot(linear_spec)
        mel_db = 20 * np.log10(mel_spec)
        source_spec = np.clip((mel_db + 120) / 125, 0, 1)

        source_embed = torch.from_numpy(np.array([0, 1])).float().unsqueeze(0)
        source_wav = wav

        source_spec = torch.Tensor(pad_seq(source_spec.T,
                                           hparams.freq)).unsqueeze(0)
        # print(source_spec.shape)

        with torch.no_grad():
            generated_spec, v_mid, v_hat = self.generate(
                source_spec, source_embed, device)

        generated_spec, v_mid, v_hat = generated_spec.cpu(), v_mid.cpu(
        ), v_hat.cpu()

        print("Generating Wavfile...")
        with torch.no_grad():
            if not self.multigpu:
                generated_wav = inv_preemphasis(
                    self.vocoder.generate(generated_spec.to(device).transpose(
                        2, 1),
                                          False,
                                          None,
                                          None,
                                          mu_law=True), hparams.preemphasis,
                    hparams.preemphasize)

            else:
                generated_wav = inv_preemphasis(
                    self.vocoder.module.generate(
                        generated_spec.to(device).transpose(2, 1),
                        False,
                        None,
                        None,
                        mu_law=True), hparams.preemphasis,
                    hparams.preemphasize)

        writer.add_video('generated', (v_hat.numpy() + 1) / 2,
                         global_step=niter)
        writer.add_video('mid', (v_mid.numpy() + 1) / 2, global_step=niter)
        writer.add_audio('ground_truth',
                         source_wav,
                         niter,
                         sample_rate=hparams.sample_rate)
        writer.add_audio('generated_wav',
                         generated_wav,
                         niter,
                         sample_rate=hparams.sample_rate)
class Generator(nn.Module):
    """Generator network."""
    def __init__(self,
                 dim_neck,
                 dim_emb,
                 dim_pre,
                 freq,
                 dim_spec=80,
                 is_train=False,
                 lr=0.001,
                 loss_content=True,
                 discriminator=False,
                 multigpu=False,
                 lambda_gan=0.0001,
                 lambda_wavenet=0.001,
                 args=None,
                 test_path_source=None,
                 test_path_target=None):
        super(Generator, self).__init__()

        self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec)
        self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec)
        self.postnet = Postnet(num_mel=dim_spec)

        if discriminator:
            self.dis = PatchDiscriminator(n_class=num_speakers)
            self.dis_criterion = GANLoss(use_lsgan=use_lsgan,
                                         tensor=torch.cuda.FloatTensor)
        else:
            self.dis = None

        self.loss_content = loss_content
        self.lambda_gan = lambda_gan
        self.lambda_wavenet = lambda_wavenet

        self.multigpu = multigpu
        self.prepare_test(dim_spec, test_path_source, test_path_target)

        self.vocoder = WaveRNN(rnn_dims=hparams.voc_rnn_dims,
                               fc_dims=hparams.voc_fc_dims,
                               bits=hparams.bits,
                               pad=hparams.voc_pad,
                               upsample_factors=hparams.voc_upsample_factors,
                               feat_dims=hparams.num_mels,
                               compute_dims=hparams.voc_compute_dims,
                               res_out_dims=hparams.voc_res_out_dims,
                               res_blocks=hparams.voc_res_blocks,
                               hop_length=hparams.hop_size,
                               sample_rate=hparams.sample_rate,
                               mode=hparams.voc_mode)

        if is_train:
            self.criterionIdt = torch.nn.L1Loss(reduction='mean')
            self.opt_encoder = torch.optim.Adam(self.encoder.parameters(),
                                                lr=lr)
            self.opt_decoder = torch.optim.Adam(itertools.chain(
                self.decoder.parameters(), self.postnet.parameters()),
                                                lr=lr)
            if discriminator:
                self.opt_dis = torch.optim.Adam(self.dis.parameters(), lr=lr)
            self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(),
                                                lr=hparams.voc_lr)
            self.vocoder_loss_func = F.cross_entropy  # Only for RAW

        if multigpu:
            self.encoder = nn.DataParallel(self.encoder)
            self.decoder = nn.DataParallel(self.decoder)
            self.postnet = nn.DataParallel(self.postnet)
            self.vocoder = nn.DataParallel(self.vocoder)
            if self.dis is not None:
                self.dis = nn.DataParallel(self.dis)

    def prepare_test(self, dim_spec, source_path=None, target_path=None):
        if source_path is None:
            source_path = "/mnt/lustre/dengkangle/cmu/datasets/audio/test/trump_02.wav"
        if target_path is None:
            target_path = "/mnt/lustre/dengkangle/cmu/datasets/audio/test/female.wav"
        # source_path = "/home/kangled/datasets/audio/Chaplin_01.wav"
        # target_path = "/home/kangled/datasets/audio/Obama_01.wav"

        mel_basis80 = librosa.filters.mel(hparams.sample_rate,
                                          hparams.n_fft,
                                          n_mels=80)

        wav, sr = librosa.load(source_path, hparams.sample_rate)
        wav = preemphasis(wav, hparams.preemphasis, hparams.preemphasize)
        linear_spec = np.abs(
            librosa.stft(wav,
                         n_fft=hparams.n_fft,
                         hop_length=hparams.hop_size,
                         win_length=hparams.win_size))
        mel_spec = mel_basis80.dot(linear_spec)
        mel_db = 20 * np.log10(mel_spec)
        source_spec = np.clip((mel_db + 120) / 125, 0, 1)
        # source_spec = mel_spec

        self.source_embed = torch.from_numpy(np.array([0, 1
                                                       ])).float().unsqueeze(0)
        self.source_wav = wav

        wav, sr = librosa.load(target_path, hparams.sample_rate)
        wav = preemphasis(wav, hparams.preemphasis, hparams.preemphasize)
        linear_spec = np.abs(
            librosa.stft(wav,
                         n_fft=hparams.n_fft,
                         hop_length=hparams.hop_size,
                         win_length=hparams.win_size))
        mel_spec = mel_basis80.dot(linear_spec)
        mel_db = 20 * np.log10(mel_spec)
        target_spec = np.clip((mel_db + 120) / 125, 0, 1)
        # target_spec = mel_spec

        self.target_embed = torch.from_numpy(np.array([1, 0
                                                       ])).float().unsqueeze(0)
        self.target_wav = wav

        self.source_spec = torch.Tensor(pad_seq(source_spec.T,
                                                hparams.freq)).unsqueeze(0)
        self.target_spec = torch.Tensor(pad_seq(target_spec.T,
                                                hparams.freq)).unsqueeze(0)

    def test_fixed(self, device):
        with torch.no_grad():
            t2s_spec = self.conversion(self.target_embed, self.source_embed,
                                       self.target_spec, device).cpu()
            s2s_spec = self.conversion(self.source_embed, self.source_embed,
                                       self.source_spec, device).cpu()
            s2t_spec = self.conversion(self.source_embed, self.target_embed,
                                       self.source_spec, device).cpu()
            t2t_spec = self.conversion(self.target_embed, self.target_embed,
                                       self.target_spec, device).cpu()

        ret_dic = {}
        ret_dic['A_fake_griffin'], sr = mel2wav(s2t_spec.numpy().squeeze(0).T)
        ret_dic['B_fake_griffin'], sr = mel2wav(t2s_spec.numpy().squeeze(0).T)
        ret_dic['A'] = self.source_wav
        ret_dic['B'] = self.target_wav

        with torch.no_grad():
            if not self.multigpu:
                ret_dic['A_fake_w'] = inv_preemphasis(
                    self.vocoder.generate(s2t_spec.to(device).transpose(2, 1),
                                          False,
                                          None,
                                          None,
                                          mu_law=True), hparams.preemphasis,
                    hparams.preemphasize)
                ret_dic['B_fake_w'] = inv_preemphasis(
                    self.vocoder.generate(t2s_spec.to(device).transpose(2, 1),
                                          False,
                                          None,
                                          None,
                                          mu_law=True), hparams.preemphasis,
                    hparams.preemphasize)
            else:
                ret_dic['A_fake_w'] = inv_preemphasis(
                    self.vocoder.module.generate(s2t_spec.to(device).transpose(
                        2, 1),
                                                 False,
                                                 None,
                                                 None,
                                                 mu_law=True),
                    hparams.preemphasis, hparams.preemphasize)
                ret_dic['B_fake_w'] = inv_preemphasis(
                    self.vocoder.module.generate(t2s_spec.to(device).transpose(
                        2, 1),
                                                 False,
                                                 None,
                                                 None,
                                                 mu_law=True),
                    hparams.preemphasis, hparams.preemphasize)
        return ret_dic, sr

    def conversion(self, speaker_org, speaker_trg, spec, device, speed=1):
        speaker_org, speaker_trg, spec = speaker_org.to(
            device), speaker_trg.to(device), spec.to(device)
        if not self.multigpu:
            codes = self.encoder(spec, speaker_org)
        else:
            codes = self.encoder.module(spec, speaker_org)
        tmp = []
        for code in codes:
            tmp.append(
                code.unsqueeze(1).expand(
                    -1, int(speed * spec.size(1) / len(codes)), -1))
        code_exp = torch.cat(tmp, dim=1)
        encoder_outputs = torch.cat((code_exp, speaker_trg.unsqueeze(1).expand(
            -1, code_exp.size(1), -1)),
                                    dim=-1)
        mel_outputs = self.decoder(
            code_exp) if not self.multigpu else self.decoder.module(code_exp)

        mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1))
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1)
        return mel_outputs_postnet

    def optimize_parameters(self,
                            dataloader,
                            epochs,
                            device,
                            display_freq=10,
                            save_freq=1000,
                            save_dir="./",
                            experimentName="Train",
                            load_model=None,
                            initial_niter=0):
        writer = SummaryWriter(log_dir="logs/" + experimentName)
        if load_model is not None:
            print("Loading from %s..." % load_model)
            # self.load_state_dict(torch.load(load_model))
            d = torch.load(load_model)
            newdict = d.copy()
            for key, value in d.items():
                newkey = key
                if 'wavenet' in key:
                    newdict[key.replace('wavenet',
                                        'vocoder')] = newdict.pop(key)
                    newkey = key.replace('wavenet', 'vocoder')
                if self.multigpu and 'module' not in key:
                    newdict[newkey.replace('.', '.module.',
                                           1)] = newdict.pop(newkey)
                    newkey = newkey.replace('.', '.module.', 1)
                if newkey not in self.state_dict():
                    newdict.pop(newkey)
            self.load_state_dict(newdict)
            print("AutoVC Model Loaded")
        niter = initial_niter
        for epoch in range(epochs):
            self.train()
            for i, data in enumerate(dataloader):
                speaker_org, spec, prev, wav = data
                loss_dict, loss_dict_discriminator, loss_dict_wavenet = \
                    self.train_step(spec.to(device), speaker_org.to(device), prev=prev.to(device), wav=wav.to(device), device=device)
                if niter % display_freq == 0:
                    print("Epoch[%d] Iter[%d] Niter[%d] %s %s %s" %
                          (epoch, i, niter, loss_dict, loss_dict_discriminator,
                           loss_dict_wavenet))
                    writer.add_scalars('data/Loss', loss_dict, niter)
                    if loss_dict_discriminator != {}:
                        writer.add_scalars('data/discriminator',
                                           loss_dict_discriminator, niter)
                    if loss_dict_wavenet != {}:
                        writer.add_scalars('data/wavenet', loss_dict_wavenet,
                                           niter)
                if niter % save_freq == 0:
                    print("Saving and Testing...", end='\t')
                    torch.save(
                        self.state_dict(),
                        save_dir + '/Epoch' + str(epoch).zfill(3) + '_Iter' +
                        str(niter).zfill(8) + ".pkl")
                    # self.load_state_dict(torch.load('params.pkl'))
                    if len(dataloader) >= 2:
                        wav_dic, sr = self.test_fixed(device)
                        for key, wav in wav_dic.items():
                            # print(wav.shape)
                            writer.add_audio(key, wav, niter, sample_rate=sr)
                    print("Done")
                    self.train()
                torch.cuda.empty_cache()  # Prevent Out of Memory
                niter += 1

    def train_step(self,
                   x,
                   c_org,
                   mask=None,
                   mask_code=None,
                   prev=None,
                   wav=None,
                   ret_content=False,
                   retain_graph=False,
                   device='cuda:0'):
        codes = self.encoder(x, c_org)
        # print(codes[0].shape)
        content = torch.cat([code.unsqueeze(1) for code in codes], dim=1)
        # print("content shape", content.shape)
        tmp = []
        for code in codes:
            tmp.append(
                code.unsqueeze(1).expand(-1, int(x.size(1) / len(codes)), -1))
        code_exp = torch.cat(tmp, dim=1)

        encoder_outputs = torch.cat(
            (code_exp, c_org.unsqueeze(1).expand(-1, x.size(1), -1)), dim=-1)

        mel_outputs = self.decoder(code_exp)

        mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1))
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1)

        loss_dict, loss_dict_discriminator, loss_dict_wavenet = {}, {}, {}

        loss_recon = self.criterionIdt(x, mel_outputs)
        loss_recon0 = self.criterionIdt(x, mel_outputs_postnet)
        loss_dict['recon'], loss_dict['recon0'] = loss_recon.data.item(
        ), loss_recon0.data.item()

        if self.loss_content:
            recons_codes = self.encoder(mel_outputs_postnet, c_org)
            recons_content = torch.cat(
                [code.unsqueeze(1) for code in recons_codes], dim=1)
            if mask is not None:
                loss_content = self.criterionIdt(
                    content.masked_select(mask_code.byte()),
                    recons_content.masked_select(mask_code.byte()))
            else:
                loss_content = self.criterionIdt(content, recons_content)
            loss_dict['content'] = loss_content.data.item()
        else:
            loss_content = torch.from_numpy(np.array(0))

        loss_gen, loss_dis, loss_vocoder = [torch.from_numpy(np.array(0))] * 3
        fake_mel = None
        if self.dis:
            # true_label = torch.from_numpy(np.ones(shape=(x.shape[0]))).to('cuda:0').long()
            # false_label = torch.from_numpy(np.zeros(shape=(x.shape[0]))).to('cuda:0').long()

            flip_speaker = 1 - c_org
            fake_mel = self.conversion(c_org, flip_speaker, x, device)

            loss_dis = self.dis_criterion(self.dis(x),
                                          True) + self.dis_criterion(
                                              self.dis(fake_mel), False)
            # +  self.dis_criterion(self.dis(mel_outputs_postnet), False)

            self.opt_dis.zero_grad()
            loss_dis.backward(retain_graph=True)
            self.opt_dis.step()
            loss_gen = self.dis_criterion(self.dis(fake_mel), True)
            # + self.dis_criterion(self.dis(mel_outputs_postnet), True)
            loss_dict_discriminator['dis'], loss_dict_discriminator[
                'gen'] = loss_dis.data.item(), loss_gen.data.item()

        if not self.multigpu:
            y_hat = self.vocoder(
                prev,
                self.vocoder.pad_tensor(mel_outputs_postnet,
                                        hparams.voc_pad).transpose(1, 2))
        else:
            y_hat = self.vocoder(
                prev,
                self.vocoder.module.pad_tensor(mel_outputs_postnet,
                                               hparams.voc_pad).transpose(
                                                   1, 2))
        y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
        # assert (0 <= wav < 2 ** 9).all()
        loss_vocoder = self.vocoder_loss_func(y_hat,
                                              wav.unsqueeze(-1).to(device))
        self.opt_vocoder.zero_grad()

        Loss = loss_recon + loss_recon0 + loss_content + \
               self.lambda_gan * loss_gen + self.lambda_wavenet * loss_vocoder
        loss_dict['total'] = Loss.data.item()
        self.opt_encoder.zero_grad()
        self.opt_decoder.zero_grad()
        Loss.backward(retain_graph=retain_graph)
        self.opt_encoder.step()
        self.opt_decoder.step()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.vocoder.parameters(),
                                                   65504.0)
        self.opt_vocoder.step()

        if ret_content:
            return loss_recon, loss_recon0, loss_content, Loss, content
        return loss_dict, loss_dict_discriminator, loss_dict_wavenet