def __init__(self, training_log=None) -> None:
        super().__init__()

        self.training_log = training_log if training_log is not None else core.get_opts(
        ).training_log_path

        self.train_loader, self.test_loader = \
            get_symbolic_dataloader(
                n_attributes=self.n_attributes,
                n_values=self.n_values,
                batch_size=self.batch_size
            )

        self.sender = core.RnnSenderGS(SymbolicSenderMLP(
            input_dim=self.n_attributes * self.n_values,
            hidden_dim=self.hidden_size),
                                       self.vocab_size,
                                       self.emb_size,
                                       self.hidden_size,
                                       max_len=self.max_len,
                                       cell="lstm",
                                       temperature=1.0)

        self.receiver = core.RnnReceiverGS(SymbolicReceiverMLP(
            self.n_attributes * self.n_values, self.hidden_size),
                                           self.vocab_size,
                                           self.emb_size,
                                           self.hidden_size,
                                           cell="lstm")

        self.game = core.SenderReceiverRnnGS(self.sender, self.receiver,
                                             self.loss)

        self.optimiser = core.build_optimizer(self.game.parameters())
        self.callbacks = []
        self.callbacks.append(
            ConsoleFileLogger(as_json=True,
                              print_train_loss=True,
                              logfile_path=self.training_log))
        self.callbacks.append(
            core.TemperatureUpdater(agent=self.sender, decay=0.9, minimum=0.1))
        self.callbacks.append(
            TopographicSimilarityLatents('hamming',
                                         'edit',
                                         log_path=core.get_opts().topo_path))
        self.trainer = core.Trainer(game=self.game,
                                    optimizer=self.optimiser,
                                    train_data=self.train_loader,
                                    validation_data=self.test_loader,
                                    callbacks=self.callbacks)
    def __init__(self, data_path:str) -> None:
        super().__init__()

        self.train_loader, self.test_loader = \
            get_dsprites_dataloader(
                batch_size=self.batch_size,
                path_to_data=data_path,
                game_size=self.game_size,
                referential=True
            )

        self.sender = core.RnnSenderGS(
            DspritesSenderCNN(self.hidden_size), 
            self.vocab_size,
            self.emb_size,
            self.hidden_size,
            max_len=self.max_len,
            cell="lstm", 
            temperature=1.0
        )

        self.receiver = core.RnnReceiverGS(
            DSpritesReceiverCNN(self.game_size, self.emb_size, self.hidden_size, reinforce=False),
            self.vocab_size,
            self.emb_size,
            self.hidden_size,
            cell="lstm"
        )

        self.game = core.SenderReceiverRnnGS(self.sender, self.receiver, self.loss)

        self.optimiser = core.build_optimizer(self.game.parameters())
        self.callbacks = []
        self.callbacks.append(core.ConsoleLogger(as_json=True,print_train_loss=True))
        self.callbacks.append(TopographicSimilarityLatents('euclidean', 'edit'))
        #self.callbacks.append(core.TemperatureUpdater(agent=self.sender, decay=0.9, minimum=0.1))
        self.trainer = core.Trainer(
            game=self.game, optimizer=self.optimiser, train_data=self.train_loader, validation_data=self.test_loader,
            callbacks=self.callbacks
        )
Example #3
0
                                             opts.receiver_embedding,
                                             opts.receiver_hidden,
                                             cell=opts.receiver_cell,
                                             num_layers=opts.receiver_layers)

        game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            non_differentiable_loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff)
    elif opts.train_mode.lower() == 'gs':
        sender = core.RnnSenderGS(sender,
                                  opts.vocab_size,
                                  opts.sender_embedding,
                                  opts.sender_hidden,
                                  cell=opts.sender_cell,
                                  max_len=opts.max_len,
                                  temperature=opts.temperature)

        receiver = core.RnnReceiverGS(receiver,
                                      opts.vocab_size,
                                      opts.receiver_embedding,
                                      opts.receiver_hidden,
                                      cell=opts.receiver_cell)

        game = core.SenderReceiverRnnGS(sender, receiver, differentiable_loss)
    else:
        raise NotImplementedError(f'Unknown training mode, {opts.mode}')

    optimizer = core.build_optimizer(game.parameters())
