示例#1
0
def main():

    # Hyperparameters

    parser = argparse.ArgumentParser()

    # in_dir = ~/wav
    parser.add_argument("--in_dir",
                        type=str,
                        required=True,
                        help="input data(pickle) dir")
    parser.add_argument(
        "--ckpt_dir",
        type=str,
        required=True,
        help="checkpoint to save/ start with for train/inference")
    parser.add_argument("--mode",
                        default="train",
                        choices=["train", "test", "infer"],
                        help="setting mode for execution")

    # Saving Checkpoints, Data... etc
    parser.add_argument("--max_step",
                        type=int,
                        default=500000,
                        help="maximum steps in training")
    parser.add_argument("--checkpoint_freq",
                        type=int,
                        default=100,
                        help="how often save checkpoint")

    # Data
    parser.add_argument("--segment_length",
                        type=float,
                        default=1.6,
                        help="segment length in seconds")
    parser.add_argument("--spectrogram_scale",
                        type=int,
                        default=40,
                        help="scale of the input spectrogram")

    # Ininitialization
    parser.add_argument("--init_type",
                        type=str,
                        default="uniform",
                        help="type of initializer")
    parser.add_argument("--init_weight_range",
                        type=float,
                        default=0.1,
                        help="initial weight ranges from -0.1 to 0.1")

    # Optimization
    parser.add_argument("--loss_type",
                        default="softmax",
                        choices=["softmax", "contrast"],
                        help="loss type for optimization")
    parser.add_argument("--optimizer",
                        type=str,
                        default="sgd",
                        help="type of optimizer")
    parser.add_argument("--learning_rate",
                        type=float,
                        default=0.01,
                        help="learning rate")
    parser.add_argument("--l2_norm_clip",
                        type=float,
                        default=3.0,
                        help="L2-norm of gradient is clipped at")

    # Train
    parser.add_argument("--num_spk_per_batch",
                        type=int,
                        default=64,
                        help="N speakers of batch size N*M")
    parser.add_argument("--num_utt_per_batch",
                        type=int,
                        default=10,
                        help="M utterances of batch size N*M")

    # LSTM
    parser.add_argument("--lstm_proj_clip",
                        type=float,
                        default=0.5,
                        help="Gradient scale for projection node in LSTM")
    parser.add_argument("--num_lstm_stacks",
                        type=int,
                        default=3,
                        help="number of LSTM stacks")
    parser.add_argument("--num_lstm_cells",
                        type=int,
                        default=768,
                        help="number of LSTM cells")
    parser.add_argument("--dim_lstm_projection",
                        type=int,
                        default=256,
                        help="dimension of LSTM projection")

    # Scaled Cosine similarity
    parser.add_argument(
        "--scale_clip",
        type=float,
        default=0.01,
        help="Gradient scale for scale values in scaled cosine similarity")

    # Collect hparams
    args = parser.parse_args()

    # Set up Queue
    global_queue = queue.Queue()
    # Set up Feeder
    libri_feeder = Feeder(args, "train", "libri")
    libri_feeder.set_up_feeder(global_queue)

    vox1_feeder = Feeder(args, "train", "vox1")
    vox1_feeder.set_up_feeder(global_queue)

    vox2_feeder = Feeder(args, "train", "vox2")
    vox2_feeder.set_up_feeder(global_queue)

    # Set up Model

    model = GE2E(args)
    graph = model.set_up_model("train")

    # Training
    with graph.as_default():
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:

        train_writer = tf.summary.FileWriter(args.ckpt_dir, sess.graph)
        ckpt = tf.train.get_checkpoint_state(args.ckpt_dir)
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            print('Restoring Variables from {}'.format(
                ckpt.model_checkpoint_path))
            saver.restore(sess, ckpt.model_checkpoint_path)
            start_step = sess.run(model.global_step)

        else:
            print('start from 0')
            init_op = tf.global_variables_initializer()
            sess.run(init_op)
            start_step = 1

        for num_step in range(start_step, args.max_step + 1):

            print("current step: " + str(num_step) + "th step")

            batch = global_queue.get()

            summary, training_loss, _ = sess.run(
                [model.sim_mat_summary, model.total_loss, model.optimize],
                feed_dict={
                    model.input_batch: batch[0],
                    model.target_batch: batch[1]
                })
            train_writer.add_summary(summary, num_step)
            print("batch loss:" + str(training_loss))

            if num_step % args.checkpoint_freq == 0:
                save_path = saver.save(sess,
                                       args.ckpt_dir + "/model.ckpt",
                                       global_step=model.global_step)
                print("model saved in file: %s / %d th step" %
                      (save_path, sess.run(model.global_step)))
