def main(args): """main man""" # reproducibility if args.seed is not None: torch.manual_seed( args.seed) # unsure if this works with SparseMNIST right now np.random.seed(args.seed) # cuda args.use_cuda = not args.no_cuda and torch.cuda.is_available() args.device = torch.device("cuda" if args.use_cuda else "cpu") # data dataset = MNIST('./data/', train=False, transform=ToTensor()) kwargs = {'num_workers': 1} if args.use_cuda: kwargs['pin_memory'] = True data_loader = DataLoader(dataset, args.batch_size, shuffle=True, **kwargs) if args.rounds is None: args.rounds = len(dataset) // args.batch_size # load judge judge_state = torch.load(args.checkpoint)['state_dict'] # debate game judge = Judge().to(args.device) judge.load_state_dict(judge_state) judge.eval() helper = Agent(honest=True, args=args) liar = Agent(honest=False, args=args) debate = Debate((helper, liar), data_loader, args) total_meter = AverageMeter() class_meters = [AverageMeter() for i in range(10)] # TODO precommit logic for _ in range(args.rounds): print("starting round {}".format(_)) helper.precommit_(None, None) liar.precommit_(None, None) result = debate.play(judge, args.device) track_stats_(total_meter, class_meters, result['helper']['preds'], result['helper']['wins'], result['labels'], args.precommit) print('Total accuracy: {}'.format(total_meter.avg)) print('Accuracy per class\n==============================================') for i in range(10): print('Digit {}: {}'.format(i, class_meters[i].avg))
def main(args): # reproducibility if args.seed is not None: torch.manual_seed( args.seed) # don't think this works with SparseMNIST right now np.random.seed(args.seed) if not os.path.isdir(args.checkpoint): mkdir_p(args.checkpoint) if args.checkpoint_filename is None: checkpoint_file = args.checkpoint + str(datetime.now())[:-10] else: checkpoint_file = args.checkpoint + args.checkpoint_filename # cuda args.use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if args.use_cuda else "cpu") # eval? args.evaluate = args.val_batches > 0 # prep sparse mnist if not args.evaluate: train_loader, _, test_loader = prepare_data(args) else: train_loader, val_loader, test_loader = prepare_data(args) # machinery model = Judge().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # setup validation metrics we want to track for tracking best model over training run best_val_loss = float('inf') best_val_acc = 0 print('\n================== TRAINING ==================') model.train() # set model to training mode # set up training metrics we want to track correct = 0 train_num = args.batches * args.batch_size # timer time0 = time.time() for ix, ( sparse, img, label) in enumerate(train_loader): # iterate over training batches sparse, label = sparse.to(device), label.to( device) # get data, send to gpu if needed optimizer.zero_grad( ) # clear parameter gradients from previous training update logits = model(sparse) # forward pass loss = F.cross_entropy(logits, label) # calculate network loss loss.backward() # backward pass optimizer.step( ) # take an optimization step to update model's parameters pred = logits.max(1, keepdim=True)[1] # get the index of the max logit correct += pred.eq( label.view_as(pred)).sum().item() # add to running total of hits if ix % args.log_interval == 0: # maybe log current metrics to terminal print('Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t\ Accuracy: {:.2f}%\tTime: {:0f} min, {:.2f} s'.format( (ix + 1) * len(sparse), train_num, 100. * ix / len(train_loader), loss.item(), 100. * correct / ((ix + 1) * len(sparse)), (time.time() - time0) // 60, (time.time() - time0) % 60)) print( 'Train Accuracy: {}/{} ({:.2f}%)\tTrain Time: {:0f} minutes, {:2f} seconds\n' .format(correct, train_num, 100. * correct / train_num, (time.time() - time0) // 60, (time.time() - time0) % 60)) if args.evaluate: print('\n================== VALIDATION ==================') model.eval() # set up validation metrics we want to track val_loss = 0. val_correct = 0 val_num = args.eval_batch_size * args.val_batches # disable autograd here (replaces volatile flag from v0.3.1 and earlier) with torch.no_grad(): for sparse, img, label in val_loader: sparse, label = sparse.to(device), label.to(device) logits = model(sparse) val_loss += F.cross_entropy(logits, label, size_average=False).item() pred = logits.max(1, keepdim=True)[1] val_correct += pred.eq(label.view_as(pred)).sum().item() # update current evaluation metrics val_loss /= val_num val_acc = 100. * val_correct / val_num print( '\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n' .format(val_loss, val_correct, val_num, val_acc)) is_best = val_acc > best_val_acc if is_best: best_val_acc = val_acc best_val_loss = val_loss # note this is val_loss of best model w.r.t. accuracy, # not the best val_loss throughout training # create checkpoint dictionary and save it; # if is_best, copy the file over to the file containing best model for this run state = { 'state_dict': model.state_dict(), 'optimizer_state': optimizer.state_dict(), 'val_loss': val_loss, 'val_acc': val_acc, } save_checkpoint(state, is_best, checkpoint_file) print('\n================== TESTING ==================') check = torch.load(checkpoint_file + '-best.pth.tar') model.load_state_dict(check['state_dict']) model.eval() test_loss = 0. test_correct = 0 test_num = args.eval_batch_size * args.test_batches # disable autograd here (replaces volatile flag from v0.3.1 and earlier) with torch.no_grad(): for sparse, img, label in test_loader: sparse, label = sparse.to(device), label.to(device) logits = model(sparse) test_loss += F.cross_entropy(logits, label, size_average=False).item() pred = logits.max( 1, keepdim=True)[1] # get the index of the max logit test_correct += pred.eq(label.view_as(pred)).sum().item() test_loss /= test_num print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( test_loss, test_correct, test_num, 100. * test_correct / test_num)) print('Final model stored at "{}".'.format(checkpoint_file + '-best.pth.tar'))