def main(params): opts = get_params(params) print(json.dumps(vars(opts))) kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {} transform = transforms.ToTensor() train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST('./data', train=False, download=False, transform=transform) n_classes = 10 corrupt_labels_(dataset=train_dataset, p_corrupt=opts.p_corrupt, seed=opts.random_seed+1) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=opts.batch_size, shuffle=False, **kwargs) deeper_alice = opts.deeper_alice == 1 and opts.deeper == 1 deeper_bob = opts.deeper_alice != 1 and opts.deeper == 1 sender = Sender(vocab_size=opts.vocab_size, deeper=deeper_alice, linear_channel=opts.linear_channel == 1, softmax_channel=opts.softmax_non_linearity == 1) receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes, deeper=deeper_bob) if opts.softmax_non_linearity != 1 and opts.linear_channel != 1: sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature) game = core.SymbolGameGS(sender, receiver, diff_loss_symbol) optimizer = core.build_optimizer(game.parameters()) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[core.ConsoleLogger(as_json=True, print_train_loss=True), EarlyStopperAccuracy(opts.early_stopping_thr)] ) trainer.train(n_epochs=opts.n_epochs) core.close()
def main(params): opts = get_params(params) print(opts) kwargs = {"num_workers": 1, "pin_memory": True} if opts.cuda else {} transform = transforms.ToTensor() train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform) test_dataset = datasets.MNIST("./data", train=False, download=False, transform=transform) n_classes = 10 corrupt_labels_(dataset=train_dataset, p_corrupt=opts.p_corrupt, seed=opts.random_seed + 1) label_mapping = torch.LongTensor([x % n_classes for x in range(100)]) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, **kwargs) train_loader = DoubleMnist(train_loader, label_mapping) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16 * 1024, shuffle=False, **kwargs) test_loader = DoubleMnist(test_loader, label_mapping) deeper_alice = opts.deeper_alice == 1 and opts.deeper == 1 deeper_bob = opts.deeper_alice != 1 and opts.deeper == 1 sender = Sender( vocab_size=opts.vocab_size, deeper=deeper_alice, linear_channel=opts.linear_channel == 1, softmax_channel=opts.softmax_non_linearity == 1, ) receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes, deeper=deeper_bob) if (opts.softmax_non_linearity != 1 and opts.linear_channel != 1 and opts.force_discrete != 1): sender = AlwaysRelaxedWrapper(sender, temperature=opts.temperature) elif opts.force_discrete == 1: sender = core.GumbelSoftmaxWrapper(sender, temperature=opts.temperature) game = core.SymbolGameGS(sender, receiver, diff_loss_symbol) optimizer = core.build_optimizer(game.parameters()) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[ core.ConsoleLogger(as_json=True, print_train_loss=True), EarlyStopperAccuracy(opts.early_stopping_thr), ], ) trainer.train(n_epochs=opts.n_epochs) core.close()