예제 #1
0
def train(**kwargs):
    config.parse(kwargs)
    if os.path.exists(config.filename + '_' + str(config.split_ratio) +
                      'SineData.pkl'):
        train_data = pickle.load(
            file(config.filename + '_' + str(config.split_ratio) +
                 'SineData.pkl'))
        print 'exists SineData.pkl, load it!'
    else:
        train_data = SineData(config.filename, split_ratio=config.split_ratio)
        pickle.dump(
            train_data,
            file(
                config.filename + '_' + str(config.split_ratio) +
                'SineData.pkl', 'w'))
    config.N = train_data.G.g.number_of_nodes() + 1
    model = getattr(models, config.model)(config)  # .eval()
    if torch.cuda.is_available():
        model.cuda()
        config.CUDA = True
    train_dataloader = DataLoader(train_data,
                                  config.batch_size,
                                  shuffle=True,
                                  num_workers=config.num_workers)
    #  optimizer = torch.optim.SGD(model.parameters(),lr = config.lr, weight_decay = config.weight_decay)
    optimizer = torch.optim.Adadelta(model.parameters(),
                                     lr=1.0,
                                     rho=0.95,
                                     weight_decay=config.weight_decay)
    task = Task(train_data.G)
    # model.train()
    for epoch in range(config.epochs):
        total_loss = 0.0
        for idx, data in enumerate(train_dataloader):
            #  if config.CUDA:
            #  data = map(lambda x: Variable(x.cuda()), data)
            #  else:
            #  data = map(lambda x: Variable(x), data)
            optimizer.zero_grad()
            loss = model(data)
            loss.backward()
            optimizer.step()
            if config.CUDA:
                total_loss += loss.cpu().data.numpy()
            else:
                total_loss += loss.data.numpy()
        print 'epoch {0}, loss: {1}'.format(epoch, total_loss)
        task.link_sign_prediction_split(model.get_embedding())
예제 #2
0
def test(**kwargs):
    snap_root = kwargs['snap_root']
    config_file = snap_root + '/config.pkl'
    config = pickle.load(file(config_file))
    model_file = snap_root + '/{}.model'.format(config.model)
    dataset_name = 'kTupleDataV1'
    if os.path.exists(config.filename + '_' + str(config.split_ratio) +
                      '_{}.pkl'.format(dataset_name)):
        train_data = pickle.load(file(snap_root + '/data.pkl'))
        print('exists {}.pkl, load it!'.format(dataset_name))
        print(train_data.G.g.number_of_nodes(),
              train_data.G.g.number_of_edges())
    else:
        raise Exception('Data Module not exists!')
    model = getattr(models, config.model)(config)  # .eval()
    if torch.cuda.is_available():
        model.cuda()
        config.CUDA = True
    model.load_state_dict(torch.load(model_file))
    task = Task(train_data.G, config)
    task.link_sign_prediction_split(utils.cat_neighbor(train_data.G.g,
                                                       model.get_embedding(),
                                                       method='null'),
                                    method='concatenate')