Beispiel #1
0
def _collate_fn_train(batch):
    # sort the batch by its feature length in a descending order
    batch = sorted(batch, key=lambda sample: sample[1], reverse=True)
    max_seqlength = batch[0][1]
    feat_dim = batch[0][0].size(1)
    minibatch_size = len(batch)
    feats = torch.zeros(minibatch_size, max_seqlength, feat_dim)
    feat_lengths = torch.zeros(minibatch_size, dtype=torch.int)
    graph_list = []
    utt_ids = []
    max_num_transitions = 0
    max_num_states = 0
    for i in range(minibatch_size):
        feat, length, utt_id, graph = batch[i]
        feats[i, :length, :].copy_(feat)
        utt_ids.append(utt_id)
        feat_lengths[i] = length
        graph_list.append(graph)
        if graph.num_transitions > max_num_transitions:
            max_num_transitions = graph.num_transitions
        if graph.num_states > max_num_states:
            max_num_states = graph.num_states
    num_graphs = ChainGraphBatch(graph_list,
                                 max_num_transitions=max_num_transitions,
                                 max_num_states=max_num_states)
    return feats, feat_lengths, utt_ids, num_graphs
Beispiel #2
0
 def merge(key):
     if key == "source":
         return speech_utils.collate_frames([s[key] for s in samples], 0.0)
     elif key == "target":
         max_num_transitions = max(s["target"].num_transitions for s in samples)
         max_num_states = max(s["target"].num_states for s in samples)
         return ChainGraphBatch(
             [s["target"] for s in samples],
             max_num_transitions=max_num_transitions,
             max_num_states=max_num_states,
         )
     else:
         raise ValueError("Invalid key.")
Beispiel #3
0
def _collate_fn(batch):
    # sort the batch by its feature length in a descending order
    batch = sorted(batch, key=lambda sample: sample[0].size(0), reverse=True)
    max_seqlength = batch[0][0].size(0)
    feat_dim = batch[0][0].size(1)
    minibatch_size = len(batch)
    feats = torch.zeros(minibatch_size, max_seqlength, feat_dim)
    graph_list = []
    max_num_transitions = 0
    max_num_states = 0
    for x in range(minibatch_size):
        sample = batch[x]
        feat, graph = sample
        feat_length = feat.size(0)
        feats[x].narrow(0, 0, feat_length).copy_(feat)
        graph_list.append(graph)
        if graph.num_transitions > max_num_transitions:
            max_num_transitions = graph.num_transitions
        if graph.num_states > max_num_states:
            max_num_states = graph.num_states
    num_graphs = ChainGraphBatch(graph_list,
                                 max_num_transitions=max_num_transitions,
                                 max_num_states=max_num_states)
    return feats, num_graphs
Beispiel #4
0
    def __len__(self):
        return len(self.samples)


if __name__ == '__main__':
    from torch.utils.data import DataLoader
    from pychain.loss import ChainLoss

    feat_dir = '/export/b08/yshao/kaldi/egs/wsj/s5/data/train_si284_spe2e_hires'
    tree_dir = '/export/b08/yshao/kaldi/egs/wsj/s5/exp/chain/e2e_tree'
    trainset = ChainDataset(feat_dir, tree_dir)
    trainloader = AudioDataLoader(trainset, batch_size=2, shuffle=True)

    feat, graphs = next(iter(trainloader))
    print(feat.size())
    den_fst_path = '/export/b08/yshao/kaldi/egs/wsj/s5/exp/chain/e2e_tree/den.fst'
    den_fst = simplefst.StdVectorFst.read(den_fst_path)
    den_graph = ChainGraph(den_fst, initial='recursive')
    print(den_graph.num_states)
    den_graph_batch = ChainGraphBatch(den_graph, batch_size=2)
    criterion = ChainLoss(den_graph_batch)
    torch.manual_seed(1)
    nnet_output = torch.randn(2, 10, 100)  # (B, T, D)
    nnet_output.requires_grad = True

    obj = criterion(nnet_output, graphs)
    obj.backward()
    print(obj)
    print(nnet_output.grad)