コード例 #1
0
ファイル: train.py プロジェクト: Slowika/GameBias-EmeCom2020
def dump(game, n_features, device, gs_mode):
    # tiny "dataset"
    dataset = [[torch.eye(n_features).to(device), None]]

    interaction = core.dump_interactions(game, dataset, gs=gs_mode, device=device, variable_length=True)

    unif_acc = 0.
    powerlaw_acc = 0.
    powerlaw_probs = 1 / np.arange(1, n_features+1, dtype=np.float32)
    powerlaw_probs /= powerlaw_probs.sum()

    for i in range(interaction.size):
        sender_input = interaction.sender_input[i]
        message = interaction.message[i]
        receiver_output = interaction.receiver_output[i]

        input_symbol = sender_input.argmax()
        output_symbol = receiver_output.argmax()
        acc = (input_symbol == output_symbol).float().item()

        unif_acc += acc
        powerlaw_acc += powerlaw_probs[input_symbol] * acc
        print(f'input: {input_symbol.item()} -> message: {",".join([str(x.item()) for x in message])} -> output: {output_symbol.item()}', flush=True)

    unif_acc /= n_features

    print(f'Mean accuracy wrt uniform distribution is {unif_acc}')
    print(f'Mean accuracy wrt powerlaw distribution is {powerlaw_acc}')
    print(json.dumps({'powerlaw': powerlaw_acc, 'unif': unif_acc}))
コード例 #2
0
ファイル: game.py プロジェクト: stringguardian15/EGG
def dump(game, dataset, device, is_gs):
    interaction = \
        core.dump_interactions(game, dataset, gs=is_gs, device=device, variable_length=True)

    for i in range(interaction.size):
        sender_input = interaction.sender_input[i]
        message = interaction.message[i]
        receiver_output = interaction.receiver_output[i]
        label = interaction.labels[i]
        length = interaction.message_length[i].long().item()

        sender_input = ' '.join(map(str, sender_input.tolist()))
        message = ' '.join(map(str, message[:length].tolist()))
        if is_gs: receiver_output = receiver_output.argmax()
        print(f'{sender_input};{message};{receiver_output};{label.item()}')
コード例 #3
0
ファイル: intervention.py プロジェクト: wedddy0707/EGG
    def validation(self, game):
        interactions = core.dump_interactions(
            game,
            self.dataset,
            gs=self.is_gs,
            device=self.device,
            variable_length=self.var_length,
        )

        messages = [interactions.message[i] for i in range(interactions.size)]
        entropy_messages = entropy(messages)
        labels = [interactions.labels[i] for i in range(interactions.size)]

        message_mapping = {}

        for message, label in zip(messages, labels):
            message = _hashable_tensor(message)
            label = _hashable_tensor(label)

            if message not in message_mapping:
                message_mapping[message] = {}

            message_mapping[message][label] = message_mapping[message].get(
                label, 0) + 1

        # majority vote per message
        correct = 0.0
        total = 0.0

        for labels in message_mapping.values():
            best_freq = None

            for freq in labels.values():
                if best_freq is None or freq > best_freq:
                    best_freq = freq

                total += freq
            correct += best_freq

        majority_accuracy = correct / total

        return dict(codewords_entropy=entropy_messages,
                    majority_acc=majority_accuracy)
コード例 #4
0
ファイル: train.py プロジェクト: wedddy0707/EGG
    receiver = Receiver(opts.receiver_hidden)
    receiver = core.RnnReceiverGS(
        receiver,
        opts.vocab_size,
        opts.receiver_embedding,
        opts.receiver_hidden,
        cell=opts.receiver_cell,
    )

    game = core.SenderReceiverRnnGS(sender, receiver, loss)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
    )
    trainer.train(n_epochs=opts.n_epochs)

    sender_inputs, messages, _, receiver_outputs, labels = core.dump_interactions(
        game, test_loader, gs=True, device=device, variable_length=True)

    for (seq, l), message, output, label in zip(sender_inputs, messages,
                                                receiver_outputs, labels):
        print(f"{seq[:l]} -> {message} -> {output.argmax()} (label = {label})")

    core.close()