def main(): """ """ args = parser.parse_args() if args.cuda: device = torch.device("cuda:0") else: device = torch.device("cpu") data_dir = tools.select_data_dir() trainset = Sudoku(data_dir, train=True) testset = Sudoku(data_dir, train=False) trainloader = DataLoader(trainset, batch_size=args.batch_size, collate_fn=collate) testloader = DataLoader(testset, batch_size=args.batch_size, collate_fn=collate) # Create network gnn = GNN(device) if not args.skip_training: optimizer = torch.optim.Adam(gnn.parameters(), lr=args.learning_rate) loss_method = nn.CrossEntropyLoss(reduction="mean") for epoch in range(args.n_epochs): for i, data in enumerate(trainloader, 0): inputs, targets, src_ids, dst_ids = data inputs, targets = inputs.to(device), targets.to(device) src_ids, dst_ids = src_ids.to(device), dst_ids.to(device) optimizer.zero_grad() gnn.zero_grad() output = gnn.forward(inputs, src_ids, dst_ids) output = output.to(device) output = output.view(-1, output.shape[2]) targets = targets.repeat(7, 1) targets = targets.view(-1) loss = loss_method(output, targets) loss.backward() optimizer.step() fraction = fraction_of_solved_puzzles(gnn, testloader, device) print("Train Epoch {}: Loss: {:.6f} Fraction: {}".format(epoch + 1, loss.item(), fraction)) tools.save_model(gnn, "7_gnn.pth") else: gnn = GNN(device) tools.load_model(gnn, "7_gnn.pth", device) # Evaluate the trained model # Get graph iterations for some test puzzles with torch.no_grad(): inputs, targets, src_ids, dst_ids = iter(testloader).next() inputs, targets = inputs.to(device), targets.to(device) src_ids, dst_ids = src_ids.to(device), dst_ids.to(device) batch_size = inputs.size(0) // 81 outputs = gnn(inputs, src_ids, dst_ids).to(device) # [n_iters, n_nodes, 9] solution = outputs.view(gnn.n_iters, batch_size, 9, 9, 9).to(device) final_solution = solution[-1].argmax(dim=3).to(device) print("Solved puzzles in the current mini-batch:") print((final_solution.view(-1, 81) == targets.view(batch_size, 81)).all(dim=1)) # Visualize graph iteration for one of the puzzles ix = 0 for i in range(gnn.n_iters): tools.draw_sudoku(solution[i, 0], logits=True) fraction_solved = fraction_of_solved_puzzles(gnn, testloader,device) print(f"Accuracy {fraction_solved}")
class DCN(): def __init__(self, batch_size, num_features, num_layers, J, dim_input, clip_grad_norm, logger): self.logger = logger self.clip_grad_norm = clip_grad_norm self.batch_size = batch_size self.J = J self.Split = Split_GNN(batch_size, num_features, num_layers, J+2, dim_input=dim_input) self.Tsp = GNN(num_features, num_layers, J+2, dim_input=dim_input) self.Merge = GNN(num_features, num_layers, J+2, dim_input=dim_input) self.optimizer_split = optim.RMSprop(self.Split.parameters()) self.optimizer_tsp = optim.Adamax(self.Tsp.parameters(), lr=1e-3) self.optimizer_merge = optim.Adamax(self.Merge.parameters(), lr=1e-3) self.test_gens = [] self.test_gens_labels = [] def load_split(self, path_load): self.Split = self.logger.load_model(path_load, 'split') self.optimizer_split = optim.RMSprop(self.Split.parameters()) def load_tsp(self, path_load): self.Tsp = self.logger.load_model(path_load, 'tsp') self.optimizer_tsp = optim.Adamax(self.Tsp.parameters(), lr=1e-3) def load_merge(self, path_load): self.Merge = self.logger.load_model(path_load, 'merge') self.optimizer_merge = optim.Adamax(self.Merge.parameters(), lr=1e-3) def save_model(self, path_load, it=-1): self.logger.save_model(path_load, self.Split, self.Tsp, self.Merge, it=it) def set_dataset(self, path_dataset, num_examples_train, num_examples_test, N_train, N_test): self.gen = Generator(path_dataset, args.path_tsp) self.gen.num_examples_train = num_examples_train self.gen.num_examples_test = num_examples_test self.gen.N_train = N_train self.gen.N_test = N_test self.gen.load_dataset() def add_test_dataset(self, gen, label): self.test_gens.append(gen) self.test_gens_labels.append(label) def sample_one(self, probs, mode='train'): probs = 1e-4 + probs*(1 - 2e-4) # to avoid log(0) if mode == 'train': rand = torch.zeros(*probs.size()).type(dtype) nn.init.uniform(rand) else: rand = torch.ones(*probs.size()).type(dtype) / 2 bin_sample = probs > Variable(rand) sample = bin_sample.clone().type(dtype) log_probs_samples = (sample*torch.log(probs) + (1-sample)*torch.log(1-probs)).sum(1) return bin_sample.data, sample.data, log_probs_samples def split_operator(self, W, sample, cities): bs = sample.size(0) Ns1 = sample.long().sum(1) N1 = Ns1.max(0)[0][0] W1 = torch.zeros(bs, N1, N1).type(dtype) cts = torch.zeros(bs, N1, 2).type(dtype) for b in range(bs): inds = torch.nonzero(sample[b]).squeeze() n = Ns1[b] W1[b,:n,:n] = W[b].index_select(1, inds).index_select(0, inds) cts[b,:n,:] = cities[b].index_select(0, inds) return W1, cts def compute_other_operators(self, W, Ns, cts, J): bs = W.size(0) N = W.size(-1) QQ = W.clone() WW = torch.zeros(bs, N, N, J + 2).type(dtype) eye = torch.eye(N).type(dtype).unsqueeze(0).expand(bs,N,N) WW[:, :, :, 0] = eye for j in range(J): WW[:, :, :, j+1] = QQ.clone() QQ = torch.bmm(QQ, QQ) mx = QQ.max(2)[0].max(1)[0].unsqueeze(1).unsqueeze(2).expand_as(QQ) QQ /= torch.clamp(mx, min=1e-6) QQ *= np.sqrt(2) d = W.sum(1) D = d.unsqueeze(1).expand_as(eye) * eye WW[:, :, :, J] = D U = Ns.float().unsqueeze(1).expand(bs,N) U = torch.ge(U, torch.arange(1,N+1).type(dtype).unsqueeze(0).expand(bs,N)) U = U.float() / Ns.float().unsqueeze(1).expand_as(U) U = torch.bmm(U.unsqueeze(2),U.unsqueeze(1)) WW[:, :, :, J+1] = U x = torch.cat((d.unsqueeze(2),cts),2) return Variable(WW), Variable(x), Variable(WW[:,:,:,1]) def compute_operators(self, W, sample, cities, J): bs = sample.size(0) Ns1 = sample.long().sum(1) Ns2 = (1-sample.long()).sum(1) W1, cts1 = self.split_operator(W, sample, cities) W2, cts2 = self.split_operator(W, 1-sample, cities) op1 = self.compute_other_operators(W1, Ns1, cts1, J) op2 = self.compute_other_operators(W2, Ns2, cts2, J) return op1, op2 WW[:, :, :, J + 1] = Phi / Phi.sum(1).unsqueeze(1).expand_as(Phi) return WW, d, Phi def join_preds(self, pred1, pred2, sample): bs = pred1.size(0) N = sample.size(1) N1 = pred1.size(1) N2 = pred2.size(1) pred = Variable(torch.ones(bs,N,N).type(dtype)*(-999)) for b in range(bs): n1 = sample[b].long().sum(0)[0] n2 = (1-sample[b]).long().sum(0)[0] inds = torch.cat((torch.nonzero(sample[b]).type(dtype),torch.nonzero(1-sample[b]).type(dtype)),0).squeeze() inds = torch.topk(-inds,N)[1] M = Variable(torch.zeros(N,N).type(dtype)) M[:n1,:n1] = pred1[b,:n1,:n1] M[n1:,n1:] = pred2[b,:n2,:n2] inds = Variable(inds, requires_grad=False) M = M.index_select(0,inds).index_select(1,inds) pred[b, :, :] = M return pred def forward(self, input, W, cities): scores, probs = self.Split(input) #variance = compute_variance(probs) bin_sample, sample, log_probs_samples = self.sample_one(probs, mode='train') op1, op2 = self.compute_operators(W.data, bin_sample, cities, self.J) pred1 = self.Tsp(op1) pred2 = self.Tsp(op2) partial_pred = self.join_preds(pred1, pred2, bin_sample) partial_pred = F.sigmoid(partial_pred) pred = self.Merge((input[0], input[1], partial_pred)) return probs, log_probs_samples, pred def compute_loss(self, pred, target, logprobs): loss_split = 0.0 loss_merge = 0.0 labels = target[1] for i in range(labels.size()[-1]): for j in range(labels.size()[0]): lab = labels[j, :, i].contiguous().view(-1) cel = CEL(pred[j], lab) loss_merge += cel loss_split += Variable(cel.data) * logprobs[j] return loss_merge/pred.size(0), loss_split/pred.size(0) def train(self, iterations, print_freq, test_freq, save_freq, path_model): for it in range(iterations): start = time.time() batch = self.gen.sample_batch(self.batch_size, cuda=torch.cuda.is_available()) input, W, WTSP, labels, target, cities, perms, costs = extract(batch) probs, log_probs_samples, pred = self.forward(input, W, cities) loss_merge, loss_split = self.compute_loss(pred, target, log_probs_samples) #loss_split -= variance*rf self.Split.zero_grad() loss_split.backward() nn.utils.clip_grad_norm(self.Split.parameters(), self.clip_grad_norm) self.optimizer_split.step() self.Tsp.zero_grad() self.Merge.zero_grad() loss_merge.backward() nn.utils.clip_grad_norm(self.Tsp.parameters(), clip_grad) nn.utils.clip_grad_norm(self.Merge.parameters(), clip_grad) self.optimizer_tsp.step() self.optimizer_merge.step() self.logger.add_train_loss(loss_split, loss_merge) self.logger.add_train_accuracy(pred, labels, W) elapsed = time.time() - start if it%print_freq == 0 and it > 0: loss_split = loss_split.data.cpu().numpy()[0] loss_merge = loss_merge.data.cpu().numpy()[0] out = ['---', it, loss_split, loss_merge, self.logger.cost_train[-1], self.logger.accuracy_train[-1], elapsed] print(template_train1.format(*info_train)) print(template_train2.format(*out)) #print(variance) #print(probs[0]) #plot_clusters(it, probs[0], cities[0]) #os.system('eog ./plots/clustering/clustering_it_{}.png'.format(it)) if it%test_freq == 0 and it >= 0: self.test() #self.logger.plot_test_logs() if it%save_freq == 0 and it > 0: self.save_model(path_model, it) def test(self): for i, gen in enumerate(self.test_gens): print('Test: {}'.format(self.test_gens_labels[i])) self.test_gen(gen) def test_gen(self, gen): iterations_test = int(gen.num_examples_test / self.batch_size) for it in range(iterations_test): start = time.time() batch = gen.sample_batch(self.batch_size, is_training=False, it=it, cuda=torch.cuda.is_available()) input, W, WTSP, labels, target, cities, perms, costs = extract(batch) probs, log_probs_samples, pred = self.forward(input, W, cities) loss_merge, loss_split = self.compute_loss(pred, target, log_probs_samples) #loss_split -= variance*rf last = (it == iterations_test-1) self.logger.add_test_accuracy(pred, labels, perms, W, cities, costs, last=last, beam_size=beam_size) self.logger.add_test_loss(loss_split, loss_merge, last=last) elapsed = time.time() - start '''if not last and it % 100 == 0: loss = loss.data.cpu().numpy()[0] out = ['---', it, loss, logger.accuracy_test_aux[-1], logger.cost_test_aux[-1], beam_size, elapsed] print(template_test1.format(*info_test)) print(template_test2.format(*out))''' print('TEST COST: {} | TEST ACCURACY {}\n' .format(self.logger.cost_test[-1], self.logger.accuracy_test[-1]))
WW = Variable(WW).type(dtype) x = Variable(x).type(dtype) y = Variable(y).type(dtype) #print(WW, x, y) partial_pred = Tsp((WW, x, y)) partial_pred = partial_pred * Variable(Phi) #print(input[0], input[1], partial_pred) pred = Merge((input[0], input[1], partial_pred)) loss_supervised, loss_reinforce = compute_loss(pred, target, log_probs_samples) loss_reinforce -= variance * rf Split.zero_grad() loss_reinforce.backward() nn.utils.clip_grad_norm(Split.parameters(), clip_grad) optimizer_split.step() Tsp.zero_grad() Merge.zero_grad() loss_supervised.backward() nn.utils.clip_grad_norm(Tsp.parameters(), clip_grad) nn.utils.clip_grad_norm(Merge.parameters(), clip_grad) optimizer_tsp.step() optimizer_merge.step() if it % 50 == 0: acc = compute_accuracy(pred, labels) out = [it, loss_reinforce.data[0], loss_supervised.data[0], acc] print(template.format(*out)) print(variance) #print(probs[0]) #print('iteracio {}:\nloss_r={}\nloss_s={}'.format(it,loss_reinforce.data[0],loss_supervised.data[0])) plot_clusters(it, probs[0], cities[0])