class InternalTypeField(Field[Nodes, Tensor]): def __init__(self, name: str, type: str) -> None: super().__init__(name, type) self.vocabulary = Vocabulary(unknown="<UNK>") def index(self, sample: Nodes) -> None: for node in sample.nodes: self.vocabulary.add_item(node.internal_type) def tensorize(self, sample: Nodes) -> Tensor: return tensor( self.vocabulary.get_indexes(node.internal_type for node in sample.nodes), dtype=torch_long, ) def collate(self, tensors: Iterable[Tensor]) -> Tensor: return torch_cat(tensors=list(tensors), dim=0) def to(self, tensor: Tensor, device: torch_device) -> Tensor: return tensor.to(device)
class RolesField(Field[Nodes, RolesFieldOutput]): def __init__(self, name: str, type: str) -> None: super().__init__(name, type) self.vocabulary = Vocabulary(unknown="<UNK>") def index(self, sample: Nodes) -> None: for node in sample.nodes: self.vocabulary.add_items(node.roles) def tensorize(self, sample: Nodes) -> RolesFieldOutput: roles_offsets = [] roles: List[int] = [] for node in sample.nodes: roles_offsets.append(len(roles)) roles.extend(self.vocabulary.get_indexes(node.roles)) return RolesFieldOutput( input=tensor(roles, dtype=torch_long), offsets=tensor(roles_offsets, dtype=torch_long), ) def collate(self, tensors: Iterable[RolesFieldOutput]) -> RolesFieldOutput: tensors = list(tensors) offset = 0 shifted_offsets = [] for t in tensors: shifted_offsets.append(t.offsets + offset) offset += t.input.shape[0] return RolesFieldOutput( input=cat([t.input for t in tensors], dim=0), offsets=cat(shifted_offsets, dim=0), ) def to(self, tensor: RolesFieldOutput, device: torch_device) -> RolesFieldOutput: return RolesFieldOutput(input=tensor.input.to(device), offsets=tensor.offsets.to(device))
class LabelField(Field[Nodes, LabelFieldOutput]): def __init__(self, name: str, type: str) -> None: super().__init__(name, type) self.vocabulary = Vocabulary(unknown="<UNK>") self.vocabulary.add_item("<PAD>") self.vocabulary.add_item("<GO>") self.vocabulary.add_item("<STOP>") def index(self, sample: Nodes) -> None: for node in sample.nodes: if node.internal_type == FORMATTING_INTERNAL_TYPE: self.vocabulary.add_items( list(node.token if node.token is not None else "") ) def tensorize(self, sample: Nodes) -> LabelFieldOutput: node_sequences = [] for i, node in enumerate(sample.nodes): if node.internal_type == FORMATTING_INTERNAL_TYPE: mapped = self.vocabulary.get_indexes( list(node.token if node.token else "") ) labels = tensor( mapped + [self.vocabulary.get_index("<STOP>")], dtype=torch_long ) decoder_inputs = tensor( [self.vocabulary.get_index("<GO>")] + mapped, dtype=torch_long ) node_sequences.append((i, decoder_inputs, labels)) node_sequences.sort(reverse=True, key=lambda s: s[1].shape[0]) indexes, decoder_inputs_tensor, labels_tensor = map(list, zip(*node_sequences)) assert len(indexes) == len(decoder_inputs_tensor) and len(indexes) == len( labels_tensor ) return LabelFieldOutput( indexes=tensor(indexes, dtype=torch_long), decoder_inputs=pack_sequence(decoder_inputs_tensor), labels=pack_sequence(labels_tensor), n_nodes=len(sample.nodes), ) def collate(self, tensors: Iterable[LabelFieldOutput]) -> LabelFieldOutput: inputs_list: List[Tuple[int, Tensor, Tensor]] = [] offset = 0 for t in tensors: for indexes, decoder_inputs, labels in zip( (t.indexes + offset).tolist(), unpack_packed_sequence(t.decoder_inputs), unpack_packed_sequence(t.labels), ): inputs_list.append((indexes, decoder_inputs, labels)) offset += t.n_nodes inputs_list.sort(reverse=True, key=lambda t: t[1].shape[0]) indexes, decoder_inputs_tensor, labels_tensor = map(list, zip(*inputs_list)) return LabelFieldOutput( indexes=tensor(indexes, dtype=torch_long), decoder_inputs=pack_sequence(decoder_inputs_tensor), labels=pack_sequence(labels_tensor), n_nodes=offset, ) def to(self, tensor: LabelFieldOutput, device: torch_device) -> LabelFieldOutput: return LabelFieldOutput( indexes=tensor.indexes.to(device), decoder_inputs=tensor.decoder_inputs.to(device), labels=tensor.labels.to(device), n_nodes=tensor.n_nodes, )