Exemplo n.º 1
0
def test_snapshoting():
    CHECKPOINT_PATH = Path("./test_checkpoints")

    core.init()
    sender = core.GumbelSoftmaxWrapper(ToyAgent(), temperature=1)
    receiver = Receiver()
    loss = lambda sender_input, message, receiver_input, receiver_output, labels, aux_input: (
        F.cross_entropy(receiver_output, labels),
        {},
    )

    game = core.SymbolGameGS(sender, receiver, loss)
    optimizer = torch.optim.Adam(game.parameters())

    data = Dataset()
    trainer = core.Trainer(
        game,
        optimizer,
        train_data=data,
        validation_data=None,
        callbacks=[core.CheckpointSaver(checkpoint_path=CHECKPOINT_PATH)],
    )
    trainer.train(2)
    assert (CHECKPOINT_PATH / Path("1.tar")).exists()
    assert (CHECKPOINT_PATH / Path("2.tar")).exists()
    assert (CHECKPOINT_PATH / Path("final.tar")).exists()
    shutil.rmtree(CHECKPOINT_PATH)  # Clean-up
    """
Exemplo n.º 2
0
def get_game(opt):
    feat_size = 4096
    sender = InformedSender(opt.game_size,
                            feat_size,
                            opt.embedding_size,
                            opt.hidden_size,
                            opt.vocab_size,
                            temp=opt.tau_s)
    receiver = Receiver(opt.game_size,
                        feat_size,
                        opt.embedding_size,
                        opt.vocab_size,
                        reinforce=(opts.mode == 'rf'))
    if opts.mode == 'rf':
        sender = core.ReinforceWrapper(sender)
        receiver = core.ReinforceWrapper(receiver)
        game = core.SymbolGameReinforce(sender,
                                        receiver,
                                        loss,
                                        sender_entropy_coeff=0.01,
                                        receiver_entropy_coeff=0.01)
    elif opts.mode == 'gs':
        sender = core.GumbelSoftmaxWrapper(sender, temperature=opt.gs_tau)
        game = core.SymbolGameGS(sender, receiver, loss_nll)
    else:
        raise RuntimeError(f"Unknown training mode: {opts.mode}")

    return game
Exemplo n.º 3
0
def test_snapshoting():
    CHECKPOINT_PATH = Path('./test_checkpoints')

    core.init()
    sender = core.GumbelSoftmaxWrapper(ToyAgent(), temperature=1)
    receiver = Receiver()
    loss = lambda sender_input, message, receiver_input, receiver_output, labels: \
        (F.cross_entropy(receiver_output, labels), {})

    game = core.SymbolGameGS(sender, receiver, loss)
    optimizer = torch.optim.Adam(game.parameters())

    data = Dataset()
    trainer = core.Trainer(
        game,
        optimizer,
        train_data=data,
        validation_data=None,
        callbacks=[core.CheckpointSaver(checkpoint_path=CHECKPOINT_PATH)])
    trainer.train(2)
    assert (CHECKPOINT_PATH / Path('1.tar')).exists()
    assert (CHECKPOINT_PATH / Path('2.tar')).exists()
    assert (CHECKPOINT_PATH / Path('final.tar')).exists()
    del trainer
    trainer = core.Trainer(game, optimizer,
                           train_data=data)  # Re-instantiate trainer
    trainer.load_from_latest(CHECKPOINT_PATH)
    assert trainer.start_epoch == 2
    trainer.train(3)
    shutil.rmtree(CHECKPOINT_PATH)  # Clean-up
Exemplo n.º 4
0
def main(params):
    opts = get_params(params)
    print(json.dumps(vars(opts)))

    device = opts.device

    train_loader = OneHotLoader(n_bits=opts.n_bits,
                                bits_s=opts.bits_s,
                                bits_r=opts.bits_r,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.n_examples_per_epoch/opts.batch_size)

    test_loader = UniformLoader(n_bits=opts.n_bits, bits_s=opts.bits_s, bits_r=opts.bits_r)
    test_loader.batch = [x.to(device) for x in test_loader.batch]

    sender = Sender(n_bits=opts.n_bits, n_hidden=opts.sender_hidden,
                    vocab_size=opts.vocab_size)

    if opts.mode == 'gs':
        receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden, vocab_size=opts.vocab_size)
        sender = core.GumbelSoftmaxWrapper(agent=sender, temperature=opts.temperature)
        game = core.SymbolGameGS(sender, receiver, diff_loss)
    elif opts.mode == 'rf':
        sender = core.ReinforceWrapper(agent=sender)
        receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden, vocab_size=opts.vocab_size)
        receiver = core.ReinforceDeterministicWrapper(agent=receiver)
        game = core.SymbolGameReinforce(sender, receiver, diff_loss, sender_entropy_coeff=opts.sender_entropy_coeff)
    elif opts.mode == 'non_diff':
        sender = core.ReinforceWrapper(agent=sender)
        receiver = ReinforcedReceiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden, vocab_size=opts.vocab_size)
        game = core.SymbolGameReinforce(sender, receiver, non_diff_loss,
                                        sender_entropy_coeff=opts.sender_entropy_coeff,
                                        receiver_entropy_coeff=opts.receiver_entropy_coeff)

    else:
        assert False, 'Unknown training mode'

    optimizer = torch.optim.Adam(
        [
            dict(params=sender.parameters(), lr=opts.sender_lr),
            dict(params=receiver.parameters(), lr=opts.receiver_lr)
        ])

    loss = game.loss

    intervention = CallbackEvaluator(test_loader, device=device, is_gs=opts.mode == 'gs', loss=loss, var_length=False,
                                     input_intervention=True)

    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr)

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader, epoch_callback=intervention, as_json=True,
                           early_stopping=early_stopper)

    trainer.train(n_epochs=opts.n_epochs)

    core.close()
Exemplo n.º 5
0
def main(params):
    opts = get_params(params)
    print(opts)

    kwargs = {"num_workers": 1, "pin_memory": True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST("./data",
                                   train=True,
                                   download=True,
                                   transform=transform)
    test_dataset = datasets.MNIST("./data",
                                  train=False,
                                  download=False,
                                  transform=transform)
    n_classes = 10
    label_mapping = torch.LongTensor([x % n_classes for x in range(100)])

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)
    train_loader = DoubleMnist(train_loader, label_mapping)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16 * 1024,
                                              shuffle=False,
                                              **kwargs)
    test_loader = DoubleMnist(test_loader, label_mapping)

    sender = Sender(
        vocab_size=opts.vocab_size,
        linear_channel=opts.linear_channel == 1,
        softmax_channel=opts.softmax_non_linearity,
    )
    receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes)

    if opts.softmax_non_linearity == 0 and opts.linear_channel == 0:
        sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=[
            core.ConsoleLogger(as_json=True, print_train_loss=True),
            EarlyStopperAccuracy(opts.early_stopping_thr),
        ],
    )

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
Exemplo n.º 6
0
def main(params):
    opts = get_params(params)
    print(opts)

    label_mapping = torch.LongTensor([x % opts.n_labels for x in range(100)])
    print('# label mapping', label_mapping.tolist())

    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST('./data',
                                   train=True,
                                   download=True,
                                   transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)
    train_loader = DoubleMnist(train_loader, label_mapping)

    test_dataset = datasets.MNIST('./data', train=False, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16 * 1024,
                                              shuffle=False,
                                              **kwargs)
    test_loader = DoubleMnist(test_loader, label_mapping)

    sender = Sender(vocab_size=opts.vocab_size)
    receiver = Receiver(vocab_size=opts.vocab_size,
                        n_classes=opts.n_labels,
                        n_hidden=opts.n_hidden)
    sender = core.GumbelSoftmaxWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    intervention = CallbackEvaluator(test_loader,
                                     device=opts.device,
                                     loss=game.loss,
                                     is_gs=True,
                                     var_length=False,
                                     input_intervention=False)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True),
                               EarlyStopperAccuracy(opts.early_stopping_thr),
                               intervention
                           ])

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
Exemplo n.º 7
0
def main(params):
    # initialize the egg lib
    opts = core.init(params=params)
    # get pre-defined common line arguments (batch/vocab size, etc).
    # See egg/core/util.py for a list

    # prepare the dataset
    kwargs = {"num_workers": 1, "pin_memory": True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        "./data", train=True, download=True, transform=transform),
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        "./data", train=False, transform=transform),
                                              batch_size=opts.batch_size,
                                              shuffle=True,
                                              **kwargs)

    # initialize the agents and the game
    sender = Sender(opts.vocab_size)  # the "data" transform part of an agent
    sender = core.GumbelSoftmaxWrapper(
        sender, temperature=1.0)  # wrapping into a GS interface

    receiver = Receiver()
    receiver = core.SymbolReceiverWrapper(receiver,
                                          vocab_size=opts.vocab_size,
                                          agent_input_size=400)
    # setting up as a standard Sender/Receiver game with 1 symbol communication
    game = core.SymbolGameGS(sender, receiver, loss)
    # This callback would be called at the end of each epoch by the Trainer; it reduces the sampling
    # temperature used by the GS
    temperature_updater = core.TemperatureUpdater(agent=sender,
                                                  decay=0.75,
                                                  minimum=0.01)
    # get an optimizer that is set up by common command line parameters,
    # defaults to Adam
    optimizer = core.build_optimizer(game.parameters())

    # initialize and launch the trainer
    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=[
            temperature_updater,
            core.ConsoleLogger(as_json=True, print_train_loss=True),
        ],
    )
    trainer.train(n_epochs=opts.n_epochs)

    core.close()
Exemplo n.º 8
0
def main(params):
    opts = get_params(params)
    print(json.dumps(vars(opts)))

    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST('./data',
                                   train=True,
                                   download=True,
                                   transform=transform)
    test_dataset = datasets.MNIST('./data',
                                  train=False,
                                  download=False,
                                  transform=transform)
    n_classes = 10

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=opts.batch_size,
                                              shuffle=False,
                                              **kwargs)

    sender = Sender(vocab_size=opts.vocab_size,
                    linear_channel=opts.linear_channel == 1,
                    softmax_channel=opts.softmax_non_linearity)
    receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes)

    if opts.softmax_non_linearity == 0 and opts.linear_channel == 0:
        sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           as_json=True,
                           early_stopping=early_stopper,
                           print_train_loss=True)

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
def test_game_gs():
    core.init()
    sender = core.GumbelSoftmaxWrapper(ToyAgent())
    receiver = Receiver()
    loss = lambda sender_input, message, receiver_input, receiver_output, labels: \
        (F.cross_entropy(receiver_output, labels), {})

    game = core.SymbolGameGS(sender, receiver, loss)
    optimizer = torch.optim.Adam(game.parameters())

    data = Dataset()
    trainer = core.Trainer(game, optimizer, train_data=data, validation_data=None)
    trainer.train(1000)

    assert (sender.agent.fc1.weight.t().argmax(dim=1).cpu() == BATCH_Y).all()
Exemplo n.º 10
0
def main(params):
    opts = get_params(params)
    print(json.dumps(vars(opts)))

    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=True, download=True,
           transform=transform),
           batch_size=opts.batch_size, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transform),
           batch_size=16 * 1024, shuffle=False, **kwargs)

    n_classes = 10

    binarize = False

    test_loader = SplitImages(TakeFirstLoader(test_loader, n=1), rows_receiver=opts.receiver_rows,
                              rows_sender=opts.sender_rows, binarize=binarize, receiver_bottom=True)
    train_loader = SplitImages(train_loader, rows_sender=opts.sender_rows, rows_receiver=opts.receiver_rows,
                               binarize=binarize, receiver_bottom=True)

    sender = Sender(vocab_size=opts.vocab_size)
    receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes)
    sender = core.GumbelSoftmaxWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    intervention = CallbackEvaluator(test_loader, device=opts.device, loss=game.loss,
                                     is_gs=True,
                                     var_length=False,
                                     input_intervention=True)

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[core.ConsoleLogger(as_json=True),
                                      EarlyStopperAccuracy(opts.early_stopping_thr),
                                      intervention])

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
Exemplo n.º 11
0
def test_temperature_updater_callback():
    core.init()
    sender = core.GumbelSoftmaxWrapper(ToyAgent(), temperature=1)
    receiver = Receiver()
    loss = lambda sender_input, message, receiver_input, receiver_output, labels: \
        (F.cross_entropy(receiver_output, labels), {})

    game = core.SymbolGameGS(sender, receiver, loss)
    optimizer = torch.optim.Adam(game.parameters())

    data = Dataset()
    trainer = core.Trainer(
        game,
        optimizer,
        train_data=data,
        validation_data=None,
        callbacks=[core.TemperatureUpdater(agent=sender, decay=0.9)])
    trainer.train(1)
    assert sender.temperature == 0.9
Exemplo n.º 12
0
def main(params):
    opts = get_params(params)
    print(json.dumps(vars(opts)))

    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST('./data', train=True, download=True,
                   transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, download=False,
                   transform=transform)
    n_classes = 10

    corrupt_labels_(dataset=train_dataset, p_corrupt=opts.p_corrupt, seed=opts.random_seed+1)

    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=opts.batch_size, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(test_dataset,
        batch_size=opts.batch_size, shuffle=False, **kwargs)

    deeper_alice = opts.deeper_alice == 1 and opts.deeper == 1
    deeper_bob = opts.deeper_alice != 1 and opts.deeper == 1

    sender = Sender(vocab_size=opts.vocab_size, deeper=deeper_alice, linear_channel=opts.linear_channel == 1,
                    softmax_channel=opts.softmax_non_linearity == 1)
    receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes, deeper=deeper_bob)

    if opts.softmax_non_linearity != 1 and opts.linear_channel != 1:
        sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[core.ConsoleLogger(as_json=True, print_train_loss=True),
                                      EarlyStopperAccuracy(opts.early_stopping_thr)]
                           )

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
Exemplo n.º 13
0
def get_my_game(opt):
    feat_size = 4096
    out_hidden_size = 20
    emb_size = 10
    sender = InformedSender(opt.game_size, feat_size,
                            opt.embedding_size, opt.hidden_size, out_hidden_size,
                            temp=opt.tau_s)
    receiver = MyReceiver(opt.game_size, feat_size,
                        opt.embedding_size, out_hidden_size, reinforce=(opts.mode == 'rf'))
    if opts.mode == 'rf':
        sender = core.MyRnnSenderReinforce(sender, opt.vocab_size, emb_size, out_hidden_size,
                                   cell="gru", max_len=opt.max_len)
        receiver = core.RnnReceiverReinforce(receiver, opt.vocab_size, emb_size,
                   out_hidden_size, cell="gru")
        game = core.SenderReceiverRnnReinforce(
            sender, receiver, loss, sender_entropy_coeff=0.01, receiver_entropy_coeff=0.01)
    elif opts.mode == 'gs':
        sender = core.GumbelSoftmaxWrapper(sender, temperature=opt.gs_tau)
        game = core.SymbolGameGS(sender, receiver, loss_nll)
    else:
        raise RuntimeError(f"Unknown training mode: {opts.mode}")

    return game
Exemplo n.º 14
0
def test_max_snapshoting():
    CHECKPOINT_PATH = Path("./test_checkpoints")

    core.init()
    sender = core.GumbelSoftmaxWrapper(ToyAgent(), temperature=1)
    receiver = Receiver()
    loss = lambda sender_input, message, receiver_input, receiver_output, labels: (
        F.cross_entropy(receiver_output, labels),
        {},
    )

    game = core.SymbolGameGS(sender, receiver, loss)
    optimizer = torch.optim.Adam(game.parameters())

    data = Dataset()
    trainer = core.Trainer(
        game,
        optimizer,
        train_data=data,
        validation_data=None,
        callbacks=[
            core.CheckpointSaver(checkpoint_path=CHECKPOINT_PATH,
                                 max_checkpoints=2)
        ],
    )
    trainer.train(n_epochs=6)
    assert (CHECKPOINT_PATH / Path("5.tar")).exists()
    assert (CHECKPOINT_PATH / Path("6.tar")).exists()
    assert (CHECKPOINT_PATH / Path("final.tar")).exists()
    assert len([x for x in CHECKPOINT_PATH.glob("**/*") if x.is_file()]) == 3
    del trainer
    trainer = core.Trainer(game, optimizer,
                           train_data=data)  # Re-instantiate trainer
    trainer.load_from_latest(CHECKPOINT_PATH)
    assert trainer.start_epoch == 6
    trainer.train(3)
    shutil.rmtree(CHECKPOINT_PATH)  # Clean-up
Exemplo n.º 15
0
def main(params):
    opts = get_params(params)
    print(opts)

    device = opts.device

    train_loader = OneHotLoader(n_bits=opts.n_bits,
                                bits_s=opts.bits_s,
                                bits_r=opts.bits_r,
                                batch_size=opts.batch_size,
                                batches_per_epoch=opts.n_examples_per_epoch /
                                opts.batch_size)

    test_loader = UniformLoader(n_bits=opts.n_bits,
                                bits_s=opts.bits_s,
                                bits_r=opts.bits_r)
    test_loader.batch = [x.to(device) for x in test_loader.batch]

    if not opts.variable_length:
        sender = Sender(n_bits=opts.n_bits,
                        n_hidden=opts.sender_hidden,
                        vocab_size=opts.vocab_size)
        if opts.mode == 'gs':
            sender = core.GumbelSoftmaxWrapper(agent=sender,
                                               temperature=opts.temperature)
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            receiver = core.SymbolReceiverWrapper(
                receiver,
                vocab_size=opts.vocab_size,
                agent_input_size=opts.receiver_hidden)
            game = core.SymbolGameGS(sender, receiver, diff_loss)
        elif opts.mode == 'rf':
            sender = core.ReinforceWrapper(agent=sender)
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            receiver = core.SymbolReceiverWrapper(
                receiver,
                vocab_size=opts.vocab_size,
                agent_input_size=opts.receiver_hidden)
            receiver = core.ReinforceDeterministicWrapper(agent=receiver)
            game = core.SymbolGameReinforce(
                sender,
                receiver,
                diff_loss,
                sender_entropy_coeff=opts.sender_entropy_coeff)
        elif opts.mode == 'non_diff':
            sender = core.ReinforceWrapper(agent=sender)
            receiver = ReinforcedReceiver(n_bits=opts.n_bits,
                                          n_hidden=opts.receiver_hidden)
            receiver = core.SymbolReceiverWrapper(
                receiver,
                vocab_size=opts.vocab_size,
                agent_input_size=opts.receiver_hidden)

            game = core.SymbolGameReinforce(
                sender,
                receiver,
                non_diff_loss,
                sender_entropy_coeff=opts.sender_entropy_coeff,
                receiver_entropy_coeff=opts.receiver_entropy_coeff)
    else:
        if opts.mode != 'rf':
            print('Only mode=rf is supported atm')
            opts.mode = 'rf'

        if opts.sender_cell == 'transformer':
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            sender = Sender(
                n_bits=opts.n_bits,
                n_hidden=opts.sender_hidden,
                vocab_size=opts.sender_hidden)  # TODO: not really vocab
            sender = core.TransformerSenderReinforce(
                agent=sender,
                vocab_size=opts.vocab_size,
                embed_dim=opts.sender_emb,
                max_len=opts.max_len,
                num_layers=1,
                num_heads=1,
                hidden_size=opts.sender_hidden)
        else:
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            sender = Sender(
                n_bits=opts.n_bits,
                n_hidden=opts.sender_hidden,
                vocab_size=opts.sender_hidden)  # TODO: not really vocab
            sender = core.RnnSenderReinforce(agent=sender,
                                             vocab_size=opts.vocab_size,
                                             embed_dim=opts.sender_emb,
                                             hidden_size=opts.sender_hidden,
                                             max_len=opts.max_len,
                                             cell=opts.sender_cell)

        if opts.receiver_cell == 'transformer':
            receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_emb)
            receiver = core.TransformerReceiverDeterministic(
                receiver,
                opts.vocab_size,
                opts.max_len,
                opts.receiver_emb,
                num_heads=1,
                hidden_size=opts.receiver_hidden,
                num_layers=1)
        else:
            receiver = Receiver(n_bits=opts.n_bits,
                                n_hidden=opts.receiver_hidden)
            receiver = core.RnnReceiverDeterministic(receiver,
                                                     opts.vocab_size,
                                                     opts.receiver_emb,
                                                     opts.receiver_hidden,
                                                     cell=opts.receiver_cell)

            game = core.SenderReceiverRnnGS(sender, receiver, diff_loss)

        game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            diff_loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff)

    optimizer = torch.optim.Adam([
        dict(params=sender.parameters(), lr=opts.sender_lr),
        dict(params=receiver.parameters(), lr=opts.receiver_lr)
    ])

    loss = game.loss

    intervention = CallbackEvaluator(test_loader,
                                     device=device,
                                     is_gs=opts.mode == 'gs',
                                     loss=loss,
                                     var_length=opts.variable_length,
                                     input_intervention=True)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True),
                               EarlyStopperAccuracy(opts.early_stopping_thr),
                               intervention
                           ])

    trainer.train(n_epochs=opts.n_epochs)

    core.close()