示例#1
0
文件: train.py 项目: lorarjohns/EGG
def main(params):
    opts = core.init(params=params)
    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        './data', train=True, download=True, transform=transform),
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        './data', train=False, transform=transform),
                                              batch_size=opts.batch_size,
                                              shuffle=True,
                                              **kwargs)

    sender = Sender(opts.vocab_size)
    receiver = Receiver(opts.vocab_size)
    game = VAE_Game(sender, receiver)
    optimizer = core.build_optimizer(game.parameters())

    # initialize and launch the trainer
    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True,
                                                  print_train_loss=True),
                               ImageDumpCallback(test_loader.dataset)
                           ])
    trainer.train(n_epochs=opts.n_epochs)

    core.close()
示例#2
0
def main(params):
    opts = core.init(params=params)

    root = os.path.join('data', 'dsprites-dataset', 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
    if not os.path.exists(root):
        import subprocess
        print('Now download dsprites-dataset')
        subprocess.call([os.path.join('egg', 'zoo', 'dsprites_bvae', 'data_loaders', 'download_dsprites.sh')])
        print('Finished')

    train_loader, test_loader = get_dsprites_dataloader(path_to_data=root,
                                                            batch_size=opts.batch_size, image=True)
    image_shape = (64, 64)

    sender = VisualSender()
    receiver = VisualReceiver()
    game = betaVAE_Game(sender, receiver)

    optimizer = core.build_optimizer(game.parameters())

    # initialize and launch the trainer
    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader,
                           callbacks=[core.ConsoleLogger(as_json=True, print_train_loss=True),
                                      ImageDumpCallback(test_loader.dataset, image_shape=image_shape),
                                      TopographicSimilarity(sender_input_distance_fn='euclidean', message_distance_fn='euclidean', is_gumbel=False), PosDisent()])
    trainer.train(n_epochs=opts.n_epochs)

    core.close()
示例#3
0
文件: train.py 项目: wedddy0707/EGG
def main(params):
    opts = get_params(params)
    print(opts)

    kwargs = {"num_workers": 1, "pin_memory": True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST("./data",
                                   train=True,
                                   download=True,
                                   transform=transform)
    test_dataset = datasets.MNIST("./data",
                                  train=False,
                                  download=False,
                                  transform=transform)
    n_classes = 10
    label_mapping = torch.LongTensor([x % n_classes for x in range(100)])

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)
    train_loader = DoubleMnist(train_loader, label_mapping)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16 * 1024,
                                              shuffle=False,
                                              **kwargs)
    test_loader = DoubleMnist(test_loader, label_mapping)

    sender = Sender(
        vocab_size=opts.vocab_size,
        linear_channel=opts.linear_channel == 1,
        softmax_channel=opts.softmax_non_linearity,
    )
    receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes)

    if opts.softmax_non_linearity == 0 and opts.linear_channel == 0:
        sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=[
            core.ConsoleLogger(as_json=True, print_train_loss=True),
            EarlyStopperAccuracy(opts.early_stopping_thr),
        ],
    )

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
示例#4
0
def main(params):
    opts = get_params(params)
    print(opts)

    label_mapping = torch.LongTensor([x % opts.n_labels for x in range(100)])
    print('# label mapping', label_mapping.tolist())

    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST('./data',
                                   train=True,
                                   download=True,
                                   transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)
    train_loader = DoubleMnist(train_loader, label_mapping)

    test_dataset = datasets.MNIST('./data', train=False, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16 * 1024,
                                              shuffle=False,
                                              **kwargs)
    test_loader = DoubleMnist(test_loader, label_mapping)

    sender = Sender(vocab_size=opts.vocab_size)
    receiver = Receiver(vocab_size=opts.vocab_size,
                        n_classes=opts.n_labels,
                        n_hidden=opts.n_hidden)
    sender = core.GumbelSoftmaxWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    intervention = CallbackEvaluator(test_loader,
                                     device=opts.device,
                                     loss=game.loss,
                                     is_gs=True,
                                     var_length=False,
                                     input_intervention=False)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True),
                               EarlyStopperAccuracy(opts.early_stopping_thr),
                               intervention
                           ])

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
示例#5
0
def main(params):
    # initialize the egg lib
    opts = core.init(params=params)
    # get pre-defined common line arguments (batch/vocab size, etc).
    # See egg/core/util.py for a list

    # prepare the dataset
    kwargs = {"num_workers": 1, "pin_memory": True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        "./data", train=True, download=True, transform=transform),
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        "./data", train=False, transform=transform),
                                              batch_size=opts.batch_size,
                                              shuffle=True,
                                              **kwargs)

    # initialize the agents and the game
    sender = Sender(opts.vocab_size)  # the "data" transform part of an agent
    sender = core.GumbelSoftmaxWrapper(
        sender, temperature=1.0)  # wrapping into a GS interface

    receiver = Receiver()
    receiver = core.SymbolReceiverWrapper(receiver,
                                          vocab_size=opts.vocab_size,
                                          agent_input_size=400)
    # setting up as a standard Sender/Receiver game with 1 symbol communication
    game = core.SymbolGameGS(sender, receiver, loss)
    # This callback would be called at the end of each epoch by the Trainer; it reduces the sampling
    # temperature used by the GS
    temperature_updater = core.TemperatureUpdater(agent=sender,
                                                  decay=0.75,
                                                  minimum=0.01)
    # get an optimizer that is set up by common command line parameters,
    # defaults to Adam
    optimizer = core.build_optimizer(game.parameters())

    # initialize and launch the trainer
    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=[
            temperature_updater,
            core.ConsoleLogger(as_json=True, print_train_loss=True),
        ],
    )
    trainer.train(n_epochs=opts.n_epochs)

    core.close()
示例#6
0
def main(params):
    opts = get_params(params)
    print(opts)

    device = opts.device

    train_data = SphereData(n_points=int(opts.n_points))
    train_loader = DataLoader(train_data,
                              batch_size=opts.batch_size,
                              shuffle=True)

    test_data = SphereData(n_points=int(1e3))
    test_loader = DataLoader(train_data,
                             batch_size=opts.batch_size,
                             shuffle=False)

    sender = CircleSender(opts.vocab_size)

    assert opts.lenses in [0, 1]
    if opts.lenses == 1:
        sender = torch.nn.Sequential(Lenses(math.pi / 4), sender)

    receiver = Receiver(n_hidden=opts.receiver_hidden,
                        n_dim=2,
                        inner_layers=opts.receiver_layers)
    receiver = core.RnnReceiverDeterministic(
        receiver,
        opts.vocab_size + 1,  # exclude eos = 0
        opts.receiver_emb,
        opts.receiver_hidden,
        cell=opts.receiver_cell,
        num_layers=opts.cell_layers)

    game = core.SenderReceiverRnnReinforce(sender,
                                           receiver,
                                           diff_loss,
                                           receiver_entropy_coeff=0.05,
                                           sender_entropy_coeff=0.0)

    optimizer = core.build_optimizer(receiver.parameters())
    loss = game.loss

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=[core.ConsoleLogger(as_json=True, print_train_loss=True)],
        grad_norm=1.0)

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
示例#7
0
def main(params):
    opts = get_params(params)
    print(json.dumps(vars(opts)))

    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=True, download=True,
           transform=transform),
           batch_size=opts.batch_size, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transform),
           batch_size=16 * 1024, shuffle=False, **kwargs)

    n_classes = 10

    binarize = False

    test_loader = SplitImages(TakeFirstLoader(test_loader, n=1), rows_receiver=opts.receiver_rows,
                              rows_sender=opts.sender_rows, binarize=binarize, receiver_bottom=True)
    train_loader = SplitImages(train_loader, rows_sender=opts.sender_rows, rows_receiver=opts.receiver_rows,
                               binarize=binarize, receiver_bottom=True)

    sender = Sender(vocab_size=opts.vocab_size)
    receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes)
    sender = core.GumbelSoftmaxWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    intervention = CallbackEvaluator(test_loader, device=opts.device, loss=game.loss,
                                     is_gs=True,
                                     var_length=False,
                                     input_intervention=True)

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[core.ConsoleLogger(as_json=True),
                                      EarlyStopperAccuracy(opts.early_stopping_thr),
                                      intervention])

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
示例#8
0
def main(params):
    opts = get_params(params)
    print(json.dumps(vars(opts)))

    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST('./data', train=True, download=True,
                   transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, download=False,
                   transform=transform)
    n_classes = 10

    corrupt_labels_(dataset=train_dataset, p_corrupt=opts.p_corrupt, seed=opts.random_seed+1)

    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=opts.batch_size, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(test_dataset,
        batch_size=opts.batch_size, shuffle=False, **kwargs)

    deeper_alice = opts.deeper_alice == 1 and opts.deeper == 1
    deeper_bob = opts.deeper_alice != 1 and opts.deeper == 1

    sender = Sender(vocab_size=opts.vocab_size, deeper=deeper_alice, linear_channel=opts.linear_channel == 1,
                    softmax_channel=opts.softmax_non_linearity == 1)
    receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes, deeper=deeper_bob)

    if opts.softmax_non_linearity != 1 and opts.linear_channel != 1:
        sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[core.ConsoleLogger(as_json=True, print_train_loss=True),
                                      EarlyStopperAccuracy(opts.early_stopping_thr)]
                           )

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
示例#9
0
    def __init__(self, data_path:str) -> None:
        super().__init__()

        self.train_loader, self.test_loader = \
            get_dsprites_dataloader(
                batch_size=self.batch_size,
                path_to_data=data_path,
                game_size=self.game_size,
                referential=True
            )

        self.sender = core.RnnSenderGS(
            DspritesSenderCNN(self.hidden_size), 
            self.vocab_size,
            self.emb_size,
            self.hidden_size,
            max_len=self.max_len,
            cell="lstm", 
            temperature=1.0
        )

        self.receiver = core.RnnReceiverGS(
            DSpritesReceiverCNN(self.game_size, self.emb_size, self.hidden_size, reinforce=False),
            self.vocab_size,
            self.emb_size,
            self.hidden_size,
            cell="lstm"
        )

        self.game = core.SenderReceiverRnnGS(self.sender, self.receiver, self.loss)

        self.optimiser = core.build_optimizer(self.game.parameters())
        self.callbacks = []
        self.callbacks.append(core.ConsoleLogger(as_json=True,print_train_loss=True))
        self.callbacks.append(TopographicSimilarityLatents('euclidean', 'edit'))
        #self.callbacks.append(core.TemperatureUpdater(agent=self.sender, decay=0.9, minimum=0.1))
        self.trainer = core.Trainer(
            game=self.game, optimizer=self.optimiser, train_data=self.train_loader, validation_data=self.test_loader,
            callbacks=self.callbacks
        )
