def dump(game, n_features, device, gs_mode): # tiny "dataset" dataset = [[torch.eye(n_features).to(device), None]] sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \ core.dump_sender_receiver(game, dataset, gs=gs_mode, device=device, variable_length=True) unif_acc = 0. powerlaw_acc = 0. powerlaw_probs = 1 / np.arange(1, n_features + 1, dtype=np.float32) powerlaw_probs /= powerlaw_probs.sum() for sender_input, message, receiver_output in zip(sender_inputs, messages, receiver_outputs): input_symbol = sender_input.argmax() output_symbol = receiver_output.argmax() acc = (input_symbol == output_symbol).float().item() unif_acc += acc powerlaw_acc += powerlaw_probs[input_symbol] * acc print( f'input: {input_symbol.item()} -> message: {",".join([str(x.item()) for x in message])} -> output: {output_symbol.item()}', flush=True) unif_acc /= n_features print(f'Mean accuracy wrt uniform distribution is {unif_acc}') print(f'Mean accuracy wrt powerlaw distribution is {powerlaw_acc}') print(json.dumps({'powerlaw': powerlaw_acc, 'unif': unif_acc}))
def dump(game, dataset, device, is_gs): sender_inputs, messages, _, receiver_outputs, labels = \ core.dump_sender_receiver(game, dataset, gs=is_gs, device=device, variable_length=True) for sender_input, message, receiver_output, label \ in zip(sender_inputs, messages, receiver_outputs, labels): sender_input = ' '.join(map(str, sender_input.tolist())) message = ' '.join(map(str, message.tolist())) if is_gs: receiver_output = receiver_output.argmax() print(f'{sender_input};{message};{receiver_output};{label.item()}')
def dump(game, dataset, device, is_gs, is_var_length): sender_inputs, messages, _1, receiver_outputs, _2 = \ core.dump_sender_receiver( game, dataset, gs=is_gs, device=device, variable_length=is_var_length) for sender_input, message, receiver_output \ in zip(sender_inputs, messages, receiver_outputs): sender_input = ''.join(map(str, sender_input.tolist())) if is_var_length: message = ' '.join(map(str, message.tolist())) receiver_output = (receiver_output > 0.5).tolist() receiver_output = ''.join([str(x) for x in receiver_output]) print(f'{sender_input} -> {message} -> {receiver_output}')
def dump(game, n_features, device, gs_mode,pos_m=-2,pos_M=-2): # tiny "dataset" dataset = [[torch.eye(n_features).to(device), None]] sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \ core.dump_sender_receiver(game, dataset, gs=gs_mode, device=device, variable_length=True,pos_m=pos_m,pos_M=pos_M) unif_acc = 0. powerlaw_acc = 0. powerlaw_probs = 1 / np.arange(1, n_features+1, dtype=np.float32) powerlaw_probs /= powerlaw_probs.sum() #m0=messages[0].cpu().numpy() #m0=np.concatenate((m0,-np.ones((30-m0.shape[0]))),axis=0) #M=np.expand_dims(m0,axis=0) #for i in range(1,len(messages)): # m=messages[i].cpu().numpy() # m=np.concatenate((m,-np.ones((30-m.shape[0]))),axis=0) # m=np.expand_dims(m,axis=0) # M=np.concatenate((M,m),axis=0) all_messages=[] for x in messages: x = x.cpu().numpy() all_messages.append(x) all_messages = np.asarray(all_messages) #print(all_messages) for sender_input, message, receiver_output in zip(sender_inputs, messages, receiver_outputs): input_symbol = sender_input.argmax() output_symbol = receiver_output.argmax() acc = (input_symbol == output_symbol).float().item() unif_acc += acc powerlaw_acc += powerlaw_probs[input_symbol] * acc #print(f'input: {input_symbol.item()} -> message: {",".join([str(x.item()) for x in message])} -> output: {output_symbol.item()}', flush=True) unif_acc /= n_features print(pos_m,pos_M) print(f'Mean accuracy wrt uniform distribution is {unif_acc}') print(f'Mean accuracy wrt powerlaw distribution is {powerlaw_acc}') print(json.dumps({'powerlaw': powerlaw_acc, 'unif': unif_acc})) return all_messages,powerlaw_acc
def validation(self, game): sender_inputs, messages, _, receiver_outputs, labels = \ core.dump_sender_receiver(game, self.dataset, gs=self.is_gs, device=self.device, variable_length=self.var_length) entropy_messages = entropy(messages) message_mapping = {} for message, label in zip(messages, labels): message = message.item() label = _hashable_tensor(label) if not message in message_mapping: message_mapping[message] = {} message_mapping[message][label] = message_mapping[message].get(label, 0) + 1 # majority vote per message correct = 0.0 total = 0.0 for labels in message_mapping.values(): best_freq = None for freq in labels.values(): if best_freq is None or freq > best_freq: best_freq = freq total += freq correct += best_freq majority_accuracy = correct / total return dict( codewords_entropy=entropy_messages, majority_acc=majority_accuracy )
def train_epoch(self,epoch): mean_loss = 0 mean_rest = {} n_batches = 0 self.game.train() for batch in self.train_data: self.optimizer.zero_grad() batch = move_to(batch, self.device) optimized_loss, rest = self.game(*batch) mean_rest = _add_dicts(mean_rest, rest) optimized_loss.backward() self.optimizer.step() n_batches += 1 mean_loss += optimized_loss ### ADDITION TO CONTROLE THE MESSAGES import egg.core as core dataset_m = [[torch.eye(20).to(self.device), None]] sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \ core.dump_sender_receiver(self.game, dataset_m, gs=False, device=self.device, variable_length=True) all_messages=[] for x in messages: x = x.cpu().numpy() all_messages.append(x) all_messages = np.asarray(all_messages) np.save('messages'+str(epoch)+'_'+str(n_batches)+'.npy',all_messages) #### mean_loss /= n_batches mean_rest = _div_dict(mean_rest, n_batches) return mean_loss.item(), mean_rest
train_loader = SequenceLoader(max_n=opts.max_n, batch_size=opts.batch_size, batches_per_epoch=opts.batches_per_epoch) test_loader = SequenceLoader(max_n=opts.max_n, batch_size=opts.batch_size, batches_per_epoch=opts.batches_per_epoch, seed=7) encoder = Encoder(n_hidden=opts.sender_hidden, emb_dim=opts.sender_embedding, cell=opts.sender_cell, vocab_size=3) # only 3 symbols in the incoming data sender = core.RnnSenderGS(encoder, opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, temperature=opts.temperature) receiver = Receiver(opts.receiver_hidden) receiver = core.RnnReceiverGS(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell) game = core.SenderReceiverRnnGS(sender, receiver, loss) optimizer = core.build_optimizer(game.parameters()) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader) trainer.train(n_epochs=opts.n_epochs) sender_inputs, messages, _, receiver_outputs, labels = \ core.dump_sender_receiver(game, test_loader, gs=True, device=device, variable_length=True) for (seq, l), message, output, label in zip(sender_inputs, messages, receiver_outputs, labels): print(f'{seq[:l]} -> {message} -> {output.argmax()} (label = {label})') core.close()
callbacks = [core.ConsoleLogger(print_train_loss=True, as_json=True)] if opts.mode.lower() == "gs": callbacks.append( core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1)) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_data, validation_data=validation_data, callbacks=callbacks, ) # validation_data=validation_data, trainer.train(n_epochs=opts.n_epochs) if opts.evaluate: is_gs = "gs" in opts.mode sender_inputs, messages, receiver_inputs, receiver_outputs, labels = core.dump_sender_receiver( game, test_data, is_gs, variable_length=True, device=device) _, _, _, train_receiver_outputs, train_labels = core.dump_sender_receiver( game, train_data, is_gs, variable_length=True, device=device) # Test receiver_outputs = move_to(receiver_outputs, device) labels = move_to(labels, device) receiver_outputs = torch.stack(receiver_outputs) labels = torch.stack(labels) output_is_vector = opts.mode.lower() in set( ["gs-hard", "gs", "rf-deterministic"]) if output_is_vector: tensor_accuracy = receiver_outputs.squeeze().argmax(