def main():
    """
    
    """
    args = parser.parse_args()
    if args.cuda:
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")

    data_dir = tools.select_data_dir()

    trainset = Sudoku(data_dir, train=True)
    testset = Sudoku(data_dir, train=False)

    trainloader = DataLoader(trainset, batch_size=args.batch_size, collate_fn=collate)
    testloader = DataLoader(testset, batch_size=args.batch_size, collate_fn=collate)

    # Create network
    gnn = GNN(device)
    if not args.skip_training:
        optimizer = torch.optim.Adam(gnn.parameters(), lr=args.learning_rate)
        loss_method = nn.CrossEntropyLoss(reduction="mean")

        for epoch in range(args.n_epochs):
            for i, data in enumerate(trainloader, 0):
                inputs, targets, src_ids, dst_ids = data
                inputs, targets = inputs.to(device), targets.to(device)
                src_ids, dst_ids = src_ids.to(device), dst_ids.to(device)
                optimizer.zero_grad()
                gnn.zero_grad()
                output = gnn.forward(inputs, src_ids, dst_ids)
                output = output.to(device)
                output = output.view(-1, output.shape[2])
                targets = targets.repeat(7, 1)
                targets = targets.view(-1)
                loss = loss_method(output, targets)
                loss.backward()
                optimizer.step()

            fraction = fraction_of_solved_puzzles(gnn, testloader, device)

            print("Train Epoch {}: Loss: {:.6f} Fraction: {}".format(epoch + 1, loss.item(), fraction))

        tools.save_model(gnn, "7_gnn.pth")
    else:
        gnn = GNN(device)
        tools.load_model(gnn, "7_gnn.pth", device)

    # Evaluate the trained model
    # Get graph iterations for some test puzzles
    with torch.no_grad():
        inputs, targets, src_ids, dst_ids = iter(testloader).next()
        inputs, targets = inputs.to(device), targets.to(device)
        src_ids, dst_ids = src_ids.to(device), dst_ids.to(device)

        batch_size = inputs.size(0) // 81
        outputs = gnn(inputs, src_ids, dst_ids).to(device)  # [n_iters, n_nodes, 9]

        solution = outputs.view(gnn.n_iters, batch_size, 9, 9, 9).to(device)
        final_solution = solution[-1].argmax(dim=3).to(device)
        print("Solved puzzles in the current mini-batch:")
        print((final_solution.view(-1, 81) == targets.view(batch_size, 81)).all(dim=1))

    # Visualize graph iteration for one of the puzzles
    ix = 0
    for i in range(gnn.n_iters):
        tools.draw_sudoku(solution[i, 0], logits=True)

    fraction_solved = fraction_of_solved_puzzles(gnn, testloader,device)
    print(f"Accuracy {fraction_solved}")
