Exemplo n.º 1
0
    def __init_model(self):
        model = Transducer(self.config.model)

        checkpoint = torch.load(self.config.training.load_model)
        model.encoder.load_state_dict(checkpoint['encoder'])
        model.decoder.load_state_dict(checkpoint['decoder'])
        model.joint.load_state_dict(checkpoint['joint'])

        model = model.cuda()
        model.eval()
        print("已加载模型")
        self.model = model
Exemplo n.º 2
0
def init_model():
    config_file = open("config/joint_streaming.yaml", encoding='utf-8')
    config = AttrDict(yaml.load(config_file, Loader=yaml.FullLoader))
    model = Transducer(config.model)
    checkpoint = torch.load(config.training.load_model)
    model.encoder.load_state_dict(checkpoint['encoder'])
    model.decoder.load_state_dict(checkpoint['decoder'])
    model.joint.load_state_dict(checkpoint['joint'])
    model = model.cuda()
    model.eval()
    print("已加载模型")

    vocab = {}
    with open(config.data.vocab, "r") as f:
        for line in f:
            parts = line.strip().split()
            word = parts[0]
            index = int(parts[1])
            vocab[index] = word
    print("已加载词典")

    return model, vocab
Exemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-config', type=str, default='config/thchs30.yaml')
    parser.add_argument('-log', type=str, default='train.log')
    parser.add_argument('-mode', type=str, default='retrain')
    opt = parser.parse_args()

    configfile = open(opt.config)
    config = AttrDict(yaml.load(configfile, Loader=yaml.FullLoader))

    exp_name = os.path.join('egs', config.data.name, 'exp', config.training.save_model)
    if not os.path.isdir(exp_name):
        os.makedirs(exp_name)
    logger = init_logger(os.path.join(exp_name, opt.log))

    shutil.copyfile(opt.config, os.path.join(exp_name, 'config.yaml'))
    logger.info('Save config info.')

    num_workers = config.training.num_gpu * 2
    train_dataset = AudioDataset(config.data, 'train')
    training_data = torch.utils.data.DataLoader(
        train_dataset, batch_size=config.data.batch_size * config.training.num_gpu,
        shuffle=config.data.shuffle, num_workers=num_workers)
    logger.info('Load Train Set!')

    dev_dataset = AudioDataset(config.data, 'dev')
    validate_data = torch.utils.data.DataLoader(
        dev_dataset, batch_size=config.data.batch_size * config.training.num_gpu,
        shuffle=False, num_workers=num_workers)
    logger.info('Load Dev Set!')

    if config.training.num_gpu > 0:
        torch.cuda.manual_seed(config.training.seed)
        torch.backends.cudnn.deterministic = True
    else:
        torch.manual_seed(config.training.seed)
    logger.info('Set random seed: %d' % config.training.seed)

    model = Transducer(config.model)

    if config.training.load_model:
        checkpoint = torch.load(config.training.load_model)
        model.encoder.load_state_dict(checkpoint['encoder'])
        model.decoder.load_state_dict(checkpoint['decoder'])
        model.joint.load_state_dict(checkpoint['joint'])
        logger.info('Loaded model from %s' % config.training.load_model)
    elif config.training.load_encoder or config.training.load_decoder:
        if config.training.load_encoder:
            checkpoint = torch.load(config.training.load_encoder)
            model.encoder.load_state_dict(checkpoint['encoder'])
            logger.info('Loaded encoder from %s' %
                        config.training.load_encoder)
        if config.training.load_decoder:
            checkpoint = torch.load(config.training.load_decoder)
            model.decoder.load_state_dict(checkpoint['decoder'])
            logger.info('Loaded decoder from %s' %
                        config.training.load_decoder)

    if config.training.num_gpu > 0:
        model = model.cuda()
        if config.training.num_gpu > 1:
            device_ids = list(range(config.training.num_gpu))
            model = torch.nn.DataParallel(model, device_ids=device_ids)
        logger.info('Loaded the model to %d GPUs' % config.training.num_gpu)

    n_params, enc, dec = count_parameters(model)
    logger.info('# the number of parameters in the whole model: %d' % n_params)
    logger.info('# the number of parameters in the Encoder: %d' % enc)
    logger.info('# the number of parameters in the Decoder: %d' % dec)
    logger.info('# the number of parameters in the JointNet: %d' %
                (n_params - dec - enc))

    optimizer = Optimizer(model.parameters(), config.optim)
    logger.info('Created a %s optimizer.' % config.optim.type)

    if opt.mode == 'continue':
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        logger.info('Load Optimizer State!')
    else:
        start_epoch = 0

    # create a visualizer
    if config.training.visualization:
        visualizer = SummaryWriter(os.path.join(exp_name, 'log'))
        logger.info('Created a visualizer.')
    else:
        visualizer = None

    for epoch in range(start_epoch, config.training.epochs):

        train(epoch, config, model, training_data,
              optimizer, logger, visualizer)

        if config.training.eval_or_not:
            _ = eval(epoch, config, model, validate_data, logger, visualizer)

        save_name = os.path.join(exp_name, '%s.epoch%d.chkpt' % (config.training.save_model, epoch))
        save_model(model, optimizer, config, save_name)
        logger.info('Epoch %d model has been saved.' % epoch)

        if epoch >= config.optim.begin_to_adjust_lr:
            optimizer.decay_lr()
            # early stop
            if optimizer.lr < 1e-6:
                logger.info('The learning rate is too low to train.')
                break
            logger.info('Epoch %d update learning rate: %.6f' %
                        (epoch, optimizer.lr))

    logger.info('The training process is OVER!')
