def trainer(model_name): chkpt_path = None #@param device = xm.xla_device() pt_dir = os.path.join('.', config.log['chkpt_dir'], model_name) os.makedirs(pt_dir, exist_ok=True) log_dir = os.path.join('.', config.log['log_dir'], model_name) os.makedirs(log_dir, exist_ok=True) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler( os.path.join( log_dir, '%s-%d.log' % (model_name, time.time()))), logging.StreamHandler() ]) logger = logging.getLogger() writer = MyWriter(log_dir) trainloader = create_dataloader(train=True) testloader = create_dataloader(train=False) embedder_pt = torch.load( '/drive/content/My Drive/ColabDisk/embedder_cpu.pt') embedder = SpeechEmbedder().to(device) embedder.load_state_dict(embedder_pt) embedder.eval() model = VoiceFilter().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=config.train['adam']) audio = Audio() starting_epoch = 1 if chkpt_path is not None: logger.info("Resuming from checkpoint: %s" % chkpt_path) checkpoint_file = torch.load(chkpt_path) model.load_state_dict(checkpoint_file['model']) optimizer.load_state_dict(checkpoint_file['optimizer']) starting_epoch = checkpoint_file['epoch'] else: logger.info("Starting new training run") for epoch in range(starting_epoch, config.train['epoch'] + 1): para_loader = pl.ParallelLoader(trainloader, [device]).per_device_loader(device) train(embedder, model, optimizer, para_loader, writer, logger, epoch, pt_dir, device) xm.master_print("Finished training epoch {}".format(epoch)) logger.info("Starting to validate epoch...") para_loader = pl.ParallelLoader(testloader, [device]).per_device_loader(device) validate(audio, model, embedder, para_loader, writer, epoch, device) model_saver(model, optimizer, pt_dir, config.train['epoch'])
def main(args): with torch.no_grad(): model = torch.nn.DataParallel(VoiceFilter()).cuda() chkpt_model = torch.load(args.checkpoint_path) model.load_state_dict(chkpt_model['model']) model.eval() embedder = SpeechEmbedder().cuda() chkpt_embed = torch.load(args.embedder_path) embedder.load_state_dict(chkpt_embed) embedder.eval() audio = Audio() dvec_wav, _ = librosa.load(args.reference_file, sr=16000) dvec_mel = audio.get_mel(dvec_wav) dvec_mel = torch.from_numpy(dvec_mel).float().cuda() dvec = embedder(dvec_mel) dvec = dvec.unsqueeze(0) mixed_wav, _ = librosa.load(args.mixed_file, sr=16000) mag, phase = audio.wav2spec(mixed_wav) mag = torch.from_numpy(mag).float().cuda() mag = mag.unsqueeze(0) mask = model(mag, dvec) est_mag = mag * mask est_mag = est_mag[0].cpu().detach().numpy() est_wav = audio.spec2wav(est_mag, phase) os.makedirs(args.out_dir, exist_ok=True) out_path = os.path.join(args.out_dir, 'result.wav') librosa.output.write_wav(out_path, est_wav, sr=16000)
def main(args, hp): with torch.no_grad(): model = VoiceFilter(hp).cuda() chkpt_model = torch.load(args.checkpoint_path)['model'] model.load_state_dict(chkpt_model) model.eval() embedder = SpeechEmbedder(hp).cuda() chkpt_embed = torch.load(args.embedder_path) embedder.load_state_dict(chkpt_embed) embedder.eval() audio = Audio(hp) ref_wav, _ = librosa.load(args.reference_file, sr=16000) ref_mel = audio.get_mel(ref_wav) ref_mel = torch.from_numpy(ref_mel).float().cuda() dvec = embedder(ref_mel) dvec = dvec.unsqueeze(0) mixed_wav, _ = librosa.load(args.mixed_file, sr=16000) mixed_mag, mixed_phase = audio.wav2spec(mixed_wav) mixed_mag = torch.from_numpy(mixed_mag).float().cuda() mixed_mag = mixed_mag.unsqueeze(0) shadow_mag = model(mixed_mag, dvec) shadow_mag = shadow_mag[0].cpu().detach().numpy() recorded_mag = tensor_normalize(mixed_mag + shadow_mag) recorded_mag = recorded_mag[0].cpu().detach().numpy() recorded_wav = audio.spec2wav(recorded_mag, mixed_mag) os.makedirs(args.out_dir, exist_ok=True) out_path = os.path.join(args.out_dir, 'result.wav') librosa.output.write_wav(out_path, recorded_wav, sr=16000)
def main(args): args = { "config": 'config/config.yaml', "embedder_path": 'model/embedder.pt', "checkpoint_path": 'enhance_my_voice/chkpt_201000.pt', "mixed_file": 'utils/speakerA.wav', "reference_file": 'utils/speakerA.wav', "out_dir": 'output', } hp = HParam(args['config']) with torch.no_grad(): model = VoiceFilter(hp).cuda() chkpt_model = torch.load(args['checkpoint_path'])['model'] model.load_state_dict(chkpt_model) model.eval() embedder = SpeechEmbedder(hp).cuda() chkpt_embed = torch.load(args['embedder_path']) embedder.load_state_dict(chkpt_embed) embedder.eval() audio = Audio(hp) dvec_wav, _ = librosa.load(args['reference_file'], sr=16000) dvec_mel = audio.get_mel(dvec_wav) dvec_mel = torch.from_numpy(dvec_mel).float().cuda() dvec = embedder(dvec_mel) dvec = dvec.unsqueeze(0) mixed_wav, _ = librosa.load(args['mixed_file'], sr=16000) mag, phase = audio.wav2spec(mixed_wav) mag = torch.from_numpy(mag).float().cuda() mag = mag.unsqueeze(0) mask = model(mag, dvec) est_mag = mag * mask est_mag = est_mag[0].cpu().detach().numpy() # est_wav = audio.spec2wav(est_mag, phase) # os.makedirs(args['out_dir'], exist_ok=True) # out_path = os.path.join(args['out_dir'], 'result.wav') # librosa.output.write_wav(out_path, est_wav, sr=16000) return audio.spec2wav(est_mag, phase)
def main(args, hp): with open('out1.txt') as f: for line in f: res = line.split('\t') with torch.no_grad(): model = VoiceFilter(hp) chkpt_model = torch.load(args.checkpoint_path, map_location='cpu')['model'] model.load_state_dict(chkpt_model) model.eval() embedder = SpeechEmbedder(hp) chkpt_embed = torch.load(args.embedder_path, map_location='cpu') embedder.load_state_dict(chkpt_embed) embedder.eval() audio = Audio(hp) dvec_wav, _ = librosa.load(res[1], sr=16000) dvec_mel = audio.get_mel(dvec_wav) dvec_mel = torch.from_numpy(dvec_mel).float() dvec = embedder(dvec_mel) dvec = dvec.unsqueeze(0) mixed_wav, _ = librosa.load(res[0], sr=16000) mag, phase = audio.wav2spec(mixed_wav) mag = torch.from_numpy(mag).float() mag = mag.unsqueeze(0) mask = model(mag, dvec) est_mag = mag * mask est_mag = est_mag[0].cpu().detach().numpy() est_wav = audio.spec2wav(est_mag, phase) os.makedirs('/root/voicefilter/res', exist_ok=True) out_path = os.path.join('/root/voicefilter/res', f'{res[2]}') librosa.output.write_wav(out_path, est_wav, sr=16000)
# dir = '/data/our_dataset/test/3/joint' for dir in dirs: speaker_count = speaker_count + 1 print("Speaker : {}/56\n".format(speaker_count)) tree = dir.split('/') speaker_id = tree[-2] hp.data.test_dir = dir testloader = create_dataloader(hp, args, train=False) for batch in testloader: # length of batch is 1, set in dataloader ref_mel, eliminated_wav, mixed_wav, expected_hidden_wav, eliminated_mag, expected_hidden_mag, mixed_mag, mixed_phase, dvec_path, eliminated_wav_path, mixed_wav_path = \ batch[0] # print("expected_focused: {}".format(expected_focused_wav_path)) print("Mixed: {}".format(mixed_wav_path)) model = VoiceFilter(hp).cuda() chkpt_model = torch.load(args.checkpoint_path, map_location='cuda:0')['model'] model.load_state_dict(chkpt_model) model.eval() embedder = SpeechEmbedder(hp).cuda() chkpt_embed = torch.load(args.embedder_path) embedder.load_state_dict(chkpt_embed) embedder.eval() audio = Audio(hp) dvec_wav, _ = librosa.load(dvec_path, sr=16000) ref_mel = audio.get_mel(dvec_wav) ref_mel = torch.from_numpy(ref_mel).float().cuda() dvec = embedder(ref_mel)
def main(): # Training settings parser = argparse.ArgumentParser(description='PyTorch Voice Filter') parser.add_argument('-b', '--base_dir', type=str, default='.', help="Root directory of run.") parser.add_argument('--checkpoint_path', type=str, default=None, help='Path to last checkpoint') parser.add_argument('-e', '--embedder_path', type=str, required=True, help="path of embedder model pt file") parser.add_argument( '-m', '--model', type=str, required=True, help="Name of the model. Used for both logging and saving checkpoints." ) args = parser.parse_args() chkpt_path = args.checkpoint_path if args.checkpoint_path is not None else None pt_dir = os.path.join(args.base_dir, config.log['chkpt_dir'], args.model) os.makedirs(pt_dir, exist_ok=True) log_dir = os.path.join(args.base_dir, config.log['log_dir'], args.model) os.makedirs(log_dir, exist_ok=True) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler( os.path.join( log_dir, '%s-%d.log' % (args.model, time.time()))), logging.StreamHandler() ]) logger = logging.getLogger() writer = MyWriter(log_dir) trainloader = create_dataloader(train=True) testloader = create_dataloader(train=False) embedder_pt = torch.load(args.embedder_path) embedder = SpeechEmbedder().cuda() embedder.load_state_dict(embedder_pt) embedder.eval() model = nn.DataParallel(VoiceFilter()) optimizer = torch.optim.Adam(model.parameters(), lr=config.train['adam']) audio = Audio() starting_step = 0 starting_epoch = 1 if chkpt_path is not None: logger.info("Resuming from checkpoint: %s" % chkpt_path) checkpoint_file = torch.load(chkpt_path) model.load_state_dict(checkpoint_file['model']) starting_epoch = checkpoint_file['epoch'] starting_step = checkpoint_file['step'] else: logger.info("Starting new training run") scheduler = StepLR(optimizer, step_size=1, gamma=0.7) for epoch in range(starting_epoch, config.train['epoch'] + 1): train(embedder, model, optimizer, trainloader, writer, logger, epoch, pt_dir, starting_step) validate(audio, model, embedder, testloader, writer, epoch) scheduler.step() starting_step = 0 model_saver(model, pt_dir, config.train['epoch'], config.train['train_step_pre_epoch'])
if hp.data.train_dir == '' or hp.data.test_dir == '': logger.error("train_dir, test_dir cannot be empty.") raise Exception("Please specify directories of data in %s" % args.config) writer = MyWriter(hp, log_dir) torch.cuda.set_device(args.gpu) embedder_pt = torch.load(args.embedder_path) embedder = SpeechEmbedder(hp).cuda() embedder.load_state_dict(embedder_pt) embedder.eval() audio = Audio(hp) model = VoiceFilter(hp).cuda() checkpoint = torch.load(chkpt_path, map_location='cuda:0') model.load_state_dict(checkpoint['model']) step = 1 if args.hide: testloader = create_hide_dataloader(hp, args, train=False) else: testloader = create_focus_dataloader(hp, args, train=False) while True: if args.hide: validate_hide(audio, model, embedder, testloader, writer, step) step = step + 1 else: validate_focus(audio, model, embedder, testloader, writer, step) step = step + 1 print(step)
def train(args, pt_dir, chkpt_path, trainloader, testloader, writer, logger, hp, hp_str): # load embedder embedder_pt = torch.load(args.embedder_path) embedder = SpeechEmbedder(hp).cuda() embedder.load_state_dict(embedder_pt) # embedder = nn.DataParallel(embedder) embedder.eval() audio = Audio(hp) model = VoiceFilter(hp).cuda() # model = nn.DataParallel(model) if hp.train.optimizer == 'adabound': optimizer = AdaBound(model.parameters(), lr=hp.train.adabound.initial, final_lr=hp.train.adabound.final) elif hp.train.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=hp.train.adam) else: raise Exception("%s optimizer not supported" % hp.train.optimizer) step = 0 if chkpt_path is not None: logger.info("Resuming from checkpoint: %s" % chkpt_path) checkpoint = torch.load(chkpt_path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) step = checkpoint['step'] # will use new given hparams. if hp_str != checkpoint['hp_str']: logger.warning("New hparams is different from checkpoint.") else: logger.info("Starting new training run") try: criterion = nn.MSELoss() while True: model.train() for dvec_mels, target_mag, mixed_mag in trainloader: target_mag = target_mag.cuda() mixed_mag = mixed_mag.cuda() dvec_list = list() for mel in dvec_mels: mel = mel.cuda() dvec = embedder(mel) dvec_list.append(dvec) dvec = torch.stack(dvec_list, dim=0) dvec = dvec.detach() mask = model(mixed_mag, dvec) output = mixed_mag * mask # output = torch.pow(torch.clamp(output, min=0.0), hp.audio.power) # target_mag = torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power) loss = criterion(output, target_mag) optimizer.zero_grad() loss.backward() optimizer.step() step += 1 loss = loss.item() if loss > 1e8 or math.isnan(loss): logger.error("Loss exploded to %.02f at step %d!" % (loss, step)) raise Exception("Loss exploded") # write loss to tensorboard if step % hp.train.summary_interval == 0: writer.log_training(loss, step) logger.info("Wrote summary at step %d" % step) # 1. save checkpoint file to resume training # 2. evaluate and save sample to tensorboard if step % hp.train.checkpoint_interval == 0: save_path = os.path.join(pt_dir, 'chkpt_%d.pt' % step) torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'step': step, 'hp_str': hp_str, }, save_path) logger.info("Saved checkpoint to: %s" % save_path) validate(audio, model, embedder, testloader, writer, step) except Exception as e: logger.info("Exiting due to exception: %s" % e) traceback.print_exc()
def train(args, pt_dir, chkpt_path, trainloader, testloader, writer, logger, hp, hp_str): # load embedder torch.cuda.set_device("cuda:1") print(torch.cuda.current_device()) embedder_pt = torch.load(args.embedder_path) embedder = SpeechEmbedder(hp).cuda() embedder.load_state_dict(embedder_pt) embedder.eval() audio = Audio(hp) model = VoiceFilter(hp).cuda() ### Multi-GPU # model = VoiceFilter(hp) # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # model = nn.DataParallel(model) # model.to(device) if hp.train.optimizer == 'adabound': optimizer = AdaBound(model.parameters(), lr=hp.train.adabound.initial, final_lr=hp.train.adabound.final) elif hp.train.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=hp.train.adam) else: raise Exception("%s optimizer not supported" % hp.train.optimizer) if hp.scheduler.type == 'oneCycle': scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=hp.scheduler.oneCycle.max_lr, epochs=hp.train.epoch, steps_per_epoch=len(trainloader)) elif hp.scheduler.type == 'Plateau': scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode=hp.scheduler.Plateau.mode, patience=hp.scheduler.Plateau.patience, factor=hp.scheduler.Plateau.factor, verbose=True) step = 0 if chkpt_path is not None: logger.info("Resuming from checkpoint: %s" % chkpt_path) checkpoint = torch.load(chkpt_path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) # for oneCycleLR scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=hp.scheduler.oneCycle.max_lr, epochs=hp.train.epoch, steps_per_epoch=len(trainloader)) step = checkpoint['step'] # will use new given hparams. if hp_str != checkpoint['hp_str']: logger.warning("New hparams is different from checkpoint.") else: logger.info("Starting new training run") try: criterion = nn.MSELoss() for i_epoch in range(hp.train.epoch): model.train() for dvec_mels, target_mag, mixed_mag in trainloader: target_mag = target_mag.cuda() mixed_mag = mixed_mag.cuda() dvec_list = list() for mel in dvec_mels: mel = mel.cuda() dvec = embedder(mel) dvec_list.append(dvec) dvec = torch.stack(dvec_list, dim=0) dvec = dvec.detach() mask = model(mixed_mag, dvec) output = mixed_mag * mask # output = torch.pow(torch.clamp(output, min=0.0), hp.audio.power) # target_mag = torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power) loss = criterion(output, target_mag) optimizer.zero_grad() loss.backward() optimizer.step() if hp.scheduler.type == 'oneCycle': scheduler.step() elif hp.scheduler.type == 'Plateau': scheduler.step(loss) step += 1 loss = loss.item() if loss > 1e8 or math.isnan(loss): logger.error("Loss exploded to %.02f at step %d!" % (loss, step)) raise Exception("Loss exploded") # write loss to tensorboard if step % hp.train.summary_interval == 0: writer.log_training(loss, step) logger.info("Wrote summary at step %d in epoch %d" % (step, i_epoch)) if step % hp.train.validation_interval == 0: validate(audio, model, embedder, testloader, writer, step) # 1. save checkpoint file to resume training # 2. evaluate and save sample to tensorboard if step % hp.train.checkpoint_interval == 0: save_path = os.path.join(pt_dir, 'chkpt_%d.pt' % step) #save_dict_path = os.path.join(pt_dir, 'chkpt_%d_dict.pt' % step) torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'step': step, 'hp_str': hp_str, }, save_path) #torch.save(model.module.state_dict() , save_dict_path) logger.info("Saved checkpoint to: %s" % save_path) except Exception as e: logger.info("Exiting due to exception: %s" % e) traceback.print_exc()