Example #4
0
        if -1 not in opts.perceptual_dimensions:
            baseline_msg += f'| "Smart" baseline with perceptual_dimensions {opts.perceptual_dimensions} = {compute_baseline_accuracy(opts.n_distractors, opts.max_len, *opts.perceptual_dimensions)}\n'
        else:
            baseline_msg += f'| Data was loaded froman external file, thus no perceptual_dimension vector was provided, "smart baseline" cannot be computed\n'

    sender = Sender(n_features=data_loader.task.num_features,
                    n_hidden=opts.sender_hidden)
    receiver = Receiver(n_features=data_loader.task.num_features,
                        linear_units=opts.receiver_hidden)

    if opts.mode.lower() == "gs":
        sender = core.RnnSenderGS(
            sender,
            opts.vocab_size,
            opts.sender_embedding,
            opts.sender_hidden,
            cell=opts.sender_cell,
            max_len=opts.max_len,
            temperature=opts.temperature,
            force_eos=False,
        )
        receiver = core.RnnReceiverGS(
            receiver,
            opts.vocab_size,
            opts.receiver_embedding,
            opts.receiver_hidden,
            cell=opts.receiver_cell,
        )
        game = core.SenderReceiverRnnGS(sender, receiver, loss)
    elif opts.mode.lower() == "gs-hard":
        sender = core.RnnSenderGS(
            sender,
Example #5
0
def main(params):
    opts = get_params(params)

    device = torch.device("cuda" if opts.cuda else "cpu")

    data_loader = VectorsLoader(
        perceptual_dimensions=opts.perceptual_dimensions,
        n_distractors=opts.n_distractors,
        batch_size=opts.batch_size,
        train_samples=opts.train_samples,
        validation_samples=opts.validation_samples,
        test_samples=opts.test_samples,
        shuffle_train_data=opts.shuffle_train_data,
        dump_data_folder=opts.dump_data_folder,
        load_data_path=opts.load_data_path,
        seed=opts.data_seed,
    )

    train_data, validation_data, test_data = data_loader.get_iterators()

    data_loader.upd_cl_options(opts)

    if opts.max_len > 1:
        baseline_msg = 'Cannot yet compute "smart" baseline value for messages of length greater than 1'
    else:
        baseline_msg = (
            f"\n| Baselines measures with {opts.n_distractors} distractors and messages of max_len = {opts.max_len}:\n"
            f"| Dummy random baseline: accuracy = {1 / (opts.n_distractors + 1)}\n"
        )
        if -1 not in opts.perceptual_dimensions:
            baseline_msg += f'| "Smart" baseline with perceptual_dimensions {opts.perceptual_dimensions} = {compute_baseline_accuracy(opts.n_distractors, opts.max_len, *opts.perceptual_dimensions)}\n'
        else:
            baseline_msg += f'| Data was loaded froman external file, thus no perceptual_dimension vector was provided, "smart baseline" cannot be computed\n'

    print(baseline_msg)

    sender = Sender(n_features=data_loader.n_features,
                    n_hidden=opts.sender_hidden)

    receiver = Receiver(n_features=data_loader.n_features,
                        linear_units=opts.receiver_hidden)

    if opts.mode.lower() == "gs":
        sender = core.RnnSenderGS(
            sender,
            opts.vocab_size,
            opts.sender_embedding,
            opts.sender_hidden,
            cell=opts.sender_cell,
            max_len=opts.max_len,
            temperature=opts.temperature,
        )

        receiver = core.RnnReceiverGS(
            receiver,
            opts.vocab_size,
            opts.receiver_embedding,
            opts.receiver_hidden,
            cell=opts.receiver_cell,
        )

        game = core.SenderReceiverRnnGS(sender, receiver, loss)
    else:
        raise NotImplementedError(f"Unknown training mode, {opts.mode}")

    optimizer = torch.optim.Adam([
        {
            "params": game.sender.parameters(),
            "lr": opts.sender_lr
        },
        {
            "params": game.receiver.parameters(),
            "lr": opts.receiver_lr
        },
    ])
    callbacks = [core.ConsoleLogger(as_json=True)]
    if opts.mode.lower() == "gs":
        callbacks.append(
            core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1))
    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_data,
        validation_data=validation_data,
        callbacks=callbacks,
    )
    trainer.train(n_epochs=opts.n_epochs)

    if opts.evaluate:
        is_gs = opts.mode == "gs"
        (
            sender_inputs,
            messages,
            receiver_inputs,
            receiver_outputs,
            labels,
        ) = dump_sender_receiver(game,
                                 test_data,
                                 is_gs,
                                 variable_length=True,
                                 device=device)

        receiver_outputs = move_to(receiver_outputs, device)
        labels = move_to(labels, device)

        receiver_outputs = torch.stack(receiver_outputs)
        labels = torch.stack(labels)

        tensor_accuracy = receiver_outputs.argmax(dim=1) == labels
        accuracy = torch.mean(tensor_accuracy.float()).item()

        unique_dict = {}

        for elem in sender_inputs:
            target = ""
            for dim in elem:
                target += f"{str(int(dim.item()))}-"
            target = target[:-1]
            if target not in unique_dict:
                unique_dict[target] = True

        print(f"| Accuracy on test set: {accuracy}")

        compute_mi_input_msgs(sender_inputs, messages)

        print(f"entropy sender inputs {entropy(sender_inputs)}")
        print(f"mi sender inputs msgs {mutual_info(sender_inputs, messages)}")

        if opts.dump_msg_folder:
            opts.dump_msg_folder.mkdir(exist_ok=True)
            msg_dict = {}

            output_msg = (
                f"messages_{opts.perceptual_dimensions}_vocab_{opts.vocab_size}"
                f"_maxlen_{opts.max_len}_bsize_{opts.batch_size}"
                f"_n_distractors_{opts.n_distractors}_train_size_{opts.train_samples}"
                f"_valid_size_{opts.validation_samples}_test_size_{opts.test_samples}"
                f"_slr_{opts.sender_lr}_rlr_{opts.receiver_lr}_shidden_{opts.sender_hidden}"
                f"_rhidden_{opts.receiver_hidden}_semb_{opts.sender_embedding}"
                f"_remb_{opts.receiver_embedding}_mode_{opts.mode}"
                f"_scell_{opts.sender_cell}_rcell_{opts.receiver_cell}.msg")

            output_file = opts.dump_msg_folder / output_msg
            with open(output_file, "w") as f:
                f.write(f"{opts}\n")
                for (
                        sender_input,
                        message,
                        receiver_input,
                        receiver_output,
                        label,
                ) in zip(sender_inputs, messages, receiver_inputs,
                         receiver_outputs, labels):
                    sender_input = ",".join(map(str, sender_input.tolist()))
                    message = ",".join(map(str, message.tolist()))
                    distractors_list = receiver_input.tolist()
                    receiver_input = "; ".join([
                        ",".join(map(str, elem)) for elem in distractors_list
                    ])
                    if is_gs:
                        receiver_output = receiver_output.argmax()
                    f.write(
                        f"{sender_input} -> {receiver_input} -> {message} -> {receiver_output} (label={label.item()})\n"
                    )

                    if message in msg_dict:
                        msg_dict[message] += 1
                    else:
                        msg_dict[message] = 1

                sorted_msgs = sorted(msg_dict.items(),
                                     key=operator.itemgetter(1),
                                     reverse=True)
                f.write(
                    f"\nUnique target vectors seen by sender: {len(unique_dict.keys())}\n"
                )
                f.write(
                    f"Unique messages produced by sender: {len(msg_dict.keys())}\n"
                )
                f.write(f"Messagses: 'msg' : msg_count: {str(sorted_msgs)}\n")
                f.write(f"\nAccuracy: {accuracy}")
