Example #1
0
def dump(game, n_features, device, gs_mode):
    # tiny "dataset"
    dataset = [[torch.eye(n_features).to(device), None]]

    sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \
        core.dump_sender_receiver(game, dataset, gs=gs_mode, device=device, variable_length=True)

    unif_acc = 0.
    powerlaw_acc = 0.
    powerlaw_probs = 1 / np.arange(1, n_features + 1, dtype=np.float32)
    powerlaw_probs /= powerlaw_probs.sum()

    for sender_input, message, receiver_output in zip(sender_inputs, messages,
                                                      receiver_outputs):
        input_symbol = sender_input.argmax()
        output_symbol = receiver_output.argmax()
        acc = (input_symbol == output_symbol).float().item()

        unif_acc += acc
        powerlaw_acc += powerlaw_probs[input_symbol] * acc
        print(
            f'input: {input_symbol.item()} -> message: {",".join([str(x.item()) for x in message])} -> output: {output_symbol.item()}',
            flush=True)

    unif_acc /= n_features

    print(f'Mean accuracy wrt uniform distribution is {unif_acc}')
    print(f'Mean accuracy wrt powerlaw distribution is {powerlaw_acc}')
    print(json.dumps({'powerlaw': powerlaw_acc, 'unif': unif_acc}))
Example #2
0
def dump(game, dataset, device, is_gs):
    sender_inputs, messages, _, receiver_outputs, labels = \
        core.dump_sender_receiver(game, dataset, gs=is_gs, device=device, variable_length=True)

    for sender_input, message, receiver_output, label \
            in zip(sender_inputs, messages, receiver_outputs, labels):
        sender_input = ' '.join(map(str, sender_input.tolist()))
        message = ' '.join(map(str, message.tolist()))
        if is_gs: receiver_output = receiver_output.argmax()
        print(f'{sender_input};{message};{receiver_output};{label.item()}')
Example #3
0
def dump(game, dataset, device, is_gs, is_var_length):
    sender_inputs, messages, _1, receiver_outputs, _2 = \
        core.dump_sender_receiver(
            game, dataset, gs=is_gs, device=device, variable_length=is_var_length)

    for sender_input, message, receiver_output \
            in zip(sender_inputs, messages, receiver_outputs):
        sender_input = ''.join(map(str, sender_input.tolist()))
        if is_var_length:
            message = ' '.join(map(str, message.tolist()))
        receiver_output = (receiver_output > 0.5).tolist()
        receiver_output = ''.join([str(x) for x in receiver_output])
        print(f'{sender_input} -> {message} -> {receiver_output}')
Example #4
0
def dump(game, n_features, device, gs_mode,pos_m=-2,pos_M=-2):
    # tiny "dataset"
    dataset = [[torch.eye(n_features).to(device), None]]

    sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \
        core.dump_sender_receiver(game, dataset, gs=gs_mode, device=device, variable_length=True,pos_m=pos_m,pos_M=pos_M)

    unif_acc = 0.
    powerlaw_acc = 0.
    powerlaw_probs = 1 / np.arange(1, n_features+1, dtype=np.float32)
    powerlaw_probs /= powerlaw_probs.sum()

    #m0=messages[0].cpu().numpy()
    #m0=np.concatenate((m0,-np.ones((30-m0.shape[0]))),axis=0)
    #M=np.expand_dims(m0,axis=0)

    #for i in range(1,len(messages)):
    #    m=messages[i].cpu().numpy()
    #        m=np.concatenate((m,-np.ones((30-m.shape[0]))),axis=0)
    #    m=np.expand_dims(m,axis=0)
    #        M=np.concatenate((M,m),axis=0)

    all_messages=[]
    for x in messages:
        x = x.cpu().numpy()
        all_messages.append(x)
    all_messages = np.asarray(all_messages)
    #print(all_messages)


    for sender_input, message, receiver_output in zip(sender_inputs, messages, receiver_outputs):
        input_symbol = sender_input.argmax()
        output_symbol = receiver_output.argmax()
        acc = (input_symbol == output_symbol).float().item()

        unif_acc += acc
        powerlaw_acc += powerlaw_probs[input_symbol] * acc
        #print(f'input: {input_symbol.item()} -> message: {",".join([str(x.item()) for x in message])} -> output: {output_symbol.item()}', flush=True)

    unif_acc /= n_features

    print(pos_m,pos_M)
    print(f'Mean accuracy wrt uniform distribution is {unif_acc}')
    print(f'Mean accuracy wrt powerlaw distribution is {powerlaw_acc}')
    print(json.dumps({'powerlaw': powerlaw_acc, 'unif': unif_acc}))
    return all_messages,powerlaw_acc
