def __getitem__(self, idx):
        wavfile = self.wav_list[idx]
        featfile = self.feat_list[idx]

        x, _ = sf.read(wavfile, dtype=np.float32)
        if check_hdf5(featfile, self.string_path):
            h = read_hdf5(featfile, self.string_path)
        else:
            h = read_hdf5(featfile, self.string_path_org)

        x, h = validate_length(x, h, self.upsampling_factor)

        if self.wav_transform_in is not None:
            x_t = self.wav_transform_in(x)
        if self.wav_transform is not None:
            if self.wav_transform_out is not None:
                x = self.wav_transform_out(self.wav_transform(x))
            else:
                x = self.wav_transform(x)

        slen = x.shape[0]
        flen = h.shape[0]

        h = torch.FloatTensor(self.pad_feat_transform(h))
        if self.wav_transform is not None and self.wav_transform_out is None:
            x = torch.LongTensor(self.pad_wav_transform(x))
        else:
            x = torch.FloatTensor(self.pad_wav_transform(x))

        if self.wav_transform_in is not None:
            x_t = torch.LongTensor(self.pad_wav_transform(x_t))
            return {
                'x_t': x_t,
                'x': x,
                'feat': h,
                'slen': slen,
                'flen': flen,
                'featfile': featfile
            }
        else:
            return {
                'x': x,
                'feat': h,
                'slen': slen,
                'flen': flen,
                'featfile': featfile
            }
