示例#1
0
class InternalTypeField(Field[Nodes, Tensor]):
    def __init__(self, name: str, type: str) -> None:
        super().__init__(name, type)
        self.vocabulary = Vocabulary(unknown="<UNK>")

    def index(self, sample: Nodes) -> None:
        for node in sample.nodes:
            self.vocabulary.add_item(node.internal_type)

    def tensorize(self, sample: Nodes) -> Tensor:
        return tensor(
            self.vocabulary.get_indexes(node.internal_type
                                        for node in sample.nodes),
            dtype=torch_long,
        )

    def collate(self, tensors: Iterable[Tensor]) -> Tensor:
        return torch_cat(tensors=list(tensors), dim=0)

    def to(self, tensor: Tensor, device: torch_device) -> Tensor:
        return tensor.to(device)
示例#2
0
class RolesField(Field[Nodes, RolesFieldOutput]):
    def __init__(self, name: str, type: str) -> None:
        super().__init__(name, type)
        self.vocabulary = Vocabulary(unknown="<UNK>")

    def index(self, sample: Nodes) -> None:
        for node in sample.nodes:
            self.vocabulary.add_items(node.roles)

    def tensorize(self, sample: Nodes) -> RolesFieldOutput:
        roles_offsets = []
        roles: List[int] = []
        for node in sample.nodes:
            roles_offsets.append(len(roles))
            roles.extend(self.vocabulary.get_indexes(node.roles))
        return RolesFieldOutput(
            input=tensor(roles, dtype=torch_long),
            offsets=tensor(roles_offsets, dtype=torch_long),
        )

    def collate(self, tensors: Iterable[RolesFieldOutput]) -> RolesFieldOutput:
        tensors = list(tensors)
        offset = 0
        shifted_offsets = []
        for t in tensors:
            shifted_offsets.append(t.offsets + offset)
            offset += t.input.shape[0]
        return RolesFieldOutput(
            input=cat([t.input for t in tensors], dim=0),
            offsets=cat(shifted_offsets, dim=0),
        )

    def to(self, tensor: RolesFieldOutput,
           device: torch_device) -> RolesFieldOutput:
        return RolesFieldOutput(input=tensor.input.to(device),
                                offsets=tensor.offsets.to(device))
示例#3
0
class LabelField(Field[Nodes, LabelFieldOutput]):
    def __init__(self, name: str, type: str) -> None:
        super().__init__(name, type)
        self.vocabulary = Vocabulary(unknown="<UNK>")
        self.vocabulary.add_item("<PAD>")
        self.vocabulary.add_item("<GO>")
        self.vocabulary.add_item("<STOP>")

    def index(self, sample: Nodes) -> None:
        for node in sample.nodes:
            if node.internal_type == FORMATTING_INTERNAL_TYPE:
                self.vocabulary.add_items(
                    list(node.token if node.token is not None else "")
                )

    def tensorize(self, sample: Nodes) -> LabelFieldOutput:
        node_sequences = []
        for i, node in enumerate(sample.nodes):
            if node.internal_type == FORMATTING_INTERNAL_TYPE:
                mapped = self.vocabulary.get_indexes(
                    list(node.token if node.token else "")
                )
                labels = tensor(
                    mapped + [self.vocabulary.get_index("<STOP>")], dtype=torch_long
                )
                decoder_inputs = tensor(
                    [self.vocabulary.get_index("<GO>")] + mapped, dtype=torch_long
                )
                node_sequences.append((i, decoder_inputs, labels))
        node_sequences.sort(reverse=True, key=lambda s: s[1].shape[0])
        indexes, decoder_inputs_tensor, labels_tensor = map(list, zip(*node_sequences))
        assert len(indexes) == len(decoder_inputs_tensor) and len(indexes) == len(
            labels_tensor
        )
        return LabelFieldOutput(
            indexes=tensor(indexes, dtype=torch_long),
            decoder_inputs=pack_sequence(decoder_inputs_tensor),
            labels=pack_sequence(labels_tensor),
            n_nodes=len(sample.nodes),
        )

    def collate(self, tensors: Iterable[LabelFieldOutput]) -> LabelFieldOutput:
        inputs_list: List[Tuple[int, Tensor, Tensor]] = []
        offset = 0
        for t in tensors:
            for indexes, decoder_inputs, labels in zip(
                (t.indexes + offset).tolist(),
                unpack_packed_sequence(t.decoder_inputs),
                unpack_packed_sequence(t.labels),
            ):
                inputs_list.append((indexes, decoder_inputs, labels))
            offset += t.n_nodes
        inputs_list.sort(reverse=True, key=lambda t: t[1].shape[0])
        indexes, decoder_inputs_tensor, labels_tensor = map(list, zip(*inputs_list))
        return LabelFieldOutput(
            indexes=tensor(indexes, dtype=torch_long),
            decoder_inputs=pack_sequence(decoder_inputs_tensor),
            labels=pack_sequence(labels_tensor),
            n_nodes=offset,
        )

    def to(self, tensor: LabelFieldOutput, device: torch_device) -> LabelFieldOutput:
        return LabelFieldOutput(
            indexes=tensor.indexes.to(device),
            decoder_inputs=tensor.decoder_inputs.to(device),
            labels=tensor.labels.to(device),
            n_nodes=tensor.n_nodes,
        )