Beispiel #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-config")
    parser.add_argument("-data", help="data yaml file")
    parser.add_argument("-data_path",
                        default='',
                        type=str,
                        help="path of data files")
    parser.add_argument("-seed_model", help="the seed nerual network model")
    parser.add_argument("-exp_dir", help="the directory to save the outputs")
    parser.add_argument("-transform",
                        help="feature transformation matrix or mvn statistics")
    parser.add_argument("-criterion",
                        type=str,
                        choices=["mmi", "mpfe", "smbr"],
                        help="set the sequence training crtierion")
    parser.add_argument(
        "-trans_model",
        help="the HMM transistion model, used for lattice generation")
    parser.add_argument(
        "-prior_path",
        help="the prior for decoder, usually named as final.occs in kaldi setup"
    )
    parser.add_argument(
        "-den_dir",
        help="the decoding graph directory to find HCLG and words.txt files")
    parser.add_argument("-lr", type=float, help="set the learning rate")
    parser.add_argument("-ce_ratio",
                        default=0.1,
                        type=float,
                        help="the ratio for ce regularization")
    parser.add_argument("-momentum",
                        default=0,
                        type=float,
                        help="set the momentum")
    parser.add_argument("-batch_size",
                        default=32,
                        type=int,
                        help="Override the batch size in the config")
    parser.add_argument("-data_loader_threads",
                        default=0,
                        type=int,
                        help="number of workers for data loading")
    parser.add_argument("-max_grad_norm",
                        default=5,
                        type=float,
                        help="max_grad_norm for gradient clipping")
    parser.add_argument("-sweep_size",
                        default=100,
                        type=float,
                        help="process n hours of data per sweep (default:60)")
    parser.add_argument("-num_epochs",
                        default=1,
                        type=int,
                        help="number of training epochs (default:1)")
    parser.add_argument('-print_freq',
                        default=10,
                        type=int,
                        metavar='N',
                        help='print frequency (default: 10)')
    parser.add_argument('-save_freq',
                        default=1000,
                        type=int,
                        metavar='N',
                        help='save model frequency (default: 1000)')

    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.safe_load(f)

    config['data_path'] = args.data_path

    config["sweep_size"] = args.sweep_size

    print("pytorch version:{}".format(th.__version__))

    with open(args.data) as f:
        data = yaml.safe_load(f)
        config["source_paths"] = [j for i, j in data['clean_source'].items()]

    print("Experiment starts with config {}".format(
        json.dumps(config, sort_keys=True, indent=4)))

    # Initialize Horovod
    hvd.init()

    th.cuda.set_device(hvd.local_rank())

    print("Run experiments with world size {}".format(hvd.size()))

    dataset = SpeechDataset(config)
    transform = None
    if args.transform is not None and os.path.isfile(args.transform):
        with open(args.transform, 'rb') as f:
            transform = pickle.load(f)
            dataset.transform = transform

    train_dataloader = SeqDataloader(dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.data_loader_threads,
                                     distributed=True,
                                     test_only=False)

    print("Data loader set up successfully!")
    print("Number of minibatches: {}".format(len(train_dataloader)))

    if not os.path.isdir(args.exp_dir):
        os.makedirs(args.exp_dir)

    # ceate model
    model_config = config["model_config"]
    lstm = LSTMStack(model_config["feat_dim"], model_config["hidden_size"],
                     model_config["num_layers"], model_config["dropout"], True)
    model = NnetAM(lstm, model_config["hidden_size"] * 2,
                   model_config["label_size"])

    model.cuda()

    # setup the optimizer
    optimizer = th.optim.SGD(model.parameters(),
                             lr=args.lr,
                             momentum=args.momentum)

    # Broadcast parameters and opterimizer state from rank 0 to all other processes.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # Add Horovod Distributed Optimizer
    optimizer = hvd.DistributedOptimizer(
        optimizer, named_parameters=model.named_parameters())

    if os.path.isfile(args.seed_model):
        checkpoint = th.load(args.seed_model)
        state_dict = checkpoint['model']
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove 'module.' of dataparallel
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
        print("=> loaded checkpoint '{}' ".format(args.seed_model))
    else:
        sys.stderr.write('ERROR: The model file %s does not exist!\n' %
                         (model_file))
        sys.exit(0)

    HCLG = args.den_dir + "/HCLG.fst"
    words_txt = args.den_dir + "/words.txt"
    silence_phones = args.den_dir + "/phones/silence.csl"

    if not os.path.isfile(HCLG):
        sys.stderr.write('ERROR: The HCLG file %s does not exist!\n' % (HCLG))
        sys.exit(0)

    if not os.path.isfile(words_txt):
        sys.stderr.write('ERROR: The words.txt file %s does not exist!\n' %
                         (words_txt))
        sys.exit(0)

    if not os.path.isfile(silence_phones):
        sys.stderr.write('ERROR: The silence phone file %s does not exist!\n' %
                         (silence_phones))
        sys.exit(0)
    with open(silence_phones) as f:
        silence_ids = [int(i) for i in f.readline().strip().split(':')]
        f.close()

    if os.path.isfile(args.trans_model):
        trans_model = kaldi_hmm.TransitionModel()
        with kaldi_util.io.xopen(args.trans_model) as ki:
            trans_model.read(ki.stream(), ki.binary)
    else:
        sys.stderr.write('ERROR: The trans_model %s does not exist!\n' %
                         (args.trans_model))
        sys.exit(0)

    # now we can setup the decoder
    decoder_opts = LatticeFasterDecoderOptions()
    decoder_opts.beam = config["decoder_config"]["beam"]
    decoder_opts.lattice_beam = config["decoder_config"]["lattice_beam"]
    decoder_opts.max_active = config["decoder_config"]["max_active"]
    acoustic_scale = config["decoder_config"]["acoustic_scale"]
    decoder_opts.determinize_lattice = False  #To produce raw state-level lattice instead of compact lattice
    asr_decoder = MappedLatticeFasterRecognizer.from_files(
        args.trans_model,
        HCLG,
        words_txt,
        acoustic_scale=acoustic_scale,
        decoder_opts=decoder_opts)

    prior = kaldi_util.io.read_matrix(args.prior_path).numpy()
    log_prior = th.tensor(np.log(prior[0] / np.sum(prior[0])), dtype=th.float)

    model.train()
    for epoch in range(args.num_epochs):

        run_train_epoch(model, optimizer, log_prior.cuda(), train_dataloader,
                        epoch, asr_decoder, trans_model, silence_ids, args)

        # save model
        if hvd.rank() == 0:
            checkpoint = {}
            checkpoint['model'] = model.state_dict()
            checkpoint['optimizer'] = optimizer.state_dict()
            checkpoint['epoch'] = epoch
            output_file = args.exp_dir + '/model.se.' + str(epoch) + '.tar'
            th.save(checkpoint, output_file)
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-config")
    parser.add_argument("-model_path")
    parser.add_argument("-data_path")
    parser.add_argument("-prior_path",
                        help="the path to load the final.occs file")
    parser.add_argument("-out_file",
                        help="write out the log-probs to this file")
    parser.add_argument("-transform",
                        help="feature transformation matrix or mvn statistics")
    parser.add_argument(
        "-trans_model",
        help="the HMM transistion model, used for lattice generation")
    parser.add_argument("-graph_dir", help="the decoding graph directory")
    parser.add_argument("-batch_size",
                        default=32,
                        type=int,
                        help="Override the batch size in the config")
    parser.add_argument("-sweep_size",
                        default=200,
                        type=float,
                        help="process n hours of data per sweep (default:60)")
    parser.add_argument("-data_loader_threads",
                        default=4,
                        type=int,
                        help="number of workers for data loading")

    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.safe_load(f)

    config["sweep_size"] = args.sweep_size

    config["source_paths"] = list()
    data_config = dict()

    data_config["type"] = "Eval"
    data_config["wav"] = args.data_path

    config["source_paths"].append(data_config)

    print("job starts with config {}".format(
        json.dumps(config, sort_keys=True, indent=4)))

    transform = None
    if args.transform is not None and os.path.isfile(args.transform):
        with open(args.transform, 'rb') as f:
            transform = pickle.load(f)

    dataset = SpeechDataset(config)
    #data = trainset.__getitem__(0)
    test_dataloader = SeqDataloader(dataset,
                                    batch_size=args.batch_size,
                                    test_only=True,
                                    global_mvn=True,
                                    transform=transform)

    print("Data loader set up successfully!")
    print("Number of minibatches: {}".format(len(test_dataloader)))

    # ceate model
    model_config = config["model_config"]
    lstm = LSTMStack(model_config["feat_dim"], model_config["hidden_size"],
                     model_config["num_layers"], model_config["dropout"], True)
    model = NnetAM(lstm, model_config["hidden_size"] * 2,
                   model_config["label_size"])

    device = th.device("cuda" if th.cuda.is_available() else "cpu")
    model.cuda()

    assert os.path.isfile(
        args.model_path), "ERROR: model file {} does not exit!".format(
            args.model_path)

    checkpoint = th.load(args.model_path, map_location='cuda:0')
    state_dict = checkpoint['model']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        header = k[:7]
        name = k[7:]  # remove 'module.' of dataparallel
        new_state_dict[name] = v
    if header == "module.":
        model.load_state_dict(new_state_dict)
    else:
        model.load_state_dict(state_dict)
    print("=> loaded checkpoint '{}' ".format(args.model_path))

    HCLG = args.graph_dir + "/HCLG.fst"
    words_txt = args.graph_dir + "/words.txt"

    if not os.path.isfile(HCLG):
        sys.stderr.write('ERROR: The HCLG file %s does not exist!\n' % (HCLG))
        sys.exit(0)

    if not os.path.isfile(words_txt):
        sys.stderr.write('ERROR: The words.txt file %s does not exist!\n' %
                         (words_txt))
        sys.exit(0)

    if os.path.isfile(args.trans_model):
        trans_model = kaldi_hmm.TransitionModel()
        with kaldi_util.io.xopen(args.trans_model) as ki:
            trans_model.read(ki.stream(), ki.binary)
    else:
        sys.stderr.write('ERROR: The trans_model %s does not exist!\n' %
                         (args.trans_model))
        sys.exit(0)

    prior = read_matrix(args.prior_path).numpy()
    log_prior = th.tensor(np.log(prior[0] / np.sum(prior[0])), dtype=th.float)

    # now we can setup the decoder
    decoder_opts = LatticeFasterDecoderOptions()
    decoder_opts.beam = config["decoder_config"]["beam"]
    decoder_opts.lattice_beam = config["decoder_config"]["lattice_beam"]
    decoder_opts.max_active = config["decoder_config"]["max_active"]
    acoustic_scale = config["decoder_config"]["acoustic_scale"]
    decoder_opts.determinize_lattice = True  #To produce compact lattice
    asr_decoder = MappedLatticeFasterRecognizer.from_files(
        args.trans_model,
        HCLG,
        words_txt,
        acoustic_scale=acoustic_scale,
        decoder_opts=decoder_opts)

    model.eval()
    with th.no_grad():
        with kaldi_util.table.CompactLatticeWriter("ark:" +
                                                   args.out_file) as lat_out:
            for data in test_dataloader:
                feat = data["x"]
                num_frs = data["num_frs"]
                utt_ids = data["utt_ids"]

                x = feat.to(th.float32)
                x = x.cuda()

                prediction = model(x)

                for j in range(len(num_frs)):
                    loglikes = prediction[j, :, :].data.cpu()

                    loglikes_j = loglikes[:num_frs[j], :]
                    loglikes_j = loglikes_j - log_prior

                    decoder_out = asr_decoder.decode(
                        kaldi_matrix.Matrix(loglikes_j.numpy()))

                    key = utt_ids[j][0]
                    print(key, decoder_out["text"])

                    print("Log-like per-frame for utterance {} is {}".format(
                        key, decoder_out["likelihood"] / num_frs[j]))

                    # save lattice
                    lat_out[key] = decoder_out["lattice"]