def measure(self, protocol: Protocol) -> float: vocab = list(get_vocab_from_protocol(protocol)) num_symbols = len(vocab) bow_protocol = {} for derivation, message in protocol.items(): message_bow = [0 for _ in range(num_symbols)] for symbol in message: message_bow[vocab.index(symbol)] += 1 bow_protocol[derivation] = [str(symbol) for symbol in message_bow] return super().measure(bow_protocol)
def _protocol_to_tensor( self, protocol: Protocol ) -> Dict[Tuple[torch.LongTensor, torch.LongTensor], torch.LongTensor]: vocab = get_vocab_from_protocol(protocol) concept_set = set(concept for derivation in protocol.keys() for concept in flatten_derivation(derivation)) concepts = {concept: idx for idx, concept in enumerate(concept_set)} tensorized_protocol = {} for derivation, message in protocol.items(): derivation = derivation_to_tensor(derivation, concepts) message = torch.LongTensor([vocab[char] for char in message]) tensorized_protocol[derivation] = torch.nn.functional.one_hot( message, num_classes=len(vocab)).reshape(-1) return tensorized_protocol