Example #6
0
def main(params):
    opts = get_params(params)
    if opts.validation_batch_size == 0:
        opts.validation_batch_size = opts.batch_size
    print(opts, flush=True)

    # the following if statement controls aspects specific to the two game tasks: loss, input data and architecture of the Receiver
    # (the Sender is identical in both cases, mapping a single input attribute-value vector to a variable-length message)
    if opts.game_type == "discri":
        # the game object we will encounter below takes as one of its mandatory arguments a loss: a loss in EGG is expected to take as arguments the sender input,
        # the message, the Receiver input, the Receiver output and the labels (although some of these elements might not actually be used by a particular loss);
        # together with the actual loss computation, the loss function can return a dictionary with other auxiliary statistics: in this case, accuracy
        def loss(
            _sender_input,
            _message,
            _receiver_input,
            receiver_output,
            labels,
            _aux_input,
        ):
            # in the discriminative case, accuracy is computed by comparing the index with highest score in Receiver output (a distribution of unnormalized
            # probabilities over target poisitions) and the corresponding label read from input, indicating the ground-truth position of the target
            acc = (receiver_output.argmax(dim=1) == labels).detach().float()
            # similarly, the loss computes cross-entropy between the Receiver-produced target-position probability distribution and the labels
            loss = F.cross_entropy(receiver_output, labels, reduction="none")
            return loss, {"acc": acc}

        # the input data are read into DataLodaer objects, which are pytorch constructs implementing standard data processing functionalities, such as shuffling
        # and batching
        # within our games, we implement dataset classes, such as AttValDiscriDataset, to read the input text files and convert the information they contain
        # into the form required by DataLoader
        # look at the definition of the AttValDiscrDataset (the class to read discrimination game data) in data_readers.py for further details
        # note that, for the training dataset, we first instantiate the AttValDiscriDataset object and then feed it to DataLoader, whereas for the
        # validation data (confusingly called "test" data due to code heritage inertia) we directly declare the AttValDiscriDataset when instantiating
        # DataLoader: the reason for this difference is that we need the train_ds object to retrieve the number of features of the input vectors
        train_ds = AttValDiscriDataset(path=opts.train_data,
                                       n_values=opts.n_values)
        train_loader = DataLoader(train_ds,
                                  batch_size=opts.batch_size,
                                  shuffle=True,
                                  num_workers=1)
        test_loader = DataLoader(
            AttValDiscriDataset(path=opts.validation_data,
                                n_values=opts.n_values),
            batch_size=opts.validation_batch_size,
            shuffle=False,
            num_workers=1,
        )
        # note that the number of features retrieved here concerns inputs after they are converted to 1-hot vectors
        n_features = train_ds.get_n_features()
        # we define here the core of the Receiver for the discriminative game, see the architectures.py file for details:
        # note that this will be embedded in a wrapper below to define the full agent
        receiver = DiscriReceiver(n_features=n_features,
                                  n_hidden=opts.receiver_hidden)

    else:  # reco game

        def loss(sender_input, _message, _receiver_input, receiver_output,
                 labels, _aux_input):
            # in the case of the recognition game, for each attribute we compute a different cross-entropy score
            # based on comparing the probability distribution produced by the Receiver over the values of each attribute
            # with the corresponding ground truth, and then averaging across attributes
            # accuracy is instead computed by considering as a hit only cases where, for each attribute, the Receiver
            # assigned the largest probability to the correct value
            # most of this function consists of the usual pytorch madness needed to reshape tensors in order to perform these computations
            n_attributes = opts.n_attributes
            n_values = opts.n_values
            batch_size = sender_input.size(0)
            receiver_output = receiver_output.view(batch_size * n_attributes,
                                                   n_values)
            receiver_guesses = receiver_output.argmax(dim=1)
            correct_samples = ((receiver_guesses == labels.view(-1)).view(
                batch_size, n_attributes).detach())
            acc = (torch.sum(correct_samples, dim=-1) == n_attributes).float()
            labels = labels.view(batch_size * n_attributes)
            loss = F.cross_entropy(receiver_output, labels, reduction="none")
            loss = loss.view(batch_size, -1).mean(dim=1)
            return loss, {"acc": acc}

        # again, see data_readers.py in this directory for the AttValRecoDataset data reading class
        train_loader = DataLoader(
            AttValRecoDataset(
                path=opts.train_data,
                n_attributes=opts.n_attributes,
                n_values=opts.n_values,
            ),
            batch_size=opts.batch_size,
            shuffle=True,
            num_workers=1,
        )
        test_loader = DataLoader(
            AttValRecoDataset(
                path=opts.validation_data,
                n_attributes=opts.n_attributes,
                n_values=opts.n_values,
            ),
            batch_size=opts.validation_batch_size,
            shuffle=False,
            num_workers=1,
        )
        # the number of features for the Receiver (input) and the Sender (output) is given by n_attributes*n_values because
        # they are fed/produce 1-hot representations of the input vectors
        n_features = opts.n_attributes * opts.n_values
        # we define here the core of the receiver for the discriminative game, see the architectures.py file for details
        # this will be embedded in a wrapper below to define the full architecture
        receiver = RecoReceiver(n_features=n_features,
                                n_hidden=opts.receiver_hidden)

    # we are now outside the block that defined game-type-specific aspects of the games: note that the core Sender architecture
    # (see architectures.py for details) is shared by the two games (it maps an input vector to a hidden layer that will be use to initialize
    # the message-producing RNN): this will also be embedded in a wrapper below to define the full architecture
    sender = Sender(n_hidden=opts.sender_hidden, n_features=n_features)

    # now, we instantiate the full sender and receiver architectures, and connect them and the loss into a game object
    # the implementation differs slightly depending on whether communication is optimized via Gumbel-Softmax ('gs') or Reinforce ('rf', default)
    if opts.mode.lower() == "gs":
        # in the following lines, we embed the Sender and Receiver architectures into standard EGG wrappers that are appropriate for Gumbel-Softmax optimization
        # the Sender wrapper takes the hidden layer produced by the core agent architecture we defined above when processing input, and uses it to initialize
        # the RNN that generates the message
        sender = core.RnnSenderGS(
            sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            hidden_size=opts.sender_hidden,
            cell=opts.sender_cell,
            max_len=opts.max_len,
            temperature=opts.temperature,
        )
        # the Receiver wrapper takes the symbol produced by the Sender at each step (more precisely, in Gumbel-Softmax mode, a function of the overall probability
        # of non-eos symbols upt to the step is used), maps it to a hidden layer through a RNN, and feeds this hidden layer to the
        # core Receiver architecture we defined above (possibly with other Receiver input, as determined by the core architecture) to generate the output
        receiver = core.RnnReceiverGS(
            receiver,
            vocab_size=opts.vocab_size,
            embed_dim=opts.receiver_embedding,
            hidden_size=opts.receiver_hidden,
            cell=opts.receiver_cell,
        )
        game = core.SenderReceiverRnnGS(sender, receiver, loss)
        # callback functions can be passed to the trainer object (see below) to operate at certain steps of training and validation
        # for example, the TemperatureUpdater (defined in callbacks.py in the core directory) will update the Gumbel-Softmax temperature hyperparameter
        # after each epoch
        callbacks = [
            core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1)
        ]
    else:  # NB: any other string than gs will lead to rf training!
        # here, the interesting thing to note is that we use the same core architectures we defined above, but now we embed them in wrappers that are suited to
        # Reinforce-based optmization
        sender = core.RnnSenderReinforce(
            sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            hidden_size=opts.sender_hidden,
            cell=opts.sender_cell,
            max_len=opts.max_len,
        )
        receiver = core.RnnReceiverDeterministic(
            receiver,
            vocab_size=opts.vocab_size,
            embed_dim=opts.receiver_embedding,
            hidden_size=opts.receiver_hidden,
            cell=opts.receiver_cell,
        )
        game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=0,
        )
        callbacks = []

    # we are almost ready to train: we define here an optimizer calling standard pytorch functionality
    optimizer = core.build_optimizer(game.parameters())
    # in the following statement, we finally instantiate the trainer object with all the components we defined (the game, the optimizer, the data
    # and the callbacks)
    if opts.print_validation_events == True:
        # we add a callback that will print loss and accuracy after each training and validation pass (see ConsoleLogger in callbacks.py in core directory)
        # if requested by the user, we will also print a detailed log of the validation pass after full training: look at PrintValidationEvents in
        # language_analysis.py (core directory)
        trainer = core.Trainer(
            game=game,
            optimizer=optimizer,
            train_data=train_loader,
            validation_data=test_loader,
            callbacks=callbacks + [
                core.ConsoleLogger(print_train_loss=True, as_json=True),
                core.PrintValidationEvents(n_epochs=opts.n_epochs),
            ],
        )
    else:
        trainer = core.Trainer(
            game=game,
            optimizer=optimizer,
            train_data=train_loader,
            validation_data=test_loader,
            callbacks=callbacks +
            [core.ConsoleLogger(print_train_loss=True, as_json=True)],
        )

    # and finally we train!
    trainer.train(n_epochs=opts.n_epochs)
