Example #1
0
def main(params):
    opts = core.init(params=params)
    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        './data', train=True, download=True, transform=transform),
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        './data', train=False, transform=transform),
                                              batch_size=opts.batch_size,
                                              shuffle=True,
                                              **kwargs)

    sender = Sender(opts.vocab_size)
    receiver = Receiver(opts.vocab_size)
    game = VAE_Game(sender, receiver)
    optimizer = core.build_optimizer(game.parameters())

    # initialize and launch the trainer
    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True,
                                                  print_train_loss=True),
                               ImageDumpCallback(test_loader.dataset)
                           ])
    trainer.train(n_epochs=opts.n_epochs)

    core.close()
Example #2
0
def main(params):
    opts = core.init(params=params)

    root = os.path.join('data', 'dsprites-dataset', 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
    if not os.path.exists(root):
        import subprocess
        print('Now download dsprites-dataset')
        subprocess.call([os.path.join('egg', 'zoo', 'dsprites_bvae', 'data_loaders', 'download_dsprites.sh')])
        print('Finished')

    train_loader, test_loader = get_dsprites_dataloader(path_to_data=root,
                                                            batch_size=opts.batch_size, image=True)
    image_shape = (64, 64)

    sender = VisualSender()
    receiver = VisualReceiver()
    game = betaVAE_Game(sender, receiver)

    optimizer = core.build_optimizer(game.parameters())

    # initialize and launch the trainer
    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader,
                           callbacks=[core.ConsoleLogger(as_json=True, print_train_loss=True),
                                      ImageDumpCallback(test_loader.dataset, image_shape=image_shape),
                                      TopographicSimilarity(sender_input_distance_fn='euclidean', message_distance_fn='euclidean', is_gumbel=False), PosDisent()])
    trainer.train(n_epochs=opts.n_epochs)

    core.close()
Example #3
0
def main(params):
    opts = get_params(params)
    print(opts)

    kwargs = {"num_workers": 1, "pin_memory": True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST("./data",
                                   train=True,
                                   download=True,
                                   transform=transform)
    test_dataset = datasets.MNIST("./data",
                                  train=False,
                                  download=False,
                                  transform=transform)
    n_classes = 10
    label_mapping = torch.LongTensor([x % n_classes for x in range(100)])

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)
    train_loader = DoubleMnist(train_loader, label_mapping)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16 * 1024,
                                              shuffle=False,
                                              **kwargs)
    test_loader = DoubleMnist(test_loader, label_mapping)

    sender = Sender(
        vocab_size=opts.vocab_size,
        linear_channel=opts.linear_channel == 1,
        softmax_channel=opts.softmax_non_linearity,
    )
    receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes)

    if opts.softmax_non_linearity == 0 and opts.linear_channel == 0:
        sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=[
            core.ConsoleLogger(as_json=True, print_train_loss=True),
            EarlyStopperAccuracy(opts.early_stopping_thr),
        ],
    )

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
Example #4
0
def main(params):
    opts = get_params(params)
    print(opts)

    label_mapping = torch.LongTensor([x % opts.n_labels for x in range(100)])
    print('# label mapping', label_mapping.tolist())

    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST('./data',
                                   train=True,
                                   download=True,
                                   transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)
    train_loader = DoubleMnist(train_loader, label_mapping)

    test_dataset = datasets.MNIST('./data', train=False, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16 * 1024,
                                              shuffle=False,
                                              **kwargs)
    test_loader = DoubleMnist(test_loader, label_mapping)

    sender = Sender(vocab_size=opts.vocab_size)
    receiver = Receiver(vocab_size=opts.vocab_size,
                        n_classes=opts.n_labels,
                        n_hidden=opts.n_hidden)
    sender = core.GumbelSoftmaxWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    intervention = CallbackEvaluator(test_loader,
                                     device=opts.device,
                                     loss=game.loss,
                                     is_gs=True,
                                     var_length=False,
                                     input_intervention=False)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True),
                               EarlyStopperAccuracy(opts.early_stopping_thr),
                               intervention
                           ])

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
Example #5
0
def main(params):
    # initialize the egg lib
    opts = core.init(params=params)
    # get pre-defined common line arguments (batch/vocab size, etc).
    # See egg/core/util.py for a list

    # prepare the dataset
    kwargs = {"num_workers": 1, "pin_memory": True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        "./data", train=True, download=True, transform=transform),
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        "./data", train=False, transform=transform),
                                              batch_size=opts.batch_size,
                                              shuffle=True,
                                              **kwargs)

    # initialize the agents and the game
    sender = Sender(opts.vocab_size)  # the "data" transform part of an agent
    sender = core.GumbelSoftmaxWrapper(
        sender, temperature=1.0)  # wrapping into a GS interface

    receiver = Receiver()
    receiver = core.SymbolReceiverWrapper(receiver,
                                          vocab_size=opts.vocab_size,
                                          agent_input_size=400)
    # setting up as a standard Sender/Receiver game with 1 symbol communication
    game = core.SymbolGameGS(sender, receiver, loss)
    # This callback would be called at the end of each epoch by the Trainer; it reduces the sampling
    # temperature used by the GS
    temperature_updater = core.TemperatureUpdater(agent=sender,
                                                  decay=0.75,
                                                  minimum=0.01)
    # get an optimizer that is set up by common command line parameters,
    # defaults to Adam
    optimizer = core.build_optimizer(game.parameters())

    # initialize and launch the trainer
    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=[
            temperature_updater,
            core.ConsoleLogger(as_json=True, print_train_loss=True),
        ],
    )
    trainer.train(n_epochs=opts.n_epochs)

    core.close()
Example #6
0
def main(params):
    opts = get_params(params)
    print(opts)

    device = opts.device

    train_data = SphereData(n_points=int(opts.n_points))
    train_loader = DataLoader(train_data,
                              batch_size=opts.batch_size,
                              shuffle=True)

    test_data = SphereData(n_points=int(1e3))
    test_loader = DataLoader(train_data,
                             batch_size=opts.batch_size,
                             shuffle=False)

    sender = CircleSender(opts.vocab_size)

    assert opts.lenses in [0, 1]
    if opts.lenses == 1:
        sender = torch.nn.Sequential(Lenses(math.pi / 4), sender)

    receiver = Receiver(n_hidden=opts.receiver_hidden,
                        n_dim=2,
                        inner_layers=opts.receiver_layers)
    receiver = core.RnnReceiverDeterministic(
        receiver,
        opts.vocab_size + 1,  # exclude eos = 0
        opts.receiver_emb,
        opts.receiver_hidden,
        cell=opts.receiver_cell,
        num_layers=opts.cell_layers)

    game = core.SenderReceiverRnnReinforce(sender,
                                           receiver,
                                           diff_loss,
                                           receiver_entropy_coeff=0.05,
                                           sender_entropy_coeff=0.0)

    optimizer = core.build_optimizer(receiver.parameters())
    loss = game.loss

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=[core.ConsoleLogger(as_json=True, print_train_loss=True)],
        grad_norm=1.0)

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
Example #7
0
def main(params):
    opts = get_params(params)
    print(json.dumps(vars(opts)))

    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST('./data',
                                   train=True,
                                   download=True,
                                   transform=transform)
    test_dataset = datasets.MNIST('./data',
                                  train=False,
                                  download=False,
                                  transform=transform)
    n_classes = 10

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=opts.batch_size,
                                              shuffle=False,
                                              **kwargs)

    sender = Sender(vocab_size=opts.vocab_size,
                    linear_channel=opts.linear_channel == 1,
                    softmax_channel=opts.softmax_non_linearity)
    receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes)

    if opts.softmax_non_linearity == 0 and opts.linear_channel == 0:
        sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           as_json=True,
                           early_stopping=early_stopper,
                           print_train_loss=True)

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
    def __init__(self, training_log=None) -> None:
        super().__init__()

        self.training_log = training_log if training_log is not None else core.get_opts(
        ).training_log_path

        self.train_loader, self.test_loader = \
            get_symbolic_dataloader(
                n_attributes=self.n_attributes,
                n_values=self.n_values,
                batch_size=self.batch_size
            )

        self.sender = core.RnnSenderGS(SymbolicSenderMLP(
            input_dim=self.n_attributes * self.n_values,
            hidden_dim=self.hidden_size),
                                       self.vocab_size,
                                       self.emb_size,
                                       self.hidden_size,
                                       max_len=self.max_len,
                                       cell="lstm",
                                       temperature=1.0)

        self.receiver = core.RnnReceiverGS(SymbolicReceiverMLP(
            self.n_attributes * self.n_values, self.hidden_size),
                                           self.vocab_size,
                                           self.emb_size,
                                           self.hidden_size,
                                           cell="lstm")

        self.game = core.SenderReceiverRnnGS(self.sender, self.receiver,
                                             self.loss)

        self.optimiser = core.build_optimizer(self.game.parameters())
        self.callbacks = []
        self.callbacks.append(
            ConsoleFileLogger(as_json=True,
                              print_train_loss=True,
                              logfile_path=self.training_log))
        self.callbacks.append(
            core.TemperatureUpdater(agent=self.sender, decay=0.9, minimum=0.1))
        self.callbacks.append(
            TopographicSimilarityLatents('hamming',
                                         'edit',
                                         log_path=core.get_opts().topo_path))
        self.trainer = core.Trainer(game=self.game,
                                    optimizer=self.optimiser,
                                    train_data=self.train_loader,
                                    validation_data=self.test_loader,
                                    callbacks=self.callbacks)
