def to(self, device: torch.device): """Method to move all (nested) tensors of the batch to a specific device. This operation doest not change the original batch element and returns a new Batch instance. """ self.sender_input = move_to(self.sender_input, device) self.labels = move_to(self.labels, device) self.receiver_input = move_to(self.receiver_input, device) self.aux_input = move_to(self.aux_input, device) return self
def evaluate( game: nn.Module, data: torch.utils.data.DataLoader, ): if torch.cuda.is_available(): game.cuda() game.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mean_loss = 0.0 interactions = [] n_batches = 0 soft_accuracy, game_accuracy = 0.0, 0.0 with torch.no_grad(): for batch in data: batch = move_to(batch, device) optimized_loss, interaction = game(*batch) interaction = interaction.to("cpu") interactions.append(interaction) mean_loss += optimized_loss soft_accuracy += interaction.aux["acc"].mean().item() game_accuracy += interaction.aux["game_acc"].mean().item() n_batches += 1 if n_batches % 10 == 0: print(f"finished batch {n_batches}") # when running kmeans, we first feed the train data. # given we're clustering only a subset of 100_000 elements from the # training data we can stop after 128 (bsz) X 800 (batches) = 102_400 samples if n_batches == 800: break print(f"processed {n_batches} batches in total") mean_loss /= n_batches soft_accuracy /= n_batches game_accuracy /= n_batches full_interaction = Interaction.from_iterable(interactions) return mean_loss, soft_accuracy, game_accuracy, full_interaction
def evaluate_test_set( game: nn.Module, data: torch.utils.data.DataLoader, k_means_clusters: KMeans, num_clusters: int ): if torch.cuda.is_available(): game.cuda() game.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging_strategy = LoggingStrategy(False, False, True, True, True, False) mean_loss = 0.0 interactions = [] n_batches = 0 soft_accuracy, game_accuracy = 0.0, 0.0 for batch in data: batch = move_to(batch, device) (x_i, x_j), labels = batch if torch.cuda.is_available(): x_i = x_i.cuda() x_j = x_j.cuda() with torch.no_grad(): sender_encoded_input, receiver_encoded_input = game.vision_module(x_i, x_j) message, message_like, resnet_output_sender = game.game.sender(sender_encoded_input, sender=True) resnet_output_sender_to_predict = resnet_output_sender.cpu().numpy() k_means_labels = torch.from_numpy( k_means_clusters.predict(resnet_output_sender_to_predict) ).to(device=message_like.device, dtype=torch.int64) one_hot_k_means_labels = torch.zeros((message_like.size()[0], num_clusters), device=message_like.device) one_hot_k_means_labels.scatter_(1, k_means_labels.view(-1, 1), 1) receiver_output, resnet_output_recv = game.game.receiver(message, receiver_encoded_input) loss, aux_info = game.game.loss( sender_encoded_input, message, receiver_encoded_input, receiver_output, labels ) if hasattr(game.game.sender, "temperature"): if isinstance(game.game.sender.temperature, torch.nn.Parameter): temperature = game.game.sender.temperature.detach() else: temperature = torch.Tensor([game.game.sender.temperature]) aux_info["temperature"] = temperature aux_info["message_like"] = message_like aux_info["kmeans"] = one_hot_k_means_labels aux_info["resnet_output_sender"] = resnet_output_sender aux_info["resnet_output_recv"] = resnet_output_recv interaction = logging_strategy.filtered_interaction( sender_input=sender_encoded_input, receiver_input=receiver_encoded_input, labels=labels, receiver_output=receiver_output.detach(), message=message, message_length=torch.ones(message_like.shape[0]), aux=aux_info, ) interaction = interaction.to("cpu") interactions.append(interaction) mean_loss += loss.mean() soft_accuracy += interaction.aux['acc'].mean().item() game_accuracy += interaction.aux['game_acc'].mean().item() n_batches += 1 if n_batches % 10 == 0: print(f"finished batch {n_batches}") print(f"processed {n_batches} in total") mean_loss /= n_batches soft_accuracy /= n_batches game_accuracy /= n_batches full_interaction = Interaction.from_iterable(interactions) return mean_loss, soft_accuracy, game_accuracy, full_interaction
train_data=train_data, validation_data=validation_data, callbacks=callbacks, ) # validation_data=validation_data, trainer.train(n_epochs=opts.n_epochs) if opts.evaluate: is_gs = "gs" in opts.mode sender_inputs, messages, receiver_inputs, receiver_outputs, labels = core.dump_sender_receiver( game, test_data, is_gs, variable_length=True, device=device) _, _, _, train_receiver_outputs, train_labels = core.dump_sender_receiver( game, train_data, is_gs, variable_length=True, device=device) # Test receiver_outputs = move_to(receiver_outputs, device) labels = move_to(labels, device) receiver_outputs = torch.stack(receiver_outputs) labels = torch.stack(labels) output_is_vector = opts.mode.lower() in set( ["gs-hard", "gs", "rf-deterministic"]) if output_is_vector: tensor_accuracy = receiver_outputs.squeeze().argmax( dim=1) == labels else: tensor_accuracy = receiver_outputs == labels accuracy = torch.mean(tensor_accuracy.float()).item() # Train
def main(params): opts = get_params(params) device = torch.device("cuda" if opts.cuda else "cpu") data_loader = VectorsLoader( perceptual_dimensions=opts.perceptual_dimensions, n_distractors=opts.n_distractors, batch_size=opts.batch_size, train_samples=opts.train_samples, validation_samples=opts.validation_samples, test_samples=opts.test_samples, shuffle_train_data=opts.shuffle_train_data, dump_data_folder=opts.dump_data_folder, load_data_path=opts.load_data_path, seed=opts.data_seed, ) train_data, validation_data, test_data = data_loader.get_iterators() data_loader.upd_cl_options(opts) if opts.max_len > 1: baseline_msg = 'Cannot yet compute "smart" baseline value for messages of length greater than 1' else: baseline_msg = ( f"\n| Baselines measures with {opts.n_distractors} distractors and messages of max_len = {opts.max_len}:\n" f"| Dummy random baseline: accuracy = {1 / (opts.n_distractors + 1)}\n" ) if -1 not in opts.perceptual_dimensions: baseline_msg += f'| "Smart" baseline with perceptual_dimensions {opts.perceptual_dimensions} = {compute_baseline_accuracy(opts.n_distractors, opts.max_len, *opts.perceptual_dimensions)}\n' else: baseline_msg += f'| Data was loaded froman external file, thus no perceptual_dimension vector was provided, "smart baseline" cannot be computed\n' print(baseline_msg) sender = Sender(n_features=data_loader.n_features, n_hidden=opts.sender_hidden) receiver = Receiver(n_features=data_loader.n_features, linear_units=opts.receiver_hidden) if opts.mode.lower() == "gs": sender = core.RnnSenderGS( sender, opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, temperature=opts.temperature, ) receiver = core.RnnReceiverGS( receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, ) game = core.SenderReceiverRnnGS(sender, receiver, loss) else: raise NotImplementedError(f"Unknown training mode, {opts.mode}") optimizer = torch.optim.Adam([ { "params": game.sender.parameters(), "lr": opts.sender_lr }, { "params": game.receiver.parameters(), "lr": opts.receiver_lr }, ]) callbacks = [core.ConsoleLogger(as_json=True)] if opts.mode.lower() == "gs": callbacks.append( core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1)) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_data, validation_data=validation_data, callbacks=callbacks, ) trainer.train(n_epochs=opts.n_epochs) if opts.evaluate: is_gs = opts.mode == "gs" ( sender_inputs, messages, receiver_inputs, receiver_outputs, labels, ) = dump_sender_receiver(game, test_data, is_gs, variable_length=True, device=device) receiver_outputs = move_to(receiver_outputs, device) labels = move_to(labels, device) receiver_outputs = torch.stack(receiver_outputs) labels = torch.stack(labels) tensor_accuracy = receiver_outputs.argmax(dim=1) == labels accuracy = torch.mean(tensor_accuracy.float()).item() unique_dict = {} for elem in sender_inputs: target = "" for dim in elem: target += f"{str(int(dim.item()))}-" target = target[:-1] if target not in unique_dict: unique_dict[target] = True print(f"| Accuracy on test set: {accuracy}") compute_mi_input_msgs(sender_inputs, messages) print(f"entropy sender inputs {entropy(sender_inputs)}") print(f"mi sender inputs msgs {mutual_info(sender_inputs, messages)}") if opts.dump_msg_folder: opts.dump_msg_folder.mkdir(exist_ok=True) msg_dict = {} output_msg = ( f"messages_{opts.perceptual_dimensions}_vocab_{opts.vocab_size}" f"_maxlen_{opts.max_len}_bsize_{opts.batch_size}" f"_n_distractors_{opts.n_distractors}_train_size_{opts.train_samples}" f"_valid_size_{opts.validation_samples}_test_size_{opts.test_samples}" f"_slr_{opts.sender_lr}_rlr_{opts.receiver_lr}_shidden_{opts.sender_hidden}" f"_rhidden_{opts.receiver_hidden}_semb_{opts.sender_embedding}" f"_remb_{opts.receiver_embedding}_mode_{opts.mode}" f"_scell_{opts.sender_cell}_rcell_{opts.receiver_cell}.msg") output_file = opts.dump_msg_folder / output_msg with open(output_file, "w") as f: f.write(f"{opts}\n") for ( sender_input, message, receiver_input, receiver_output, label, ) in zip(sender_inputs, messages, receiver_inputs, receiver_outputs, labels): sender_input = ",".join(map(str, sender_input.tolist())) message = ",".join(map(str, message.tolist())) distractors_list = receiver_input.tolist() receiver_input = "; ".join([ ",".join(map(str, elem)) for elem in distractors_list ]) if is_gs: receiver_output = receiver_output.argmax() f.write( f"{sender_input} -> {receiver_input} -> {message} -> {receiver_output} (label={label.item()})\n" ) if message in msg_dict: msg_dict[message] += 1 else: msg_dict[message] = 1 sorted_msgs = sorted(msg_dict.items(), key=operator.itemgetter(1), reverse=True) f.write( f"\nUnique target vectors seen by sender: {len(unique_dict.keys())}\n" ) f.write( f"Unique messages produced by sender: {len(msg_dict.keys())}\n" ) f.write(f"Messagses: 'msg' : msg_count: {str(sorted_msgs)}\n") f.write(f"\nAccuracy: {accuracy}")
# ================ First play the recon game, get the sender ================== ref_game = SymbolicReferGame() ref_game.train(10000) # the argument is the number of training epochs. ref_game.game.eval() recon_game = SymbolicReconGame( training_log='~/GitWS/GameBias/log/recon_train_temp.txt') optimizer = core.build_optimizer(recon_game.game.receiver.parameters()) train_loss = [] test_loss = [] for i in range(5000): acc_list = [] for batch_idx, (target, label) in enumerate(recon_game.train_loader): optimizer.zero_grad() target = move_to(target, recon_game.trainer.device) label = move_to(label, recon_game.trainer.device) msg, _ = ref_game.sender(target) rec_out = recon_game.receiver(msg.detach()) loss, _ = recon_game.loss(target, msg, msg, rec_out, label) acc_list.append(loss.mean().item()) loss.sum().backward() optimizer.step() print('train loss:', np.mean(acc_list)) train_loss.append(np.mean(acc_list)) acc_list = [] for batch_idx, (target, label) in enumerate(recon_game.test_loader): recon_game.game.eval() target = move_to(target, recon_game.trainer.device)
def dump_sender_receiver( game: torch.nn.Module, dataset: "torch.utils.data.DataLoader", gs: bool, variable_length: bool, device: Optional[torch.device] = None, ): """ A tool to dump the interaction between Sender and Receiver :param game: A Game instance :param dataset: Dataset of inputs to be used when analyzing the communication :param gs: whether Gumbel-Softmax relaxation was used during training :param variable_length: whether variable-length communication is used :param device: device (e.g. 'cuda') to be used :return: """ train_state = game.training # persist so we restore it back game.eval() device = device if device is not None else common_opts.device sender_inputs, messages, receiver_inputs, receiver_outputs = [], [], [], [] labels = [] with torch.no_grad(): for batch in dataset: # by agreement, each batch is (sender_input, labels) plus optional (receiver_input) sender_input = move_to(batch[0], device) receiver_input = None if len(batch) == 2 else move_to( batch[2], device) message = game.sender(sender_input) # Under GS, the only output is a message; under Reinforce, two additional tensors are returned. # We don't need them. if not gs: message = message[0] output = game.receiver(message, receiver_input) if not gs: output = output[0] if batch[1] is not None: labels.extend(batch[1]) if isinstance(sender_input, list) or isinstance( sender_input, tuple): sender_inputs.extend(zip(*sender_input)) else: sender_inputs.extend(sender_input) if receiver_input is not None: receiver_inputs.extend(receiver_input) if gs: message = message.argmax( dim=-1) # actual symbols instead of one-hot encoded if not variable_length: messages.extend(message) receiver_outputs.extend(output) else: # A trickier part is to handle EOS in the messages. It also might happen that not every message has EOS. # We cut messages at EOS if it is present or return the entire message otherwise. Note, EOS id is always # set to 0. for i in range(message.size(0)): eos_positions = (message[i, :] == 0).nonzero() message_end = (eos_positions[0].item() if eos_positions.size(0) > 0 else -1) assert message_end == -1 or message[i, message_end] == 0 if message_end < 0: messages.append(message[i, :]) else: messages.append(message[i, :message_end + 1]) if gs: receiver_outputs.append(output[i, message_end, ...]) else: receiver_outputs.append(output[i, ...]) game.train(mode=train_state) return sender_inputs, messages, receiver_inputs, receiver_outputs, labels
raise NotImplementedError(f'Unknown training mode, {opts.mode}') optimizer = torch.optim.Adam([ {'params': game.sender.parameters(), 'lr': opts.sender_lr}, {'params': game.receiver.parameters(), 'lr': opts.receiver_lr} ]) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_data, validation_data=validation_data, epoch_callback=callback, as_json=opts.output_json) trainer.train(n_epochs=opts.n_epochs) if opts.evaluate: is_gs = opts.mode == 'gs' sender_inputs, messages, receiver_inputs, receiver_outputs, labels = core.dump_sender_receiver(game, test_data, is_gs, variable_length=True, device=device) receiver_outputs = move_to(receiver_outputs, device) labels = move_to(labels, device) receiver_outputs = torch.stack(receiver_outputs) labels = torch.stack(labels) tensor_accuracy = receiver_outputs.argmax(dim=1) == labels accuracy = torch.mean(tensor_accuracy.float()).item() unique_dict = {} for elem in sender_inputs: target = "" for dim in elem: target += f'{str(int(dim.item()))}-' target = target[:-1]