Ejemplo n.º 1
0
 def load_index(self):
     if os.path.exists(os.path.join(self.data_directory, 'ent2id.txt')):
         self.entity_dict = load_index(
             os.path.join(self.data_directory, 'ent2id.txt'))
         print("Load preprocessed entity index")
     elif os.path.exists(os.path.join(self.data_directory, 'ent2ids')):
         self.entity_dict = unserialize(os.path.join(
             self.data_directory, "ent2ids"),
                                        form='json')
         print("Load raw entity index")
     else:
         print("Entity Index not exist")
         self.entity_dict = {}
     if os.path.exists(os.path.join(self.data_directory,
                                    'relation2id.txt')):
         self.relation_dict = load_index(
             os.path.join(self.data_directory, 'relation2id.txt'))
         print("Load preprocessed relation index")
     elif os.path.exists(os.path.join(self.data_directory, 'relation2ids')):
         self.relation_dict = unserialize(os.path.join(
             self.data_directory, "relation2ids"),
                                          form='json')
         print("Load raw relation index")
     else:
         print("Relation Index not exist")
         self.relation_dict = {}
Ejemplo n.º 2
0
 def load_data(self):
     self.data_directory = os.path.join(self.root_directory, "data")
     self.entity_dict = load_index(
         os.path.join(self.data_directory, "ent2id.txt"))
     self.relation_dict = load_index(
         os.path.join(self.data_directory, "relation2id.txt"))
     self.facts_data = translate_facts(
         load_facts(os.path.join(self.data_directory, "train.txt")),
         self.entity_dict, self.relation_dict)
     self.test_support = translate_facts(
         load_facts(os.path.join(self.data_directory, "test_support.txt")),
         self.entity_dict, self.relation_dict)
     self.valid_support = translate_facts(
         load_facts(os.path.join(self.data_directory, "valid_support.txt")),
         self.entity_dict, self.relation_dict)
     self.test_eval = translate_facts(
         load_facts(os.path.join(self.data_directory, "test_eval.txt")),
         self.entity_dict, self.relation_dict)
     self.valid_eval = translate_facts(
         load_facts(os.path.join(self.data_directory, "valid_eval.txt")),
         self.entity_dict, self.relation_dict)
     # augment
     with open(os.path.join(self.data_directory, 'pagerank.txt')) as file:
         self.pagerank = list(
             map(lambda x: float(x.strip()), file.readlines()))
     if os.path.exists(os.path.join(self.data_directory, "fact_dist")):
         self.fact_dist = unserialize(
             os.path.join(self.data_directory, "fact_dist"))
     else:
         self.fact_dist = None
     if os.path.exists(os.path.join(self.data_directory, "train_graphs")):
         self.train_graphs = unserialize(
             os.path.join(self.data_directory, "train_graphs"))
     else:
         self.train_graphs = None
     assert os.path.exists(
         os.path.join(self.data_directory, "evaluate_graphs"))
     if os.path.exists(os.path.join(self.data_directory,
                                    "evaluate_graphs")):
         print("Use evaluate graphs")
         self.evaluate_graphs = unserialize(
             os.path.join(self.data_directory, "evaluate_graphs"))
     else:
         self.evaluate_graphs = None
     if os.path.exists(os.path.join(self.data_directory, "rel2candidates")):
         self.rel2candidate = unserialize(
             os.path.join(self.data_directory, "rel2candidates"))
     else:
         self.rel2candidate = {}
     # self.rel2candidate = {self.relation_dict[key]: value for key, value in self.rel2candidate.items() if
     #                       key in self.relation_dict}
     self.id2entity = sorted(self.entity_dict.keys(),
                             key=self.entity_dict.get)
     self.id2relation = sorted(self.relation_dict.keys(),
                               key=self.relation_dict.get)
     self.data_loaded = True
def get_test_data():
    test_set = unpickle("cifar10/test_batch")
    test_data = test_set[b"data"]
    test_images = unserialize(test_data)
    labels = test_set[b"labels"]
    images = prepare_pixels(test_images)
    labels = prepare_labels(labels)
    return images, labels
def get_train_data(batch_no):
    train_set = unpickle("cifar10/data_batch_" + str(batch_no))
    train_data = train_set[b"data"]
    train_images = unserialize(train_data)
    labels = train_set[b"labels"]
    images = prepare_pixels(train_images)
    labels = prepare_labels(labels)
    return images, labels
Ejemplo n.º 5
0
 def load_model(self, path, batch_id):
     """
     remain for compatible
     """
     config_path = os.path.join(path, 'config,json')
     if os.path.exists(config_path):
         self.config = unserialize(os.path.join(path, 'config.json'))
         self.cogKR = CogKR(graph=self.kg,
                            entity_dict=self.entity_dict,
                            relation_dict=self.relation_dict,
                            device=self.device,
                            **self.config['model']).to(self.device)
         model_state = torch.load(
             os.path.join(path,
                          str(batch_id) + ".model.dict"))
         self.cogKR.load_state_dict(model_state)
