def forward(self, inputs, supervision): (mfcc, ivec) = inputs # FIXME supervisions are distributed over devices. support multi host device = next(self.parameters()).device device_id = device.index # supervision = supervisions[device_id] n_pdf = supervision.n_pdf n_batch, n_freq, n_time_in = mfcc.shape # FIXME do not hardcode here n_time_out = math.floor((n_time_in - (29 - 1) - 1) / 3 + 1) lf_mmi_pred, xe_pred = self.model(mfcc.to(device), ivec.to(device)) ref_shape = (n_batch, n_pdf, n_time_out) assert lf_mmi_pred.shape == ref_shape, "{} != {}".format( lf_mmi_pred.shape, ref_shape) io.set_kaldi_device(device_id) loss, results = chain_loss( lf_mmi_pred, self.den_graph, supervision, l2_regularize=self.args.l2_regularize, leaky_hmm_coefficient=self.args.leaky_hmm_coefficient, xent_regularize=self.args.xent_regularize, xent_input=xe_pred, kaldi_way=True) return pad_chain_loss(loss, results, n_batch)
def forward(data, idx): (mfcc, ivec), supervision = data n_batch, n_freq, n_time_in = mfcc.shape # FIXME do not hardcode here n_time_out = math.floor((n_time_in - (29 - 1) -1) / 3 + 1) lf_mmi_pred, xe_pred = model(mfcc.cuda(), ivec.cuda()) ref_shape = (n_batch, n_pdf, n_time_out) assert lf_mmi_pred.shape == ref_shape, "{} != {}".format(lf_mmi_pred.shape, ref_shape) return chain_loss(lf_mmi_pred, den_graph, supervision, l2_regularize=args.l2_regularize, leaky_hmm_coefficient=args.leaky_hmm_coefficient, xent_regularize=args.xent_regularize, xent_input=xe_pred, kaldi_way=True)
def test_io(): exp_root = "/data/work70/skarita/exp/chime5/kaldi-22fbdd/egs/chime5/s5/" den_fst_rs = exp_root + "exp/chain_train_worn_u100k_cleaned/tdnn1a_sp/den.fst" cmd = "nnet3-chain-copy-egs --frame-shift=1 ark:/data/work49/skarita/repos/torchain/cegs.1.ark ark:- | nnet3-chain-shuffle-egs --buffer-size=5000 --srand=0 ark:- ark:- | nnet3-chain-merge-egs --minibatch-size=128,64,32 ark:- ark:-" with io.open_example(cmd) as example: idx = example.indexes print(idx.shape) print(idx[0]) (mfcc, ivec), sup = example.value() print(mfcc.shape) sup = example.supervision assert sup.n_frame == idx.shape[1] for use_xent in [True]: for use_kaldi_way in [True]: print("xent: ", use_xent, "kaldi: ", use_kaldi_way) with io.open_example(cmd) as example: n_pdf = example.supervision.n_pdf print(n_pdf) # n_pdf = 2928 den_graph = io.DenominatorGraph(den_fst_rs, n_pdf) model = Model(n_pdf) model.cuda() print(model) opt = torch.optim.SGD(model.parameters(), lr=1e-6) count = 0 start = time.time() for (mfcc, ivec), supervision in example: x = Variable(mfcc).cuda() print("input:", x.shape) pred, xent = model(x) if not use_xent: xent = None loss, results = chain_loss(pred, den_graph, supervision, l2_regularize=0.01, xent_regularize=0.01, xent_input=xent, kaldi_way=use_kaldi_way) opt.zero_grad() loss.backward() opt.step() print(count, results) count += 1 if count > 10: break elapsed = time.time() - start print("took: %f" % (elapsed))
def test_rand_io(): # scp_path = "/data/work49/skarita/repos/torch-backup/100.scp" # example/chime5/exp/scp/egs.scp" scp_path = "/data/work49/skarita/repos/torch-backup/example/chime5/exp/scp/train.scp" seed = 1 exp_root = "/data/work70/skarita/exp/chime5/kaldi-22fbdd/egs/chime5/s5/" den_fst_rs = exp_root + "exp/chain_train_worn_u100k_cleaned/tdnn1a_sp/den.fst" for use_xent in [True]: for use_kaldi_way in [True]: print("xent: ", use_xent, "kaldi: ", use_kaldi_way) print("shuffling") io.set_kaldi_device() example = io.RandExample(scp_path, seed, 128) n_pdf = example.supervision.n_pdf print(n_pdf) # n_pdf = 2928 den_graph = io.DenominatorGraph(den_fst_rs, n_pdf) model = Model(n_pdf) model.cuda() print(model) opt = torch.optim.SGD(model.parameters(), lr=1e-6) count = 0 start = time.time() for (mfcc, ivec), supervision in example: x = Variable(mfcc).cuda() print("input:", x.shape) pred, xent = model(x) if not use_xent: xent = None loss, results = chain_loss(pred, den_graph, supervision, l2_regularize=0.01, xent_regularize=0.01, xent_input=xent, kaldi_way=use_kaldi_way) opt.zero_grad() loss.backward() opt.step() print(count, results) count += 1 if count > 10: break elapsed = time.time() - start print("took: %f" % (elapsed))