Пример #2
0
class DCN():
    def __init__(self, batch_size, num_features, num_layers, J, dim_input, clip_grad_norm, logger):
        self.logger = logger
        self.clip_grad_norm = clip_grad_norm
        self.batch_size = batch_size
        self.J = J
        self.Split = Split_GNN(batch_size, num_features, num_layers, J+2, dim_input=dim_input)
        self.Tsp = GNN(num_features, num_layers, J+2, dim_input=dim_input)
        self.Merge = GNN(num_features, num_layers, J+2, dim_input=dim_input)
        self.optimizer_split = optim.RMSprop(self.Split.parameters())
        self.optimizer_tsp = optim.Adamax(self.Tsp.parameters(), lr=1e-3)
        self.optimizer_merge = optim.Adamax(self.Merge.parameters(), lr=1e-3)
        self.test_gens = []
        self.test_gens_labels = []
   
    def load_split(self, path_load):
        self.Split = self.logger.load_model(path_load, 'split')
        self.optimizer_split = optim.RMSprop(self.Split.parameters())
    
    def load_tsp(self, path_load):
        self.Tsp = self.logger.load_model(path_load, 'tsp')
        self.optimizer_tsp = optim.Adamax(self.Tsp.parameters(), lr=1e-3)

    def load_merge(self, path_load):
        self.Merge = self.logger.load_model(path_load, 'merge')
        self.optimizer_merge = optim.Adamax(self.Merge.parameters(), lr=1e-3)

    def save_model(self, path_load, it=-1):
        self.logger.save_model(path_load, self.Split, self.Tsp, self.Merge, it=it)
    
    def set_dataset(self, path_dataset, num_examples_train, num_examples_test, N_train, N_test):
        self.gen = Generator(path_dataset, args.path_tsp)
        self.gen.num_examples_train = num_examples_train
        self.gen.num_examples_test = num_examples_test
        self.gen.N_train = N_train
        self.gen.N_test = N_test
        self.gen.load_dataset()
    
    def add_test_dataset(self, gen, label):
        self.test_gens.append(gen)
        self.test_gens_labels.append(label)

    def sample_one(self, probs, mode='train'):
        probs = 1e-4 + probs*(1 - 2e-4) # to avoid log(0)
        if mode == 'train':
            rand = torch.zeros(*probs.size()).type(dtype)
            nn.init.uniform(rand)
        else:
            rand = torch.ones(*probs.size()).type(dtype) / 2
        bin_sample = probs > Variable(rand)
        sample = bin_sample.clone().type(dtype)
        log_probs_samples = (sample*torch.log(probs) + (1-sample)*torch.log(1-probs)).sum(1)
        return bin_sample.data, sample.data, log_probs_samples
    
    def split_operator(self, W, sample, cities):
        bs = sample.size(0)
        Ns1 = sample.long().sum(1)
        N1 = Ns1.max(0)[0][0]
        W1 = torch.zeros(bs, N1, N1).type(dtype)
        cts = torch.zeros(bs, N1, 2).type(dtype)
        for b in range(bs):
            inds = torch.nonzero(sample[b]).squeeze()
            n = Ns1[b]
            W1[b,:n,:n] = W[b].index_select(1, inds).index_select(0, inds)
            cts[b,:n,:] = cities[b].index_select(0, inds)
        return W1, cts

    def compute_other_operators(self, W, Ns, cts, J):
        bs = W.size(0)
        N = W.size(-1)
        QQ = W.clone()
        WW = torch.zeros(bs, N, N, J + 2).type(dtype)
        eye = torch.eye(N).type(dtype).unsqueeze(0).expand(bs,N,N)
        WW[:, :, :, 0] = eye
        for j in range(J):
            WW[:, :, :, j+1] = QQ.clone()
            QQ = torch.bmm(QQ, QQ)
            mx = QQ.max(2)[0].max(1)[0].unsqueeze(1).unsqueeze(2).expand_as(QQ)
            QQ /= torch.clamp(mx, min=1e-6)
            QQ *= np.sqrt(2)
        d = W.sum(1)
        D = d.unsqueeze(1).expand_as(eye) * eye
        WW[:, :, :, J] = D
        U = Ns.float().unsqueeze(1).expand(bs,N)
        U = torch.ge(U, torch.arange(1,N+1).type(dtype).unsqueeze(0).expand(bs,N))
        U = U.float() / Ns.float().unsqueeze(1).expand_as(U)
        U = torch.bmm(U.unsqueeze(2),U.unsqueeze(1))
        WW[:, :, :, J+1] = U
        x = torch.cat((d.unsqueeze(2),cts),2)
        return Variable(WW), Variable(x), Variable(WW[:,:,:,1])


    def compute_operators(self, W, sample, cities, J):
        bs = sample.size(0)
        Ns1 = sample.long().sum(1)
        Ns2 = (1-sample.long()).sum(1)
        W1, cts1 = self.split_operator(W, sample, cities)
        W2, cts2 = self.split_operator(W, 1-sample, cities)
        op1 = self.compute_other_operators(W1, Ns1, cts1, J)
        op2 = self.compute_other_operators(W2, Ns2, cts2, J)
        return op1, op2
        WW[:, :, :, J + 1] = Phi / Phi.sum(1).unsqueeze(1).expand_as(Phi)
        return WW, d, Phi
    
    def join_preds(self, pred1, pred2, sample):
        bs = pred1.size(0)
        N = sample.size(1)
        N1 = pred1.size(1)
        N2 = pred2.size(1)
        pred = Variable(torch.ones(bs,N,N).type(dtype)*(-999))

        for b in range(bs):
            n1 = sample[b].long().sum(0)[0]
            n2 = (1-sample[b]).long().sum(0)[0]
            inds = torch.cat((torch.nonzero(sample[b]).type(dtype),torch.nonzero(1-sample[b]).type(dtype)),0).squeeze()
            inds = torch.topk(-inds,N)[1]
            M = Variable(torch.zeros(N,N).type(dtype))
            M[:n1,:n1] = pred1[b,:n1,:n1]
            M[n1:,n1:] = pred2[b,:n2,:n2]
            inds = Variable(inds, requires_grad=False)
            M = M.index_select(0,inds).index_select(1,inds)
            pred[b, :, :] = M
        return pred

    def forward(self, input, W, cities):
        scores, probs = self.Split(input)
        #variance = compute_variance(probs)
        bin_sample, sample, log_probs_samples = self.sample_one(probs, mode='train')
        op1, op2 = self.compute_operators(W.data, bin_sample, cities, self.J)
        pred1 = self.Tsp(op1)
        pred2 = self.Tsp(op2)
        partial_pred = self.join_preds(pred1, pred2, bin_sample)
        partial_pred = F.sigmoid(partial_pred)
        
        pred = self.Merge((input[0], input[1], partial_pred))
        return probs, log_probs_samples, pred
    
    def compute_loss(self, pred, target, logprobs):
        loss_split = 0.0
        loss_merge = 0.0
        labels = target[1]
        for i in range(labels.size()[-1]):
            for j in range(labels.size()[0]):
                lab = labels[j, :, i].contiguous().view(-1)
                cel = CEL(pred[j], lab)
                loss_merge += cel
                loss_split += Variable(cel.data) * logprobs[j]
        return loss_merge/pred.size(0), loss_split/pred.size(0)
    
    def train(self, iterations, print_freq, test_freq, save_freq, path_model):
        for it in range(iterations):
            start = time.time()
            batch = self.gen.sample_batch(self.batch_size, cuda=torch.cuda.is_available())
            input, W, WTSP, labels, target, cities, perms, costs = extract(batch)
            probs, log_probs_samples, pred = self.forward(input, W, cities)
            loss_merge, loss_split = self.compute_loss(pred, target, log_probs_samples)
            #loss_split -= variance*rf
            self.Split.zero_grad()
            loss_split.backward()
            nn.utils.clip_grad_norm(self.Split.parameters(), self.clip_grad_norm)
            self.optimizer_split.step()
            self.Tsp.zero_grad()
            self.Merge.zero_grad()
            loss_merge.backward()
            nn.utils.clip_grad_norm(self.Tsp.parameters(), clip_grad)
            nn.utils.clip_grad_norm(self.Merge.parameters(), clip_grad)
            self.optimizer_tsp.step()
            self.optimizer_merge.step()
            
            self.logger.add_train_loss(loss_split, loss_merge)
            self.logger.add_train_accuracy(pred, labels, W)
            elapsed = time.time() - start
            
            if it%print_freq == 0 and it > 0:
                loss_split = loss_split.data.cpu().numpy()[0]
                loss_merge = loss_merge.data.cpu().numpy()[0]
                out = ['---', it, loss_split, loss_merge, self.logger.cost_train[-1],
                       self.logger.accuracy_train[-1], elapsed]
                print(template_train1.format(*info_train))
                print(template_train2.format(*out))
                #print(variance)
                #print(probs[0])
                #plot_clusters(it, probs[0], cities[0])
                #os.system('eog ./plots/clustering/clustering_it_{}.png'.format(it))
            if it%test_freq == 0 and it >= 0:
                self.test()
                #self.logger.plot_test_logs()
            if it%save_freq == 0 and it > 0:
                self.save_model(path_model, it)
    
    def test(self):
        for i, gen in enumerate(self.test_gens):
            print('Test: {}'.format(self.test_gens_labels[i]))
            self.test_gen(gen)

    def test_gen(self, gen):
        iterations_test = int(gen.num_examples_test / self.batch_size)
        for it in range(iterations_test):
            start = time.time()
            batch = gen.sample_batch(self.batch_size, is_training=False, it=it,
                                    cuda=torch.cuda.is_available())
            input, W, WTSP, labels, target, cities, perms, costs = extract(batch)
            probs, log_probs_samples, pred = self.forward(input, W, cities)
            loss_merge, loss_split = self.compute_loss(pred, target, log_probs_samples)
            #loss_split -= variance*rf
            last = (it == iterations_test-1)
            self.logger.add_test_accuracy(pred, labels, perms, W, cities, costs,
                                    last=last, beam_size=beam_size)
            self.logger.add_test_loss(loss_split, loss_merge, last=last)
            elapsed = time.time() - start
            '''if not last and it % 100 == 0:
                loss = loss.data.cpu().numpy()[0]
                out = ['---', it, loss, logger.accuracy_test_aux[-1], 
                    logger.cost_test_aux[-1], beam_size, elapsed]
                print(template_test1.format(*info_test))
                print(template_test2.format(*out))'''
        print('TEST COST: {} | TEST ACCURACY {}\n'
            .format(self.logger.cost_test[-1], self.logger.accuracy_test[-1]))
