Exemple #1
0
def trainer(model_name):
    chkpt_path = None  #@param
    device = xm.xla_device()
    pt_dir = os.path.join('.', config.log['chkpt_dir'], model_name)
    os.makedirs(pt_dir, exist_ok=True)

    log_dir = os.path.join('.', config.log['log_dir'], model_name)
    os.makedirs(log_dir, exist_ok=True)

    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler(
                                os.path.join(
                                    log_dir,
                                    '%s-%d.log' % (model_name, time.time()))),
                            logging.StreamHandler()
                        ])
    logger = logging.getLogger()
    writer = MyWriter(log_dir)

    trainloader = create_dataloader(train=True)
    testloader = create_dataloader(train=False)

    embedder_pt = torch.load(
        '/drive/content/My Drive/ColabDisk/embedder_cpu.pt')
    embedder = SpeechEmbedder().to(device)
    embedder.load_state_dict(embedder_pt)
    embedder.eval()

    model = VoiceFilter().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.train['adam'])
    audio = Audio()

    starting_epoch = 1

    if chkpt_path is not None:
        logger.info("Resuming from checkpoint: %s" % chkpt_path)
        checkpoint_file = torch.load(chkpt_path)
        model.load_state_dict(checkpoint_file['model'])
        optimizer.load_state_dict(checkpoint_file['optimizer'])
        starting_epoch = checkpoint_file['epoch']
    else:
        logger.info("Starting new training run")

    for epoch in range(starting_epoch, config.train['epoch'] + 1):
        para_loader = pl.ParallelLoader(trainloader,
                                        [device]).per_device_loader(device)
        train(embedder, model, optimizer, para_loader, writer, logger, epoch,
              pt_dir, device)
        xm.master_print("Finished training epoch {}".format(epoch))
        logger.info("Starting to validate epoch...")
        para_loader = pl.ParallelLoader(testloader,
                                        [device]).per_device_loader(device)
        validate(audio, model, embedder, para_loader, writer, epoch, device)

    model_saver(model, optimizer, pt_dir, config.train['epoch'])
Exemple #2
0
def main(args):
    with torch.no_grad():
        model = torch.nn.DataParallel(VoiceFilter()).cuda()
        chkpt_model = torch.load(args.checkpoint_path)
        model.load_state_dict(chkpt_model['model'])

        model.eval()

        embedder = SpeechEmbedder().cuda()
        chkpt_embed = torch.load(args.embedder_path)
        embedder.load_state_dict(chkpt_embed)
        embedder.eval()

        audio = Audio()
        dvec_wav, _ = librosa.load(args.reference_file, sr=16000)
        dvec_mel = audio.get_mel(dvec_wav)
        dvec_mel = torch.from_numpy(dvec_mel).float().cuda()
        dvec = embedder(dvec_mel)
        dvec = dvec.unsqueeze(0)

        mixed_wav, _ = librosa.load(args.mixed_file, sr=16000)
        mag, phase = audio.wav2spec(mixed_wav)
        mag = torch.from_numpy(mag).float().cuda()

        mag = mag.unsqueeze(0)
        mask = model(mag, dvec)
        est_mag = mag * mask

        est_mag = est_mag[0].cpu().detach().numpy()
        est_wav = audio.spec2wav(est_mag, phase)

        os.makedirs(args.out_dir, exist_ok=True)
        out_path = os.path.join(args.out_dir, 'result.wav')
        librosa.output.write_wav(out_path, est_wav, sr=16000)