def test_toy_counting_gradient():
    core.init()

    agent = ToyAgent()
    game = ToyGame(agent)
    optimizer = core.build_optimizer(agent.parameters())

    d = ToyDataset()
    trainer = core.Trainer(game, optimizer, train_data=d, validation_data=None)
    trainer.train(10000)

    are_close = torch.allclose(agent.fc1.weight,
                               torch.ones_like(agent.fc1.weight),
                               rtol=0.05)
    assert are_close, agent.fc1.weight
Example #10
0
def main(params):
    opts = get_params(params)
    print(json.dumps(vars(opts)))

    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=True, download=True,
           transform=transform),
           batch_size=opts.batch_size, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transform),
           batch_size=16 * 1024, shuffle=False, **kwargs)

    n_classes = 10

    binarize = False

    test_loader = SplitImages(TakeFirstLoader(test_loader, n=1), rows_receiver=opts.receiver_rows,
                              rows_sender=opts.sender_rows, binarize=binarize, receiver_bottom=True)
    train_loader = SplitImages(train_loader, rows_sender=opts.sender_rows, rows_receiver=opts.receiver_rows,
                               binarize=binarize, receiver_bottom=True)

    sender = Sender(vocab_size=opts.vocab_size)
    receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes)
    sender = core.GumbelSoftmaxWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    intervention = CallbackEvaluator(test_loader, device=opts.device, loss=game.loss,
                                     is_gs=True,
                                     var_length=False,
                                     input_intervention=True)

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[core.ConsoleLogger(as_json=True),
                                      EarlyStopperAccuracy(opts.early_stopping_thr),
                                      intervention])

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
Example #11
0
def main(params):
    opts = get_params(params)
    print(json.dumps(vars(opts)))

    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST('./data', train=True, download=True,
                   transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, download=False,
                   transform=transform)
    n_classes = 10

    corrupt_labels_(dataset=train_dataset, p_corrupt=opts.p_corrupt, seed=opts.random_seed+1)

    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=opts.batch_size, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(test_dataset,
        batch_size=opts.batch_size, shuffle=False, **kwargs)

    deeper_alice = opts.deeper_alice == 1 and opts.deeper == 1
    deeper_bob = opts.deeper_alice != 1 and opts.deeper == 1

    sender = Sender(vocab_size=opts.vocab_size, deeper=deeper_alice, linear_channel=opts.linear_channel == 1,
                    softmax_channel=opts.softmax_non_linearity == 1)
    receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes, deeper=deeper_bob)

    if opts.softmax_non_linearity != 1 and opts.linear_channel != 1:
        sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[core.ConsoleLogger(as_json=True, print_train_loss=True),
                                      EarlyStopperAccuracy(opts.early_stopping_thr)]
                           )

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
Example #12
0
    def __init__(self, data_path:str) -> None:
        super().__init__()

        self.train_loader, self.test_loader = \
            get_dsprites_dataloader(
                batch_size=self.batch_size,
                path_to_data=data_path,
                game_size=self.game_size,
                referential=True
            )

        self.sender = core.RnnSenderGS(
            DspritesSenderCNN(self.hidden_size), 
            self.vocab_size,
            self.emb_size,
            self.hidden_size,
            max_len=self.max_len,
            cell="lstm", 
            temperature=1.0
        )

        self.receiver = core.RnnReceiverGS(
            DSpritesReceiverCNN(self.game_size, self.emb_size, self.hidden_size, reinforce=False),
            self.vocab_size,
            self.emb_size,
            self.hidden_size,
            cell="lstm"
        )

        self.game = core.SenderReceiverRnnGS(self.sender, self.receiver, self.loss)

        self.optimiser = core.build_optimizer(self.game.parameters())
        self.callbacks = []
        self.callbacks.append(core.ConsoleLogger(as_json=True,print_train_loss=True))
        self.callbacks.append(TopographicSimilarityLatents('euclidean', 'edit'))
        #self.callbacks.append(core.TemperatureUpdater(agent=self.sender, decay=0.9, minimum=0.1))
        self.trainer = core.Trainer(
            game=self.game, optimizer=self.optimiser, train_data=self.train_loader, validation_data=self.test_loader,
            callbacks=self.callbacks
        )