示例#10
0
文件: play.py 项目: wedddy0707/EGG
def main(params):
    opts = get_params(params)
    if opts.validation_batch_size == 0:
        opts.validation_batch_size = opts.batch_size
    print(opts, flush=True)

    # the following if statement controls aspects specific to the two game tasks: loss, input data and architecture of the Receiver
    # (the Sender is identical in both cases, mapping a single input attribute-value vector to a variable-length message)
    if opts.game_type == "discri":
        # the game object we will encounter below takes as one of its mandatory arguments a loss: a loss in EGG is expected to take as arguments the sender input,
        # the message, the Receiver input, the Receiver output and the labels (although some of these elements might not actually be used by a particular loss);
        # together with the actual loss computation, the loss function can return a dictionary with other auxiliary statistics: in this case, accuracy
        def loss(
            _sender_input,
            _message,
            _receiver_input,
            receiver_output,
            labels,
            _aux_input,
        ):
            # in the discriminative case, accuracy is computed by comparing the index with highest score in Receiver output (a distribution of unnormalized
            # probabilities over target poisitions) and the corresponding label read from input, indicating the ground-truth position of the target
            acc = (receiver_output.argmax(dim=1) == labels).detach().float()
            # similarly, the loss computes cross-entropy between the Receiver-produced target-position probability distribution and the labels
            loss = F.cross_entropy(receiver_output, labels, reduction="none")
            return loss, {"acc": acc}

        # the input data are read into DataLodaer objects, which are pytorch constructs implementing standard data processing functionalities, such as shuffling
        # and batching
        # within our games, we implement dataset classes, such as AttValDiscriDataset, to read the input text files and convert the information they contain
        # into the form required by DataLoader
        # look at the definition of the AttValDiscrDataset (the class to read discrimination game data) in data_readers.py for further details
        # note that, for the training dataset, we first instantiate the AttValDiscriDataset object and then feed it to DataLoader, whereas for the
        # validation data (confusingly called "test" data due to code heritage inertia) we directly declare the AttValDiscriDataset when instantiating
        # DataLoader: the reason for this difference is that we need the train_ds object to retrieve the number of features of the input vectors
        train_ds = AttValDiscriDataset(path=opts.train_data,
                                       n_values=opts.n_values)
        train_loader = DataLoader(train_ds,
                                  batch_size=opts.batch_size,
                                  shuffle=True,
                                  num_workers=1)
        test_loader = DataLoader(
            AttValDiscriDataset(path=opts.validation_data,
                                n_values=opts.n_values),
            batch_size=opts.validation_batch_size,
            shuffle=False,
            num_workers=1,
        )
        # note that the number of features retrieved here concerns inputs after they are converted to 1-hot vectors
        n_features = train_ds.get_n_features()
        # we define here the core of the Receiver for the discriminative game, see the architectures.py file for details:
        # note that this will be embedded in a wrapper below to define the full agent
        receiver = DiscriReceiver(n_features=n_features,
                                  n_hidden=opts.receiver_hidden)

    else:  # reco game

        def loss(sender_input, _message, _receiver_input, receiver_output,
                 labels, _aux_input):
            # in the case of the recognition game, for each attribute we compute a different cross-entropy score
            # based on comparing the probability distribution produced by the Receiver over the values of each attribute
            # with the corresponding ground truth, and then averaging across attributes
            # accuracy is instead computed by considering as a hit only cases where, for each attribute, the Receiver
            # assigned the largest probability to the correct value
            # most of this function consists of the usual pytorch madness needed to reshape tensors in order to perform these computations
            n_attributes = opts.n_attributes
            n_values = opts.n_values
            batch_size = sender_input.size(0)
            receiver_output = receiver_output.view(batch_size * n_attributes,
                                                   n_values)
            receiver_guesses = receiver_output.argmax(dim=1)
            correct_samples = ((receiver_guesses == labels.view(-1)).view(
                batch_size, n_attributes).detach())
            acc = (torch.sum(correct_samples, dim=-1) == n_attributes).float()
            labels = labels.view(batch_size * n_attributes)
            loss = F.cross_entropy(receiver_output, labels, reduction="none")
            loss = loss.view(batch_size, -1).mean(dim=1)
            return loss, {"acc": acc}

        # again, see data_readers.py in this directory for the AttValRecoDataset data reading class
        train_loader = DataLoader(
            AttValRecoDataset(
                path=opts.train_data,
                n_attributes=opts.n_attributes,
                n_values=opts.n_values,
            ),
            batch_size=opts.batch_size,
            shuffle=True,
            num_workers=1,
        )
        test_loader = DataLoader(
            AttValRecoDataset(
                path=opts.validation_data,
                n_attributes=opts.n_attributes,
                n_values=opts.n_values,
            ),
            batch_size=opts.validation_batch_size,
            shuffle=False,
            num_workers=1,
        )
        # the number of features for the Receiver (input) and the Sender (output) is given by n_attributes*n_values because
        # they are fed/produce 1-hot representations of the input vectors
        n_features = opts.n_attributes * opts.n_values
        # we define here the core of the receiver for the discriminative game, see the architectures.py file for details
        # this will be embedded in a wrapper below to define the full architecture
        receiver = RecoReceiver(n_features=n_features,
                                n_hidden=opts.receiver_hidden)

    # we are now outside the block that defined game-type-specific aspects of the games: note that the core Sender architecture
    # (see architectures.py for details) is shared by the two games (it maps an input vector to a hidden layer that will be use to initialize
    # the message-producing RNN): this will also be embedded in a wrapper below to define the full architecture
    sender = Sender(n_hidden=opts.sender_hidden, n_features=n_features)

    # now, we instantiate the full sender and receiver architectures, and connect them and the loss into a game object
    # the implementation differs slightly depending on whether communication is optimized via Gumbel-Softmax ('gs') or Reinforce ('rf', default)
    if opts.mode.lower() == "gs":
        # in the following lines, we embed the Sender and Receiver architectures into standard EGG wrappers that are appropriate for Gumbel-Softmax optimization
        # the Sender wrapper takes the hidden layer produced by the core agent architecture we defined above when processing input, and uses it to initialize
        # the RNN that generates the message
        sender = core.RnnSenderGS(
            sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            hidden_size=opts.sender_hidden,
            cell=opts.sender_cell,
            max_len=opts.max_len,
            temperature=opts.temperature,
        )
        # the Receiver wrapper takes the symbol produced by the Sender at each step (more precisely, in Gumbel-Softmax mode, a function of the overall probability
        # of non-eos symbols upt to the step is used), maps it to a hidden layer through a RNN, and feeds this hidden layer to the
        # core Receiver architecture we defined above (possibly with other Receiver input, as determined by the core architecture) to generate the output
        receiver = core.RnnReceiverGS(
            receiver,
            vocab_size=opts.vocab_size,
            embed_dim=opts.receiver_embedding,
            hidden_size=opts.receiver_hidden,
            cell=opts.receiver_cell,
        )
        game = core.SenderReceiverRnnGS(sender, receiver, loss)
        # callback functions can be passed to the trainer object (see below) to operate at certain steps of training and validation
        # for example, the TemperatureUpdater (defined in callbacks.py in the core directory) will update the Gumbel-Softmax temperature hyperparameter
        # after each epoch
        callbacks = [
            core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1)
        ]
    else:  # NB: any other string than gs will lead to rf training!
        # here, the interesting thing to note is that we use the same core architectures we defined above, but now we embed them in wrappers that are suited to
        # Reinforce-based optmization
        sender = core.RnnSenderReinforce(
            sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            hidden_size=opts.sender_hidden,
            cell=opts.sender_cell,
            max_len=opts.max_len,
        )
        receiver = core.RnnReceiverDeterministic(
            receiver,
            vocab_size=opts.vocab_size,
            embed_dim=opts.receiver_embedding,
            hidden_size=opts.receiver_hidden,
            cell=opts.receiver_cell,
        )
        game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=0,
        )
        callbacks = []

    # we are almost ready to train: we define here an optimizer calling standard pytorch functionality
    optimizer = core.build_optimizer(game.parameters())
    # in the following statement, we finally instantiate the trainer object with all the components we defined (the game, the optimizer, the data
    # and the callbacks)
    if opts.print_validation_events == True:
        # we add a callback that will print loss and accuracy after each training and validation pass (see ConsoleLogger in callbacks.py in core directory)
        # if requested by the user, we will also print a detailed log of the validation pass after full training: look at PrintValidationEvents in
        # language_analysis.py (core directory)
        trainer = core.Trainer(
            game=game,
            optimizer=optimizer,
            train_data=train_loader,
            validation_data=test_loader,
            callbacks=callbacks + [
                core.ConsoleLogger(print_train_loss=True, as_json=True),
                core.PrintValidationEvents(n_epochs=opts.n_epochs),
            ],
        )
    else:
        trainer = core.Trainer(
            game=game,
            optimizer=optimizer,
            train_data=train_loader,
            validation_data=test_loader,
            callbacks=callbacks +
            [core.ConsoleLogger(print_train_loss=True, as_json=True)],
        )

    # and finally we train!
    trainer.train(n_epochs=opts.n_epochs)