Ejemplo n.º 6
0
 def load_raw_data(self):
     self.train_facts = load_facts(
         os.path.join(self.data_directory, "train.txt"))
     self.test_facts = load_facts(
         os.path.join(
             self.data_directory, "test_support.txt")) + load_facts(
                 os.path.join(self.data_directory, "test_eval.txt"))
     self.valid_facts = load_facts(
         os.path.join(
             self.data_directory, "valid_support.txt")) + load_facts(
                 os.path.join(self.data_directory, "valid_eval.txt"))
     if os.path.exists(
             os.path.join(self.data_directory, "rel2candidates.json")):
         print("Load rel2candidates")
         self.rel2candidate = unserialize(
             os.path.join(self.data_directory, "rel2candidates.json"))
     else:
         self.rel2candidate = {}
Ejemplo n.º 7
0
         preprocess.save_data(save_train=args.save_train)
         preprocess.compute_pagerank()
     if args.search_evaluate_graph:
         print("Search Evaluate Graph")
         preprocess.search_evaluate_graph(wiki=args.wiki)
 else:
     os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
     device = torch.device('cuda:0')
     main_body = Main(args,
                      root_directory=args.directory,
                      device=device,
                      comment=args.comment,
                      relation_encode=args.relation_encode,
                      tqdm_wrapper=tqdm)
     if args.config:
         main_body.config = unserialize(args.config)
     main_body.sparse_embed = args.sparse_embed
     main_body.load_data()
     main_body.build_env(main_body.config['graph'])
     if args.save_minerva:
         data_dir = os.path.join(args.directory, "minerva")
         if not os.path.exists(data_dir):
             os.makedirs(data_dir)
         main_body.save_to_hyper(data_dir)
     elif args.get_fact_dist:
         fact_dist = main_body.get_fact_dist(
             main_body.config['trainer']['ignore_relation'])
         serialize(fact_dist,
                   os.path.join(main_body.data_directory, "fact_dist"))
     elif args.pretrain:
         main_body.build_pretrain_model(main_body.config['model'])
Ejemplo n.º 8
0
            print("Hit@10: {:.4f}, NDCG@10: {:.4f}".format(hr[9], ndcg[9]),
                  file=f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True)
    parser.add_argument('--dataset', type=str, required=True)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--eval_samples', type=str, default='')
    parser.add_argument('--save_path', type=str, default='')
    parser.add_argument('--load_path', type=str, default='')
    parser.add_argument('--results_path', type=str, default='')
    args = parser.parse_args()

    config = unserialize(args.config)

    filename_raw = os.path.join(args.dataset, "totalCheckins.txt")
    filename_clean = os.path.join(args.dataset, "QuadKeyLSBNDataset.data")
    user2mc_filename = os.path.join(args.dataset, "reg_trans_pmc_model.pkl")
    loc_query_tree_path = os.path.join(args.dataset, "loc_query_tree.pkl")
    knn_wrmf_sample_prob_path = os.path.join(args.dataset,
                                             "knn_wrmf_sample_prob.pkl")

    reset_random_seed(42)

    if not os.path.isfile(filename_clean):
        dataset = QuadKeyLBSNDataset(filename_raw)
        serialize(dataset, filename_clean)
    else:
        dataset = unserialize(filename_clean)
Ejemplo n.º 9
0
            batch_first=True,
            unk_token=None,
            preprocessing=str.split
        )
        self.QUADKEY.build_vocab(all_quadkeys)

        return user_seq_array, user2idx, region2idx, n_users, n_region, regidx2loc, 169

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, required=True)
    args = parser.parse_args()

    filename_raw = os.path.join(args.dataset, "totalCheckins.txt")
    filename_clean = os.path.join(args.dataset, "QuadKeyLSBNDataset.data")

    if not os.path.isfile(filename_clean):
        dataset = QuadKeyLBSNDataset(filename_raw)
        serialize(dataset, filename_clean)
    else:
        dataset = unserialize(filename_clean)
    
    count = 0
    length = []
    for seq in dataset.user_seq:
        count += len(seq)
        length.append(len(seq))
    print("#check-ins:", count)
    print("#users:", dataset.n_user - 1)
    print("#locations:", dataset.n_loc - 1)
    print("#median seq len:", np.median(np.array(length)))
Ejemplo n.º 10
0
                    filter_statedict(gradient_model),
                    os.path.join(data_directory,
                                 str(batch_id + 1) + ".dict"))
            scheduler.step(test_values[measure_keys.index('Recall_10')])
            gradient_model.train()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str)
    parser.add_argument('--root_directory', type=str)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--comment', type=str, default='init')
    args = parser.parse_args()

    config = unserialize(args.config)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    device = torch.device('cuda:0')
    root_directory = args.root_directory
    # load data
    user_embedding = torch.from_numpy(
        unserialize(
            os.path.join(root_directory,
                         "embeddings/user_embeddings.npy")).astype(np.float32))
    item_embedding = torch.from_numpy(
        unserialize(
            os.path.join(root_directory,
                         "embeddings/item_embeddings.npy")).astype(np.float32))
    train_data = unserialize(
        os.path.join(root_directory, "train_data/train_data"))
    user_dict = unserialize(