Ejemplo n.º 1
0
def _generate_compact_vectorized_lookup(
    ids: torch.LongTensor,
    label_to_id: Mapping[str, int],
) -> Tuple[Mapping[str, int], torch.LongTensor]:
    """
    Given a tensor of IDs and a label to ID mapping, retain only occurring IDs, and compact the mapping.

    Additionally returns a vectorized translation, i.e. a tensor `translation` of shape (max_old_id,) with
    `translation[old_id] = new_id` for all translated IDs and `translation[old_id] = -1` for non-occurring IDs.
    This allows to use `translation[ids]` to translate the IDs as a simple integer index based lookup.

    :param ids:
        The tensor of IDs.
    :param label_to_id:
        The label to ID mapping.

    :return:
        A tuple new_label_to_id, vectorized_translation.
    """
    # get existing IDs
    existing_ids = set(ids.view(-1).unique().tolist())
    # remove non-existing ID from label mapping
    label_to_id, old_to_new_id = compact_mapping(mapping={
        label: i
        for label, i in label_to_id.items()
        if i in existing_ids
    })
    # create translation tensor
    translation = torch.full(size=(max(existing_ids) + 1,), fill_value=-1)
    for old, new in old_to_new_id.items():
        translation[old] = new
    return label_to_id, translation
Ejemplo n.º 2
0
    def test_compact_mapping(self):
        """Test ``compact_mapping()``."""
        mapping = {letter: 2 * i for i, letter in enumerate(string.ascii_letters)}
        compacted_mapping, id_remapping = compact_mapping(mapping=mapping)

        # check correct value range
        self.assertEqual(set(compacted_mapping.values()), set(range(len(mapping))))
        self.assertEqual(set(id_remapping.keys()), set(mapping.values()))
        self.assertEqual(set(id_remapping.values()), set(compacted_mapping.values()))