def main(params): opts = get_params(params) print(opts) label_mapping = torch.LongTensor([x % opts.n_labels for x in range(100)]) print('# label mapping', label_mapping.tolist()) kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {} transform = transforms.ToTensor() train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, **kwargs) train_loader = DoubleMnist(train_loader, label_mapping) test_dataset = datasets.MNIST('./data', train=False, transform=transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16 * 1024, shuffle=False, **kwargs) test_loader = DoubleMnist(test_loader, label_mapping) sender = Sender(vocab_size=opts.vocab_size) receiver = Receiver(vocab_size=opts.vocab_size, n_classes=opts.n_labels, n_hidden=opts.n_hidden) sender = core.GumbelSoftmaxWrapper(sender, temperature=opts.temperature) game = core.SymbolGameGS(sender, receiver, diff_loss_symbol) optimizer = core.build_optimizer(game.parameters()) intervention = CallbackEvaluator(test_loader, device=opts.device, loss=game.loss, is_gs=True, var_length=False, input_intervention=False) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[ core.ConsoleLogger(as_json=True), EarlyStopperAccuracy(opts.early_stopping_thr), intervention ]) trainer.train(n_epochs=opts.n_epochs) core.close()
def main(params): opts = get_params(params) print(json.dumps(vars(opts))) kwargs = {'num_workers': 1, 'pin_memory': True} if opts.cuda else {} transform = transforms.ToTensor() train_loader = torch.utils.data.DataLoader( datasets.MNIST('./data', train=True, download=True, transform=transform), batch_size=opts.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( datasets.MNIST('./data', train=False, transform=transform), batch_size=16 * 1024, shuffle=False, **kwargs) n_classes = 10 binarize = False test_loader = SplitImages(TakeFirstLoader(test_loader, n=1), rows_receiver=opts.receiver_rows, rows_sender=opts.sender_rows, binarize=binarize, receiver_bottom=True) train_loader = SplitImages(train_loader, rows_sender=opts.sender_rows, rows_receiver=opts.receiver_rows, binarize=binarize, receiver_bottom=True) sender = Sender(vocab_size=opts.vocab_size) receiver = Receiver(vocab_size=opts.vocab_size, n_classes=n_classes) sender = core.GumbelSoftmaxWrapper(sender, temperature=opts.temperature) game = core.SymbolGameGS(sender, receiver, diff_loss_symbol) optimizer = core.build_optimizer(game.parameters()) intervention = CallbackEvaluator(test_loader, device=opts.device, loss=game.loss, is_gs=True, var_length=False, input_intervention=True) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[core.ConsoleLogger(as_json=True), EarlyStopperAccuracy(opts.early_stopping_thr), intervention]) trainer.train(n_epochs=opts.n_epochs) core.close()