def main(params):
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    force_eos = opts.force_eos == 1

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32)
    probs /= probs.sum()

    train_loader = OneHotLoader(n_features=opts.n_values, batch_size=opts.batch_size*opts.n_attributes,
                                batches_per_epoch=opts.batches_per_epoch, probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_values)

    ### SENDER ###

    sender = Sender(n_features=opts.n_attributes*opts.n_values, n_hidden=opts.sender_hidden)

    sender = core.RnnSenderReinforce(sender,opts.vocab_size, opts.sender_embedding, opts.sender_hidden,
                                   cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers,
                                   force_eos=force_eos)


    ### RECEIVER ###

    receiver = Receiver(n_features=opts.n_values, n_hidden=opts.receiver_hidden)

    if not opts.impatient:
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
        receiver = RnnReceiverCompositionality(receiver, opts.vocab_size, opts.receiver_embedding,
                                            opts.receiver_hidden, cell=opts.receiver_cell,
                                            num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_attributes=opts.n_attributes, n_values=opts.n_values)
    else:
        receiver = Receiver(n_features=opts.receiver_hidden, n_hidden=opts.vocab_size)
        # If impatient 1
        receiver = RnnReceiverImpatientCompositionality(receiver, opts.vocab_size, opts.receiver_embedding,
                                            opts.receiver_hidden, cell=opts.receiver_cell,
                                            num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_attributes=opts.n_attributes, n_values=opts.n_values)


    sender.load_state_dict(torch.load(opts.sender_weights,map_location=torch.device('cpu')))
    receiver.load_state_dict(torch.load(opts.receiver_weights,map_location=torch.device('cpu')))

    if not opts.impatient:
        game = CompositionalitySenderReceiverRnnReinforce(sender, receiver, loss_compositionality, sender_entropy_coeff=opts.sender_entropy_coeff,
                                           n_attributes=opts.n_attributes,n_values=opts.n_values,att_weights=[1],receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                           length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg)
    else:
        game = CompositionalitySenderImpatientReceiverRnnReinforce(sender, receiver, loss_impatient_compositionality, sender_entropy_coeff=opts.sender_entropy_coeff,
                                           n_attributes=opts.n_attributes,n_values=opts.n_values,att_weights=[1],receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                           length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg)

    optimizer = core.build_optimizer(game.parameters())

    trainer = CompoTrainer(n_attributes=opts.n_attributes,n_values=opts.n_values,game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader, callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)])



    # Debut test position

    position_sieve=np.zeros((opts.n_attributes**opts.n_values,opts.max_len,opts.n_attributes))

    for position in range(opts.max_len):

        one_hots = torch.eye(opts.n_values)

        val=np.arange(opts.n_values)
        combination=list(itertools.product(val,repeat=opts.n_attributes))

        dataset=[]

        for i in range(len(combination)):
          new_input=torch.zeros(0)
          for j in combination[i]:
            new_input=torch.cat((new_input,one_hots[j]))
          dataset.append(new_input)

        dataset=torch.stack(dataset)

        dataset=[[dataset,None]]

        if opts.impatient:
            sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \
                dump_test_position_impatient_compositionality(trainer.game,
                                    dataset,
                                    position=position,
                                    voc_size=opts.vocab_size,
                                    gs=False,
                                    device=device,
                                    variable_length=True)
        else:
            sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \
                dump_test_position_compositionality(trainer.game,
                                    dataset,
                                    position=position,
                                    voc_size=opts.vocab_size,
                                    gs=False,
                                    device=device,
                                    variable_length=True)

        for i in range(len(receiver_outputs)):
          message=messages[i]
          correct=True
          for j in range(len(list(combination[i]))):
            if receiver_outputs[i][j]==list(combination[i])[j]:
              position_sieve[i,position,j]=1


    # Put -1 for position after message_length
    if not opts.impatient:
        acc_vec,messages=dump_compositionality(trainer.game, opts.n_attributes, opts.n_values, device, False,0)
    else:
        acc_vec,messages=dump_impatient_compositionality(trainer.game, opts.n_attributes, opts.n_values, device, False,0)

    # Convert messages to numpy array
    messages_np=[]
    for x in messages:
        x = x.cpu().numpy()
        messages_np.append(x)

    for i in range(len(messages_np)):
        # Message i
        message_i=messages_np[i]
        id_0=np.where(message_i==0)[0]

        if id_0.shape[0]>0:
          for j in range(id_0[0]+1,opts.max_len):
              position_sieve[i,j]=-1

    np.save("analysis/position_sieve.npy",position_sieve)

    core.close()
Example #14
0
def main(params):
    print(torch.cuda.is_available())
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    force_eos = opts.force_eos == 1

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32)
    #elif opts.probs == "creneau":
    #    ones = np.ones(int(opts.n_features/2))
    #    tens = 10*np.ones(opts.n_features-int(opts.n_features/2))
    #    probs = np.concatenate((tens,ones),axis=0)
    #elif opts.probs == "toy":
    #    fives = 5*np.ones(int(opts.n_features/10))
    #    ones = np.ones(opts.n_features-int(opts.n_features/10))
    #    probs = np.concatenate((fives,ones),axis=0)
    #elif opts.probs == "escalier":
    #    ones = np.ones(int(opts.n_features/4))
    #    tens = 10*np.ones(int(opts.n_features/4))
    #    huns = 100*np.ones(int(opts.n_features/4))
    #    thous = 1000*np.ones(opts.n_features-3*int(opts.n_features/4))
    #    probs = np.concatenate((thous,huns,tens,ones),axis=0)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32)

    probs /= probs.sum()

    print('the probs are: ', probs, flush=True)

    train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch, probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    if opts.sender_cell == 'transformer':
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding)
        sender = core.TransformerSenderReinforce(agent=sender, vocab_size=opts.vocab_size,
                                                 embed_dim=opts.sender_embedding, max_len=opts.max_len,
                                                 num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads,
                                                 hidden_size=opts.sender_hidden,
                                                 force_eos=opts.force_eos,
                                                 generate_style=opts.sender_generate_style,
                                                 causal=opts.causal_sender)
    else:
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)

        sender = core.RnnSenderReinforce(sender,
                                   opts.vocab_size, opts.sender_embedding, opts.sender_hidden,
                                   cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers,
                                   force_eos=force_eos)
    if opts.receiver_cell == 'transformer':
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_embedding)
        receiver = core.TransformerReceiverDeterministic(receiver, opts.vocab_size, opts.max_len,
                                                         opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden,
                                                         opts.receiver_num_layers, causal=opts.causal_receiver)
    else:

        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)

        if not opts.impatient:
          receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
          receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_embedding,
                                                 opts.receiver_hidden, cell=opts.receiver_cell,
                                                 num_layers=opts.receiver_num_layers)
        else:
          receiver = Receiver(n_features=opts.receiver_hidden, n_hidden=opts.vocab_size)
          # If impatient 1
          receiver = RnnReceiverImpatient(receiver, opts.vocab_size, opts.receiver_embedding,
                                            opts.receiver_hidden, cell=opts.receiver_cell,
                                            num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_features=opts.n_features)
          # If impatient 2
          #receiver = RnnReceiverImpatient2(receiver, opts.vocab_size, opts.receiver_embedding,
        #                                         opts.receiver_hidden, cell=opts.receiver_cell,
        #                                         num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_features=opts.n_features)

    if not opts.impatient:
        game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff,
                                           receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                           length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg)
    else:
        game = SenderImpatientReceiverRnnReinforce(sender, receiver, loss_impatient, sender_entropy_coeff=opts.sender_entropy_coeff,
                                           receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                           length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader, callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)])


    for epoch in range(int(opts.n_epochs)):

        print("Epoch: "+str(epoch))

        if epoch%100==0:
          trainer.optimizer.defaults["lr"]/=2

        trainer.train(n_epochs=1)
        if opts.checkpoint_dir:
            trainer.save_checkpoint(name=f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}')

        if not opts.impatient:
            acc_vec,messages=dump(trainer.game, opts.n_features, device, False,epoch)
        else:
            acc_vec,messages=dump_impatient(trainer.game, opts.n_features, device, False,epoch)

        # ADDITION TO SAVE MESSAGES
        all_messages=[]
        for x in messages:
            x = x.cpu().numpy()
            all_messages.append(x)
        all_messages = np.asarray(all_messages)

        if epoch%50==0:
            torch.save(sender.state_dict(), opts.dir_save+"/sender/sender_weights"+str(epoch)+".pth")
            torch.save(receiver.state_dict(), opts.dir_save+"/receiver/receiver_weights"+str(epoch)+".pth")
            #print(acc_vec)

        np.save(opts.dir_save+'/messages/messages_'+str((epoch))+'.npy', all_messages)
        np.save(opts.dir_save+'/accuracy/accuracy_'+str((epoch))+'.npy', acc_vec)

    core.close()