def main():

    # Hyperparameters

    parser = argparse.ArgumentParser()

    # Path

    # wav name formatting: id_clip_uttnum.wav
    parser.add_argument("--test_dir",
                        type=str,
                        required=True,
                        help="input test dir")
    #/home/hdd2tb/ninas96211/dev_wav_set

    parser.add_argument("--ckpt_file",
                        type=str,
                        required=True,
                        help="checkpoint to start with for inference")

    # Data
    #parser.add_argument("--window_length", type=int, default=160, help="sliding window length(frames)")
    parser.add_argument("--segment_length",
                        type=float,
                        default=1.6,
                        help="segment length in seconds")
    parser.add_argument("--overlap_ratio",
                        type=float,
                        default=0.5,
                        help="overlaping percentage")
    parser.add_argument("--spectrogram_scale",
                        type=int,
                        default=40,
                        help="scale of the input spectrogram")
    # Enrol
    parser.add_argument("--num_spk_per_batch",
                        type=int,
                        default=5,
                        help="N speakers of batch size N*M")
    parser.add_argument("--num_utt_per_batch",
                        type=int,
                        default=10,
                        help="M utterances of batch size N*M")

    # LSTM
    parser.add_argument("--num_lstm_stacks",
                        type=int,
                        default=3,
                        help="number of LSTM stacks")
    parser.add_argument("--num_lstm_cells",
                        type=int,
                        default=768,
                        help="number of LSTM cells")
    parser.add_argument("--dim_lstm_projection",
                        type=int,
                        default=256,
                        help="dimension of LSTM projection")

    # Total Score
    # (Sum of True Scores) -  (Sum of False Scores)
    # if the model is perfect, the score will be num_of_true_pairs
    # if the model really sucks, the score will be - num_of_false_pairs

    # Collect hparams
    args = parser.parse_args()
    global_queue = queue.Queue()
    feeder = Feeder(args, "test")
    feeder.set_up_feeder(global_queue)

    model = GE2E(args)
    graph = model.set_up_model("test")

    with graph.as_default():
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        # restore from checkpoints

        saver.restore(sess, args.ckpt_file)
        total_score = 0
        num_true_pairs = 0
        num_false_pairs = 0

        while len(feeder.wav_pairs) > 0:
            wav1_data, wav2_data, match = global_queue.get()
            wav1_out = sess.run(model.norm_out,
                                feed_dict={model.input_batch: wav1_data})
            wav2_out = sess.run(model.norm_out,
                                feed_dict={model.input_batch: wav2_data})
            wav1_dvector = np.mean(wav1_out, axis=0)
            wav2_dvector = np.mean(wav2_out, axis=0)
            final_score = np.dot(wav1_dvector, wav2_dvector) / (
                np.linalg.norm(wav1_dvector) * np.linalg.norm(wav2_dvector))
            print("final score:" + str(final_score))
            print("same? :" + str(match))
            if match == True:
                total_score += final_score
                num_true_pairs += 1
            if match == False:
                total_score -= final_score
                num_false_pairs += 1

    print("in total: " + str(total_score))
    print("num true pairs: " + str(num_true_pairs))
    print("num false pairs: " + str(num_false_pairs))