Example #7
0
if __name__ == "__main__":
    opts = get_params()
    opts.on_slurm = os.environ.get('SLURM_JOB_NAME', False)
    core.util._set_seed(opts.seed)
    full_dataset, train, test = prepare_datasets()
    train_loader = DataLoader(train, batch_size=opts.batch_size, drop_last=False, shuffle=True)
    test_loader = DataLoader(test, batch_size=opts.batch_size, drop_last=False, shuffle=False)
    vision_module = Vision.from_pretrained('visual_data/vision_model.pth')
    pretrained_senders = [
        core.RnnSenderGS(
            agent=Sender(opts.sender_hidden, vision_module),
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            hidden_size=opts.sender_hidden,
            max_len=1,
            temperature=3.,
            trainable_temperature=True,
            cell=opts.rnn_cell,
            force_eos=False
        )
        for i in range(2)]
    sender_3 = core.RnnSenderGS(
            agent=Sender(opts.sender_hidden, vision_module),
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            hidden_size=opts.sender_hidden,
            max_len=opts.max_len,
            temperature=3.,
            trainable_temperature=True,
            force_eos=False,
                                        force_eos=False,
                                        cell=opts.rnn_cell)
     receiver = RnnReceiverDeterministic(agent=Receiver(
         opts.receiver_hidden, opts.n_features, opts.n_attributes),
                                         vocab_size=opts.vocab_size,
                                         embed_dim=opts.receiver_embedding,
                                         hidden_size=opts.receiver_hidden,
                                         cell=opts.rnn_cell)
 else:
     pretrained_senders = [
         core.RnnSenderGS(agent=VisualSender(opts.sender_hidden,
                                             opts.n_features,
                                             opts.n_attributes),
                          vocab_size=opts.vocab_size,
                          embed_dim=opts.sender_embedding,
                          hidden_size=opts.sender_hidden,
                          max_len=1,
                          temperature=3.,
                          trainable_temperature=True,
                          cell=opts.rnn_cell,
                          force_eos=False) for i in range(2)
     ]
     sender_3 = core.RnnSenderGS(agent=VisualSender(opts.sender_hidden,
                                                    opts.n_features,
                                                    opts.n_attributes),
                                 vocab_size=opts.vocab_size,
                                 embed_dim=opts.sender_embedding,
                                 hidden_size=opts.sender_hidden,
                                 max_len=opts.max_len,
                                 temperature=3.,
                                 trainable_temperature=True,
Example #9
0
File: train.py Project: evdcush/EGG
def main(params):
    opts = get_params(params)

    device = torch.device("cuda" if opts.cuda else "cpu")
    train_loader = OneHotLoader(n_features=opts.n_features,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch)
    test_loader = OneHotLoader(n_features=opts.n_features,
                               batch_size=opts.batch_size,
                               batches_per_epoch=opts.batches_per_epoch,
                               seed=7)

    sender = Sender(n_hidden=opts.sender_hidden, n_features=opts.n_features)
    receiver = Receiver(n_features=opts.n_features,
                        n_hidden=opts.receiver_hidden)

    if opts.mode.lower() == 'rf':
        sender = core.RnnSenderReinforce(sender,
                                         opts.vocab_size,
                                         opts.sender_embedding,
                                         opts.sender_hidden,
                                         cell=opts.sender_cell,
                                         max_len=opts.max_len)
        receiver = core.RnnReceiverDeterministic(receiver,
                                                 opts.vocab_size,
                                                 opts.receiver_embedding,
                                                 opts.receiver_hidden,
                                                 cell=opts.receiver_cell)

        game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff)
        callbacks = []
    elif opts.mode.lower() == 'gs':
        sender = core.RnnSenderGS(sender,
                                  opts.vocab_size,
                                  opts.sender_embedding,
                                  opts.sender_hidden,
                                  cell=opts.sender_cell,
                                  max_len=opts.max_len,
                                  temperature=opts.temperature)

        receiver = core.RnnReceiverGS(receiver,
                                      opts.vocab_size,
                                      opts.receiver_embedding,
                                      opts.receiver_hidden,
                                      cell=opts.receiver_cell)

        game = core.SenderReceiverRnnGS(sender, receiver, loss)
        callbacks = [
            core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1)
        ]
    else:
        raise NotImplementedError(f'Unknown training mode, {opts.mode}')

    optimizer = torch.optim.Adam([{
        'params': game.sender.parameters(),
        'lr': opts.sender_lr
    }, {
        'params': game.receiver.parameters(),
        'lr': opts.receiver_lr
    }])

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=callbacks +
                           [core.ConsoleLogger(as_json=True)])
    trainer.train(n_epochs=opts.n_epochs)

    core.close()