Example #15
0
File: train.py Project: Slowika/EGG
def main(params):
    opts = get_params(params)
    print(opts, flush=True)

    # For compatibility, after https://github.com/facebookresearch/EGG/pull/130
    # the meaning of `length` changed a bit. Before it included the EOS symbol; now
    # it doesn't. To ensure that hyperparameters/CL arguments do not change,
    # we subtract it here.
    opts.max_len -= 1

    device = opts.device

    if opts.probs == "uniform":
        probs = np.ones(opts.n_features)
    elif opts.probs == "powerlaw":
        probs = 1 / np.arange(1, opts.n_features + 1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(",")], dtype=np.float32)
    probs /= probs.sum()

    print("the probs are: ", probs, flush=True)

    train_loader = OneHotLoader(
        n_features=opts.n_features,
        batch_size=opts.batch_size,
        batches_per_epoch=opts.batches_per_epoch,
        probs=probs,
    )

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    if opts.sender_cell == "transformer":
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding)
        sender = core.TransformerSenderReinforce(
            agent=sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            max_len=opts.max_len,
            num_layers=opts.sender_num_layers,
            num_heads=opts.sender_num_heads,
            hidden_size=opts.sender_hidden,
            generate_style=opts.sender_generate_style,
            causal=opts.causal_sender,
        )
    else:
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)

        sender = core.RnnSenderReinforce(
            sender,
            opts.vocab_size,
            opts.sender_embedding,
            opts.sender_hidden,
            cell=opts.sender_cell,
            max_len=opts.max_len,
            num_layers=opts.sender_num_layers,
        )
    if opts.receiver_cell == "transformer":
        receiver = Receiver(
            n_features=opts.n_features, n_hidden=opts.receiver_embedding
        )
        receiver = core.TransformerReceiverDeterministic(
            receiver,
            opts.vocab_size,
            opts.max_len,
            opts.receiver_embedding,
            opts.receiver_num_heads,
            opts.receiver_hidden,
            opts.receiver_num_layers,
            causal=opts.causal_receiver,
        )
    else:
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
        receiver = core.RnnReceiverDeterministic(
            receiver,
            opts.vocab_size,
            opts.receiver_embedding,
            opts.receiver_hidden,
            cell=opts.receiver_cell,
            num_layers=opts.receiver_num_layers,
        )

    empty_logger = LoggingStrategy.minimal()
    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=opts.receiver_entropy_coeff,
        train_logging_strategy=empty_logger,
        length_cost=opts.length_cost,
    )

    optimizer = core.build_optimizer(game.parameters())

    callbacks = [
        EarlyStopperAccuracy(opts.early_stopping_thr),
        core.ConsoleLogger(as_json=True, print_train_loss=True),
    ]

    if opts.checkpoint_dir:
        checkpoint_name = f"{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}"
        callbacks.append(
            core.CheckpointSaver(
                checkpoint_path=opts.checkpoint_dir, prefix=checkpoint_name
            )
        )

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=callbacks,
    )

    trainer.train(n_epochs=opts.n_epochs)

    game.logging_strategy = LoggingStrategy.maximal()  # now log everything
    dump(trainer.game, opts.n_features, device, False)
    core.close()
Example #16
0
def main(params):
    opts = get_params(params)
    if opts.validation_batch_size == 0:
        opts.validation_batch_size = opts.batch_size
    print(opts, flush=True)

    # the following if statement controls aspects specific to the two game tasks: loss, input data and architecture of the Receiver
    # (the Sender is identical in both cases, mapping a single input attribute-value vector to a variable-length message)
    if opts.game_type == "discri":
        # the game object we will encounter below takes as one of its mandatory arguments a loss: a loss in EGG is expected to take as arguments the sender input,
        # the message, the Receiver input, the Receiver output and the labels (although some of these elements might not actually be used by a particular loss);
        # together with the actual loss computation, the loss function can return a dictionary with other auxiliary statistics: in this case, accuracy
        def loss(
            _sender_input,
            _message,
            _receiver_input,
            receiver_output,
            labels,
            _aux_input,
        ):
            # in the discriminative case, accuracy is computed by comparing the index with highest score in Receiver output (a distribution of unnormalized
            # probabilities over target poisitions) and the corresponding label read from input, indicating the ground-truth position of the target
            acc = (receiver_output.argmax(dim=1) == labels).detach().float()
            # similarly, the loss computes cross-entropy between the Receiver-produced target-position probability distribution and the labels
            loss = F.cross_entropy(receiver_output, labels, reduction="none")
            return loss, {"acc": acc}

        # the input data are read into DataLodaer objects, which are pytorch constructs implementing standard data processing functionalities, such as shuffling
        # and batching
        # within our games, we implement dataset classes, such as AttValDiscriDataset, to read the input text files and convert the information they contain
        # into the form required by DataLoader
        # look at the definition of the AttValDiscrDataset (the class to read discrimination game data) in data_readers.py for further details
        # note that, for the training dataset, we first instantiate the AttValDiscriDataset object and then feed it to DataLoader, whereas for the
        # validation data (confusingly called "test" data due to code heritage inertia) we directly declare the AttValDiscriDataset when instantiating
        # DataLoader: the reason for this difference is that we need the train_ds object to retrieve the number of features of the input vectors
        train_ds = AttValDiscriDataset(path=opts.train_data,
                                       n_values=opts.n_values)
        train_loader = DataLoader(train_ds,
                                  batch_size=opts.batch_size,
                                  shuffle=True,
                                  num_workers=1)
        test_loader = DataLoader(
            AttValDiscriDataset(path=opts.validation_data,
                                n_values=opts.n_values),
            batch_size=opts.validation_batch_size,
            shuffle=False,
            num_workers=1,
        )
        # note that the number of features retrieved here concerns inputs after they are converted to 1-hot vectors
        n_features = train_ds.get_n_features()
        # we define here the core of the Receiver for the discriminative game, see the architectures.py file for details:
        # note that this will be embedded in a wrapper below to define the full agent
        receiver = DiscriReceiver(n_features=n_features,
                                  n_hidden=opts.receiver_hidden)

    else:  # reco game

        def loss(sender_input, _message, _receiver_input, receiver_output,
                 labels, _aux_input):
            # in the case of the recognition game, for each attribute we compute a different cross-entropy score
            # based on comparing the probability distribution produced by the Receiver over the values of each attribute
            # with the corresponding ground truth, and then averaging across attributes
            # accuracy is instead computed by considering as a hit only cases where, for each attribute, the Receiver
            # assigned the largest probability to the correct value
            # most of this function consists of the usual pytorch madness needed to reshape tensors in order to perform these computations
            n_attributes = opts.n_attributes
            n_values = opts.n_values
            batch_size = sender_input.size(0)
            receiver_output = receiver_output.view(batch_size * n_attributes,
                                                   n_values)
            receiver_guesses = receiver_output.argmax(dim=1)
            correct_samples = ((receiver_guesses == labels.view(-1)).view(
                batch_size, n_attributes).detach())
            acc = (torch.sum(correct_samples, dim=-1) == n_attributes).float()
            labels = labels.view(batch_size * n_attributes)
            loss = F.cross_entropy(receiver_output, labels, reduction="none")
            loss = loss.view(batch_size, -1).mean(dim=1)
            return loss, {"acc": acc}

        # again, see data_readers.py in this directory for the AttValRecoDataset data reading class
        train_loader = DataLoader(
            AttValRecoDataset(
                path=opts.train_data,
                n_attributes=opts.n_attributes,
                n_values=opts.n_values,
            ),
            batch_size=opts.batch_size,
            shuffle=True,
            num_workers=1,
        )
        test_loader = DataLoader(
            AttValRecoDataset(
                path=opts.validation_data,
                n_attributes=opts.n_attributes,
                n_values=opts.n_values,
            ),
            batch_size=opts.validation_batch_size,
            shuffle=False,
            num_workers=1,
        )
        # the number of features for the Receiver (input) and the Sender (output) is given by n_attributes*n_values because
        # they are fed/produce 1-hot representations of the input vectors
        n_features = opts.n_attributes * opts.n_values
        # we define here the core of the receiver for the discriminative game, see the architectures.py file for details
        # this will be embedded in a wrapper below to define the full architecture
        receiver = RecoReceiver(n_features=n_features,
                                n_hidden=opts.receiver_hidden)

    # we are now outside the block that defined game-type-specific aspects of the games: note that the core Sender architecture
    # (see architectures.py for details) is shared by the two games (it maps an input vector to a hidden layer that will be use to initialize
    # the message-producing RNN): this will also be embedded in a wrapper below to define the full architecture
    sender = Sender(n_hidden=opts.sender_hidden, n_features=n_features)

    # now, we instantiate the full sender and receiver architectures, and connect them and the loss into a game object
    # the implementation differs slightly depending on whether communication is optimized via Gumbel-Softmax ('gs') or Reinforce ('rf', default)
    if opts.mode.lower() == "gs":
        # in the following lines, we embed the Sender and Receiver architectures into standard EGG wrappers that are appropriate for Gumbel-Softmax optimization
        # the Sender wrapper takes the hidden layer produced by the core agent architecture we defined above when processing input, and uses it to initialize
        # the RNN that generates the message
        sender = core.RnnSenderGS(
            sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            hidden_size=opts.sender_hidden,
            cell=opts.sender_cell,
            max_len=opts.max_len,
            temperature=opts.temperature,
        )
        # the Receiver wrapper takes the symbol produced by the Sender at each step (more precisely, in Gumbel-Softmax mode, a function of the overall probability
        # of non-eos symbols upt to the step is used), maps it to a hidden layer through a RNN, and feeds this hidden layer to the
        # core Receiver architecture we defined above (possibly with other Receiver input, as determined by the core architecture) to generate the output
        receiver = core.RnnReceiverGS(
            receiver,
            vocab_size=opts.vocab_size,
            embed_dim=opts.receiver_embedding,
            hidden_size=opts.receiver_hidden,
            cell=opts.receiver_cell,
        )
        game = core.SenderReceiverRnnGS(sender, receiver, loss)
        # callback functions can be passed to the trainer object (see below) to operate at certain steps of training and validation
        # for example, the TemperatureUpdater (defined in callbacks.py in the core directory) will update the Gumbel-Softmax temperature hyperparameter
        # after each epoch
        callbacks = [
            core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1)
        ]
    else:  # NB: any other string than gs will lead to rf training!
        # here, the interesting thing to note is that we use the same core architectures we defined above, but now we embed them in wrappers that are suited to
        # Reinforce-based optmization
        sender = core.RnnSenderReinforce(
            sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            hidden_size=opts.sender_hidden,
            cell=opts.sender_cell,
            max_len=opts.max_len,
        )
        receiver = core.RnnReceiverDeterministic(
            receiver,
            vocab_size=opts.vocab_size,
            embed_dim=opts.receiver_embedding,
            hidden_size=opts.receiver_hidden,
            cell=opts.receiver_cell,
        )
        game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=0,
        )
        callbacks = []

    # we are almost ready to train: we define here an optimizer calling standard pytorch functionality
    optimizer = core.build_optimizer(game.parameters())
    # in the following statement, we finally instantiate the trainer object with all the components we defined (the game, the optimizer, the data
    # and the callbacks)
    if opts.print_validation_events == True:
        # we add a callback that will print loss and accuracy after each training and validation pass (see ConsoleLogger in callbacks.py in core directory)
        # if requested by the user, we will also print a detailed log of the validation pass after full training: look at PrintValidationEvents in
        # language_analysis.py (core directory)
        trainer = core.Trainer(
            game=game,
            optimizer=optimizer,
            train_data=train_loader,
            validation_data=test_loader,
            callbacks=callbacks + [
                core.ConsoleLogger(print_train_loss=True, as_json=True),
                core.PrintValidationEvents(n_epochs=opts.n_epochs),
            ],
        )
    else:
        trainer = core.Trainer(
            game=game,
            optimizer=optimizer,
            train_data=train_loader,
            validation_data=test_loader,
            callbacks=callbacks +
            [core.ConsoleLogger(print_train_loss=True, as_json=True)],
        )

    # and finally we train!
    trainer.train(n_epochs=opts.n_epochs)
