Ejemplo n.º 1
0
def main(params):
    opts = get_params(params)
    print(json.dumps(vars(opts)))

    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST('./data', train=True, download=True,
                   transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, download=False,
                   transform=transform)
    n_classes = 10

    corrupt_labels_(dataset=train_dataset, p_corrupt=opts.p_corrupt, seed=opts.random_seed+1)

    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=opts.batch_size, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(test_dataset,
        batch_size=opts.batch_size, shuffle=False, **kwargs)

    deeper_alice = opts.deeper_alice == 1 and opts.deeper == 1
    deeper_bob = opts.deeper_alice != 1 and opts.deeper == 1

    sender = Sender(vocab_size=opts.vocab_size, deeper=deeper_alice, linear_channel=opts.linear_channel == 1,
                    softmax_channel=opts.softmax_non_linearity == 1)
    receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes, deeper=deeper_bob)

    if opts.softmax_non_linearity != 1 and opts.linear_channel != 1:
        sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[core.ConsoleLogger(as_json=True, print_train_loss=True),
                                      EarlyStopperAccuracy(opts.early_stopping_thr)]
                           )

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
Ejemplo n.º 2
0
def main(params):
    opts = get_params(params)
    print(opts)

    kwargs = {"num_workers": 1, "pin_memory": True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST("./data",
                                   train=True,
                                   download=True,
                                   transform=transform)
    test_dataset = datasets.MNIST("./data",
                                  train=False,
                                  download=False,
                                  transform=transform)

    n_classes = 10

    corrupt_labels_(dataset=train_dataset,
                    p_corrupt=opts.p_corrupt,
                    seed=opts.random_seed + 1)
    label_mapping = torch.LongTensor([x % n_classes for x in range(100)])

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)
    train_loader = DoubleMnist(train_loader, label_mapping)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16 * 1024,
                                              shuffle=False,
                                              **kwargs)
    test_loader = DoubleMnist(test_loader, label_mapping)

    deeper_alice = opts.deeper_alice == 1 and opts.deeper == 1
    deeper_bob = opts.deeper_alice != 1 and opts.deeper == 1

    sender = Sender(
        vocab_size=opts.vocab_size,
        deeper=deeper_alice,
        linear_channel=opts.linear_channel == 1,
        softmax_channel=opts.softmax_non_linearity == 1,
    )
    receiver = Receiver(vocab_size=opts.vocab_size,
                        n_classes=n_classes,
                        deeper=deeper_bob)

    if (opts.softmax_non_linearity != 1 and opts.linear_channel != 1
            and opts.force_discrete != 1):
        sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature)
    elif opts.force_discrete == 1:
        sender = core.GumbelSoftmaxWrapper(sender,
                                           temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=[
            core.ConsoleLogger(as_json=True, print_train_loss=True),
            EarlyStopperAccuracy(opts.early_stopping_thr),
        ],
    )

    trainer.train(n_epochs=opts.n_epochs)
    core.close()