예제 #1
0
def main(args):
    checkpoint = torch.load(args.load_model)
    if 'params' in checkpoint:
        params = checkpoint['params']
    else:
        assert os.path.isfile(args.config), 'please specify a configure file.'
        with open(args.config, 'r') as f:
            params = yaml.load(f)

    params['data']['shuffle'] = False
    params['data']['spec_argument'] = False
    params['data']['short_first'] = False
    params['data']['batch_size'] = args.batch_size

    model = Transformer(params['model'])

    model.load_state_dict(checkpoint['model'])
    print('Load pre-trained model from %s' % args.load_model)

    model.eval()
    if args.ngpu > 0:
        model.cuda()

    char2unit = load_vocab(params['data']['vocab'])
    unit2char = {i: c for c, i in char2unit.items()}

    recognizer = TransformerRecognizer(model,
                                       unit2char=unit2char,
                                       beam_width=args.beam_width,
                                       max_len=args.max_len,
                                       penalty=args.penalty,
                                       lamda=args.lamda,
                                       ngpu=args.ngpu)

    # inputs_length: [len]
    inputs, inputs_length = calc_fbank(args.file, params['data'])
    if args.ngpu > 0:
        inputs = inputs.cuda()
        inputs_length = inputs_length.cuda()

    preds = recognizer.recognize(inputs, inputs_length)
    print('preds: {}'.format(preds[0].replace(' ', '')))
예제 #2
0
def main(args):

    checkpoint = torch.load(args.load_model)
    if 'params' in checkpoint:
        params = checkpoint['params']
    else:
        assert os.path.isfile(args.config), 'please specify a configure file.'
        with open(args.config, 'r') as f:
            params = yaml.load(f)

    params['data']['shuffle'] = False
    params['data']['spec_argument'] = False
    params['data']['short_first'] = False
    params['data']['batch_size'] = args.batch_size

    expdir = os.path.join('egs', params['data']['name'], 'exp',
                          params['train']['save_name'])
    if args.suffix is None:
        decode_dir = os.path.join(expdir, 'decode_%s' % args.decode_set)
    else:
        decode_dir = os.path.join(
            expdir, 'decode_%s_%s' % (args.decode_set, args.suffix))

    if not os.path.exists(decode_dir):
        os.makedirs(decode_dir)

    model = Transformer(params['model'])

    model.load_state_dict(checkpoint['model'])
    print('Load pre-trained model from %s' % args.load_model)

    model.eval()
    if args.ngpu > 0:
        model.cuda()

    char2unit = load_vocab(params['data']['vocab'])
    unit2char = {i: c for c, i in char2unit.items()}

    dataset = AudioDataset(params['data'], args.decode_set)
    data_loader = FeatureLoader(dataset)

    recognizer = TransformerRecognizer(model,
                                       unit2char=unit2char,
                                       beam_width=args.beam_width,
                                       max_len=args.max_len,
                                       penalty=args.penalty,
                                       lamda=args.lamda,
                                       ngpu=args.ngpu)

    totals = len(dataset)
    batch_size = params['data']['batch_size']
    writer = open(os.path.join(decode_dir, 'predict.txt'), 'w')
    for step, (utt_id, batch) in enumerate(data_loader.loader):

        if args.ngpu > 0:
            inputs = batch['inputs'].cuda()
            inputs_length = batch['inputs_length'].cuda()

        preds = recognizer.recognize(inputs, inputs_length)

        targets = batch['targets']
        targets_length = batch['targets_length']

        for b in range(len(preds)):
            n = step * batch_size + b
            truth = ' '.join(
                [unit2char[i.item()] for i in targets[b][1:targets_length[b]]])
            print('[%d / %d ] %s - pred : %s' %
                  (n, totals, utt_id[b], preds[b]))
            print('[%d / %d ] %s - truth: %s' % (n, totals, utt_id[b], truth))
            writer.write(utt_id[b] + ' ' + preds[b] + '\n')

    writer.close()
