def main(params): print(torch.cuda.is_available()) opts = get_params(params) print(opts, flush=True) device = opts.device force_eos = opts.force_eos == 1 if opts.probs == 'uniform': probs = np.ones(opts.n_features) elif opts.probs == 'powerlaw': probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32) #elif opts.probs == "creneau": # ones = np.ones(int(opts.n_features/2)) # tens = 10*np.ones(opts.n_features-int(opts.n_features/2)) # probs = np.concatenate((tens,ones),axis=0) #elif opts.probs == "toy": # fives = 5*np.ones(int(opts.n_features/10)) # ones = np.ones(opts.n_features-int(opts.n_features/10)) # probs = np.concatenate((fives,ones),axis=0) #elif opts.probs == "escalier": # ones = np.ones(int(opts.n_features/4)) # tens = 10*np.ones(int(opts.n_features/4)) # huns = 100*np.ones(int(opts.n_features/4)) # thous = 1000*np.ones(opts.n_features-3*int(opts.n_features/4)) # probs = np.concatenate((thous,huns,tens,ones),axis=0) else: probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32) probs /= probs.sum() print('the probs are: ', probs, flush=True) train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size, batches_per_epoch=opts.batches_per_epoch, probs=probs) # single batches with 1s on the diag test_loader = UniformLoader(opts.n_features) if opts.sender_cell == 'transformer': sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding) sender = core.TransformerSenderReinforce(agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_embedding, max_len=opts.max_len, num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads, hidden_size=opts.sender_hidden, force_eos=opts.force_eos, generate_style=opts.sender_generate_style, causal=opts.causal_sender) else: sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden) sender = core.RnnSenderReinforce(sender, opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers, force_eos=force_eos) if opts.receiver_cell == 'transformer': receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_embedding) receiver = core.TransformerReceiverDeterministic(receiver, opts.vocab_size, opts.max_len, opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden, opts.receiver_num_layers, causal=opts.causal_receiver) else: receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) if not opts.impatient: receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers) else: receiver = Receiver(n_features=opts.receiver_hidden, n_hidden=opts.vocab_size) # If impatient 1 receiver = RnnReceiverImpatient(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_features=opts.n_features) # If impatient 2 #receiver = RnnReceiverImpatient2(receiver, opts.vocab_size, opts.receiver_embedding, # opts.receiver_hidden, cell=opts.receiver_cell, # num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_features=opts.n_features) if not opts.impatient: game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg) else: game = SenderImpatientReceiverRnnReinforce(sender, receiver, loss_impatient, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg) optimizer = core.build_optimizer(game.parameters()) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)]) for epoch in range(int(opts.n_epochs)): print("Epoch: "+str(epoch)) if epoch%100==0: trainer.optimizer.defaults["lr"]/=2 trainer.train(n_epochs=1) if opts.checkpoint_dir: trainer.save_checkpoint(name=f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}') if not opts.impatient: acc_vec,messages=dump(trainer.game, opts.n_features, device, False,epoch) else: acc_vec,messages=dump_impatient(trainer.game, opts.n_features, device, False,epoch) # ADDITION TO SAVE MESSAGES all_messages=[] for x in messages: x = x.cpu().numpy() all_messages.append(x) all_messages = np.asarray(all_messages) if epoch%50==0: torch.save(sender.state_dict(), opts.dir_save+"/sender/sender_weights"+str(epoch)+".pth") torch.save(receiver.state_dict(), opts.dir_save+"/receiver/receiver_weights"+str(epoch)+".pth") #print(acc_vec) np.save(opts.dir_save+'/messages/messages_'+str((epoch))+'.npy', all_messages) np.save(opts.dir_save+'/accuracy/accuracy_'+str((epoch))+'.npy', acc_vec) core.close()
def main(params): print(torch.cuda.is_available()) opts = get_params(params) print(opts, flush=True) device = opts.device force_eos = opts.force_eos == 1 # Distribution of the inputs if opts.probs=="uniform": probs=[] probs_by_att = np.ones(opts.n_values) probs_by_att /= probs_by_att.sum() for i in range(opts.n_attributes): probs.append(probs_by_att) if opts.probs=="entropy_test": probs=[] for i in range(opts.n_attributes): probs_by_att = np.ones(opts.n_values) probs_by_att[0]=1+(1*i) probs_by_att /= probs_by_att.sum() probs.append(probs_by_att) if opts.probs_attributes=="uniform": probs_attributes=[1]*opts.n_attributes if opts.probs_attributes=="uniform_indep": probs_attributes=[] probs_attributes=[0.2]*opts.n_attributes if opts.probs_attributes=="echelon": probs_attributes=[] for i in range(opts.n_attributes): #probs_attributes.append(1.-(0.2)*i) #probs_attributes.append(0.7+0.3/(i+1)) probs_attributes=[1.,0.95,0.9,0.85] print("Probability by attribute is:",probs_attributes) train_loader = OneHotLoaderCompositionality(n_values=opts.n_values, n_attributes=opts.n_attributes, batch_size=opts.batch_size*opts.n_attributes, batches_per_epoch=opts.batches_per_epoch, probs=probs, probs_attributes=probs_attributes) # single batches with 1s on the diag test_loader = TestLoaderCompositionality(n_values=opts.n_values,n_attributes=opts.n_attributes) ### SENDER ### sender = Sender(n_features=opts.n_attributes*opts.n_values, n_hidden=opts.sender_hidden) sender = core.RnnSenderReinforce(sender,opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers, force_eos=force_eos) ### RECEIVER ### receiver = Receiver(n_features=opts.n_values, n_hidden=opts.receiver_hidden) if not opts.impatient: receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) receiver = RnnReceiverCompositionality(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_attributes=opts.n_attributes, n_values=opts.n_values) else: receiver = Receiver(n_features=opts.receiver_hidden, n_hidden=opts.vocab_size) # If impatient 1 receiver = RnnReceiverImpatientCompositionality(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_attributes=opts.n_attributes, n_values=opts.n_values) if not opts.impatient: game = CompositionalitySenderReceiverRnnReinforce(sender, receiver, loss_compositionality, sender_entropy_coeff=opts.sender_entropy_coeff, n_attributes=opts.n_attributes,n_values=opts.n_values,receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg) else: game = CompositionalitySenderImpatientReceiverRnnReinforce(sender, receiver, loss_impatient_compositionality, sender_entropy_coeff=opts.sender_entropy_coeff, n_attributes=opts.n_attributes,n_values=opts.n_values,att_weights=opts.att_weights,receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg) optimizer = core.build_optimizer(game.parameters()) trainer = CompoTrainer(n_attributes=opts.n_attributes,n_values=opts.n_values,game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)]) curr_accs=[0]*7 game.att_weights=[1]*(game.n_attributes) for epoch in range(int(opts.n_epochs)): print("Epoch: "+str(epoch)) #if epoch%100==0: # trainer.optimizer.defaults["lr"]/=2 trainer.train(n_epochs=1) if opts.checkpoint_dir: trainer.save_checkpoint(name=f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}') if not opts.impatient: acc_vec,messages=dump_compositionality(trainer.game, opts.n_attributes, opts.n_values, device, False,epoch) else: acc_vec,messages=dump_impatient_compositionality(trainer.game, opts.n_attributes, opts.n_values, device, False,epoch) print(acc_vec.mean(0)) #print(trainer.optimizer.defaults["lr"]) # ADDITION TO SAVE MESSAGES all_messages=[] for x in messages: x = x.cpu().numpy() all_messages.append(x) all_messages = np.asarray(all_messages) if epoch%50==0: torch.save(sender.state_dict(), opts.dir_save+"/sender/sender_weights"+str(epoch)+".pth") torch.save(receiver.state_dict(), opts.dir_save+"/receiver/receiver_weights"+str(epoch)+".pth") np.save(opts.dir_save+'/messages/messages_'+str((epoch))+'.npy', all_messages) np.save(opts.dir_save+'/accuracy/accuracy_'+str((epoch))+'.npy', acc_vec) print(acc_vec.T) core.close()