Exemple #1
0
def get_game(opt):
    feat_size = 4096
    sender = InformedSender(opt.game_size,
                            feat_size,
                            opt.embedding_size,
                            opt.hidden_size,
                            opt.vocab_size,
                            temp=opt.tau_s)
    receiver = Receiver(opt.game_size,
                        feat_size,
                        opt.embedding_size,
                        opt.vocab_size,
                        reinforce=(opts.mode == 'rf'))
    if opts.mode == 'rf':
        sender = core.ReinforceWrapper(sender)
        receiver = core.ReinforceWrapper(receiver)
        game = core.SymbolGameReinforce(sender,
                                        receiver,
                                        loss,
                                        sender_entropy_coeff=0.01,
                                        receiver_entropy_coeff=0.01)
    elif opts.mode == 'gs':
        sender = core.GumbelSoftmaxWrapper(sender, temperature=opt.gs_tau)
        game = core.SymbolGameGS(sender, receiver, loss_nll)
    else:
        raise RuntimeError(f"Unknown training mode: {opts.mode}")

    return game
Exemple #2
0
def main(params):
    opts = get_params(params)
    print(json.dumps(vars(opts)))

    device = opts.device

    train_loader = OneHotLoader(n_bits=opts.n_bits,
                                bits_s=opts.bits_s,
                                bits_r=opts.bits_r,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.n_examples_per_epoch/opts.batch_size)

    test_loader = UniformLoader(n_bits=opts.n_bits, bits_s=opts.bits_s, bits_r=opts.bits_r)
    test_loader.batch = [x.to(device) for x in test_loader.batch]

    sender = Sender(n_bits=opts.n_bits, n_hidden=opts.sender_hidden,
                    vocab_size=opts.vocab_size)

    if opts.mode == 'gs':
        receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden, vocab_size=opts.vocab_size)
        sender = core.GumbelSoftmaxWrapper(agent=sender, temperature=opts.temperature)
        game = core.SymbolGameGS(sender, receiver, diff_loss)
    elif opts.mode == 'rf':
        sender = core.ReinforceWrapper(agent=sender)
        receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden, vocab_size=opts.vocab_size)
        receiver = core.ReinforceDeterministicWrapper(agent=receiver)
        game = core.SymbolGameReinforce(sender, receiver, diff_loss, sender_entropy_coeff=opts.sender_entropy_coeff)
    elif opts.mode == 'non_diff':
        sender = core.ReinforceWrapper(agent=sender)
        receiver = ReinforcedReceiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden, vocab_size=opts.vocab_size)
        game = core.SymbolGameReinforce(sender, receiver, non_diff_loss,
                                        sender_entropy_coeff=opts.sender_entropy_coeff,
                                        receiver_entropy_coeff=opts.receiver_entropy_coeff)

    else:
        assert False, 'Unknown training mode'

    optimizer = torch.optim.Adam(
        [
            dict(params=sender.parameters(), lr=opts.sender_lr),
            dict(params=receiver.parameters(), lr=opts.receiver_lr)
        ])

    loss = game.loss

    intervention = CallbackEvaluator(test_loader, device=device, is_gs=opts.mode == 'gs', loss=loss, var_length=False,
                                     input_intervention=True)

    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr)

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader, epoch_callback=intervention, as_json=True,
                           early_stopping=early_stopper)

    trainer.train(n_epochs=opts.n_epochs)

    core.close()
def test_game_reinforce():
    core.init()
    sender = core.ReinforceWrapper(ToyAgent())
    receiver = core.ReinforceDeterministicWrapper(Receiver())

    loss = lambda sender_input, message, receiver_input, receiver_output, labels, aux_input: (
        -(receiver_output == labels).float(),
        {},
    )

    game = core.SymbolGameReinforce(sender,
                                    receiver,
                                    loss,
                                    sender_entropy_coeff=1e-1,
                                    receiver_entropy_coeff=0.0)
    optimizer = torch.optim.Adagrad(game.parameters(), lr=1e-1)

    data = Dataset()
    trainer = core.Trainer(game,
                           optimizer,
                           train_data=data,
                           validation_data=None)
    trainer.train(5000)

    assert (sender.agent.fc1.weight.t().argmax(
        dim=1).cpu() == BATCH_Y).all(), str(sender.agent.fc1.weight)
def test_toy_agent_reinforce():
    core.init()
    agent = core.ReinforceWrapper(ToyAgent())

    optimizer = torch.optim.Adam(agent.parameters())

    for _ in range(1000):
        optimizer.zero_grad()
        output, log_prob, entropy = agent(BATCH_X)
        loss = -((output == BATCH_Y).float() * log_prob).mean()
        loss.backward()

        optimizer.step()

    assert (agent.agent.fc1.weight.t().argmax(dim=1).cpu() == BATCH_Y).all()