Пример #3
0
        WW = Variable(WW).type(dtype)
        x = Variable(x).type(dtype)
        y = Variable(y).type(dtype)
        #print(WW, x, y)
        partial_pred = Tsp((WW, x, y))
        partial_pred = partial_pred * Variable(Phi)
        #print(input[0], input[1], partial_pred)
        pred = Merge((input[0], input[1], partial_pred))
        loss_supervised, loss_reinforce = compute_loss(pred, target,
                                                       log_probs_samples)
        loss_reinforce -= variance * rf
        Split.zero_grad()
        loss_reinforce.backward()
        nn.utils.clip_grad_norm(Split.parameters(), clip_grad)
        optimizer_split.step()
        Tsp.zero_grad()
        Merge.zero_grad()
        loss_supervised.backward()
        nn.utils.clip_grad_norm(Tsp.parameters(), clip_grad)
        nn.utils.clip_grad_norm(Merge.parameters(), clip_grad)
        optimizer_tsp.step()
        optimizer_merge.step()

        if it % 50 == 0:
            acc = compute_accuracy(pred, labels)
            out = [it, loss_reinforce.data[0], loss_supervised.data[0], acc]
            print(template.format(*out))
            print(variance)
            #print(probs[0])
            #print('iteracio {}:\nloss_r={}\nloss_s={}'.format(it,loss_reinforce.data[0],loss_supervised.data[0]))
            plot_clusters(it, probs[0], cities[0])