コード例 #1
0
def main(params):
    opts = get_params(params)
    print(opts)

    label_mapping = torch.LongTensor([x % opts.n_labels for x in range(100)])
    print('# label mapping', label_mapping.tolist())

    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST('./data',
                                   train=True,
                                   download=True,
                                   transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               **kwargs)
    train_loader = DoubleMnist(train_loader, label_mapping)

    test_dataset = datasets.MNIST('./data', train=False, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16 * 1024,
                                              shuffle=False,
                                              **kwargs)
    test_loader = DoubleMnist(test_loader, label_mapping)

    sender = Sender(vocab_size=opts.vocab_size)
    receiver = Receiver(vocab_size=opts.vocab_size,
                        n_classes=opts.n_labels,
                        n_hidden=opts.n_hidden)
    sender = core.GumbelSoftmaxWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    intervention = CallbackEvaluator(test_loader,
                                     device=opts.device,
                                     loss=game.loss,
                                     is_gs=True,
                                     var_length=False,
                                     input_intervention=False)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True),
                               EarlyStopperAccuracy(opts.early_stopping_thr),
                               intervention
                           ])

    trainer.train(n_epochs=opts.n_epochs)
    core.close()
コード例 #2
0
def main(params):
    opts = get_params(params)
    print(json.dumps(vars(opts)))

    kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {}
    transform = transforms.ToTensor()

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=True, download=True,
           transform=transform),
           batch_size=opts.batch_size, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transform),
           batch_size=16 * 1024, shuffle=False, **kwargs)

    n_classes = 10

    binarize = False

    test_loader = SplitImages(TakeFirstLoader(test_loader, n=1), rows_receiver=opts.receiver_rows,
                              rows_sender=opts.sender_rows, binarize=binarize, receiver_bottom=True)
    train_loader = SplitImages(train_loader, rows_sender=opts.sender_rows, rows_receiver=opts.receiver_rows,
                               binarize=binarize, receiver_bottom=True)

    sender = Sender(vocab_size=opts.vocab_size)
    receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes)
    sender = core.GumbelSoftmaxWrapper(sender, temperature=opts.temperature)

    game = core.SymbolGameGS(sender, receiver, diff_loss_symbol)

    optimizer = core.build_optimizer(game.parameters())

    intervention = CallbackEvaluator(test_loader, device=opts.device, loss=game.loss,
                                     is_gs=True,
                                     var_length=False,
                                     input_intervention=True)

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[core.ConsoleLogger(as_json=True),
                                      EarlyStopperAccuracy(opts.early_stopping_thr),
                                      intervention])

    trainer.train(n_epochs=opts.n_epochs)
    core.close()