def __init__(
        self,
        sender: nn.Module,
        receiver: nn.Module,
        loss: Callable,
        train_logging_strategy: Optional[LoggingStrategy] = None,
        test_logging_strategy: Optional[LoggingStrategy] = None,
    ):
        """
        :param sender: Sender agent. sender.forward() has to output a continouos vector
        :param receiver: Receiver agent. receiver.forward() has to accept two parameters:
            message and receiver_input.
        `message` is shaped as (batch_size, vocab_size).
        :param loss: Callable that outputs differentiable loss, takes the following parameters:
          * sender_input: input to Sender (comes from dataset)
          * message: message sent from Sender
          * receiver_input: input to Receiver from dataset
          * receiver_output: output of Receiver
          * labels: labels that come from dataset
        :param train_logging_strategy, test_logging_strategy: specify what parts of interactions to persist for
            later analysis in the callbacks.
        """
        super(SenderReceiverContinuousCommunication, self).__init__()
        self.sender = sender
        self.receiver = receiver
        self.loss = loss

        self.train_logging_strategy = (LoggingStrategy()
                                       if train_logging_strategy is None else
                                       train_logging_strategy)
        self.test_logging_strategy = (LoggingStrategy()
                                      if test_logging_strategy is None else
                                      test_logging_strategy)
Exemple #2
0
def build_game(opts):
    vision_encoder, visual_features_dim = build_vision_encoder(
        model_name=opts.model_name,
        shared_vision=opts.shared_vision,
        pretrain_vision=opts.pretrain_vision,
    )

    loss = get_loss(
        temperature=opts.loss_temperature,
        similarity=opts.similarity,
        loss_type=opts.loss_type,
    )

    train_logging_strategy = LoggingStrategy(False, False, True, True, True,
                                             False)
    test_logging_strategy = LoggingStrategy(False, False, True, True, True,
                                            False)

    if opts.simclr_sender:
        sender = SimCLRSender(
            input_dim=visual_features_dim,
            hidden_dim=opts.projection_hidden_dim,
            output_dim=opts.projection_output_dim,
            discrete_evaluation=opts.discrete_evaluation_simclr,
        )
        receiver = sender
    else:
        sender = EmSSLSender(
            input_dim=visual_features_dim,
            hidden_dim=opts.projection_hidden_dim,
            output_dim=opts.projection_output_dim,
            temperature=opts.gs_temperature,
            trainable_temperature=opts.train_gs_temperature,
            straight_through=opts.straight_through,
        )
        receiver = Receiver(
            input_dim=visual_features_dim,
            hidden_dim=opts.projection_hidden_dim,
            output_dim=opts.projection_output_dim,
        )

    game = EmComSSLSymbolGame(
        sender,
        receiver,
        loss,
        train_logging_strategy=train_logging_strategy,
        test_logging_strategy=test_logging_strategy,
    )

    game = VisionGameWrapper(game, vision_encoder)
    if opts.distributed_context.is_distributed:
        game = torch.nn.SyncBatchNorm.convert_sync_batchnorm(game)

    return game
Exemple #3
0
    def __init__(
        self,
        train_logging_strategy: Optional[LoggingStrategy] = None,
        test_logging_strategy: Optional[LoggingStrategy] = None,
    ):
        super(Game, self).__init__()

        self.train_logging_strategy = (LoggingStrategy()
                                       if train_logging_strategy is None else
                                       train_logging_strategy)
        self.test_logging_strategy = (LoggingStrategy()
                                      if test_logging_strategy is None else
                                      test_logging_strategy)