from recon_game import SymbolicReconGame
from refer_game import SymbolicReferGame
import torch
import egg.core as core
from egg.core.util import move_to
import numpy as np
from utils import create_dir_for_file

# ================ First play the recon game, get the sender ==================
ref_game = SymbolicReferGame()
ref_game.train(10000)  # the argument is the number of training epochs.

ref_game.game.eval()
recon_game = SymbolicReconGame(
    training_log='~/GitWS/GameBias/log/recon_train_temp.txt')
optimizer = core.build_optimizer(recon_game.game.receiver.parameters())
train_loss = []
test_loss = []

for i in range(5000):
    acc_list = []
    for batch_idx, (target, label) in enumerate(recon_game.train_loader):
        optimizer.zero_grad()
        target = move_to(target, recon_game.trainer.device)
        label = move_to(label, recon_game.trainer.device)

        msg, _ = ref_game.sender(target)
        rec_out = recon_game.receiver(msg.detach())
        loss, _ = recon_game.loss(target, msg, msg, rec_out, label)
        acc_list.append(loss.mean().item())
        loss.sum().backward()
Example #18
0
def main(params):
    opts = get_params(params)
    print(opts)

    device = opts.device

    n_a, n_v = opts.n_a, opts.n_v
    opts.vocab_size = n_v

    train_data = AttributeValueData(n_attributes=n_a,
                                    n_values=n_v,
                                    mul=1,
                                    mode='train')
    train_loader = DataLoader(train_data,
                              batch_size=opts.batch_size,
                              shuffle=True)

    test_data = AttributeValueData(n_attributes=n_a,
                                   n_values=n_v,
                                   mul=1,
                                   mode='test')
    test_loader = DataLoader(test_data,
                             batch_size=opts.batch_size,
                             shuffle=False)

    print(f'# Size of train {len(train_data)} test {len(test_data)}')

    if opts.language == 'identity':
        sender = IdentitySender(n_a, n_v)
    elif opts.language == 'rotated':
        sender = RotatedSender(n_a, n_v)
    else:
        assert False

    receiver = Receiver(n_hidden=opts.receiver_hidden,
                        n_dim=n_a * n_v,
                        inner_layers=opts.receiver_layers)
    receiver = core.RnnReceiverDeterministic(
        receiver,
        opts.vocab_size + 1,  # exclude eos = 0
        opts.receiver_emb,
        opts.receiver_hidden,
        cell=opts.receiver_cell,
        num_layers=opts.cell_layers)

    diff_loss = DiffLoss(n_a, n_v, loss_type=opts.loss_type)

    game = core.SenderReceiverRnnReinforce(sender,
                                           receiver,
                                           diff_loss,
                                           receiver_entropy_coeff=0.05,
                                           sender_entropy_coeff=0.0)

    optimizer = core.build_optimizer(receiver.parameters())
    loss = game.loss

    early_stopper = core.EarlyStopperAccuracy(1.0, validation=False)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True,
                                                  print_train_loss=True),
                               early_stopper
                           ],
                           grad_norm=1.0)

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
Example #19
0
def main(params):
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    force_eos = opts.force_eos == 1

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32)
    probs /= probs.sum()

    train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch, probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    if opts.sender_cell == 'transformer':
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding)
        sender = core.TransformerSenderReinforce(agent=sender, vocab_size=opts.vocab_size,
                                                 embed_dim=opts.sender_embedding, max_len=opts.max_len,
                                                 num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads,
                                                 hidden_size=opts.sender_hidden,
                                                 force_eos=opts.force_eos,
                                                 generate_style=opts.sender_generate_style,
                                                 causal=opts.causal_sender)
    else:
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)

        sender = core.RnnSenderReinforce(sender,
                                   opts.vocab_size, opts.sender_embedding, opts.sender_hidden,
                                   cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers,
                                   force_eos=force_eos)
    if opts.receiver_cell == 'transformer':
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_embedding)
        receiver = core.TransformerReceiverDeterministic(receiver, opts.vocab_size, opts.max_len,
                                                         opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden,
                                                         opts.receiver_num_layers, causal=opts.causal_receiver)
    else:

        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)

        if not opts.impatient:
          receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
          receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_embedding,
                                                 opts.receiver_hidden, cell=opts.receiver_cell,
                                                 num_layers=opts.receiver_num_layers)
        else:
          receiver = Receiver(n_features=opts.receiver_hidden, n_hidden=opts.vocab_size)
          # If impatient 1
          receiver = RnnReceiverImpatient(receiver, opts.vocab_size, opts.receiver_embedding,
                                            opts.receiver_hidden, cell=opts.receiver_cell,
                                            num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_features=opts.n_features)
          # If impatient 2
          #receiver = RnnReceiverImpatient2(receiver, opts.vocab_size, opts.receiver_embedding,
        #                                         opts.receiver_hidden, cell=opts.receiver_cell,
        #                                         num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_features=opts.n_features)

    sender.load_state_dict(torch.load(opts.sender_weights,map_location=torch.device('cpu')))
    receiver.load_state_dict(torch.load(opts.receiver_weights,map_location=torch.device('cpu')))

    if not opts.impatient:
        game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff,
                                           receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                           length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen)
    else:
        game = SenderImpatientReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff,
                                           receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                           length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader, callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)])

    # Test impose message

    if not opts.impatient:
        acc_vec,messages=dump(trainer.game, opts.n_features, device, False)
    else:
        acc_vec,messages=dump_impatient(trainer.game, opts.n_features, device, False,save_dir=opts.save_dir)

    all_messages=[]
    for x in messages:
        x = x.cpu().numpy()
        all_messages.append(x)
    all_messages = np.asarray(all_messages)

    messages=-1*np.ones((opts.n_features,opts.max_len))

    for i in range(len(all_messages)):
      for j in range(all_messages[i].shape[0]):
        messages[i,j]=all_messages[i][j]

    np.save(opts.save_dir+"messages_analysis.npy",messages)

    core.close()
