예제 #1
0
def get_datasets(config_data):
    train_dataloader = DataLoader(
        CycleDataset(config_data, 'train'),
        batch_size=config_data['experiment']['batch_size'],
        num_workers=config_data['experiment']['num_worker'])
    test_dataloader = DataLoader(
        CycleDataset(config_data, 'test'),
        batch_size=config_data['experiment']['batch_size'],
        num_workers=config_data['experiment']['num_worker'])
    return train_dataloader, test_dataloader
예제 #2
0
def train(num_epochs=200):
    num_epochs = int(num_epochs)
    sgvae = ConstructorOnly(rounds=6,
                            node_dim=5,
                            msg_dim=6,
                            edge_dim=3,
                            graph_dim=30,
                            num_node_types=2,
                            lamb=1)
    trainData = CycleDataset('cycles/train.cycles')
    valData = CycleDataset('cycles/val.cycles')

    trainLoader = utils.DataLoader(trainData,
                                   batch_size=1,
                                   shuffle=False,
                                   num_workers=0,
                                   collate_fn=trainData.collate_single)

    optimizer = optim.SGD(sgvae.parameters(), lr=0.001, momentum=0.9)

    # for g in trainLoader:
    #     print(g)
    #     break

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        print("Epoch", epoch)
        if epoch % 5 == 0 and epoch != 0:
            print("Saving to {}.params".format(epoch))
            torch.save(sgvae.state_dict(), 'params/{}.params'.format(epoch))
            eval(epoch, writeFile=True, z_value=z)
        loss_sum = 0
        for g in tqdm(trainLoader, desc="[{}]".format(epoch)):
            loss, genGraph, z = sgvae.loss(g, return_graph=True)
            loss_sum += loss
        loss_sum /= len(trainLoader)
        loss_sum.backward()
        optimizer.step()
        print(loss_sum)
    print("Saving to {}.params".format(epoch))
    torch.save(sgvae.state_dict(), 'params/{}.params'.format(epoch))
예제 #3
0
def main():
    if sys.argv[1] == 'train':
        train(*sys.argv[2:])
    elif sys.argv[1] == 'vis':
        trainData = CycleDataset('cycles/five_train.cycles')
        for x in trainData:
            break
        assert is_cycle(x)
        nx.draw(x.to_networkx())
        plt.show()
    else:
        eval(sys.argv[2], calc_cycle=True)
예제 #4
0
        seen.add(int(node))
        node = list(neighbors)[0]
    return g.number_of_nodes() - len(
        seen
    ) == 1  #and 0 in [int(f) for f in g.successors(node)] and len(g.successors(node)) == 2
    #return g.number_of_nodes() == 1


epoch = 60

params = 'train/cycles{}.params'.format(epoch)
sgvae = torch.load(params)
#sgvae = None
optimizer = torch.load('optimizer')
#optimizer = None
trainData = CycleDataset('datasets/cycles.pkl')
sgvae, optim = train('cycles',
                     trainData,
                     better_is_cycle,
                     batch_size=10,
                     num_epochs=epoch + 21,
                     sgvae=sgvae,
                     start_epoch=epoch + 1,
                     optimizer=optimizer)  #torch.load('optimizer'))
