def __init__( self, sender: nn.Module, receiver: nn.Module, loss: Callable, train_logging_strategy: Optional[LoggingStrategy] = None, test_logging_strategy: Optional[LoggingStrategy] = None, ): """ :param sender: Sender agent. sender.forward() has to output a continouos vector :param receiver: Receiver agent. receiver.forward() has to accept two parameters: message and receiver_input. `message` is shaped as (batch_size, vocab_size). :param loss: Callable that outputs differentiable loss, takes the following parameters: * sender_input: input to Sender (comes from dataset) * message: message sent from Sender * receiver_input: input to Receiver from dataset * receiver_output: output of Receiver * labels: labels that come from dataset :param train_logging_strategy, test_logging_strategy: specify what parts of interactions to persist for later analysis in the callbacks. """ super(SenderReceiverContinuousCommunication, self).__init__() self.sender = sender self.receiver = receiver self.loss = loss self.train_logging_strategy = (LoggingStrategy() if train_logging_strategy is None else train_logging_strategy) self.test_logging_strategy = (LoggingStrategy() if test_logging_strategy is None else test_logging_strategy)
def build_game(opts): vision_encoder, visual_features_dim = build_vision_encoder( model_name=opts.model_name, shared_vision=opts.shared_vision, pretrain_vision=opts.pretrain_vision, ) loss = get_loss( temperature=opts.loss_temperature, similarity=opts.similarity, loss_type=opts.loss_type, ) train_logging_strategy = LoggingStrategy(False, False, True, True, True, False) test_logging_strategy = LoggingStrategy(False, False, True, True, True, False) if opts.simclr_sender: sender = SimCLRSender( input_dim=visual_features_dim, hidden_dim=opts.projection_hidden_dim, output_dim=opts.projection_output_dim, discrete_evaluation=opts.discrete_evaluation_simclr, ) receiver = sender else: sender = EmSSLSender( input_dim=visual_features_dim, hidden_dim=opts.projection_hidden_dim, output_dim=opts.projection_output_dim, temperature=opts.gs_temperature, trainable_temperature=opts.train_gs_temperature, straight_through=opts.straight_through, ) receiver = Receiver( input_dim=visual_features_dim, hidden_dim=opts.projection_hidden_dim, output_dim=opts.projection_output_dim, ) game = EmComSSLSymbolGame( sender, receiver, loss, train_logging_strategy=train_logging_strategy, test_logging_strategy=test_logging_strategy, ) game = VisionGameWrapper(game, vision_encoder) if opts.distributed_context.is_distributed: game = torch.nn.SyncBatchNorm.convert_sync_batchnorm(game) return game
def __init__( self, train_logging_strategy: Optional[LoggingStrategy] = None, test_logging_strategy: Optional[LoggingStrategy] = None, ): super(Game, self).__init__() self.train_logging_strategy = (LoggingStrategy() if train_logging_strategy is None else train_logging_strategy) self.test_logging_strategy = (LoggingStrategy() if test_logging_strategy is None else test_logging_strategy)
def build_game( batch_size: int = 32, loss_temperature: float = 0.1, vision_encoder_name: str = "resnet50", output_size: int = 128, is_distributed: bool = False, ): vision_module, visual_features_dim = get_vision_module( encoder_arch=vision_encoder_name) vision_encoder = VisionModule(vision_module=vision_module) train_logging_strategy = LoggingStrategy.minimal() assert (not batch_size % 2), f"Batch size must be multiple of 2. Found {batch_size} instead" loss = Loss(batch_size, loss_temperature) sender = Sender(visual_features_dim=visual_features_dim, output_dim=output_size) receiver = Receiver(visual_features_dim=visual_features_dim, output_dim=output_size) game = SenderReceiverContinuousCommunication(sender, receiver, loss, train_logging_strategy) game = VisionGameWrapper(game, vision_encoder) if is_distributed: game = torch.nn.SyncBatchNorm.convert_sync_batchnorm(game) return game
def __init__( self, sender, receiver, loss, train_logging_strategy: Optional[LoggingStrategy] = None, test_logging_strategy: Optional[LoggingStrategy] = None, ): super(SenderReceiverBiRnnSTGS, self).__init__() self.sender = sender self.receiver = receiver self.loss = loss self.train_logging_strategy = (LoggingStrategy() if train_logging_strategy is None else train_logging_strategy) self.test_logging_strategy = (LoggingStrategy() if test_logging_strategy is None else test_logging_strategy)
def evaluate_test_set( game: nn.Module, data: torch.utils.data.DataLoader, k_means_clusters: KMeans, num_clusters: int ): if torch.cuda.is_available(): game.cuda() game.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging_strategy = LoggingStrategy(False, False, True, True, True, False) mean_loss = 0.0 interactions = [] n_batches = 0 soft_accuracy, game_accuracy = 0.0, 0.0 for batch in data: batch = move_to(batch, device) (x_i, x_j), labels = batch if torch.cuda.is_available(): x_i = x_i.cuda() x_j = x_j.cuda() with torch.no_grad(): sender_encoded_input, receiver_encoded_input = game.vision_module(x_i, x_j) message, message_like, resnet_output_sender = game.game.sender(sender_encoded_input, sender=True) resnet_output_sender_to_predict = resnet_output_sender.cpu().numpy() k_means_labels = torch.from_numpy( k_means_clusters.predict(resnet_output_sender_to_predict) ).to(device=message_like.device, dtype=torch.int64) one_hot_k_means_labels = torch.zeros((message_like.size()[0], num_clusters), device=message_like.device) one_hot_k_means_labels.scatter_(1, k_means_labels.view(-1, 1), 1) receiver_output, resnet_output_recv = game.game.receiver(message, receiver_encoded_input) loss, aux_info = game.game.loss( sender_encoded_input, message, receiver_encoded_input, receiver_output, labels ) if hasattr(game.game.sender, "temperature"): if isinstance(game.game.sender.temperature, torch.nn.Parameter): temperature = game.game.sender.temperature.detach() else: temperature = torch.Tensor([game.game.sender.temperature]) aux_info["temperature"] = temperature aux_info["message_like"] = message_like aux_info["kmeans"] = one_hot_k_means_labels aux_info["resnet_output_sender"] = resnet_output_sender aux_info["resnet_output_recv"] = resnet_output_recv interaction = logging_strategy.filtered_interaction( sender_input=sender_encoded_input, receiver_input=receiver_encoded_input, labels=labels, receiver_output=receiver_output.detach(), message=message, message_length=torch.ones(message_like.shape[0]), aux=aux_info, ) interaction = interaction.to("cpu") interactions.append(interaction) mean_loss += loss.mean() soft_accuracy += interaction.aux['acc'].mean().item() game_accuracy += interaction.aux['game_acc'].mean().item() n_batches += 1 if n_batches % 10 == 0: print(f"finished batch {n_batches}") print(f"processed {n_batches} in total") mean_loss /= n_batches soft_accuracy /= n_batches game_accuracy /= n_batches full_interaction = Interaction.from_iterable(interactions) return mean_loss, soft_accuracy, game_accuracy, full_interaction
def main(params): opts = get_params(params) print(opts, flush=True) device = opts.device force_eos = opts.force_eos == 1 if opts.probs == 'uniform': probs = np.ones(opts.n_features) elif opts.probs == 'powerlaw': probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32) else: probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32) probs /= probs.sum() print('the probs are: ', probs, flush=True) train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size, batches_per_epoch=opts.batches_per_epoch, probs=probs) # single batches with 1s on the diag test_loader = UniformLoader(opts.n_features) if opts.sender_cell == 'transformer': sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding) sender = core.TransformerSenderReinforce(agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_embedding, max_len=opts.max_len, num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads, hidden_size=opts.sender_hidden, force_eos=opts.force_eos, generate_style=opts.sender_generate_style, causal=opts.causal_sender) else: sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden) sender = core.RnnSenderReinforce(sender, opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers, force_eos=force_eos) if opts.receiver_cell == 'transformer': receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_embedding) receiver = core.TransformerReceiverDeterministic(receiver, opts.vocab_size, opts.max_len, opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden, opts.receiver_num_layers, causal=opts.causal_receiver) else: receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers) empty_logger = LoggingStrategy.minimal() game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff, train_logging_strategy=empty_logger, length_cost=opts.length_cost) optimizer = core.build_optimizer(game.parameters()) callbacks = [EarlyStopperAccuracy(opts.early_stopping_thr), core.ConsoleLogger(as_json=True, print_train_loss=True)] if opts.checkpoint_dir: checkpoint_name = f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}' callbacks.append(core.CheckpointSaver(checkpoint_path=opts.checkpoint_dir, prefix=checkpoint_name)) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=callbacks) trainer.train(n_epochs=opts.n_epochs) game.logging_strategy = LoggingStrategy.maximal() # now log everything dump(trainer.game, opts.n_features, device, False) core.close()
def build_game(opts): train_logging_strategy = LoggingStrategy(False, False, False, False, False, False, False) test_logging_strategy = LoggingStrategy(False, False, True, True, True, True, False) if opts.use_different_architectures: vision_module_names = opts.vision_model_names print(vision_module_names) vision_modules = [ initialize_vision_module(name=vision_module_names[i], pretrained=True) for i in range(opts.n_senders) ] senders = [ GumbelSoftmaxWrapper( Sender(vision_module=vision_modules[i][0], input_dim=vision_modules[i][1], vocab_size=opts.vocab_size, name=vision_module_names[i]), temperature=opts.gs_temperature, trainable_temperature=opts.train_gs_temperature, straight_through=opts.straight_through, ) for i in range(opts.n_senders) ] receivers = [ SymbolReceiverWrapper( Receiver(vision_module=vision_modules[i][0], input_dim=vision_modules[i][1], hidden_dim=opts.recv_hidden_dim, output_dim=opts.recv_output_dim, temperature=opts.recv_temperature, name=vision_module_names[i]), opts.vocab_size, opts.recv_output_dim, ) for i in range(opts.n_recvs) ] else: vision_module, input_dim, name = initialize_vision_module( name=opts.vision_model_name, pretrained=True) senders = [ GumbelSoftmaxWrapper( Sender(vision_module=vision_module, input_dim=input_dim, vocab_size=opts.vocab_size, name=name), temperature=opts.gs_temperature, trainable_temperature=opts.train_gs_temperature, straight_through=opts.straight_through, ) for _ in range(opts.n_senders) ] receivers = [ SymbolReceiverWrapper( Receiver(vision_module=vision_module, input_dim=input_dim, hidden_dim=opts.recv_hidden_dim, output_dim=opts.recv_output_dim, temperature=opts.recv_temperature, name=name), opts.vocab_size, opts.recv_output_dim, ) for _ in range(opts.n_recvs) ] agents_loss_sampler = AgentSampler( senders, receivers, [loss], ) game = Game( train_logging_strategy=train_logging_strategy, test_logging_strategy=test_logging_strategy, ) game = PopulationGame(game, agents_loss_sampler) if opts.distributed_context.is_distributed: game = torch.nn.SyncBatchNorm.convert_sync_batchnorm(game) return game
def main(params): opts = get_params(params) print(opts, flush=True) # For compatibility, after https://github.com/facebookresearch/EGG/pull/130 # the meaning of `length` changed a bit. Before it included the EOS symbol; now # it doesn't. To ensure that hyperparameters/CL arguments do not change, # we subtract it here. opts.max_len -= 1 device = opts.device if opts.probs == "uniform": probs = np.ones(opts.n_features) elif opts.probs == "powerlaw": probs = 1 / np.arange(1, opts.n_features + 1, dtype=np.float32) else: probs = np.array([float(x) for x in opts.probs.split(",")], dtype=np.float32) probs /= probs.sum() print("the probs are: ", probs, flush=True) train_loader = OneHotLoader( n_features=opts.n_features, batch_size=opts.batch_size, batches_per_epoch=opts.batches_per_epoch, probs=probs, ) # single batches with 1s on the diag test_loader = UniformLoader(opts.n_features) if opts.sender_cell == "transformer": sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding) sender = core.TransformerSenderReinforce( agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_embedding, max_len=opts.max_len, num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads, hidden_size=opts.sender_hidden, generate_style=opts.sender_generate_style, causal=opts.causal_sender, ) else: sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden) sender = core.RnnSenderReinforce( sender, opts.vocab_size, opts.sender_embedding, opts.sender_hidden, cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers, ) if opts.receiver_cell == "transformer": receiver = Receiver( n_features=opts.n_features, n_hidden=opts.receiver_embedding ) receiver = core.TransformerReceiverDeterministic( receiver, opts.vocab_size, opts.max_len, opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden, opts.receiver_num_layers, causal=opts.causal_receiver, ) else: receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden) receiver = core.RnnReceiverDeterministic( receiver, opts.vocab_size, opts.receiver_embedding, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.receiver_num_layers, ) empty_logger = LoggingStrategy.minimal() game = core.SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=opts.receiver_entropy_coeff, train_logging_strategy=empty_logger, length_cost=opts.length_cost, ) optimizer = core.build_optimizer(game.parameters()) callbacks = [ EarlyStopperAccuracy(opts.early_stopping_thr), core.ConsoleLogger(as_json=True, print_train_loss=True), ] if opts.checkpoint_dir: checkpoint_name = f"{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}" callbacks.append( core.CheckpointSaver( checkpoint_path=opts.checkpoint_dir, prefix=checkpoint_name ) ) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=callbacks, ) trainer.train(n_epochs=opts.n_epochs) game.logging_strategy = LoggingStrategy.maximal() # now log everything dump(trainer.game, opts.n_features, device, False) core.close()