Example #20
0
def main(params):
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    force_eos = opts.force_eos == 1

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features + 1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')],
                         dtype=np.float32)
    probs /= probs.sum()

    print('the probs are: ', probs, flush=True)

    train_loader = OneHotLoader(n_features=opts.n_features,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch,
                                probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    #################################
    # define sender (speaker) agent #
    #################################
    sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)
    sender = RnnSenderReinforce(sender,
                                opts.vocab_size,
                                opts.sender_embedding,
                                opts.sender_hidden,
                                cell=opts.sender_cell,
                                max_len=opts.max_len,
                                num_layers=opts.sender_num_layers,
                                force_eos=force_eos,
                                noise_loc=opts.sender_noise_loc,
                                noise_scale=opts.sender_noise_scale)

    ####################################
    # define receiver (listener) agent #
    ####################################
    receiver = Receiver(n_features=opts.n_features,
                        n_hidden=opts.receiver_hidden)
    receiver = RnnReceiverDeterministic(receiver,
                                        opts.vocab_size,
                                        opts.receiver_embedding,
                                        opts.receiver_hidden,
                                        cell=opts.receiver_cell,
                                        num_layers=opts.receiver_num_layers,
                                        noise_loc=opts.receiver_noise_loc,
                                        noise_scale=opts.receiver_noise_scale)

    ###################
    # define  channel #
    ###################
    channel = Channel(vocab_size=opts.vocab_size, p=opts.channel_repl_prob)

    game = SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=opts.receiver_entropy_coeff,
        length_cost=opts.length_cost,
        effective_max_len=opts.effective_max_len,
        channel=channel,
        sender_entropy_common_ratio=opts.sender_entropy_common_ratio)

    optimizer = core.build_optimizer(game.parameters())

    callbacks = [
        EarlyStopperAccuracy(opts.early_stopping_thr),
        core.ConsoleLogger(as_json=True, print_train_loss=True)
    ]

    if opts.checkpoint_dir:
        '''
        info in checkpoint_name:
            - n_features as f
            - vocab_size as vocab
            - random_seed as rs
            - lr as lr
            - sender_hidden as shid
            - receiver_hidden as rhid
            - sender_entropy_coeff as sentr
            - length_cost as reg
            - max_len as max_len
            - sender_noise_scale as sscl
            - receiver_noise_scale as rscl
            - channel_repl_prob as crp
            - sender_entropy_common_ratio as scr
        '''
        checkpoint_name = (
            f'{opts.name}' + '_aer' +
            ('_uniform' if opts.probs == 'uniform' else '') +
            f'_f{opts.n_features}' + f'_vocab{opts.vocab_size}' +
            f'_rs{opts.random_seed}' + f'_lr{opts.lr}' +
            f'_shid{opts.sender_hidden}' + f'_rhid{opts.receiver_hidden}' +
            f'_sentr{opts.sender_entropy_coeff}' + f'_reg{opts.length_cost}' +
            f'_max_len{opts.max_len}' + f'_sscl{opts.sender_noise_scale}' +
            f'_rscl{opts.receiver_noise_scale}' +
            f'_crp{opts.channel_repl_prob}' +
            f'_scr{opts.sender_entropy_common_ratio}')
        callbacks.append(
            core.CheckpointSaver(checkpoint_path=opts.checkpoint_dir,
                                 checkpoint_freq=opts.checkpoint_freq,
                                 prefix=checkpoint_name))

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=callbacks)

    trainer.train(n_epochs=opts.n_epochs)
    print('<div id="prefix test without eos">')
    prefix_test(trainer.game, opts.n_features, device, add_eos=False)
    print('</div>')
    print('<div id="prefix test with eos">')
    prefix_test(trainer.game, opts.n_features, device, add_eos=True)
    print('<div id="suffix test">')
    suffix_test(trainer.game, opts.n_features, device)
    print('</div>')
    print('<div id="replacement test">')
    replacement_test(trainer.game, opts.n_features, opts.vocab_size, device)
    print('</div>')
    print('<div id="dump">')
    dump(trainer.game, opts.n_features, device, False)
    print('</div>')
    core.close()