示例#11
0
文件: train.py 项目: wedddy0707/EGG
    )
    validation_loader = ImagenetLoader(
        dataset,
        opt=opts,
        batch_size=opts.batch_size,
        batches_per_epoch=opts.batches_per_epoch,
        seed=7,
    )
    game = get_game(opts)
    optimizer = core.build_optimizer(game.parameters())
    callback = None
    if opts.mode == "gs":
        callbacks = [
            core.TemperatureUpdater(agent=game.sender, decay=0.9, minimum=0.1)
        ]
    else:
        callbacks = []

    callbacks.append(core.ConsoleLogger(as_json=True, print_train_loss=True))
    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=validation_loader,
        callbacks=callbacks,
    )

    trainer.train(n_epochs=opts.n_epochs)

    core.close()
        'lr': opts.lr
    } for agent in agents])
    neptune.init(project_qualified_name='anonymous/anonymous',
                 backend=neptune.OfflineBackend())
    with neptune.create_experiment(params=vars(opts),
                                   upload_source_files=get_filepaths(),
                                   tags=['']) as experiment:
        trainer = core.Trainer(
            game=game,
            optimizer=optimizer,
            train_data=train_loader,
            validation_data=test_loader,
            callbacks=[
                CompositionalityMetricObverter(full_dataset,
                                               agents[0],
                                               opts,
                                               opts.vocab_size,
                                               prefix='1_'),
                CompositionalityMetricObverter(full_dataset,
                                               agents[1],
                                               opts,
                                               opts.vocab_size,
                                               prefix='2_'),
                NeptuneMonitor(),
                core.ConsoleLogger(print_train_loss=not opts.on_slurm),
                EarlyStopperAccuracy(threshold=0.99,
                                     field_name='accuracy',
                                     delay=5)
            ])
        trainer.train(n_epochs=opts.n_epochs)
示例#13
0
def main(params):
    opts = get_params(params)
    print(opts)

    device = opts.device

    n_a, n_v = opts.n_a, opts.n_v
    opts.vocab_size = n_v

    train_data = AttributeValueData(n_attributes=n_a,
                                    n_values=n_v,
                                    mul=1,
                                    mode='train')
    train_loader = DataLoader(train_data,
                              batch_size=opts.batch_size,
                              shuffle=True)

    test_data = AttributeValueData(n_attributes=n_a,
                                   n_values=n_v,
                                   mul=1,
                                   mode='test')
    test_loader = DataLoader(test_data,
                             batch_size=opts.batch_size,
                             shuffle=False)

    print(f'# Size of train {len(train_data)} test {len(test_data)}')

    if opts.language == 'identity':
        sender = IdentitySender(n_a, n_v)
    elif opts.language == 'rotated':
        sender = RotatedSender(n_a, n_v)
    else:
        assert False

    receiver = Receiver(n_hidden=opts.receiver_hidden,
                        n_dim=n_a * n_v,
                        inner_layers=opts.receiver_layers)
    receiver = core.RnnReceiverDeterministic(
        receiver,
        opts.vocab_size + 1,  # exclude eos = 0
        opts.receiver_emb,
        opts.receiver_hidden,
        cell=opts.receiver_cell,
        num_layers=opts.cell_layers)

    diff_loss = DiffLoss(n_a, n_v, loss_type=opts.loss_type)

    game = core.SenderReceiverRnnReinforce(sender,
                                           receiver,
                                           diff_loss,
                                           receiver_entropy_coeff=0.05,
                                           sender_entropy_coeff=0.0)

    optimizer = core.build_optimizer(receiver.parameters())
    loss = game.loss

    early_stopper = core.EarlyStopperAccuracy(1.0, validation=False)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True,
                                                  print_train_loss=True),
                               early_stopper
                           ],
                           grad_norm=1.0)

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
示例#14
0
文件: train.py 项目: Slowika/EGG
def main(params):
    opts = get_params(params)
    print(opts, flush=True)

    # For compatibility, after https://github.com/facebookresearch/EGG/pull/130
    # the meaning of `length` changed a bit. Before it included the EOS symbol; now
    # it doesn't. To ensure that hyperparameters/CL arguments do not change,
    # we subtract it here.
    opts.max_len -= 1

    device = opts.device

    if opts.probs == "uniform":
        probs = np.ones(opts.n_features)
    elif opts.probs == "powerlaw":
        probs = 1 / np.arange(1, opts.n_features + 1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(",")], dtype=np.float32)
    probs /= probs.sum()

    print("the probs are: ", probs, flush=True)

    train_loader = OneHotLoader(
        n_features=opts.n_features,
        batch_size=opts.batch_size,
        batches_per_epoch=opts.batches_per_epoch,
        probs=probs,
    )

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    if opts.sender_cell == "transformer":
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding)
        sender = core.TransformerSenderReinforce(
            agent=sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            max_len=opts.max_len,
            num_layers=opts.sender_num_layers,
            num_heads=opts.sender_num_heads,
            hidden_size=opts.sender_hidden,
            generate_style=opts.sender_generate_style,
            causal=opts.causal_sender,
        )
    else:
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)

        sender = core.RnnSenderReinforce(
            sender,
            opts.vocab_size,
            opts.sender_embedding,
            opts.sender_hidden,
            cell=opts.sender_cell,
            max_len=opts.max_len,
            num_layers=opts.sender_num_layers,
        )
    if opts.receiver_cell == "transformer":
        receiver = Receiver(
            n_features=opts.n_features, n_hidden=opts.receiver_embedding
        )
        receiver = core.TransformerReceiverDeterministic(
            receiver,
            opts.vocab_size,
            opts.max_len,
            opts.receiver_embedding,
            opts.receiver_num_heads,
            opts.receiver_hidden,
            opts.receiver_num_layers,
            causal=opts.causal_receiver,
        )
    else:
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
        receiver = core.RnnReceiverDeterministic(
            receiver,
            opts.vocab_size,
            opts.receiver_embedding,
            opts.receiver_hidden,
            cell=opts.receiver_cell,
            num_layers=opts.receiver_num_layers,
        )

    empty_logger = LoggingStrategy.minimal()
    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=opts.receiver_entropy_coeff,
        train_logging_strategy=empty_logger,
        length_cost=opts.length_cost,
    )

    optimizer = core.build_optimizer(game.parameters())

    callbacks = [
        EarlyStopperAccuracy(opts.early_stopping_thr),
        core.ConsoleLogger(as_json=True, print_train_loss=True),
    ]

    if opts.checkpoint_dir:
        checkpoint_name = f"{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}"
        callbacks.append(
            core.CheckpointSaver(
                checkpoint_path=opts.checkpoint_dir, prefix=checkpoint_name
            )
        )

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=callbacks,
    )

    trainer.train(n_epochs=opts.n_epochs)

    game.logging_strategy = LoggingStrategy.maximal()  # now log everything
    dump(trainer.game, opts.n_features, device, False)
    core.close()
示例#15
0
        game = SenderReceiverRnnGS(sender, receiver, loss)

    optimizer = torch.optim.Adam([{
        'params': game.sender.parameters(),
        'lr': opts.sender_lr
    }, {
        'params': game.receiver.parameters(),
        'lr': opts.receiver_lr
    }])

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=validation_loader,
                           callbacks=[
                               core.ConsoleLogger(print_train_loss=True,
                                                  as_json=True),
                               TopographicSimilarity(
                                   sender_input_distance_fn='cosine',
                                   message_distance_fn='edit',
                                   is_gumbel=True)
                           ])

    if dump_loader is not None:
        print("Printing vocab...")
        lines = []
        for vector, label in dump_loader:
            vector = torch.tensor(vector).cuda()
            with torch.no_grad():
                message, _ = sender(vector)
                message = message.argmax(dim=-1).cpu().numpy()[0]
                message = ' '.join([str(i) for i in message])
