def retrain_receiver(receiver_generator, sender): receiver = receiver_generator() game = core.SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=0.0, receiver_entropy_coeff=0.0, ) optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr) early_stopper = EarlyStopperAccuracy( opts.early_stopping_thr, validation=True ) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_loader, validation_data=validation_loader, callbacks=[early_stopper, Evaluator(loaders, opts.device, freq=0)], ) trainer.train(n_epochs=opts.n_epochs // 2) accs = [x[1]["acc"] for x in early_stopper.validation_stats] return accs
def main(params): opts = get_params(params) print(opts) kwargs = {"num_workers": 1, "pin_memory": True} if opts.cuda else {} transform = transforms.ToTensor() train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform) test_dataset = datasets.MNIST("./data", train=False, download=False, transform=transform) n_classes = 10 label_mapping = torch.LongTensor([x % n_classes for x in range(100)]) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, **kwargs) train_loader = DoubleMnist(train_loader, label_mapping) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16 * 1024, shuffle=False, **kwargs) test_loader = DoubleMnist(test_loader, label_mapping) sender = Sender( vocab_size=opts.vocab_size, linear_channel=opts.linear_channel == 1, softmax_channel=opts.softmax_non_linearity, ) receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes) if opts.softmax_non_linearity == 0 and opts.linear_channel == 0: sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature) game = core.SymbolGameGS(sender, receiver, diff_loss_symbol) optimizer = core.build_optimizer(game.parameters()) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[ core.ConsoleLogger(as_json=True, print_train_loss=True), EarlyStopperAccuracy(opts.early_stopping_thr), ], ) trainer.train(n_epochs=opts.n_epochs) core.close()
def main(params): opts = get_params(params) print(json.dumps(vars(opts))) device = opts.device train_loader = OneHotLoader(n_bits=opts.n_bits, bits_s=opts.bits_s, bits_r=opts.bits_r, batch_size=opts.batch_size, batches_per_epoch=opts.n_examples_per_epoch/opts.batch_size) test_loader = UniformLoader(n_bits=opts.n_bits, bits_s=opts.bits_s, bits_r=opts.bits_r) test_loader.batch = [x.to(device) for x in test_loader.batch] sender = Sender(n_bits=opts.n_bits, n_hidden=opts.sender_hidden, vocab_size=opts.vocab_size) if opts.mode == 'gs': receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden, vocab_size=opts.vocab_size) sender = core.GumbelSoftmaxWrapper(agent=sender, temperature=opts.temperature) game = core.SymbolGameGS(sender, receiver, diff_loss) elif opts.mode == 'rf': sender = core.ReinforceWrapper(agent=sender) receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden, vocab_size=opts.vocab_size) receiver = core.ReinforceDeterministicWrapper(agent=receiver) game = core.SymbolGameReinforce(sender, receiver, diff_loss, sender_entropy_coeff=opts.sender_entropy_coeff) elif opts.mode == 'non_diff': sender = core.ReinforceWrapper(agent=sender) receiver = ReinforcedReceiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden, vocab_size=opts.vocab_size) game = core.SymbolGameReinforce(sender, receiver, non_diff_loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff) else: assert False, 'Unknown training mode' optimizer = torch.optim.Adam( [ dict(params=sender.parameters(), lr=opts.sender_lr), dict(params=receiver.parameters(), lr=opts.receiver_lr) ]) loss = game.loss intervention = CallbackEvaluator(test_loader, device=device, is_gs=opts.mode == 'gs', loss=loss, var_length=False, input_intervention=True) early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, epoch_callback=intervention, as_json=True, early_stopping=early_stopper) trainer.train(n_epochs=opts.n_epochs) core.close()
def main(params): opts = get_params(params) print(opts) label_mapping = torch.LongTensor([x % opts.n_labels for x in range(100)]) print('# label mapping', label_mapping.tolist()) kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {} transform = transforms.ToTensor() train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, **kwargs) train_loader = DoubleMnist(train_loader, label_mapping) test_dataset = datasets.MNIST('./data', train=False, transform=transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16 * 1024, shuffle=False, **kwargs) test_loader = DoubleMnist(test_loader, label_mapping) sender = Sender(vocab_size=opts.vocab_size) receiver = Receiver(vocab_size=opts.vocab_size, n_classes=opts.n_labels, n_hidden=opts.n_hidden) sender = core.GumbelSoftmaxWrapper(sender, temperature=opts.temperature) game = core.SymbolGameGS(sender, receiver, diff_loss_symbol) optimizer = core.build_optimizer(game.parameters()) intervention = CallbackEvaluator(test_loader, device=opts.device, loss=game.loss, is_gs=True, var_length=False, input_intervention=False) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[ core.ConsoleLogger(as_json=True), EarlyStopperAccuracy(opts.early_stopping_thr), intervention ]) trainer.train(n_epochs=opts.n_epochs) core.close()
def main(params): opts = get_params(params) print(json.dumps(vars(opts))) kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {} transform = transforms.ToTensor() train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST('./data', train=False, download=False, transform=transform) n_classes = 10 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=opts.batch_size, shuffle=False, **kwargs) sender = Sender(vocab_size=opts.vocab_size, linear_channel=opts.linear_channel == 1, softmax_channel=opts.softmax_non_linearity) receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes) if opts.softmax_non_linearity == 0 and opts.linear_channel == 0: sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature) game = core.SymbolGameGS(sender, receiver, diff_loss_symbol) optimizer = core.build_optimizer(game.parameters()) early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, as_json=True, early_stopping=early_stopper, print_train_loss=True) trainer.train(n_epochs=opts.n_epochs) core.close()
def main(params): opts = get_params(params) print(json.dumps(vars(opts))) kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {} transform = transforms.ToTensor() train_loader = torch.utils.data.DataLoader( datasets.MNIST('./data', train=True, download=True, transform=transform), batch_size=opts.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( datasets.MNIST('./data', train=False, transform=transform), batch_size=16 * 1024, shuffle=False, **kwargs) n_classes = 10 binarize = False test_loader = SplitImages(TakeFirstLoader(test_loader, n=1), rows_receiver=opts.receiver_rows, rows_sender=opts.sender_rows, binarize=binarize, receiver_bottom=True) train_loader = SplitImages(train_loader, rows_sender=opts.sender_rows, rows_receiver=opts.receiver_rows, binarize=binarize, receiver_bottom=True) sender = Sender(vocab_size=opts.vocab_size) receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes) sender = core.GumbelSoftmaxWrapper(sender, temperature=opts.temperature) game = core.SymbolGameGS(sender, receiver, diff_loss_symbol) optimizer = core.build_optimizer(game.parameters()) intervention = CallbackEvaluator(test_loader, device=opts.device, loss=game.loss, is_gs=True, var_length=False, input_intervention=True) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[core.ConsoleLogger(as_json=True), EarlyStopperAccuracy(opts.early_stopping_thr), intervention]) trainer.train(n_epochs=opts.n_epochs) core.close()
def main(params): opts = get_params(params) print(json.dumps(vars(opts))) kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {} transform = transforms.ToTensor() train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST('./data', train=False, download=False, transform=transform) n_classes = 10 corrupt_labels_(dataset=train_dataset, p_corrupt=opts.p_corrupt, seed=opts.random_seed+1) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=opts.batch_size, shuffle=False, **kwargs) deeper_alice = opts.deeper_alice == 1 and opts.deeper == 1 deeper_bob = opts.deeper_alice != 1 and opts.deeper == 1 sender = Sender(vocab_size=opts.vocab_size, deeper=deeper_alice, linear_channel=opts.linear_channel == 1, softmax_channel=opts.softmax_non_linearity == 1) receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes, deeper=deeper_bob) if opts.softmax_non_linearity != 1 and opts.linear_channel != 1: sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature) game = core.SymbolGameGS(sender, receiver, diff_loss_symbol) optimizer = core.build_optimizer(game.parameters()) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[core.ConsoleLogger(as_json=True, print_train_loss=True), EarlyStopperAccuracy(opts.early_stopping_thr)] ) trainer.train(n_epochs=opts.n_epochs) core.close()
def main(params): print(torch.cuda.is_available()) opts = get_params(params) print(opts, flush=True) device = opts.device force_eos = opts.force_eos == 1 if opts.probs == 'uniform': probs = np.ones(opts.n_features) elif opts.probs == 'powerlaw': probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32) #elif opts.probs == "creneau": # ones = np.ones(int(opts.n_features/2)) # tens = 10*np.ones(opts.n_features-int(opts.n_features/2)) # probs = np.concatenate((tens,ones),axis=0) #elif opts.probs == "toy": # fives = 5*np.ones(int(opts.n_features/10)) # ones = np.ones(opts.n_features-int(opts.n_features/10)) # probs = np.concatenate((fives,ones),axis=0) #elif opts.probs == "escalier": # ones = np.ones(int(opts.n_features/4)) # tens = 10*np.ones(int(opts.n_features/4)) # huns = 100*np.ones(int(opts.n_features/4)) # thous = 1000*np.ones(opts.n_features-3*int(opts.n_features/4)) # probs = np.concatenate((thous,huns,tens,ones),axis=0) else: probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32) probs /= probs.sum() print('the probs are: ', probs, flush=True) train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size, batches_per_epoch=opts.batches_per_epoch, probs=probs) # single batches with 1s on the diag test_loader = UniformLoader(opts.n_features) if opts.sender_cell == 'transformer': sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding) sender = core.TransformerSenderReinforce(agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_embedding, max_len=opts.max_len, num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads, hidden_size=opts.sender_hidden, force_eos=opts.force_eos, generate_style=opts.sender_generate_style, causal=opts.causal_sender) else: sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden) sender = core.RnnSenderReinforce(sender, opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers, force_eos=force_eos) if opts.receiver_cell == 'transformer': receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_embedding) receiver = core.TransformerReceiverDeterministic(receiver, opts.vocab_size, opts.max_len, opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden, opts.receiver_num_layers, causal=opts.causal_receiver) else: receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) if not opts.impatient: receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers) else: receiver = Receiver(n_features=opts.receiver_hidden, n_hidden=opts.vocab_size) # If impatient 1 receiver = RnnReceiverImpatient(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_features=opts.n_features) # If impatient 2 #receiver = RnnReceiverImpatient2(receiver, opts.vocab_size, opts.receiver_embedding, # opts.receiver_hidden, cell=opts.receiver_cell, # num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_features=opts.n_features) if not opts.impatient: game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg) else: game = SenderImpatientReceiverRnnReinforce(sender, receiver, loss_impatient, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg) optimizer = core.build_optimizer(game.parameters()) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)]) for epoch in range(int(opts.n_epochs)): print("Epoch: "+str(epoch)) if epoch%100==0: trainer.optimizer.defaults["lr"]/=2 trainer.train(n_epochs=1) if opts.checkpoint_dir: trainer.save_checkpoint(name=f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}') if not opts.impatient: acc_vec,messages=dump(trainer.game, opts.n_features, device, False,epoch) else: acc_vec,messages=dump_impatient(trainer.game, opts.n_features, device, False,epoch) # ADDITION TO SAVE MESSAGES all_messages=[] for x in messages: x = x.cpu().numpy() all_messages.append(x) all_messages = np.asarray(all_messages) if epoch%50==0: torch.save(sender.state_dict(), opts.dir_save+"/sender/sender_weights"+str(epoch)+".pth") torch.save(receiver.state_dict(), opts.dir_save+"/receiver/receiver_weights"+str(epoch)+".pth") #print(acc_vec) np.save(opts.dir_save+'/messages/messages_'+str((epoch))+'.npy', all_messages) np.save(opts.dir_save+'/accuracy/accuracy_'+str((epoch))+'.npy', acc_vec) core.close()
def main(params): opts = get_params(params) print(opts, flush=True) device = opts.device force_eos = opts.force_eos == 1 if opts.probs == 'uniform': probs = np.ones(opts.n_features) elif opts.probs == 'powerlaw': probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32) else: probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32) probs /= probs.sum() train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size, batches_per_epoch=opts.batches_per_epoch, probs=probs) # single batches with 1s on the diag test_loader = UniformLoader(opts.n_features) if opts.sender_cell == 'transformer': sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding) sender = core.TransformerSenderReinforce(agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_embedding, max_len=opts.max_len, num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads, hidden_size=opts.sender_hidden, force_eos=opts.force_eos, generate_style=opts.sender_generate_style, causal=opts.causal_sender) else: sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden) sender = core.RnnSenderReinforce(sender, opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers, force_eos=force_eos) if opts.receiver_cell == 'transformer': receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_embedding) receiver = core.TransformerReceiverDeterministic(receiver, opts.vocab_size, opts.max_len, opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden, opts.receiver_num_layers, causal=opts.causal_receiver) else: receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) if not opts.impatient: receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers) else: receiver = Receiver(n_features=opts.receiver_hidden, n_hidden=opts.vocab_size) # If impatient 1 receiver = RnnReceiverImpatient(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_features=opts.n_features) # If impatient 2 #receiver = RnnReceiverImpatient2(receiver, opts.vocab_size, opts.receiver_embedding, # opts.receiver_hidden, cell=opts.receiver_cell, # num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_features=opts.n_features) sender.load_state_dict(torch.load(opts.sender_weights,map_location=torch.device('cpu'))) receiver.load_state_dict(torch.load(opts.receiver_weights,map_location=torch.device('cpu'))) if not opts.impatient: game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen) else: game = SenderImpatientReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen) optimizer = core.build_optimizer(game.parameters()) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)]) # Test impose message if not opts.impatient: acc_vec,messages=dump(trainer.game, opts.n_features, device, False) else: acc_vec,messages=dump_impatient(trainer.game, opts.n_features, device, False,save_dir=opts.save_dir) all_messages=[] for x in messages: x = x.cpu().numpy() all_messages.append(x) all_messages = np.asarray(all_messages) messages=-1*np.ones((opts.n_features,opts.max_len)) for i in range(len(all_messages)): for j in range(all_messages[i].shape[0]): messages[i,j]=all_messages[i][j] np.save(opts.save_dir+"messages_analysis.npy",messages) core.close()
def main(params): opts = get_params(params) print(opts, flush=True) device = opts.device force_eos = opts.force_eos == 1 if opts.probs == 'uniform': probs = np.ones(opts.n_features) elif opts.probs == 'powerlaw': probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32) elif opts.probs == 'perso': probs = opts.n_features+1 - np.arange(1, opts.n_features+1, dtype=np.float32) else: probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32) probs /= probs.sum() print('the probs are: ', probs, flush=True) train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size, batches_per_epoch=opts.batches_per_epoch, probs=probs) # single batches with 1s on the diag test_loader = UniformLoader(opts.n_features) if opts.sender_cell == 'transformer': sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding) sender = core.TransformerSenderReinforce(agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_embedding, max_len=opts.max_len, num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads, hidden_size=opts.sender_hidden, force_eos=opts.force_eos, generate_style=opts.sender_generate_style, causal=opts.causal_sender) else: sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden) sender = core.RnnSenderReinforce(sender, opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers, force_eos=force_eos) if opts.receiver_cell == 'transformer': receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_embedding) receiver = core.TransformerReceiverDeterministic(receiver, opts.vocab_size, opts.max_len, opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden, opts.receiver_num_layers, causal=opts.causal_receiver) else: receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers) game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost) optimizer = core.build_optimizer(game.parameters()) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr), core.ConsoleLogger(as_json=True, print_train_loss=True)]) """ mode accuracy chope a chaque epoch accs=[] all_messages,acc=dump(trainer.game, opts.n_features, device, False) np.save('messages_0.npy', all_messages) accs.append(acc) for i in range(int(opts.n_epochs)): print(i) trainer.train(n_epochs=1) all_messages,acc=dump(trainer.game, opts.n_features, device, False) np.save('messages_'+str((i+1))+'.npy', all_messages) accs.append(acc) np.save('accuracy.npy',accs) """ trainer.train(n_epochs=opts.n_epochs) #if opts.checkpoint_dir: #trainer.save_checkpoint(name=f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}') #for i in range(30): # for k in range(30): # if i<k: # all_messages=dump(trainer.game, opts.n_features, device, False,pos_m=i,pos_M=k) all_messages=dump(trainer.game, opts.n_features, device, False) print(all_messages) #freq=np.zeros(30) #for message in all_messages[0]: # if i in range(message.shape[0]): # freq[int(message[i])]+=1 #print(freq) core.close()
def main(params): import copy opts = get_params(params) device = opts.device train_loader, validation_loader = get_dsprites_dataloader( path_to_data='egg/zoo/data_loaders/data/dsprites.npz', batch_size=opts.batch_size, subsample=opts.subsample, image=False) n_dim = opts.n_attributes if opts.receiver_cell in ['lstm', 'rnn', 'gru']: receiver = Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim) receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size + 1, opts.receiver_emb, opts.receiver_hidden, cell=opts.receiver_cell) else: raise ValueError(f'Unknown receiver cell, {opts.receiver_cell}') if opts.sender_cell in ['lstm', 'rnn', 'gru']: sender = Sender(n_inputs=n_dim, n_hidden=opts.sender_hidden) sender = core.RnnSenderReinforce(agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_emb, hidden_size=opts.sender_hidden, max_len=opts.max_len, force_eos=False, cell=opts.sender_cell) else: raise ValueError(f'Unknown sender cell, {opts.sender_cell}') sender = PlusOneWrapper(sender) loss = DiffLoss(opts.n_attributes, opts.n_values) baseline = { 'no': core.baselines.NoBaseline, 'mean': core.baselines.MeanBaseline, 'builtin': core.baselines.BuiltInBaseline }[opts.baseline] game = core.SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=0.0, length_cost=0.0, baseline_type=baseline) optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr) latent_values, _ = zip(*[batch for batch in validation_loader]) latent_values = latent_values[0] metrics_evaluator = Metrics(latent_values, opts.device, opts.n_attributes, opts.n_values, opts.vocab_size + 1, freq=opts.stats_freq) loaders = [] #loaders.append(("generalization hold out", generalization_holdout_loader, DiffLoss( #opts.n_attributes, opts.n_values, generalization=True))) loaders.append(("uniform holdout", validation_loader, DiffLoss(opts.n_attributes, opts.n_values))) holdout_evaluator = Evaluator(loaders, opts.device, freq=0) early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=validation_loader, callbacks=[ core.ConsoleLogger(as_json=True, print_train_loss=False), early_stopper, metrics_evaluator, holdout_evaluator ]) trainer.train(n_epochs=opts.n_epochs) last_epoch_interaction = early_stopper.validation_stats[-1][1] validation_acc = last_epoch_interaction.aux['acc'].mean() uniformtest_acc = holdout_evaluator.results['uniform holdout']['acc'] # Train new agents if validation_acc > 0.99: def _set_seed(seed): import random import numpy as np random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) core.get_opts().preemptable = False core.get_opts().checkpoint_path = None # freeze Sender and probe how fast a simple Receiver will learn the thing def retrain_receiver(receiver_generator, sender): receiver = receiver_generator() game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=0.0, receiver_entropy_coeff=0.0) optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr) early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=validation_loader, callbacks=[ early_stopper, Evaluator(loaders, opts.device, freq=0) ]) trainer.train(n_epochs=opts.n_epochs // 2) accs = [x[1]['acc'] for x in early_stopper.validation_stats] return accs frozen_sender = Freezer(copy.deepcopy(sender)) def gru_receiver_generator(): return \ core.RnnReceiverDeterministic(Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=opts.receiver_hidden, cell='gru') def small_gru_receiver_generator(): return \ core.RnnReceiverDeterministic( Receiver(n_hidden=100, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=100, cell='gru') def tiny_gru_receiver_generator(): return \ core.RnnReceiverDeterministic( Receiver(n_hidden=50, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=50, cell='gru') def nonlinear_receiver_generator(): return \ NonLinearReceiver(n_outputs=n_dim, vocab_size=opts.vocab_size + 1, max_length=opts.max_len, n_hidden=opts.receiver_hidden) for name, receiver_generator in [ ('gru', gru_receiver_generator), ('nonlinear', nonlinear_receiver_generator), ('tiny_gru', tiny_gru_receiver_generator), ('small_gru', small_gru_receiver_generator), ]: for seed in range(17, 17 + 3): _set_seed(seed) accs = retrain_receiver(receiver_generator, frozen_sender) accs += [1.0] * (opts.n_epochs // 2 - len(accs)) auc = sum(accs) print( json.dumps({ "mode": "reset", "seed": seed, "receiver_name": name, "auc": auc })) print('---End--') core.close()
def main(params): import copy opts = get_params(params) device = opts.device full_data = enumerate_attribute_value(opts.n_attributes, opts.n_values) if opts.density_data > 0: sampled_data = select_subset_V2( full_data, opts.density_data, opts.n_attributes, opts.n_values ) full_data = copy.deepcopy(sampled_data) train, generalization_holdout = split_holdout(full_data) train, uniform_holdout = split_train_test(train, 0.1) generalization_holdout, train, uniform_holdout, full_data = [ one_hotify(x, opts.n_attributes, opts.n_values) for x in [generalization_holdout, train, uniform_holdout, full_data] ] train, validation = ScaledDataset(train, opts.data_scaler), ScaledDataset(train, 1) generalization_holdout, uniform_holdout, full_data = ( ScaledDataset(generalization_holdout), ScaledDataset(uniform_holdout), ScaledDataset(full_data), ) generalization_holdout_loader, uniform_holdout_loader, full_data_loader = [ DataLoader(x, batch_size=opts.batch_size) for x in [generalization_holdout, uniform_holdout, full_data] ] train_loader = DataLoader(train, batch_size=opts.batch_size) validation_loader = DataLoader(validation, batch_size=len(validation)) n_dim = opts.n_attributes * opts.n_values if opts.receiver_cell in ["lstm", "rnn", "gru"]: receiver = Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim) receiver = core.RnnReceiverDeterministic( receiver, opts.vocab_size + 1, opts.receiver_emb, opts.receiver_hidden, cell=opts.receiver_cell, ) else: raise ValueError(f"Unknown receiver cell, {opts.receiver_cell}") if opts.sender_cell in ["lstm", "rnn", "gru"]: sender = Sender(n_inputs=n_dim, n_hidden=opts.sender_hidden) sender = core.RnnSenderReinforce( agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_emb, hidden_size=opts.sender_hidden, max_len=opts.max_len, cell=opts.sender_cell, ) else: raise ValueError(f"Unknown sender cell, {opts.sender_cell}") sender = PlusOneWrapper(sender) loss = DiffLoss(opts.n_attributes, opts.n_values) baseline = { "no": core.baselines.NoBaseline, "mean": core.baselines.MeanBaseline, "builtin": core.baselines.BuiltInBaseline, }[opts.baseline] game = core.SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=0.0, length_cost=0.0, baseline_type=baseline, ) optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr) metrics_evaluator = Metrics( validation.examples, opts.device, opts.n_attributes, opts.n_values, opts.vocab_size + 1, freq=opts.stats_freq, ) loaders = [] loaders.append( ( "generalization hold out", generalization_holdout_loader, DiffLoss(opts.n_attributes, opts.n_values, generalization=True), ) ) loaders.append( ( "uniform holdout", uniform_holdout_loader, DiffLoss(opts.n_attributes, opts.n_values), ) ) holdout_evaluator = Evaluator(loaders, opts.device, freq=0) early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_loader, validation_data=validation_loader, callbacks=[ core.ConsoleLogger(as_json=True, print_train_loss=False), early_stopper, metrics_evaluator, holdout_evaluator, ], ) trainer.train(n_epochs=opts.n_epochs) last_epoch_interaction = early_stopper.validation_stats[-1][1] validation_acc = last_epoch_interaction.aux["acc"].mean() uniformtest_acc = holdout_evaluator.results["uniform holdout"]["acc"] # Train new agents if validation_acc > 0.99: def _set_seed(seed): import random import numpy as np random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) core.get_opts().preemptable = False core.get_opts().checkpoint_path = None # freeze Sender and probe how fast a simple Receiver will learn the thing def retrain_receiver(receiver_generator, sender): receiver = receiver_generator() game = core.SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=0.0, receiver_entropy_coeff=0.0, ) optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr) early_stopper = EarlyStopperAccuracy( opts.early_stopping_thr, validation=True ) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_loader, validation_data=validation_loader, callbacks=[early_stopper, Evaluator(loaders, opts.device, freq=0)], ) trainer.train(n_epochs=opts.n_epochs // 2) accs = [x[1]["acc"] for x in early_stopper.validation_stats] return accs frozen_sender = Freezer(copy.deepcopy(sender)) def gru_receiver_generator(): return core.RnnReceiverDeterministic( Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=opts.receiver_hidden, cell="gru", ) def small_gru_receiver_generator(): return core.RnnReceiverDeterministic( Receiver(n_hidden=100, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=100, cell="gru", ) def tiny_gru_receiver_generator(): return core.RnnReceiverDeterministic( Receiver(n_hidden=50, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=50, cell="gru", ) def nonlinear_receiver_generator(): return NonLinearReceiver( n_outputs=n_dim, vocab_size=opts.vocab_size + 1, max_length=opts.max_len, n_hidden=opts.receiver_hidden, ) for name, receiver_generator in [ ("gru", gru_receiver_generator), ("nonlinear", nonlinear_receiver_generator), ("tiny_gru", tiny_gru_receiver_generator), ("small_gru", small_gru_receiver_generator), ]: for seed in range(17, 17 + 3): _set_seed(seed) accs = retrain_receiver(receiver_generator, frozen_sender) accs += [1.0] * (opts.n_epochs // 2 - len(accs)) auc = sum(accs) print( json.dumps( { "mode": "reset", "seed": seed, "receiver_name": name, "auc": auc, } ) ) print("---End--") core.close()
def main(params): opts = get_params(params) print(opts, flush=True) # For compatibility, after https://github.com/facebookresearch/EGG/pull/130 # the meaning of `length` changed a bit. Before it included the EOS symbol; now # it doesn't. To ensure that hyperparameters/CL arguments do not change, # we subtract it here. opts.max_len -= 1 device = opts.device if opts.probs == "uniform": probs = np.ones(opts.n_features) elif opts.probs == "powerlaw": probs = 1 / np.arange(1, opts.n_features + 1, dtype=np.float32) else: probs = np.array([float(x) for x in opts.probs.split(",")], dtype=np.float32) probs /= probs.sum() print("the probs are: ", probs, flush=True) train_loader = OneHotLoader( n_features=opts.n_features, batch_size=opts.batch_size, batches_per_epoch=opts.batches_per_epoch, probs=probs, ) # single batches with 1s on the diag test_loader = UniformLoader(opts.n_features) if opts.sender_cell == "transformer": sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding) sender = core.TransformerSenderReinforce( agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_embedding, max_len=opts.max_len, num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads, hidden_size=opts.sender_hidden, generate_style=opts.sender_generate_style, causal=opts.causal_sender, ) else: sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden) sender = core.RnnSenderReinforce( sender, opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers, ) if opts.receiver_cell == "transformer": receiver = Receiver( n_features=opts.n_features, n_hidden=opts.receiver_embedding ) receiver = core.TransformerReceiverDeterministic( receiver, opts.vocab_size, opts.max_len, opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden, opts.receiver_num_layers, causal=opts.causal_receiver, ) else: receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) receiver = core.RnnReceiverDeterministic( receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers, ) empty_logger = LoggingStrategy.minimal() game = core.SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff, train_logging_strategy=empty_logger, length_cost=opts.length_cost, ) optimizer = core.build_optimizer(game.parameters()) callbacks = [ EarlyStopperAccuracy(opts.early_stopping_thr), core.ConsoleLogger(as_json=True, print_train_loss=True), ] if opts.checkpoint_dir: checkpoint_name = f"{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}" callbacks.append( core.CheckpointSaver( checkpoint_path=opts.checkpoint_dir, prefix=checkpoint_name ) ) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=callbacks, ) trainer.train(n_epochs=opts.n_epochs) game.logging_strategy = LoggingStrategy.maximal() # now log everything dump(trainer.game, opts.n_features, device, False) core.close()
def main(params): import copy opts = get_params(params) device = opts.device full_data = enumerate_attribute_value(opts.n_attributes, opts.n_values) if opts.density_data > 0: sampled_data = select_subset_V2(full_data, opts.density_data, opts.n_attributes, opts.n_values) full_data = copy.deepcopy(sampled_data) train, generalization_holdout = split_holdout(full_data) train, uniform_holdout = split_train_test(train, 0.1) generalization_holdout, train, uniform_holdout, full_data = [ one_hotify(x, opts.n_attributes, opts.n_values) for x in [generalization_holdout, train, uniform_holdout, full_data] ] train, validation = ScaledDataset(train, opts.data_scaler), ScaledDataset( train, 1) generalization_holdout, uniform_holdout, full_data = ScaledDataset( generalization_holdout), ScaledDataset(uniform_holdout), ScaledDataset( full_data) generalization_holdout_loader, uniform_holdout_loader, full_data_loader = [ DataLoader(x, batch_size=opts.batch_size) for x in [generalization_holdout, uniform_holdout, full_data] ] train_loader = DataLoader(train, batch_size=opts.batch_size) validation_loader = DataLoader(validation, batch_size=len(validation)) n_dim = opts.n_attributes * opts.n_values if opts.receiver_cell in ['lstm', 'rnn', 'gru']: receiver = MMReceiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim) receiver = core.MMRnnReceiverDeterministic(receiver, opts.vocab_size + 1, opts.receiver_emb, opts.receiver_hidden, cell=opts.receiver_cell) else: raise ValueError(f'Unknown receiver cell, {opts.receiver_cell}') if opts.sender_cell in ['lstm', 'rnn', 'gru']: sender = MMSender(n_inputs=n_dim, n_hidden=opts.sender_hidden) s1 = SplitWrapper(sender, 0) s2 = SplitWrapper(sender, 1) sender1 = core.RnnSenderReinforce(agent=s1, vocab_size=opts.vocab_size, embed_dim=opts.sender_emb, hidden_size=opts.sender_hidden, max_len=opts.max_len, force_eos=False, cell=opts.sender_cell) sender1 = PlusNWrapper(sender1, 1) sender2 = core.RnnSenderReinforce(agent=s2, vocab_size=opts.vocab_size, embed_dim=opts.sender_emb, hidden_size=opts.sender_hidden, max_len=opts.max_len, force_eos=False, cell=opts.sender_cell) # sender2 = PlusNWrapper(sender2, opts.vocab_size + 1) sender2 = PlusNWrapper(sender2, 1) sender = CombineMMRnnSenderReinforce(sender1, sender2) else: raise ValueError(f'Unknown sender cell, {opts.sender_cell}') loss = DiffLoss(opts.n_attributes, opts.n_values) baseline = { 'no': core.baselines.NoBaseline, 'mean': core.baselines.MeanBaseline, 'builtin': core.baselines.BuiltInBaseline }[opts.baseline] game = core.SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=0.0, length_cost=0.0, baseline_type=baseline) optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr) metrics_evaluator = MMMetrics(validation.examples, opts.device, opts.n_attributes, opts.n_values, opts.vocab_size + 1, freq=opts.stats_freq) loaders = [] loaders.append(("generalization hold out", generalization_holdout_loader, DiffLoss(opts.n_attributes, opts.n_values, generalization=True))) loaders.append(("uniform holdout", uniform_holdout_loader, DiffLoss(opts.n_attributes, opts.n_values))) holdout_evaluator = Evaluator(loaders, opts.device, freq=0) early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=validation_loader, callbacks=[ core.ConsoleLogger(as_json=True, print_train_loss=False), early_stopper, metrics_evaluator, holdout_evaluator ]) trainer.train(n_epochs=opts.n_epochs) validation_acc = early_stopper.validation_stats[-1][1]['acc'] uniformtest_acc = holdout_evaluator.results['uniform holdout']['acc'] # Train new agents if validation_acc > 0.99: def _set_seed(seed): import random import numpy as np random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) core.get_opts().preemptable = False core.get_opts().checkpoint_path = None # freeze Sender and probe how fast a simple Receiver will learn the thing def retrain_receiver(receiver_generator, sender): receiver = receiver_generator() game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=0.0, receiver_entropy_coeff=0.0) optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr) early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=validation_loader, callbacks=[ early_stopper, Evaluator(loaders, opts.device, freq=0) ]) trainer.train(n_epochs=opts.n_epochs // 2) accs = [x[1]['acc'] for x in early_stopper.validation_stats] return accs frozen_sender = Freezer(copy.deepcopy(sender)) def gru_receiver_generator(): return \ core.MMRnnReceiverDeterministic(MMReceiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=opts.receiver_hidden, cell='gru') def small_gru_receiver_generator(): return \ core.MMRnnReceiverDeterministic( MMReceiver(n_hidden=50, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=50, cell='gru') def tiny_gru_receiver_generator(): return \ core.MMRnnReceiverDeterministic( MMReceiver(n_hidden=25, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=25, cell='gru') def nonlinear_receiver_generator(): return \ MMNonLinearReceiver(n_outputs=n_dim, vocab_size=opts.vocab_size + 1, max_length=opts.max_len, n_hidden=opts.receiver_hidden) for name, receiver_generator in [ ('gru', gru_receiver_generator), ('nonlinear', nonlinear_receiver_generator), ('tiny_gru', tiny_gru_receiver_generator), ('small_gru', small_gru_receiver_generator), ]: for seed in range(17, 17 + 3): _set_seed(seed) accs = retrain_receiver(receiver_generator, frozen_sender) accs += [1.0] * (opts.n_epochs // 2 - len(accs)) auc = sum(accs) print( json.dumps({ "mode": "reset", "seed": seed, "receiver_name": name, "auc": auc })) print('---End--') core.close()
def main(params): opts = get_params(params) print(opts, flush=True) device = opts.device force_eos = opts.force_eos == 1 if opts.probs == 'uniform': probs = np.ones(opts.n_features) elif opts.probs == 'powerlaw': probs = 1 / np.arange(1, opts.n_features + 1, dtype=np.float32) else: probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32) probs /= probs.sum() train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size, batches_per_epoch=opts.batches_per_epoch, probs=probs) # single batches with 1s on the diag test_loader = UniformLoader(opts.n_features) if opts.sender_cell == 'transformer': sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding) sender = core.TransformerSenderReinforce( agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_embedding, max_len=opts.max_len, num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads, hidden_size=opts.sender_hidden, force_eos=opts.force_eos, generate_style=opts.sender_generate_style, causal=opts.causal_sender) else: sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden) sender = core.RnnSenderReinforce(sender, opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers, force_eos=force_eos) if opts.receiver_cell == 'transformer': receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_embedding) receiver = core.TransformerReceiverDeterministic( receiver, opts.vocab_size, opts.max_len, opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden, opts.receiver_num_layers, causal=opts.causal_receiver) else: receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) if not opts.impatient: receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) receiver = core.RnnReceiverDeterministic( receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers) else: receiver = Receiver(n_features=opts.receiver_hidden, n_hidden=opts.vocab_size) # If impatient 1 receiver = RnnReceiverImpatient( receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_features=opts.n_features) sender.load_state_dict( torch.load(opts.sender_weights, map_location=torch.device('cpu'))) receiver.load_state_dict( torch.load(opts.receiver_weights, map_location=torch.device('cpu'))) if not opts.impatient: game = core.SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost, unigram_penalty=opts.unigram_pen) else: game = SenderImpatientReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost, unigram_penalty=opts.unigram_pen) optimizer = core.build_optimizer(game.parameters()) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)]) # Debut test position position_sieve = np.zeros((opts.n_features, opts.max_len)) for position in range(opts.max_len): dataset = [[torch.eye(opts.n_features).to(device), None]] if opts.impatient: sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \ dump_test_position_impatient(trainer.game, dataset, position=position, voc_size=opts.vocab_size, gs=False, device=device, variable_length=True) else: sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \ dump_test_position(trainer.game, dataset, position=position, voc_size=opts.vocab_size, gs=False, device=device, variable_length=True) acc_pos = [] for sender_input, message, receiver_output in zip( sender_inputs, messages, receiver_outputs): input_symbol = sender_input.argmax() output_symbol = receiver_output.argmax() acc = (input_symbol == output_symbol).float().item() acc_pos.append(acc) acc_pos = np.array(acc_pos) position_sieve[:, position] = acc_pos # Put -1 for position after message_length _, messages = dump(trainer.game, opts.n_features, device, False) # Convert messages to numpy array messages_np = [] for x in messages: x = x.cpu().numpy() messages_np.append(x) for i in range(len(messages_np)): # Message i message_i = messages_np[i] id_0 = np.where(message_i == 0)[0] if id_0.shape[0] > 0: for j in range(id_0[0] + 1, opts.max_len): position_sieve[i, j] = -1 np.save("analysis/position_sieve.npy", position_sieve) core.close()
def main(params): print(torch.cuda.is_available()) opts = get_params(params) print(opts, flush=True) device = opts.device force_eos = opts.force_eos == 1 # Distribution of the inputs if opts.probs=="uniform": probs=[] probs_by_att = np.ones(opts.n_values) probs_by_att /= probs_by_att.sum() for i in range(opts.n_attributes): probs.append(probs_by_att) if opts.probs=="entropy_test": probs=[] for i in range(opts.n_attributes): probs_by_att = np.ones(opts.n_values) probs_by_att[0]=1+(1*i) probs_by_att /= probs_by_att.sum() probs.append(probs_by_att) if opts.probs_attributes=="uniform": probs_attributes=[1]*opts.n_attributes if opts.probs_attributes=="uniform_indep": probs_attributes=[] probs_attributes=[0.2]*opts.n_attributes if opts.probs_attributes=="echelon": probs_attributes=[] for i in range(opts.n_attributes): #probs_attributes.append(1.-(0.2)*i) #probs_attributes.append(0.7+0.3/(i+1)) probs_attributes=[1.,0.95,0.9,0.85] print("Probability by attribute is:",probs_attributes) train_loader = OneHotLoaderCompositionality(n_values=opts.n_values, n_attributes=opts.n_attributes, batch_size=opts.batch_size*opts.n_attributes, batches_per_epoch=opts.batches_per_epoch, probs=probs, probs_attributes=probs_attributes) # single batches with 1s on the diag test_loader = TestLoaderCompositionality(n_values=opts.n_values,n_attributes=opts.n_attributes) ### SENDER ### sender = Sender(n_features=opts.n_attributes*opts.n_values, n_hidden=opts.sender_hidden) sender = core.RnnSenderReinforce(sender,opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers, force_eos=force_eos) ### RECEIVER ### receiver = Receiver(n_features=opts.n_values, n_hidden=opts.receiver_hidden) if not opts.impatient: receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) receiver = RnnReceiverCompositionality(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_attributes=opts.n_attributes, n_values=opts.n_values) else: receiver = Receiver(n_features=opts.receiver_hidden, n_hidden=opts.vocab_size) # If impatient 1 receiver = RnnReceiverImpatientCompositionality(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_attributes=opts.n_attributes, n_values=opts.n_values) if not opts.impatient: game = CompositionalitySenderReceiverRnnReinforce(sender, receiver, loss_compositionality, sender_entropy_coeff=opts.sender_entropy_coeff, n_attributes=opts.n_attributes,n_values=opts.n_values,receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg) else: game = CompositionalitySenderImpatientReceiverRnnReinforce(sender, receiver, loss_impatient_compositionality, sender_entropy_coeff=opts.sender_entropy_coeff, n_attributes=opts.n_attributes,n_values=opts.n_values,att_weights=opts.att_weights,receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg) optimizer = core.build_optimizer(game.parameters()) trainer = CompoTrainer(n_attributes=opts.n_attributes,n_values=opts.n_values,game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)]) curr_accs=[0]*7 game.att_weights=[1]*(game.n_attributes) for epoch in range(int(opts.n_epochs)): print("Epoch: "+str(epoch)) #if epoch%100==0: # trainer.optimizer.defaults["lr"]/=2 trainer.train(n_epochs=1) if opts.checkpoint_dir: trainer.save_checkpoint(name=f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}') if not opts.impatient: acc_vec,messages=dump_compositionality(trainer.game, opts.n_attributes, opts.n_values, device, False,epoch) else: acc_vec,messages=dump_impatient_compositionality(trainer.game, opts.n_attributes, opts.n_values, device, False,epoch) print(acc_vec.mean(0)) #print(trainer.optimizer.defaults["lr"]) # ADDITION TO SAVE MESSAGES all_messages=[] for x in messages: x = x.cpu().numpy() all_messages.append(x) all_messages = np.asarray(all_messages) if epoch%50==0: torch.save(sender.state_dict(), opts.dir_save+"/sender/sender_weights"+str(epoch)+".pth") torch.save(receiver.state_dict(), opts.dir_save+"/receiver/receiver_weights"+str(epoch)+".pth") np.save(opts.dir_save+'/messages/messages_'+str((epoch))+'.npy', all_messages) np.save(opts.dir_save+'/accuracy/accuracy_'+str((epoch))+'.npy', acc_vec) print(acc_vec.T) core.close()
def main(params): opts = get_params(params) print(opts, flush=True) device = opts.device force_eos = opts.force_eos == 1 if opts.probs == 'uniform': probs = np.ones(opts.n_features) elif opts.probs == 'powerlaw': probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32) else: probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32) probs /= probs.sum() print('the probs are: ', probs, flush=True) train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size, batches_per_epoch=opts.batches_per_epoch, probs=probs) # single batches with 1s on the diag test_loader = UniformLoader(opts.n_features) if opts.sender_cell == 'transformer': sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding) sender = core.TransformerSenderReinforce(agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_embedding, max_len=opts.max_len, num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads, hidden_size=opts.sender_hidden, force_eos=opts.force_eos, generate_style=opts.sender_generate_style, causal=opts.causal_sender) else: sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden) sender = core.RnnSenderReinforce(sender, opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers, force_eos=force_eos) if opts.receiver_cell == 'transformer': receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_embedding) receiver = core.TransformerReceiverDeterministic(receiver, opts.vocab_size, opts.max_len, opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden, opts.receiver_num_layers, causal=opts.causal_receiver) else: receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers) empty_logger = LoggingStrategy.minimal() game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff, train_logging_strategy=empty_logger, length_cost=opts.length_cost) optimizer = core.build_optimizer(game.parameters()) callbacks = [EarlyStopperAccuracy(opts.early_stopping_thr), core.ConsoleLogger(as_json=True, print_train_loss=True)] if opts.checkpoint_dir: checkpoint_name = f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}' callbacks.append(core.CheckpointSaver(checkpoint_path=opts.checkpoint_dir, prefix=checkpoint_name)) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=callbacks) trainer.train(n_epochs=opts.n_epochs) game.logging_strategy = LoggingStrategy.maximal() # now log everything dump(trainer.game, opts.n_features, device, False) core.close()
def main(params): opts = get_params(params) print(opts, flush=True) device = opts.device force_eos = opts.force_eos == 1 if opts.probs == 'uniform': probs = np.ones(opts.n_features) elif opts.probs == 'powerlaw': probs = 1 / np.arange(1, opts.n_features + 1, dtype=np.float32) else: probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32) probs /= probs.sum() print('the probs are: ', probs, flush=True) train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size, batches_per_epoch=opts.batches_per_epoch, probs=probs) # single batches with 1s on the diag test_loader = UniformLoader(opts.n_features) ################################# # define sender (speaker) agent # ################################# sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden) sender = RnnSenderReinforce(sender, opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers, force_eos=force_eos, noise_loc=opts.sender_noise_loc, noise_scale=opts.sender_noise_scale) #################################### # define receiver (listener) agent # #################################### receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) receiver = RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers, noise_loc=opts.receiver_noise_loc, noise_scale=opts.receiver_noise_scale) ################### # define channel # ################### channel = Channel(vocab_size=opts.vocab_size, p=opts.channel_repl_prob) game = SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost, effective_max_len=opts.effective_max_len, channel=channel, sender_entropy_common_ratio=opts.sender_entropy_common_ratio) optimizer = core.build_optimizer(game.parameters()) callbacks = [ EarlyStopperAccuracy(opts.early_stopping_thr), core.ConsoleLogger(as_json=True, print_train_loss=True) ] if opts.checkpoint_dir: ''' info in checkpoint_name: - n_features as f - vocab_size as vocab - random_seed as rs - lr as lr - sender_hidden as shid - receiver_hidden as rhid - sender_entropy_coeff as sentr - length_cost as reg - max_len as max_len - sender_noise_scale as sscl - receiver_noise_scale as rscl - channel_repl_prob as crp - sender_entropy_common_ratio as scr ''' checkpoint_name = ( f'{opts.name}' + '_aer' + ('_uniform' if opts.probs == 'uniform' else '') + f'_f{opts.n_features}' + f'_vocab{opts.vocab_size}' + f'_rs{opts.random_seed}' + f'_lr{opts.lr}' + f'_shid{opts.sender_hidden}' + f'_rhid{opts.receiver_hidden}' + f'_sentr{opts.sender_entropy_coeff}' + f'_reg{opts.length_cost}' + f'_max_len{opts.max_len}' + f'_sscl{opts.sender_noise_scale}' + f'_rscl{opts.receiver_noise_scale}' + f'_crp{opts.channel_repl_prob}' + f'_scr{opts.sender_entropy_common_ratio}') callbacks.append( core.CheckpointSaver(checkpoint_path=opts.checkpoint_dir, checkpoint_freq=opts.checkpoint_freq, prefix=checkpoint_name)) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=callbacks) trainer.train(n_epochs=opts.n_epochs) print('<div id="prefix test without eos">') prefix_test(trainer.game, opts.n_features, device, add_eos=False) print('</div>') print('<div id="prefix test with eos">') prefix_test(trainer.game, opts.n_features, device, add_eos=True) print('<div id="suffix test">') suffix_test(trainer.game, opts.n_features, device) print('</div>') print('<div id="replacement test">') replacement_test(trainer.game, opts.n_features, opts.vocab_size, device) print('</div>') print('<div id="dump">') dump(trainer.game, opts.n_features, device, False) print('</div>') core.close()
def main(params): opts = get_params(params) print(opts) full_data, train, uniform_holdout, generalization_holdout = get_data(opts) generalization_holdout, train, uniform_holdout, full_data = [ one_hotify(x, opts.n_attributes, opts.n_values) for x in [generalization_holdout, train, uniform_holdout, full_data] ] train, validation = ScaledDataset(train, opts.data_scaler), ScaledDataset( train, 1) generalization_holdout, uniform_holdout, full_data = ( ScaledDataset(generalization_holdout), ScaledDataset(uniform_holdout), ScaledDataset(full_data), ) generalization_holdout_loader, uniform_holdout_loader, full_data_loader = [ DataLoader(x, batch_size=opts.batch_size) for x in [generalization_holdout, uniform_holdout, full_data] ] train_loader = DataLoader(train, batch_size=opts.batch_size, shuffle=True) validation_loader = DataLoader(validation, batch_size=len(validation)) loss = DiffLoss(opts.n_attributes, opts.n_values) sender = getattr(egg.zoo.compo_vs_generalization_ood.archs, opts.sender)(opts) receiver = getattr(egg.zoo.compo_vs_generalization_ood.archs, opts.receiver)(opts) game = core.SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=0.0, length_cost=0.0, baseline_type=core.baselines.MeanBaseline, ) optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr) metrics_evaluator = Metrics( validation.examples, opts.device, opts.n_attributes, opts.n_values, opts.vocab_size + 1, freq=opts.stats_freq, ) metrics_evaluator_generalization_holdout = Metrics( generalization_holdout.examples, opts.device, opts.n_attributes, opts.n_values, opts.vocab_size + 1, freq=opts.stats_freq, ) loaders = [] loaders.append(( "generalization hold out", generalization_holdout_loader, # DiffLoss(opts.n_attributes, opts.n_values, generalization=True), # we don't want to ignore zeros: DiffLoss(opts.n_attributes, opts.n_values, generalization=False), )) loaders.append(( "uniform holdout", uniform_holdout_loader, DiffLoss(opts.n_attributes, opts.n_values), )) holdout_evaluator = Evaluator(loaders, opts.device, freq=1) early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_loader, validation_data=validation_loader, callbacks=[ # print validation (i.e. unscaled training data) loss: core.ConsoleLogger(as_json=True, print_train_loss=False), early_stopper, # print compositionality metrics at the end of training # (validation, i.e, unscaled training data): metrics_evaluator, # print compositionality metrics at the end of training (holdout data): metrics_evaluator_generalization_holdout, # print generalization and uniform holdout accuracies at each epoch: holdout_evaluator, ], ) trainer.train(n_epochs=opts.n_epochs) core.close()
def main(params): opts = get_params(params) print(opts) device = opts.device train_loader = OneHotLoader(n_bits=opts.n_bits, bits_s=opts.bits_s, bits_r=opts.bits_r, batch_size=opts.batch_size, batches_per_epoch=opts.n_examples_per_epoch / opts.batch_size) test_loader = UniformLoader(n_bits=opts.n_bits, bits_s=opts.bits_s, bits_r=opts.bits_r) test_loader.batch = [x.to(device) for x in test_loader.batch] if not opts.variable_length: sender = Sender(n_bits=opts.n_bits, n_hidden=opts.sender_hidden, vocab_size=opts.vocab_size) if opts.mode == 'gs': sender = core.GumbelSoftmaxWrapper(agent=sender, temperature=opts.temperature) receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden) receiver = core.SymbolReceiverWrapper( receiver, vocab_size=opts.vocab_size, agent_input_size=opts.receiver_hidden) game = core.SymbolGameGS(sender, receiver, diff_loss) elif opts.mode == 'rf': sender = core.ReinforceWrapper(agent=sender) receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden) receiver = core.SymbolReceiverWrapper( receiver, vocab_size=opts.vocab_size, agent_input_size=opts.receiver_hidden) receiver = core.ReinforceDeterministicWrapper(agent=receiver) game = core.SymbolGameReinforce( sender, receiver, diff_loss, sender_entropy_coeff=opts.sender_entropy_coeff) elif opts.mode == 'non_diff': sender = core.ReinforceWrapper(agent=sender) receiver = ReinforcedReceiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden) receiver = core.SymbolReceiverWrapper( receiver, vocab_size=opts.vocab_size, agent_input_size=opts.receiver_hidden) game = core.SymbolGameReinforce( sender, receiver, non_diff_loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff) else: if opts.mode != 'rf': print('Only mode=rf is supported atm') opts.mode = 'rf' if opts.sender_cell == 'transformer': receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden) sender = Sender( n_bits=opts.n_bits, n_hidden=opts.sender_hidden, vocab_size=opts.sender_hidden) # TODO: not really vocab sender = core.TransformerSenderReinforce( agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_emb, max_len=opts.max_len, num_layers=1, num_heads=1, hidden_size=opts.sender_hidden) else: receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden) sender = Sender( n_bits=opts.n_bits, n_hidden=opts.sender_hidden, vocab_size=opts.sender_hidden) # TODO: not really vocab sender = core.RnnSenderReinforce(agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_emb, hidden_size=opts.sender_hidden, max_len=opts.max_len, cell=opts.sender_cell) if opts.receiver_cell == 'transformer': receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_emb) receiver = core.TransformerReceiverDeterministic( receiver, opts.vocab_size, opts.max_len, opts.receiver_emb, num_heads=1, hidden_size=opts.receiver_hidden, num_layers=1) else: receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden) receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_emb, opts.receiver_hidden, cell=opts.receiver_cell) game = core.SenderReceiverRnnGS(sender, receiver, diff_loss) game = core.SenderReceiverRnnReinforce( sender, receiver, diff_loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff) optimizer = torch.optim.Adam([ dict(params=sender.parameters(), lr=opts.sender_lr), dict(params=receiver.parameters(), lr=opts.receiver_lr) ]) loss = game.loss intervention = CallbackEvaluator(test_loader, device=device, is_gs=opts.mode == 'gs', loss=loss, var_length=opts.variable_length, input_intervention=True) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[ core.ConsoleLogger(as_json=True), EarlyStopperAccuracy(opts.early_stopping_thr), intervention ]) trainer.train(n_epochs=opts.n_epochs) core.close()
def main(params): opts = get_params(params) print(opts, flush=True) device = opts.device force_eos = opts.force_eos == 1 if opts.probs == 'uniform': probs = np.ones(opts.n_features) elif opts.probs == 'powerlaw': probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32) else: probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32) probs /= probs.sum() train_loader = OneHotLoader(n_features=opts.n_values, batch_size=opts.batch_size*opts.n_attributes, batches_per_epoch=opts.batches_per_epoch, probs=probs) # single batches with 1s on the diag test_loader = UniformLoader(opts.n_values) ### SENDER ### sender = Sender(n_features=opts.n_attributes*opts.n_values, n_hidden=opts.sender_hidden) sender = core.RnnSenderReinforce(sender,opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers, force_eos=force_eos) ### RECEIVER ### receiver = Receiver(n_features=opts.n_values, n_hidden=opts.receiver_hidden) if not opts.impatient: receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) receiver = RnnReceiverCompositionality(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_attributes=opts.n_attributes, n_values=opts.n_values) else: receiver = Receiver(n_features=opts.receiver_hidden, n_hidden=opts.vocab_size) # If impatient 1 receiver = RnnReceiverImpatientCompositionality(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_attributes=opts.n_attributes, n_values=opts.n_values) sender.load_state_dict(torch.load(opts.sender_weights,map_location=torch.device('cpu'))) receiver.load_state_dict(torch.load(opts.receiver_weights,map_location=torch.device('cpu'))) if not opts.impatient: game = CompositionalitySenderReceiverRnnReinforce(sender, receiver, loss_compositionality, sender_entropy_coeff=opts.sender_entropy_coeff, n_attributes=opts.n_attributes,n_values=opts.n_values,att_weights=[1],receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg) else: game = CompositionalitySenderImpatientReceiverRnnReinforce(sender, receiver, loss_impatient_compositionality, sender_entropy_coeff=opts.sender_entropy_coeff, n_attributes=opts.n_attributes,n_values=opts.n_values,att_weights=[1],receiver_entropy_coeff=opts.receiver_entropy_coeff, length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg) optimizer = core.build_optimizer(game.parameters()) trainer = CompoTrainer(n_attributes=opts.n_attributes,n_values=opts.n_values,game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)]) # Debut test position position_sieve=np.zeros((opts.n_attributes**opts.n_values,opts.max_len,opts.n_attributes)) for position in range(opts.max_len): one_hots = torch.eye(opts.n_values) val=np.arange(opts.n_values) combination=list(itertools.product(val,repeat=opts.n_attributes)) dataset=[] for i in range(len(combination)): new_input=torch.zeros(0) for j in combination[i]: new_input=torch.cat((new_input,one_hots[j])) dataset.append(new_input) dataset=torch.stack(dataset) dataset=[[dataset,None]] if opts.impatient: sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \ dump_test_position_impatient_compositionality(trainer.game, dataset, position=position, voc_size=opts.vocab_size, gs=False, device=device, variable_length=True) else: sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \ dump_test_position_compositionality(trainer.game, dataset, position=position, voc_size=opts.vocab_size, gs=False, device=device, variable_length=True) for i in range(len(receiver_outputs)): message=messages[i] correct=True for j in range(len(list(combination[i]))): if receiver_outputs[i][j]==list(combination[i])[j]: position_sieve[i,position,j]=1 # Put -1 for position after message_length if not opts.impatient: acc_vec,messages=dump_compositionality(trainer.game, opts.n_attributes, opts.n_values, device, False,0) else: acc_vec,messages=dump_impatient_compositionality(trainer.game, opts.n_attributes, opts.n_values, device, False,0) # Convert messages to numpy array messages_np=[] for x in messages: x = x.cpu().numpy() messages_np.append(x) for i in range(len(messages_np)): # Message i message_i=messages_np[i] id_0=np.where(message_i==0)[0] if id_0.shape[0]>0: for j in range(id_0[0]+1,opts.max_len): position_sieve[i,j]=-1 np.save("analysis/position_sieve.npy",position_sieve) core.close()