Exemplo n.º 1
0
    if args.word_dict is not None:
        fin = open(args.word_dict, 'r')
        word_dic = {}
        for line in fin:
            tokens = line.split()
            word_dic[''.join(tokens[1:])] = tokens[0]

    device = torch.device('cuda')

    mdic = torch.load(args.model_dic)
    model = Seq2Seq(**mdic['params']).to(device)
    model.load_state_dict(mdic['state'])
    model.eval()

    reader = ScpStreamReader(args.data_scp,
                             mean_sub=args.mean_sub,
                             downsample=args.downsample)
    reader.initialize()

    space, beam_size, max_len = args.space, args.beam_size, args.max_len
    start, win, stable_time = args.start_block, args.incl_block, args.stable_time
    head, padding = args.attn_head, args.attn_padding

    since = time.time()
    fctm = open(args.output, 'w')
    total_latency = 0
    count = 0
    with torch.no_grad():
        while True:
            utt, mat = reader.read_next_utt()
            if utt is None or utt == '': break
Exemplo n.º 2
0
                          sort_src=True,
                          max_len=args.max_len,
                          max_utt=args.max_utt,
                          mean_sub=args.mean_sub,
                          zero_pad=args.zero_pad,
                          fp16=args.fp16,
                          shuffle=args.shuffle,
                          spec_drop=args.spec_drop,
                          spec_bar=args.spec_bar,
                          time_stretch=args.time_stretch,
                          time_win=args.time_win)
    cv_reader = ScpStreamReader(args.valid_scp,
                                args.valid_target,
                                downsample=args.downsample,
                                sort_src=True,
                                max_len=args.max_len,
                                max_utt=args.max_utt,
                                mean_sub=args.mean_sub,
                                zero_pad=args.zero_pad,
                                fp16=args.fp16)

    tensorboard_writer = None
    if args.tensorboard:
        from torch.utils.tensorboard import SummaryWriter
        summary_name = '{}s2s-e{}d{}h{}'.format(args.tensorboard_dir,
                                                args.n_enc, args.n_dec,
                                                args.n_head)
        if args.use_cnn:
            summary_name += 'cnn'
        tensorboard_writer = SummaryWriter(summary_name)
Exemplo n.º 3
0
    models = []
    for dic_path in args.model_dic:
        sdic = torch.load(dic_path[0])
        m_params = sdic['params']
        if sdic['type'] == 'lstm':
            model = Seq2Seq(**m_params).to(device)
        elif sdic['type'] == 'tf':
            model = Transformer(**m_params).to(device)
        model.load_state_dict(sdic['state'])
        model.eval()
        models.append(model)
    model = Ensemble(models)

    reader = ScpStreamReader(args.data_scp,
                             mean_sub=args.mean_sub,
                             downsample=args.downsample)
    reader.initialize()

    since = time.time()
    batch_size = args.batch_size
    fout = open(args.output, 'w')
    while True:
        src_seq, src_mask, utts = reader.read_batch_utt(batch_size)
        if len(utts) == 0: break
        with torch.no_grad():
            src_seq, src_mask = src_seq.to(device), src_mask.to(device)
            hypos, scores = Decoder.beam_search(model, src_seq, src_mask,
                                                device, args.beam_size,
                                                args.max_len)
            hypos, scores = hypos.tolist(), scores.tolist()
Exemplo n.º 4
0
        dic = {}
        fin = open(args.dict, 'r')
        for line in fin:
            tokens = line.split()
            dic[int(tokens[1])] = tokens[0]

    use_gpu = torch.cuda.is_available()
    device = torch.device('cuda' if use_gpu else 'cpu')

    mdic = torch.load(args.model_dic)
    model = Seq2Seq(**mdic['params']).to(device)
    model.load_state_dict(mdic['state'])
    model.eval()

    reader = ScpStreamReader(args.source,
                             mean_sub=args.mean_sub,
                             sort_src=True,
                             downsample=args.downsample)
    reader.initialize()

    space, beam_size, max_len = args.space, args.beam_size, args.max_len
    win = args.win

    with torch.no_grad():
        while True:
            #src_seq, src_mask, tgt_seq = reader.next_batch(1)
            src_seq, src_mask, utts = reader.read_batch_utt(1)
            if len(utts) == 0: break
            utt = utts[0]

            fout = open(utt + '.info', 'w')
            time_len = src_seq.size(1)
Exemplo n.º 5
0
#!/usr/bin/env python3
# encoding: utf-8

# Copyright 2019 Thai-Son Nguyen
# Licensed under the Apache License, Version 2.0 (the "License")

import os
import argparse

from pynn.io.kaldi_seq import ScpStreamReader

parser = argparse.ArgumentParser(description='pynn')
parser.add_argument('--data-scp', help='path to data scp', required=True)
parser.add_argument('--output',
                    help='output file',
                    type=str,
                    default='data.len')

if __name__ == '__main__':
    args = parser.parse_args()

    loader = ScpStreamReader(args.data_scp)
    loader.initialize()

    fout = open(args.output, 'w')
    while True:
        utt_id, utt_mat = loader.read_next_utt()
        if utt_id is None or utt_id == '': break
        fout.write(utt_id + u' ' + str(len(utt_mat)) + os.linesep)
    fout.close()
Exemplo n.º 6
0
    tr_reader = ScpReader(args.train_scp,
                          args.train_target,
                          downsample=args.downsample,
                          max_len=args.max_len,
                          max_utt=args.max_utt,
                          shuffle=args.shuffle,
                          mean_sub=args.mean_sub,
                          fp16=args.fp16,
                          spec_drop=args.spec_drop,
                          spec_bar=args.spec_bar,
                          time_stretch=args.time_stretch,
                          time_win=args.time_win)
    cv_reader = ScpStreamReader(args.valid_scp,
                                args.valid_target,
                                downsample=args.downsample,
                                mean_sub=args.mean_sub,
                                max_len=args.max_len,
                                max_utt=args.max_utt,
                                fp16=args.fp16)

    cfg = {
        'model_path': args.model_path,
        'lr': args.lr,
        'label_smooth': args.label_smooth,
        'weight_decay': args.weight_decay,
        'teacher_force': args.teacher_force,
        'n_warmup': args.n_warmup,
        'n_const': args.n_const,
        'b_input': args.b_input,
        'b_update': args.b_update,
        'n_print': args.n_print