Exemplo n.º 4
0
# @Contact : [email protected]
# @Time    : 2020/8/5 下午3:11
import os
import yaml
import torch
from tt.model import Transducer
from tt.utils import AttrDict, read_wave_from_file, get_feature, concat_frame, subsampling, context_mask

os.chdir('../')

WAVE_OUTPUT_FILENAME = 'audio/5_1812_20170628135834.wav'

# 加载模型
config_file = open("config/joint_streaming.yaml")
config = AttrDict(yaml.load(config_file, Loader=yaml.FullLoader))
model = Transducer(config.model)

checkpoint = torch.load(config.training.load_model)
model.encoder.load_state_dict(checkpoint['encoder'])
model.decoder.load_state_dict(checkpoint['decoder'])
model.joint.load_state_dict(checkpoint['joint'])
print('加载模型')
model.eval()

# 获取音频特征
audio, fr = read_wave_from_file(WAVE_OUTPUT_FILENAME)
feature = get_feature(audio, fr)
feature = concat_frame(feature, 3, 0)
feature = subsampling(feature, 3)
feature = torch.from_numpy(feature)
feature = torch.unsqueeze(feature, 0)
Exemplo n.º 5
0
def eval():
    """Decode with the given args.

    Args:
        args (namespace): The program arguments.

    """
    parser = get_parser()
    args = parser.parse_args()

    # load yaml config file
    configfile = open(args.config)
    config = AttrDict(yaml.load(configfile, Loader=yaml.FullLoader))

    # read json data
    with open(config.data.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]

    # read idim and odim from json data
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]["input"][0]["shape"][-1])
    odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])
    # model, train_args = load_trained_model(args.model)
    from tt.model import Transducer
    model = Transducer(idim, odim, args)
    checkpoint = torch.load(args.model)
    model.encoder.load_state_dict(checkpoint['encoder'])
    model.decoder.load_state_dict(checkpoint['decoder'])
    model.recog_args = args

    logging.info(
        " Total parameter of the model = "
        + str(sum(p.numel() for p in model.parameters()))
    )

    rnnlm = None

    if config.data.vocab is not None:
        with open(config.data.vocab, "rb") as f:
            dictionary = f.readlines()
        char_list = [entry.decode("utf-8").split(" ")[0] for entry in dictionary]
        char_list.insert(0, "<blank>")
        char_list.append("<eos>")
        args.char_list = char_list

    # gpu
    if config.ngpu == 1:
        gpu_id = list(range(args.ngpu))
        logging.info("gpu id: " + str(gpu_id))
        model.cuda()
        if rnnlm:
            rnnlm.cuda()

    new_js = {}

    load_inputs_and_targets = LoadInputsAndTargets(
        mode="asr",
        load_output=False,
        sort_in_input_length=False,
        preprocess_args={"train": False},
    )
    # model.eval()  ## training: false
    avg_cer = []
    with torch.no_grad():  ## training: false
        for idx, name in enumerate(valid_json.keys(), 1):
            logging.info("(%d/%d) decoding " + name, idx, len(valid_json.keys()))
            batch = [(name, valid_json[name])]
            feat = load_inputs_and_targets(batch)
            feat = (
                feat[0][0]
            )
            nbest_hyps = model.recognize(
                    feat, args, args.char_list, rnnlm
                )
            new_js[name] = add_results_to_json(
                    valid_json[name], nbest_hyps, args.char_list
                )

            hyp_chars = new_js[name]['output'][0]['rec_text'].replace(" ", "")
            ref_chars = valid_json[name]['output'][0]['text'].replace(" ", "")
            char_eds = editdistance.eval(hyp_chars, ref_chars)

            cer = float(char_eds / len(ref_chars)) * 100
            avg_cer.append(cer)
            logging.info("{} cer: {}".format(name, cer))

    logging.info('avg_cer:{}'.format(mean(np.array(avg_cer))))

    with open('result.txt', "wb") as f:
        f.write(
            json.dumps(
                {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
            ).encode("utf_8")
        )