def __init__(self): super().__init__() self.point_linear_1 = torch.nn.Linear(in_features=2, out_features=HIDDEN_FEATURE) self.final_sat = satnet.SATNet( HIDDEN_FEATURE * ACTION_SPACE + ACTION_SPACE, 16, 64) self.is_input = torch.cat([ torch.tensor([1 for _ in range(HIDDEN_FEATURE * ACTION_SPACE)] + [0 for _ in range(ACTION_SPACE)], dtype=torch.int) ])
def main(): parser = argparse.ArgumentParser() parser.add_argument('--data_dir', type=str, default='parity') parser.add_argument('--testPct', type=float, default=0.1) parser.add_argument('--batchSz', type=int, default=100) parser.add_argument('--testBatchSz', type=int, default=500) parser.add_argument('--nEpoch', type=int, default=100) parser.add_argument('--lr', type=float, default=1e-1) parser.add_argument('--seq', type=int, default=20) parser.add_argument('--save', type=str) parser.add_argument('--m', type=int, default=4) parser.add_argument('--aux', type=int, default=4) parser.add_argument('--no_cuda', action='store_true') parser.add_argument('--adam', action='store_true') args = parser.parse_args() # For debugging: fix the random seed npr.seed(1) torch.manual_seed(7) args.cuda = not args.no_cuda and torch.cuda.is_available() if args.cuda: print('Using', torch.cuda.get_device_name(0)) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.cuda.init() save = 'parity.aux{}-m{}-lr{}-bsz{}'.format( args.aux, args.m, args.lr, args.batchSz) if args.save: save = '{}-{}'.format(args.save, save) save = os.path.join('logs', save) if os.path.isdir(save): shutil.rmtree(save) os.makedirs(save) L = args.seq with open(os.path.join(args.data_dir, str(L), 'features.pt'), 'rb') as f: X = torch.load(f).float() with open(os.path.join(args.data_dir, str(L), 'labels.pt'), 'rb') as f: Y = torch.load(f).float() if args.cuda: X, Y = X.cuda(), Y.cuda() N = X.size(0) nTrain = int(N*(1-args.testPct)) nTest = N-nTrain assert(nTrain % args.batchSz == 0) assert(nTest % args.testBatchSz == 0) train_is_input = torch.IntTensor([1,1,0]).repeat(nTrain,1) test_is_input = torch.IntTensor([1,1,0]).repeat(nTest,1) if args.cuda: train_is_input, test_is_input = train_is_input.cuda(), test_is_input.cuda() train_set = TensorDataset(X[:nTrain], train_is_input, Y[:nTrain]) test_set = TensorDataset(X[nTrain:], test_is_input, Y[nTrain:]) model = satnet.SATNet(3, args.m, args.aux, prox_lam=1e-1) if args.cuda: model = model.cuda() if args.adam: optimizer = optim.Adam(model.parameters(), lr=args.lr) else: optimizer = optim.SGD(model.parameters(), lr=args.lr) train_logger = CSVLogger(os.path.join(save, 'train.csv')) test_logger = CSVLogger(os.path.join(save, 'test.csv')) fields = ['epoch', 'loss', 'err'] train_logger.log(fields) test_logger.log(fields) test(0, model, optimizer, test_logger, test_set, args.testBatchSz) for epoch in range(1, args.nEpoch+1): train(epoch, model, optimizer, train_logger, train_set, args.batchSz) test(epoch, model, optimizer, test_logger, test_set, args.testBatchSz)
def __init__(self, boardSz, aux, m): super(NQueensSolver, self).__init__() n = boardSz**2 self.sat = satnet.SATNet(n, m, aux, max_iter=100, eps=1e-6)
def __init__(self, boardSz, aux, m): super(SudokuSolver, self).__init__() n = boardSz**6 self.sat = satnet.SATNet(n, m, aux)