Ejemplo n.º 1
0
 def __init__(self):
     super().__init__()
     self.point_linear_1 = torch.nn.Linear(in_features=2,
                                           out_features=HIDDEN_FEATURE)
     self.final_sat = satnet.SATNet(
         HIDDEN_FEATURE * ACTION_SPACE + ACTION_SPACE, 16, 64)
     self.is_input = torch.cat([
         torch.tensor([1 for _ in range(HIDDEN_FEATURE * ACTION_SPACE)] +
                      [0 for _ in range(ACTION_SPACE)],
                      dtype=torch.int)
     ])
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='parity')
    parser.add_argument('--testPct', type=float, default=0.1)
    parser.add_argument('--batchSz', type=int, default=100)
    parser.add_argument('--testBatchSz', type=int, default=500)
    parser.add_argument('--nEpoch', type=int, default=100)
    parser.add_argument('--lr', type=float, default=1e-1)
    parser.add_argument('--seq', type=int, default=20)
    parser.add_argument('--save', type=str)
    parser.add_argument('--m', type=int, default=4)
    parser.add_argument('--aux', type=int, default=4)
    parser.add_argument('--no_cuda', action='store_true')
    parser.add_argument('--adam', action='store_true')

    args = parser.parse_args()

    # For debugging: fix the random seed
    npr.seed(1)
    torch.manual_seed(7)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda: 
        print('Using', torch.cuda.get_device_name(0))
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.cuda.init()

    save = 'parity.aux{}-m{}-lr{}-bsz{}'.format(
            args.aux, args.m, args.lr, args.batchSz)

    if args.save: save = '{}-{}'.format(args.save, save)
    save = os.path.join('logs', save)
    if os.path.isdir(save): shutil.rmtree(save)
    os.makedirs(save)

    L = args.seq

    with open(os.path.join(args.data_dir, str(L), 'features.pt'), 'rb') as f:
        X = torch.load(f).float()
    with open(os.path.join(args.data_dir, str(L), 'labels.pt'), 'rb') as f:
        Y = torch.load(f).float()

    if args.cuda: X, Y = X.cuda(), Y.cuda()

    N = X.size(0)

    nTrain = int(N*(1-args.testPct))
    nTest = N-nTrain

    assert(nTrain % args.batchSz == 0)
    assert(nTest % args.testBatchSz == 0)

    train_is_input = torch.IntTensor([1,1,0]).repeat(nTrain,1)
    test_is_input = torch.IntTensor([1,1,0]).repeat(nTest,1)
    if args.cuda: train_is_input, test_is_input = train_is_input.cuda(), test_is_input.cuda()

    train_set = TensorDataset(X[:nTrain], train_is_input, Y[:nTrain])
    test_set =  TensorDataset(X[nTrain:], test_is_input, Y[nTrain:])

    model = satnet.SATNet(3, args.m, args.aux, prox_lam=1e-1)
    if args.cuda: model = model.cuda()

    if args.adam:
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    else:
        optimizer = optim.SGD(model.parameters(), lr=args.lr)

    train_logger = CSVLogger(os.path.join(save, 'train.csv'))
    test_logger = CSVLogger(os.path.join(save, 'test.csv'))
    fields = ['epoch', 'loss', 'err']
    train_logger.log(fields)
    test_logger.log(fields)

    test(0, model, optimizer, test_logger, test_set, args.testBatchSz)
    for epoch in range(1, args.nEpoch+1):
        train(epoch, model, optimizer, train_logger, train_set, args.batchSz)
        test(epoch, model, optimizer, test_logger, test_set, args.testBatchSz)
Ejemplo n.º 3
0
 def __init__(self, boardSz, aux, m):
     super(NQueensSolver, self).__init__()
     n = boardSz**2
     self.sat = satnet.SATNet(n, m, aux, max_iter=100, eps=1e-6)
Ejemplo n.º 4
0
 def __init__(self, boardSz, aux, m):
     super(SudokuSolver, self).__init__()
     n = boardSz**6
     self.sat = satnet.SATNet(n, m, aux)