torch.save(optim, 'optimizer')
'''Interpolation
sgvae = RESTORE PARAMS
import cycle_dataset
x1 = cycle_dataset.create_cycle_with_size(5)
x2 = cycle_dataset.create_cycle_with_size(15)
evaluate(sgvae, is_cycle, x1, x2, lambda g: g.number_of_nodes())
예제 #5
0
파일: jcycles.py 프로젝트: bjing2016/sgvae
def train(num_epochs=200):
    num_epochs = int(num_epochs)

    sgvae = SGVAE(rounds=2,
                  node_dim=5,
                  msg_dim=6,
                  edge_dim=3,
                  graph_dim=30,
                  num_node_types=2,
                  lamb=1)

    destructor = sgvae.encoder
    constructor = sgvae.decoder

    trainData = CycleDataset('cycles/train.cycles')
    trainLoader = utils.DataLoader(trainData,
                                   batch_size=1,
                                   shuffle=False,
                                   num_workers=0,
                                   collate_fn=trainData.collate_single)
    # g = trainData[0]
    # print(g.number_of_nodes())
    # print(g)
    # z, pi, __ = destructor(deepcopy(g))
    # print(pi)
    optimizer = optim.Adam(sgvae.parameters(), lr=0.01)
    for epoch in range(num_epochs):
        print("Epoch", epoch)
        t = tqdm(trainLoader)
        probs = []
        # g = trainData[0]
        # for i, g in enumerate(t):
        # z, pi, __ = destructor(deepcopy(g))
        optimizer.zero_grad()
        loss_sum = 0
        for i, g in enumerate(t):
            loss, genGraph, z, log_qzpi, log_px = sgvae.loss(g,
                                                             return_graph=True)
            loss_sum += loss
            t.set_description("{:.3f}".format(float(loss)))
            if i == 99:
                avg_prob = loss_sum / 100
                t.set_description("Avg: {:.3f}".format(float(avg_prob)))
        loss_sum /= len(trainLoader)
        loss_sum.backward()
        # g, prob = constructor(z, pi=pi, target=g)
        # (-prob).backward(retain_graph=False)
        optimizer.step()
        # print(prob)
        # if epoch % 100 == 0:
        new = constructor(z)[0]
        plt.clf()
        nx.draw(new.to_networkx())
        plt.savefig('outputs3/{}.png'.format(epoch))
        if epoch % 10 == 0 or epoch == (num_epochs - 1):
            print("Saving to {}.params".format(epoch))
            torch.save(sgvae.state_dict(), 'params3/{}.params'.format(epoch))
            # t.set_description("{:.3f}".format(float(prob)))
            # probs.append(float(prob))
            # if i == 99:
            #     avg_prob = sum(probs) / len(probs)
            #     t.set_description("Avg: {:.3f}".format(avg_prob))

        # print(avg_prob)

        # print(prob)

    exit()
    valData = CycleDataset('cycles/val.cycles')

    optimizer = optim.SGD(sgvae.parameters(), lr=0.01, momentum=0.9)

    # for g in trainLoader:
    #     print(g)
    #     break

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        print("Epoch", epoch)
        if epoch % 5 == 0 and epoch != 0:
            print("Saving to {}.params".format(epoch))
            torch.save(sgvae.state_dict(), 'params/{}.params'.format(epoch))
            eval(epoch, writeFile=True, z_value=z)
        loss_sum = 0
        for g in tqdm(trainLoader, desc="[{}]".format(epoch)):
            loss, genGraph, z = sgvae.loss(g, return_graph=True)
            loss_sum += loss
        loss_sum /= len(trainLoader)
        loss_sum.backward()
        optimizer.step()
        print(loss_sum)
    print("Saving to {}.params".format(epoch))
    torch.save(sgvae.state_dict(), 'params/{}.params'.format(epoch))
예제 #6
0
파일: cycles.py 프로젝트: bjing2016/sgvae
def train(num_epochs=200):
    num_epochs = int(num_epochs)

    sgvae = SGVAE(rounds=3,
                  node_dim=3,
                  msg_dim=6,
                  edge_dim=3,
                  graph_dim=30,
                  num_node_types=2,
                  lamb=1)

    destructor = sgvae.encoder
    constructor = sgvae.decoder

    trainData = CycleDataset('cycles/train.cycles')
    g = trainData[0]

    #z, pi, __ = destructor(deepcopy(g))
    #pi = range(7)
    #print(pi)
    optimizer = optim.SGD(sgvae.parameters(), lr=0.01)
    t = trange(18000)

    for i in t:
        optimizer.zero_grad()
        #z, pi, log_qzpi = destructor(deepcopy(g))
        #_, prob = constructor(z, pi=pi, target=g)

        loss, genGraph, z, log_qzpi, prob, unldr = sgvae.loss(
            g, return_graph=True)
        #f.write(str(pi))# (z.detach().numpy(), pi, float(log_qzpi), float(prob))))
        #f.write('\n')
        #f.flush()
        (loss).backward(retain_graph=False)
        #(-log_qzpi-prob).backward(retain_graph=False)
        optimizer.step()
        s = '{:.4f} {:.4f}'.format(float(log_qzpi), float(prob))
        sprime = '{:.4f}'.format(float(unldr))
        f.write(s + ' ' + sprime + '\n')
        f.flush()
        t.set_description('{:.4f}'.format(float(unldr)))

        if i % 100 == 0:
            new = sgvae.generate()
            plt.clf()
            nx.draw(new.to_networkx())
            plt.savefig('outputs/{}.png'.format(i))

    exit()
    valData = CycleDataset('cycles/val.cycles')

    optimizer = optim.SGD(sgvae.parameters(), lr=0.01, momentum=0.9)

    # for g in trainLoader:
    #     print(g)
    #     break

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        print("Epoch", epoch)
        if epoch % 5 == 0 and epoch != 0:
            print("Saving to {}.params".format(epoch))
            torch.save(sgvae.state_dict(), 'params/{}.params'.format(epoch))
            eval(epoch, writeFile=True, z_value=z)
        loss_sum = 0
        for g in tqdm(trainLoader, desc="[{}]".format(epoch)):
            loss, genGraph, z = sgvae.loss(g, return_graph=True)
            loss_sum += loss
        loss_sum /= len(trainLoader)
        loss_sum.backward()
        optimizer.step()
        print(loss_sum)
    print("Saving to {}.params".format(epoch))
    torch.save(sgvae.state_dict(), 'params/{}.params'.format(epoch))