示例#3
0
def main():

    # Hyperparameters

    parser = argparse.ArgumentParser()

    # Path

    # wav name formatting: id_clip_uttnum.wav
    parser.add_argument("--in_dir", type=str, required=True, help="input dir")
    parser.add_argument("--out_dir", type=str, required=True, help="out dir")
    parser.add_argument("--batch_inference",
                        action="store_true",
                        help="set whether to use the batch inference")
    parser.add_argument("--dataset", type=str, default="libri", help="out dir")
    parser.add_argument("--in_wav1", type=str, help="input wav1 dir")
    parser.add_argument("--in_wav2",
                        default="temp.wav",
                        type=str,
                        help="input wav2 dir")
    #/home/hdd2tb/ninas96211/dev_wav_set
    parser.add_argument("--mode",
                        default="infer",
                        choices=["train", "test", "infer"],
                        help="setting mode for execution")

    parser.add_argument("--ckpt_file",
                        type=str,
                        default='./xckpt/model.ckpt-58100',
                        help="checkpoint to start with for inference")

    # Data
    #parser.add_argument("--window_length", type=int, default=160, help="sliding window length(frames)")
    parser.add_argument("--segment_length",
                        type=float,
                        default=1.6,
                        help="segment length in seconds")
    parser.add_argument("--overlap_ratio",
                        type=float,
                        default=0.5,
                        help="overlaping percentage")
    parser.add_argument("--spectrogram_scale",
                        type=int,
                        default=40,
                        help="scale of the input spectrogram")
    # Enrol
    parser.add_argument("--num_spk_per_batch",
                        type=int,
                        default=5,
                        help="N speakers of batch size N*M")
    parser.add_argument("--num_utt_per_batch",
                        type=int,
                        default=10,
                        help="M utterances of batch size N*M")

    # LSTM
    parser.add_argument("--num_lstm_stacks",
                        type=int,
                        default=3,
                        help="number of LSTM stacks")
    parser.add_argument("--num_lstm_cells",
                        type=int,
                        default=768,
                        help="number of LSTM cells")
    parser.add_argument("--dim_lstm_projection",
                        type=int,
                        default=256,
                        help="dimension of LSTM projection")
    parser.add_argument('--gpu', default='0', help='Path to model checkpoint')
    parser.add_argument('--gpu_num',
                        default=4,
                        help='Path to model checkpoint')

    # Collect hparams
    args = parser.parse_args()

    import os
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

    feeder = Feeder(args)
    feeder.set_up_feeder()

    model = GE2E(args)
    graph = model.set_up_model()

    #Training
    with graph.as_default():
        saver = tf.train.Saver()

    #num_gpu=4
    #sess_arr=[]
    #for i in range(num_gpu):
    #    gpu_options = tf.GPUOptions(visible_device_list=str(i))
    #    sess_arr.append(tf.Session(graph=graph, config=tf.ConfigProto(gpu_options=gpu_options)))
    #    saver.restore(sess_arr[i], args.ckpt_file)

    #    t = Thread(target=worker)
    #    t.daemon = True
    #    t.start()

    #    save_dvector_of_dir_parallel(sess_arr, feeder, model, args)

    with tf.Session(graph=graph) as sess:
        # restore from checkpoints

        saver.restore(sess, args.ckpt_file)

        #get_dvector_of_dir(sess, feeder, model, args)

        t = Thread(target=worker)
        t.daemon = True
        t.start()
        save_dvector_of_dir_parallel(sess, feeder, model, args)
def main():

    # Hyperparameters

    parser = argparse.ArgumentParser()

    # Path

    # wav name formatting: id_clip_uttnum.wav
    parser.add_argument("--in_wav1",
                        type=str,
                        required=True,
                        help="input wav1 dir")
    parser.add_argument("--in_wav2",
                        type=str,
                        required=True,
                        help="input wav2 dir")
    #/home/hdd2tb/ninas96211/dev_wav_set
    parser.add_argument("--mode",
                        default="infer",
                        choices=["train", "test", "infer"],
                        help="setting mode for execution")

    parser.add_argument("--ckpt_file",
                        type=str,
                        required=True,
                        help="checkpoint to start with for inference")

    # Data
    #parser.add_argument("--window_length", type=int, default=160, help="sliding window length(frames)")
    parser.add_argument("--segment_length",
                        type=float,
                        default=1.6,
                        help="segment length in seconds")
    parser.add_argument("--overlap_ratio",
                        type=float,
                        default=0.5,
                        help="overlaping percentage")
    parser.add_argument("--spectrogram_scale",
                        type=int,
                        default=40,
                        help="scale of the input spectrogram")
    # Enrol
    parser.add_argument("--num_spk_per_batch",
                        type=int,
                        default=5,
                        help="N speakers of batch size N*M")
    parser.add_argument("--num_utt_per_batch",
                        type=int,
                        default=10,
                        help="M utterances of batch size N*M")

    # LSTM
    parser.add_argument("--num_lstm_stacks",
                        type=int,
                        default=3,
                        help="number of LSTM stacks")
    parser.add_argument("--num_lstm_cells",
                        type=int,
                        default=768,
                        help="number of LSTM cells")
    parser.add_argument("--dim_lstm_projection",
                        type=int,
                        default=256,
                        help="dimension of LSTM projection")

    # Collect hparams
    args = parser.parse_args()

    feeder = Feeder(args, "infer")
    feeder.set_up_feeder()

    model = GE2E(args)
    graph = model.set_up_model("infer")

    # Training

    with graph.as_default():
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        # restore from checkpoints

        saver.restore(sess, args.ckpt_file)

        wav1_data, wav2_data, match = feeder.create_infer_batch()

        # score

        wav1_out = sess.run(model.norm_out,
                            feed_dict={model.input_batch: wav1_data})
        wav2_out = sess.run(model.norm_out,
                            feed_dict={model.input_batch: wav2_data})

        wav1_dvector = np.mean(wav1_out, axis=0)
        wav2_dvector = np.mean(wav2_out, axis=0)

        final_score = np.dot(wav1_dvector, wav2_dvector) / (
            np.linalg.norm(wav1_dvector) * np.linalg.norm(wav2_dvector))

        print("final score:" + str(final_score))
        print("same? :" + str(match))