예제 #1
0
파일: main.py 프로젝트: jvmncs/safe-debates
def main(args):
    """main man"""
    # reproducibility
    if args.seed is not None:
        torch.manual_seed(
            args.seed)  # unsure if this works with SparseMNIST right now
        np.random.seed(args.seed)

    # cuda
    args.use_cuda = not args.no_cuda and torch.cuda.is_available()
    args.device = torch.device("cuda" if args.use_cuda else "cpu")

    # data
    dataset = MNIST('./data/', train=False, transform=ToTensor())
    kwargs = {'num_workers': 1}
    if args.use_cuda:
        kwargs['pin_memory'] = True
    data_loader = DataLoader(dataset, args.batch_size, shuffle=True, **kwargs)
    if args.rounds is None:
        args.rounds = len(dataset) // args.batch_size

    # load judge
    judge_state = torch.load(args.checkpoint)['state_dict']

    # debate game
    judge = Judge().to(args.device)
    judge.load_state_dict(judge_state)
    judge.eval()
    helper = Agent(honest=True, args=args)
    liar = Agent(honest=False, args=args)
    debate = Debate((helper, liar), data_loader, args)

    total_meter = AverageMeter()
    class_meters = [AverageMeter() for i in range(10)]

    # TODO precommit logic
    for _ in range(args.rounds):
        print("starting round {}".format(_))
        helper.precommit_(None, None)
        liar.precommit_(None, None)
        result = debate.play(judge, args.device)
        track_stats_(total_meter, class_meters, result['helper']['preds'],
                     result['helper']['wins'], result['labels'],
                     args.precommit)

    print('Total accuracy: {}'.format(total_meter.avg))
    print('Accuracy per class\n==============================================')
    for i in range(10):
        print('Digit {}: {}'.format(i, class_meters[i].avg))
예제 #2
0
def main(args):
    # reproducibility
    if args.seed is not None:
        torch.manual_seed(
            args.seed)  # don't think this works with SparseMNIST right now
        np.random.seed(args.seed)
    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)
    if args.checkpoint_filename is None:
        checkpoint_file = args.checkpoint + str(datetime.now())[:-10]
    else:
        checkpoint_file = args.checkpoint + args.checkpoint_filename

    # cuda
    args.use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if args.use_cuda else "cpu")

    # eval?
    args.evaluate = args.val_batches > 0

    # prep sparse mnist
    if not args.evaluate:
        train_loader, _, test_loader = prepare_data(args)
    else:
        train_loader, val_loader, test_loader = prepare_data(args)

    # machinery
    model = Judge().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # setup validation metrics we want to track for tracking best model over training run
    best_val_loss = float('inf')
    best_val_acc = 0

    print('\n================== TRAINING ==================')
    model.train()  # set model to training mode

    # set up training metrics we want to track
    correct = 0
    train_num = args.batches * args.batch_size

    # timer
    time0 = time.time()

    for ix, (
            sparse, img,
            label) in enumerate(train_loader):  # iterate over training batches
        sparse, label = sparse.to(device), label.to(
            device)  # get data, send to gpu if needed
        optimizer.zero_grad(
        )  # clear parameter gradients from previous training update
        logits = model(sparse)  # forward pass
        loss = F.cross_entropy(logits, label)  # calculate network loss
        loss.backward()  # backward pass
        optimizer.step(
        )  # take an optimization step to update model's parameters

        pred = logits.max(1, keepdim=True)[1]  # get the index of the max logit
        correct += pred.eq(
            label.view_as(pred)).sum().item()  # add to running total of hits

        if ix % args.log_interval == 0:  # maybe log current metrics to terminal
            print('Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t\
                Accuracy: {:.2f}%\tTime: {:0f} min, {:.2f} s'.format(
                (ix + 1) * len(sparse), train_num,
                100. * ix / len(train_loader), loss.item(),
                100. * correct / ((ix + 1) * len(sparse)),
                (time.time() - time0) // 60, (time.time() - time0) % 60))

    print(
        'Train Accuracy: {}/{} ({:.2f}%)\tTrain Time: {:0f} minutes, {:2f} seconds\n'
        .format(correct, train_num, 100. * correct / train_num,
                (time.time() - time0) // 60, (time.time() - time0) % 60))

    if args.evaluate:
        print('\n================== VALIDATION ==================')
        model.eval()

        # set up validation metrics we want to track
        val_loss = 0.
        val_correct = 0
        val_num = args.eval_batch_size * args.val_batches

        # disable autograd here (replaces volatile flag from v0.3.1 and earlier)
        with torch.no_grad():
            for sparse, img, label in val_loader:
                sparse, label = sparse.to(device), label.to(device)
                logits = model(sparse)

                val_loss += F.cross_entropy(logits, label,
                                            size_average=False).item()

                pred = logits.max(1, keepdim=True)[1]
                val_correct += pred.eq(label.view_as(pred)).sum().item()

        # update current evaluation metrics
        val_loss /= val_num
        val_acc = 100. * val_correct / val_num
        print(
            '\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'
            .format(val_loss, val_correct, val_num, val_acc))

        is_best = val_acc > best_val_acc
        if is_best:
            best_val_acc = val_acc
            best_val_loss = val_loss  # note this is val_loss of best model w.r.t. accuracy,
            # not the best val_loss throughout training

        # create checkpoint dictionary and save it;
        # if is_best, copy the file over to the file containing best model for this run
        state = {
            'state_dict': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_acc,
        }
        save_checkpoint(state, is_best, checkpoint_file)

    print('\n================== TESTING ==================')
    check = torch.load(checkpoint_file + '-best.pth.tar')
    model.load_state_dict(check['state_dict'])
    model.eval()

    test_loss = 0.
    test_correct = 0
    test_num = args.eval_batch_size * args.test_batches

    # disable autograd here (replaces volatile flag from v0.3.1 and earlier)
    with torch.no_grad():
        for sparse, img, label in test_loader:
            sparse, label = sparse.to(device), label.to(device)
            logits = model(sparse)
            test_loss += F.cross_entropy(logits, label,
                                         size_average=False).item()
            pred = logits.max(
                1, keepdim=True)[1]  # get the index of the max logit
            test_correct += pred.eq(label.view_as(pred)).sum().item()

    test_loss /= test_num
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, test_correct, test_num, 100. * test_correct / test_num))

    print('Final model stored at "{}".'.format(checkpoint_file +
                                               '-best.pth.tar'))