Exemple #3
0
def main(args, hp):
    with torch.no_grad():
        model = VoiceFilter(hp).cuda()
        chkpt_model = torch.load(args.checkpoint_path)['model']
        model.load_state_dict(chkpt_model)
        model.eval()

        embedder = SpeechEmbedder(hp).cuda()
        chkpt_embed = torch.load(args.embedder_path)
        embedder.load_state_dict(chkpt_embed)
        embedder.eval()

        audio = Audio(hp)
        ref_wav, _ = librosa.load(args.reference_file, sr=16000)
        ref_mel = audio.get_mel(ref_wav)
        ref_mel = torch.from_numpy(ref_mel).float().cuda()
        dvec = embedder(ref_mel)
        dvec = dvec.unsqueeze(0)

        mixed_wav, _ = librosa.load(args.mixed_file, sr=16000)
        mixed_mag, mixed_phase = audio.wav2spec(mixed_wav)
        mixed_mag = torch.from_numpy(mixed_mag).float().cuda()

        mixed_mag = mixed_mag.unsqueeze(0)
        shadow_mag = model(mixed_mag, dvec)

        shadow_mag = shadow_mag[0].cpu().detach().numpy()
        recorded_mag = tensor_normalize(mixed_mag + shadow_mag)
        recorded_mag = recorded_mag[0].cpu().detach().numpy()
        recorded_wav = audio.spec2wav(recorded_mag, mixed_mag)

        os.makedirs(args.out_dir, exist_ok=True)
        out_path = os.path.join(args.out_dir, 'result.wav')
        librosa.output.write_wav(out_path, recorded_wav, sr=16000)
Exemple #4
0
def main(args):
    args = {
        "config": 'config/config.yaml',
        "embedder_path": 'model/embedder.pt',
        "checkpoint_path": 'enhance_my_voice/chkpt_201000.pt',
        "mixed_file": 'utils/speakerA.wav',
        "reference_file": 'utils/speakerA.wav',
        "out_dir": 'output',
    }

    hp = HParam(args['config'])

    with torch.no_grad():
        model = VoiceFilter(hp).cuda()
        chkpt_model = torch.load(args['checkpoint_path'])['model']
        model.load_state_dict(chkpt_model)
        model.eval()

        embedder = SpeechEmbedder(hp).cuda()
        chkpt_embed = torch.load(args['embedder_path'])
        embedder.load_state_dict(chkpt_embed)
        embedder.eval()

        audio = Audio(hp)
        dvec_wav, _ = librosa.load(args['reference_file'], sr=16000)
        dvec_mel = audio.get_mel(dvec_wav)
        dvec_mel = torch.from_numpy(dvec_mel).float().cuda()
        dvec = embedder(dvec_mel)
        dvec = dvec.unsqueeze(0)

        mixed_wav, _ = librosa.load(args['mixed_file'], sr=16000)
        mag, phase = audio.wav2spec(mixed_wav)
        mag = torch.from_numpy(mag).float().cuda()

        mag = mag.unsqueeze(0)
        mask = model(mag, dvec)
        est_mag = mag * mask

        est_mag = est_mag[0].cpu().detach().numpy()
        # est_wav = audio.spec2wav(est_mag, phase)

        # os.makedirs(args['out_dir'], exist_ok=True)
        # out_path = os.path.join(args['out_dir'], 'result.wav')
        # librosa.output.write_wav(out_path, est_wav, sr=16000)
        return audio.spec2wav(est_mag, phase)