Example #5
0
    def validation(self, game):
        sender_inputs, messages, _, receiver_outputs, labels = \
            core.dump_sender_receiver(game, self.dataset, gs=self.is_gs, device=self.device,
                                      variable_length=self.var_length)

        entropy_messages = entropy(messages)

        message_mapping = {}

        for message, label in zip(messages, labels):
            message = message.item()
            label = _hashable_tensor(label)

            if not message in message_mapping:
                message_mapping[message] = {}

            message_mapping[message][label] = message_mapping[message].get(label, 0) + 1

        # majority vote per message
        correct = 0.0
        total = 0.0

        for labels in message_mapping.values():
            best_freq = None

            for freq in labels.values():
                if best_freq is None or freq > best_freq:
                    best_freq = freq

                total += freq
            correct += best_freq

        majority_accuracy = correct / total

        return dict(
            codewords_entropy=entropy_messages,
            majority_acc=majority_accuracy
        )
Example #6
0
    def train_epoch(self,epoch):
        mean_loss = 0
        mean_rest = {}
        n_batches = 0
        self.game.train()
        for batch in self.train_data:
            self.optimizer.zero_grad()
            batch = move_to(batch, self.device)
            optimized_loss, rest = self.game(*batch)
            mean_rest = _add_dicts(mean_rest, rest)
            optimized_loss.backward()
            self.optimizer.step()

            n_batches += 1
            mean_loss += optimized_loss

            ### ADDITION TO CONTROLE THE MESSAGES

            import egg.core as core

            dataset_m = [[torch.eye(20).to(self.device), None]]

            sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \
                core.dump_sender_receiver(self.game, dataset_m, gs=False, device=self.device, variable_length=True)

            all_messages=[]
            for x in messages:
                x = x.cpu().numpy()
                all_messages.append(x)
            all_messages = np.asarray(all_messages)
            np.save('messages'+str(epoch)+'_'+str(n_batches)+'.npy',all_messages)

            ####

        mean_loss /= n_batches
        mean_rest = _div_dict(mean_rest, n_batches)
        return mean_loss.item(), mean_rest
Example #7
0
    train_loader = SequenceLoader(max_n=opts.max_n, batch_size=opts.batch_size,
                                  batches_per_epoch=opts.batches_per_epoch)
    test_loader = SequenceLoader(max_n=opts.max_n, batch_size=opts.batch_size,
                                 batches_per_epoch=opts.batches_per_epoch, seed=7)

    encoder = Encoder(n_hidden=opts.sender_hidden, emb_dim=opts.sender_embedding,
                     cell=opts.sender_cell, vocab_size=3)  # only 3 symbols in the incoming data
    sender = core.RnnSenderGS(encoder, opts.vocab_size, opts.sender_embedding, opts.sender_hidden,
                              cell=opts.sender_cell, max_len=opts.max_len, temperature=opts.temperature)

    receiver = Receiver(opts.receiver_hidden)
    receiver = core.RnnReceiverGS(receiver, opts.vocab_size, opts.receiver_embedding,
                                  opts.receiver_hidden, cell=opts.receiver_cell)

    game = core.SenderReceiverRnnGS(sender, receiver, loss)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader)
    trainer.train(n_epochs=opts.n_epochs)

    sender_inputs, messages, _, receiver_outputs, labels = \
        core.dump_sender_receiver(game, test_loader, gs=True, device=device, variable_length=True)

    for (seq, l), message, output, label in zip(sender_inputs, messages, receiver_outputs, labels):
        print(f'{seq[:l]} -> {message} -> {output.argmax()} (label = {label})')

    core.close()

Example #8
0
    callbacks = [core.ConsoleLogger(print_train_loss=True, as_json=True)]
    if opts.mode.lower() == "gs":
        callbacks.append(
            core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1))
    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_data,
        validation_data=validation_data,
        callbacks=callbacks,
    )  # validation_data=validation_data,
    trainer.train(n_epochs=opts.n_epochs)

    if opts.evaluate:
        is_gs = "gs" in opts.mode
        sender_inputs, messages, receiver_inputs, receiver_outputs, labels = core.dump_sender_receiver(
            game, test_data, is_gs, variable_length=True, device=device)

        _, _, _, train_receiver_outputs, train_labels = core.dump_sender_receiver(
            game, train_data, is_gs, variable_length=True, device=device)

        # Test
        receiver_outputs = move_to(receiver_outputs, device)
        labels = move_to(labels, device)
        receiver_outputs = torch.stack(receiver_outputs)
        labels = torch.stack(labels)

        output_is_vector = opts.mode.lower() in set(
            ["gs-hard", "gs", "rf-deterministic"])

        if output_is_vector:
            tensor_accuracy = receiver_outputs.squeeze().argmax(