Exemple #4
0
def build_game(
    batch_size: int = 32,
    loss_temperature: float = 0.1,
    vision_encoder_name: str = "resnet50",
    output_size: int = 128,
    is_distributed: bool = False,
):
    vision_module, visual_features_dim = get_vision_module(
        encoder_arch=vision_encoder_name)
    vision_encoder = VisionModule(vision_module=vision_module)

    train_logging_strategy = LoggingStrategy.minimal()
    assert (not batch_size %
            2), f"Batch size must be multiple of 2. Found {batch_size} instead"

    loss = Loss(batch_size, loss_temperature)

    sender = Sender(visual_features_dim=visual_features_dim,
                    output_dim=output_size)
    receiver = Receiver(visual_features_dim=visual_features_dim,
                        output_dim=output_size)

    game = SenderReceiverContinuousCommunication(sender, receiver, loss,
                                                 train_logging_strategy)

    game = VisionGameWrapper(game, vision_encoder)
    if is_distributed:
        game = torch.nn.SyncBatchNorm.convert_sync_batchnorm(game)

    return game
Exemple #5
0
 def __init__(
     self,
     sender,
     receiver,
     loss,
     train_logging_strategy: Optional[LoggingStrategy] = None,
     test_logging_strategy: Optional[LoggingStrategy] = None,
 ):
     super(SenderReceiverBiRnnSTGS, self).__init__()
     self.sender = sender
     self.receiver = receiver
     self.loss = loss
     self.train_logging_strategy = (LoggingStrategy()
                                    if train_logging_strategy is None else
                                    train_logging_strategy)
     self.test_logging_strategy = (LoggingStrategy()
                                   if test_logging_strategy is None else
                                   test_logging_strategy)
Exemple #6
0
def evaluate_test_set(
    game: nn.Module,
    data: torch.utils.data.DataLoader,
    k_means_clusters: KMeans,
    num_clusters: int
):
    if torch.cuda.is_available():
        game.cuda()
    game.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    logging_strategy = LoggingStrategy(False, False, True, True, True, False)

    mean_loss = 0.0
    interactions = []
    n_batches = 0
    soft_accuracy, game_accuracy = 0.0, 0.0
    for batch in data:
        batch = move_to(batch, device)
        (x_i, x_j), labels = batch
        if torch.cuda.is_available():
            x_i = x_i.cuda()
            x_j = x_j.cuda()

        with torch.no_grad():
            sender_encoded_input, receiver_encoded_input = game.vision_module(x_i, x_j)
            message, message_like, resnet_output_sender = game.game.sender(sender_encoded_input, sender=True)

            resnet_output_sender_to_predict = resnet_output_sender.cpu().numpy()
            k_means_labels = torch.from_numpy(
                k_means_clusters.predict(resnet_output_sender_to_predict)
            ).to(device=message_like.device, dtype=torch.int64)

            one_hot_k_means_labels = torch.zeros((message_like.size()[0], num_clusters), device=message_like.device)
            one_hot_k_means_labels.scatter_(1, k_means_labels.view(-1, 1), 1)

            receiver_output, resnet_output_recv = game.game.receiver(message, receiver_encoded_input)

            loss, aux_info = game.game.loss(
                sender_encoded_input, message, receiver_encoded_input, receiver_output, labels
            )

            if hasattr(game.game.sender, "temperature"):
                if isinstance(game.game.sender.temperature, torch.nn.Parameter):
                    temperature = game.game.sender.temperature.detach()
                else:
                    temperature = torch.Tensor([game.game.sender.temperature])
                aux_info["temperature"] = temperature

            aux_info["message_like"] = message_like
            aux_info["kmeans"] = one_hot_k_means_labels
            aux_info["resnet_output_sender"] = resnet_output_sender
            aux_info["resnet_output_recv"] = resnet_output_recv

            interaction = logging_strategy.filtered_interaction(
                sender_input=sender_encoded_input,
                receiver_input=receiver_encoded_input,
                labels=labels,
                receiver_output=receiver_output.detach(),
                message=message,
                message_length=torch.ones(message_like.shape[0]),
                aux=aux_info,
            )

            interaction = interaction.to("cpu")
            interactions.append(interaction)

            mean_loss += loss.mean()
            soft_accuracy += interaction.aux['acc'].mean().item()
            game_accuracy += interaction.aux['game_acc'].mean().item()
            n_batches += 1
            if n_batches % 10 == 0:
                print(f"finished batch {n_batches}")

    print(f"processed {n_batches} in total")
    mean_loss /= n_batches
    soft_accuracy /= n_batches
    game_accuracy /= n_batches
    full_interaction = Interaction.from_iterable(interactions)

    return mean_loss, soft_accuracy, game_accuracy, full_interaction