Exemple #5
0
def main(args, hp):
    with open('out1.txt') as f:
        for line in f:
            res = line.split('\t')
            with torch.no_grad():
                model = VoiceFilter(hp)
                chkpt_model = torch.load(args.checkpoint_path, map_location='cpu')['model']
                model.load_state_dict(chkpt_model)
                model.eval()

                embedder = SpeechEmbedder(hp)
                chkpt_embed = torch.load(args.embedder_path, map_location='cpu')
                embedder.load_state_dict(chkpt_embed)
                embedder.eval()

                audio = Audio(hp)
                dvec_wav, _ = librosa.load(res[1], sr=16000)
                dvec_mel = audio.get_mel(dvec_wav)
                dvec_mel = torch.from_numpy(dvec_mel).float()
                dvec = embedder(dvec_mel)
                dvec = dvec.unsqueeze(0)

                mixed_wav, _ = librosa.load(res[0], sr=16000)
                mag, phase = audio.wav2spec(mixed_wav)
                mag = torch.from_numpy(mag).float()

                mag = mag.unsqueeze(0)
                mask = model(mag, dvec)
                est_mag = mag * mask

                est_mag = est_mag[0].cpu().detach().numpy()
                est_wav = audio.spec2wav(est_mag, phase)

                os.makedirs('/root/voicefilter/res', exist_ok=True)
                out_path = os.path.join('/root/voicefilter/res', f'{res[2]}')
                librosa.output.write_wav(out_path, est_wav, sr=16000)
    # dir = '/data/our_dataset/test/3/joint'
    for dir in dirs:
        speaker_count = speaker_count + 1
        print("Speaker : {}/56\n".format(speaker_count))
        tree = dir.split('/')
        speaker_id = tree[-2]
        hp.data.test_dir = dir
        testloader = create_dataloader(hp, args, train=False)
        for batch in testloader:
            # length of batch is 1, set in dataloader
            ref_mel, eliminated_wav, mixed_wav, expected_hidden_wav, eliminated_mag, expected_hidden_mag, mixed_mag, mixed_phase, dvec_path, eliminated_wav_path, mixed_wav_path = \
                batch[0]
            # print("expected_focused: {}".format(expected_focused_wav_path))
            print("Mixed: {}".format(mixed_wav_path))
            model = VoiceFilter(hp).cuda()
            chkpt_model = torch.load(args.checkpoint_path,
                                     map_location='cuda:0')['model']
            model.load_state_dict(chkpt_model)
            model.eval()

            embedder = SpeechEmbedder(hp).cuda()
            chkpt_embed = torch.load(args.embedder_path)
            embedder.load_state_dict(chkpt_embed)
            embedder.eval()

            audio = Audio(hp)
            dvec_wav, _ = librosa.load(dvec_path, sr=16000)
            ref_mel = audio.get_mel(dvec_wav)
            ref_mel = torch.from_numpy(ref_mel).float().cuda()
            dvec = embedder(ref_mel)
Exemple #7
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch Voice Filter')
    parser.add_argument('-b',
                        '--base_dir',
                        type=str,
                        default='.',
                        help="Root directory of run.")
    parser.add_argument('--checkpoint_path',
                        type=str,
                        default=None,
                        help='Path to last checkpoint')
    parser.add_argument('-e',
                        '--embedder_path',
                        type=str,
                        required=True,
                        help="path of embedder model pt file")
    parser.add_argument(
        '-m',
        '--model',
        type=str,
        required=True,
        help="Name of the model. Used for both logging and saving checkpoints."
    )
    args = parser.parse_args()

    chkpt_path = args.checkpoint_path if args.checkpoint_path is not None else None

    pt_dir = os.path.join(args.base_dir, config.log['chkpt_dir'], args.model)
    os.makedirs(pt_dir, exist_ok=True)

    log_dir = os.path.join(args.base_dir, config.log['log_dir'], args.model)
    os.makedirs(log_dir, exist_ok=True)

    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler(
                                os.path.join(
                                    log_dir,
                                    '%s-%d.log' % (args.model, time.time()))),
                            logging.StreamHandler()
                        ])
    logger = logging.getLogger()
    writer = MyWriter(log_dir)

    trainloader = create_dataloader(train=True)
    testloader = create_dataloader(train=False)

    embedder_pt = torch.load(args.embedder_path)
    embedder = SpeechEmbedder().cuda()
    embedder.load_state_dict(embedder_pt)
    embedder.eval()

    model = nn.DataParallel(VoiceFilter())
    optimizer = torch.optim.Adam(model.parameters(), lr=config.train['adam'])
    audio = Audio()

    starting_step = 0
    starting_epoch = 1

    if chkpt_path is not None:
        logger.info("Resuming from checkpoint: %s" % chkpt_path)
        checkpoint_file = torch.load(chkpt_path)
        model.load_state_dict(checkpoint_file['model'])
        starting_epoch = checkpoint_file['epoch']
        starting_step = checkpoint_file['step']
    else:
        logger.info("Starting new training run")

    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
    for epoch in range(starting_epoch, config.train['epoch'] + 1):
        train(embedder, model, optimizer, trainloader, writer, logger, epoch,
              pt_dir, starting_step)
        validate(audio, model, embedder, testloader, writer, epoch)
        scheduler.step()
        starting_step = 0

    model_saver(model, pt_dir, config.train['epoch'],
                config.train['train_step_pre_epoch'])