示例#16
0
def main(params):
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    force_eos = opts.force_eos == 1

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features + 1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')],
                         dtype=np.float32)
    probs /= probs.sum()

    print('the probs are: ', probs, flush=True)

    train_loader = OneHotLoader(n_features=opts.n_features,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch,
                                probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    #################################
    # define sender (speaker) agent #
    #################################
    sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)
    sender = RnnSenderReinforce(sender,
                                opts.vocab_size,
                                opts.sender_embedding,
                                opts.sender_hidden,
                                cell=opts.sender_cell,
                                max_len=opts.max_len,
                                num_layers=opts.sender_num_layers,
                                force_eos=force_eos,
                                noise_loc=opts.sender_noise_loc,
                                noise_scale=opts.sender_noise_scale)

    ####################################
    # define receiver (listener) agent #
    ####################################
    receiver = Receiver(n_features=opts.n_features,
                        n_hidden=opts.receiver_hidden)
    receiver = RnnReceiverDeterministic(receiver,
                                        opts.vocab_size,
                                        opts.receiver_embedding,
                                        opts.receiver_hidden,
                                        cell=opts.receiver_cell,
                                        num_layers=opts.receiver_num_layers,
                                        noise_loc=opts.receiver_noise_loc,
                                        noise_scale=opts.receiver_noise_scale)

    ###################
    # define  channel #
    ###################
    channel = Channel(vocab_size=opts.vocab_size, p=opts.channel_repl_prob)

    game = SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=opts.receiver_entropy_coeff,
        length_cost=opts.length_cost,
        effective_max_len=opts.effective_max_len,
        channel=channel,
        sender_entropy_common_ratio=opts.sender_entropy_common_ratio)

    optimizer = core.build_optimizer(game.parameters())

    callbacks = [
        EarlyStopperAccuracy(opts.early_stopping_thr),
        core.ConsoleLogger(as_json=True, print_train_loss=True)
    ]

    if opts.checkpoint_dir:
        '''
        info in checkpoint_name:
            - n_features as f
            - vocab_size as vocab
            - random_seed as rs
            - lr as lr
            - sender_hidden as shid
            - receiver_hidden as rhid
            - sender_entropy_coeff as sentr
            - length_cost as reg
            - max_len as max_len
            - sender_noise_scale as sscl
            - receiver_noise_scale as rscl
            - channel_repl_prob as crp
            - sender_entropy_common_ratio as scr
        '''
        checkpoint_name = (
            f'{opts.name}' + '_aer' +
            ('_uniform' if opts.probs == 'uniform' else '') +
            f'_f{opts.n_features}' + f'_vocab{opts.vocab_size}' +
            f'_rs{opts.random_seed}' + f'_lr{opts.lr}' +
            f'_shid{opts.sender_hidden}' + f'_rhid{opts.receiver_hidden}' +
            f'_sentr{opts.sender_entropy_coeff}' + f'_reg{opts.length_cost}' +
            f'_max_len{opts.max_len}' + f'_sscl{opts.sender_noise_scale}' +
            f'_rscl{opts.receiver_noise_scale}' +
            f'_crp{opts.channel_repl_prob}' +
            f'_scr{opts.sender_entropy_common_ratio}')
        callbacks.append(
            core.CheckpointSaver(checkpoint_path=opts.checkpoint_dir,
                                 checkpoint_freq=opts.checkpoint_freq,
                                 prefix=checkpoint_name))

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=callbacks)

    trainer.train(n_epochs=opts.n_epochs)
    print('<div id="prefix test without eos">')
    prefix_test(trainer.game, opts.n_features, device, add_eos=False)
    print('</div>')
    print('<div id="prefix test with eos">')
    prefix_test(trainer.game, opts.n_features, device, add_eos=True)
    print('<div id="suffix test">')
    suffix_test(trainer.game, opts.n_features, device)
    print('</div>')
    print('<div id="replacement test">')
    replacement_test(trainer.game, opts.n_features, opts.vocab_size, device)
    print('</div>')
    print('<div id="dump">')
    dump(trainer.game, opts.n_features, device, False)
    print('</div>')
    core.close()
