Exemplo n.º 1
0
             (32, 30), (32, 31), (33, 8), (33, 9), (33, 13), (33, 14),
             (33, 15), (33, 18), (33, 19), (33, 20), (33, 22), (33, 23),
             (33, 26), (33, 27), (33, 28), (33, 29), (33, 30), (33, 31),
             (33, 32)]
# add edges two lists of nodes: src and dst
src, dst = tuple(zip(*edge_list))
g.add_edges(src, dst, {'l_e': torch.ones(78).to(device)})
g.add_edges(dst, src, {'l_e': torch.zeros(78).to(device)})

e_1 = torch.from_numpy(np.arange(0, 78)).to(device)
e_0 = torch.from_numpy(np.arange(78, 156)).to(device)

g.ndata['x_v'] = torch.rand(34, EMB_SIZE).to(device)
g.ndata['h_v'] = torch.zeros(34, EMB_SIZE).to(device)
g.edata['h_e'] = torch.zeros(156, EMB_SIZE).to(device)

# # a smaller graph.
# g.add_nodes(4)
# edge_list = [(1, 0), (2, 0), (2, 1), (3, 0)]
# src, dst = tuple(zip(*edge_list))
# g.add_edges(src, dst, {'l_e': torch.ones(4)})
# g.add_edges(dst, src, {'l_e': torch.zeros(4)})
# e_1 = torch.from_numpy(np.arange(0, 4))
# e_0 = torch.from_numpy(np.arange(4, 8))
# g.ndata['x_v'] = torch.rand(4, EMB_SIZE)
# g.ndata['h_v'] = torch.zeros(4, EMB_SIZE)
# g.edata['h_e'] = torch.zeros(8, EMB_SIZE)

# pass it through.
g_1 = model.forward(e_1, e_0, g)
print("done")
def main():
    """
    
    """
    args = parser.parse_args()
    if args.cuda:
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")

    data_dir = tools.select_data_dir()

    trainset = Sudoku(data_dir, train=True)
    testset = Sudoku(data_dir, train=False)

    trainloader = DataLoader(trainset, batch_size=args.batch_size, collate_fn=collate)
    testloader = DataLoader(testset, batch_size=args.batch_size, collate_fn=collate)

    # Create network
    gnn = GNN(device)
    if not args.skip_training:
        optimizer = torch.optim.Adam(gnn.parameters(), lr=args.learning_rate)
        loss_method = nn.CrossEntropyLoss(reduction="mean")

        for epoch in range(args.n_epochs):
            for i, data in enumerate(trainloader, 0):
                inputs, targets, src_ids, dst_ids = data
                inputs, targets = inputs.to(device), targets.to(device)
                src_ids, dst_ids = src_ids.to(device), dst_ids.to(device)
                optimizer.zero_grad()
                gnn.zero_grad()
                output = gnn.forward(inputs, src_ids, dst_ids)
                output = output.to(device)
                output = output.view(-1, output.shape[2])
                targets = targets.repeat(7, 1)
                targets = targets.view(-1)
                loss = loss_method(output, targets)
                loss.backward()
                optimizer.step()

            fraction = fraction_of_solved_puzzles(gnn, testloader, device)

            print("Train Epoch {}: Loss: {:.6f} Fraction: {}".format(epoch + 1, loss.item(), fraction))

        tools.save_model(gnn, "7_gnn.pth")
    else:
        gnn = GNN(device)
        tools.load_model(gnn, "7_gnn.pth", device)

    # Evaluate the trained model
    # Get graph iterations for some test puzzles
    with torch.no_grad():
        inputs, targets, src_ids, dst_ids = iter(testloader).next()
        inputs, targets = inputs.to(device), targets.to(device)
        src_ids, dst_ids = src_ids.to(device), dst_ids.to(device)

        batch_size = inputs.size(0) // 81
        outputs = gnn(inputs, src_ids, dst_ids).to(device)  # [n_iters, n_nodes, 9]

        solution = outputs.view(gnn.n_iters, batch_size, 9, 9, 9).to(device)
        final_solution = solution[-1].argmax(dim=3).to(device)
        print("Solved puzzles in the current mini-batch:")
        print((final_solution.view(-1, 81) == targets.view(batch_size, 81)).all(dim=1))

    # Visualize graph iteration for one of the puzzles
    ix = 0
    for i in range(gnn.n_iters):
        tools.draw_sudoku(solution[i, 0], logits=True)

    fraction_solved = fraction_of_solved_puzzles(gnn, testloader,device)
    print(f"Accuracy {fraction_solved}")