Exemple #8
0
    if hp.data.train_dir == '' or hp.data.test_dir == '':
        logger.error("train_dir, test_dir cannot be empty.")
        raise Exception("Please specify directories of data in %s" %
                        args.config)

    writer = MyWriter(hp, log_dir)

    torch.cuda.set_device(args.gpu)
    embedder_pt = torch.load(args.embedder_path)
    embedder = SpeechEmbedder(hp).cuda()
    embedder.load_state_dict(embedder_pt)
    embedder.eval()

    audio = Audio(hp)
    model = VoiceFilter(hp).cuda()
    checkpoint = torch.load(chkpt_path, map_location='cuda:0')
    model.load_state_dict(checkpoint['model'])
    step = 1
    if args.hide:
        testloader = create_hide_dataloader(hp, args, train=False)
    else:
        testloader = create_focus_dataloader(hp, args, train=False)
    while True:
        if args.hide:
            validate_hide(audio, model, embedder, testloader, writer, step)
            step = step + 1
        else:
            validate_focus(audio, model, embedder, testloader, writer, step)
            step = step + 1
        print(step)
Exemple #9
0
def train(args, pt_dir, chkpt_path, trainloader, testloader, writer, logger,
          hp, hp_str):
    # load embedder
    embedder_pt = torch.load(args.embedder_path)
    embedder = SpeechEmbedder(hp).cuda()
    embedder.load_state_dict(embedder_pt)
    # embedder = nn.DataParallel(embedder)
    embedder.eval()

    audio = Audio(hp)
    model = VoiceFilter(hp).cuda()
    # model = nn.DataParallel(model)
    if hp.train.optimizer == 'adabound':
        optimizer = AdaBound(model.parameters(),
                             lr=hp.train.adabound.initial,
                             final_lr=hp.train.adabound.final)
    elif hp.train.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=hp.train.adam)
    else:
        raise Exception("%s optimizer not supported" % hp.train.optimizer)

    step = 0

    if chkpt_path is not None:
        logger.info("Resuming from checkpoint: %s" % chkpt_path)
        checkpoint = torch.load(chkpt_path)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        step = checkpoint['step']

        # will use new given hparams.
        if hp_str != checkpoint['hp_str']:
            logger.warning("New hparams is different from checkpoint.")
    else:
        logger.info("Starting new training run")

    try:
        criterion = nn.MSELoss()
        while True:
            model.train()
            for dvec_mels, target_mag, mixed_mag in trainloader:
                target_mag = target_mag.cuda()
                mixed_mag = mixed_mag.cuda()

                dvec_list = list()
                for mel in dvec_mels:
                    mel = mel.cuda()
                    dvec = embedder(mel)
                    dvec_list.append(dvec)
                dvec = torch.stack(dvec_list, dim=0)
                dvec = dvec.detach()

                mask = model(mixed_mag, dvec)
                output = mixed_mag * mask

                # output = torch.pow(torch.clamp(output, min=0.0), hp.audio.power)
                # target_mag = torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power)
                loss = criterion(output, target_mag)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                step += 1

                loss = loss.item()
                if loss > 1e8 or math.isnan(loss):
                    logger.error("Loss exploded to %.02f at step %d!" %
                                 (loss, step))
                    raise Exception("Loss exploded")

                # write loss to tensorboard
                if step % hp.train.summary_interval == 0:
                    writer.log_training(loss, step)
                    logger.info("Wrote summary at step %d" % step)

                # 1. save checkpoint file to resume training
                # 2. evaluate and save sample to tensorboard
                if step % hp.train.checkpoint_interval == 0:
                    save_path = os.path.join(pt_dir, 'chkpt_%d.pt' % step)
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'step': step,
                            'hp_str': hp_str,
                        }, save_path)
                    logger.info("Saved checkpoint to: %s" % save_path)
                    validate(audio, model, embedder, testloader, writer, step)
    except Exception as e:
        logger.info("Exiting due to exception: %s" % e)
        traceback.print_exc()
