Ejemplo n.º 1
0
def build_model(config, from_style, to_style):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    generator_ab = ResidualGenerator(config.image_size,
                                     config.num_residual_blocks).to(device)
    generator_ba = ResidualGenerator(config.image_size,
                                     config.num_residual_blocks).to(device)
    discriminator_a = Discriminator(config.image_size).to(device)
    discriminator_b = Discriminator(config.image_size).to(device)

    generator_ab_param = glob(
        os.path.join(config.checkpoint_dir, f"{from_style}2{to_style}",
                     f"generator_ab_{config.epoch-1}.pth"))
    generator_ba_param = glob(
        os.path.join(config.checkpoint_dir, f"{from_style}2{to_style}",
                     f"generator_ba_{config.epoch-1}.pth"))
    discriminator_a_param = glob(
        os.path.join(config.checkpoint_dir, f"{from_style}2{to_style}",
                     f"discriminator_a_{config.epoch-1}.pth"))
    discriminator_b_param = glob(
        os.path.join(config.checkpoint_dir, f"{from_style}2{to_style}",
                     f"discriminator_b_{config.epoch-1}.pth"))

    print(f"[*] Load checkpoint in {config.checkpoint_dir}")
    if not os.path.exists(
            os.path.join(config.checkpoint_dir, f"{from_style}2{to_style}")):
        os.makedirs(
            os.path.join(config.checkpoint_dir, f"{from_style}2{to_style}"))

    if len(
            os.listdir(
                os.path.join(config.checkpoint_dir,
                             f"{from_style}2{to_style}"))) == 0:
        print(f"[!] No checkpoint in {config.checkpoint_dir}")
        generator_ab.apply(weights_init)
        generator_ba.apply(weights_init)
        discriminator_a.apply(weights_init)
        discriminator_b.apply(weights_init)
    else:
        generator_ab.load_state_dict(
            torch.load(generator_ab_param[-1], map_location=device))
        generator_ba.load_state_dict(
            torch.load(generator_ba_param[-1], map_location=device))
        discriminator_a.load_state_dict(
            torch.load(discriminator_a_param[-1], map_location=device))
        discriminator_b.load_state_dict(
            torch.load(discriminator_b_param[-1], map_location=device))

    return generator_ab, generator_ba, discriminator_a, discriminator_b
Ejemplo n.º 2
0
def train(args,
          generator: Generator,
          discriminator: Discriminator,
          feature_extractor: FeatureExtractor,
          photo_dataloader,
          edge_smooth_dataloader,
          animation_dataloader,
          checkpoint_dir=None):

    tb_writter = SummaryWriter()

    gen_criterion = nn.BCELoss().to(args.device)
    disc_criterion = nn.BCELoss().to(args.device)
    content_criterion = nn.L1Loss().to(args.device)

    gen_optimizer = torch.optim.Adam(generator.parameters(),
                                     lr=args.lr,
                                     betas=(args.adam_beta, 0.999))
    disc_optimizer = torch.optim.Adam(discriminator.parameters(),
                                      lr=args.lr,
                                      betas=(args.adam_beta, 0.999))

    global_step = 0
    global_init_step = 0

    # The number of steps to skip when loading a checkpoint
    skipped_step = 0
    skipped_init_step = 0

    cur_epoch = 0
    cur_init_epoch = 0

    data_len = min(len(photo_dataloader), len(edge_smooth_dataloader),
                   len(animation_dataloader))

    if checkpoint_dir:
        try:
            checkpoint_dict = load(checkpoint_dir)
            generator.load_state_dict(checkpoint_dict['generator'])
            discriminator.load_state_dict(checkpoint_dict['discriminator'])
            gen_optimizer.load_state_dict(checkpoint_dict['gen_optimizer'])
            disc_optimizer.load_state_dict(checkpoint_dict['disc_optimizer'])
            global_step = checkpoint_dict['global_step']
            global_init_step = checkpoint_dict['global_init_step']

            cur_epoch = global_step // data_len
            cur_init_epoch = global_init_step // len(photo_dataloader)

            skipped_step = global_step % data_len
            skipped_init_step = global_init_step % len(photo_dataloader)

            logger.info("Start training with,")
            logger.info("In initialization step, epoch: %d, step: %d",
                        cur_init_epoch, skipped_init_step)
            logger.info("In main train step, epoch: %d, step: %d", cur_epoch,
                        skipped_step)
        except:
            logger.info("Wrong checkpoint path")

    t_total = data_len * args.n_epochs
    t_init_total = len(photo_dataloader) * args.n_init_epoch

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num photo examples = %d", len(photo_dataloader))
    logger.info("  Num edge_smooth examples = %d", len(edge_smooth_dataloader))
    logger.info("  Num animation examples = %d", len(animation_dataloader))

    logger.info("  Num Epochs = %d", args.n_epochs)
    logger.info("  Total train batch size = %d", args.batch_size)
    logger.info("  Total optimization steps = %d", t_total)

    logger.info("  Num Init Epochs = %d", args.n_init_epoch)
    logger.info("  Total Init optimization steps = %d", t_init_total)

    logger.info("  Logging steps = %d", args.logging_steps)
    logger.info("  Save steps = %d", args.save_steps)

    init_phase = True
    try:
        generator.train()
        discriminator.train()

        gloabl_init_loss = 0
        # --- Initialization Content loss
        mb = master_bar(range(cur_init_epoch, args.n_init_epoch))
        for init_epoch in mb:
            epoch_iter = progress_bar(photo_dataloader, parent=mb)
            for step, (photo, _) in enumerate(epoch_iter):
                if skipped_init_step > 0:
                    skipped_init_step = -1
                    continue

                photo = photo.to(args.device)

                gen_optimizer.zero_grad()
                x_features = feature_extractor((photo + 1) / 2).detach()
                Gx = generator(photo)
                Gx_features = feature_extractor((Gx + 1) / 2)

                content_loss = args.content_loss_weight * content_criterion(
                    Gx_features, x_features)
                content_loss.backward()
                gen_optimizer.step()

                gloabl_init_loss += content_loss.item()

                global_init_step += 1

                if args.save_steps > 0 and global_init_step % args.save_steps == 0:
                    logger.info(
                        "Save Initialization Phase, init_epoch: %d, init_step: %d",
                        init_epoch, global_init_step)
                    save(checkpoint_dir, global_step, global_init_step,
                         generator, discriminator, gen_optimizer,
                         disc_optimizer)

                if args.logging_steps > 0 and global_init_step % args.logging_steps == 0:
                    tb_writter.add_scalar('Initialization Phase/Content Loss',
                                          content_loss.item(),
                                          global_init_step)
                    tb_writter.add_scalar(
                        'Initialization Phase/Global Generator Loss',
                        gloabl_init_loss / global_init_step, global_init_step)

                    logger.info(
                        "Initialization Phase, Epoch: %d, Global Step: %d, Content Loss: %.4f",
                        init_epoch, global_init_step,
                        gloabl_init_loss / (global_init_step))

        # -----------------------------------------------------
        logger.info("Finish Initialization Phase, save model...")
        save(checkpoint_dir, global_step, global_init_step, generator,
             discriminator, gen_optimizer, disc_optimizer)

        init_phase = False
        global_loss_D = 0
        global_loss_G = 0
        global_loss_content = 0

        mb = master_bar(range(cur_epoch, args.n_epochs))
        for epoch in mb:
            epoch_iter = progress_bar(list(
                zip(animation_dataloader, edge_smooth_dataloader,
                    photo_dataloader)),
                                      parent=mb)
            for step, ((animation, _), (edge_smoothed, _),
                       (photo, _)) in enumerate(epoch_iter):
                if skipped_step > 0:
                    skipped_step = -1
                    continue

                animation = animation.to(args.device)
                edge_smoothed = edge_smoothed.to(args.device)
                photo = photo.to(args.device)

                disc_optimizer.zero_grad()
                # --- Train discriminator
                # ------ Train Discriminator with animation image
                animation_disc = discriminator(animation)
                animation_target = torch.ones_like(animation_disc)
                loss_animation_disc = disc_criterion(animation_disc,
                                                     animation_target)

                # ------ Train Discriminator with edge image
                edge_smoothed_disc = discriminator(edge_smoothed)
                edge_smoothed_target = torch.zeros_like(edge_smoothed_disc)
                loss_edge_disc = disc_criterion(edge_smoothed_disc,
                                                edge_smoothed_target)

                # ------ Train Discriminator with generated image
                generated_image = generator(photo).detach()

                generated_image_disc = discriminator(generated_image)
                generated_image_target = torch.zeros_like(generated_image_disc)
                loss_generated_disc = disc_criterion(generated_image_disc,
                                                     generated_image_target)

                loss_disc = loss_animation_disc + loss_edge_disc + loss_generated_disc

                loss_disc.backward()
                disc_optimizer.step()

                global_loss_D += loss_disc.item()

                # --- Train Generator
                gen_optimizer.zero_grad()

                generated_image = generator(photo)

                generated_image_disc = discriminator(generated_image)
                generated_image_target = torch.ones_like(generated_image_disc)
                loss_adv = gen_criterion(generated_image_disc,
                                         generated_image_target)

                # ------ Train Generator with content loss
                x_features = feature_extractor((photo + 1) / 2).detach()
                Gx_features = feature_extractor((generated_image + 1) / 2)

                loss_content = args.content_loss_weight * content_criterion(
                    Gx_features, x_features)

                loss_gen = loss_adv + loss_content
                loss_gen.backward()
                gen_optimizer.step()

                global_loss_G += loss_adv.item()
                global_loss_content += loss_content.item()

                global_step += 1

                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    logger.info("Save Training Phase, epoch: %d, step: %d",
                                epoch, global_step)
                    save(checkpoint_dir, global_step, global_init_step,
                         generator, discriminator, gen_optimizer,
                         disc_optimizer)

                if args.logging_steps > 0 and global_init_step % args.logging_steps == 0:
                    tb_writter.add_scalar('Train Phase/Generator Loss',
                                          loss_adv.item(), global_step)
                    tb_writter.add_scalar('Train Phase/Discriminator Loss',
                                          loss_disc.item(), global_step)
                    tb_writter.add_scalar('Train Phase/Content Loss',
                                          loss_content.item(), global_step)
                    tb_writter.add_scalar('Train Phase/Global Generator Loss',
                                          global_loss_G / global_step,
                                          global_step)
                    tb_writter.add_scalar(
                        'Train Phase/Global Discriminator Loss',
                        global_loss_D / global_step, global_step)
                    tb_writter.add_scalar('Train Phase/Global Content Loss',
                                          global_loss_content / global_step,
                                          global_step)

                    logger.info(
                        "Training Phase, Epoch: %d, Global Step: %d, Disc Loss %.4f, Gen Loss %.4f, Content Loss: %.4f",
                        epoch, global_step, global_loss_D / global_step,
                        global_loss_G / global_step,
                        global_loss_content / global_step)

    except KeyboardInterrupt:

        if init_phase:
            logger.info("KeyboardInterrupt in Initialization Phase!")
            logger.info("Save models, init_epoch: %d, init_step: %d",
                        init_epoch, global_init_step)
        else:
            logger.info("KeyboardInterrupt in Training Phase!")
            logger.info("Save models, epoch: %d, step: %d", epoch, global_step)

        save(checkpoint_dir, global_step, global_init_step, generator,
             discriminator, gen_optimizer, disc_optimizer)