示例#17
0
def main(params):
    import copy
    opts = get_params(params)
    device = opts.device

    full_data = enumerate_attribute_value(opts.n_attributes, opts.n_values)
    if opts.density_data > 0:
        sampled_data = select_subset_V2(full_data, opts.density_data,
                                        opts.n_attributes, opts.n_values)
        full_data = copy.deepcopy(sampled_data)

    train, generalization_holdout = split_holdout(full_data)
    train, uniform_holdout = split_train_test(train, 0.1)

    generalization_holdout, train, uniform_holdout, full_data = [
        one_hotify(x, opts.n_attributes, opts.n_values)
        for x in [generalization_holdout, train, uniform_holdout, full_data]
    ]

    train, validation = ScaledDataset(train, opts.data_scaler), ScaledDataset(
        train, 1)

    generalization_holdout, uniform_holdout, full_data = ScaledDataset(
        generalization_holdout), ScaledDataset(uniform_holdout), ScaledDataset(
            full_data)
    generalization_holdout_loader, uniform_holdout_loader, full_data_loader = [
        DataLoader(x, batch_size=opts.batch_size)
        for x in [generalization_holdout, uniform_holdout, full_data]
    ]

    train_loader = DataLoader(train, batch_size=opts.batch_size)
    validation_loader = DataLoader(validation, batch_size=len(validation))

    n_dim = opts.n_attributes * opts.n_values

    if opts.receiver_cell in ['lstm', 'rnn', 'gru']:
        receiver = MMReceiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim)
        receiver = core.MMRnnReceiverDeterministic(receiver,
                                                   opts.vocab_size + 1,
                                                   opts.receiver_emb,
                                                   opts.receiver_hidden,
                                                   cell=opts.receiver_cell)
    else:
        raise ValueError(f'Unknown receiver cell, {opts.receiver_cell}')

    if opts.sender_cell in ['lstm', 'rnn', 'gru']:
        sender = MMSender(n_inputs=n_dim, n_hidden=opts.sender_hidden)
        s1 = SplitWrapper(sender, 0)
        s2 = SplitWrapper(sender, 1)
        sender1 = core.RnnSenderReinforce(agent=s1,
                                          vocab_size=opts.vocab_size,
                                          embed_dim=opts.sender_emb,
                                          hidden_size=opts.sender_hidden,
                                          max_len=opts.max_len,
                                          force_eos=False,
                                          cell=opts.sender_cell)
        sender1 = PlusNWrapper(sender1, 1)
        sender2 = core.RnnSenderReinforce(agent=s2,
                                          vocab_size=opts.vocab_size,
                                          embed_dim=opts.sender_emb,
                                          hidden_size=opts.sender_hidden,
                                          max_len=opts.max_len,
                                          force_eos=False,
                                          cell=opts.sender_cell)
        # sender2 = PlusNWrapper(sender2, opts.vocab_size + 1)
        sender2 = PlusNWrapper(sender2, 1)
        sender = CombineMMRnnSenderReinforce(sender1, sender2)
    else:
        raise ValueError(f'Unknown sender cell, {opts.sender_cell}')

    loss = DiffLoss(opts.n_attributes, opts.n_values)

    baseline = {
        'no': core.baselines.NoBaseline,
        'mean': core.baselines.MeanBaseline,
        'builtin': core.baselines.BuiltInBaseline
    }[opts.baseline]

    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=0.0,
        length_cost=0.0,
        baseline_type=baseline)
    optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr)

    metrics_evaluator = MMMetrics(validation.examples,
                                  opts.device,
                                  opts.n_attributes,
                                  opts.n_values,
                                  opts.vocab_size + 1,
                                  freq=opts.stats_freq)

    loaders = []
    loaders.append(("generalization hold out", generalization_holdout_loader,
                    DiffLoss(opts.n_attributes,
                             opts.n_values,
                             generalization=True)))
    loaders.append(("uniform holdout", uniform_holdout_loader,
                    DiffLoss(opts.n_attributes, opts.n_values)))

    holdout_evaluator = Evaluator(loaders, opts.device, freq=0)
    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr,
                                         validation=True)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=validation_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True,
                                                  print_train_loss=False),
                               early_stopper, metrics_evaluator,
                               holdout_evaluator
                           ])
    trainer.train(n_epochs=opts.n_epochs)

    validation_acc = early_stopper.validation_stats[-1][1]['acc']
    uniformtest_acc = holdout_evaluator.results['uniform holdout']['acc']

    # Train new agents
    if validation_acc > 0.99:

        def _set_seed(seed):
            import random
            import numpy as np

            random.seed(seed)
            torch.manual_seed(seed)
            np.random.seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)

        core.get_opts().preemptable = False
        core.get_opts().checkpoint_path = None

        # freeze Sender and probe how fast a simple Receiver will learn the thing
        def retrain_receiver(receiver_generator, sender):
            receiver = receiver_generator()
            game = core.SenderReceiverRnnReinforce(sender,
                                                   receiver,
                                                   loss,
                                                   sender_entropy_coeff=0.0,
                                                   receiver_entropy_coeff=0.0)
            optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr)
            early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr,
                                                 validation=True)

            trainer = core.Trainer(game=game,
                                   optimizer=optimizer,
                                   train_data=train_loader,
                                   validation_data=validation_loader,
                                   callbacks=[
                                       early_stopper,
                                       Evaluator(loaders, opts.device, freq=0)
                                   ])
            trainer.train(n_epochs=opts.n_epochs // 2)

            accs = [x[1]['acc'] for x in early_stopper.validation_stats]
            return accs

        frozen_sender = Freezer(copy.deepcopy(sender))

        def gru_receiver_generator():            return \
core.MMRnnReceiverDeterministic(MMReceiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=opts.receiver_hidden, cell='gru')

        def small_gru_receiver_generator():            return \
core.MMRnnReceiverDeterministic(
                MMReceiver(n_hidden=50, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=50, cell='gru')

        def tiny_gru_receiver_generator():            return \
core.MMRnnReceiverDeterministic(
                MMReceiver(n_hidden=25, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=25, cell='gru')

        def nonlinear_receiver_generator():            return \
MMNonLinearReceiver(n_outputs=n_dim, vocab_size=opts.vocab_size + 1,
            max_length=opts.max_len, n_hidden=opts.receiver_hidden)

        for name, receiver_generator in [
            ('gru', gru_receiver_generator),
            ('nonlinear', nonlinear_receiver_generator),
            ('tiny_gru', tiny_gru_receiver_generator),
            ('small_gru', small_gru_receiver_generator),
        ]:

            for seed in range(17, 17 + 3):
                _set_seed(seed)
                accs = retrain_receiver(receiver_generator, frozen_sender)
                accs += [1.0] * (opts.n_epochs // 2 - len(accs))
                auc = sum(accs)
                print(
                    json.dumps({
                        "mode": "reset",
                        "seed": seed,
                        "receiver_name": name,
                        "auc": auc
                    }))

    print('---End--')

    core.close()
示例#18
0
def main(params):
    opts = get_params(params)

    device = torch.device("cuda" if opts.cuda else "cpu")

    data_loader = VectorsLoader(
        perceptual_dimensions=opts.perceptual_dimensions,
        n_distractors=opts.n_distractors,
        batch_size=opts.batch_size,
        train_samples=opts.train_samples,
        validation_samples=opts.validation_samples,
        test_samples=opts.test_samples,
        shuffle_train_data=opts.shuffle_train_data,
        dump_data_folder=opts.dump_data_folder,
        load_data_path=opts.load_data_path,
        seed=opts.data_seed,
    )

    train_data, validation_data, test_data = data_loader.get_iterators()

    data_loader.upd_cl_options(opts)

    if opts.max_len > 1:
        baseline_msg = 'Cannot yet compute "smart" baseline value for messages of length greater than 1'
    else:
        baseline_msg = (
            f"\n| Baselines measures with {opts.n_distractors} distractors and messages of max_len = {opts.max_len}:\n"
            f"| Dummy random baseline: accuracy = {1 / (opts.n_distractors + 1)}\n"
        )
        if -1 not in opts.perceptual_dimensions:
            baseline_msg += f'| "Smart" baseline with perceptual_dimensions {opts.perceptual_dimensions} = {compute_baseline_accuracy(opts.n_distractors, opts.max_len, *opts.perceptual_dimensions)}\n'
        else:
            baseline_msg += f'| Data was loaded froman external file, thus no perceptual_dimension vector was provided, "smart baseline" cannot be computed\n'

    print(baseline_msg)

    sender = Sender(n_features=data_loader.n_features,
                    n_hidden=opts.sender_hidden)

    receiver = Receiver(n_features=data_loader.n_features,
                        linear_units=opts.receiver_hidden)

    if opts.mode.lower() == "gs":
        sender = core.RnnSenderGS(
            sender,
            opts.vocab_size,
            opts.sender_embedding,
            opts.sender_hidden,
            cell=opts.sender_cell,
            max_len=opts.max_len,
            temperature=opts.temperature,
        )

        receiver = core.RnnReceiverGS(
            receiver,
            opts.vocab_size,
            opts.receiver_embedding,
            opts.receiver_hidden,
            cell=opts.receiver_cell,
        )

        game = core.SenderReceiverRnnGS(sender, receiver, loss)
    else:
        raise NotImplementedError(f"Unknown training mode, {opts.mode}")

    optimizer = torch.optim.Adam([
        {
            "params": game.sender.parameters(),
            "lr": opts.sender_lr
        },
        {
            "params": game.receiver.parameters(),
            "lr": opts.receiver_lr
        },
    ])
    callbacks = [core.ConsoleLogger(as_json=True)]
    if opts.mode.lower() == "gs":
        callbacks.append(
            core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1))
    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_data,
        validation_data=validation_data,
        callbacks=callbacks,
    )
    trainer.train(n_epochs=opts.n_epochs)

    if opts.evaluate:
        is_gs = opts.mode == "gs"
        (
            sender_inputs,
            messages,
            receiver_inputs,
            receiver_outputs,
            labels,
        ) = dump_sender_receiver(game,
                                 test_data,
                                 is_gs,
                                 variable_length=True,
                                 device=device)

        receiver_outputs = move_to(receiver_outputs, device)
        labels = move_to(labels, device)

        receiver_outputs = torch.stack(receiver_outputs)
        labels = torch.stack(labels)

        tensor_accuracy = receiver_outputs.argmax(dim=1) == labels
        accuracy = torch.mean(tensor_accuracy.float()).item()

        unique_dict = {}

        for elem in sender_inputs:
            target = ""
            for dim in elem:
                target += f"{str(int(dim.item()))}-"
            target = target[:-1]
            if target not in unique_dict:
                unique_dict[target] = True

        print(f"| Accuracy on test set: {accuracy}")

        compute_mi_input_msgs(sender_inputs, messages)

        print(f"entropy sender inputs {entropy(sender_inputs)}")
        print(f"mi sender inputs msgs {mutual_info(sender_inputs, messages)}")

        if opts.dump_msg_folder:
            opts.dump_msg_folder.mkdir(exist_ok=True)
            msg_dict = {}

            output_msg = (
                f"messages_{opts.perceptual_dimensions}_vocab_{opts.vocab_size}"
                f"_maxlen_{opts.max_len}_bsize_{opts.batch_size}"
                f"_n_distractors_{opts.n_distractors}_train_size_{opts.train_samples}"
                f"_valid_size_{opts.validation_samples}_test_size_{opts.test_samples}"
                f"_slr_{opts.sender_lr}_rlr_{opts.receiver_lr}_shidden_{opts.sender_hidden}"
                f"_rhidden_{opts.receiver_hidden}_semb_{opts.sender_embedding}"
                f"_remb_{opts.receiver_embedding}_mode_{opts.mode}"
                f"_scell_{opts.sender_cell}_rcell_{opts.receiver_cell}.msg")

            output_file = opts.dump_msg_folder / output_msg
            with open(output_file, "w") as f:
                f.write(f"{opts}\n")
                for (
                        sender_input,
                        message,
                        receiver_input,
                        receiver_output,
                        label,
                ) in zip(sender_inputs, messages, receiver_inputs,
                         receiver_outputs, labels):
                    sender_input = ",".join(map(str, sender_input.tolist()))
                    message = ",".join(map(str, message.tolist()))
                    distractors_list = receiver_input.tolist()
                    receiver_input = "; ".join([
                        ",".join(map(str, elem)) for elem in distractors_list
                    ])
                    if is_gs:
                        receiver_output = receiver_output.argmax()
                    f.write(
                        f"{sender_input} -> {receiver_input} -> {message} -> {receiver_output} (label={label.item()})\n"
                    )

                    if message in msg_dict:
                        msg_dict[message] += 1
                    else:
                        msg_dict[message] = 1

                sorted_msgs = sorted(msg_dict.items(),
                                     key=operator.itemgetter(1),
                                     reverse=True)
                f.write(
                    f"\nUnique target vectors seen by sender: {len(unique_dict.keys())}\n"
                )
                f.write(
                    f"Unique messages produced by sender: {len(msg_dict.keys())}\n"
                )
                f.write(f"Messagses: 'msg' : msg_count: {str(sorted_msgs)}\n")
                f.write(f"\nAccuracy: {accuracy}")
示例#19
0
文件: train.py 项目: evdcush/EGG
def main(params):
    opts = get_params(params)

    device = torch.device("cuda" if opts.cuda else "cpu")
    train_loader = OneHotLoader(n_features=opts.n_features,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch)
    test_loader = OneHotLoader(n_features=opts.n_features,
                               batch_size=opts.batch_size,
                               batches_per_epoch=opts.batches_per_epoch,
                               seed=7)

    sender = Sender(n_hidden=opts.sender_hidden, n_features=opts.n_features)
    receiver = Receiver(n_features=opts.n_features,
                        n_hidden=opts.receiver_hidden)

    if opts.mode.lower() == 'rf':
        sender = core.RnnSenderReinforce(sender,
                                         opts.vocab_size,
                                         opts.sender_embedding,
                                         opts.sender_hidden,
                                         cell=opts.sender_cell,
                                         max_len=opts.max_len)
        receiver = core.RnnReceiverDeterministic(receiver,
                                                 opts.vocab_size,
                                                 opts.receiver_embedding,
                                                 opts.receiver_hidden,
                                                 cell=opts.receiver_cell)

        game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff)
        callbacks = []
    elif opts.mode.lower() == 'gs':
        sender = core.RnnSenderGS(sender,
                                  opts.vocab_size,
                                  opts.sender_embedding,
                                  opts.sender_hidden,
                                  cell=opts.sender_cell,
                                  max_len=opts.max_len,
                                  temperature=opts.temperature)

        receiver = core.RnnReceiverGS(receiver,
                                      opts.vocab_size,
                                      opts.receiver_embedding,
                                      opts.receiver_hidden,
                                      cell=opts.receiver_cell)

        game = core.SenderReceiverRnnGS(sender, receiver, loss)
        callbacks = [
            core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1)
        ]
    else:
        raise NotImplementedError(f'Unknown training mode, {opts.mode}')

    optimizer = torch.optim.Adam([{
        'params': game.sender.parameters(),
        'lr': opts.sender_lr
    }, {
        'params': game.receiver.parameters(),
        'lr': opts.receiver_lr
    }])

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=callbacks +
                           [core.ConsoleLogger(as_json=True)])
    trainer.train(n_epochs=opts.n_epochs)

    core.close()
                                   tags=['buffled_berkeley']) as experiment:
        print(os.environ)
        compositional_game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff,
            length_cost=opts.length_cost)
        # compositional_game = CompositionalGameReinforce(sender, receiver, loss)
        sender_params = [{'params': sender.parameters(), 'lr': opts.sender_lr}]
        receiver_params = [{
            'params': receiver.parameters(),
            'lr': opts.receiver_lr
        }]
        optimizer = torch.optim.Adam(sender_params + receiver_params)
        trainer = core.Trainer(
            game=compositional_game,
            optimizer=optimizer,
            train_data=train_loader,
            validation_data=test_loader,
            callbacks=[
                CompositionalityMetric(full_dataset,
                                       opts,
                                       test_targets=test_targets),
                NeptuneMonitor(train_freq=100, test_freq=1),
                core.ConsoleLogger(
                    print_train_loss=not os.environ.get('SLURM_JOB_NAME')),
            ])
        trainer.train(n_epochs=opts.n_epochs)
