Пример #1
0
 def forward(self, inputs, supervision):
     (mfcc, ivec) = inputs
     # FIXME supervisions are distributed over devices. support multi host
     device = next(self.parameters()).device
     device_id = device.index
     # supervision = supervisions[device_id]
     n_pdf = supervision.n_pdf
     n_batch, n_freq, n_time_in = mfcc.shape
     # FIXME do not hardcode here
     n_time_out = math.floor((n_time_in - (29 - 1) - 1) / 3 + 1)
     lf_mmi_pred, xe_pred = self.model(mfcc.to(device), ivec.to(device))
     ref_shape = (n_batch, n_pdf, n_time_out)
     assert lf_mmi_pred.shape == ref_shape, "{} != {}".format(
         lf_mmi_pred.shape, ref_shape)
     io.set_kaldi_device(device_id)
     loss, results = chain_loss(
         lf_mmi_pred,
         self.den_graph,
         supervision,
         l2_regularize=self.args.l2_regularize,
         leaky_hmm_coefficient=self.args.leaky_hmm_coefficient,
         xent_regularize=self.args.xent_regularize,
         xent_input=xe_pred,
         kaldi_way=True)
     return pad_chain_loss(loss, results, n_batch)
Пример #2
0
 def forward(data, idx):
     (mfcc, ivec), supervision = data
     n_batch, n_freq, n_time_in = mfcc.shape
     # FIXME do not hardcode here
     n_time_out = math.floor((n_time_in - (29 - 1) -1) / 3 + 1)
     lf_mmi_pred, xe_pred = model(mfcc.cuda(), ivec.cuda())
     ref_shape = (n_batch, n_pdf, n_time_out)
     assert lf_mmi_pred.shape == ref_shape, "{} != {}".format(lf_mmi_pred.shape, ref_shape)
     return chain_loss(lf_mmi_pred, den_graph, supervision,
                       l2_regularize=args.l2_regularize,
                       leaky_hmm_coefficient=args.leaky_hmm_coefficient,
                       xent_regularize=args.xent_regularize,
                       xent_input=xe_pred, kaldi_way=True)
Пример #3
0
def test_io():
    exp_root = "/data/work70/skarita/exp/chime5/kaldi-22fbdd/egs/chime5/s5/"
    den_fst_rs = exp_root + "exp/chain_train_worn_u100k_cleaned/tdnn1a_sp/den.fst"
    cmd = "nnet3-chain-copy-egs --frame-shift=1  ark:/data/work49/skarita/repos/torchain/cegs.1.ark ark:- | nnet3-chain-shuffle-egs --buffer-size=5000 --srand=0 ark:- ark:- | nnet3-chain-merge-egs --minibatch-size=128,64,32 ark:- ark:-"
    with io.open_example(cmd) as example:
        idx = example.indexes
        print(idx.shape)
        print(idx[0])
        (mfcc, ivec), sup = example.value()
        print(mfcc.shape)
        sup = example.supervision
        assert sup.n_frame == idx.shape[1]

    for use_xent in [True]:
        for use_kaldi_way in [True]:
            print("xent: ", use_xent, "kaldi: ", use_kaldi_way)
            with io.open_example(cmd) as example:
                n_pdf = example.supervision.n_pdf
                print(n_pdf)
                # n_pdf = 2928
                den_graph = io.DenominatorGraph(den_fst_rs, n_pdf)
                model = Model(n_pdf)
                model.cuda()
                print(model)
                opt = torch.optim.SGD(model.parameters(), lr=1e-6)
                count = 0
                start = time.time()
                for (mfcc, ivec), supervision in example:
                    x = Variable(mfcc).cuda()
                    print("input:", x.shape)
                    pred, xent = model(x)
                    if not use_xent:
                        xent = None
                    loss, results = chain_loss(pred,
                                               den_graph,
                                               supervision,
                                               l2_regularize=0.01,
                                               xent_regularize=0.01,
                                               xent_input=xent,
                                               kaldi_way=use_kaldi_way)
                    opt.zero_grad()
                    loss.backward()
                    opt.step()
                    print(count, results)
                    count += 1
                    if count > 10:
                        break
                elapsed = time.time() - start
                print("took: %f" % (elapsed))
Пример #4
0
def test_rand_io():
    # scp_path = "/data/work49/skarita/repos/torch-backup/100.scp" # example/chime5/exp/scp/egs.scp"
    scp_path = "/data/work49/skarita/repos/torch-backup/example/chime5/exp/scp/train.scp"
    seed = 1
    exp_root = "/data/work70/skarita/exp/chime5/kaldi-22fbdd/egs/chime5/s5/"
    den_fst_rs = exp_root + "exp/chain_train_worn_u100k_cleaned/tdnn1a_sp/den.fst"

    for use_xent in [True]:
        for use_kaldi_way in [True]:
            print("xent: ", use_xent, "kaldi: ", use_kaldi_way)
            print("shuffling")
            io.set_kaldi_device()
            example = io.RandExample(scp_path, seed, 128)
            n_pdf = example.supervision.n_pdf
            print(n_pdf)
            # n_pdf = 2928
            den_graph = io.DenominatorGraph(den_fst_rs, n_pdf)
            model = Model(n_pdf)
            model.cuda()
            print(model)
            opt = torch.optim.SGD(model.parameters(), lr=1e-6)
            count = 0
            start = time.time()
            for (mfcc, ivec), supervision in example:
                x = Variable(mfcc).cuda()
                print("input:", x.shape)
                pred, xent = model(x)
                if not use_xent:
                    xent = None
                loss, results = chain_loss(pred,
                                           den_graph,
                                           supervision,
                                           l2_regularize=0.01,
                                           xent_regularize=0.01,
                                           xent_input=xent,
                                           kaldi_way=use_kaldi_way)
                opt.zero_grad()
                loss.backward()
                opt.step()
                print(count, results)
                count += 1
                if count > 10:
                    break
            elapsed = time.time() - start
            print("took: %f" % (elapsed))