def _collate_fn_train(batch): # sort the batch by its feature length in a descending order batch = sorted(batch, key=lambda sample: sample[1], reverse=True) max_seqlength = batch[0][1] feat_dim = batch[0][0].size(1) minibatch_size = len(batch) feats = torch.zeros(minibatch_size, max_seqlength, feat_dim) feat_lengths = torch.zeros(minibatch_size, dtype=torch.int) graph_list = [] utt_ids = [] max_num_transitions = 0 max_num_states = 0 for i in range(minibatch_size): feat, length, utt_id, graph = batch[i] feats[i, :length, :].copy_(feat) utt_ids.append(utt_id) feat_lengths[i] = length graph_list.append(graph) if graph.num_transitions > max_num_transitions: max_num_transitions = graph.num_transitions if graph.num_states > max_num_states: max_num_states = graph.num_states num_graphs = ChainGraphBatch(graph_list, max_num_transitions=max_num_transitions, max_num_states=max_num_states) return feats, feat_lengths, utt_ids, num_graphs
def merge(key): if key == "source": return speech_utils.collate_frames([s[key] for s in samples], 0.0) elif key == "target": max_num_transitions = max(s["target"].num_transitions for s in samples) max_num_states = max(s["target"].num_states for s in samples) return ChainGraphBatch( [s["target"] for s in samples], max_num_transitions=max_num_transitions, max_num_states=max_num_states, ) else: raise ValueError("Invalid key.")
def _collate_fn(batch): # sort the batch by its feature length in a descending order batch = sorted(batch, key=lambda sample: sample[0].size(0), reverse=True) max_seqlength = batch[0][0].size(0) feat_dim = batch[0][0].size(1) minibatch_size = len(batch) feats = torch.zeros(minibatch_size, max_seqlength, feat_dim) graph_list = [] max_num_transitions = 0 max_num_states = 0 for x in range(minibatch_size): sample = batch[x] feat, graph = sample feat_length = feat.size(0) feats[x].narrow(0, 0, feat_length).copy_(feat) graph_list.append(graph) if graph.num_transitions > max_num_transitions: max_num_transitions = graph.num_transitions if graph.num_states > max_num_states: max_num_states = graph.num_states num_graphs = ChainGraphBatch(graph_list, max_num_transitions=max_num_transitions, max_num_states=max_num_states) return feats, num_graphs
def __len__(self): return len(self.samples) if __name__ == '__main__': from torch.utils.data import DataLoader from pychain.loss import ChainLoss feat_dir = '/export/b08/yshao/kaldi/egs/wsj/s5/data/train_si284_spe2e_hires' tree_dir = '/export/b08/yshao/kaldi/egs/wsj/s5/exp/chain/e2e_tree' trainset = ChainDataset(feat_dir, tree_dir) trainloader = AudioDataLoader(trainset, batch_size=2, shuffle=True) feat, graphs = next(iter(trainloader)) print(feat.size()) den_fst_path = '/export/b08/yshao/kaldi/egs/wsj/s5/exp/chain/e2e_tree/den.fst' den_fst = simplefst.StdVectorFst.read(den_fst_path) den_graph = ChainGraph(den_fst, initial='recursive') print(den_graph.num_states) den_graph_batch = ChainGraphBatch(den_graph, batch_size=2) criterion = ChainLoss(den_graph_batch) torch.manual_seed(1) nnet_output = torch.randn(2, 10, 100) # (B, T, D) nnet_output.requires_grad = True obj = criterion(nnet_output, graphs) obj.backward() print(obj) print(nnet_output.grad)