def measure(self, protocol: Protocol) -> float:
        disentanglement_scores = []
        non_constant_positions = 0

        for j in range(self.max_message_length):
            symbols_j = [message[j] for message in protocol.values()]
            symbol_mutual_info = []
            symbol_entropy = compute_entropy(symbols_j)
            for i in range(self.num_concept_slots):
                concepts_i = [
                    flatten_derivation(derivation)[i]
                    for derivation in protocol.keys()
                ]
                mutual_info = compute_mutual_information(concepts_i, symbols_j)
                symbol_mutual_info.append(mutual_info)
            symbol_mutual_info.sort(reverse=True)

            if symbol_entropy > 0:
                disentanglement_score = (
                    symbol_mutual_info[0] -
                    symbol_mutual_info[1]) / symbol_entropy
                disentanglement_scores.append(disentanglement_score)
                non_constant_positions += 1
            if non_constant_positions > 0:
                return sum(disentanglement_scores) / non_constant_positions
            else:
                return np.nan
 def measure(self, protocol: Protocol) -> float:
     distance_messages = self._compute_distances(
         sequence=list(protocol.values()),
         metric=self.messages_metric)
     distance_inputs = self._compute_distances(
         sequence=[flatten_derivation(derivation) for derivation in protocol.keys()],
         metric=self.input_metric)
     return spearmanr(distance_messages, distance_inputs).correlation
示例#3
0
 def _protocol_to_tensor(
     self, protocol: Protocol
 ) -> Dict[Tuple[torch.LongTensor, torch.LongTensor], torch.LongTensor]:
     vocab = get_vocab_from_protocol(protocol)
     concept_set = set(concept for derivation in protocol.keys()
                       for concept in flatten_derivation(derivation))
     concepts = {concept: idx for idx, concept in enumerate(concept_set)}
     tensorized_protocol = {}
     for derivation, message in protocol.items():
         derivation = derivation_to_tensor(derivation, concepts)
         message = torch.LongTensor([vocab[char] for char in message])
         tensorized_protocol[derivation] = torch.nn.functional.one_hot(
             message, num_classes=len(vocab)).reshape(-1)
     return tensorized_protocol
示例#4
0
 def _compute_concept_symbol_matrix(self,
                                    protocol: Protocol,
                                    vocab: Dict[str, int],
                                    concepts: Dict[str, int],
                                    epsilon: float = 10e-8) -> np.ndarray:
     concept_to_message = defaultdict(list)
     for derivation, message in protocol.items():
         for concept in flatten_derivation(derivation):
             concept_to_message[concept] += list(message)
     concept_symbol_matrix = np.ndarray((self.num_concepts, len(vocab)))
     concept_symbol_matrix.fill(epsilon)
     for concept, symbols in concept_to_message.items():
         for symbol in symbols:
             concept_symbol_matrix[concepts[concept], vocab[symbol]] += 1
     return concept_symbol_matrix
示例#5
0
    def measure(self, protocol: Protocol) -> float:
        character_set = set(c for message in protocol.values()
                            for c in message)
        vocab = {char: idx for idx, char in enumerate(character_set)}
        concept_set = set(concept for concepts in protocol.keys()
                          for concept in flatten_derivation(concepts))
        concepts = {concept: idx for idx, concept in enumerate(concept_set)}

        concept_symbol_matrix = self._compute_concept_symbol_matrix(
            protocol, vocab, concepts)
        v_cs = concept_symbol_matrix.argmax(axis=1)
        context_independence_scores = np.zeros(len(concept_set))
        for concept in range(len(concept_set)):
            v_c = v_cs[concept]
            p_vc_c = concept_symbol_matrix[
                concept, v_c] / concept_symbol_matrix[concept, :].sum(axis=0)
            p_c_vc = concept_symbol_matrix[
                concept, v_c] / concept_symbol_matrix[:, v_c].sum(axis=0)
            context_independence_scores[concept] = p_vc_c * p_c_vc
        return context_independence_scores.mean(axis=0)
 def measure(self, protocol: Protocol) -> float:
     vocab = get_vocab_from_protocol(protocol)
     sender = FixedProtocolSender(protocol, vocab)
     receiver = RnnReceiverDeterministic(
         agent=Receiver(NN_CONFIG['receiver_hidden'], NN_CONFIG['num_features']),
         vocab_size=len(vocab) + 1,
         embed_dim=NN_CONFIG['receiver_emb'],
         hidden_size=NN_CONFIG['receiver_hidden'],
         cell=NN_CONFIG['receiver_cell'],
         num_layers=NN_CONFIG['cell_layers']
     )
     game = SenderReceiverRnnReinforce(sender, receiver, loss_nll, sender_entropy_coeff=0, receiver_entropy_coeff=0.05)
     if self.context_sensitive:
         concept_set = set(concept for derivation in protocol.keys() for concept in flatten_derivation(derivation)[1:])
     else:
         concept_set = set(concept for derivation in protocol.keys() for concept in flatten_derivation(derivation))
     concepts = {concept: idx for idx, concept in enumerate(concept_set)}
     if self.context_sensitive:
         derivations = [(derivation, derivation_to_tensor(derivation[1], concepts)) for derivation in protocol.keys()]
     else:
         derivations = [(derivation, derivation_to_tensor(derivation, concepts)) for derivation in protocol.keys()]
     shuffle(derivations)
     split_idx = int(len(derivations)*0.8)
     train_derivations, test_derivations = derivations[:split_idx], derivations[split_idx:]
     test_accuracy = self._train_and_test(game, train_derivations, test_derivations)
     return test_accuracy