Ejemplo n.º 1
0
    hp.set_hparam_yaml(args.case)

    # model
    audio_meta_train = VoxCelebMeta(hp.train.data_path, hp.train.meta_path)
    model = ClassificationModel(num_classes=audio_meta_train.num_speaker, **hp.model)

    # data loader
    audio_meta_class = globals()[hp.embed.audio_meta_class]
    params = {'data_path': hp.embed.data_path}
    if hp.embed.meta_path:
            params['meta_path'] = hp.embed.meta_path
    audio_meta = audio_meta_class(**params)
    data_loader = DataLoader(audio_meta, hp.embed.batch_size)

    # samples
    wav, mel_spec, speaker_id = data_loader.dataflow().get_data().next()

    ckpt = args.ckpt if args.ckpt else tf.train.latest_checkpoint(hp.logdir)

    pred_conf = PredictConfig(
        model=model,
        input_names=['x'],
        output_names=['embedding/embedding', 'prediction'],
        session_init=SaverRestore(ckpt) if ckpt else None)
    embedding_pred = OfflinePredictor(pred_conf)

    embedding, pred_speaker_id = embedding_pred(mel_spec)

    # get a random audio of the predicted speaker.
    wavfile_pred_speaker = np.array(map(lambda s: audio_meta_train.get_random_audio(s), pred_speaker_id))
    length = int(hp.signal.duration * hp.signal.sr)
Ejemplo n.º 2
0
# !/usr/bin/env python
import argparse
from tensorpack.dataflow.remote import send_dataflow_zmq
from data_load import DataLoader, AudioMeta
from hparam import hparam as hp
import multiprocessing

if __name__ == '__main__':
    # get arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('case', type=str, help='experiment case name.')
    parser.add_argument('-data_path', type=str)
    parser.add_argument('-dest_url', type=str)
    parser.add_argument('-num_thread', type=int, default=1)
    args = parser.parse_args()

    # set hyper-parameters from yaml file
    hp.set_hparam_yaml(case=args.case)

    if args.data_path:
        hp.train.data_path = args.data_path

    # dataflow
    audio_meta = AudioMeta(hp.train.data_path)
    data_loader = DataLoader(audio_meta, 1)
    num_thread = args.num_thread if args.num_thread else multiprocessing.cpu_count(
    ) // 1.5
    data_loader = data_loader.dataflow(nr_prefetch=5000,
                                       nr_thread=args.num_thread)

    send_dataflow_zmq(data_loader, args.dest_url)