예제 #1
0
파일: nli.py 프로젝트: tungk/cstlstm
 def collate(self, batch_data):
     # Create a forest from premises and hypotheses, in order
     premises = [NLP(x['sentence1'].rstrip()) for x in batch_data]
     hypotheses = [NLP(x['sentence2'].rstrip()) for x in batch_data]
     premises = [tree_batch.sent_to_tree(x) for x in premises]
     hypotheses = [tree_batch.sent_to_tree(x) for x in hypotheses]
     forest = tree_batch.Forest(premises + hypotheses)
     # Get the labels
     forest.labels = [LABEL_MAP[x['gold_label']] for x in batch_data]
     # Pre-lookup dictionary ixs - the encoder expects an attribute vocab_ix
     for node in forest.node_list:
         node.vocab_ix = self.vocab_dict[node.token]
     return forest
예제 #2
0
    def collate_fn(self, batch):
        batch_tree, batch_rgb_feats, batch_flow_feats, batch_vis_mask, batch_gt = [
            b for b in zip(*batch)
        ]

        batch_data = {
            'batch_tree': tree_batch.Forest(batch_tree),
            'batch_rgb_feats': torch.stack(batch_rgb_feats),
            'batch_flow_feats': torch.stack(batch_flow_feats),
            'batch_mask': torch.stack(batch_vis_mask),
            'batch_gt': batch_gt,
        }

        return batch_data
예제 #3
0
    def collate(self, batch_data):
        """For collating a batch of trees.

        Args:
          batch_data: List of tree_batch.Tree.

        Returns:
          tree_batch.Forest.
        """
        forest = tree_batch.Forest(batch_data)
        forest.labels = []

        # Setting annotation_ixs here necessary downstream and for labels
        forest.annotation_ixs = self.annotation_ixs(forest)

        # Get labels and pre-emptively perform dictionary lookup.
        for l in range(forest.max_level + 1):
            forest.labels += [int(forest.nodes[l][i].annotation)
                              for i in forest.annotation_ixs[l]]
            for node in [n for n in forest.nodes[l] if n.token]:
                node.vocab_ix = self.vocab_dict[node.token]

        return forest