def __getitem__(self, idx): wavfile = self.wav_list[idx] featfile = self.feat_list[idx] x, _ = sf.read(wavfile, dtype=np.float32) if check_hdf5(featfile, self.string_path): h = read_hdf5(featfile, self.string_path) else: h = read_hdf5(featfile, self.string_path_org) x, h = validate_length(x, h, self.upsampling_factor) if self.wav_transform_in is not None: x_t = self.wav_transform_in(x) if self.wav_transform is not None: if self.wav_transform_out is not None: x = self.wav_transform_out(self.wav_transform(x)) else: x = self.wav_transform(x) slen = x.shape[0] flen = h.shape[0] h = torch.FloatTensor(self.pad_feat_transform(h)) if self.wav_transform is not None and self.wav_transform_out is None: x = torch.LongTensor(self.pad_wav_transform(x)) else: x = torch.FloatTensor(self.pad_wav_transform(x)) if self.wav_transform_in is not None: x_t = torch.LongTensor(self.pad_wav_transform(x_t)) return { 'x_t': x_t, 'x': x, 'feat': h, 'slen': slen, 'flen': flen, 'featfile': featfile } else: return { 'x': x, 'feat': h, 'slen': slen, 'flen': flen, 'featfile': featfile }
def main(): parser = argparse.ArgumentParser() # path setting parser.add_argument("--waveforms", required=True, type=str, help="directory or list of wav files") parser.add_argument("--waveforms_eval", required=True, type=str, help="directory or list of evaluation wav files") parser.add_argument("--feats", required=True, type=str, help="directory or list of aux feat files") parser.add_argument("--feats_eval", required=True, type=str, help="directory or list of evaluation aux feat files") parser.add_argument("--stats", required=True, type=str, help="hdf5 file including statistics") parser.add_argument("--expdir", required=True, type=str, help="directory to save the model") # network structure setting parser.add_argument("--n_quantize", default=256, type=int, help="number of quantization") parser.add_argument("--n_aux", default=39, type=int, help="number of dimension of aux feats") parser.add_argument("--dilation_depth", default=3, type=int, help="depth of dilation") parser.add_argument("--dilation_repeat", default=3, type=int, help="depth of dilation") parser.add_argument("--hid_chn", default=192, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--skip_chn", default=256, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--kernel_size", default=6, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--aux_kernel_size", default=3, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--aux_dilation_size", default=2, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--upsampling_factor", default=110, type=int, help="upsampling factor of aux features" "(if set 0, do not apply)") parser.add_argument("--string_path", default="/feat_org_lf0", type=str, help="directory to save the model") # network training setting parser.add_argument("--lr", default=1e-4, type=float, help="learning rate") parser.add_argument("--batch_size", default=1100, type=int, help="batch size (if set 0, utterance batch will be used)") parser.add_argument("--epoch_count", default=500, type=int, help="number of training epochs") parser.add_argument("--do_prob", default=0, type=float, help="dropout probability") parser.add_argument("--wav_conv_flag", default=False, type=strtobool, help="flag to use 1d conv of wav") # other setting parser.add_argument("--audio_in", default=False, type=strtobool, help="flag for including previous sample in conditioning feat") parser.add_argument("--seed", default=1, type=int, help="seed number") parser.add_argument("--resume", default=None, type=str, help="model path to restart training") parser.add_argument("--pretrained", default=None, type=str, help="model path to restart training") parser.add_argument("--GPU_device", default=None, type=int, help="selection of GPU device") parser.add_argument("--verbose", default=1, type=int, help="log level") args = parser.parse_args() if args.GPU_device is not None: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU_device) # make experimental directory if not os.path.exists(args.expdir): os.makedirs(args.expdir) # set log level if args.verbose == 1: logging.basicConfig(level=logging.INFO, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) elif args.verbose > 1: logging.basicConfig(level=logging.DEBUG, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) else: logging.basicConfig(level=logging.WARN, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) logging.warn("logging is disabled.") # fix seed os.environ['PYTHONHASHSEED'] = str(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.benchmark = True #faster #torch.backends.cudnn.deterministic = True #reproducibility_slower #torch.backends.cudnn.benchmark = False #reproducibility_slower # save args as conf torch.save(args, args.expdir + "/model.conf") # # define network model = DSWNV( n_quantize=args.n_quantize, n_aux=args.n_aux, hid_chn=args.hid_chn, skip_chn=args.skip_chn, dilation_depth=args.dilation_depth, dilation_repeat=args.dilation_repeat, kernel_size=args.kernel_size, aux_kernel_size=args.aux_kernel_size, aux_dilation_size=args.aux_dilation_size, audio_in_flag=args.audio_in, do_prob=args.do_prob, wav_conv_flag=args.wav_conv_flag, upsampling_factor=args.upsampling_factor) logging.info(model) criterion = nn.CrossEntropyLoss() # define transforms string_path_name = args.string_path.split('feat_')[1] logging.info(string_path_name) scaler = StandardScaler() if check_hdf5(args.stats, "/mean_"+string_path_name): scaler.mean_ = read_hdf5(args.stats, "/mean_"+string_path_name) scaler.scale_ = read_hdf5(args.stats, "/scale_"+string_path_name) elif check_hdf5(args.stats, "/mean_"+args.string_path): scaler.mean_ = read_hdf5(args.stats, "/mean_"+args.string_path) scaler.scale_ = read_hdf5(args.stats, "/scale_"+args.string_path) else: scaler.mean_ = read_hdf5(args.stats, "/mean_feat_"+string_path_name) scaler.scale_ = read_hdf5(args.stats, "/scale_feat_"+string_path_name) mean_src = torch.FloatTensor(scaler.mean_) std_src = torch.FloatTensor(scaler.scale_) # send to gpu if torch.cuda.is_available(): model.cuda() criterion.cuda() mean_src = mean_src.cuda() std_src = std_src.cuda() else: logging.error("gpu is not available. please check the setting.") sys.exit(1) model.train() model.apply(initialize) model.scale_in.weight = torch.nn.Parameter(torch.unsqueeze(torch.diag(1.0/std_src.data),2)) model.scale_in.bias = torch.nn.Parameter(-(mean_src.data/std_src.data)) for param in model.parameters(): param.requires_grad = True for param in model.scale_in.parameters(): param.requires_grad = False parameters = filter(lambda p: p.requires_grad, model.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1000000 logging.info('Trainable Parameters: %.3f million' % parameters) module_list = list(model.conv_aux.parameters()) + list(model.upsampling.parameters()) if model.wav_conv_flag: module_list += list(model.wav_conv.parameters()) module_list += list(model.causal.parameters()) module_list += list(model.in_x.parameters()) + list(model.dil_h.parameters()) module_list += list(model.out_skip.parameters()) module_list += list(model.out_1.parameters()) + list(model.out_2.parameters()) optimizer = torch.optim.Adam(module_list, lr=args.lr) # resume if args.pretrained is not None: checkpoint = torch.load(args.pretrained) model.load_state_dict(checkpoint["model"]) epoch_idx = checkpoint["iterations"] logging.info("pretrained from %d-iter checkpoint." % epoch_idx) epoch_idx = 0 elif args.resume is not None: checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) epoch_idx = checkpoint["iterations"] logging.info("restored from %d-iter checkpoint." % epoch_idx) else: epoch_idx = 0 wav_transform = transforms.Compose([lambda x: encode_mu_law(x, args.n_quantize)]) # define generator training if os.path.isdir(args.waveforms): filenames = sorted(find_files(args.waveforms, "*.wav", use_dir_name=False)) wav_list = [args.waveforms + "/" + filename for filename in filenames] feat_list = [args.feats + "/" + filename.replace(".wav", ".h5") for filename in filenames] elif os.path.isfile(args.waveforms): wav_list = read_txt(args.waveforms) feat_list = read_txt(args.feats) else: logging.error("--waveforms should be directory or list.") sys.exit(1) assert len(wav_list) == len(feat_list) logging.info("number of training data = %d." % len(wav_list)) if args.pretrained is None: generator = train_generator( wav_list, feat_list, model.receptive_field, string_path=args.string_path, batch_size=args.batch_size, wav_transform=wav_transform, training=True, upsampling_factor=args.upsampling_factor) else: generator = train_generator( wav_list, feat_list, model.receptive_field, string_path=args.string_path, batch_size=args.batch_size, wav_transform=wav_transform, training=True, upsampling_factor=args.upsampling_factor) # define generator evaluation if os.path.isdir(args.waveforms_eval): filenames_eval = sorted(find_files(args.waveforms_eval, "*.wav", use_dir_name=False)) wav_list_eval = [args.waveforms_eval + "/" + filename for filename in filenames_eval] feat_list_eval = [args.feats_eval + "/" + filename.replace(".wav", ".h5") \ for filename in filenames_eval] elif os.path.isfile(args.waveforms_eval): wav_list_eval = read_txt(args.waveforms_eval) feat_list_eval = read_txt(args.feats_eval) else: logging.error("--waveforms_eval should be directory or list.") sys.exit(1) logging.info("number of evaluation data = %d." % len(wav_list_eval)) assert len(wav_list_eval) == len(feat_list_eval) if args.pretrained is None: generator_eval = train_generator( wav_list_eval, feat_list_eval, model.receptive_field, string_path=args.string_path, batch_size=args.batch_size, wav_transform=wav_transform, training=False, upsampling_factor=args.upsampling_factor) else: generator_eval = train_generator( wav_list_eval, feat_list_eval, model.receptive_field, string_path=args.string_path, batch_size=args.batch_size, wav_transform=wav_transform, training=False, upsampling_factor=args.upsampling_factor) # train loss = [] total = 0 iter_idx = 0 iter_count = 0 min_eval_loss = 99999999.99 min_eval_loss_std = 99999999.99 min_idx = -1 if args.resume is not None: np.random.set_state(checkpoint["numpy_random_state"]) torch.set_rng_state(checkpoint["torch_random_state"]) logging.info("==%d EPOCH==" % (epoch_idx+1)) logging.info("Training data") while epoch_idx < args.epoch_count: start = time.time() batch_x_class, batch_x, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = next(generator) if c_idx < 0: # summarize epoch numpy_random_state = np.random.get_state() torch_random_state = torch.get_rng_state() # save current epoch model save_checkpoint(args.expdir, model, optimizer, numpy_random_state, torch_random_state, epoch_idx+1) # report current epoch logging.info("(EPOCH:%d) average training loss = %.6f (+- %.6f) (%.3f min., %.3f sec / batch)" % ( epoch_idx + 1, np.mean(np.array(loss, dtype=np.float64)), \ np.std(np.array(loss, dtype=np.float64)), total / 60.0, total / iter_count)) logging.info("estimated training required time = {0.days:02}:{0.hours:02}:{0.minutes:02}:"\ "{0.seconds:02}".format(relativedelta(seconds=int((args.epoch_count - (epoch_idx + 1)) * total)))) # compute loss in evaluation data loss = [] total = 0 iter_count = 0 model.eval() for param in model.parameters(): param.requires_grad = False logging.info("Evaluation data") with torch.no_grad(): while True: start = time.time() batch_x_class, batch_x, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = \ next(generator_eval) if c_idx < 0: break tf = batch_h.shape[0] ts = batch_x.shape[0] batch_h = batch_h[h_ss:] batch_x_class = batch_x_class[x_ss:] batch_x = batch_x[x_ss:] if h_bs != -1: batch_h = batch_h[:h_bs] batch_x_class = batch_x_class[1:x_bs] batch_x = batch_x[:x_bs-1] else: batch_x = batch_x[:-1] batch_x_class = batch_x_class[1:] batch_h = batch_h.transpose(0,1).unsqueeze(0) batch_x = batch_x.transpose(0,1).unsqueeze(0) batch_output = model(batch_x, batch_h)[0] if h_ss > 0: batch_loss = criterion(batch_output[model.receptive_field:], \ batch_x_class[model.receptive_field:]) else: batch_loss = criterion(batch_output, batch_x_class) loss.append(batch_loss.item()) logging.info("batch eval loss %s [%d:%d] %d %d %d %d %d %d = %.3f (%.3f sec)" % ( os.path.basename(os.path.dirname(wavfile))+"/"+os.path.basename(wavfile), c_idx+1, \ utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, batch_loss.item(), time.time() - start)) iter_count += 1 total += time.time() - start eval_loss = np.mean(np.array(loss, dtype=np.float64)) eval_loss_std = np.std(np.array(loss, dtype=np.float64)) logging.info("(EPOCH:%d) average evaluation loss = %.6f (+- %.6f) (%.3f min., %.3f sec / batch)" %( epoch_idx + 1, eval_loss, eval_loss_std, total / 60.0, total / iter_count)) if (eval_loss+eval_loss_std) <= (min_eval_loss+min_eval_loss_std): min_eval_loss = eval_loss min_eval_loss_std = eval_loss_std min_idx = epoch_idx logging.info("min_eval_loss=%.6f (+- %.6f), min_idx=%d" % (\ min_eval_loss, min_eval_loss_std, min_idx+1)) loss = [] total = 0 iter_count = 0 epoch_idx += 1 np.random.set_state(numpy_random_state) torch.set_rng_state(torch_random_state) model.train() for param in model.parameters(): param.requires_grad = True for param in model.scale_in.parameters(): param.requires_grad = False # start next epoch if epoch_idx < args.epoch_count: start = time.time() logging.info("==%d EPOCH==" % (epoch_idx+1)) logging.info("Training data") batch_x_class, batch_x, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = \ next(generator) # feedforward and backpropagate current batch if epoch_idx < args.epoch_count: logging.info("%d iteration [%d]" % (iter_idx+1, epoch_idx+1)) tf = batch_h.shape[0] ts = batch_x.shape[0] batch_h = batch_h[h_ss:] batch_x_class = batch_x_class[x_ss:] batch_x = batch_x[x_ss:] if h_bs != -1: batch_h = batch_h[:h_bs] batch_x_class = batch_x_class[1:x_bs] batch_x = batch_x[:x_bs-1] else: batch_x = batch_x[:-1] batch_x_class = batch_x_class[1:] batch_h = batch_h.transpose(0,1).unsqueeze(0) batch_x = batch_x.transpose(0,1).unsqueeze(0) batch_output = model(batch_x, batch_h, do=True)[0] if h_ss > 0: batch_loss = criterion(batch_output[model.receptive_field:], \ batch_x_class[model.receptive_field:]) else: batch_loss = criterion(batch_output, batch_x_class) optimizer.zero_grad() batch_loss.backward() optimizer.step() loss.append(batch_loss.item()) logging.info("batch loss %s [%d:%d] %d %d %d %d %d %d = %.3f (%.3f sec)" % ( os.path.basename(os.path.dirname(wavfile))+"/"+os.path.basename(wavfile), c_idx+1, utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, batch_loss.item(), time.time() - start)) iter_idx += 1 iter_count += 1 total += time.time() - start # save final model model.cpu() torch.save({"model": model.state_dict()}, args.expdir + "/checkpoint-final.pkl") logging.info("final checkpoint created.")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--feats", default=None, required=True, help="name of the list of hdf5 files") parser.add_argument("--stats", default=None, required=True, help="filename of hdf5 format") parser.add_argument("--expdir", required=True, type=str, help="directory to save the log") parser.add_argument("--stdim", default=5, type=int, help="directory to save the log") parser.add_argument("--spkr", default=None, type=str, help="directory to save the log") parser.add_argument("--verbose", default=1, type=int, help="log message level") args = parser.parse_args() # set log level if args.verbose == 1: logging.basicConfig( level=logging.INFO, format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/calc_stats.log") logging.getLogger().addHandler(logging.StreamHandler()) elif args.verbose > 1: logging.basicConfig( level=logging.DEBUG, format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/calc_stats.log") logging.getLogger().addHandler(logging.StreamHandler()) else: logging.basicConfig( level=logging.WARN, format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/calc_stats.log") logging.getLogger().addHandler(logging.StreamHandler()) logging.warn("logging is disabled.") # read list and define scaler filenames = read_txt(args.feats) scaler_feat_org_lf0 = StandardScaler() logging.info("number of training utterances = " + str(len(filenames))) #var = [] var_range = [] f0s_range = np.empty((0)) # process over all of data for filename in filenames: logging.info(filename) feat_org_lf0 = read_hdf5(filename, "/feat_org_lf0") scaler_feat_org_lf0.partial_fit(feat_org_lf0) mcep_range = feat_org_lf0[:, args.stdim:] var_range.append(np.var(mcep_range, axis=0)) logging.info(mcep_range.shape) if check_hdf5(filename, "/f0_range"): f0_range = read_hdf5(filename, "/f0_range") else: f0_range = read_hdf5(filename, "/f0") nonzero_indices = np.nonzero(f0_range) logging.info(f0_range[nonzero_indices].shape) logging.info(f0s_range.shape) f0s_range = np.concatenate([f0s_range, f0_range[nonzero_indices]]) logging.info(f0s_range.shape) mean_feat_org_lf0 = scaler_feat_org_lf0.mean_ scale_feat_org_lf0 = scaler_feat_org_lf0.scale_ gv_range_mean = np.mean(np.array(var_range), axis=0) gv_range_var = np.var(np.array(var_range), axis=0) logging.info(gv_range_mean) logging.info(gv_range_var) f0_range_mean = np.mean(f0s_range) f0_range_std = np.std(f0s_range) logging.info(f0_range_mean) logging.info(f0_range_std) lf0_range_mean = np.mean(np.log(f0s_range)) lf0_range_std = np.std(np.log(f0s_range)) logging.info(lf0_range_mean) logging.info(lf0_range_std) logging.info(np.array_equal(f0_range_mean, np.exp(lf0_range_mean))) logging.info(np.array_equal(f0_range_std, np.exp(lf0_range_std))) logging.info(mean_feat_org_lf0) logging.info(scale_feat_org_lf0) write_hdf5(args.stats, "/mean_feat_org_lf0", mean_feat_org_lf0) write_hdf5(args.stats, "/scale_feat_org_lf0", scale_feat_org_lf0) write_hdf5(args.stats, "/gv_range_mean", gv_range_mean) write_hdf5(args.stats, "/gv_range_var", gv_range_var) write_hdf5(args.stats, "/f0_range_mean", f0_range_mean) write_hdf5(args.stats, "/f0_range_std", f0_range_std) write_hdf5(args.stats, "/lf0_range_mean", lf0_range_mean) write_hdf5(args.stats, "/lf0_range_std", lf0_range_std)
def world_feature_extract(queue, wav_list, args): """EXTRACT WORLD FEATURE VECTOR Parameters ---------- queue : multiprocessing.Queue() the queue to store the file name of utterance wav_list : list list of the wav files args : feature extract arguments """ # define feature extractor feature_extractor = FeatureExtractor(analyzer="world", fs=args.fs, shiftms=args.shiftms, minf0=args.minf0, maxf0=args.maxf0, fftl=args.fftl) # extraction for i, wav_name in enumerate(wav_list): # check exists if args.feature_dir == None: feat_name = wav_name.replace("wav", args.feature_format) else: feat_name = rootdir_replace(wav_name, extname=args.feature_format, newdir=args.feature_dir) #if not os.path.exists(os.path.dirname(feat_name)): # os.makedirs(os.path.dirname(feat_name)) if check_hdf5(feat_name, "/world"): if args.overwrite: logging.info("overwrite %s (%d/%d)" % (wav_name, i + 1, len(wav_list))) else: logging.info("skip %s (%d/%d)" % (wav_name, i + 1, len(wav_list))) continue else: logging.info("now processing %s (%d/%d)" % (wav_name, i + 1, len(wav_list))) # load wavfile and apply low cut filter fs, x = wavfile.read(wav_name) x = np.array(x, dtype=np.float32) if args.highpass_cutoff != 0: x = low_cut_filter(x, fs, cutoff=args.highpass_cutoff) # check sampling frequency if not fs == args.fs: logging.error("sampling frequency is not matched.") sys.exit(1) # extract features f0, spc, ap = feature_extractor.analyze(x) codeap = feature_extractor.codeap() mcep = feature_extractor.mcep(dim=args.mcep_dim, alpha=args.mcep_alpha) npow = feature_extractor.npow() uv, cont_f0 = convert_continuos_f0(f0) lpf_fs = int(1.0 / (args.shiftms * 0.001)) cont_f0_lpf = low_pass_filter(cont_f0, lpf_fs, cutoff=20) next_cutoff = 70 while not (cont_f0_lpf > [0]).all(): logging.info("%s low-pass-filtered [%dHz]" % (feat_name, next_cutoff)) cont_f0_lpf = low_pass_filter(cont_f0, lpf_fs, cutoff=next_cutoff) next_cutoff *= 2 # concatenate cont_f0_lpf = np.expand_dims(cont_f0_lpf, axis=-1) uv = np.expand_dims(uv, axis=-1) feats = np.concatenate([uv, cont_f0_lpf, mcep, codeap], axis=1) # save feature write_hdf5(feat_name, "/world", feats) if args.save_f0: write_hdf5(feat_name, "/f0", f0) if args.save_ap: write_hdf5(feat_name, "/ap", ap) if args.save_spc: write_hdf5(feat_name, "/spc", spc) if args.save_npow: write_hdf5(feat_name, "/npow", npow) if args.save_extended: # extend time resolution upsampling_factor = int(args.shiftms * fs * 0.001) feats_extended = extend_time(feats, upsampling_factor) feats_extended = feats_extended.astype(np.float32) write_hdf5(feat_name, "/world_extend", feats_extended) if args.save_vad: _, vad_idx = extfrm(mcep, npow, power_threshold=args.pow_th) write_hdf5(feat_name, "/vad_idx", vad_idx) queue.put('Finish')
def world_speech_synthesis(queue, wav_list, args): """WORLD SPEECH SYNTHESIS Parameters ---------- queue : multiprocessing.Queue() the queue to store the file name of utterance wav_list : list list of the wav files args : feature extract arguments """ # define ynthesizer synthesizer = Synthesizer(fs=args.fs, fftl=args.fftl, shiftms=args.shiftms) # synthesis for i, wav_name in enumerate(wav_list): if args.feature_dir == None: restored_name = wav_name.replace("wav", args.feature_format + "_restored") restored_name = restored_name.replace( ".%s" % args.feature_format + "_restored", ".wav") feat_name = wav_name.replace("wav", args.feature_format) else: restored_name = rootdir_replace(wav_name, newdir=args.feature_dir + "restored") feat_name = rootdir_replace(wav_name, extname=args.feature_format, newdir=args.feature_dir) if os.path.exists(restored_name): if args.overwrite: logging.info("overwrite %s (%d/%d)" % (restored_name, i + 1, len(wav_list))) else: logging.info("skip %s (%d/%d)" % (restored_name, i + 1, len(wav_list))) continue else: logging.info("now processing %s (%d/%d)" % (restored_name, i + 1, len(wav_list))) # load acoustic features if check_hdf5(feat_name, "/world"): h = read_hdf5(feat_name, "/world") else: logging.error("%s is not existed." % (feat_name)) sys.exit(1) if check_hdf5(feat_name, "/f0"): f0 = read_hdf5(feat_name, "/f0") else: uv = h[:, 0].copy(order='C') f0 = h[:, args.f0_dim_idx].copy(order='C') # cont_f0_lpf fz_idx = np.where(uv == 0.0) f0[fz_idx] = 0.0 if check_hdf5(feat_name, "/ap"): ap = read_hdf5(feat_name, "/ap") else: codeap = h[:, args.ap_dim_idx:].copy(order='C') ap = pyworld.decode_aperiodicity(codeap, args.fs, args.fftl) mcep = h[:, args.mcep_dim_start:args.mcep_dim_end].copy(order='C') # waveform synthesis wav = synthesizer.synthesis(f0, mcep, ap, alpha=args.mcep_alpha) wav = np.clip(wav, -32768, 32767) wavfile.write(restored_name, args.fs, wav.astype(np.int16)) #logging.info("wrote %s." % (restored_name)) queue.put('Finish')
def main(): parser = argparse.ArgumentParser() # path setting parser.add_argument("--waveforms", required=True, type=str, help="directory or list of wav files") parser.add_argument("--waveforms_eval", required=True, type=str, help="directory or list of evaluation wav files") parser.add_argument("--feats", required=True, type=str, help="directory or list of aux feat files") parser.add_argument("--feats_eval", required=True, type=str, help="directory or list of evaluation aux feat files") parser.add_argument("--stats", required=True, type=str, help="hdf5 file including statistics") parser.add_argument("--expdir", required=True, type=str, help="directory to save the model") # network structure setting parser.add_argument("--n_aux", default=54, type=int, help="number of dimension of aux feats") parser.add_argument("--skip_chn", default=256, type=int, help="number of channels of skip output") parser.add_argument("--seg", default=1, type=int, help="segment size") parser.add_argument("--dilation_depth", default=3, type=int, help="depth of dilation") parser.add_argument("--dilation_repeat", default=2, type=int, help="repeat of dilation depth") parser.add_argument("--hid_chn", default=192, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--kernel_size", default=7, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--aux_kernel_size", default=3, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--aux_dilation_size", default=2, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--upsampling_factor", default=110, type=int, help="upsampling factor of aux features" "(if set 0, do not apply)") parser.add_argument("--n_fft_facts", default=17, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--string_path", default="/feat_org_lf0", type=str, help="directory to save the model") # network training setting parser.add_argument("--lr", default=1e-4, type=float, help="learning rate") parser.add_argument("--batch_size", default=8800, type=int, help="batch size (if set 0, utterance batch will be used)") parser.add_argument("--epoch_count", default=4000, type=int, help="number of training epochs") parser.add_argument("--do_prob", default=0, type=float, help="dropout probability") parser.add_argument("--lpc", default=0, type=int, help="number of linear predictive coefficients for location estimate") parser.add_argument("--aux_conv2d_flag", default=False, type=strtobool, help="flag to use 2d conv of aux") parser.add_argument("--wav_conv_flag", default=False, type=strtobool, help="flag to use 1d conv of wav") # other setting parser.add_argument("--seed", default=1, type=int, help="seed number") parser.add_argument("--resume", default=None, type=str, help="model path to restart training") parser.add_argument("--pretrained", default=None, type=str, help="model path to restart training") parser.add_argument("--GPU_device", default=None, type=int, help="selection of GPU device") parser.add_argument("--verbose", default=1, type=int, help="log level") args = parser.parse_args() if args.GPU_device is not None: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU_device) # make experimental directory if not os.path.exists(args.expdir): os.makedirs(args.expdir) # set log level if args.verbose == 1: logging.basicConfig(level=logging.INFO, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) elif args.verbose > 1: logging.basicConfig(level=logging.DEBUG, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) else: logging.basicConfig(level=logging.WARN, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) logging.warn("logging is disabled.") # fix seed os.environ['PYTHONHASHSEED'] = str(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.benchmark = True #faster #torch.backends.cudnn.deterministic = True #reproducibility_slower #torch.backends.cudnn.benchmark = False #reproducibility_slower # save args as conf torch.save(args, args.expdir + "/model.conf") # define network model = CSWNV( n_aux=args.n_aux, skip_chn=args.skip_chn, hid_chn=args.hid_chn, dilation_depth=args.dilation_depth, dilation_repeat=args.dilation_repeat, kernel_size=args.kernel_size, aux_kernel_size=args.aux_kernel_size, aux_dilation_size=args.aux_dilation_size, do_prob=args.do_prob, seg=args.seg, lpc=args.lpc, aux_conv2d_flag=args.aux_conv2d_flag, wav_conv_flag=args.wav_conv_flag, upsampling_factor=args.upsampling_factor) logging.info(model) criterion_lsd = LSDloss() criterion_laplace = LaplaceLoss() # define transforms string_path_name = args.string_path.split('feat_')[1] logging.info(string_path_name) scaler = StandardScaler() if check_hdf5(args.stats, "/mean_"+string_path_name): scaler.mean_ = read_hdf5(args.stats, "/mean_"+string_path_name) scaler.scale_ = read_hdf5(args.stats, "/scale_"+string_path_name) elif check_hdf5(args.stats, "/mean_"+args.string_path): scaler.mean_ = read_hdf5(args.stats, "/mean_"+args.string_path) scaler.scale_ = read_hdf5(args.stats, "/scale_"+args.string_path) else: scaler.mean_ = read_hdf5(args.stats, "/mean_feat_"+string_path_name) scaler.scale_ = read_hdf5(args.stats, "/scale_feat_"+string_path_name) mean_src = torch.FloatTensor(scaler.mean_) std_src = torch.FloatTensor(scaler.scale_) # send to gpu if torch.cuda.is_available(): model.cuda() criterion_lsd.cuda() criterion_laplace.cuda() mean_src = mean_src.cuda() std_src = std_src.cuda() else: logging.error("gpu is not available. please check the setting.") sys.exit(1) model.train() model.apply(initialize) model.scale_in.weight = torch.nn.Parameter(torch.unsqueeze(torch.diag(1.0/std_src.data),2)) model.scale_in.bias = torch.nn.Parameter(-(mean_src.data/std_src.data)) for param in model.parameters(): param.requires_grad = True for param in model.scale_in.parameters(): param.requires_grad = False parameters = filter(lambda p: p.requires_grad, model.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1000000 logging.info('Trainable Parameters: %.3f million' % parameters) module_list = list(model.conv_aux.parameters()) module_list += list(model.upsampling.parameters()) if model.aux_conv2d_flag and model.seg > 1: module_list += list(model.aux_conv2d.parameters()) if model.wav_conv_flag: module_list += list(model.wav_conv.parameters()) module_list += list(model.causal.parameters()) + list(model.in_x.parameters()) module_list += list(model.dil_h.parameters()) + list(model.out_skip.parameters()) module_list += list(model.out_1.parameters()) + list(model.out_2.parameters()) optimizer = torch.optim.Adam(module_list, lr=args.lr) # resume if args.pretrained is not None: checkpoint = torch.load(args.pretrained) model.load_state_dict(checkpoint["model"]) epoch_idx = checkpoint["iterations"] logging.info("pretrained from %d-iter checkpoint." % epoch_idx) epoch_idx = 0 elif args.resume is not None: checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) epoch_idx = checkpoint["iterations"] logging.info("restored from %d-iter checkpoint." % epoch_idx) else: epoch_idx = 0 # define generator training if os.path.isdir(args.waveforms): filenames = sorted(find_files(args.waveforms, "*.wav", use_dir_name=False)) wav_list = [args.waveforms + "/" + filename for filename in filenames] feat_list = [args.feats + "/" + filename.replace(".wav", ".h5") for filename in filenames] elif os.path.isfile(args.waveforms): wav_list = read_txt(args.waveforms) feat_list = read_txt(args.feats) else: logging.error("--waveforms should be directory or list.") sys.exit(1) assert len(wav_list) == len(feat_list) logging.info("number of training data = %d." % len(wav_list)) if args.pretrained is None: generator = train_generator( wav_list, feat_list, model.receptive_field, string_path=args.string_path, seg=model.seg, batch_size=args.batch_size, training=True, upsampling_factor=args.upsampling_factor) else: generator = train_generator( wav_list, feat_list, model.receptive_field, string_path=args.string_path, seg=model.seg, batch_size=args.batch_size, training=True, upsampling_factor=args.upsampling_factor) # define generator evaluation if os.path.isdir(args.waveforms_eval): filenames_eval = sorted(find_files(args.waveforms_eval, "*.wav", use_dir_name=False)) wav_list_eval = [args.waveforms_eval + "/" + filename for filename in filenames_eval] feat_list_eval = [args.feats_eval + "/" + filename.replace(".wav", ".h5") \ for filename in filenames_eval] elif os.path.isfile(args.waveforms_eval): wav_list_eval = read_txt(args.waveforms_eval) feat_list_eval = read_txt(args.feats_eval) else: logging.error("--waveforms_eval should be directory or list.") sys.exit(1) logging.info("number of evaluation data = %d." % len(wav_list_eval)) assert len(wav_list_eval) == len(feat_list_eval) if args.pretrained is None: generator_eval = train_generator( wav_list_eval, feat_list_eval, model.receptive_field, string_path=args.string_path, seg=model.seg, batch_size=args.batch_size, training=False, upsampling_factor=args.upsampling_factor) else: generator_eval = train_generator( wav_list_eval, feat_list_eval, model.receptive_field, string_path=args.string_path, seg=model.seg, batch_size=args.batch_size, training=False, upsampling_factor=args.upsampling_factor) # train loss_laplace = [] loss_err = [] loss_lsd = [] fft_facts = [] init_fft = 64 hann_win = [None]*args.n_fft_facts if args.n_fft_facts == 5: fft_facts = [128, 256, 512, 1024, 2048] for i in range(args.n_fft_facts): hann_win[i] = torch.hann_window(fft_facts[i]).cuda() elif args.n_fft_facts == 9: fft_facts = [128, 192, 256, 384, 512, 768, 1024, 1536, 2048] for i in range(args.n_fft_facts): hann_win[i] = torch.hann_window(fft_facts[i]).cuda() elif args.n_fft_facts == 17: fft_facts = [128, 160, 192, 224, 256, 320, 384, 448, 512, 640, 768, 896, 1024, 1280, 1536, 1792, 2048] for i in range(args.n_fft_facts): hann_win[i] = torch.hann_window(fft_facts[i]).cuda() else: for i in range(args.n_fft_facts): if i % 2 == 0: init_fft *= 2 fft_facts.append(init_fft) else: fft_facts.append(init_fft+int(init_fft/2)) hann_win[i] = torch.hann_window(fft_facts[i]).cuda() logging.info(fft_facts) batch_stft_loss = [None]*args.n_fft_facts stft_out = [None]*args.n_fft_facts stft_trg = [None]*args.n_fft_facts total = 0 iter_idx = 0 iter_count = 0 min_eval_loss_lsd = 99999999.99 min_eval_loss_laplace = 99999999.99 min_eval_loss_err = 99999999.99 min_eval_loss_lsd_std = 99999999.99 min_eval_loss_laplace_std = 99999999.99 min_eval_loss_err_std = 99999999.99 min_idx = -1 if args.resume is not None: np.random.set_state(checkpoint["numpy_random_state"]) torch.set_rng_state(checkpoint["torch_random_state"]) logging.info("==%d EPOCH==" % (epoch_idx+1)) logging.info("Training data") #args.epoch_count = 5300 while epoch_idx < args.epoch_count: start = time.time() batch_x_float, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = next(generator) if c_idx < 0: # summarize epoch # save current epoch model numpy_random_state = np.random.get_state() torch_random_state = torch.get_rng_state() save_checkpoint(args.expdir, model, optimizer, numpy_random_state, torch_random_state, epoch_idx+1) # report current epoch logging.info("(EPOCH:%d) average training loss = %.6f (+- %.6f) %.6f dB (+- %.6f dB) %.6f "\ "(+- %.6f) (%.3f min., %.3f sec / batch)" % ( epoch_idx + 1, np.mean(loss_laplace), np.std(loss_laplace), np.mean(loss_lsd), \ np.std(loss_lsd), np.mean(loss_err), np.std(loss_err), total / 60.0, total / iter_count)) logging.info("estimated training required time = {0.days:02}:{0.hours:02}:{0.minutes:02}:"\ "{0.seconds:02}".format(relativedelta(seconds=int((args.epoch_count - (epoch_idx + 1)) * total)))) # compute loss in evaluation data loss_lsd = [] loss_err = [] loss_laplace = [] total = 0 iter_count = 0 model.eval() for param in model.parameters(): param.requires_grad = False logging.info("Evaluation data") while True: with torch.no_grad(): start = time.time() batch_x_float, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = \ next(generator_eval) if c_idx < 0: break tf = batch_h.shape[0] ts = batch_x_float.shape[0] batch_h = batch_h[h_ss:] batch_x_ = batch_x_float[x_ss:] if model.lpc > 0: if x_ss+model.lpc_offset >= 0: batch_x_lpc = batch_x_float[x_ss+model.lpc_offset:] else: batch_x_lpc = batch_x_float[x_ss:] if h_bs != -1: batch_h = batch_h[:h_bs] if model.lpc > 0: if x_ss+model.lpc_offset >= 0: batch_x_prob = batch_x_lpc[:x_bs-model.lpc_offset].unsqueeze(0) else: batch_x_prob = F.pad(batch_x_lpc[:x_bs], (-(x_ss+model.lpc_offset), 0), \ 'constant', 0).unsqueeze(0) batch_x = batch_x_[:x_bs-model.seg] batch_x_float = batch_x_[model.seg:x_bs] else: if model.lpc > 0: if x_ss+model.lpc_offset > 0: batch_x_prob = batch_x_lpc.unsqueeze(0) else: batch_x_prob = F.pad(batch_x_lpc, (-(x_ss+model.lpc_offset)), \ 'constant', 0).unsqueeze(0) batch_x = batch_x_[:-model.seg] batch_x_float = batch_x_[model.seg:] batch_h = batch_h.transpose(0,1).unsqueeze(0) batch_x = batch_x.unsqueeze(0).unsqueeze(1) if h_ss > 0: feat_len = batch_x_float[model.receptive_field:].shape[0] else: feat_len = batch_x_float.shape[0] if model.lpc > 0: mus, bs, log_bs, ass = model(batch_h, batch_x) # jump off s samples as in synthesis mus = mus[:,::model.seg,:] bs = bs[:,::model.seg,:] log_bs = log_bs[:,::model.seg,:] ass = ass[:,::model.seg,:].flip(-1) init_mus = mus for j in range(model.seg): tmp_smpls = batch_x_prob[:,j:-(model.seg-j)].unfold(1, model.lpc, model.seg) lpc = torch.sum(ass*tmp_smpls,-1,keepdim=True) if j > 0: mus = torch.cat((mus, lpc+init_mus[:,:,j:j+1]),2) else: mus = lpc+init_mus[:,:,j:j+1] mus = mus.reshape(mus.shape[0],-1) bs = bs.reshape(bs.shape[0],-1) log_bs = log_bs.reshape(log_bs.shape[0],-1) else: mus, bs, log_bs = model(batch_h, batch_x) if h_ss > 0: mus = mus[0,model.receptive_field:] bs = bs[0,model.receptive_field:] log_bs = log_bs[0,model.receptive_field:] batch_x_float = batch_x_float[model.receptive_field:] else: mus = mus[0] bs = bs[0] log_bs = log_bs[0] m_sum = 0 batch_loss_laplace = criterion_laplace(mus, bs, batch_x_float, log_b=log_bs) eps = torch.empty(mus.shape).cuda().uniform_(-0.4999,0.5) batch_output = mus-bs*eps.sign()*torch.log1p(-2*eps.abs()) batch_loss_err = torch.mean(torch.abs(batch_output-batch_x_float)) logging.info("%lf %E %lf %E" % (torch.min(batch_x_float), torch.mean(batch_x_float), \ torch.max(batch_x_float), torch.var(batch_x_float))) logging.info("%lf %E %lf %E" % (torch.min(batch_output), torch.mean(batch_output), \ torch.max(batch_output), torch.var(batch_output))) m = 0 for i in range(args.n_fft_facts): if feat_len > int(fft_facts[i]/2): stft_out[i] = torch.stft(batch_output, fft_facts[i], window=hann_win[i]) stft_trg[i] = torch.stft(batch_x_float, fft_facts[i], window=hann_win[i]) tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i]) if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss): if m > 0: batch_loss_lsd = torch.cat((batch_loss_lsd, \ tmp_batch_stft_loss.unsqueeze(0))) else: batch_loss_lsd = tmp_batch_stft_loss.unsqueeze(0) m += 1 loss_err.append(batch_loss_err.item()) loss_laplace.append(batch_loss_laplace.item()) if m > 0: batch_loss_lsd = torch.mean(batch_loss_lsd) loss_lsd.append(batch_loss_lsd.item()) logging.info("batch eval loss %s [%d:%d] %d %d %d %d %d %d = %.3f %.3f dB %.6f "\ "(%.3f sec)" % (os.path.join(os.path.basename(os.path.dirname(wavfile)),\ os.path.basename(wavfile)), c_idx+1, utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, \ batch_loss_laplace.item(), batch_loss_lsd.item(), \ batch_loss_err.item(), time.time() - start)) else: logging.info("batch eval loss %s [%d:%d] %d %d %d %d %d %d = %.3f n/a %.6f "\ "(%.3f sec)" % (os.path.join(os.path.basename(os.path.dirname(wavfile)),\ os.path.basename(wavfile)), c_idx+1, utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, \ batch_loss_laplace.item(), batch_loss_err.item(), time.time() - start)) iter_count += 1 total += time.time() - start eval_loss_lsd = np.mean(loss_lsd) eval_loss_lsd_std = np.std(loss_lsd) eval_loss_err = np.mean(loss_err) eval_loss_err_std = np.std(loss_err) eval_loss_laplace = np.mean(loss_laplace) eval_loss_laplace_std = np.std(loss_laplace) logging.info("(EPOCH:%d) average evaluation loss = %.6f (+- %.6f) %.6f dB (+- %.6f dB) %.6f "\ "(+- %.6f) (%.3f min., %.3f sec / batch)" % (epoch_idx + 1, eval_loss_laplace, \ eval_loss_laplace_std, eval_loss_lsd, eval_loss_lsd_std, eval_loss_err, eval_loss_err_std, \ total / 60.0, total / iter_count)) if (eval_loss_laplace+eval_loss_laplace_std+eval_loss_lsd+eval_loss_lsd_std+eval_loss_err\ +eval_loss_err_std) <= (min_eval_loss_laplace+min_eval_loss_laplace_std+min_eval_loss_lsd\ +min_eval_loss_lsd_std+min_eval_loss_err+min_eval_loss_err_std): min_eval_loss_lsd = eval_loss_lsd min_eval_loss_lsd_std = eval_loss_lsd_std min_eval_loss_err = eval_loss_err min_eval_loss_err_std = eval_loss_err_std min_eval_loss_laplace = eval_loss_laplace min_eval_loss_laplace_std = eval_loss_laplace_std min_idx = epoch_idx logging.info("min_eval_loss = %.6f (+- %.6f) %.6f dB (+- %.6f dB) %.6f (+- %.6f) min_idx=%d" % ( min_eval_loss_laplace, min_eval_loss_laplace_std, min_eval_loss_lsd, min_eval_loss_lsd_std, \ min_eval_loss_err, min_eval_loss_err_std, min_idx+1)) loss_lsd = [] loss_laplace = [] loss_err = [] total = 0 iter_count = 0 epoch_idx += 1 np.random.set_state(numpy_random_state) torch.set_rng_state(torch_random_state) model.train() for param in model.parameters(): param.requires_grad = True for param in model.scale_in.parameters(): param.requires_grad = False # start next epoch if epoch_idx < args.epoch_count: start = time.time() logging.info("==%d EPOCH==" % (epoch_idx+1)) logging.info("Training data") batch_x_float, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = next(generator) # feedforward and backpropagate current batch if epoch_idx < args.epoch_count: logging.info("%d iteration [%d]" % (iter_idx+1, epoch_idx+1)) tf = batch_h.shape[0] ts = batch_x_float.shape[0] batch_h = batch_h[h_ss:] batch_x_ = batch_x_float[x_ss:] if model.lpc > 0: if x_ss+model.lpc_offset >= 0: batch_x_lpc = batch_x_float[x_ss+model.lpc_offset:] else: batch_x_lpc = batch_x_float[x_ss:] if h_bs != -1: batch_h = batch_h[:h_bs] if model.lpc > 0: if x_ss+model.lpc_offset >= 0: batch_x_prob = batch_x_lpc[:x_bs-model.lpc_offset].unsqueeze(0) else: batch_x_prob = F.pad(batch_x_lpc[:x_bs], (-(x_ss+model.lpc_offset), 0), \ 'constant', 0).unsqueeze(0) batch_x = batch_x_[:x_bs-model.seg] batch_x_float = batch_x_[model.seg:x_bs] else: if model.lpc > 0: if x_ss+model.lpc_offset > 0: batch_x_prob = batch_x_lpc.unsqueeze(0) else: batch_x_prob = F.pad(batch_x_lpc, (-(x_ss+model.lpc_offset)), \ 'constant', 0).unsqueeze(0) batch_x = batch_x_[:-model.seg] batch_x_float = batch_x_[model.seg:] batch_h = batch_h.transpose(0,1).unsqueeze(0) batch_x = batch_x.unsqueeze(0).unsqueeze(1) if h_ss > 0: if model.seg > 1: feat_len = batch_x_float[model.receptive_field:-(model.seg-1)].shape[0] else: feat_len = batch_x_float[model.receptive_field:].shape[0] else: if model.seg > 1: feat_len = batch_x_float[:-(model.seg-1)].shape[0] else: feat_len = batch_x_float.shape[0] if model.lpc > 0: mus, bs_noclip, bs, log_bs, ass = model(batch_h, batch_x, do=True, clip=True) ass = ass.flip(-1) init_mus = mus for j in range(model.seg): tmp_smpls = batch_x_prob[:,j:-(model.seg-j)].unfold(1, model.lpc, 1) lpc = torch.sum(ass*tmp_smpls,-1,keepdim=True) if j > 0: mus = torch.cat((mus, lpc+init_mus[:,:,j:j+1]),2) else: mus = lpc+init_mus[:,:,j:j+1] if model.seg == 1: mus = mus.reshape(mus.shape[0], -1) bs_noclip = bs_noclip.reshape(mus.shape[0], -1) bs = bs.reshape(mus.shape[0], -1) log_bs = log_bs.reshape(mus.shape[0], -1) else: mus, bs_noclip, bs, log_bs = model(batch_h, batch_x, do=True, clip=True) if h_ss > 0: mus = mus[0,model.receptive_field:] bs_noclip = bs_noclip[0,model.receptive_field:] bs = bs[0,model.receptive_field:] log_bs = log_bs[0,model.receptive_field:] batch_x_float = batch_x_float[model.receptive_field:] else: mus = mus[0] bs_noclip = bs_noclip[0] bs = bs[0] log_bs = log_bs[0] m_sum = 0 if model.seg > 1: n_sum = 0 for i in range(model.seg): if i > 0: i_n = i+1 mus_i = mus[:,i:i_n].squeeze(-1) bs_noclip_i = bs_noclip[:,i:i_n].squeeze(-1) if i_n < model.seg: batch_x_float_i = batch_x_float[i:-(model.seg-(i_n))] else: batch_x_float_i = batch_x_float[i:] tmp_batch_loss_laplace = criterion_laplace(mus_i, bs[:,i:i_n].squeeze(-1), \ batch_x_float_i, log_b=log_bs[:,i:i_n].squeeze(-1), log=False) batch_loss_laplace = torch.cat((batch_loss_laplace, \ tmp_batch_loss_laplace.unsqueeze(0))) else: mus_i = mus[:,:1].squeeze(-1) bs_noclip_i = bs_noclip[:,:1].squeeze(-1) batch_x_float_i = batch_x_float[:-(model.seg-1)] tmp_batch_loss_laplace = criterion_laplace(mus_i, bs[:,:1].squeeze(-1), \ batch_x_float_i, log_b=log_bs[:,:1].squeeze(-1)) batch_loss_laplace = tmp_batch_loss_laplace.unsqueeze(0) eps = torch.empty(mus_i.shape).cuda().uniform_(-0.4999,0.5) batch_output = mus_i-bs_noclip_i*eps.sign()*torch.log1p(-2*eps.abs()) tmp_batch_loss_err = torch.mean(torch.abs(batch_output-batch_x_float_i)) if i > 0: batch_loss_err = torch.cat((batch_loss_err, tmp_batch_loss_err.unsqueeze(0))) else: batch_loss_err = tmp_batch_loss_err.unsqueeze(0) if i == 0: logging.info("%lf %E %lf %E" % (torch.min(batch_x_float_i), \ torch.mean(batch_x_float_i), torch.max(batch_x_float_i), torch.var(batch_x_float_i))) logging.info("%lf %E %lf %E" % (torch.min(batch_output), torch.mean(batch_output), \ torch.max(batch_output), torch.var(batch_output))) n = 0 for i in range(args.n_fft_facts): if feat_len > int(fft_facts[i]/2): stft_out[i] = torch.stft(batch_output, fft_facts[i], window=hann_win[i]) stft_trg[i] = torch.stft(batch_x_float_i, fft_facts[i], window=hann_win[i]) tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i], LSD=False, L2=False) if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss): if n > 0: tmp_batch_loss_stft_l1 = torch.cat((tmp_batch_loss_stft_l1, \ tmp_batch_stft_loss.unsqueeze(0))) else: tmp_batch_loss_stft_l1 = tmp_batch_stft_loss.unsqueeze(0) n += 1 if n > 0: if n_sum > 0: batch_loss_stft_l1 = torch.cat((batch_loss_stft_l1, \ torch.mean(tmp_batch_loss_stft_l1).unsqueeze(0))) else: batch_loss_stft_l1 = torch.mean(tmp_batch_loss_stft_l1).unsqueeze(0) n_sum += n m = 0 for i in range(args.n_fft_facts): if feat_len > int(fft_facts[i]/2): tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i]) if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss): if m > 0: tmp_batch_loss_lsd = torch.cat((tmp_batch_loss_lsd, \ tmp_batch_stft_loss.unsqueeze(0))) else: tmp_batch_loss_lsd = tmp_batch_stft_loss.unsqueeze(0) m += 1 if m > 0: if m_sum > 0: batch_loss_lsd = torch.cat((batch_loss_lsd, \ torch.mean(tmp_batch_loss_lsd).unsqueeze(0))) else: batch_loss_lsd = torch.mean(tmp_batch_loss_lsd).unsqueeze(0) m_sum += m batch_loss_laplace = torch.mean(batch_loss_laplace) batch_loss = batch_loss_laplace if n_sum > 0: batch_loss += torch.mean(batch_loss_stft_l1) if m_sum > 0: batch_loss_lsd = torch.mean(batch_loss_lsd) batch_loss_err = torch.mean(batch_loss_err) else: batch_loss_laplace = criterion_laplace(mus, bs, batch_x_float, log_b=log_bs) batch_loss = batch_loss_laplace eps = torch.empty(mus.shape).cuda().uniform_(-0.4999,0.5) batch_output = mus-bs_noclip*eps.sign()*torch.log1p(-2*eps.abs()) batch_loss_err = torch.mean(torch.abs(batch_output-batch_x_float)) logging.info("%lf %E %lf %E" % (torch.min(batch_x_float), torch.mean(batch_x_float), \ torch.max(batch_x_float), torch.var(batch_x_float))) logging.info("%lf %E %lf %E" % (torch.min(batch_output), torch.mean(batch_output), \ torch.max(batch_output), torch.var(batch_output))) n = 0 for i in range(args.n_fft_facts): if feat_len > int(fft_facts[i]/2): stft_out[i] = torch.stft(batch_output, fft_facts[i], window=hann_win[i]) stft_trg[i] = torch.stft(batch_x_float, fft_facts[i], window=hann_win[i]) tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i], LSD=False, L2=False) if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss): if n > 0: batch_loss_stft_l1 = torch.cat((batch_loss_stft_l1, \ tmp_batch_stft_loss.unsqueeze(0))) else: batch_loss_stft_l1 = tmp_batch_stft_loss.unsqueeze(0) n += 1 if n > 0: batch_loss += torch.mean(batch_loss_stft_l1) m = 0 for i in range(args.n_fft_facts): if feat_len > int(fft_facts[i]/2): tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i]) if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss): if m > 0: batch_loss_lsd = torch.cat((batch_loss_lsd, tmp_batch_stft_loss.unsqueeze(0))) else: batch_loss_lsd = tmp_batch_stft_loss.unsqueeze(0) m += 1 if m > 0: batch_loss_lsd = torch.mean(batch_loss_lsd) optimizer.zero_grad() batch_loss.backward() optimizer.step() loss_err.append(batch_loss_err.item()) loss_laplace.append(batch_loss_laplace.item()) if (model.seg > 1 and m_sum > 0) or (model.seg == 1 and m > 0): loss_lsd.append(batch_loss_lsd.item()) logging.info("batch loss %s [%d:%d] %d %d %d %d %d %d = %.3f %.3f dB %.6f (%.3f sec)" % ( os.path.basename(os.path.dirname(wavfile))+"/"+os.path.basename(wavfile), c_idx+1, \ utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, batch_loss_laplace.item(), \ batch_loss_lsd.item(), batch_loss_err.item(), time.time() - start)) else: logging.info("batch loss %s [%d:%d] %d %d %d %d %d %d = %.3f n/a %.6f (%.3f sec)" % ( os.path.basename(os.path.dirname(wavfile))+"/"+os.path.basename(wavfile), c_idx+1, \ utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, batch_loss_laplace.item(), \ batch_loss_err.item(), time.time() - start)) iter_idx += 1 iter_count += 1 total += time.time() - start # save final model model.cpu() torch.save({"model": model.state_dict()}, args.expdir + "/checkpoint-final.pkl") logging.info("final checkpoint created.")
def __getitem__(self, idx): featfile_src = self.file_list[idx] h_src = read_hdf5(featfile_src, "/feat_org_lf0") flen_src = h_src.shape[0] src_code = np.zeros((flen_src, self.n_spk)) spk_idx, mean_trg_list, std_trg_list, src_trg_code_list, featfile_spk, pair_spk_list, src_class_code, \ trg_class_code_list = proc_multspk_data_mix_random_cls_statcvexcit(featfile_src, self.spk_list, \ self.n_cyc, src_code, self.n_spk, self.spk_idx_dict, self.stat_spk_list) mean_src = read_hdf5(self.stat_spk_list[spk_idx], "/mean_feat_org_lf0")[1:2] std_src = read_hdf5(self.stat_spk_list[spk_idx], "/scale_feat_org_lf0")[1:2] if check_hdf5(featfile_src, "/spcidx_range"): spcidx_src = read_hdf5(featfile_src, "/spcidx_range")[0] else: spk_f0rate = os.path.basename(os.path.dirname(featfile_src)) spk_ = spk_f0rate.split('_')[0] spcidx_src = read_hdf5(os.path.join(os.path.dirname(os.path.dirname(featfile_src)), spk_, \ os.path.basename(featfile_src)), "/spcidx_range")[0] flen_spc_src = spcidx_src.shape[0] src_code = torch.FloatTensor(self.pad_transform(src_code)) src_class_code = torch.LongTensor(self.pad_transform(src_class_code)) cv_src_list = [None] * self.n_cyc for i in range(self.n_cyc): cv_src_list[i] = torch.FloatTensor(self.pad_transform(np.c_[h_src[:,:1], \ (std_trg_list[i]/std_src)*(h_src[:,1:2]-mean_src)+mean_trg_list[i], \ h_src[:,2:self.stdim]])) src_trg_code_list[i] = torch.FloatTensor( self.pad_transform(src_trg_code_list[i])) trg_class_code_list[i] = torch.LongTensor( self.pad_transform(trg_class_code_list[i])) h_src = torch.FloatTensor(self.pad_transform(h_src)) spcidx_src = torch.LongTensor(self.pad_transform(spcidx_src)) file_src_trg_flag = False if self.pair_utt_flag: featfile_src_trg = os.path.dirname(os.path.dirname(featfile_src))+"/"+pair_spk_list[0]+\ "/"+os.path.basename(featfile_src) if os.path.exists(featfile_src_trg): file_src_trg_flag = True h_src_trg = read_hdf5(featfile_src_trg, "/feat_org_lf0") flen_src_trg = h_src_trg.shape[0] if check_hdf5(featfile_src_trg, "/spcidx_range"): spcidx_src_trg = read_hdf5(featfile_src_trg, "/spcidx_range")[0] else: spk_f0rate = os.path.basename( os.path.dirname(featfile_src_trg)) spk_ = spk_f0rate.split('_')[0] spcidx_src_trg = read_hdf5(os.path.join(os.path.dirname(\ os.path.dirname(featfile_src_trg)), spk_,\ os.path.basename(featfile_src_trg)), "/spcidx_range")[0] flen_spc_src_trg = spcidx_src_trg.shape[0] h_src_trg = torch.FloatTensor(self.pad_transform(h_src_trg)) spcidx_src_trg = torch.LongTensor( self.pad_transform(spcidx_src_trg)) else: h_src_trg = h_src flen_src_trg = flen_src spcidx_src_trg = spcidx_src flen_spc_src_trg = flen_spc_src return {'h_src': h_src, 'flen_src': flen_src, 'src_code': src_code, \ 'src_trg_code_list': src_trg_code_list, 'cv_src_list': cv_src_list, \ 'spcidx_src': spcidx_src, 'flen_spc_src': flen_spc_src, 'h_src_trg': h_src_trg, \ 'flen_src_trg': flen_src_trg, 'spcidx_src_trg': spcidx_src_trg, \ 'flen_spc_src_trg': flen_spc_src_trg, 'featfile_src': featfile_src, \ 'featfile_src_trg': featfile_src_trg, 'featfile_spk': featfile_spk, \ 'pair_spk_list': pair_spk_list, 'src_class_code': src_class_code, \ 'trg_class_code_list': trg_class_code_list, 'file_src_trg_flag': file_src_trg_flag} else: return {'h_src': h_src, 'flen_src': flen_src, 'src_code': src_code, \ 'src_trg_code_list': src_trg_code_list, 'cv_src_list': cv_src_list, \ 'spcidx_src': spcidx_src, 'flen_spc_src': flen_spc_src, 'featfile_src': featfile_src, \ 'featfile_spk': featfile_spk, 'pair_spk_list': pair_spk_list, \ 'src_class_code': src_class_code, 'trg_class_code_list': trg_class_code_list, \ 'file_src_trg_flag': file_src_trg_flag}
def __getitem__(self, idx): featfile_src = self.file_list_src[idx] featfile_src_trg = self.file_list_src_trg[idx] file_src_trg_flag = self.list_src_trg_flag[idx] spk_src = os.path.basename(os.path.dirname(featfile_src)) spk_trg = os.path.basename(os.path.dirname(featfile_src_trg)) idx_src = self.spk_idx_dict[spk_src] idx_trg = self.spk_idx_dict[spk_trg] mean_src = read_hdf5(self.stat_spk_list[idx_src], "/mean_feat_org_lf0")[1:2] std_src = read_hdf5(self.stat_spk_list[idx_src], "/scale_feat_org_lf0")[1:2] mean_trg = read_hdf5(self.stat_spk_list[idx_trg], "/mean_feat_org_lf0")[1:2] std_trg = read_hdf5(self.stat_spk_list[idx_trg], "/scale_feat_org_lf0")[1:2] h_src = read_hdf5(featfile_src, "/feat_org_lf0") flen_src = h_src.shape[0] src_code = np.zeros((flen_src, self.n_spk)) src_trg_code = np.zeros((flen_src, self.n_spk)) src_code[:, idx_src] = 1 src_trg_code[:, idx_trg] = 1 cv_src = np.c_[h_src[:, :1], (std_trg / std_src) * (h_src[:, 1:2] - mean_src) + mean_trg, h_src[:, 2:self.stdim]] if check_hdf5(featfile_src, "/spcidx_range"): spcidx_src = read_hdf5(featfile_src, "/spcidx_range")[0] else: spk_f0rate = os.path.basename(os.path.dirname(featfile_src)) spk_ = spk_f0rate.split('_')[0] spcidx_src = read_hdf5(os.path.join(os.path.dirname(os.path.dirname(featfile_src)), spk_, \ os.path.basename(featfile_src)), "/spcidx_range")[0] src_class_code = np.ones(h_src.shape[0], dtype=np.int64) * idx_src src_trg_class_code = np.ones(h_src.shape[0], dtype=np.int64) * idx_trg flen_spc_src = spcidx_src.shape[0] if file_src_trg_flag: h_src_trg = read_hdf5(featfile_src_trg, "/feat_org_lf0") flen_src_trg = h_src_trg.shape[0] trg_code = np.zeros((flen_src_trg, self.n_spk)) trg_src_code = np.zeros((flen_src_trg, self.n_spk)) trg_code[:, idx_trg] = 1 trg_src_code[:, idx_src] = 1 cv_trg = np.c_[h_src_trg[:,:1], (std_src/std_trg)*(h_src_trg[:,1:2]-mean_trg)+mean_src, \ h_src_trg[:,2:self.stdim]] if check_hdf5(featfile_src_trg, "/spcidx_range"): spcidx_src_trg = read_hdf5(featfile_src_trg, "/spcidx_range")[0] else: spk_f0rate = os.path.basename( os.path.dirname(featfile_src_trg)) spk_ = spk_f0rate.split('_')[0] spcidx_src_trg = read_hdf5(os.path.join(os.path.dirname(os.path.dirname(featfile_src_trg)), \ spk_, os.path.basename(featfile_src_trg)), "/spcidx_range")[0] trg_class_code = np.ones(h_src_trg.shape[0], dtype=np.int64) * idx_trg trg_src_class_code = np.ones(h_src_trg.shape[0], dtype=np.int64) * idx_src flen_spc_src_trg = spcidx_src_trg.shape[0] h_src = torch.FloatTensor(self.pad_transform(h_src)) src_code = torch.FloatTensor(self.pad_transform(src_code)) src_trg_code = torch.FloatTensor(self.pad_transform(src_trg_code)) cv_src = torch.FloatTensor(self.pad_transform(cv_src)) spcidx_src = torch.LongTensor(self.pad_transform(spcidx_src)) src_class_code = torch.LongTensor(self.pad_transform(src_class_code)) src_trg_class_code = torch.LongTensor( self.pad_transform(src_trg_class_code)) if file_src_trg_flag: h_src_trg = torch.FloatTensor(self.pad_transform(h_src_trg)) trg_code = torch.FloatTensor(self.pad_transform(trg_code)) trg_src_code = torch.FloatTensor(self.pad_transform(trg_src_code)) cv_trg = torch.FloatTensor(self.pad_transform(cv_trg)) spcidx_src_trg = torch.LongTensor( self.pad_transform(spcidx_src_trg)) trg_class_code = torch.LongTensor( self.pad_transform(trg_class_code)) trg_src_class_code = torch.LongTensor( self.pad_transform(trg_src_class_code)) else: flen_src_trg = flen_src h_src_trg = h_src trg_code = src_code trg_src_code = src_trg_code cv_trg = cv_src spcidx_src_trg = spcidx_src trg_class_code = src_class_code trg_src_class_code = src_trg_class_code flen_spc_src_trg = flen_spc_src return {'h_src': h_src, 'flen_src': flen_src, 'src_code': src_code, 'src_trg_code': src_trg_code, \ 'cv_src': cv_src, 'spcidx_src': spcidx_src, 'flen_spc_src': flen_spc_src, \ 'h_src_trg': h_src_trg, 'flen_src_trg': flen_src_trg, 'trg_code': trg_code, \ 'trg_src_code': trg_src_code, 'cv_trg': cv_trg, 'spcidx_src_trg': spcidx_src_trg, \ 'flen_spc_src_trg': flen_spc_src_trg, 'featfile_src': featfile_src, \ 'featfile_src_trg': featfile_src_trg, 'src_class_code': src_class_code, \ 'src_trg_class_code': src_trg_class_code, 'trg_class_code': trg_class_code, \ 'trg_src_class_code': trg_src_class_code, 'file_src_trg_flag': file_src_trg_flag}