예제 #2
0
def main():
    parser = argparse.ArgumentParser()
    # path setting
    parser.add_argument("--waveforms", required=True,
                        type=str, help="directory or list of wav files")
    parser.add_argument("--waveforms_eval", required=True,
                        type=str, help="directory or list of evaluation wav files")
    parser.add_argument("--feats", required=True,
                        type=str, help="directory or list of aux feat files")
    parser.add_argument("--feats_eval", required=True,
                        type=str, help="directory or list of evaluation aux feat files")
    parser.add_argument("--stats", required=True,
                        type=str, help="hdf5 file including statistics")
    parser.add_argument("--expdir", required=True,
                        type=str, help="directory to save the model")
    # network structure setting
    parser.add_argument("--n_quantize", default=256,
                        type=int, help="number of quantization")
    parser.add_argument("--n_aux", default=39,
                        type=int, help="number of dimension of aux feats")
    parser.add_argument("--dilation_depth", default=3,
                        type=int, help="depth of dilation")
    parser.add_argument("--dilation_repeat", default=3,
                        type=int, help="depth of dilation")
    parser.add_argument("--hid_chn", default=192,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--skip_chn", default=256,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--kernel_size", default=6,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--aux_kernel_size", default=3,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--aux_dilation_size", default=2,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--upsampling_factor", default=110,
                        type=int, help="upsampling factor of aux features"
                                       "(if set 0, do not apply)")
    parser.add_argument("--string_path", default="/feat_org_lf0",
                        type=str, help="directory to save the model")
    # network training setting
    parser.add_argument("--lr", default=1e-4,
                        type=float, help="learning rate")
    parser.add_argument("--batch_size", default=1100,
                        type=int, help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--epoch_count", default=500,
                        type=int, help="number of training epochs")
    parser.add_argument("--do_prob", default=0,
                        type=float, help="dropout probability")
    parser.add_argument("--wav_conv_flag", default=False,
                        type=strtobool, help="flag to use 1d conv of wav")
    # other setting
    parser.add_argument("--audio_in", default=False,
        type=strtobool, help="flag for including previous sample in conditioning feat")
    parser.add_argument("--seed", default=1,
                        type=int, help="seed number")
    parser.add_argument("--resume", default=None,
                        type=str, help="model path to restart training")
    parser.add_argument("--pretrained", default=None,
                        type=str, help="model path to restart training")
    parser.add_argument("--GPU_device", default=None,
                        type=int, help="selection of GPU device")
    parser.add_argument("--verbose", default=1,
                        type=int, help="log level")
    args = parser.parse_args()

    if args.GPU_device is not None:
        os.environ["CUDA_DEVICE_ORDER"]     = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"]  = str(args.GPU_device)

    # make experimental directory
    if not os.path.exists(args.expdir):
        os.makedirs(args.expdir)

    # set log level
    if args.verbose == 1:
        logging.basicConfig(level=logging.INFO,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    elif args.verbose > 1:
        logging.basicConfig(level=logging.DEBUG,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    else:
        logging.basicConfig(level=logging.WARN,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
        logging.warn("logging is disabled.")

    # fix seed
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    torch.backends.cudnn.benchmark = True #faster
    #torch.backends.cudnn.deterministic = True #reproducibility_slower
    #torch.backends.cudnn.benchmark = False #reproducibility_slower

    # save args as conf
    torch.save(args, args.expdir + "/model.conf")

    # # define network
    model = DSWNV(
        n_quantize=args.n_quantize,
        n_aux=args.n_aux,
        hid_chn=args.hid_chn,
        skip_chn=args.skip_chn,
        dilation_depth=args.dilation_depth,
        dilation_repeat=args.dilation_repeat,
        kernel_size=args.kernel_size,
        aux_kernel_size=args.aux_kernel_size,
        aux_dilation_size=args.aux_dilation_size,
        audio_in_flag=args.audio_in,
        do_prob=args.do_prob,
        wav_conv_flag=args.wav_conv_flag,
        upsampling_factor=args.upsampling_factor)
    logging.info(model)
    criterion = nn.CrossEntropyLoss()

    # define transforms
    string_path_name = args.string_path.split('feat_')[1]
    logging.info(string_path_name)
    scaler = StandardScaler()
    if check_hdf5(args.stats, "/mean_"+string_path_name):
        scaler.mean_ = read_hdf5(args.stats, "/mean_"+string_path_name)
        scaler.scale_ = read_hdf5(args.stats, "/scale_"+string_path_name)
    elif check_hdf5(args.stats, "/mean_"+args.string_path):
        scaler.mean_ = read_hdf5(args.stats, "/mean_"+args.string_path)
        scaler.scale_ = read_hdf5(args.stats, "/scale_"+args.string_path)
    else:
        scaler.mean_ = read_hdf5(args.stats, "/mean_feat_"+string_path_name)
        scaler.scale_ = read_hdf5(args.stats, "/scale_feat_"+string_path_name)
    mean_src = torch.FloatTensor(scaler.mean_)
    std_src = torch.FloatTensor(scaler.scale_)

    # send to gpu
    if torch.cuda.is_available():
        model.cuda()
        criterion.cuda()
        mean_src = mean_src.cuda()
        std_src = std_src.cuda()
    else:
        logging.error("gpu is not available. please check the setting.")
        sys.exit(1)

    model.train()
    model.apply(initialize)
    model.scale_in.weight = torch.nn.Parameter(torch.unsqueeze(torch.diag(1.0/std_src.data),2))
    model.scale_in.bias = torch.nn.Parameter(-(mean_src.data/std_src.data))

    for param in model.parameters():
        param.requires_grad = True
    for param in model.scale_in.parameters():
        param.requires_grad = False

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    parameters = sum([np.prod(p.size()) for p in parameters]) / 1000000
    logging.info('Trainable Parameters: %.3f million' % parameters)

    module_list = list(model.conv_aux.parameters()) + list(model.upsampling.parameters())
    if model.wav_conv_flag:
        module_list += list(model.wav_conv.parameters())
    module_list += list(model.causal.parameters())
    module_list += list(model.in_x.parameters()) + list(model.dil_h.parameters())
    module_list += list(model.out_skip.parameters())
    module_list += list(model.out_1.parameters()) + list(model.out_2.parameters())
    optimizer = torch.optim.Adam(module_list, lr=args.lr)

    # resume
    if args.pretrained is not None:
        checkpoint = torch.load(args.pretrained)
        model.load_state_dict(checkpoint["model"])
        epoch_idx = checkpoint["iterations"]
        logging.info("pretrained from %d-iter checkpoint." % epoch_idx)
        epoch_idx = 0
    elif args.resume is not None:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epoch_idx = checkpoint["iterations"]
        logging.info("restored from %d-iter checkpoint." % epoch_idx)
    else:
        epoch_idx = 0

    wav_transform = transforms.Compose([lambda x: encode_mu_law(x, args.n_quantize)])

    # define generator training
    if os.path.isdir(args.waveforms):
        filenames = sorted(find_files(args.waveforms, "*.wav", use_dir_name=False))
        wav_list = [args.waveforms + "/" + filename for filename in filenames]
        feat_list = [args.feats + "/" + filename.replace(".wav", ".h5") for filename in filenames]
    elif os.path.isfile(args.waveforms):
        wav_list = read_txt(args.waveforms)
        feat_list = read_txt(args.feats)
    else:
        logging.error("--waveforms should be directory or list.")
        sys.exit(1)
    assert len(wav_list) == len(feat_list)
    logging.info("number of training data = %d." % len(wav_list))
    if args.pretrained is None:
        generator = train_generator(
            wav_list, feat_list,
            model.receptive_field,
            string_path=args.string_path,
            batch_size=args.batch_size,
            wav_transform=wav_transform,
            training=True,
            upsampling_factor=args.upsampling_factor)
    else:
        generator = train_generator(
            wav_list, feat_list,
            model.receptive_field,
            string_path=args.string_path,
            batch_size=args.batch_size,
            wav_transform=wav_transform,
            training=True,
            upsampling_factor=args.upsampling_factor)

    # define generator evaluation
    if os.path.isdir(args.waveforms_eval):
        filenames_eval = sorted(find_files(args.waveforms_eval, "*.wav", use_dir_name=False))
        wav_list_eval = [args.waveforms_eval + "/" + filename for filename in filenames_eval]
        feat_list_eval = [args.feats_eval + "/" + filename.replace(".wav", ".h5") \
                        for filename in filenames_eval]
    elif os.path.isfile(args.waveforms_eval):
        wav_list_eval = read_txt(args.waveforms_eval)
        feat_list_eval = read_txt(args.feats_eval)
    else:
        logging.error("--waveforms_eval should be directory or list.")
        sys.exit(1)
    logging.info("number of evaluation data = %d." % len(wav_list_eval))
    assert len(wav_list_eval) == len(feat_list_eval)
    if args.pretrained is None:
        generator_eval = train_generator(
            wav_list_eval, feat_list_eval,
            model.receptive_field,
            string_path=args.string_path,
            batch_size=args.batch_size,
            wav_transform=wav_transform,
            training=False,
            upsampling_factor=args.upsampling_factor)
    else:
        generator_eval = train_generator(
            wav_list_eval, feat_list_eval,
            model.receptive_field,
            string_path=args.string_path,
            batch_size=args.batch_size,
            wav_transform=wav_transform,
            training=False,
            upsampling_factor=args.upsampling_factor)

    # train
    loss = []
    total = 0
    iter_idx = 0
    iter_count = 0
    min_eval_loss = 99999999.99
    min_eval_loss_std = 99999999.99
    min_idx = -1
    if args.resume is not None:
        np.random.set_state(checkpoint["numpy_random_state"])
        torch.set_rng_state(checkpoint["torch_random_state"])
    logging.info("==%d EPOCH==" % (epoch_idx+1))
    logging.info("Training data")
    while epoch_idx < args.epoch_count:
        start = time.time()
        batch_x_class, batch_x, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = next(generator)
        if c_idx < 0: # summarize epoch
            numpy_random_state = np.random.get_state()
            torch_random_state = torch.get_rng_state()
            # save current epoch model
            save_checkpoint(args.expdir, model, optimizer, numpy_random_state, torch_random_state, epoch_idx+1)
            # report current epoch
            logging.info("(EPOCH:%d) average training loss = %.6f (+- %.6f) (%.3f min., %.3f sec / batch)" % (
                epoch_idx + 1, np.mean(np.array(loss, dtype=np.float64)), \
                np.std(np.array(loss, dtype=np.float64)), total / 60.0, total / iter_count))
            logging.info("estimated training required time = {0.days:02}:{0.hours:02}:{0.minutes:02}:"\
            "{0.seconds:02}".format(relativedelta(seconds=int((args.epoch_count - (epoch_idx + 1)) * total))))
            # compute loss in evaluation data
            loss = []
            total = 0
            iter_count = 0
            model.eval()
            for param in model.parameters():
                param.requires_grad = False
            logging.info("Evaluation data")
            with torch.no_grad():
                while True:
                    start = time.time()
                    batch_x_class, batch_x, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = \
                        next(generator_eval)
                    if c_idx < 0:
                        break

                    tf = batch_h.shape[0]
                    ts = batch_x.shape[0]

                    batch_h = batch_h[h_ss:]
                    batch_x_class = batch_x_class[x_ss:]
                    batch_x = batch_x[x_ss:]
                    if h_bs != -1:
                        batch_h = batch_h[:h_bs]
                        batch_x_class = batch_x_class[1:x_bs]
                        batch_x = batch_x[:x_bs-1]
                    else:
                        batch_x = batch_x[:-1]
                        batch_x_class = batch_x_class[1:]
                    batch_h = batch_h.transpose(0,1).unsqueeze(0)
                    batch_x = batch_x.transpose(0,1).unsqueeze(0)

                    batch_output = model(batch_x, batch_h)[0]

                    if h_ss > 0:
                        batch_loss = criterion(batch_output[model.receptive_field:], \
                                                batch_x_class[model.receptive_field:])
                    else:
                        batch_loss = criterion(batch_output, batch_x_class)

                    loss.append(batch_loss.item())
                    logging.info("batch eval loss %s [%d:%d] %d %d %d %d %d %d = %.3f (%.3f sec)" % (
                        os.path.basename(os.path.dirname(wavfile))+"/"+os.path.basename(wavfile), c_idx+1, \
                            utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, batch_loss.item(), time.time() - start))
                    iter_count += 1
                    total += time.time() - start
            eval_loss = np.mean(np.array(loss, dtype=np.float64))
            eval_loss_std = np.std(np.array(loss, dtype=np.float64))
            logging.info("(EPOCH:%d) average evaluation loss = %.6f (+- %.6f) (%.3f min., %.3f sec / batch)" %(
                epoch_idx + 1, eval_loss, eval_loss_std, total / 60.0, total / iter_count))
            if (eval_loss+eval_loss_std) <= (min_eval_loss+min_eval_loss_std):
                min_eval_loss = eval_loss
                min_eval_loss_std = eval_loss_std
                min_idx = epoch_idx
            logging.info("min_eval_loss=%.6f (+- %.6f), min_idx=%d" % (\
                            min_eval_loss, min_eval_loss_std, min_idx+1))
            loss = []
            total = 0
            iter_count = 0
            epoch_idx += 1
            np.random.set_state(numpy_random_state)
            torch.set_rng_state(torch_random_state)
            model.train()
            for param in model.parameters():
                param.requires_grad = True
            for param in model.scale_in.parameters():
                param.requires_grad = False
            # start next epoch
            if epoch_idx < args.epoch_count:
                start = time.time()
                logging.info("==%d EPOCH==" % (epoch_idx+1))
                logging.info("Training data")
                batch_x_class, batch_x, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = \
                    next(generator)
        # feedforward and backpropagate current batch
        if epoch_idx < args.epoch_count:
            logging.info("%d iteration [%d]" % (iter_idx+1, epoch_idx+1))

            tf = batch_h.shape[0]
            ts = batch_x.shape[0]

            batch_h = batch_h[h_ss:]
            batch_x_class = batch_x_class[x_ss:]
            batch_x = batch_x[x_ss:]
            if h_bs != -1:
                batch_h = batch_h[:h_bs]
                batch_x_class = batch_x_class[1:x_bs]
                batch_x = batch_x[:x_bs-1]
            else:
                batch_x = batch_x[:-1]
                batch_x_class = batch_x_class[1:]
            batch_h = batch_h.transpose(0,1).unsqueeze(0)
            batch_x = batch_x.transpose(0,1).unsqueeze(0)

            batch_output = model(batch_x, batch_h, do=True)[0]

            if h_ss > 0:
                batch_loss = criterion(batch_output[model.receptive_field:], \
                                                batch_x_class[model.receptive_field:])
            else:
                batch_loss = criterion(batch_output, batch_x_class)

            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

            loss.append(batch_loss.item())
            logging.info("batch loss %s [%d:%d] %d %d %d %d %d %d = %.3f (%.3f sec)" % (
                os.path.basename(os.path.dirname(wavfile))+"/"+os.path.basename(wavfile), c_idx+1, utt_idx+1,
                    tf, ts, h_ss, h_bs, x_ss, x_bs, batch_loss.item(), time.time() - start))
            iter_idx += 1
            iter_count += 1
            total += time.time() - start

    # save final model
    model.cpu()
    torch.save({"model": model.state_dict()}, args.expdir + "/checkpoint-final.pkl")
    logging.info("final checkpoint created.")
예제 #3
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--feats",
                        default=None,
                        required=True,
                        help="name of the list of hdf5 files")
    parser.add_argument("--stats",
                        default=None,
                        required=True,
                        help="filename of hdf5 format")
    parser.add_argument("--expdir",
                        required=True,
                        type=str,
                        help="directory to save the log")
    parser.add_argument("--stdim",
                        default=5,
                        type=int,
                        help="directory to save the log")
    parser.add_argument("--spkr",
                        default=None,
                        type=str,
                        help="directory to save the log")
    parser.add_argument("--verbose",
                        default=1,
                        type=int,
                        help="log message level")

    args = parser.parse_args()

    # set log level
    if args.verbose == 1:
        logging.basicConfig(
            level=logging.INFO,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/calc_stats.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    elif args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/calc_stats.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/calc_stats.log")
        logging.getLogger().addHandler(logging.StreamHandler())
        logging.warn("logging is disabled.")

    # read list and define scaler
    filenames = read_txt(args.feats)
    scaler_feat_org_lf0 = StandardScaler()
    logging.info("number of training utterances = " + str(len(filenames)))

    #var = []
    var_range = []
    f0s_range = np.empty((0))
    # process over all of data
    for filename in filenames:
        logging.info(filename)
        feat_org_lf0 = read_hdf5(filename, "/feat_org_lf0")
        scaler_feat_org_lf0.partial_fit(feat_org_lf0)
        mcep_range = feat_org_lf0[:, args.stdim:]
        var_range.append(np.var(mcep_range, axis=0))
        logging.info(mcep_range.shape)
        if check_hdf5(filename, "/f0_range"):
            f0_range = read_hdf5(filename, "/f0_range")
        else:
            f0_range = read_hdf5(filename, "/f0")
        nonzero_indices = np.nonzero(f0_range)
        logging.info(f0_range[nonzero_indices].shape)
        logging.info(f0s_range.shape)
        f0s_range = np.concatenate([f0s_range, f0_range[nonzero_indices]])
        logging.info(f0s_range.shape)

    mean_feat_org_lf0 = scaler_feat_org_lf0.mean_
    scale_feat_org_lf0 = scaler_feat_org_lf0.scale_
    gv_range_mean = np.mean(np.array(var_range), axis=0)
    gv_range_var = np.var(np.array(var_range), axis=0)
    logging.info(gv_range_mean)
    logging.info(gv_range_var)
    f0_range_mean = np.mean(f0s_range)
    f0_range_std = np.std(f0s_range)
    logging.info(f0_range_mean)
    logging.info(f0_range_std)
    lf0_range_mean = np.mean(np.log(f0s_range))
    lf0_range_std = np.std(np.log(f0s_range))
    logging.info(lf0_range_mean)
    logging.info(lf0_range_std)
    logging.info(np.array_equal(f0_range_mean, np.exp(lf0_range_mean)))
    logging.info(np.array_equal(f0_range_std, np.exp(lf0_range_std)))

    logging.info(mean_feat_org_lf0)
    logging.info(scale_feat_org_lf0)
    write_hdf5(args.stats, "/mean_feat_org_lf0", mean_feat_org_lf0)
    write_hdf5(args.stats, "/scale_feat_org_lf0", scale_feat_org_lf0)
    write_hdf5(args.stats, "/gv_range_mean", gv_range_mean)
    write_hdf5(args.stats, "/gv_range_var", gv_range_var)
    write_hdf5(args.stats, "/f0_range_mean", f0_range_mean)
    write_hdf5(args.stats, "/f0_range_std", f0_range_std)
    write_hdf5(args.stats, "/lf0_range_mean", lf0_range_mean)
    write_hdf5(args.stats, "/lf0_range_std", lf0_range_std)
예제 #4
0
def world_feature_extract(queue, wav_list, args):
    """EXTRACT WORLD FEATURE VECTOR
    Parameters
    ----------
    queue : multiprocessing.Queue()
        the queue to store the file name of utterance
    wav_list : list
        list of the wav files
    args : 
        feature extract arguments
    """
    # define feature extractor
    feature_extractor = FeatureExtractor(analyzer="world",
                                         fs=args.fs,
                                         shiftms=args.shiftms,
                                         minf0=args.minf0,
                                         maxf0=args.maxf0,
                                         fftl=args.fftl)
    # extraction
    for i, wav_name in enumerate(wav_list):
        # check exists
        if args.feature_dir == None:
            feat_name = wav_name.replace("wav", args.feature_format)
        else:
            feat_name = rootdir_replace(wav_name,
                                        extname=args.feature_format,
                                        newdir=args.feature_dir)
        #if not os.path.exists(os.path.dirname(feat_name)):
        #    os.makedirs(os.path.dirname(feat_name))
        if check_hdf5(feat_name, "/world"):
            if args.overwrite:
                logging.info("overwrite %s (%d/%d)" %
                             (wav_name, i + 1, len(wav_list)))
            else:
                logging.info("skip %s (%d/%d)" %
                             (wav_name, i + 1, len(wav_list)))
                continue
        else:
            logging.info("now processing %s (%d/%d)" %
                         (wav_name, i + 1, len(wav_list)))
        # load wavfile and apply low cut filter
        fs, x = wavfile.read(wav_name)
        x = np.array(x, dtype=np.float32)
        if args.highpass_cutoff != 0:
            x = low_cut_filter(x, fs, cutoff=args.highpass_cutoff)

        # check sampling frequency
        if not fs == args.fs:
            logging.error("sampling frequency is not matched.")
            sys.exit(1)

        # extract features
        f0, spc, ap = feature_extractor.analyze(x)
        codeap = feature_extractor.codeap()
        mcep = feature_extractor.mcep(dim=args.mcep_dim, alpha=args.mcep_alpha)
        npow = feature_extractor.npow()
        uv, cont_f0 = convert_continuos_f0(f0)
        lpf_fs = int(1.0 / (args.shiftms * 0.001))
        cont_f0_lpf = low_pass_filter(cont_f0, lpf_fs, cutoff=20)
        next_cutoff = 70
        while not (cont_f0_lpf > [0]).all():
            logging.info("%s low-pass-filtered [%dHz]" %
                         (feat_name, next_cutoff))
            cont_f0_lpf = low_pass_filter(cont_f0, lpf_fs, cutoff=next_cutoff)
            next_cutoff *= 2

        # concatenate
        cont_f0_lpf = np.expand_dims(cont_f0_lpf, axis=-1)
        uv = np.expand_dims(uv, axis=-1)
        feats = np.concatenate([uv, cont_f0_lpf, mcep, codeap], axis=1)

        # save feature
        write_hdf5(feat_name, "/world", feats)
        if args.save_f0:
            write_hdf5(feat_name, "/f0", f0)
        if args.save_ap:
            write_hdf5(feat_name, "/ap", ap)
        if args.save_spc:
            write_hdf5(feat_name, "/spc", spc)
        if args.save_npow:
            write_hdf5(feat_name, "/npow", npow)
        if args.save_extended:
            # extend time resolution
            upsampling_factor = int(args.shiftms * fs * 0.001)
            feats_extended = extend_time(feats, upsampling_factor)
            feats_extended = feats_extended.astype(np.float32)
            write_hdf5(feat_name, "/world_extend", feats_extended)
        if args.save_vad:
            _, vad_idx = extfrm(mcep, npow, power_threshold=args.pow_th)
            write_hdf5(feat_name, "/vad_idx", vad_idx)
    queue.put('Finish')
예제 #5
0
def world_speech_synthesis(queue, wav_list, args):
    """WORLD SPEECH SYNTHESIS
    Parameters
    ----------
    queue : multiprocessing.Queue()
        the queue to store the file name of utterance
    wav_list : list
        list of the wav files
    args : 
        feature extract arguments
    """
    # define ynthesizer
    synthesizer = Synthesizer(fs=args.fs, fftl=args.fftl, shiftms=args.shiftms)
    # synthesis
    for i, wav_name in enumerate(wav_list):
        if args.feature_dir == None:
            restored_name = wav_name.replace("wav",
                                             args.feature_format + "_restored")
            restored_name = restored_name.replace(
                ".%s" % args.feature_format + "_restored", ".wav")
            feat_name = wav_name.replace("wav", args.feature_format)
        else:
            restored_name = rootdir_replace(wav_name,
                                            newdir=args.feature_dir +
                                            "restored")
            feat_name = rootdir_replace(wav_name,
                                        extname=args.feature_format,
                                        newdir=args.feature_dir)
        if os.path.exists(restored_name):
            if args.overwrite:
                logging.info("overwrite %s (%d/%d)" %
                             (restored_name, i + 1, len(wav_list)))
            else:
                logging.info("skip %s (%d/%d)" %
                             (restored_name, i + 1, len(wav_list)))
                continue
        else:
            logging.info("now processing %s (%d/%d)" %
                         (restored_name, i + 1, len(wav_list)))
        # load acoustic features
        if check_hdf5(feat_name, "/world"):
            h = read_hdf5(feat_name, "/world")
        else:
            logging.error("%s is not existed." % (feat_name))
            sys.exit(1)
        if check_hdf5(feat_name, "/f0"):
            f0 = read_hdf5(feat_name, "/f0")
        else:
            uv = h[:, 0].copy(order='C')
            f0 = h[:, args.f0_dim_idx].copy(order='C')  # cont_f0_lpf
            fz_idx = np.where(uv == 0.0)
            f0[fz_idx] = 0.0
        if check_hdf5(feat_name, "/ap"):
            ap = read_hdf5(feat_name, "/ap")
        else:
            codeap = h[:, args.ap_dim_idx:].copy(order='C')
            ap = pyworld.decode_aperiodicity(codeap, args.fs, args.fftl)
        mcep = h[:, args.mcep_dim_start:args.mcep_dim_end].copy(order='C')
        # waveform synthesis
        wav = synthesizer.synthesis(f0, mcep, ap, alpha=args.mcep_alpha)
        wav = np.clip(wav, -32768, 32767)
        wavfile.write(restored_name, args.fs, wav.astype(np.int16))
        #logging.info("wrote %s." % (restored_name))
    queue.put('Finish')
예제 #6
0
def main():
    parser = argparse.ArgumentParser()
    # path setting
    parser.add_argument("--waveforms", required=True,
                        type=str, help="directory or list of wav files")
    parser.add_argument("--waveforms_eval", required=True,
                        type=str, help="directory or list of evaluation wav files")
    parser.add_argument("--feats", required=True,
                        type=str, help="directory or list of aux feat files")
    parser.add_argument("--feats_eval", required=True,
                        type=str, help="directory or list of evaluation aux feat files")
    parser.add_argument("--stats", required=True,
                        type=str, help="hdf5 file including statistics")
    parser.add_argument("--expdir", required=True,
                        type=str, help="directory to save the model")
    # network structure setting
    parser.add_argument("--n_aux", default=54,
                        type=int, help="number of dimension of aux feats")
    parser.add_argument("--skip_chn", default=256,
                        type=int, help="number of channels of skip output")
    parser.add_argument("--seg", default=1,
                        type=int, help="segment size")
    parser.add_argument("--dilation_depth", default=3,
                        type=int, help="depth of dilation")
    parser.add_argument("--dilation_repeat", default=2,
                        type=int, help="repeat of dilation depth")
    parser.add_argument("--hid_chn", default=192,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--kernel_size", default=7,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--aux_kernel_size", default=3,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--aux_dilation_size", default=2,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--upsampling_factor", default=110,
                        type=int, help="upsampling factor of aux features"
                                       "(if set 0, do not apply)")
    parser.add_argument("--n_fft_facts", default=17,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--string_path", default="/feat_org_lf0",
                        type=str, help="directory to save the model")
    # network training setting
    parser.add_argument("--lr", default=1e-4,
                        type=float, help="learning rate")
    parser.add_argument("--batch_size", default=8800,
                        type=int, help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--epoch_count", default=4000,
                        type=int, help="number of training epochs")
    parser.add_argument("--do_prob", default=0,
                        type=float, help="dropout probability")
    parser.add_argument("--lpc", default=0,
                        type=int, help="number of linear predictive coefficients for location estimate")
    parser.add_argument("--aux_conv2d_flag", default=False,
                        type=strtobool, help="flag to use 2d conv of aux")
    parser.add_argument("--wav_conv_flag", default=False,
                        type=strtobool, help="flag to use 1d conv of wav")
    # other setting
    parser.add_argument("--seed", default=1,
                        type=int, help="seed number")
    parser.add_argument("--resume", default=None,
                        type=str, help="model path to restart training")
    parser.add_argument("--pretrained", default=None,
                        type=str, help="model path to restart training")
    parser.add_argument("--GPU_device", default=None,
                        type=int, help="selection of GPU device")
    parser.add_argument("--verbose", default=1,
                        type=int, help="log level")
    args = parser.parse_args()

    if args.GPU_device is not None:
        os.environ["CUDA_DEVICE_ORDER"]     = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"]  = str(args.GPU_device)

    # make experimental directory
    if not os.path.exists(args.expdir):
        os.makedirs(args.expdir)

    # set log level
    if args.verbose == 1:
        logging.basicConfig(level=logging.INFO,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    elif args.verbose > 1:
        logging.basicConfig(level=logging.DEBUG,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    else:
        logging.basicConfig(level=logging.WARN,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
        logging.warn("logging is disabled.")

    # fix seed
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    torch.backends.cudnn.benchmark = True #faster
    #torch.backends.cudnn.deterministic = True #reproducibility_slower
    #torch.backends.cudnn.benchmark = False #reproducibility_slower

    # save args as conf
    torch.save(args, args.expdir + "/model.conf")

    # define network
    model = CSWNV(
        n_aux=args.n_aux,
        skip_chn=args.skip_chn,
        hid_chn=args.hid_chn,
        dilation_depth=args.dilation_depth,
        dilation_repeat=args.dilation_repeat,
        kernel_size=args.kernel_size,
        aux_kernel_size=args.aux_kernel_size,
        aux_dilation_size=args.aux_dilation_size,
        do_prob=args.do_prob,
        seg=args.seg,
        lpc=args.lpc,
        aux_conv2d_flag=args.aux_conv2d_flag,
        wav_conv_flag=args.wav_conv_flag,
        upsampling_factor=args.upsampling_factor)
    logging.info(model)
    criterion_lsd = LSDloss()
    criterion_laplace = LaplaceLoss()

    # define transforms
    string_path_name = args.string_path.split('feat_')[1]
    logging.info(string_path_name)
    scaler = StandardScaler()
    if check_hdf5(args.stats, "/mean_"+string_path_name):
        scaler.mean_ = read_hdf5(args.stats, "/mean_"+string_path_name)
        scaler.scale_ = read_hdf5(args.stats, "/scale_"+string_path_name)
    elif check_hdf5(args.stats, "/mean_"+args.string_path):
        scaler.mean_ = read_hdf5(args.stats, "/mean_"+args.string_path)
        scaler.scale_ = read_hdf5(args.stats, "/scale_"+args.string_path)
    else:
        scaler.mean_ = read_hdf5(args.stats, "/mean_feat_"+string_path_name)
        scaler.scale_ = read_hdf5(args.stats, "/scale_feat_"+string_path_name)
    mean_src = torch.FloatTensor(scaler.mean_)
    std_src = torch.FloatTensor(scaler.scale_)

    # send to gpu
    if torch.cuda.is_available():
        model.cuda()
        criterion_lsd.cuda()
        criterion_laplace.cuda()
        mean_src = mean_src.cuda()
        std_src = std_src.cuda()
    else:
        logging.error("gpu is not available. please check the setting.")
        sys.exit(1)

    model.train()
    model.apply(initialize)
    model.scale_in.weight = torch.nn.Parameter(torch.unsqueeze(torch.diag(1.0/std_src.data),2))
    model.scale_in.bias = torch.nn.Parameter(-(mean_src.data/std_src.data))

    for param in model.parameters():
        param.requires_grad = True
    for param in model.scale_in.parameters():
        param.requires_grad = False

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    parameters = sum([np.prod(p.size()) for p in parameters]) / 1000000
    logging.info('Trainable Parameters: %.3f million' % parameters)

    module_list = list(model.conv_aux.parameters())
    module_list += list(model.upsampling.parameters())
    if model.aux_conv2d_flag and model.seg > 1:
        module_list += list(model.aux_conv2d.parameters())
    if model.wav_conv_flag:
        module_list += list(model.wav_conv.parameters())
    module_list += list(model.causal.parameters()) + list(model.in_x.parameters())
    module_list += list(model.dil_h.parameters()) + list(model.out_skip.parameters())
    module_list += list(model.out_1.parameters()) + list(model.out_2.parameters())
    optimizer = torch.optim.Adam(module_list, lr=args.lr)

    # resume
    if args.pretrained is not None:
        checkpoint = torch.load(args.pretrained)
        model.load_state_dict(checkpoint["model"])
        epoch_idx = checkpoint["iterations"]
        logging.info("pretrained from %d-iter checkpoint." % epoch_idx)
        epoch_idx = 0
    elif args.resume is not None:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epoch_idx = checkpoint["iterations"]
        logging.info("restored from %d-iter checkpoint." % epoch_idx)
    else:
        epoch_idx = 0

    # define generator training
    if os.path.isdir(args.waveforms):
        filenames = sorted(find_files(args.waveforms, "*.wav", use_dir_name=False))
        wav_list = [args.waveforms + "/" + filename for filename in filenames]
        feat_list = [args.feats + "/" + filename.replace(".wav", ".h5") for filename in filenames]
    elif os.path.isfile(args.waveforms):
        wav_list = read_txt(args.waveforms)
        feat_list = read_txt(args.feats)
    else:
        logging.error("--waveforms should be directory or list.")
        sys.exit(1)
    assert len(wav_list) == len(feat_list)
    logging.info("number of training data = %d." % len(wav_list))
    if args.pretrained is None:
        generator = train_generator(
            wav_list, feat_list,
            model.receptive_field,
            string_path=args.string_path,
            seg=model.seg,
            batch_size=args.batch_size,
            training=True,
            upsampling_factor=args.upsampling_factor)
    else:
        generator = train_generator(
            wav_list, feat_list,
            model.receptive_field,
            string_path=args.string_path,
            seg=model.seg,
            batch_size=args.batch_size,
            training=True,
            upsampling_factor=args.upsampling_factor)

    # define generator evaluation
    if os.path.isdir(args.waveforms_eval):
        filenames_eval = sorted(find_files(args.waveforms_eval, "*.wav", use_dir_name=False))
        wav_list_eval = [args.waveforms_eval + "/" + filename for filename in filenames_eval]
        feat_list_eval = [args.feats_eval + "/" + filename.replace(".wav", ".h5") \
                            for filename in filenames_eval]
    elif os.path.isfile(args.waveforms_eval):
        wav_list_eval = read_txt(args.waveforms_eval)
        feat_list_eval = read_txt(args.feats_eval)
    else:
        logging.error("--waveforms_eval should be directory or list.")
        sys.exit(1)
    logging.info("number of evaluation data = %d." % len(wav_list_eval))
    assert len(wav_list_eval) == len(feat_list_eval)
    if args.pretrained is None:
        generator_eval = train_generator(
            wav_list_eval, feat_list_eval,
            model.receptive_field,
            string_path=args.string_path,
            seg=model.seg,
            batch_size=args.batch_size,
            training=False,
            upsampling_factor=args.upsampling_factor)
    else:
        generator_eval = train_generator(
            wav_list_eval, feat_list_eval,
            model.receptive_field,
            string_path=args.string_path,
            seg=model.seg,
            batch_size=args.batch_size,
            training=False,
            upsampling_factor=args.upsampling_factor)

    # train
    loss_laplace = []
    loss_err = []
    loss_lsd = []
    fft_facts = []
    init_fft = 64
    hann_win = [None]*args.n_fft_facts
    if args.n_fft_facts == 5:
        fft_facts = [128, 256, 512, 1024, 2048]
        for i in range(args.n_fft_facts):
            hann_win[i] = torch.hann_window(fft_facts[i]).cuda()
    elif args.n_fft_facts == 9:
        fft_facts = [128, 192, 256, 384, 512, 768, 1024, 1536, 2048]
        for i in range(args.n_fft_facts):
            hann_win[i] = torch.hann_window(fft_facts[i]).cuda()
    elif args.n_fft_facts == 17:
        fft_facts = [128, 160, 192, 224, 256, 320, 384, 448, 512, 640, 768, 896, 1024, 1280, 1536, 1792, 2048]
        for i in range(args.n_fft_facts):
            hann_win[i] = torch.hann_window(fft_facts[i]).cuda()
    else:
        for i in range(args.n_fft_facts):
            if i % 2 == 0:
                init_fft *= 2
                fft_facts.append(init_fft)
            else:
                fft_facts.append(init_fft+int(init_fft/2))
            hann_win[i] = torch.hann_window(fft_facts[i]).cuda()
    logging.info(fft_facts)
    batch_stft_loss = [None]*args.n_fft_facts
    stft_out = [None]*args.n_fft_facts
    stft_trg = [None]*args.n_fft_facts
    total = 0
    iter_idx = 0
    iter_count = 0
    min_eval_loss_lsd = 99999999.99
    min_eval_loss_laplace = 99999999.99
    min_eval_loss_err = 99999999.99
    min_eval_loss_lsd_std = 99999999.99
    min_eval_loss_laplace_std = 99999999.99
    min_eval_loss_err_std = 99999999.99
    min_idx = -1
    if args.resume is not None:
        np.random.set_state(checkpoint["numpy_random_state"])
        torch.set_rng_state(checkpoint["torch_random_state"])
    logging.info("==%d EPOCH==" % (epoch_idx+1))
    logging.info("Training data")
    #args.epoch_count = 5300
    while epoch_idx < args.epoch_count:
        start = time.time()
        batch_x_float, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = next(generator)
        if c_idx < 0: # summarize epoch
            # save current epoch model
            numpy_random_state = np.random.get_state()
            torch_random_state = torch.get_rng_state()
            save_checkpoint(args.expdir, model, optimizer, numpy_random_state, torch_random_state, epoch_idx+1)
            # report current epoch
            logging.info("(EPOCH:%d) average training loss = %.6f (+- %.6f) %.6f dB (+- %.6f dB) %.6f "\
                            "(+- %.6f) (%.3f min., %.3f sec / batch)" % (
                epoch_idx + 1, np.mean(loss_laplace), np.std(loss_laplace), np.mean(loss_lsd), \
                    np.std(loss_lsd), np.mean(loss_err), np.std(loss_err), total / 60.0, total / iter_count))
            logging.info("estimated training required time = {0.days:02}:{0.hours:02}:{0.minutes:02}:"\
            "{0.seconds:02}".format(relativedelta(seconds=int((args.epoch_count - (epoch_idx + 1)) * total))))
            # compute loss in evaluation data
            loss_lsd = []
            loss_err = []
            loss_laplace = []
            total = 0
            iter_count = 0
            model.eval()
            for param in model.parameters():
                param.requires_grad = False
            logging.info("Evaluation data")
            while True:
                with torch.no_grad():
                    start = time.time()
                    batch_x_float, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = \
                        next(generator_eval)
                    if c_idx < 0:
                        break

                    tf = batch_h.shape[0]
                    ts = batch_x_float.shape[0]

                    batch_h = batch_h[h_ss:]
                    batch_x_ = batch_x_float[x_ss:]
                    if model.lpc > 0:
                        if x_ss+model.lpc_offset >= 0:
                            batch_x_lpc = batch_x_float[x_ss+model.lpc_offset:]
                        else:
                            batch_x_lpc = batch_x_float[x_ss:]
                    if h_bs != -1:
                        batch_h = batch_h[:h_bs]
                        if model.lpc > 0:
                            if x_ss+model.lpc_offset >= 0:
                                batch_x_prob = batch_x_lpc[:x_bs-model.lpc_offset].unsqueeze(0)
                            else:
                                batch_x_prob = F.pad(batch_x_lpc[:x_bs], (-(x_ss+model.lpc_offset), 0), \
                                                        'constant', 0).unsqueeze(0)
                        batch_x = batch_x_[:x_bs-model.seg]
                        batch_x_float = batch_x_[model.seg:x_bs]
                    else:
                        if model.lpc > 0:
                            if x_ss+model.lpc_offset > 0:
                                batch_x_prob = batch_x_lpc.unsqueeze(0)
                            else:
                                batch_x_prob = F.pad(batch_x_lpc, (-(x_ss+model.lpc_offset)), \
                                                    'constant', 0).unsqueeze(0)
                        batch_x = batch_x_[:-model.seg]
                        batch_x_float = batch_x_[model.seg:]
                    batch_h = batch_h.transpose(0,1).unsqueeze(0)
                    batch_x = batch_x.unsqueeze(0).unsqueeze(1)
                    if h_ss > 0:
                        feat_len = batch_x_float[model.receptive_field:].shape[0]
                    else:
                        feat_len = batch_x_float.shape[0]

                    if model.lpc > 0:
                        mus, bs, log_bs, ass = model(batch_h, batch_x)
                        # jump off s samples as in synthesis
                        mus = mus[:,::model.seg,:]
                        bs = bs[:,::model.seg,:]
                        log_bs = log_bs[:,::model.seg,:]
                        ass = ass[:,::model.seg,:].flip(-1)
                        init_mus = mus
                        for j in range(model.seg):
                            tmp_smpls = batch_x_prob[:,j:-(model.seg-j)].unfold(1, model.lpc, model.seg)
                            lpc = torch.sum(ass*tmp_smpls,-1,keepdim=True)
                            if j > 0:
                                mus = torch.cat((mus, lpc+init_mus[:,:,j:j+1]),2)
                            else:
                                mus = lpc+init_mus[:,:,j:j+1]
                        mus = mus.reshape(mus.shape[0],-1)
                        bs = bs.reshape(bs.shape[0],-1)
                        log_bs = log_bs.reshape(log_bs.shape[0],-1)
                    else:
                        mus, bs, log_bs = model(batch_h, batch_x)

                    if h_ss > 0:
                        mus = mus[0,model.receptive_field:]
                        bs = bs[0,model.receptive_field:]
                        log_bs = log_bs[0,model.receptive_field:]
                        batch_x_float = batch_x_float[model.receptive_field:]
                    else:
                        mus = mus[0]
                        bs = bs[0]
                        log_bs = log_bs[0]

                    m_sum = 0
                    batch_loss_laplace = criterion_laplace(mus, bs, batch_x_float, log_b=log_bs)
                    eps = torch.empty(mus.shape).cuda().uniform_(-0.4999,0.5)
                    batch_output = mus-bs*eps.sign()*torch.log1p(-2*eps.abs())
                    batch_loss_err = torch.mean(torch.abs(batch_output-batch_x_float))
                    logging.info("%lf %E %lf %E" % (torch.min(batch_x_float), torch.mean(batch_x_float), \
                                                    torch.max(batch_x_float), torch.var(batch_x_float)))
                    logging.info("%lf %E %lf %E" % (torch.min(batch_output), torch.mean(batch_output), \
                                                    torch.max(batch_output), torch.var(batch_output)))
                    m = 0
                    for i in range(args.n_fft_facts):
                        if feat_len > int(fft_facts[i]/2):
                            stft_out[i] = torch.stft(batch_output, fft_facts[i], window=hann_win[i])
                            stft_trg[i] = torch.stft(batch_x_float, fft_facts[i], window=hann_win[i])
                            tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i])
                            if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss):
                                if m > 0:
                                    batch_loss_lsd = torch.cat((batch_loss_lsd, \
                                                                tmp_batch_stft_loss.unsqueeze(0)))
                                else:
                                    batch_loss_lsd = tmp_batch_stft_loss.unsqueeze(0)
                                m += 1

                    loss_err.append(batch_loss_err.item())
                    loss_laplace.append(batch_loss_laplace.item())
                    if m > 0:
                        batch_loss_lsd = torch.mean(batch_loss_lsd)
                        loss_lsd.append(batch_loss_lsd.item())
                        logging.info("batch eval loss %s [%d:%d] %d %d %d %d %d %d = %.3f %.3f dB %.6f "\
                            "(%.3f sec)" % (os.path.join(os.path.basename(os.path.dirname(wavfile)),\
                            os.path.basename(wavfile)), c_idx+1, utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, \
                            batch_loss_laplace.item(), batch_loss_lsd.item(), \
                            batch_loss_err.item(), time.time() - start))
                    else:
                        logging.info("batch eval loss %s [%d:%d] %d %d %d %d %d %d = %.3f n/a %.6f "\
                            "(%.3f sec)" % (os.path.join(os.path.basename(os.path.dirname(wavfile)),\
                            os.path.basename(wavfile)), c_idx+1, utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, \
                            batch_loss_laplace.item(), batch_loss_err.item(), time.time() - start))
                    iter_count += 1
                    total += time.time() - start
            eval_loss_lsd = np.mean(loss_lsd)
            eval_loss_lsd_std = np.std(loss_lsd)
            eval_loss_err = np.mean(loss_err)
            eval_loss_err_std = np.std(loss_err)
            eval_loss_laplace = np.mean(loss_laplace)
            eval_loss_laplace_std = np.std(loss_laplace)
            logging.info("(EPOCH:%d) average evaluation loss = %.6f (+- %.6f) %.6f dB (+- %.6f dB) %.6f "\
                "(+- %.6f) (%.3f min., %.3f sec / batch)" % (epoch_idx + 1, eval_loss_laplace, \
                eval_loss_laplace_std, eval_loss_lsd, eval_loss_lsd_std, eval_loss_err, eval_loss_err_std, \
                total / 60.0, total / iter_count))
            if (eval_loss_laplace+eval_loss_laplace_std+eval_loss_lsd+eval_loss_lsd_std+eval_loss_err\
                +eval_loss_err_std) <= (min_eval_loss_laplace+min_eval_loss_laplace_std+min_eval_loss_lsd\
                    +min_eval_loss_lsd_std+min_eval_loss_err+min_eval_loss_err_std):
                min_eval_loss_lsd = eval_loss_lsd
                min_eval_loss_lsd_std = eval_loss_lsd_std
                min_eval_loss_err = eval_loss_err
                min_eval_loss_err_std = eval_loss_err_std
                min_eval_loss_laplace = eval_loss_laplace
                min_eval_loss_laplace_std = eval_loss_laplace_std
                min_idx = epoch_idx
            logging.info("min_eval_loss = %.6f (+- %.6f) %.6f dB (+- %.6f dB) %.6f (+- %.6f) min_idx=%d" % (
                min_eval_loss_laplace, min_eval_loss_laplace_std, min_eval_loss_lsd, min_eval_loss_lsd_std, \
                min_eval_loss_err, min_eval_loss_err_std, min_idx+1))
            loss_lsd = []
            loss_laplace = []
            loss_err = []
            total = 0
            iter_count = 0
            epoch_idx += 1
            np.random.set_state(numpy_random_state)
            torch.set_rng_state(torch_random_state)
            model.train()
            for param in model.parameters():
                param.requires_grad = True
            for param in model.scale_in.parameters():
                param.requires_grad = False
            # start next epoch
            if epoch_idx < args.epoch_count:
                start = time.time()
                logging.info("==%d EPOCH==" % (epoch_idx+1))
                logging.info("Training data")
                batch_x_float, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = next(generator)
        # feedforward and backpropagate current batch
        if epoch_idx < args.epoch_count:
            logging.info("%d iteration [%d]" % (iter_idx+1, epoch_idx+1))

            tf = batch_h.shape[0]
            ts = batch_x_float.shape[0]

            batch_h = batch_h[h_ss:]
            batch_x_ = batch_x_float[x_ss:]
            if model.lpc > 0:
                if x_ss+model.lpc_offset >= 0:
                    batch_x_lpc = batch_x_float[x_ss+model.lpc_offset:]
                else:
                    batch_x_lpc = batch_x_float[x_ss:]
            if h_bs != -1:
                batch_h = batch_h[:h_bs]
                if model.lpc > 0:
                    if x_ss+model.lpc_offset >= 0:
                        batch_x_prob = batch_x_lpc[:x_bs-model.lpc_offset].unsqueeze(0)
                    else:
                        batch_x_prob = F.pad(batch_x_lpc[:x_bs], (-(x_ss+model.lpc_offset), 0), \
                                                'constant', 0).unsqueeze(0)
                batch_x = batch_x_[:x_bs-model.seg]
                batch_x_float = batch_x_[model.seg:x_bs]
            else:
                if model.lpc > 0:
                    if x_ss+model.lpc_offset > 0:
                        batch_x_prob = batch_x_lpc.unsqueeze(0)
                    else:
                        batch_x_prob = F.pad(batch_x_lpc, (-(x_ss+model.lpc_offset)), \
                                                'constant', 0).unsqueeze(0)
                batch_x = batch_x_[:-model.seg]
                batch_x_float = batch_x_[model.seg:]
            batch_h = batch_h.transpose(0,1).unsqueeze(0)
            batch_x = batch_x.unsqueeze(0).unsqueeze(1)
            if h_ss > 0:
                if model.seg > 1:
                    feat_len = batch_x_float[model.receptive_field:-(model.seg-1)].shape[0]
                else:
                    feat_len = batch_x_float[model.receptive_field:].shape[0]
            else:
                if model.seg > 1:
                    feat_len = batch_x_float[:-(model.seg-1)].shape[0]
                else:
                    feat_len = batch_x_float.shape[0]

            if model.lpc > 0:
                mus, bs_noclip, bs, log_bs, ass = model(batch_h, batch_x, do=True, clip=True)
                ass = ass.flip(-1)
                init_mus = mus
                for j in range(model.seg):
                    tmp_smpls = batch_x_prob[:,j:-(model.seg-j)].unfold(1, model.lpc, 1)
                    lpc = torch.sum(ass*tmp_smpls,-1,keepdim=True)
                    if j > 0:
                        mus = torch.cat((mus, lpc+init_mus[:,:,j:j+1]),2)
                    else:
                        mus = lpc+init_mus[:,:,j:j+1]
                if model.seg == 1:
                    mus = mus.reshape(mus.shape[0], -1)
                    bs_noclip = bs_noclip.reshape(mus.shape[0], -1)
                    bs = bs.reshape(mus.shape[0], -1)
                    log_bs = log_bs.reshape(mus.shape[0], -1)
            else:
                mus, bs_noclip, bs, log_bs = model(batch_h, batch_x, do=True, clip=True)

            if h_ss > 0:
                mus = mus[0,model.receptive_field:]
                bs_noclip = bs_noclip[0,model.receptive_field:]
                bs = bs[0,model.receptive_field:]
                log_bs = log_bs[0,model.receptive_field:]
                batch_x_float = batch_x_float[model.receptive_field:]
            else:
                mus = mus[0]
                bs_noclip = bs_noclip[0]
                bs = bs[0]
                log_bs = log_bs[0]

            m_sum = 0
            if model.seg > 1:
                n_sum = 0
                for i in range(model.seg):
                    if i > 0:
                        i_n = i+1
                        mus_i = mus[:,i:i_n].squeeze(-1)
                        bs_noclip_i = bs_noclip[:,i:i_n].squeeze(-1)
                        if i_n < model.seg:
                            batch_x_float_i = batch_x_float[i:-(model.seg-(i_n))]
                        else:
                            batch_x_float_i = batch_x_float[i:]
                        tmp_batch_loss_laplace = criterion_laplace(mus_i, bs[:,i:i_n].squeeze(-1), \
                                                batch_x_float_i, log_b=log_bs[:,i:i_n].squeeze(-1), log=False)
                        batch_loss_laplace = torch.cat((batch_loss_laplace, \
                                                        tmp_batch_loss_laplace.unsqueeze(0)))
                    else:
                        mus_i = mus[:,:1].squeeze(-1)
                        bs_noclip_i = bs_noclip[:,:1].squeeze(-1)
                        batch_x_float_i = batch_x_float[:-(model.seg-1)]
                        tmp_batch_loss_laplace = criterion_laplace(mus_i, bs[:,:1].squeeze(-1), \
                                                    batch_x_float_i, log_b=log_bs[:,:1].squeeze(-1))
                        batch_loss_laplace = tmp_batch_loss_laplace.unsqueeze(0)
                    eps = torch.empty(mus_i.shape).cuda().uniform_(-0.4999,0.5)
                    batch_output = mus_i-bs_noclip_i*eps.sign()*torch.log1p(-2*eps.abs())
                    tmp_batch_loss_err = torch.mean(torch.abs(batch_output-batch_x_float_i))
                    if i > 0:
                        batch_loss_err = torch.cat((batch_loss_err, tmp_batch_loss_err.unsqueeze(0)))
                    else:
                        batch_loss_err = tmp_batch_loss_err.unsqueeze(0)
                    if i == 0:
                        logging.info("%lf %E %lf %E" % (torch.min(batch_x_float_i), \
                        torch.mean(batch_x_float_i), torch.max(batch_x_float_i), torch.var(batch_x_float_i)))
                        logging.info("%lf %E %lf %E" % (torch.min(batch_output), torch.mean(batch_output), \
                        torch.max(batch_output), torch.var(batch_output)))
                    n = 0
                    for i in range(args.n_fft_facts):
                        if feat_len > int(fft_facts[i]/2):
                            stft_out[i] = torch.stft(batch_output, fft_facts[i], window=hann_win[i])
                            stft_trg[i] = torch.stft(batch_x_float_i, fft_facts[i], window=hann_win[i])
                            tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i], LSD=False, L2=False)
                            if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss):
                                if n > 0:
                                    tmp_batch_loss_stft_l1 = torch.cat((tmp_batch_loss_stft_l1, \
                                                                        tmp_batch_stft_loss.unsqueeze(0)))
                                else:
                                    tmp_batch_loss_stft_l1 = tmp_batch_stft_loss.unsqueeze(0)
                                n += 1
                    if n > 0:
                        if n_sum > 0:
                            batch_loss_stft_l1 = torch.cat((batch_loss_stft_l1, \
                                                            torch.mean(tmp_batch_loss_stft_l1).unsqueeze(0)))
                        else:
                            batch_loss_stft_l1 = torch.mean(tmp_batch_loss_stft_l1).unsqueeze(0)
                    n_sum += n
                    m = 0
                    for i in range(args.n_fft_facts):
                        if feat_len > int(fft_facts[i]/2):
                            tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i])
                            if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss):
                                if m > 0:
                                    tmp_batch_loss_lsd = torch.cat((tmp_batch_loss_lsd, \
                                                                    tmp_batch_stft_loss.unsqueeze(0)))
                                else:
                                    tmp_batch_loss_lsd = tmp_batch_stft_loss.unsqueeze(0)
                                m += 1
                    if m > 0:
                        if m_sum > 0:
                            batch_loss_lsd = torch.cat((batch_loss_lsd, \
                                                        torch.mean(tmp_batch_loss_lsd).unsqueeze(0)))
                        else:
                            batch_loss_lsd = torch.mean(tmp_batch_loss_lsd).unsqueeze(0)
                    m_sum += m
                batch_loss_laplace = torch.mean(batch_loss_laplace)
                batch_loss = batch_loss_laplace
                if n_sum > 0:
                    batch_loss += torch.mean(batch_loss_stft_l1)
                if m_sum > 0:
                    batch_loss_lsd = torch.mean(batch_loss_lsd)
                batch_loss_err = torch.mean(batch_loss_err)
            else:
                batch_loss_laplace = criterion_laplace(mus, bs, batch_x_float, log_b=log_bs)
                batch_loss = batch_loss_laplace
                eps = torch.empty(mus.shape).cuda().uniform_(-0.4999,0.5)
                batch_output = mus-bs_noclip*eps.sign()*torch.log1p(-2*eps.abs())
                batch_loss_err = torch.mean(torch.abs(batch_output-batch_x_float))
                logging.info("%lf %E %lf %E" % (torch.min(batch_x_float), torch.mean(batch_x_float), \
                                                torch.max(batch_x_float), torch.var(batch_x_float)))
                logging.info("%lf %E %lf %E" % (torch.min(batch_output), torch.mean(batch_output), \
                                                torch.max(batch_output), torch.var(batch_output)))
                n = 0
                for i in range(args.n_fft_facts):
                    if feat_len > int(fft_facts[i]/2):
                        stft_out[i] = torch.stft(batch_output, fft_facts[i], window=hann_win[i])
                        stft_trg[i] = torch.stft(batch_x_float, fft_facts[i], window=hann_win[i])
                        tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i], LSD=False, L2=False)
                        if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss):
                            if n > 0:
                                batch_loss_stft_l1 = torch.cat((batch_loss_stft_l1, \
                                                                tmp_batch_stft_loss.unsqueeze(0)))
                            else:
                                batch_loss_stft_l1 = tmp_batch_stft_loss.unsqueeze(0)
                            n += 1
                if n > 0:
                    batch_loss += torch.mean(batch_loss_stft_l1)
                m = 0
                for i in range(args.n_fft_facts):
                    if feat_len > int(fft_facts[i]/2):
                        tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i])
                        if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss):
                            if m > 0:
                                batch_loss_lsd = torch.cat((batch_loss_lsd, tmp_batch_stft_loss.unsqueeze(0)))
                            else:
                                batch_loss_lsd = tmp_batch_stft_loss.unsqueeze(0)
                            m += 1
                if m > 0:
                    batch_loss_lsd = torch.mean(batch_loss_lsd)

            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

            loss_err.append(batch_loss_err.item())
            loss_laplace.append(batch_loss_laplace.item())
            if (model.seg > 1 and m_sum > 0) or (model.seg == 1 and m > 0):
                loss_lsd.append(batch_loss_lsd.item())
                logging.info("batch loss %s [%d:%d] %d %d %d %d %d %d = %.3f %.3f dB %.6f (%.3f sec)" % (
                    os.path.basename(os.path.dirname(wavfile))+"/"+os.path.basename(wavfile), c_idx+1, \
                    utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, batch_loss_laplace.item(), \
                    batch_loss_lsd.item(), batch_loss_err.item(), time.time() - start))
            else:
                logging.info("batch loss %s [%d:%d] %d %d %d %d %d %d = %.3f n/a %.6f (%.3f sec)" % (
                    os.path.basename(os.path.dirname(wavfile))+"/"+os.path.basename(wavfile), c_idx+1, \
                    utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, batch_loss_laplace.item(), \
                    batch_loss_err.item(), time.time() - start))
            iter_idx += 1
            iter_count += 1
            total += time.time() - start

    # save final model
    model.cpu()
    torch.save({"model": model.state_dict()}, args.expdir + "/checkpoint-final.pkl")
    logging.info("final checkpoint created.")