Example #10
0

if __name__ == "__main__":
    opts = get_params()
    opts.on_slurm = os.environ.get('SLURM_JOB_NAME', False)
    core.util._set_seed(opts.seed)
    full_dataset, train, test = prepare_datasets()
    train_loader = DataLoader(train, batch_size=opts.batch_size, drop_last=False, shuffle=True)
    test_loader = DataLoader(test, batch_size=opts.batch_size, drop_last=False, shuffle=False)
    pretrained_senders = [
        core.RnnSenderGS(
            agent=Sender(opts.sender_hidden, Vision.from_pretrained('vision_model.pth')),
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            hidden_size=opts.sender_hidden,
            max_len=1,
            temperature=3.,
            trainable_temperature=True,
            cell=opts.rnn_cell,
            force_eos=False
        )
        for i in range(2)]
    sender_3 = core.RnnSenderGS(
            agent=Sender(opts.sender_hidden, Vision.from_pretrained('vision_model.pth')),
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            hidden_size=opts.sender_hidden,
            max_len=2,
            temperature=3.,
            trainable_temperature=True,
            force_eos=False,
                                                 opts.n_attributes)
    train_loader = DataLoader(train,
                              batch_size=opts.batch_size,
                              drop_last=False,
                              shuffle=True)
    test_loader = DataLoader(test,
                             batch_size=10,
                             drop_last=False,
                             shuffle=False)

    senders = [
        core.RnnSenderGS(agent=Sender(opts.sender_hidden, opts.n_features,
                                      opts.n_attributes),
                         vocab_size=opts.vocab_size,
                         emb_dim=opts.sender_embedding,
                         n_hidden=opts.sender_hidden,
                         max_len=opts.max_len,
                         temperature=1.,
                         trainable_temperature=True,
                         cell=opts.rnn_cell)
        for _ in range(opts.population_size)
    ]
    sender_ensemble = GumbelSoftmaxMultiAgentEnsemble(agents=senders)
    receivers_1 = [
        core.RnnReceiverGS(agent=Receiver(opts.receiver_hidden,
                                          opts.n_features, opts.n_attributes),
                           vocab_size=opts.vocab_size,
                           emb_dim=opts.receiver_embedding,
                           n_hidden=opts.receiver_hidden,
                           cell=opts.rnn_cell)
        for _ in range(opts.population_size)