示例#21
0
文件: train.py 项目: KIKOU2016/EGG
                                  force_eos=opts.force_eos)

        receiver = core.RnnReceiverGS(receiver,
                                      opts.vocab_size,
                                      opts.receiver_embedding,
                                      opts.receiver_hidden,
                                      cell=opts.receiver_cell)

        game = core.SenderReceiverRnnGS(sender, receiver, loss)
        callback = sender.update_temp(0.9, 0.1)
    else:
        raise NotImplementedError(f'Unknown training mode, {opts.mode}')

    optimizer = torch.optim.Adam([{
        'params': game.sender.parameters(),
        'lr': opts.sender_lr
    }, {
        'params': game.receiver.parameters(),
        'lr': opts.receiver_lr
    }])

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           epoch_callback=callback,
                           callbacks=[core.ConsoleLogger(as_json=True)])
    trainer.train(n_epochs=opts.n_epochs)

    core.close()
示例#22
0
        receiver = core.RnnReceiverGS(receiver,
                                      opts.vocab_size,
                                      opts.receiver_embedding,
                                      opts.receiver_hidden,
                                      cell=opts.receiver_cell)

        game = core.SenderReceiverRnnGS(sender, receiver, loss)
        callbacks = [
            core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1)
        ]
    else:
        raise NotImplementedError(f'Unknown training mode, {opts.mode}')

    optimizer = torch.optim.Adam([{
        'params': game.sender.parameters(),
        'lr': opts.sender_lr
    }, {
        'params': game.receiver.parameters(),
        'lr': opts.receiver_lr
    }])

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=callbacks +
                           [core.ConsoleLogger(as_json=True)])
    trainer.train(n_epochs=opts.n_epochs)

    core.close()
示例#23
0
def main(params):
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    force_eos = opts.force_eos == 1

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32)
    probs /= probs.sum()

    print('the probs are: ', probs, flush=True)

    train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch, probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    if opts.sender_cell == 'transformer':
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding)
        sender = core.TransformerSenderReinforce(agent=sender, vocab_size=opts.vocab_size,
                                                 embed_dim=opts.sender_embedding, max_len=opts.max_len,
                                                 num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads,
                                                 hidden_size=opts.sender_hidden,
                                                 force_eos=opts.force_eos,
                                                 generate_style=opts.sender_generate_style,
                                                 causal=opts.causal_sender)
    else:
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)

        sender = core.RnnSenderReinforce(sender,
                                   opts.vocab_size, opts.sender_embedding, opts.sender_hidden,
                                   cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers,
                                   force_eos=force_eos)
    if opts.receiver_cell == 'transformer':
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_embedding)
        receiver = core.TransformerReceiverDeterministic(receiver, opts.vocab_size, opts.max_len,
                                                         opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden,
                                                         opts.receiver_num_layers, causal=opts.causal_receiver)
    else:
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
        receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_embedding,
                                             opts.receiver_hidden, cell=opts.receiver_cell,
                                             num_layers=opts.receiver_num_layers)

    empty_logger = LoggingStrategy.minimal()
    game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff,
                                           receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                           train_logging_strategy=empty_logger,
                                           length_cost=opts.length_cost)

    optimizer = core.build_optimizer(game.parameters())

    callbacks = [EarlyStopperAccuracy(opts.early_stopping_thr),
               core.ConsoleLogger(as_json=True, print_train_loss=True)]

    if opts.checkpoint_dir:
        checkpoint_name = f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}'
        callbacks.append(core.CheckpointSaver(checkpoint_path=opts.checkpoint_dir, prefix=checkpoint_name))

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=callbacks)

    trainer.train(n_epochs=opts.n_epochs)

    game.logging_strategy = LoggingStrategy.maximal() # now log everything
    dump(trainer.game, opts.n_features, device, False)
    core.close()
示例#24
0
文件: train.py 项目: wedddy0707/EGG
def main(params):
    opts = get_opts(params=params)
    print(opts)
    print(
        f"Running a distruted training is set to: {opts.distributed_context.is_distributed}. "
        f"World size is {opts.distributed_context.world_size}\n"
        f"Using imagenet with image size: {opts.image_size}. "
        f"Using batch of size {opts.batch_size} on {opts.distributed_context.world_size} device(s)"
    )
    if not opts.distributed_context.is_distributed and opts.pdb:
        breakpoint()

    train_loader = get_dataloader(
        dataset_dir=opts.dataset_dir,
        dataset_name=opts.dataset_name,
        image_size=opts.image_size,
        batch_size=opts.batch_size,
        num_workers=opts.num_workers,
        is_distributed=opts.distributed_context.is_distributed,
        seed=opts.random_seed,
    )

    simclr_game = build_game(
        batch_size=opts.batch_size,
        loss_temperature=opts.ntxent_tau,
        vision_encoder_name=opts.model_name,
        output_size=opts.output_size,
        is_distributed=opts.distributed_context.is_distributed,
    )

    model_parameters = add_weight_decay(simclr_game,
                                        opts.weight_decay,
                                        skip_name="bn")

    optimizer_original = torch.optim.SGD(
        model_parameters,
        lr=opts.lr,
        momentum=0.9,
    )
    optimizer = LARC(optimizer_original,
                     trust_coefficient=0.001,
                     clip=False,
                     eps=1e-8)
    optimizer_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_original, T_max=opts.n_epochs)

    callbacks = [
        core.ConsoleLogger(as_json=True, print_train_loss=True),
        BestStatsTracker(),
        VisionModelSaver(),
    ]

    if opts.distributed_context.is_distributed:
        callbacks.append(DistributedSamplerEpochSetter())

    trainer = core.Trainer(
        game=simclr_game,
        optimizer=optimizer,
        optimizer_scheduler=optimizer_scheduler,
        train_data=train_loader,
        callbacks=callbacks,
    )
    trainer.train(n_epochs=opts.n_epochs)