Exemple #7
0
def main(params):
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    force_eos = opts.force_eos == 1

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32)
    probs /= probs.sum()

    print('the probs are: ', probs, flush=True)

    train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch, probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    if opts.sender_cell == 'transformer':
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding)
        sender = core.TransformerSenderReinforce(agent=sender, vocab_size=opts.vocab_size,
                                                 embed_dim=opts.sender_embedding, max_len=opts.max_len,
                                                 num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads,
                                                 hidden_size=opts.sender_hidden,
                                                 force_eos=opts.force_eos,
                                                 generate_style=opts.sender_generate_style,
                                                 causal=opts.causal_sender)
    else:
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)

        sender = core.RnnSenderReinforce(sender,
                                   opts.vocab_size, opts.sender_embedding, opts.sender_hidden,
                                   cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers,
                                   force_eos=force_eos)
    if opts.receiver_cell == 'transformer':
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_embedding)
        receiver = core.TransformerReceiverDeterministic(receiver, opts.vocab_size, opts.max_len,
                                                         opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden,
                                                         opts.receiver_num_layers, causal=opts.causal_receiver)
    else:
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
        receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_embedding,
                                             opts.receiver_hidden, cell=opts.receiver_cell,
                                             num_layers=opts.receiver_num_layers)

    empty_logger = LoggingStrategy.minimal()
    game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff,
                                           receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                           train_logging_strategy=empty_logger,
                                           length_cost=opts.length_cost)

    optimizer = core.build_optimizer(game.parameters())

    callbacks = [EarlyStopperAccuracy(opts.early_stopping_thr),
               core.ConsoleLogger(as_json=True, print_train_loss=True)]

    if opts.checkpoint_dir:
        checkpoint_name = f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}'
        callbacks.append(core.CheckpointSaver(checkpoint_path=opts.checkpoint_dir, prefix=checkpoint_name))

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=callbacks)

    trainer.train(n_epochs=opts.n_epochs)

    game.logging_strategy = LoggingStrategy.maximal() # now log everything
    dump(trainer.game, opts.n_features, device, False)
    core.close()
Exemple #8
0
def build_game(opts):

    train_logging_strategy = LoggingStrategy(False, False, False, False, False,
                                             False, False)
    test_logging_strategy = LoggingStrategy(False, False, True, True, True,
                                            True, False)

    if opts.use_different_architectures:
        vision_module_names = opts.vision_model_names

        print(vision_module_names)

        vision_modules = [
            initialize_vision_module(name=vision_module_names[i],
                                     pretrained=True)
            for i in range(opts.n_senders)
        ]

        senders = [
            GumbelSoftmaxWrapper(
                Sender(vision_module=vision_modules[i][0],
                       input_dim=vision_modules[i][1],
                       vocab_size=opts.vocab_size,
                       name=vision_module_names[i]),
                temperature=opts.gs_temperature,
                trainable_temperature=opts.train_gs_temperature,
                straight_through=opts.straight_through,
            ) for i in range(opts.n_senders)
        ]
        receivers = [
            SymbolReceiverWrapper(
                Receiver(vision_module=vision_modules[i][0],
                         input_dim=vision_modules[i][1],
                         hidden_dim=opts.recv_hidden_dim,
                         output_dim=opts.recv_output_dim,
                         temperature=opts.recv_temperature,
                         name=vision_module_names[i]),
                opts.vocab_size,
                opts.recv_output_dim,
            ) for i in range(opts.n_recvs)
        ]

    else:
        vision_module, input_dim, name = initialize_vision_module(
            name=opts.vision_model_name, pretrained=True)
        senders = [
            GumbelSoftmaxWrapper(
                Sender(vision_module=vision_module,
                       input_dim=input_dim,
                       vocab_size=opts.vocab_size,
                       name=name),
                temperature=opts.gs_temperature,
                trainable_temperature=opts.train_gs_temperature,
                straight_through=opts.straight_through,
            ) for _ in range(opts.n_senders)
        ]
        receivers = [
            SymbolReceiverWrapper(
                Receiver(vision_module=vision_module,
                         input_dim=input_dim,
                         hidden_dim=opts.recv_hidden_dim,
                         output_dim=opts.recv_output_dim,
                         temperature=opts.recv_temperature,
                         name=name),
                opts.vocab_size,
                opts.recv_output_dim,
            ) for _ in range(opts.n_recvs)
        ]

    agents_loss_sampler = AgentSampler(
        senders,
        receivers,
        [loss],
    )

    game = Game(
        train_logging_strategy=train_logging_strategy,
        test_logging_strategy=test_logging_strategy,
    )

    game = PopulationGame(game, agents_loss_sampler)

    if opts.distributed_context.is_distributed:
        game = torch.nn.SyncBatchNorm.convert_sync_batchnorm(game)

    return game