Ejemplo n.º 3
0
class MaskCycleGANVCTrainer(object):
    """Trainer for MaskCycleGAN-VC
    """
    def __init__(self, args):
        """
        Args:
            args (Namespace): Program arguments from argparser
        """
        # Store args
        self.num_epochs = args.num_epochs
        self.start_epoch = args.start_epoch
        self.generator_lr = args.generator_lr
        self.discriminator_lr = args.discriminator_lr
        self.decay_after = args.decay_after
        self.mini_batch_size = args.batch_size
        self.cycle_loss_lambda = args.cycle_loss_lambda
        self.identity_loss_lambda = args.identity_loss_lambda
        self.device = args.device
        self.epochs_per_save = args.epochs_per_save
        self.sample_rate = args.sample_rate
        self.validation_A_dir = os.path.join(args.origin_data_dir,
                                             args.speaker_A_id)
        self.output_A_dir = os.path.join(args.output_data_dir,
                                         args.speaker_A_id)
        self.validation_B_dir = os.path.join(args.origin_data_dir,
                                             args.speaker_B_id)
        self.output_B_dir = os.path.join(args.output_data_dir,
                                         args.speaker_B_id)
        self.infer_data_dir = args.infer_data_dir
        self.pretrain_models = args.pretrain_models

        # Initialize speakerA's dataset
        self.dataset_A = self.loadPickleFile(
            os.path.join(
                args.preprocessed_data_dir,
                args.speaker_A_id,
                f"{args.speaker_A_id}_normalized.pickle",
            ))
        dataset_A_norm_stats = np.load(
            os.path.join(
                args.preprocessed_data_dir,
                args.speaker_A_id,
                f"{args.speaker_A_id}_norm_stat.npz",
            ))
        self.dataset_A_mean = dataset_A_norm_stats["mean"]
        self.dataset_A_std = dataset_A_norm_stats["std"]

        # Initialize speakerB's dataset
        self.dataset_B = self.loadPickleFile(
            os.path.join(
                args.preprocessed_data_dir,
                args.speaker_B_id,
                f"{args.speaker_B_id}_normalized.pickle",
            ))
        dataset_B_norm_stats = np.load(
            os.path.join(
                args.preprocessed_data_dir,
                args.speaker_B_id,
                f"{args.speaker_B_id}_norm_stat.npz",
            ))
        self.dataset_B_mean = dataset_B_norm_stats["mean"]
        self.dataset_B_std = dataset_B_norm_stats["std"]

        # Compute lr decay rate
        self.n_samples = len(self.dataset_A)
        print(f"n_samples = {self.n_samples}")
        self.generator_lr_decay = self.generator_lr / float(
            self.num_epochs * (self.n_samples // self.mini_batch_size))
        self.discriminator_lr_decay = self.discriminator_lr / float(
            self.num_epochs * (self.n_samples // self.mini_batch_size))
        print(f"generator_lr_decay = {self.generator_lr_decay}")
        print(f"discriminator_lr_decay = {self.discriminator_lr_decay}")

        # Initialize Train Dataloader
        self.num_frames = args.num_frames

        self.dataset = VCDataset(
            datasetA=self.dataset_A,
            datasetB=self.dataset_B,
            n_frames=args.num_frames,
            max_mask_len=args.max_mask_len,
        )
        self.train_dataloader = flow.utils.data.DataLoader(
            dataset=self.dataset,
            batch_size=self.mini_batch_size,
            shuffle=True,
            drop_last=False,
        )

        # Initialize Generators and Discriminators
        self.generator_A2B = Generator().to(self.device)
        self.generator_B2A = Generator().to(self.device)
        self.discriminator_A = Discriminator().to(self.device)
        self.discriminator_B = Discriminator().to(self.device)
        # Discriminator to compute 2 step adversarial loss
        self.discriminator_A2 = Discriminator().to(self.device)
        # Discriminator to compute 2 step adversarial loss
        self.discriminator_B2 = Discriminator().to(self.device)

        # Initialize Optimizers
        g_params = list(self.generator_A2B.parameters()) + list(
            self.generator_B2A.parameters())
        d_params = (list(self.discriminator_A.parameters()) +
                    list(self.discriminator_B.parameters()) +
                    list(self.discriminator_A2.parameters()) +
                    list(self.discriminator_B2.parameters()))
        self.generator_optimizer = flow.optim.Adam(g_params,
                                                   lr=self.generator_lr,
                                                   betas=(0.5, 0.999))
        self.discriminator_optimizer = flow.optim.Adam(
            d_params, lr=self.discriminator_lr, betas=(0.5, 0.999))

    def adjust_lr_rate(self, optimizer, generator):
        """Decays learning rate.

        Args:
            optimizer (torch.optim): torch optimizer
            generator (bool): Whether to adjust generator lr.
        """
        if generator:
            self.generator_lr = max(
                0.0, self.generator_lr - self.generator_lr_decay)
            for param_groups in optimizer.param_groups:
                param_groups["lr"] = self.generator_lr
        else:
            self.discriminator_lr = max(
                0.0, self.discriminator_lr - self.discriminator_lr_decay)
            for param_groups in optimizer.param_groups:
                param_groups["lr"] = self.discriminator_lr

    def reset_grad(self):
        """Sets gradients of the generators and discriminators to zero before backpropagation.
        """
        self.generator_optimizer.zero_grad()
        self.discriminator_optimizer.zero_grad()

    def loadPickleFile(self, fileName):
        """Loads a Pickle file.

        Args:
            fileName (str): pickle file path

        Returns:
            file object: The loaded pickle file object
        """
        with open(fileName, "rb") as f:
            return pickle.load(f)

    def train(self):
        """Implements the training loop for MaskCycleGAN-VC
        """
        for epoch in range(self.start_epoch, self.num_epochs + 1):

            for i, (real_A, mask_A, real_B,
                    mask_B) in enumerate(self.train_dataloader):
                num_iterations = (self.n_samples //
                                  self.mini_batch_size) * epoch + i
                if num_iterations > 10000:
                    self.identity_loss_lambda = 0
                if num_iterations > self.decay_after:
                    self.adjust_lr_rate(self.generator_optimizer,
                                        generator=True)
                    self.adjust_lr_rate(self.generator_optimizer,
                                        generator=False)

                real_A = real_A.to(self.device, dtype=flow.float)
                mask_A = mask_A.to(self.device, dtype=flow.float)
                real_B = real_B.to(self.device, dtype=flow.float)
                mask_B = mask_B.to(self.device, dtype=flow.float)

                # Train Generator
                self.generator_A2B.train()
                self.generator_B2A.train()
                self.discriminator_A.eval()
                self.discriminator_B.eval()
                self.discriminator_A2.eval()
                self.discriminator_B2.eval()

                # Generator Feed Forward
                fake_B = self.generator_A2B(real_A, mask_A)
                cycle_A = self.generator_B2A(fake_B, flow.ones_like(fake_B))
                fake_A = self.generator_B2A(real_B, mask_B)
                cycle_B = self.generator_A2B(fake_A, flow.ones_like(fake_A))
                identity_A = self.generator_B2A(real_A, flow.ones_like(real_A))
                identity_B = self.generator_A2B(real_B, flow.ones_like(real_B))
                d_fake_A = self.discriminator_A(fake_A)
                d_fake_B = self.discriminator_B(fake_B)

                # For Two Step Adverserial Loss
                d_fake_cycle_A = self.discriminator_A2(cycle_A)
                d_fake_cycle_B = self.discriminator_B2(cycle_B)

                # Generator Cycle Loss
                cycleLoss = flow.mean(flow.abs(real_A - cycle_A)) + flow.mean(
                    flow.abs(real_B - cycle_B))

                # Generator Identity Loss
                identityLoss = flow.mean(
                    flow.abs(real_A - identity_A)) + flow.mean(
                        flow.abs(real_B - identity_B))

                # Generator Loss
                g_loss_A2B = flow.mean((1 - d_fake_B)**2)
                g_loss_B2A = flow.mean((1 - d_fake_A)**2)

                # Generator Two Step Adverserial Loss
                generator_loss_A2B_2nd = flow.mean((1 - d_fake_cycle_B)**2)
                generator_loss_B2A_2nd = flow.mean((1 - d_fake_cycle_A)**2)

                # Total Generator Loss
                g_loss = (g_loss_A2B + g_loss_B2A + generator_loss_A2B_2nd +
                          generator_loss_B2A_2nd +
                          self.cycle_loss_lambda * cycleLoss +
                          self.identity_loss_lambda * identityLoss)

                # Backprop for Generator
                self.reset_grad()
                g_loss.backward()
                self.generator_optimizer.step()

                # Train Discriminator
                self.generator_A2B.eval()
                self.generator_B2A.eval()
                self.discriminator_A.train()
                self.discriminator_B.train()
                self.discriminator_A2.train()
                self.discriminator_B2.train()

                # Discriminator Feed Forward
                d_real_A = self.discriminator_A(real_A)
                d_real_B = self.discriminator_B(real_B)
                d_real_A2 = self.discriminator_A2(real_A)
                d_real_B2 = self.discriminator_B2(real_B)
                generated_A = self.generator_B2A(real_B, mask_B)
                d_fake_A = self.discriminator_A(generated_A)

                # For Two Step Adverserial Loss A->B
                cycled_B = self.generator_A2B(generated_A,
                                              flow.ones_like(generated_A))
                d_cycled_B = self.discriminator_B2(cycled_B)

                generated_B = self.generator_A2B(real_A, mask_A)
                d_fake_B = self.discriminator_B(generated_B)

                # For Two Step Adverserial Loss B->A
                cycled_A = self.generator_B2A(generated_B,
                                              flow.ones_like(generated_B))
                d_cycled_A = self.discriminator_A2(cycled_A)

                # Loss Functions
                d_loss_A_real = flow.mean((1 - d_real_A)**2)
                d_loss_A_fake = flow.mean((0 - d_fake_A)**2)
                d_loss_A = (d_loss_A_real + d_loss_A_fake) / 2.0

                d_loss_B_real = flow.mean((1 - d_real_B)**2)
                d_loss_B_fake = flow.mean((0 - d_fake_B)**2)
                d_loss_B = (d_loss_B_real + d_loss_B_fake) / 2.0

                # Two Step Adverserial Loss
                d_loss_A_cycled = flow.mean((0 - d_cycled_A)**2)
                d_loss_B_cycled = flow.mean((0 - d_cycled_B)**2)
                d_loss_A2_real = flow.mean((1 - d_real_A2)**2)
                d_loss_B2_real = flow.mean((1 - d_real_B2)**2)
                d_loss_A_2nd = (d_loss_A2_real + d_loss_A_cycled) / 2.0
                d_loss_B_2nd = (d_loss_B2_real + d_loss_B_cycled) / 2.0

                # Final Loss for discriminator with the Two Step Adverserial Loss
                d_loss = (d_loss_A + d_loss_B) / 2.0 + (d_loss_A_2nd +
                                                        d_loss_B_2nd) / 2.0

                # Backprop for Discriminator
                self.reset_grad()
                d_loss.backward()
                self.discriminator_optimizer.step()

                if (i + 1) % 2 == 0:
                    print(
                        "Iter:{} Generator Loss:{:.4f} Discrimator Loss:{:.4f} GA2B:{:.4f} GB2A:{:.4f} G_id:{:.4f} G_cyc:{:.4f} D_A:{:.4f} D_B:{:.4f}"
                        .format(
                            num_iterations,
                            g_loss.item(),
                            d_loss.item(),
                            g_loss_A2B,
                            g_loss_B2A,
                            identityLoss,
                            cycleLoss,
                            d_loss_A,
                            d_loss_B,
                        ))

            # Save each model checkpoint and validation
            if epoch % self.epochs_per_save == 0 and epoch != 0:
                self.saveModelCheckPoint(epoch, PATH="model_checkpoint")
                self.validation_for_A_dir()
                self.validation_for_B_dir()

    def infer(self):
        """Implements the infering loop for MaskCycleGAN-VC
        """
        # load pretrain models
        self.loadModel(self.pretrain_models)

        num_mcep = 80
        sampling_rate = self.sample_rate
        frame_period = 5.0
        infer_A_dir = self.infer_data_dir

        print("Generating Validation Data B from A...")
        for file in os.listdir(infer_A_dir):
            filePath = os.path.join(infer_A_dir, file)
            wav, _ = librosa.load(filePath, sr=sampling_rate, mono=True)
            wav = preprocess.wav_padding(wav=wav,
                                         sr=sampling_rate,
                                         frame_period=frame_period,
                                         multiple=4)
            f0, timeaxis, sp, ap = preprocess.world_decompose(
                wav=wav, fs=sampling_rate, frame_period=frame_period)
            f0_converted = preprocess.pitch_conversion(
                f0=f0,
                mean_log_src=self.dataset_A_mean,
                std_log_src=self.dataset_A_std,
                mean_log_target=self.dataset_B_mean,
                std_log_target=self.dataset_B_std,
            )
            coded_sp = preprocess.world_encode_spectral_envelop(
                sp=sp, fs=sampling_rate, dim=num_mcep)
            coded_sp_transposed = coded_sp.T
            coded_sp_norm = (coded_sp_transposed -
                             self.dataset_A_mean) / self.dataset_A_std
            coded_sp_norm = np.array([coded_sp_norm])

            if flow.cuda.is_available():
                coded_sp_norm = flow.tensor(coded_sp_norm).cuda().float()
            else:
                coded_sp_norm = flow.tensor(coded_sp_norm).float()

            coded_sp_converted_norm = self.generator_A2B(
                coded_sp_norm, flow.ones_like(coded_sp_norm))
            coded_sp_converted_norm = coded_sp_converted_norm.cpu().detach(
            ).numpy()
            coded_sp_converted_norm = np.squeeze(coded_sp_converted_norm)
            coded_sp_converted = (
                coded_sp_converted_norm * self.dataset_B_std +
                self.dataset_B_mean)
            coded_sp_converted = coded_sp_converted.T
            coded_sp_converted = np.ascontiguousarray(
                coded_sp_converted).astype(np.double)
            decoded_sp_converted = preprocess.world_decode_spectral_envelop(
                coded_sp=coded_sp_converted, fs=sampling_rate)

            wav_transformed = preprocess.world_speech_synthesis(
                f0=f0_converted[0],
                decoded_sp=decoded_sp_converted,
                ap=ap,
                fs=sampling_rate,
                frame_period=frame_period,
            )

            sf.write(
                os.path.join(infer_A_dir, "convert_" + os.path.basename(file)),
                wav_transformed,
                sampling_rate,
            )

    def validation_for_A_dir(self):
        num_mcep = 80
        sampling_rate = 22050
        frame_period = 5.0
        validation_A_dir = self.validation_A_dir
        output_A_dir = self.output_A_dir

        os.makedirs(output_A_dir, exist_ok=True)

        print("Generating Validation Data B from A...")
        for file in os.listdir(validation_A_dir):
            filePath = os.path.join(validation_A_dir, file)
            wav, _ = librosa.load(filePath, sr=sampling_rate, mono=True)
            wav = preprocess.wav_padding(wav=wav,
                                         sr=sampling_rate,
                                         frame_period=frame_period,
                                         multiple=4)
            f0, timeaxis, sp, ap = preprocess.world_decompose(
                wav=wav, fs=sampling_rate, frame_period=frame_period)
            f0_converted = preprocess.pitch_conversion(
                f0=f0,
                mean_log_src=self.dataset_A_mean,
                std_log_src=self.dataset_A_std,
                mean_log_target=self.dataset_B_mean,
                std_log_target=self.dataset_B_std,
            )
            coded_sp = preprocess.world_encode_spectral_envelop(
                sp=sp, fs=sampling_rate, dim=num_mcep)
            coded_sp_transposed = coded_sp.T
            coded_sp_norm = (coded_sp_transposed -
                             self.dataset_A_mean) / self.dataset_A_std
            coded_sp_norm = np.array([coded_sp_norm])

            if flow.cuda.is_available():
                coded_sp_norm = flow.tensor(coded_sp_norm).cuda().float()
            else:
                coded_sp_norm = flow.tensor(coded_sp_norm).float()

            coded_sp_converted_norm = self.generator_A2B(
                coded_sp_norm, flow.ones_like(coded_sp_norm))
            coded_sp_converted_norm = coded_sp_converted_norm.cpu().detach(
            ).numpy()
            coded_sp_converted_norm = np.squeeze(coded_sp_converted_norm)
            coded_sp_converted = (
                coded_sp_converted_norm * self.dataset_B_std +
                self.dataset_B_mean)
            coded_sp_converted = coded_sp_converted.T
            coded_sp_converted = np.ascontiguousarray(
                coded_sp_converted).astype(np.double)
            decoded_sp_converted = preprocess.world_decode_spectral_envelop(
                coded_sp=coded_sp_converted, fs=sampling_rate)

            wav_transformed = preprocess.world_speech_synthesis(
                f0=f0_converted[0],
                decoded_sp=decoded_sp_converted,
                ap=ap,
                fs=sampling_rate,
                frame_period=frame_period,
            )

            sf.write(
                os.path.join(output_A_dir,
                             "convert_" + os.path.basename(file)),
                wav_transformed,
                sampling_rate,
            )

    def validation_for_B_dir(self):
        num_mcep = 80
        sampling_rate = 22050
        frame_period = 5.0
        validation_B_dir = self.validation_B_dir
        output_B_dir = self.output_B_dir

        os.makedirs(output_B_dir, exist_ok=True)

        print("Generating Validation Data A from B...")
        for file in os.listdir(validation_B_dir):
            filePath = os.path.join(validation_B_dir, file)
            wav, _ = librosa.load(filePath, sr=sampling_rate, mono=True)
            wav = preprocess.wav_padding(wav=wav,
                                         sr=sampling_rate,
                                         frame_period=frame_period,
                                         multiple=4)
            f0, timeaxis, sp, ap = preprocess.world_decompose(
                wav=wav, fs=sampling_rate, frame_period=frame_period)
            f0_converted = preprocess.pitch_conversion(
                f0=f0,
                mean_log_src=self.dataset_B_mean,
                std_log_src=self.dataset_B_std,
                mean_log_target=self.dataset_A_mean,
                std_log_target=self.dataset_A_std,
            )
            coded_sp = preprocess.world_encode_spectral_envelop(
                sp=sp, fs=sampling_rate, dim=num_mcep)
            coded_sp_transposed = coded_sp.T
            coded_sp_norm = (coded_sp_transposed -
                             self.dataset_B_mean) / self.dataset_B_std
            coded_sp_norm = np.array([coded_sp_norm])

            if flow.cuda.is_available():
                coded_sp_norm = flow.tensor(coded_sp_norm).cuda().float()
            else:
                coded_sp_norm = flow.tensor(coded_sp_norm).float()

            coded_sp_converted_norm = self.generator_B2A(
                coded_sp_norm, flow.ones_like(coded_sp_norm))
            coded_sp_converted_norm = coded_sp_converted_norm.cpu().detach(
            ).numpy()
            coded_sp_converted_norm = np.squeeze(coded_sp_converted_norm)
            coded_sp_converted = (
                coded_sp_converted_norm * self.dataset_A_std +
                self.dataset_A_mean)
            coded_sp_converted = coded_sp_converted.T
            coded_sp_converted = np.ascontiguousarray(
                coded_sp_converted).astype(np.double)
            decoded_sp_converted = preprocess.world_decode_spectral_envelop(
                coded_sp=coded_sp_converted, fs=sampling_rate)

            wav_transformed = preprocess.world_speech_synthesis(
                f0=f0_converted[0],
                decoded_sp=decoded_sp_converted,
                ap=ap,
                fs=sampling_rate,
                frame_period=frame_period,
            )

            sf.write(
                os.path.join(output_B_dir,
                             "convert_" + os.path.basename(file)),
                wav_transformed,
                sampling_rate,
            )

    def saveModelCheckPoint(self, epoch, PATH):
        flow.save(
            self.generator_A2B.state_dict(),
            os.path.join(PATH, "generator_A2B_%d" % epoch),
        )
        flow.save(
            self.generator_B2A.state_dict(),
            os.path.join(PATH, "generator_B2A_%d" % epoch),
        )
        flow.save(
            self.discriminator_A.state_dict(),
            os.path.join(PATH, "discriminator_A_%d" % epoch),
        )
        flow.save(
            self.discriminator_B.state_dict(),
            os.path.join(PATH, "discriminator_B_%d" % epoch),
        )

    def loadModel(self, PATH):
        self.generator_A2B.load_state_dict(
            flow.load(os.path.join(PATH, "generator_A2B")))
        self.generator_B2A.load_state_dict(
            flow.load(os.path.join(PATH, "generator_B2A")))
        self.discriminator_A.load_state_dict(
            flow.load(os.path.join(PATH, "discriminator_A")))
        self.discriminator_B.load_state_dict(
            flow.load(os.path.join(PATH, "discriminator_B")))
Ejemplo n.º 4
0
class CycleGANTrainr(object):
    def __init__(
        self,
        logf0s_normalization,
        mcep_normalization,
        coded_sps_A_norm,
        coded_sps_B_norm,
        model_checkpoint,
        validation_A_dir,
        output_A_dir,
        validation_B_dir,
        output_B_dir,
        restart_training_at=None,
    ):
        self.start_epoch = 0
        self.num_epochs = 200000
        self.mini_batch_size = 10
        self.dataset_A = self.loadPickleFile(coded_sps_A_norm)
        self.dataset_B = self.loadPickleFile(coded_sps_B_norm)
        self.device = flow.device(
            "cuda" if flow.cuda.is_available() else "cpu")

        # Speech Parameters
        logf0s_normalization = np.load(logf0s_normalization)
        self.log_f0s_mean_A = logf0s_normalization["mean_A"]
        self.log_f0s_std_A = logf0s_normalization["std_A"]
        self.log_f0s_mean_B = logf0s_normalization["mean_B"]
        self.log_f0s_std_B = logf0s_normalization["std_B"]

        mcep_normalization = np.load(mcep_normalization)
        self.coded_sps_A_mean = mcep_normalization["mean_A"]
        self.coded_sps_A_std = mcep_normalization["std_A"]
        self.coded_sps_B_mean = mcep_normalization["mean_B"]
        self.coded_sps_B_std = mcep_normalization["std_B"]

        # Generator and Discriminator
        self.generator_A2B = Generator().to(self.device)
        self.generator_B2A = Generator().to(self.device)
        self.discriminator_A = Discriminator().to(self.device)
        self.discriminator_B = Discriminator().to(self.device)

        # Loss Functions
        criterion_mse = flow.nn.MSELoss()

        # Optimizer
        g_params = list(self.generator_A2B.parameters()) + list(
            self.generator_B2A.parameters())
        d_params = list(self.discriminator_A.parameters()) + list(
            self.discriminator_B.parameters())

        # Initial learning rates
        self.generator_lr = 2e-4
        self.discriminator_lr = 1e-4

        # Learning rate decay
        self.generator_lr_decay = self.generator_lr / 200000
        self.discriminator_lr_decay = self.discriminator_lr / 200000

        # Starts learning rate decay from after this many iterations have passed
        self.start_decay = 10000

        self.generator_optimizer = flow.optim.Adam(g_params,
                                                   lr=self.generator_lr,
                                                   betas=(0.5, 0.999))
        self.discriminator_optimizer = flow.optim.Adam(
            d_params, lr=self.discriminator_lr, betas=(0.5, 0.999))

        # To Load save previously saved models
        self.modelCheckpoint = model_checkpoint
        os.makedirs(self.modelCheckpoint, exist_ok=True)

        # Validation set Parameters
        self.validation_A_dir = validation_A_dir
        self.output_A_dir = output_A_dir
        os.makedirs(self.output_A_dir, exist_ok=True)
        self.validation_B_dir = validation_B_dir
        self.output_B_dir = output_B_dir
        os.makedirs(self.output_B_dir, exist_ok=True)

        # Storing Discriminatior and Generator Loss
        self.generator_loss_store = []
        self.discriminator_loss_store = []

        self.file_name = "log_store_non_sigmoid.txt"

    def adjust_lr_rate(self, optimizer, name="generator"):
        if name == "generator":
            self.generator_lr = max(
                0.0, self.generator_lr - self.generator_lr_decay)
            for param_groups in optimizer.param_groups:
                param_groups["lr"] = self.generator_lr
        else:
            self.discriminator_lr = max(
                0.0, self.discriminator_lr - self.discriminator_lr_decay)
            for param_groups in optimizer.param_groups:
                param_groups["lr"] = self.discriminator_lr

    def reset_grad(self):
        self.generator_optimizer.zero_grad()
        self.discriminator_optimizer.zero_grad()

    def train(self):
        # Training Begins
        for epoch in range(self.start_epoch, self.num_epochs):
            start_time_epoch = time.time()

            # Constants
            cycle_loss_lambda = 10
            identity_loss_lambda = 5

            # Preparing Dataset
            n_samples = len(self.dataset_A)

            dataset = trainingDataset(datasetA=self.dataset_A,
                                      datasetB=self.dataset_B,
                                      n_frames=128)

            train_loader = flow.utils.data.DataLoader(
                dataset=dataset,
                batch_size=self.mini_batch_size,
                shuffle=True,
                drop_last=False,
            )

            pbar = tqdm(enumerate(train_loader))

            for i, (real_A, real_B) in enumerate(train_loader):

                num_iterations = (n_samples //
                                  self.mini_batch_size) * epoch + i

                if num_iterations > 10000:
                    identity_loss_lambda = 0
                if num_iterations > self.start_decay:
                    self.adjust_lr_rate(self.generator_optimizer,
                                        name="generator")
                    self.adjust_lr_rate(self.generator_optimizer,
                                        name="discriminator")

                real_A = real_A.to(self.device).float()
                real_B = real_B.to(self.device).float()

                # Generator Loss function
                fake_B = self.generator_A2B(real_A)
                cycle_A = self.generator_B2A(fake_B)

                fake_A = self.generator_B2A(real_B)
                cycle_B = self.generator_A2B(fake_A)

                identity_A = self.generator_B2A(real_A)
                identity_B = self.generator_A2B(real_B)

                d_fake_A = self.discriminator_A(fake_A)
                d_fake_B = self.discriminator_B(fake_B)

                # for the second step adverserial loss
                d_fake_cycle_A = self.discriminator_A(cycle_A)
                d_fake_cycle_B = self.discriminator_B(cycle_B)

                # Generator Cycle loss
                cycleLoss = flow.mean(flow.abs(real_A - cycle_A)) + flow.mean(
                    flow.abs(real_B - cycle_B))

                # Generator Identity Loss
                identiyLoss = flow.mean(
                    flow.abs(real_A - identity_A)) + flow.mean(
                        flow.abs(real_B - identity_B))

                # Generator Loss
                generator_loss_A2B = flow.mean((1 - d_fake_B)**2)
                generator_loss_B2A = flow.mean((1 - d_fake_A)**2)

                # Total Generator Loss
                generator_loss = (generator_loss_A2B + generator_loss_B2A +
                                  cycle_loss_lambda * cycleLoss +
                                  identity_loss_lambda * identiyLoss)
                self.generator_loss_store.append(generator_loss.item())

                # Backprop for Generator
                self.reset_grad()
                generator_loss.backward()

                self.generator_optimizer.step()

                # Discriminator Feed Forward
                d_real_A = self.discriminator_A(real_A)
                d_real_B = self.discriminator_B(real_B)

                generated_A = self.generator_B2A(real_B)
                d_fake_A = self.discriminator_A(generated_A)

                # for the second step adverserial loss
                cycled_B = self.generator_A2B(generated_A)
                d_cycled_B = self.discriminator_B(cycled_B)

                generated_B = self.generator_A2B(real_A)
                d_fake_B = self.discriminator_B(generated_B)

                # for the second step adverserial loss
                cycled_A = self.generator_B2A(generated_B)
                d_cycled_A = self.discriminator_A(cycled_A)

                # Loss Functions
                d_loss_A_real = flow.mean((1 - d_real_A)**2)
                d_loss_A_fake = flow.mean((0 - d_fake_A)**2)
                d_loss_A = (d_loss_A_real + d_loss_A_fake) / 2.0

                d_loss_B_real = flow.mean((1 - d_real_B)**2)
                d_loss_B_fake = flow.mean((0 - d_fake_B)**2)
                d_loss_B = (d_loss_B_real + d_loss_B_fake) / 2.0

                # the second step adverserial loss
                d_loss_A_cycled = flow.mean((0 - d_cycled_A)**2)
                d_loss_B_cycled = flow.mean((0 - d_cycled_B)**2)
                d_loss_A_2nd = (d_loss_A_real + d_loss_A_cycled) / 2.0
                d_loss_B_2nd = (d_loss_B_real + d_loss_B_cycled) / 2.0

                # Final Loss for discriminator with the second step adverserial loss
                d_loss = (d_loss_A + d_loss_B) / 2.0 + (d_loss_A_2nd +
                                                        d_loss_B_2nd) / 2.0
                self.discriminator_loss_store.append(d_loss.item())

                # Backprop for Discriminator
                self.reset_grad()
                d_loss.backward()

                self.discriminator_optimizer.step()

                if (i + 1) % 2 == 0:
                    pbar.set_description(
                        "Iter:{} Generator Loss:{:.4f} Discrimator Loss:{:.4f} GA2B:{:.4f} GB2A:{:.4f} G_id:{:.4f} G_cyc:{:.4f} D_A:{:.4f} D_B:{:.4f}"
                        .format(
                            num_iterations,
                            generator_loss.item(),
                            d_loss.item(),
                            generator_loss_A2B,
                            generator_loss_B2A,
                            identiyLoss,
                            cycleLoss,
                            d_loss_A,
                            d_loss_B,
                        ))

            if epoch % 2000 == 0 and epoch != 0:
                end_time = time.time()
                store_to_file = "Epoch: {} Generator Loss: {:.4f} Discriminator Loss: {}, Time: {:.2f}\n\n".format(
                    epoch,
                    generator_loss.item(),
                    d_loss.item(),
                    end_time - start_time_epoch,
                )
                self.store_to_file(store_to_file)
                print(
                    "Epoch: {} Generator Loss: {:.4f} Discriminator Loss: {}, Time: {:.2f}\n\n"
                    .format(
                        epoch,
                        generator_loss.item(),
                        d_loss.item(),
                        end_time - start_time_epoch,
                    ))

                # Save the Entire model
                print("Saving model Checkpoint  ......")
                store_to_file = "Saving model Checkpoint  ......"
                self.store_to_file(store_to_file)
                self.saveModelCheckPoint(epoch, self.modelCheckpoint)
                print("Model Saved!")

            if epoch % 2000 == 0 and epoch != 0:
                # Validation Set
                validation_start_time = time.time()
                self.validation_for_A_dir()
                self.validation_for_B_dir()
                validation_end_time = time.time()
                store_to_file = "Time taken for validation Set: {}".format(
                    validation_end_time - validation_start_time)
                self.store_to_file(store_to_file)
                print("Time taken for validation Set: {}".format(
                    validation_end_time - validation_start_time))

    def infer(self, PATH="sample"):
        num_mcep = 36
        sampling_rate = 16000
        frame_period = 5.0
        n_frames = 128
        infer_A_dir = PATH
        output_A_dir = PATH

        for file in os.listdir(infer_A_dir):
            filePath = os.path.join(infer_A_dir, file)
            wav, _ = librosa.load(filePath, sr=sampling_rate, mono=True)
            wav = preprocess.wav_padding(wav=wav,
                                         sr=sampling_rate,
                                         frame_period=frame_period,
                                         multiple=4)
            f0, timeaxis, sp, ap = preprocess.world_decompose(
                wav=wav, fs=sampling_rate, frame_period=frame_period)
            f0_converted = preprocess.pitch_conversion(
                f0=f0,
                mean_log_src=self.log_f0s_mean_A,
                std_log_src=self.log_f0s_std_A,
                mean_log_target=self.log_f0s_mean_B,
                std_log_target=self.log_f0s_std_B,
            )
            coded_sp = preprocess.world_encode_spectral_envelop(
                sp=sp, fs=sampling_rate, dim=num_mcep)
            coded_sp_transposed = coded_sp.T
            coded_sp_norm = (coded_sp_transposed -
                             self.coded_sps_A_mean) / self.coded_sps_A_std
            coded_sp_norm = np.array([coded_sp_norm])

            if flow.cuda.is_available():
                coded_sp_norm = flow.tensor(coded_sp_norm).cuda().float()
            else:
                coded_sp_norm = flow.tensor(coded_sp_norm).float()

            coded_sp_converted_norm = self.generator_A2B(coded_sp_norm)
            coded_sp_converted_norm = coded_sp_converted_norm.cpu().detach(
            ).numpy()
            coded_sp_converted_norm = np.squeeze(coded_sp_converted_norm)
            coded_sp_converted = (
                coded_sp_converted_norm * self.coded_sps_B_std +
                self.coded_sps_B_mean)
            coded_sp_converted = coded_sp_converted.T
            coded_sp_converted = np.ascontiguousarray(coded_sp_converted)
            decoded_sp_converted = preprocess.world_decode_spectral_envelop(
                coded_sp=coded_sp_converted, fs=sampling_rate)
            wav_transformed = preprocess.world_speech_synthesis(
                f0=f0_converted,
                decoded_sp=decoded_sp_converted,
                ap=ap,
                fs=sampling_rate,
                frame_period=frame_period,
            )

            sf.write(
                os.path.join(output_A_dir,
                             "convert_" + os.path.basename(file)),
                wav_transformed,
                sampling_rate,
            )

    def validation_for_A_dir(self):
        num_mcep = 36
        sampling_rate = 16000
        frame_period = 5.0
        n_frames = 128
        validation_A_dir = self.validation_A_dir
        output_A_dir = self.output_A_dir

        print("Generating Validation Data B from A...")
        for file in os.listdir(validation_A_dir):
            filePath = os.path.join(validation_A_dir, file)
            wav, _ = librosa.load(filePath, sr=sampling_rate, mono=True)
            wav = preprocess.wav_padding(wav=wav,
                                         sr=sampling_rate,
                                         frame_period=frame_period,
                                         multiple=4)
            f0, timeaxis, sp, ap = preprocess.world_decompose(
                wav=wav, fs=sampling_rate, frame_period=frame_period)
            f0_converted = preprocess.pitch_conversion(
                f0=f0,
                mean_log_src=self.log_f0s_mean_A,
                std_log_src=self.log_f0s_std_A,
                mean_log_target=self.log_f0s_mean_B,
                std_log_target=self.log_f0s_std_B,
            )
            coded_sp = preprocess.world_encode_spectral_envelop(
                sp=sp, fs=sampling_rate, dim=num_mcep)
            coded_sp_transposed = coded_sp.T
            coded_sp_norm = (coded_sp_transposed -
                             self.coded_sps_A_mean) / self.coded_sps_A_std
            coded_sp_norm = np.array([coded_sp_norm])

            if flow.cuda.is_available():
                coded_sp_norm = flow.tensor(coded_sp_norm).cuda().float()
            else:
                coded_sp_norm = flow.tensor(coded_sp_norm).float()

            coded_sp_converted_norm = self.generator_A2B(coded_sp_norm)
            coded_sp_converted_norm = coded_sp_converted_norm.cpu().detach(
            ).numpy()
            coded_sp_converted_norm = np.squeeze(coded_sp_converted_norm)
            coded_sp_converted = (
                coded_sp_converted_norm * self.coded_sps_B_std +
                self.coded_sps_B_mean)
            coded_sp_converted = coded_sp_converted.T
            coded_sp_converted = np.ascontiguousarray(coded_sp_converted)
            decoded_sp_converted = preprocess.world_decode_spectral_envelop(
                coded_sp=coded_sp_converted, fs=sampling_rate)
            wav_transformed = preprocess.world_speech_synthesis(
                f0=f0_converted,
                decoded_sp=decoded_sp_converted,
                ap=ap,
                fs=sampling_rate,
                frame_period=frame_period,
            )

            sf.write(
                os.path.join(output_A_dir, os.path.basename(file)),
                wav_transformed,
                sampling_rate,
            )

    def validation_for_B_dir(self):
        num_mcep = 36
        sampling_rate = 16000
        frame_period = 5.0
        n_frames = 128
        validation_B_dir = self.validation_B_dir
        output_B_dir = self.output_B_dir

        print("Generating Validation Data A from B...")
        for file in os.listdir(validation_B_dir):
            filePath = os.path.join(validation_B_dir, file)
            wav, _ = librosa.load(filePath, sr=sampling_rate, mono=True)
            wav = preprocess.wav_padding(wav=wav,
                                         sr=sampling_rate,
                                         frame_period=frame_period,
                                         multiple=4)
            f0, timeaxis, sp, ap = preprocess.world_decompose(
                wav=wav, fs=sampling_rate, frame_period=frame_period)
            f0_converted = preprocess.pitch_conversion(
                f0=f0,
                mean_log_src=self.log_f0s_mean_B,
                std_log_src=self.log_f0s_std_B,
                mean_log_target=self.log_f0s_mean_A,
                std_log_target=self.log_f0s_std_A,
            )
            coded_sp = preprocess.world_encode_spectral_envelop(
                sp=sp, fs=sampling_rate, dim=num_mcep)
            coded_sp_transposed = coded_sp.T
            coded_sp_norm = (coded_sp_transposed -
                             self.coded_sps_B_mean) / self.coded_sps_B_std
            coded_sp_norm = np.array([coded_sp_norm])

            if flow.cuda.is_available():
                coded_sp_norm = flow.tensor(coded_sp_norm).cuda().float()
            else:
                coded_sp_norm = flow.tensor(coded_sp_norm).float()

            coded_sp_converted_norm = self.generator_B2A(coded_sp_norm)
            coded_sp_converted_norm = coded_sp_converted_norm.cpu().detach(
            ).numpy()
            coded_sp_converted_norm = np.squeeze(coded_sp_converted_norm)
            coded_sp_converted = (
                coded_sp_converted_norm * self.coded_sps_A_std +
                self.coded_sps_A_mean)
            coded_sp_converted = coded_sp_converted.T
            coded_sp_converted = np.ascontiguousarray(coded_sp_converted)
            decoded_sp_converted = preprocess.world_decode_spectral_envelop(
                coded_sp=coded_sp_converted, fs=sampling_rate)
            wav_transformed = preprocess.world_speech_synthesis(
                f0=f0_converted,
                decoded_sp=decoded_sp_converted,
                ap=ap,
                fs=sampling_rate,
                frame_period=frame_period,
            )

            sf.write(
                os.path.join(output_B_dir, os.path.basename(file)),
                wav_transformed,
                sampling_rate,
            )

    def savePickle(self, variable, fileName):
        with open(fileName, "wb") as f:
            pickle.dump(variable, f)

    def loadPickleFile(self, fileName):
        with open(fileName, "rb") as f:
            return pickle.load(f)

    def store_to_file(self, doc):
        doc = doc + "\n"
        with open(self.file_name, "a") as myfile:
            myfile.write(doc)

    def saveModelCheckPoint(self, epoch, PATH):
        flow.save(
            self.generator_A2B.state_dict(),
            os.path.join(PATH, "generator_A2B_%d" % epoch),
        )
        flow.save(
            self.generator_B2A.state_dict(),
            os.path.join(PATH, "generator_B2A_%d" % epoch),
        )
        flow.save(
            self.discriminator_A.state_dict(),
            os.path.join(PATH, "discriminator_A_%d" % epoch),
        )
        flow.save(
            self.discriminator_B.state_dict(),
            os.path.join(PATH, "discriminator_B_%d" % epoch),
        )

    def loadModel(self, PATH):
        self.generator_A2B.load_state_dict(
            flow.load(os.path.join(PATH, "generator_A2B")))
        self.generator_B2A.load_state_dict(
            flow.load(os.path.join(PATH, "generator_B2A")))
        self.discriminator_A.load_state_dict(
            flow.load(os.path.join(PATH, "discriminator_A")))
        self.discriminator_B.load_state_dict(
            flow.load(os.path.join(PATH, "discriminator_B")))
Ejemplo n.º 5
0
class Solver(object):
    def __init__(self, data_loader, config):

        self.config = config
        self.data_loader = data_loader

        # Model configurations.
        self.lambda_cycle = config.lambda_cycle
        self.lambda_cls = config.lambda_cls
        self.lambda_identity = config.lambda_identity

        # Training configurations.
        self.data_dir = config.data_dir
        self.test_dir = config.test_dir
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.c_lr = config.c_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters

        # Test configurations.
        self.pretrain_models = config.pretrain_models
        self.sample_dir = config.sample_dir
        self.trg_speaker = ast.literal_eval(config.trg_speaker)
        self.src_speaker = config.src_speaker

        # Miscellaneous.
        self.device = flow.device(
            "cuda:0" if flow.cuda.is_available() else "cpu")
        self.spk_enc = LabelBinarizer().fit(speakers)

        # Directories.
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir
        self.use_gradient_penalty = config.use_gradient_penalty

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # Build the model.
        self.build_model()

    def build_model(self):
        self.G = Generator()
        self.D = Discriminator()
        self.C = DomainClassifier()

        self.g_optimizer = flow.optim.Adam(self.G.parameters(), self.g_lr,
                                           [self.beta1, self.beta2])
        self.d_optimizer = flow.optim.Adam(self.D.parameters(), self.d_lr,
                                           [self.beta1, self.beta2])
        self.c_optimizer = flow.optim.Adam(self.C.parameters(), self.c_lr,
                                           [self.beta1, self.beta2])

        self.print_network(self.G, "G")
        self.print_network(self.D, "D")
        self.print_network(self.C, "C")

        self.G.to(self.device)
        self.D.to(self.device)
        self.C.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def update_lr(self, g_lr, d_lr, c_lr):
        """Decay learning rates of the generator and discriminator and classifier."""
        for param_group in self.g_optimizer.param_groups:
            param_group["lr"] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group["lr"] = d_lr
        for param_group in self.c_optimizer.param_groups:
            param_group["lr"] = c_lr

    def train(self):
        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr
        c_lr = self.c_lr

        start_iters = 0
        if self.resume_iters:
            pass

        norm = Normalizer()
        data_iter = iter(self.data_loader)

        print("Start training......")
        start_time = datetime.now()

        for i in range(start_iters, self.num_iters):
            # Preprocess input data
            # Fetch real images and labels.
            try:
                x_real, speaker_idx_org, label_org = next(data_iter)
            except:
                data_iter = iter(self.data_loader)
                x_real, speaker_idx_org, label_org = next(data_iter)

            # Generate target domain labels randomly.
            rand_idx = flow.randperm(label_org.size(0))
            label_trg = label_org[rand_idx]
            speaker_idx_trg = speaker_idx_org[rand_idx]

            x_real = x_real.to(self.device)
            # Original domain one-hot labels.
            label_org = label_org.to(self.device)
            # Target domain one-hot labels.
            label_trg = label_trg.to(self.device)
            speaker_idx_org = speaker_idx_org.to(self.device)
            speaker_idx_trg = speaker_idx_trg.to(self.device)

            # Train the discriminator
            # Compute loss with real audio frame.
            CELoss = nn.CrossEntropyLoss()
            cls_real = self.C(x_real)
            cls_loss_real = CELoss(input=cls_real, target=speaker_idx_org)

            self.reset_grad()
            cls_loss_real.backward()
            self.c_optimizer.step()
            # Logging.
            loss = {}
            loss["C/C_loss"] = cls_loss_real.item()

            out_r = self.D(x_real, label_org)
            # Compute loss with fake audio frame.
            x_fake = self.G(x_real, label_trg)
            out_f = self.D(x_fake.detach(), label_trg)
            d_loss_t = nn.BCEWithLogitsLoss()(
                input=out_f, target=flow.zeros_like(
                    out_f).float()) + nn.BCEWithLogitsLoss()(
                        input=out_r, target=flow.ones_like(out_r).float())

            out_cls = self.C(x_fake)
            d_loss_cls = CELoss(input=out_cls, target=speaker_idx_trg)

            # Compute loss for gradient penalty.
            alpha = flow.rand(x_real.size(0), 1, 1, 1).to(self.device)
            x_hat = ((alpha * x_real +
                      (1 - alpha) * x_fake).detach().requires_grad_(True))
            out_src = self.D(x_hat, label_trg)

            # TODO: Second-order derivation is not currently supported in oneflow, so gradient penalty cannot be used temporarily.
            if self.use_gradient_penalty:
                d_loss_gp = self.gradient_penalty(out_src, x_hat)
                d_loss = d_loss_t + self.lambda_cls * d_loss_cls + 5 * d_loss_gp
            else:
                d_loss = d_loss_t + self.lambda_cls * d_loss_cls

            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            loss["D/D_loss"] = d_loss.item()

            # Train the generator
            if (i + 1) % self.n_critic == 0:
                # Original-to-target domain.
                x_fake = self.G(x_real, label_trg)
                g_out_src = self.D(x_fake, label_trg)
                g_loss_fake = nn.BCEWithLogitsLoss()(
                    input=g_out_src, target=flow.ones_like(g_out_src).float())

                out_cls = self.C(x_real)
                g_loss_cls = CELoss(input=out_cls, target=speaker_idx_org)

                # Target-to-original domain.
                x_reconst = self.G(x_fake, label_org)
                g_loss_rec = nn.L1Loss()(x_reconst, x_real)

                # Original-to-Original domain(identity).
                x_fake_iden = self.G(x_real, label_org)
                id_loss = nn.L1Loss()(x_fake_iden, x_real)

                # Backward and optimize.
                g_loss = (g_loss_fake + self.lambda_cycle * g_loss_rec +
                          self.lambda_cls * g_loss_cls +
                          self.lambda_identity * id_loss)

                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging.
                loss["G/loss_fake"] = g_loss_fake.item()
                loss["G/loss_rec"] = g_loss_rec.item()
                loss["G/loss_cls"] = g_loss_cls.item()
                loss["G/loss_id"] = id_loss.item()
                loss["G/g_loss"] = g_loss.item()

            # Miscellaneous
            # Print out training information.
            if (i + 1) % self.log_step == 0:
                et = datetime.now() - start_time
                et = str(et)[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(
                    et, i + 1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

            # Translate fixed images for debugging.
            if (i + 1) % self.sample_step == 0:
                with flow.no_grad():
                    d, speaker = TestSet(self.test_dir).test_data()
                    target = random.choice(
                        [x for x in speakers if x != speaker])
                    label_t = self.spk_enc.transform([target])[0]
                    label_t = np.asarray([label_t])

                    for filename, content in d.items():
                        f0 = content["f0"]
                        ap = content["ap"]
                        sp_norm_pad = self.pad_coded_sp(
                            content["coded_sp_norm"])

                        convert_result = []
                        for start_idx in range(
                                0, sp_norm_pad.shape[1] - FRAMES + 1, FRAMES):
                            one_seg = sp_norm_pad[:,
                                                  start_idx:start_idx + FRAMES]

                            one_seg = flow.Tensor(one_seg).to(self.device)
                            one_seg = one_seg.view(1, 1, one_seg.size(0),
                                                   one_seg.size(1))
                            l = flow.Tensor(label_t)
                            one_seg = one_seg.to(self.device)
                            l = l.to(self.device)
                            one_set_return = self.G(one_seg,
                                                    l).detach().cpu().numpy()
                            one_set_return = np.squeeze(one_set_return)
                            one_set_return = norm.backward_process(
                                one_set_return, target)
                            convert_result.append(one_set_return)

                        convert_con = np.concatenate(convert_result, axis=1)
                        convert_con = convert_con[:,
                                                  0:content["coded_sp_norm"].
                                                  shape[1]]
                        contigu = np.ascontiguousarray(convert_con.T,
                                                       dtype=np.float64)
                        decoded_sp = decode_spectral_envelope(contigu,
                                                              SAMPLE_RATE,
                                                              fft_size=FFTSIZE)
                        f0_converted = norm.pitch_conversion(
                            f0, speaker, target)
                        wav = synthesize(f0_converted, decoded_sp, ap,
                                         SAMPLE_RATE)

                        name = f"{speaker}-{target}_iter{i+1}_{filename}"
                        path = os.path.join(self.sample_dir, name)
                        print(f"[save]:{path}")
                        sf.write(path, wav, SAMPLE_RATE)

            # Save model checkpoints.
            if (i + 1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir,
                                      "{}-G".format(i + 1))
                D_path = os.path.join(self.model_save_dir,
                                      "{}-D".format(i + 1))
                C_path = os.path.join(self.model_save_dir,
                                      "{}-C".format(i + 1))
                flow.save(self.G.state_dict(), G_path)
                flow.save(self.D.state_dict(), D_path)
                flow.save(self.C.state_dict(), C_path)
                print("Saved model checkpoints into {}...".format(
                    self.model_save_dir))

            # Decay learning rates.
            if (i + 1) % self.lr_update_step == 0 and (i + 1) > (
                    self.num_iters - self.num_iters_decay):
                g_lr -= self.g_lr / float(self.num_iters_decay)
                d_lr -= self.d_lr / float(self.num_iters_decay)
                c_lr -= self.c_lr / float(self.num_iters_decay)
                self.update_lr(g_lr, d_lr, c_lr)
                print("Decayed learning rates, g_lr: {}, d_lr: {}.".format(
                    g_lr, d_lr))

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = flow.ones(y.size()).to(self.device)

        dydx = flow.autograd.grad(outputs=y,
                                  inputs=x,
                                  out_grads=weight,
                                  retain_graph=True,
                                  create_graph=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = flow.sqrt(flow.sum(dydx**2, dim=1))

        return flow.mean((dydx_l2norm - 1)**2)

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()
        self.c_optimizer.zero_grad()

    def restore_model(self, model_save_dir):
        """Restore the tra,zined generator and discriminator."""
        print("Loading the pretrain models...")
        G_path = os.path.join(model_save_dir, "200000-G")
        D_path = os.path.join(model_save_dir, "200000-D")
        C_path = os.path.join(model_save_dir, "200000-C")
        self.G.load_state_dict(flow.load(G_path))
        self.D.load_state_dict(flow.load(D_path))
        self.C.load_state_dict(flow.load(C_path))

    @staticmethod
    def pad_coded_sp(coded_sp_norm):
        f_len = coded_sp_norm.shape[1]
        if f_len >= FRAMES:
            pad_length = FRAMES - (f_len - (f_len // FRAMES) * FRAMES)
        elif f_len < FRAMES:
            pad_length = FRAMES - f_len

        sp_norm_pad = np.hstack(
            (coded_sp_norm, np.zeros((coded_sp_norm.shape[0], pad_length))))
        return sp_norm_pad

    def test(self):
        """Translate speech using StarGAN ."""
        # Load the trained generator.
        self.restore_model(self.pretrain_models)
        norm = Normalizer()

        # Set data loader.
        d, speaker = TestSet(self.test_dir).test_data(self.src_speaker)
        targets = self.trg_speaker

        for target in targets:
            print(target)
            assert target in speakers
            label_t = self.spk_enc.transform([target])[0]
            label_t = np.asarray([label_t])

            with flow.no_grad():

                for filename, content in d.items():
                    f0 = content["f0"]
                    ap = content["ap"]
                    sp_norm_pad = self.pad_coded_sp(content["coded_sp_norm"])

                    convert_result = []
                    for start_idx in range(0,
                                           sp_norm_pad.shape[1] - FRAMES + 1,
                                           FRAMES):
                        one_seg = sp_norm_pad[:, start_idx:start_idx + FRAMES]

                        one_seg = flow.Tensor(one_seg).to(self.device)
                        one_seg = one_seg.view(1, 1, one_seg.size(0),
                                               one_seg.size(1))
                        l = flow.Tensor(label_t)
                        one_seg = one_seg.to(self.device)
                        l = l.to(self.device)
                        one_set_return = self.G(one_seg,
                                                l).detach().cpu().numpy()
                        one_set_return = np.squeeze(one_set_return)
                        one_set_return = norm.backward_process(
                            one_set_return, target)
                        convert_result.append(one_set_return)

                    convert_con = np.concatenate(convert_result, axis=1)
                    convert_con = convert_con[:, 0:content["coded_sp_norm"].
                                              shape[1]]
                    contigu = np.ascontiguousarray(convert_con.T,
                                                   dtype=np.float64)
                    decoded_sp = decode_spectral_envelope(contigu,
                                                          SAMPLE_RATE,
                                                          fft_size=FFTSIZE)
                    f0_converted = norm.pitch_conversion(f0, speaker, target)
                    wav = synthesize(f0_converted, decoded_sp, ap, SAMPLE_RATE)

                    name = f"{speaker}-{target}_{filename}"
                    path = os.path.join(self.result_dir, name)
                    print(f"[save]:{path}")
                    sf.write(path, wav, SAMPLE_RATE)