def main(params):
    print(torch.cuda.is_available())
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    force_eos = opts.force_eos == 1

    # Distribution of the inputs
    if opts.probs=="uniform":
        probs=[]
        probs_by_att = np.ones(opts.n_values)
        probs_by_att /= probs_by_att.sum()
        for i in range(opts.n_attributes):
            probs.append(probs_by_att)

    if opts.probs=="entropy_test":
        probs=[]
        for i in range(opts.n_attributes):
            probs_by_att = np.ones(opts.n_values)
            probs_by_att[0]=1+(1*i)
            probs_by_att /= probs_by_att.sum()
            probs.append(probs_by_att)

    if opts.probs_attributes=="uniform":
        probs_attributes=[1]*opts.n_attributes

    if opts.probs_attributes=="uniform_indep":
        probs_attributes=[]
        probs_attributes=[0.2]*opts.n_attributes

    if opts.probs_attributes=="echelon":
        probs_attributes=[]
        for i in range(opts.n_attributes):
            #probs_attributes.append(1.-(0.2)*i)
            #probs_attributes.append(0.7+0.3/(i+1))
            probs_attributes=[1.,0.95,0.9,0.85]

    print("Probability by attribute is:",probs_attributes)

    train_loader = OneHotLoaderCompositionality(n_values=opts.n_values, n_attributes=opts.n_attributes, batch_size=opts.batch_size*opts.n_attributes,
                                                batches_per_epoch=opts.batches_per_epoch, probs=probs, probs_attributes=probs_attributes)

    # single batches with 1s on the diag
    test_loader = TestLoaderCompositionality(n_values=opts.n_values,n_attributes=opts.n_attributes)

    ### SENDER ###

    sender = Sender(n_features=opts.n_attributes*opts.n_values, n_hidden=opts.sender_hidden)

    sender = core.RnnSenderReinforce(sender,opts.vocab_size, opts.sender_embedding, opts.sender_hidden,
                                   cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers,
                                   force_eos=force_eos)


    ### RECEIVER ###

    receiver = Receiver(n_features=opts.n_values, n_hidden=opts.receiver_hidden)

    if not opts.impatient:
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
        receiver = RnnReceiverCompositionality(receiver, opts.vocab_size, opts.receiver_embedding,
                                            opts.receiver_hidden, cell=opts.receiver_cell,
                                            num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_attributes=opts.n_attributes, n_values=opts.n_values)
    else:
        receiver = Receiver(n_features=opts.receiver_hidden, n_hidden=opts.vocab_size)
        # If impatient 1
        receiver = RnnReceiverImpatientCompositionality(receiver, opts.vocab_size, opts.receiver_embedding,
                                            opts.receiver_hidden, cell=opts.receiver_cell,
                                            num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_attributes=opts.n_attributes, n_values=opts.n_values)


    if not opts.impatient:
        game = CompositionalitySenderReceiverRnnReinforce(sender, receiver, loss_compositionality, sender_entropy_coeff=opts.sender_entropy_coeff,
                                           n_attributes=opts.n_attributes,n_values=opts.n_values,receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                           length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg)
    else:
        game = CompositionalitySenderImpatientReceiverRnnReinforce(sender, receiver, loss_impatient_compositionality, sender_entropy_coeff=opts.sender_entropy_coeff,
                                           n_attributes=opts.n_attributes,n_values=opts.n_values,att_weights=opts.att_weights,receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                           length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg)

    optimizer = core.build_optimizer(game.parameters())

    trainer = CompoTrainer(n_attributes=opts.n_attributes,n_values=opts.n_values,game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader, callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)])

    curr_accs=[0]*7

    game.att_weights=[1]*(game.n_attributes)

    for epoch in range(int(opts.n_epochs)):

        print("Epoch: "+str(epoch))

        #if epoch%100==0:
        #  trainer.optimizer.defaults["lr"]/=2


        trainer.train(n_epochs=1)
        if opts.checkpoint_dir:
            trainer.save_checkpoint(name=f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}')

        if not opts.impatient:
            acc_vec,messages=dump_compositionality(trainer.game, opts.n_attributes, opts.n_values, device, False,epoch)
        else:
            acc_vec,messages=dump_impatient_compositionality(trainer.game, opts.n_attributes, opts.n_values, device, False,epoch)

        print(acc_vec.mean(0))
        #print(trainer.optimizer.defaults["lr"])


        # ADDITION TO SAVE MESSAGES
        all_messages=[]
        for x in messages:
            x = x.cpu().numpy()
            all_messages.append(x)
        all_messages = np.asarray(all_messages)

        if epoch%50==0:
            torch.save(sender.state_dict(), opts.dir_save+"/sender/sender_weights"+str(epoch)+".pth")
            torch.save(receiver.state_dict(), opts.dir_save+"/receiver/receiver_weights"+str(epoch)+".pth")

        np.save(opts.dir_save+'/messages/messages_'+str((epoch))+'.npy', all_messages)
        np.save(opts.dir_save+'/accuracy/accuracy_'+str((epoch))+'.npy', acc_vec)
        print(acc_vec.T)

    core.close()
Example #22
0
def main(params):
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    force_eos = opts.force_eos == 1

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32)
    probs /= probs.sum()

    print('the probs are: ', probs, flush=True)

    train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch, probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    if opts.sender_cell == 'transformer':
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding)
        sender = core.TransformerSenderReinforce(agent=sender, vocab_size=opts.vocab_size,
                                                 embed_dim=opts.sender_embedding, max_len=opts.max_len,
                                                 num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads,
                                                 hidden_size=opts.sender_hidden,
                                                 force_eos=opts.force_eos,
                                                 generate_style=opts.sender_generate_style,
                                                 causal=opts.causal_sender)
    else:
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)

        sender = core.RnnSenderReinforce(sender,
                                   opts.vocab_size, opts.sender_embedding, opts.sender_hidden,
                                   cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers,
                                   force_eos=force_eos)
    if opts.receiver_cell == 'transformer':
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_embedding)
        receiver = core.TransformerReceiverDeterministic(receiver, opts.vocab_size, opts.max_len,
                                                         opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden,
                                                         opts.receiver_num_layers, causal=opts.causal_receiver)
    else:
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
        receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_embedding,
                                             opts.receiver_hidden, cell=opts.receiver_cell,
                                             num_layers=opts.receiver_num_layers)

    empty_logger = LoggingStrategy.minimal()
    game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff,
                                           receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                           train_logging_strategy=empty_logger,
                                           length_cost=opts.length_cost)

    optimizer = core.build_optimizer(game.parameters())

    callbacks = [EarlyStopperAccuracy(opts.early_stopping_thr),
               core.ConsoleLogger(as_json=True, print_train_loss=True)]

    if opts.checkpoint_dir:
        checkpoint_name = f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}'
        callbacks.append(core.CheckpointSaver(checkpoint_path=opts.checkpoint_dir, prefix=checkpoint_name))

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=callbacks)

    trainer.train(n_epochs=opts.n_epochs)

    game.logging_strategy = LoggingStrategy.maximal() # now log everything
    dump(trainer.game, opts.n_features, device, False)
    core.close()
