Exemple #1
0
                         batch_size=200,
                         num_workers=0,
                         drop_last=True)

    G1 = nx.path_graph(6).to_directed()
    G_target = nx.path_graph(6).to_directed()
    nx.draw(G1)
    plt.show()
    node_feat_size = 6
    edge_feat_size = 3
    graph_feat_size = 10
    gn = FFGN(graph_feat_size, node_feat_size, edge_feat_size).cuda()
    if opt.model != '':
        gn.load_state_dict(torch.load(opt.model))

    optimizer = optim.Adam(gn.parameters(), lr=1e-4)
    schedular = optim.lr_scheduler.StepLR(optimizer, 5e4, gamma=0.975)
    savedir = os.path.join('./logs', 'runs',
                           datetime.now().strftime('%B%d_%H:%M:%S'))
    writer = SummaryWriter(savedir)
    step = 0

    normalizers = torch.load('normalize.pth')
    in_normalizer = normalizers['in_normalizer']
    out_normalizer = normalizers['out_normalizer']
    std = in_normalizer.get_std()
    for epoch in range(300):
        for i, data in tqdm(enumerate(dl), total=len(dset) / 200 + 1):
            optimizer.zero_grad()
            action, delta_state, last_state = data
            action, delta_state, last_state = action.float(),\
from tqdm import tqdm
from dataset import SwimmerDataset
from utils import *

if __name__ == "__main__":
    dset = SwimmerDataset('swimmer.npy')
    use_cuda = True
    dl = DataLoader(dset, batch_size=200, num_workers=0, drop_last=True)
    G1 = nx.path_graph(6).to_directed()
    #nx.draw(G1)
    #plt.show()
    node_feat_size = 6
    edge_feat_size = 3
    graph_feat_size = 10
    gn = FFGN(graph_feat_size, node_feat_size, edge_feat_size).cuda()
    optimizer = optim.Adam(gn.parameters(), lr = 1e-3)
    savedir = os.path.join('./logs','runs',
        datetime.now().strftime('%B%d_%H:%M:%S'))
    writer = SummaryWriter(savedir)
    step = 0

    in_normalizer = Normalizer()
    out_normalizer = Normalizer()

    for epoch in range(1):
        for i,data in tqdm(enumerate(dl)):
            action, delta_state, last_state = data
            action, delta_state, last_state = action.float(), delta_state.float(), last_state.float()
            if use_cuda:
                action, delta_state, last_state = action.cuda(), delta_state.cuda(), last_state.cuda()