Exemple #9
0
def main(params):
    opts = get_params(params)
    print(opts, flush=True)

    # For compatibility, after https://github.com/facebookresearch/EGG/pull/130
    # the meaning of `length` changed a bit. Before it included the EOS symbol; now
    # it doesn't. To ensure that hyperparameters/CL arguments do not change,
    # we subtract it here.
    opts.max_len -= 1

    device = opts.device

    if opts.probs == "uniform":
        probs = np.ones(opts.n_features)
    elif opts.probs == "powerlaw":
        probs = 1 / np.arange(1, opts.n_features + 1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(",")], dtype=np.float32)
    probs /= probs.sum()

    print("the probs are: ", probs, flush=True)

    train_loader = OneHotLoader(
        n_features=opts.n_features,
        batch_size=opts.batch_size,
        batches_per_epoch=opts.batches_per_epoch,
        probs=probs,
    )

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    if opts.sender_cell == "transformer":
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding)
        sender = core.TransformerSenderReinforce(
            agent=sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            max_len=opts.max_len,
            num_layers=opts.sender_num_layers,
            num_heads=opts.sender_num_heads,
            hidden_size=opts.sender_hidden,
            generate_style=opts.sender_generate_style,
            causal=opts.causal_sender,
        )
    else:
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)

        sender = core.RnnSenderReinforce(
            sender,
            opts.vocab_size,
            opts.sender_embedding,
            opts.sender_hidden,
            cell=opts.sender_cell,
            max_len=opts.max_len,
            num_layers=opts.sender_num_layers,
        )
    if opts.receiver_cell == "transformer":
        receiver = Receiver(
            n_features=opts.n_features, n_hidden=opts.receiver_embedding
        )
        receiver = core.TransformerReceiverDeterministic(
            receiver,
            opts.vocab_size,
            opts.max_len,
            opts.receiver_embedding,
            opts.receiver_num_heads,
            opts.receiver_hidden,
            opts.receiver_num_layers,
            causal=opts.causal_receiver,
        )
    else:
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
        receiver = core.RnnReceiverDeterministic(
            receiver,
            opts.vocab_size,
            opts.receiver_embedding,
            opts.receiver_hidden,
            cell=opts.receiver_cell,
            num_layers=opts.receiver_num_layers,
        )

    empty_logger = LoggingStrategy.minimal()
    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=opts.receiver_entropy_coeff,
        train_logging_strategy=empty_logger,
        length_cost=opts.length_cost,
    )

    optimizer = core.build_optimizer(game.parameters())

    callbacks = [
        EarlyStopperAccuracy(opts.early_stopping_thr),
        core.ConsoleLogger(as_json=True, print_train_loss=True),
    ]

    if opts.checkpoint_dir:
        checkpoint_name = f"{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}"
        callbacks.append(
            core.CheckpointSaver(
                checkpoint_path=opts.checkpoint_dir, prefix=checkpoint_name
            )
        )

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=callbacks,
    )

    trainer.train(n_epochs=opts.n_epochs)

    game.logging_strategy = LoggingStrategy.maximal()  # now log everything
    dump(trainer.game, opts.n_features, device, False)
    core.close()