예제 #7
0
    def __getitem__(self, idx):
        featfile_src = self.file_list[idx]

        h_src = read_hdf5(featfile_src, "/feat_org_lf0")
        flen_src = h_src.shape[0]
        src_code = np.zeros((flen_src, self.n_spk))

        spk_idx, mean_trg_list, std_trg_list, src_trg_code_list, featfile_spk, pair_spk_list, src_class_code, \
            trg_class_code_list = proc_multspk_data_mix_random_cls_statcvexcit(featfile_src, self.spk_list, \
                self.n_cyc, src_code, self.n_spk, self.spk_idx_dict, self.stat_spk_list)

        mean_src = read_hdf5(self.stat_spk_list[spk_idx],
                             "/mean_feat_org_lf0")[1:2]
        std_src = read_hdf5(self.stat_spk_list[spk_idx],
                            "/scale_feat_org_lf0")[1:2]

        if check_hdf5(featfile_src, "/spcidx_range"):
            spcidx_src = read_hdf5(featfile_src, "/spcidx_range")[0]
        else:
            spk_f0rate = os.path.basename(os.path.dirname(featfile_src))
            spk_ = spk_f0rate.split('_')[0]
            spcidx_src = read_hdf5(os.path.join(os.path.dirname(os.path.dirname(featfile_src)), spk_, \
                            os.path.basename(featfile_src)), "/spcidx_range")[0]
        flen_spc_src = spcidx_src.shape[0]

        src_code = torch.FloatTensor(self.pad_transform(src_code))
        src_class_code = torch.LongTensor(self.pad_transform(src_class_code))

        cv_src_list = [None] * self.n_cyc
        for i in range(self.n_cyc):
            cv_src_list[i] = torch.FloatTensor(self.pad_transform(np.c_[h_src[:,:1], \
                                (std_trg_list[i]/std_src)*(h_src[:,1:2]-mean_src)+mean_trg_list[i], \
                                    h_src[:,2:self.stdim]]))
            src_trg_code_list[i] = torch.FloatTensor(
                self.pad_transform(src_trg_code_list[i]))
            trg_class_code_list[i] = torch.LongTensor(
                self.pad_transform(trg_class_code_list[i]))

        h_src = torch.FloatTensor(self.pad_transform(h_src))
        spcidx_src = torch.LongTensor(self.pad_transform(spcidx_src))

        file_src_trg_flag = False
        if self.pair_utt_flag:
            featfile_src_trg = os.path.dirname(os.path.dirname(featfile_src))+"/"+pair_spk_list[0]+\
                                                "/"+os.path.basename(featfile_src)
            if os.path.exists(featfile_src_trg):
                file_src_trg_flag = True
                h_src_trg = read_hdf5(featfile_src_trg, "/feat_org_lf0")
                flen_src_trg = h_src_trg.shape[0]
                if check_hdf5(featfile_src_trg, "/spcidx_range"):
                    spcidx_src_trg = read_hdf5(featfile_src_trg,
                                               "/spcidx_range")[0]
                else:
                    spk_f0rate = os.path.basename(
                        os.path.dirname(featfile_src_trg))
                    spk_ = spk_f0rate.split('_')[0]
                    spcidx_src_trg = read_hdf5(os.path.join(os.path.dirname(\
                            os.path.dirname(featfile_src_trg)), spk_,\
                                os.path.basename(featfile_src_trg)), "/spcidx_range")[0]
                flen_spc_src_trg = spcidx_src_trg.shape[0]
                h_src_trg = torch.FloatTensor(self.pad_transform(h_src_trg))
                spcidx_src_trg = torch.LongTensor(
                    self.pad_transform(spcidx_src_trg))
            else:
                h_src_trg = h_src
                flen_src_trg = flen_src
                spcidx_src_trg = spcidx_src
                flen_spc_src_trg = flen_spc_src
            return {'h_src': h_src, 'flen_src': flen_src, 'src_code': src_code, \
                    'src_trg_code_list': src_trg_code_list, 'cv_src_list': cv_src_list, \
                    'spcidx_src': spcidx_src, 'flen_spc_src': flen_spc_src, 'h_src_trg': h_src_trg, \
                    'flen_src_trg': flen_src_trg, 'spcidx_src_trg': spcidx_src_trg, \
                    'flen_spc_src_trg': flen_spc_src_trg, 'featfile_src': featfile_src, \
                    'featfile_src_trg': featfile_src_trg, 'featfile_spk': featfile_spk, \
                    'pair_spk_list': pair_spk_list, 'src_class_code': src_class_code, \
                    'trg_class_code_list': trg_class_code_list, 'file_src_trg_flag': file_src_trg_flag}
        else:
            return {'h_src': h_src, 'flen_src': flen_src, 'src_code': src_code, \
                    'src_trg_code_list': src_trg_code_list, 'cv_src_list': cv_src_list, \
                    'spcidx_src': spcidx_src, 'flen_spc_src': flen_spc_src, 'featfile_src': featfile_src, \
                    'featfile_spk': featfile_spk, 'pair_spk_list': pair_spk_list, \
                    'src_class_code': src_class_code, 'trg_class_code_list': trg_class_code_list, \
                    'file_src_trg_flag': file_src_trg_flag}