示例#25
0
文件: train.py 项目: cjlovering/EGG
            length_cost=opts.length_cost,
        )
    else:
        raise NotImplementedError(f"Unknown training mode, {opts.mode}")

    optimizer = torch.optim.RMSprop([
        {
            "params": game.sender.parameters(),
            "lr": opts.sender_lr
        },
        {
            "params": game.receiver.parameters(),
            "lr": opts.receiver_lr
        },
    ])
    callbacks = [core.ConsoleLogger(print_train_loss=True, as_json=True)]
    if opts.mode.lower() == "gs":
        callbacks.append(
            core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1))
    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_data,
        validation_data=validation_data,
        callbacks=callbacks,
    )  # validation_data=validation_data,
    trainer.train(n_epochs=opts.n_epochs)

    if opts.evaluate:
        is_gs = "gs" in opts.mode
        sender_inputs, messages, receiver_inputs, receiver_outputs, labels = core.dump_sender_receiver(
示例#26
0
def main(params):
    import copy
    opts = get_params(params)
    device = opts.device

    train_loader, validation_loader = get_dsprites_dataloader(
        path_to_data='egg/zoo/data_loaders/data/dsprites.npz',
        batch_size=opts.batch_size,
        subsample=opts.subsample,
        image=False)

    n_dim = opts.n_attributes

    if opts.receiver_cell in ['lstm', 'rnn', 'gru']:
        receiver = Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim)
        receiver = core.RnnReceiverDeterministic(receiver,
                                                 opts.vocab_size + 1,
                                                 opts.receiver_emb,
                                                 opts.receiver_hidden,
                                                 cell=opts.receiver_cell)
    else:
        raise ValueError(f'Unknown receiver cell, {opts.receiver_cell}')

    if opts.sender_cell in ['lstm', 'rnn', 'gru']:
        sender = Sender(n_inputs=n_dim, n_hidden=opts.sender_hidden)
        sender = core.RnnSenderReinforce(agent=sender,
                                         vocab_size=opts.vocab_size,
                                         embed_dim=opts.sender_emb,
                                         hidden_size=opts.sender_hidden,
                                         max_len=opts.max_len,
                                         force_eos=False,
                                         cell=opts.sender_cell)
    else:
        raise ValueError(f'Unknown sender cell, {opts.sender_cell}')

    sender = PlusOneWrapper(sender)
    loss = DiffLoss(opts.n_attributes, opts.n_values)

    baseline = {
        'no': core.baselines.NoBaseline,
        'mean': core.baselines.MeanBaseline,
        'builtin': core.baselines.BuiltInBaseline
    }[opts.baseline]

    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=0.0,
        length_cost=0.0,
        baseline_type=baseline)
    optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr)

    latent_values, _ = zip(*[batch for batch in validation_loader])
    latent_values = latent_values[0]

    metrics_evaluator = Metrics(latent_values,
                                opts.device,
                                opts.n_attributes,
                                opts.n_values,
                                opts.vocab_size + 1,
                                freq=opts.stats_freq)

    loaders = []
    #loaders.append(("generalization hold out", generalization_holdout_loader, DiffLoss(
    #opts.n_attributes, opts.n_values, generalization=True)))
    loaders.append(("uniform holdout", validation_loader,
                    DiffLoss(opts.n_attributes, opts.n_values)))

    holdout_evaluator = Evaluator(loaders, opts.device, freq=0)
    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr,
                                         validation=True)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=validation_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True,
                                                  print_train_loss=False),
                               early_stopper, metrics_evaluator,
                               holdout_evaluator
                           ])
    trainer.train(n_epochs=opts.n_epochs)

    last_epoch_interaction = early_stopper.validation_stats[-1][1]
    validation_acc = last_epoch_interaction.aux['acc'].mean()

    uniformtest_acc = holdout_evaluator.results['uniform holdout']['acc']

    # Train new agents
    if validation_acc > 0.99:

        def _set_seed(seed):
            import random
            import numpy as np

            random.seed(seed)
            torch.manual_seed(seed)
            np.random.seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)

        core.get_opts().preemptable = False
        core.get_opts().checkpoint_path = None

        # freeze Sender and probe how fast a simple Receiver will learn the thing
        def retrain_receiver(receiver_generator, sender):
            receiver = receiver_generator()
            game = core.SenderReceiverRnnReinforce(sender,
                                                   receiver,
                                                   loss,
                                                   sender_entropy_coeff=0.0,
                                                   receiver_entropy_coeff=0.0)
            optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr)
            early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr,
                                                 validation=True)

            trainer = core.Trainer(game=game,
                                   optimizer=optimizer,
                                   train_data=train_loader,
                                   validation_data=validation_loader,
                                   callbacks=[
                                       early_stopper,
                                       Evaluator(loaders, opts.device, freq=0)
                                   ])
            trainer.train(n_epochs=opts.n_epochs // 2)

            accs = [x[1]['acc'] for x in early_stopper.validation_stats]
            return accs

        frozen_sender = Freezer(copy.deepcopy(sender))

        def gru_receiver_generator():            return \
core.RnnReceiverDeterministic(Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=opts.receiver_hidden, cell='gru')

        def small_gru_receiver_generator():            return \
core.RnnReceiverDeterministic(
                Receiver(n_hidden=100, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=100, cell='gru')

        def tiny_gru_receiver_generator():            return \
core.RnnReceiverDeterministic(
                Receiver(n_hidden=50, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=50, cell='gru')

        def nonlinear_receiver_generator():            return \
NonLinearReceiver(n_outputs=n_dim, vocab_size=opts.vocab_size + 1,
            max_length=opts.max_len, n_hidden=opts.receiver_hidden)

        for name, receiver_generator in [
            ('gru', gru_receiver_generator),
            ('nonlinear', nonlinear_receiver_generator),
            ('tiny_gru', tiny_gru_receiver_generator),
            ('small_gru', small_gru_receiver_generator),
        ]:

            for seed in range(17, 17 + 3):
                _set_seed(seed)
                accs = retrain_receiver(receiver_generator, frozen_sender)
                accs += [1.0] * (opts.n_epochs // 2 - len(accs))
                auc = sum(accs)
                print(
                    json.dumps({
                        "mode": "reset",
                        "seed": seed,
                        "receiver_name": name,
                        "auc": auc
                    }))

    print('---End--')

    core.close()
示例#27
0
文件: train.py 项目: wedddy0707/EGG
def main(params):
    import copy

    opts = get_params(params)
    device = opts.device

    full_data = enumerate_attribute_value(opts.n_attributes, opts.n_values)
    if opts.density_data > 0:
        sampled_data = select_subset_V2(
            full_data, opts.density_data, opts.n_attributes, opts.n_values
        )
        full_data = copy.deepcopy(sampled_data)

    train, generalization_holdout = split_holdout(full_data)
    train, uniform_holdout = split_train_test(train, 0.1)

    generalization_holdout, train, uniform_holdout, full_data = [
        one_hotify(x, opts.n_attributes, opts.n_values)
        for x in [generalization_holdout, train, uniform_holdout, full_data]
    ]

    train, validation = ScaledDataset(train, opts.data_scaler), ScaledDataset(train, 1)

    generalization_holdout, uniform_holdout, full_data = (
        ScaledDataset(generalization_holdout),
        ScaledDataset(uniform_holdout),
        ScaledDataset(full_data),
    )
    generalization_holdout_loader, uniform_holdout_loader, full_data_loader = [
        DataLoader(x, batch_size=opts.batch_size)
        for x in [generalization_holdout, uniform_holdout, full_data]
    ]

    train_loader = DataLoader(train, batch_size=opts.batch_size)
    validation_loader = DataLoader(validation, batch_size=len(validation))

    n_dim = opts.n_attributes * opts.n_values

    if opts.receiver_cell in ["lstm", "rnn", "gru"]:
        receiver = Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim)
        receiver = core.RnnReceiverDeterministic(
            receiver,
            opts.vocab_size + 1,
            opts.receiver_emb,
            opts.receiver_hidden,
            cell=opts.receiver_cell,
        )
    else:
        raise ValueError(f"Unknown receiver cell, {opts.receiver_cell}")

    if opts.sender_cell in ["lstm", "rnn", "gru"]:
        sender = Sender(n_inputs=n_dim, n_hidden=opts.sender_hidden)
        sender = core.RnnSenderReinforce(
            agent=sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_emb,
            hidden_size=opts.sender_hidden,
            max_len=opts.max_len,
            cell=opts.sender_cell,
        )
    else:
        raise ValueError(f"Unknown sender cell, {opts.sender_cell}")

    sender = PlusOneWrapper(sender)
    loss = DiffLoss(opts.n_attributes, opts.n_values)

    baseline = {
        "no": core.baselines.NoBaseline,
        "mean": core.baselines.MeanBaseline,
        "builtin": core.baselines.BuiltInBaseline,
    }[opts.baseline]

    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=0.0,
        length_cost=0.0,
        baseline_type=baseline,
    )
    optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr)

    metrics_evaluator = Metrics(
        validation.examples,
        opts.device,
        opts.n_attributes,
        opts.n_values,
        opts.vocab_size + 1,
        freq=opts.stats_freq,
    )

    loaders = []
    loaders.append(
        (
            "generalization hold out",
            generalization_holdout_loader,
            DiffLoss(opts.n_attributes, opts.n_values, generalization=True),
        )
    )
    loaders.append(
        (
            "uniform holdout",
            uniform_holdout_loader,
            DiffLoss(opts.n_attributes, opts.n_values),
        )
    )

    holdout_evaluator = Evaluator(loaders, opts.device, freq=0)
    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True)

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=validation_loader,
        callbacks=[
            core.ConsoleLogger(as_json=True, print_train_loss=False),
            early_stopper,
            metrics_evaluator,
            holdout_evaluator,
        ],
    )
    trainer.train(n_epochs=opts.n_epochs)

    last_epoch_interaction = early_stopper.validation_stats[-1][1]
    validation_acc = last_epoch_interaction.aux["acc"].mean()

    uniformtest_acc = holdout_evaluator.results["uniform holdout"]["acc"]

    # Train new agents
    if validation_acc > 0.99:

        def _set_seed(seed):
            import random

            import numpy as np

            random.seed(seed)
            torch.manual_seed(seed)
            np.random.seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)

        core.get_opts().preemptable = False
        core.get_opts().checkpoint_path = None

        # freeze Sender and probe how fast a simple Receiver will learn the thing
        def retrain_receiver(receiver_generator, sender):
            receiver = receiver_generator()
            game = core.SenderReceiverRnnReinforce(
                sender,
                receiver,
                loss,
                sender_entropy_coeff=0.0,
                receiver_entropy_coeff=0.0,
            )
            optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr)
            early_stopper = EarlyStopperAccuracy(
                opts.early_stopping_thr, validation=True
            )

            trainer = core.Trainer(
                game=game,
                optimizer=optimizer,
                train_data=train_loader,
                validation_data=validation_loader,
                callbacks=[early_stopper, Evaluator(loaders, opts.device, freq=0)],
            )
            trainer.train(n_epochs=opts.n_epochs // 2)

            accs = [x[1]["acc"] for x in early_stopper.validation_stats]
            return accs

        frozen_sender = Freezer(copy.deepcopy(sender))

        def gru_receiver_generator():
            return core.RnnReceiverDeterministic(
                Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim),
                opts.vocab_size + 1,
                opts.receiver_emb,
                hidden_size=opts.receiver_hidden,
                cell="gru",
            )

        def small_gru_receiver_generator():
            return core.RnnReceiverDeterministic(
                Receiver(n_hidden=100, n_outputs=n_dim),
                opts.vocab_size + 1,
                opts.receiver_emb,
                hidden_size=100,
                cell="gru",
            )

        def tiny_gru_receiver_generator():
            return core.RnnReceiverDeterministic(
                Receiver(n_hidden=50, n_outputs=n_dim),
                opts.vocab_size + 1,
                opts.receiver_emb,
                hidden_size=50,
                cell="gru",
            )

        def nonlinear_receiver_generator():
            return NonLinearReceiver(
                n_outputs=n_dim,
                vocab_size=opts.vocab_size + 1,
                max_length=opts.max_len,
                n_hidden=opts.receiver_hidden,
            )

        for name, receiver_generator in [
            ("gru", gru_receiver_generator),
            ("nonlinear", nonlinear_receiver_generator),
            ("tiny_gru", tiny_gru_receiver_generator),
            ("small_gru", small_gru_receiver_generator),
        ]:

            for seed in range(17, 17 + 3):
                _set_seed(seed)
                accs = retrain_receiver(receiver_generator, frozen_sender)
                accs += [1.0] * (opts.n_epochs // 2 - len(accs))
                auc = sum(accs)
                print(
                    json.dumps(
                        {
                            "mode": "reset",
                            "seed": seed,
                            "receiver_name": name,
                            "auc": auc,
                        }
                    )
                )

    print("---End--")

    core.close()
示例#28
0
def main(params):
    opts = get_params(params)
    print(opts)

    device = opts.device

    train_loader = OneHotLoader(n_bits=opts.n_bits,
                                bits_s=opts.bits_s,
                                bits_r=opts.bits_r,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.n_examples_per_epoch /
                                opts.batch_size)

    test_loader = UniformLoader(n_bits=opts.n_bits,
                                bits_s=opts.bits_s,
                                bits_r=opts.bits_r)
    test_loader.batch = [x.to(device) for x in test_loader.batch]

    if not opts.variable_length:
        sender = Sender(n_bits=opts.n_bits,
                        n_hidden=opts.sender_hidden,
                        vocab_size=opts.vocab_size)
        if opts.mode == 'gs':
            sender = core.GumbelSoftmaxWrapper(agent=sender,
                                               temperature=opts.temperature)
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            receiver = core.SymbolReceiverWrapper(
                receiver,
                vocab_size=opts.vocab_size,
                agent_input_size=opts.receiver_hidden)
            game = core.SymbolGameGS(sender, receiver, diff_loss)
        elif opts.mode == 'rf':
            sender = core.ReinforceWrapper(agent=sender)
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            receiver = core.SymbolReceiverWrapper(
                receiver,
                vocab_size=opts.vocab_size,
                agent_input_size=opts.receiver_hidden)
            receiver = core.ReinforceDeterministicWrapper(agent=receiver)
            game = core.SymbolGameReinforce(
                sender,
                receiver,
                diff_loss,
                sender_entropy_coeff=opts.sender_entropy_coeff)
        elif opts.mode == 'non_diff':
            sender = core.ReinforceWrapper(agent=sender)
            receiver = ReinforcedReceiver(n_bits=opts.n_bits,
                                          n_hidden=opts.receiver_hidden)
            receiver = core.SymbolReceiverWrapper(
                receiver,
                vocab_size=opts.vocab_size,
                agent_input_size=opts.receiver_hidden)

            game = core.SymbolGameReinforce(
                sender,
                receiver,
                non_diff_loss,
                sender_entropy_coeff=opts.sender_entropy_coeff,
                receiver_entropy_coeff=opts.receiver_entropy_coeff)
    else:
        if opts.mode != 'rf':
            print('Only mode=rf is supported atm')
            opts.mode = 'rf'

        if opts.sender_cell == 'transformer':
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            sender = Sender(
                n_bits=opts.n_bits,
                n_hidden=opts.sender_hidden,
                vocab_size=opts.sender_hidden)  # TODO: not really vocab
            sender = core.TransformerSenderReinforce(
                agent=sender,
                vocab_size=opts.vocab_size,
                embed_dim=opts.sender_emb,
                max_len=opts.max_len,
                num_layers=1,
                num_heads=1,
                hidden_size=opts.sender_hidden)
        else:
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            sender = Sender(
                n_bits=opts.n_bits,
                n_hidden=opts.sender_hidden,
                vocab_size=opts.sender_hidden)  # TODO: not really vocab
            sender = core.RnnSenderReinforce(agent=sender,
                                             vocab_size=opts.vocab_size,
                                             embed_dim=opts.sender_emb,
                                             hidden_size=opts.sender_hidden,
                                             max_len=opts.max_len,
                                             cell=opts.sender_cell)

        if opts.receiver_cell == 'transformer':
            receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_emb)
            receiver = core.TransformerReceiverDeterministic(
                receiver,
                opts.vocab_size,
                opts.max_len,
                opts.receiver_emb,
                num_heads=1,
                hidden_size=opts.receiver_hidden,
                num_layers=1)
        else:
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            receiver = core.RnnReceiverDeterministic(receiver,
                                                     opts.vocab_size,
                                                     opts.receiver_emb,
                                                     opts.receiver_hidden,
                                                     cell=opts.receiver_cell)

            game = core.SenderReceiverRnnGS(sender, receiver, diff_loss)

        game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            diff_loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff)

    optimizer = torch.optim.Adam([
        dict(params=sender.parameters(), lr=opts.sender_lr),
        dict(params=receiver.parameters(), lr=opts.receiver_lr)
    ])

    loss = game.loss

    intervention = CallbackEvaluator(test_loader,
                                     device=device,
                                     is_gs=opts.mode == 'gs',
                                     loss=loss,
                                     var_length=opts.variable_length,
                                     input_intervention=True)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True),
                               EarlyStopperAccuracy(opts.early_stopping_thr),
                               intervention
                           ])

    trainer.train(n_epochs=opts.n_epochs)

    core.close()