Example #23
0
def main(params):
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    force_eos = opts.force_eos == 1

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features + 1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')],
                         dtype=np.float32)
    probs /= probs.sum()

    train_loader = OneHotLoader(n_features=opts.n_features,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch,
                                probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    if opts.sender_cell == 'transformer':
        sender = Sender(n_features=opts.n_features,
                        n_hidden=opts.sender_embedding)
        sender = core.TransformerSenderReinforce(
            agent=sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_embedding,
            max_len=opts.max_len,
            num_layers=opts.sender_num_layers,
            num_heads=opts.sender_num_heads,
            hidden_size=opts.sender_hidden,
            force_eos=opts.force_eos,
            generate_style=opts.sender_generate_style,
            causal=opts.causal_sender)
    else:
        sender = Sender(n_features=opts.n_features,
                        n_hidden=opts.sender_hidden)

        sender = core.RnnSenderReinforce(sender,
                                         opts.vocab_size,
                                         opts.sender_embedding,
                                         opts.sender_hidden,
                                         cell=opts.sender_cell,
                                         max_len=opts.max_len,
                                         num_layers=opts.sender_num_layers,
                                         force_eos=force_eos)
    if opts.receiver_cell == 'transformer':
        receiver = Receiver(n_features=opts.n_features,
                            n_hidden=opts.receiver_embedding)
        receiver = core.TransformerReceiverDeterministic(
            receiver,
            opts.vocab_size,
            opts.max_len,
            opts.receiver_embedding,
            opts.receiver_num_heads,
            opts.receiver_hidden,
            opts.receiver_num_layers,
            causal=opts.causal_receiver)
    else:

        receiver = Receiver(n_features=opts.n_features,
                            n_hidden=opts.receiver_hidden)

        if not opts.impatient:
            receiver = Receiver(n_features=opts.n_features,
                                n_hidden=opts.receiver_hidden)
            receiver = core.RnnReceiverDeterministic(
                receiver,
                opts.vocab_size,
                opts.receiver_embedding,
                opts.receiver_hidden,
                cell=opts.receiver_cell,
                num_layers=opts.receiver_num_layers)
        else:
            receiver = Receiver(n_features=opts.receiver_hidden,
                                n_hidden=opts.vocab_size)
            # If impatient 1
            receiver = RnnReceiverImpatient(
                receiver,
                opts.vocab_size,
                opts.receiver_embedding,
                opts.receiver_hidden,
                cell=opts.receiver_cell,
                num_layers=opts.receiver_num_layers,
                max_len=opts.max_len,
                n_features=opts.n_features)

    sender.load_state_dict(
        torch.load(opts.sender_weights, map_location=torch.device('cpu')))
    receiver.load_state_dict(
        torch.load(opts.receiver_weights, map_location=torch.device('cpu')))

    if not opts.impatient:
        game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff,
            length_cost=opts.length_cost,
            unigram_penalty=opts.unigram_pen)
    else:
        game = SenderImpatientReceiverRnnReinforce(
            sender,
            receiver,
            loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff,
            length_cost=opts.length_cost,
            unigram_penalty=opts.unigram_pen)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)])

    # Debut test position

    position_sieve = np.zeros((opts.n_features, opts.max_len))

    for position in range(opts.max_len):

        dataset = [[torch.eye(opts.n_features).to(device), None]]

        if opts.impatient:
            sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \
                dump_test_position_impatient(trainer.game,
                                    dataset,
                                    position=position,
                                    voc_size=opts.vocab_size,
                                    gs=False,
                                    device=device,
                                    variable_length=True)
        else:
            sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \
                dump_test_position(trainer.game,
                                    dataset,
                                    position=position,
                                    voc_size=opts.vocab_size,
                                    gs=False,
                                    device=device,
                                    variable_length=True)

        acc_pos = []

        for sender_input, message, receiver_output in zip(
                sender_inputs, messages, receiver_outputs):
            input_symbol = sender_input.argmax()
            output_symbol = receiver_output.argmax()
            acc = (input_symbol == output_symbol).float().item()
            acc_pos.append(acc)

        acc_pos = np.array(acc_pos)

        position_sieve[:, position] = acc_pos

    # Put -1 for position after message_length
    _, messages = dump(trainer.game, opts.n_features, device, False)

    # Convert messages to numpy array
    messages_np = []
    for x in messages:
        x = x.cpu().numpy()
        messages_np.append(x)

    for i in range(len(messages_np)):
        # Message i
        message_i = messages_np[i]
        id_0 = np.where(message_i == 0)[0]

        if id_0.shape[0] > 0:
            for j in range(id_0[0] + 1, opts.max_len):
                position_sieve[i, j] = -1

    np.save("analysis/position_sieve.npy", position_sieve)

    core.close()
Example #24
0
    dataset = ImageNetFeat(root=data_folder)

    train_loader = ImagenetLoader(dataset,
                                  batch_size=opts.batch_size,
                                  shuffle=True,
                                  opt=opts,
                                  batches_per_epoch=opts.batches_per_epoch,
                                  seed=None)
    validation_loader = ImagenetLoader(
        dataset,
        opt=opts,
        batch_size=opts.batch_size,
        batches_per_epoch=opts.batches_per_epoch,
        seed=7)
    game = get_game(opts)
    optimizer = core.build_optimizer(game.parameters())
    callback = None
    if opts.mode == 'gs':
        callbacks = [
            core.TemperatureUpdater(agent=game.sender, decay=0.9, minimum=0.1)
        ]
    else:
        callbacks = None
    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=validation_loader,
                           callbacks=callbacks)

    trainer.train(n_epochs=opts.n_epochs)
Example #25
0
def main(params):
    opts = get_params(params)
    print(opts, flush=True)
    device = opts.device

    force_eos = opts.force_eos == 1

    if opts.probs == 'uniform':
        probs = np.ones(opts.n_features)
    elif opts.probs == 'powerlaw':
        probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32)
    elif opts.probs == 'perso':
        probs = opts.n_features+1 - np.arange(1, opts.n_features+1, dtype=np.float32)
    else:
        probs = np.array([float(x) for x in opts.probs.split(',')], dtype=np.float32)
    probs /= probs.sum()

    print('the probs are: ', probs, flush=True)

    train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size,
                                batches_per_epoch=opts.batches_per_epoch, probs=probs)

    # single batches with 1s on the diag
    test_loader = UniformLoader(opts.n_features)

    if opts.sender_cell == 'transformer':
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_embedding)
        sender = core.TransformerSenderReinforce(agent=sender, vocab_size=opts.vocab_size,
                                                 embed_dim=opts.sender_embedding, max_len=opts.max_len,
                                                 num_layers=opts.sender_num_layers, num_heads=opts.sender_num_heads,
                                                 hidden_size=opts.sender_hidden,
                                                 force_eos=opts.force_eos,
                                                 generate_style=opts.sender_generate_style,
                                                 causal=opts.causal_sender)
    else:
        sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)

        sender = core.RnnSenderReinforce(sender,
                                   opts.vocab_size, opts.sender_embedding, opts.sender_hidden,
                                   cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers,
                                   force_eos=force_eos)
    if opts.receiver_cell == 'transformer':
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_embedding)
        receiver = core.TransformerReceiverDeterministic(receiver, opts.vocab_size, opts.max_len,
                                                         opts.receiver_embedding, opts.receiver_num_heads, opts.receiver_hidden,
                                                         opts.receiver_num_layers, causal=opts.causal_receiver)
    else:
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
        receiver = core.RnnReceiverDeterministic(receiver, opts.vocab_size, opts.receiver_embedding,
                                             opts.receiver_hidden, cell=opts.receiver_cell,
                                             num_layers=opts.receiver_num_layers)

    game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff,
                                           receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                           length_cost=opts.length_cost)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr),
                                      core.ConsoleLogger(as_json=True, print_train_loss=True)])

    """ mode accuracy chope a chaque epoch
    accs=[]
    all_messages,acc=dump(trainer.game, opts.n_features, device, False)
    np.save('messages_0.npy', all_messages)
    accs.append(acc)
    for i in range(int(opts.n_epochs)):
        print(i)
        trainer.train(n_epochs=1)
        all_messages,acc=dump(trainer.game, opts.n_features, device, False)
        np.save('messages_'+str((i+1))+'.npy', all_messages)
        accs.append(acc)
    np.save('accuracy.npy',accs)
    """

    trainer.train(n_epochs=opts.n_epochs)

    #if opts.checkpoint_dir:
        #trainer.save_checkpoint(name=f'{opts.name}_vocab{opts.vocab_size}_rs{opts.random_seed}_lr{opts.lr}_shid{opts.sender_hidden}_rhid{opts.receiver_hidden}_sentr{opts.sender_entropy_coeff}_reg{opts.length_cost}_max_len{opts.max_len}')
    #for i in range(30):
    #        for k in range(30):
    #        if i<k:
    #            all_messages=dump(trainer.game, opts.n_features, device, False,pos_m=i,pos_M=k)



    all_messages=dump(trainer.game, opts.n_features, device, False)


    print(all_messages)

    #freq=np.zeros(30)
    #for message in all_messages[0]:
    #        if i in range(message.shape[0]):
    #        freq[int(message[i])]+=1
    #print(freq)

    core.close()