def get_game(opt): feat_size = 4096 sender = InformedSender(opt.game_size, feat_size, opt.embedding_size, opt.hidden_size, opt.vocab_size, temp=opt.tau_s) receiver = Receiver(opt.game_size, feat_size, opt.embedding_size, opt.vocab_size, reinforce=(opts.mode == 'rf')) if opts.mode == 'rf': sender = core.ReinforceWrapper(sender) receiver = core.ReinforceWrapper(receiver) game = core.SymbolGameReinforce(sender, receiver, loss, sender_entropy_coeff=0.01, receiver_entropy_coeff=0.01) elif opts.mode == 'gs': sender = core.GumbelSoftmaxWrapper(sender, temperature=opt.gs_tau) game = core.SymbolGameGS(sender, receiver, loss_nll) else: raise RuntimeError(f"Unknown training mode: {opts.mode}") return game
def main(params): opts = get_params(params) print(json.dumps(vars(opts))) device = opts.device train_loader = OneHotLoader(n_bits=opts.n_bits, bits_s=opts.bits_s, bits_r=opts.bits_r, batch_size=opts.batch_size, batches_per_epoch=opts.n_examples_per_epoch/opts.batch_size) test_loader = UniformLoader(n_bits=opts.n_bits, bits_s=opts.bits_s, bits_r=opts.bits_r) test_loader.batch = [x.to(device) for x in test_loader.batch] sender = Sender(n_bits=opts.n_bits, n_hidden=opts.sender_hidden, vocab_size=opts.vocab_size) if opts.mode == 'gs': receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden, vocab_size=opts.vocab_size) sender = core.GumbelSoftmaxWrapper(agent=sender, temperature=opts.temperature) game = core.SymbolGameGS(sender, receiver, diff_loss) elif opts.mode == 'rf': sender = core.ReinforceWrapper(agent=sender) receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden, vocab_size=opts.vocab_size) receiver = core.ReinforceDeterministicWrapper(agent=receiver) game = core.SymbolGameReinforce(sender, receiver, diff_loss, sender_entropy_coeff=opts.sender_entropy_coeff) elif opts.mode == 'non_diff': sender = core.ReinforceWrapper(agent=sender) receiver = ReinforcedReceiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden, vocab_size=opts.vocab_size) game = core.SymbolGameReinforce(sender, receiver, non_diff_loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff) else: assert False, 'Unknown training mode' optimizer = torch.optim.Adam( [ dict(params=sender.parameters(), lr=opts.sender_lr), dict(params=receiver.parameters(), lr=opts.receiver_lr) ]) loss = game.loss intervention = CallbackEvaluator(test_loader, device=device, is_gs=opts.mode == 'gs', loss=loss, var_length=False, input_intervention=True) early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, epoch_callback=intervention, as_json=True, early_stopping=early_stopper) trainer.train(n_epochs=opts.n_epochs) core.close()
def test_game_reinforce(): core.init() sender = core.ReinforceWrapper(ToyAgent()) receiver = core.ReinforceDeterministicWrapper(Receiver()) loss = lambda sender_input, message, receiver_input, receiver_output, labels, aux_input: ( -(receiver_output == labels).float(), {}, ) game = core.SymbolGameReinforce(sender, receiver, loss, sender_entropy_coeff=1e-1, receiver_entropy_coeff=0.0) optimizer = torch.optim.Adagrad(game.parameters(), lr=1e-1) data = Dataset() trainer = core.Trainer(game, optimizer, train_data=data, validation_data=None) trainer.train(5000) assert (sender.agent.fc1.weight.t().argmax( dim=1).cpu() == BATCH_Y).all(), str(sender.agent.fc1.weight)
def test_toy_agent_reinforce(): core.init() agent = core.ReinforceWrapper(ToyAgent()) optimizer = torch.optim.Adam(agent.parameters()) for _ in range(1000): optimizer.zero_grad() output, log_prob, entropy = agent(BATCH_X) loss = -((output == BATCH_Y).float() * log_prob).mean() loss.backward() optimizer.step() assert (agent.agent.fc1.weight.t().argmax(dim=1).cpu() == BATCH_Y).all()
loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost, ) elif opts.mode.lower() == "rf": sender = core.RnnSenderReinforce( sender, opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, force_eos=False, ) receiver = core.ReinforceWrapper(receiver) receiver = core.RnnReceiverReinforce( receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, ) game = core.SenderReceiverRnnReinforce( sender, receiver, non_differentiable_loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost, )
def main(params): opts = get_params(params) print(opts) device = opts.device train_loader = OneHotLoader(n_bits=opts.n_bits, bits_s=opts.bits_s, bits_r=opts.bits_r, batch_size=opts.batch_size, batches_per_epoch=opts.n_examples_per_epoch / opts.batch_size) test_loader = UniformLoader(n_bits=opts.n_bits, bits_s=opts.bits_s, bits_r=opts.bits_r) test_loader.batch = [x.to(device) for x in test_loader.batch] if not opts.variable_length: sender = Sender(n_bits=opts.n_bits, n_hidden=opts.sender_hidden, vocab_size=opts.vocab_size) if opts.mode == 'gs': sender = core.GumbelSoftmaxWrapper(agent=sender, temperature=opts.temperature) receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden) receiver = core.SymbolReceiverWrapper( receiver, vocab_size=opts.vocab_size, agent_input_size=opts.receiver_hidden) game = core.SymbolGameGS(sender, receiver, diff_loss) elif opts.mode == 'rf': sender = core.ReinforceWrapper(agent=sender) receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden) receiver = core.SymbolReceiverWrapper( receiver, vocab_size=opts.vocab_size, agent_input_size=opts.receiver_hidden) receiver = core.ReinforceDeterministicWrapper(agent=receiver) game = core.SymbolGameReinforce( sender, receiver, diff_loss, sender_entropy_coeff=opts.sender_entropy_coeff) elif opts.mode == 'non_diff': sender = core.ReinforceWrapper(agent=sender) receiver = ReinforcedReceiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden) receiver = core.SymbolReceiverWrapper( receiver, vocab_size=opts.vocab_size, agent_input_size=opts.receiver_hidden) game = core.SymbolGameReinforce( sender, receiver, non_diff_loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff) else: if opts.mode != 'rf': print('Only mode=rf is supported atm') opts.mode = 'rf' if opts.sender_cell == 'transformer': receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden) sender = Sender( n_bits=opts.n_bits, n_hidden=opts.sender_hidden, vocab_size=opts.sender_hidden) # TODO: not really vocab sender = core.TransformerSenderReinforce( agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_emb, max_len=opts.max_len, num_layers=1, num_heads=1, hidden_size=opts.sender_hidden) else: receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden) sender = Sender( n_bits=opts.n_bits, n_hidden=opts.sender_hidden, vocab_size=opts.sender_hidden) # TODO: not really vocab sender = core.RnnSenderReinforce(agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_emb, hidden_size=opts.sender_hidden, max_len=opts.max_len, cell=opts.sender_cell) if opts.receiver_cell == 'transformer': receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_emb) receiver = core.TransformerReceiverDeterministic( receiver, opts.vocab_size, opts.max_len, opts.receiver_emb, num_heads=1, hidden_size=opts.receiver_hidden, num_layers=1) else: receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden) receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_emb, opts.receiver_hidden, cell=opts.receiver_cell) game = core.SenderReceiverRnnGS(sender, receiver, diff_loss) game = core.SenderReceiverRnnReinforce( sender, receiver, diff_loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff) optimizer = torch.optim.Adam([ dict(params=sender.parameters(), lr=opts.sender_lr), dict(params=receiver.parameters(), lr=opts.receiver_lr) ]) loss = game.loss intervention = CallbackEvaluator(test_loader, device=device, is_gs=opts.mode == 'gs', loss=loss, var_length=opts.variable_length, input_intervention=True) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[ core.ConsoleLogger(as_json=True), EarlyStopperAccuracy(opts.early_stopping_thr), intervention ]) trainer.train(n_epochs=opts.n_epochs) core.close()
cell=opts.rnn_cell) for _ in range(opts.population_size) ] receiver_ensemble_1 = GumbelSoftmaxMultiAgentEnsemble(agents=receivers_1) receivers_2 = [ core.RnnReceiverGS(agent=Receiver(opts.receiver_hidden, opts.n_features, opts.n_attributes), vocab_size=opts.vocab_size, emb_dim=opts.receiver_embedding, n_hidden=opts.receiver_hidden, cell=opts.rnn_cell) for _ in range(opts.population_size) ] receiver_ensemble_2 = GumbelSoftmaxMultiAgentEnsemble(agents=receivers_2) executive_sender = core.ReinforceWrapper( ExecutiveSender(opts.population_size, opts.n_features, opts.n_attributes)) game = GSSequentialTeamworkGame(sender_ensemble, receiver_ensemble_1, receiver_ensemble_2, executive_sender, loss_diff, opts.executive_sender_entropy_coeff) sender_params = [{ 'params': sender_ensemble.parameters(), 'lr': opts.sender_lr }] executive_sender_params = [{ 'params': executive_sender.parameters(), 'lr': opts.executive_sender_lr }] receivers_params = [{
class GumbelSoftmaxMultiAgentEnsemble(nn.Module): def __init__(self, agents: List[core.ReinforceWrapper]): super(GumbelSoftmaxMultiAgentEnsemble, self).__init__() self.agents = nn.ModuleList(agents) def forward(self, input, agent_indices, **kwargs): samples = [agent(input) for agent in self.agents] if samples[0].dim() > 2: # RNN agent_indices = agent_indices.reshape(1, agent_indices.size(0), 1, 1).expand(1, agent_indices.size(0), samples[0].size(1), samples[0].size(2)) else: agent_indices = agent_indices.reshape(1, agent_indices.size(0), 1).expand(1, agent_indices.size(0), samples[0].size(2)) samples = torch.stack(samples, dim=0).gather(dim=0, index=agent_indices) return samples.squeeze(dim=0) if __name__ == "__main__": from teamwork.agents import Sender BATCH_SIZE, INPUT_SIZE, OUTPUT_SIZE = 8, 10, 5 AGENT_INDICES = torch.LongTensor([0, 1, 0, 1, 0, 1, 0, 1]) multi_agent = ReinforceMultiAgentEnsemble(agents=[core.ReinforceWrapper(Sender(OUTPUT_SIZE, INPUT_SIZE)), core.ReinforceWrapper(Sender(OUTPUT_SIZE, INPUT_SIZE))]) samples, log_probs, entropies = multi_agent(torch.Tensor(BATCH_SIZE, INPUT_SIZE), agent_indices=AGENT_INDICES) assert samples.shape == log_probs.shape == entropies.shape == torch.Size([BATCH_SIZE]) multi_agent = GumbelSoftmaxMultiAgentEnsemble(agents=[core.GumbelSoftmaxWrapper(Sender(OUTPUT_SIZE, INPUT_SIZE)), core.GumbelSoftmaxWrapper(Sender(OUTPUT_SIZE, INPUT_SIZE))]) samples = multi_agent(torch.Tensor(BATCH_SIZE, INPUT_SIZE), agent_indices=AGENT_INDICES) assert samples.shape == torch.Size([BATCH_SIZE, OUTPUT_SIZE])