def _generate_compact_vectorized_lookup( ids: torch.LongTensor, label_to_id: Mapping[str, int], ) -> Tuple[Mapping[str, int], torch.LongTensor]: """ Given a tensor of IDs and a label to ID mapping, retain only occurring IDs, and compact the mapping. Additionally returns a vectorized translation, i.e. a tensor `translation` of shape (max_old_id,) with `translation[old_id] = new_id` for all translated IDs and `translation[old_id] = -1` for non-occurring IDs. This allows to use `translation[ids]` to translate the IDs as a simple integer index based lookup. :param ids: The tensor of IDs. :param label_to_id: The label to ID mapping. :return: A tuple new_label_to_id, vectorized_translation. """ # get existing IDs existing_ids = set(ids.view(-1).unique().tolist()) # remove non-existing ID from label mapping label_to_id, old_to_new_id = compact_mapping(mapping={ label: i for label, i in label_to_id.items() if i in existing_ids }) # create translation tensor translation = torch.full(size=(max(existing_ids) + 1,), fill_value=-1) for old, new in old_to_new_id.items(): translation[old] = new return label_to_id, translation
def test_compact_mapping(self): """Test ``compact_mapping()``.""" mapping = {letter: 2 * i for i, letter in enumerate(string.ascii_letters)} compacted_mapping, id_remapping = compact_mapping(mapping=mapping) # check correct value range self.assertEqual(set(compacted_mapping.values()), set(range(len(mapping)))) self.assertEqual(set(id_remapping.keys()), set(mapping.values())) self.assertEqual(set(id_remapping.values()), set(compacted_mapping.values()))