예제 #8
0
    def __getitem__(self, idx):
        featfile_src = self.file_list_src[idx]
        featfile_src_trg = self.file_list_src_trg[idx]
        file_src_trg_flag = self.list_src_trg_flag[idx]

        spk_src = os.path.basename(os.path.dirname(featfile_src))
        spk_trg = os.path.basename(os.path.dirname(featfile_src_trg))
        idx_src = self.spk_idx_dict[spk_src]
        idx_trg = self.spk_idx_dict[spk_trg]

        mean_src = read_hdf5(self.stat_spk_list[idx_src],
                             "/mean_feat_org_lf0")[1:2]
        std_src = read_hdf5(self.stat_spk_list[idx_src],
                            "/scale_feat_org_lf0")[1:2]
        mean_trg = read_hdf5(self.stat_spk_list[idx_trg],
                             "/mean_feat_org_lf0")[1:2]
        std_trg = read_hdf5(self.stat_spk_list[idx_trg],
                            "/scale_feat_org_lf0")[1:2]

        h_src = read_hdf5(featfile_src, "/feat_org_lf0")
        flen_src = h_src.shape[0]
        src_code = np.zeros((flen_src, self.n_spk))
        src_trg_code = np.zeros((flen_src, self.n_spk))
        src_code[:, idx_src] = 1
        src_trg_code[:, idx_trg] = 1
        cv_src = np.c_[h_src[:, :1], (std_trg / std_src) *
                       (h_src[:, 1:2] - mean_src) + mean_trg,
                       h_src[:, 2:self.stdim]]
        if check_hdf5(featfile_src, "/spcidx_range"):
            spcidx_src = read_hdf5(featfile_src, "/spcidx_range")[0]
        else:
            spk_f0rate = os.path.basename(os.path.dirname(featfile_src))
            spk_ = spk_f0rate.split('_')[0]
            spcidx_src = read_hdf5(os.path.join(os.path.dirname(os.path.dirname(featfile_src)), spk_, \
                            os.path.basename(featfile_src)), "/spcidx_range")[0]
        src_class_code = np.ones(h_src.shape[0], dtype=np.int64) * idx_src
        src_trg_class_code = np.ones(h_src.shape[0], dtype=np.int64) * idx_trg
        flen_spc_src = spcidx_src.shape[0]

        if file_src_trg_flag:
            h_src_trg = read_hdf5(featfile_src_trg, "/feat_org_lf0")
            flen_src_trg = h_src_trg.shape[0]
            trg_code = np.zeros((flen_src_trg, self.n_spk))
            trg_src_code = np.zeros((flen_src_trg, self.n_spk))
            trg_code[:, idx_trg] = 1
            trg_src_code[:, idx_src] = 1
            cv_trg = np.c_[h_src_trg[:,:1], (std_src/std_trg)*(h_src_trg[:,1:2]-mean_trg)+mean_src, \
                            h_src_trg[:,2:self.stdim]]
            if check_hdf5(featfile_src_trg, "/spcidx_range"):
                spcidx_src_trg = read_hdf5(featfile_src_trg,
                                           "/spcidx_range")[0]
            else:
                spk_f0rate = os.path.basename(
                    os.path.dirname(featfile_src_trg))
                spk_ = spk_f0rate.split('_')[0]
                spcidx_src_trg = read_hdf5(os.path.join(os.path.dirname(os.path.dirname(featfile_src_trg)), \
                                        spk_, os.path.basename(featfile_src_trg)), "/spcidx_range")[0]
            trg_class_code = np.ones(h_src_trg.shape[0],
                                     dtype=np.int64) * idx_trg
            trg_src_class_code = np.ones(h_src_trg.shape[0],
                                         dtype=np.int64) * idx_src
            flen_spc_src_trg = spcidx_src_trg.shape[0]

        h_src = torch.FloatTensor(self.pad_transform(h_src))
        src_code = torch.FloatTensor(self.pad_transform(src_code))
        src_trg_code = torch.FloatTensor(self.pad_transform(src_trg_code))
        cv_src = torch.FloatTensor(self.pad_transform(cv_src))
        spcidx_src = torch.LongTensor(self.pad_transform(spcidx_src))
        src_class_code = torch.LongTensor(self.pad_transform(src_class_code))
        src_trg_class_code = torch.LongTensor(
            self.pad_transform(src_trg_class_code))

        if file_src_trg_flag:
            h_src_trg = torch.FloatTensor(self.pad_transform(h_src_trg))
            trg_code = torch.FloatTensor(self.pad_transform(trg_code))
            trg_src_code = torch.FloatTensor(self.pad_transform(trg_src_code))
            cv_trg = torch.FloatTensor(self.pad_transform(cv_trg))
            spcidx_src_trg = torch.LongTensor(
                self.pad_transform(spcidx_src_trg))
            trg_class_code = torch.LongTensor(
                self.pad_transform(trg_class_code))
            trg_src_class_code = torch.LongTensor(
                self.pad_transform(trg_src_class_code))
        else:
            flen_src_trg = flen_src
            h_src_trg = h_src
            trg_code = src_code
            trg_src_code = src_trg_code
            cv_trg = cv_src
            spcidx_src_trg = spcidx_src
            trg_class_code = src_class_code
            trg_src_class_code = src_trg_class_code
            flen_spc_src_trg = flen_spc_src

        return {'h_src': h_src, 'flen_src': flen_src, 'src_code': src_code, 'src_trg_code': src_trg_code, \
                'cv_src': cv_src, 'spcidx_src': spcidx_src, 'flen_spc_src': flen_spc_src, \
                'h_src_trg': h_src_trg, 'flen_src_trg': flen_src_trg, 'trg_code': trg_code, \
                'trg_src_code': trg_src_code, 'cv_trg': cv_trg, 'spcidx_src_trg': spcidx_src_trg, \
                'flen_spc_src_trg': flen_spc_src_trg, 'featfile_src': featfile_src, \
                'featfile_src_trg': featfile_src_trg, 'src_class_code': src_class_code, \
                'src_trg_class_code': src_trg_class_code, 'trg_class_code': trg_class_code, \
                'trg_src_class_code': trg_src_class_code, 'file_src_trg_flag': file_src_trg_flag}