예제 #3
0
def main(args):

    checkpoint = torch.load(args.load_model)
    if 'params' in checkpoint:
        params = checkpoint['params']
    else:
        assert os.path.isfile(args.config), 'please specify a configure file.'
        with open(args.config, 'r') as f:
            params = yaml.load(f)

    params['data']['shuffle'] = False
    params['data']['spec_augment'] = False
    params['data']['short_first'] = False
    params['data']['batch_size'] = args.batch_size

    expdir = os.path.join('egs', params['data']['name'], 'exp',
                          params['train']['save_name'])
    decoder_set_name = 'decode_%s' % args.decode_set
    if args.load_language_model is not None:
        decoder_set_name += '_lm_lmw%.2f' % args.lm_weight
    if args.suffix is not None:
        decoder_set_name += '_%s' % args.suffix

    decode_dir = os.path.join(expdir, decoder_set_name)
    if not os.path.exists(decode_dir):
        os.makedirs(decode_dir)

    model = Transformer(params['model'])

    model.load_state_dict(checkpoint['model'])
    print('Load pre-trained model from %s' % args.load_model)

    model.eval()
    if args.ngpu > 0:
        model.cuda()

    if args.load_language_model is not None:
        lm_chkpt = torch.load(args.load_language_model)
        lm = TransformerLanguageModel(lm_chkpt['params']['model'])
        lm.load_state_dict(lm_chkpt['model'])
        lm.eval()
        if args.ngpu > 0: lm.cuda()
        print('Load pre-trained transformer language model from %s' %
              args.load_language_model)
    else:
        lm = None

    char2unit = load_vocab(params['data']['vocab'])
    unit2char = {i: c for c, i in char2unit.items()}

    data_loader = FeatureLoader(params, args.decode_set, is_eval=True)

    recognizer = TransformerRecognizer(model,
                                       lm=lm,
                                       lm_weight=args.lm_weight,
                                       unit2char=unit2char,
                                       beam_width=args.beam_width,
                                       max_len=args.max_len,
                                       penalty=args.penalty,
                                       lamda=args.lamda,
                                       ngpu=args.ngpu)

    totals = len(data_loader.dataset)
    batch_size = params['data']['batch_size']
    writer = open(os.path.join(decode_dir, 'predict.txt'), 'w')
    writerRef = open(os.path.join(decode_dir, 'reference.txt'), 'w')
    for step, (utt_id, batch) in enumerate(data_loader.loader):

        if args.ngpu > 0:
            inputs = batch['inputs'].cuda()
            inputs_length = batch['inputs_length'].cuda()
        else:
            inputs = batch['inputs']
            inputs_length = batch['inputs_length']

        preds = recognizer.recognize(inputs, inputs_length)

        targets = batch['targets']
        targets_length = batch['targets_length']

        for b in range(len(preds)):
            n = step * batch_size + b
            truth = ' '.join([
                unit2char[i.item()]
                for i in targets[b][1:targets_length[b] + 1]
            ])
            print('[%d / %d ] %s - pred : %s' %
                  (n, totals, utt_id[b], preds[b]))
            print('[%d / %d ] %s - truth: %s' % (n, totals, utt_id[b], truth))
            if utt_id[b][7] == '1':
                newpred = preds[b] + " (" + 'S1000' + "-" + utt_id[b] + ")"
                newtruth = truth + " (" + 'S1000' + "-" + utt_id[b] + ")"
            elif utt_id[b][7] == '2':
                newpred = preds[b] + " (" + 'S2000' + "-" + utt_id[b] + ")"
                newtruth = truth + " (" + 'S2000' + "-" + utt_id[b] + ")"
            elif utt_id[b][0] == 'O':
                newpred = preds[b] + " (" + 'O' + "-" + utt_id[b] + ")"
                newtruth = truth + " (" + 'O' + "-" + utt_id[b] + ")"
            else:
                newpred = preds[b] + " (" + utt_id[b][6:11] + "-" + utt_id[
                    b] + ")"
                newtruth = truth + " (" + utt_id[b][6:11] + "-" + utt_id[
                    b] + ")"

            writer.write(newpred + '\n')
            writerRef.write(newtruth + '\n')

    writer.close()
    writerRef.close()