Esempio n. 1
0
if __name__ == '__main__':
    train_data = sys.argv[1]
    val_data = sys.argv[2]
    utt2spk = sys.argv[3]
    pdfs = sys.argv[4]
    left_context = int(sys.argv[5])
    right_context = int(sys.argv[6])
    lda_path = sys.argv[7]
    output_path = sys.argv[8]

    num_epochs = 400
    batch_size = 1
    learning_rate = 0.0015

    utt_to_spk = load_utt_to_spk(utt2spk)
    utt_to_pdfs = load_utt_to_pdfs(pdfs)
    num_spks = max(utt_to_spk.values()) + 1

    train_dataset = load_sd_batchnorm_dataset(train_data, utt_to_spk, utt_to_pdfs, chunk_size=50, subsampling_factor=1, left_context=left_context, right_context=right_context)
    train_dataset = train_dataset.prefetch(1024)
    x, spk, y = train_dataset.make_one_shot_iterator().get_next()

    val_dataset = load_sd_batchnorm_dataset(val_data, utt_to_spk, utt_to_pdfs, chunk_size=50, subsampling_factor=1, left_context=left_context, right_context=right_context, si_prob=1.0)
    val_dataset = val_dataset.take(512).cache().repeat()
    val_x, val_spk, val_y = val_dataset.make_one_shot_iterator().get_next()

    model = create_sat_batchnorm_model(850, lda_path, num_spks)
    model.compile(
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'],
if __name__ == '__main__':
    train_data = sys.argv[1]
    val_data = sys.argv[2]
    utt2spk = sys.argv[3]
    pdfs = sys.argv[4]
    left_context = int(sys.argv[5])
    right_context = int(sys.argv[6])
    lda_path = sys.argv[7]
    output_path = sys.argv[8]

    num_epochs = 400
    batch_size = 256
    learning_rate = 0.0015
    model_type = 'SAT-LHUC'

    utt_to_spk = load_utt_to_spk(utt2spk, lambda x: "".join(x.split("-")[1:-1]))
    utt_to_pdfs = load_utt_to_pdfs(pdfs)
    num_spks = max(utt_to_spk.values()) + 1
    print "num_spks = %d" % num_spks

    train_dataset = load_dataset(train_data, utt_to_spk, utt_to_pdfs, chunk_size=8, subsampling_factor=1, left_context=left_context, right_context=right_context, speaker_independent_prob=0.5)
    train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
    train_dataset = train_dataset.prefetch(1024)
    x, spk, y = train_dataset.make_one_shot_iterator().get_next()

    val_dataset = load_dataset(val_data, utt_to_spk, utt_to_pdfs, chunk_size=8, subsampling_factor=1, left_context=left_context, right_context=right_context, speaker_independent_prob=1.0)
    val_dataset = val_dataset.batch(batch_size, drop_remainder=True)
    val_dataset = val_dataset.take(512).cache().repeat()
    val_x, val_spk, val_y = val_dataset.make_one_shot_iterator().get_next()

    if model_type == 'SAT-LHUC':