Exemple #5
0
         loss,
         sender_entropy_coeff=opts.sender_entropy_coeff,
         receiver_entropy_coeff=opts.receiver_entropy_coeff,
         length_cost=opts.length_cost,
     )
 elif opts.mode.lower() == "rf":
     sender = core.RnnSenderReinforce(
         sender,
         opts.vocab_size,
         opts.sender_embedding,
         opts.sender_hidden,
         cell=opts.sender_cell,
         max_len=opts.max_len,
         force_eos=False,
     )
     receiver = core.ReinforceWrapper(receiver)
     receiver = core.RnnReceiverReinforce(
         receiver,
         opts.vocab_size,
         opts.receiver_embedding,
         opts.receiver_hidden,
         cell=opts.receiver_cell,
     )
     game = core.SenderReceiverRnnReinforce(
         sender,
         receiver,
         non_differentiable_loss,
         sender_entropy_coeff=opts.sender_entropy_coeff,
         receiver_entropy_coeff=opts.receiver_entropy_coeff,
         length_cost=opts.length_cost,
     )
Exemple #6
0
def main(params):
    opts = get_params(params)
    print(opts)

    device = opts.device

    train_loader = OneHotLoader(n_bits=opts.n_bits,
                                bits_s=opts.bits_s,
                                bits_r=opts.bits_r,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.n_examples_per_epoch /
                                opts.batch_size)

    test_loader = UniformLoader(n_bits=opts.n_bits,
                                bits_s=opts.bits_s,
                                bits_r=opts.bits_r)
    test_loader.batch = [x.to(device) for x in test_loader.batch]

    if not opts.variable_length:
        sender = Sender(n_bits=opts.n_bits,
                        n_hidden=opts.sender_hidden,
                        vocab_size=opts.vocab_size)
        if opts.mode == 'gs':
            sender = core.GumbelSoftmaxWrapper(agent=sender,
                                               temperature=opts.temperature)
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            receiver = core.SymbolReceiverWrapper(
                receiver,
                vocab_size=opts.vocab_size,
                agent_input_size=opts.receiver_hidden)
            game = core.SymbolGameGS(sender, receiver, diff_loss)
        elif opts.mode == 'rf':
            sender = core.ReinforceWrapper(agent=sender)
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            receiver = core.SymbolReceiverWrapper(
                receiver,
                vocab_size=opts.vocab_size,
                agent_input_size=opts.receiver_hidden)
            receiver = core.ReinforceDeterministicWrapper(agent=receiver)
            game = core.SymbolGameReinforce(
                sender,
                receiver,
                diff_loss,
                sender_entropy_coeff=opts.sender_entropy_coeff)
        elif opts.mode == 'non_diff':
            sender = core.ReinforceWrapper(agent=sender)
            receiver = ReinforcedReceiver(n_bits=opts.n_bits,
                                          n_hidden=opts.receiver_hidden)
            receiver = core.SymbolReceiverWrapper(
                receiver,
                vocab_size=opts.vocab_size,
                agent_input_size=opts.receiver_hidden)

            game = core.SymbolGameReinforce(
                sender,
                receiver,
                non_diff_loss,
                sender_entropy_coeff=opts.sender_entropy_coeff,
                receiver_entropy_coeff=opts.receiver_entropy_coeff)
    else:
        if opts.mode != 'rf':
            print('Only mode=rf is supported atm')
            opts.mode = 'rf'

        if opts.sender_cell == 'transformer':
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            sender = Sender(
                n_bits=opts.n_bits,
                n_hidden=opts.sender_hidden,
                vocab_size=opts.sender_hidden)  # TODO: not really vocab
            sender = core.TransformerSenderReinforce(
                agent=sender,
                vocab_size=opts.vocab_size,
                embed_dim=opts.sender_emb,
                max_len=opts.max_len,
                num_layers=1,
                num_heads=1,
                hidden_size=opts.sender_hidden)
        else:
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            sender = Sender(
                n_bits=opts.n_bits,
                n_hidden=opts.sender_hidden,
                vocab_size=opts.sender_hidden)  # TODO: not really vocab
            sender = core.RnnSenderReinforce(agent=sender,
                                             vocab_size=opts.vocab_size,
                                             embed_dim=opts.sender_emb,
                                             hidden_size=opts.sender_hidden,
                                             max_len=opts.max_len,
                                             cell=opts.sender_cell)

        if opts.receiver_cell == 'transformer':
            receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_emb)
            receiver = core.TransformerReceiverDeterministic(
                receiver,
                opts.vocab_size,
                opts.max_len,
                opts.receiver_emb,
                num_heads=1,
                hidden_size=opts.receiver_hidden,
                num_layers=1)
        else:
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            receiver = core.RnnReceiverDeterministic(receiver,
                                                     opts.vocab_size,
                                                     opts.receiver_emb,
                                                     opts.receiver_hidden,
                                                     cell=opts.receiver_cell)

            game = core.SenderReceiverRnnGS(sender, receiver, diff_loss)

        game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            diff_loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff)

    optimizer = torch.optim.Adam([
        dict(params=sender.parameters(), lr=opts.sender_lr),
        dict(params=receiver.parameters(), lr=opts.receiver_lr)
    ])

    loss = game.loss

    intervention = CallbackEvaluator(test_loader,
                                     device=device,
                                     is_gs=opts.mode == 'gs',
                                     loss=loss,
                                     var_length=opts.variable_length,
                                     input_intervention=True)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True),
                               EarlyStopperAccuracy(opts.early_stopping_thr),
                               intervention
                           ])

    trainer.train(n_epochs=opts.n_epochs)

    core.close()
                           cell=opts.rnn_cell)
        for _ in range(opts.population_size)
    ]
    receiver_ensemble_1 = GumbelSoftmaxMultiAgentEnsemble(agents=receivers_1)
    receivers_2 = [
        core.RnnReceiverGS(agent=Receiver(opts.receiver_hidden,
                                          opts.n_features, opts.n_attributes),
                           vocab_size=opts.vocab_size,
                           emb_dim=opts.receiver_embedding,
                           n_hidden=opts.receiver_hidden,
                           cell=opts.rnn_cell)
        for _ in range(opts.population_size)
    ]
    receiver_ensemble_2 = GumbelSoftmaxMultiAgentEnsemble(agents=receivers_2)
    executive_sender = core.ReinforceWrapper(
        ExecutiveSender(opts.population_size, opts.n_features,
                        opts.n_attributes))

    game = GSSequentialTeamworkGame(sender_ensemble, receiver_ensemble_1,
                                    receiver_ensemble_2, executive_sender,
                                    loss_diff,
                                    opts.executive_sender_entropy_coeff)
    sender_params = [{
        'params': sender_ensemble.parameters(),
        'lr': opts.sender_lr
    }]
    executive_sender_params = [{
        'params': executive_sender.parameters(),
        'lr': opts.executive_sender_lr
    }]
    receivers_params = [{
class GumbelSoftmaxMultiAgentEnsemble(nn.Module):

    def __init__(self, agents: List[core.ReinforceWrapper]):
        super(GumbelSoftmaxMultiAgentEnsemble, self).__init__()
        self.agents = nn.ModuleList(agents)

    def forward(self, input, agent_indices, **kwargs):
        samples = [agent(input) for agent in self.agents]
        if samples[0].dim() > 2:  # RNN
            agent_indices = agent_indices.reshape(1, agent_indices.size(0), 1, 1).expand(1, agent_indices.size(0), samples[0].size(1), samples[0].size(2))
        else:
            agent_indices = agent_indices.reshape(1, agent_indices.size(0), 1).expand(1, agent_indices.size(0), samples[0].size(2))
        samples = torch.stack(samples, dim=0).gather(dim=0, index=agent_indices)
        return samples.squeeze(dim=0)


if __name__ == "__main__":
    from teamwork.agents import Sender
    BATCH_SIZE, INPUT_SIZE, OUTPUT_SIZE = 8, 10, 5
    AGENT_INDICES = torch.LongTensor([0, 1, 0, 1, 0, 1, 0, 1])
    multi_agent = ReinforceMultiAgentEnsemble(agents=[core.ReinforceWrapper(Sender(OUTPUT_SIZE, INPUT_SIZE)),
                                                      core.ReinforceWrapper(Sender(OUTPUT_SIZE, INPUT_SIZE))])
    samples, log_probs, entropies = multi_agent(torch.Tensor(BATCH_SIZE, INPUT_SIZE),
                                                agent_indices=AGENT_INDICES)
    assert samples.shape == log_probs.shape == entropies.shape == torch.Size([BATCH_SIZE])

    multi_agent = GumbelSoftmaxMultiAgentEnsemble(agents=[core.GumbelSoftmaxWrapper(Sender(OUTPUT_SIZE, INPUT_SIZE)),
                                                         core.GumbelSoftmaxWrapper(Sender(OUTPUT_SIZE, INPUT_SIZE))])
    samples = multi_agent(torch.Tensor(BATCH_SIZE, INPUT_SIZE), agent_indices=AGENT_INDICES)
    assert samples.shape == torch.Size([BATCH_SIZE, OUTPUT_SIZE])