Exemplo n.º 1
0
def main():
    print_arguments(args)
    # 创建保存模型的文件夹
    if not os.path.exists(args.save_model_path):
        os.makedirs(args.save_model_path)
    # 加载数据字典
    with open(args.vocab_path, 'r', encoding='utf-8') as f:
        vocabulary = eval(f.read())
        vocabulary = "".join(vocabulary)
    # 获取模型
    model = GatedConv(vocabulary)
    # 加载预训练模型
    if args.restore_model:
        model = torch.load(args.restore_model)
    model = model.cuda()
    train(model=model,
          train_manifest_path=args.train_manifest_path,
          dev_manifest_path=args.dev_manifest_path,
          vocab_path=args.vocab_path,
          epochs=args.epochs,
          batch_size=args.batch_size,
          learning_rate=args.learning_rate)
Exemplo n.º 2
0
    decoder = GreedyDecoder(dataloader.dataset.labels_str)
    cer = 0
    print("decoding")
    with torch.no_grad():
        for i, (x, y, x_lens, y_lens) in tqdm(enumerate(dataloader)):
            x = x.cuda()
            outs, out_lens = model(x, x_lens)
            outs = F.softmax(outs, 1)
            outs = outs.transpose(1, 2)
            ys = []
            offset = 0
            for y_len in y_lens:
                ys.append(y[offset:offset + y_len])
                offset += y_len
            out_strings, out_offsets = decoder.decode(outs, out_lens)
            y_strings = decoder.convert_to_strings(ys)
            for pred, truth in zip(out_strings, y_strings):
                trans, ref = pred[0], truth[0]
                cer += decoder.cer(trans, ref) / float(len(ref))
        cer /= len(dataloader.dataset)
    model.train()
    return cer


if __name__ == "__main__":
    vocabulary = joblib.load(LABEL_PATH)
    vocabulary = "".join(vocabulary)
    model = GatedConv(vocabulary)
    model.cuda()
    train(model)