def get_game(opt): feat_size = 4096 out_hidden_size = 20 emb_size = 10 sender = InformedSender(opt.game_size, feat_size, opt.embedding_size, opt.hidden_size, out_hidden_size, temp=opt.tau_s) receiver = Receiver(opt.game_size, feat_size, opt.embedding_size, out_hidden_size, reinforce=(opts.mode == 'rf')) if opts.mode == 'rf': sender = core.RnnSenderReinforce(sender, opt.vocab_size, emb_size, out_hidden_size, cell="gru", max_len=2) receiver = core.RnnReceiverReinforce(receiver, opt.vocab_size, emb_size, out_hidden_size, cell="gru") game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=0.01, receiver_entropy_coeff=0.01) elif opts.mode == 'gs': sender = core.GumbelSoftmaxWrapper(sender, temperature=opt.gs_tau) game = core.SymbolGameGS(sender, receiver, loss_nll) else: raise RuntimeError(f"Unknown training mode: {opts.mode}") return game
num_workers=1) assert train_loader or dump_loader, 'Either training or dump data must be specified' sender, receiver, loss = build_model(opts, train_loader, dump_loader) if opts.train_mode.lower() == 'rf': sender = core.RnnSenderReinforce(sender, opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_layers) receiver = core.RnnReceiverReinforce(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_layers) game = core.SenderReceiverRnnReinforce( sender, receiver, non_differentiable_loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff) elif opts.train_mode.lower() == 'gs': sender = core.RnnSenderGS(sender, opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell,
def get_my_game(opt): feat_size = 4096 out_hidden_size = 20 emb_size = 10 pop = opts.pop_size sender_list = [] receiver_list = [] for i in range(pop): if not opts.multi_head: sender = InformedSender(opt.game_size, feat_size, opt.embedding_size, opt.hidden_size, out_hidden_size, temp=opt.tau_s) else: sender = InformedSenderMultiHead(opt.game_size, feat_size, opt.embedding_size, opt.hidden_size, out_hidden_size, temp=opt.tau_s) receiver = MyReceiver(opt.game_size, feat_size, opt.embedding_size, out_hidden_size, reinforce=(opts.mode == 'rf')) if opts.mode == 'rf': sender = core.MyRnnSenderReinforce(sender, opt.vocab_size, emb_size, out_hidden_size, multi_head=opt.multi_head, cell="gru", max_len=opt.max_len) receiver = core.RnnReceiverReinforce(receiver, opt.vocab_size, emb_size, out_hidden_size, cell="gru") elif opts.mode == 'gs': sender = core.GumbelSoftmaxWrapper(sender, temperature=opt.gs_tau) else: raise RuntimeError(f"Unknown training mode: {opts.mode}") sender_list.append(sender) receiver_list.append(receiver) if opts.mode == 'rf': if opts.pop_mode == 0: game = core.PopSenderReceiverRnnReinforce( sender_list, receiver_list, pop, loss, sender_entropy_coeff=0.01, receiver_entropy_coeff=0.01) elif opts.pop_mode == 1: game = core.PopUncSenderReceiverRnnReinforce( sender_list, receiver_list, pop, loss, use_critic_baseline=False, sender_entropy_coeff=0.01, receiver_entropy_coeff=0.01) else: game = core.PopUncSenderReceiverRnnReinforce( sender_list, receiver_list, pop, loss, use_critic_baseline=True, sender_entropy_coeff=0.01, receiver_entropy_coeff=0.01) elif opts.mode == 'gs': game = core.PopSymbolGameGS(sender_list, receiver_list, pop, loss_nll) else: raise RuntimeError(f"Unknown training mode: {opts.mode}") return game