示例#29
0
def main(params):
    opts = get_params(params)
    print(json.dumps(vars(opts)))

    device = opts.device

    train_loader = OneHotLoader(n_bits=opts.n_bits,
                                bits_s=opts.bits_s,
                                bits_r=opts.bits_r,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.n_examples_per_epoch /
                                opts.batch_size)

    test_loader = UniformLoader(n_bits=opts.n_bits,
                                bits_s=opts.bits_s,
                                bits_r=opts.bits_r)
    test_loader.batch = [x.to(device) for x in test_loader.batch]

    sender = Sender(n_bits=opts.n_bits,
                    n_hidden=opts.sender_hidden,
                    vocab_size=opts.vocab_size)

    if opts.mode == 'gs':
        receiver = Receiver(n_bits=opts.n_bits,
                            n_hidden=opts.receiver_hidden,
                            vocab_size=opts.vocab_size)
        sender = core.GumbelSoftmaxWrapper(agent=sender,
                                           temperature=opts.temperature)
        game = core.SymbolGameGS(sender, receiver, diff_loss)
    elif opts.mode == 'rf':
        sender = core.ReinforceWrapper(agent=sender)
        receiver = Receiver(n_bits=opts.n_bits,
                            n_hidden=opts.receiver_hidden,
                            vocab_size=opts.vocab_size)
        receiver = core.ReinforceDeterministicWrapper(agent=receiver)
        game = core.SymbolGameReinforce(
            sender,
            receiver,
            diff_loss,
            sender_entropy_coeff=opts.sender_entropy_coeff)
    elif opts.mode == 'non_diff':
        sender = core.ReinforceWrapper(agent=sender)
        receiver = ReinforcedReceiver(n_bits=opts.n_bits,
                                      n_hidden=opts.receiver_hidden,
                                      vocab_size=opts.vocab_size)
        game = core.SymbolGameReinforce(
            sender,
            receiver,
            non_diff_loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff)

    else:
        assert False, 'Unknown training mode'

    optimizer = torch.optim.Adam([
        dict(params=sender.parameters(), lr=opts.sender_lr),
        dict(params=receiver.parameters(), lr=opts.receiver_lr)
    ])

    loss = game.loss

    intervention = CallbackEvaluator(test_loader,
                                     device=device,
                                     is_gs=opts.mode == 'gs',
                                     loss=loss,
                                     var_length=False,
                                     input_intervention=True)

    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           epoch_callback=intervention,
                           early_stopping=early_stopper,
                           callbacks=[core.ConsoleLogger(as_json=True)])

    trainer.train(n_epochs=opts.n_epochs)

    core.close()
示例#30
0
def main(params):
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    force_eos = opts.force_eos == 1

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32)
    elif opts.probs == 'perso':
        probs = opts.n_features+1 - np.arange(1, opts.n_features+1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32)
    probs /= probs.sum()

    print('the probs are: ', probs, flush=True)

    train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch, probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    if opts.sender_cell == 'transformer':
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding)
        sender = core.TransformerSenderReinforce(agent=sender, vocab_size=opts.vocab_size,
                                                 embed_dim=opts.sender_embedding, max_len=opts.max_len,
                                                 num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads,
                                                 hidden_size=opts.sender_hidden,
                                                 force_eos=opts.force_eos,
                                                 generate_style=opts.sender_generate_style,
                                                 causal=opts.causal_sender)
    else:
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)

        sender = core.RnnSenderReinforce(sender,
                                   opts.vocab_size, opts.sender_embedding, opts.sender_hidden,
                                   cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers,
                                   force_eos=force_eos)
    if opts.receiver_cell == 'transformer':
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_embedding)
        receiver = core.TransformerReceiverDeterministic(receiver, opts.vocab_size, opts.max_len,
                                                         opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden,
                                                         opts.receiver_num_layers, causal=opts.causal_receiver)
    else:
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
        receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_embedding,
                                             opts.receiver_hidden, cell=opts.receiver_cell,
                                             num_layers=opts.receiver_num_layers)

    game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff,
                                           receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                           length_cost=opts.length_cost)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr),
                                      core.ConsoleLogger(as_json=True, print_train_loss=True)])

    """ mode accuracy chope a chaque epoch
    accs=[]
    all_messages,acc=dump(trainer.game, opts.n_features, device, False)
    np.save('messages_0.npy', all_messages)
    accs.append(acc)
    for i in range(int(opts.n_epochs)):
        print(i)
        trainer.train(n_epochs=1)
        all_messages,acc=dump(trainer.game, opts.n_features, device, False)
        np.save('messages_'+str((i+1))+'.npy', all_messages)
        accs.append(acc)
    np.save('accuracy.npy',accs)
    """

    trainer.train(n_epochs=opts.n_epochs)

    #if opts.checkpoint_dir:
        #trainer.save_checkpoint(name=f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}')
    #for i in range(30):
    #        for k in range(30):
    #        if i<k:
    #            all_messages=dump(trainer.game, opts.n_features, device, False,pos_m=i,pos_M=k)



    all_messages=dump(trainer.game, opts.n_features, device, False)


    print(all_messages)

    #freq=np.zeros(30)
    #for message in all_messages[0]:
    #        if i in range(message.shape[0]):
    #        freq[int(message[i])]+=1
    #print(freq)

    core.close()