Exemple #10
0
def train(args, pt_dir, chkpt_path, trainloader, testloader, writer, logger,
          hp, hp_str):
    # load embedder

    torch.cuda.set_device("cuda:1")
    print(torch.cuda.current_device())

    embedder_pt = torch.load(args.embedder_path)
    embedder = SpeechEmbedder(hp).cuda()
    embedder.load_state_dict(embedder_pt)
    embedder.eval()

    audio = Audio(hp)
    model = VoiceFilter(hp).cuda()

    ### Multi-GPU
    #    model = VoiceFilter(hp)
    #    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #    model = nn.DataParallel(model)
    #    model.to(device)

    if hp.train.optimizer == 'adabound':
        optimizer = AdaBound(model.parameters(),
                             lr=hp.train.adabound.initial,
                             final_lr=hp.train.adabound.final)
    elif hp.train.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=hp.train.adam)
    else:
        raise Exception("%s optimizer not supported" % hp.train.optimizer)
    if hp.scheduler.type == 'oneCycle':
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=hp.scheduler.oneCycle.max_lr,
            epochs=hp.train.epoch,
            steps_per_epoch=len(trainloader))
    elif hp.scheduler.type == 'Plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=hp.scheduler.Plateau.mode,
            patience=hp.scheduler.Plateau.patience,
            factor=hp.scheduler.Plateau.factor,
            verbose=True)

    step = 0

    if chkpt_path is not None:
        logger.info("Resuming from checkpoint: %s" % chkpt_path)
        checkpoint = torch.load(chkpt_path)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])

        # for oneCycleLR
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=hp.scheduler.oneCycle.max_lr,
            epochs=hp.train.epoch,
            steps_per_epoch=len(trainloader))

        step = checkpoint['step']

        # will use new given hparams.
        if hp_str != checkpoint['hp_str']:
            logger.warning("New hparams is different from checkpoint.")
    else:
        logger.info("Starting new training run")

    try:
        criterion = nn.MSELoss()
        for i_epoch in range(hp.train.epoch):
            model.train()
            for dvec_mels, target_mag, mixed_mag in trainloader:
                target_mag = target_mag.cuda()
                mixed_mag = mixed_mag.cuda()

                dvec_list = list()
                for mel in dvec_mels:
                    mel = mel.cuda()
                    dvec = embedder(mel)
                    dvec_list.append(dvec)
                dvec = torch.stack(dvec_list, dim=0)
                dvec = dvec.detach()

                mask = model(mixed_mag, dvec)
                output = mixed_mag * mask

                # output = torch.pow(torch.clamp(output, min=0.0), hp.audio.power)
                # target_mag = torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power)
                loss = criterion(output, target_mag)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if hp.scheduler.type == 'oneCycle':
                    scheduler.step()
                elif hp.scheduler.type == 'Plateau':
                    scheduler.step(loss)

                step += 1

                loss = loss.item()
                if loss > 1e8 or math.isnan(loss):
                    logger.error("Loss exploded to %.02f at step %d!" %
                                 (loss, step))
                    raise Exception("Loss exploded")

                # write loss to tensorboard
                if step % hp.train.summary_interval == 0:
                    writer.log_training(loss, step)
                    logger.info("Wrote summary at step %d in epoch %d" %
                                (step, i_epoch))

                if step % hp.train.validation_interval == 0:
                    validate(audio, model, embedder, testloader, writer, step)

                # 1. save checkpoint file to resume training
                # 2. evaluate and save sample to tensorboard
                if step % hp.train.checkpoint_interval == 0:
                    save_path = os.path.join(pt_dir, 'chkpt_%d.pt' % step)
                    #save_dict_path = os.path.join(pt_dir, 'chkpt_%d_dict.pt' % step)
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'step': step,
                            'hp_str': hp_str,
                        }, save_path)

                    #torch.save(model.module.state_dict() , save_dict_path)

                    logger.info("Saved checkpoint to: %s" % save_path)
    except Exception as e:
        logger.info("Exiting due to exception: %s" % e)
        traceback.print_exc()