Esempio n. 1
0
class RandomGraph(InMemoryDataset):
    def __init__(self, n):
        n_edges = m = n * 10
        in_feats = 100
        n_classes = 10

        row = np.random.choice(n, m)
        col = np.random.choice(n, m)
        spm = spsp.coo_matrix((np.ones(len(row)), (row, col)), shape=(n, n))

        features = torch.ones((n, in_feats))
        labels = torch.LongTensor(np.random.choice(n_classes, n))
        train_mask = np.ones(shape=(n))
        val_mask = np.ones(shape=(n))
        test_mask = np.ones(shape=(n))
        if hasattr(torch, 'BoolTensor'):
            train_mask = torch.BoolTensor(train_mask)
            val_mask = torch.BoolTensor(val_mask)
            test_mask = torch.BoolTensor(test_mask)
        else:
            train_mask = torch.ByteTensor(train_mask)
            val_mask = torch.ByteTensor(val_mask)
            test_mask = torch.ByteTensor(test_mask)

        self.edge_attr = torch.FloatTensor(spm.data)
        indices = np.vstack((spm.row, spm.col))
        indices = torch.LongTensor(indices)
        self.edge_index = indices
        self.x = features
        self.y = labels
        self.train_mask = train_mask
        self.val_mask = val_mask
        self.test_mask = test_mask
        self.in_feats = in_feats
        self.n_classes = n_classes
        self.adj = SparseTensor(row=self.edge_index[0],
                                col=self.edge_index[1],
                                value=self.edge_attr)

    def to(self, device):
        self.x = self.x.cuda()
        self.y = self.y.cuda()
        self.train_mask = self.train_mask.cuda()
        self.val_mask = self.val_mask.cuda()
        self.test_mask = self.test_mask.cuda()
        self.edge_index = self.edge_index.cuda()
        self.edge_attr = self.edge_attr.cuda()
        self.adj = self.adj.cuda()
        return self
Esempio n. 2
0
def main(parser):
    parser = parse_additional_args(parser)
    args = parser.parse_args(None)
    set_global_seed(args.seed)
    set_train_mode(args)
    set_save_path(args)
    gpus = [int(i) for i in args.gpus.split(".")]
    assert len(gpus) == 1, "pruner pretraining only supports single GPU"
    print("Logging to", args.save_path)
    writer = set_logger(args)

    latent_space_executor = get_lse_model(args)
    logging.info("Latent space executor created")
    load_lse_checkpoint(args, latent_space_executor)
    logging.info("Latent space executor loaded")

    relation_pruner = RelationPruner(latent_space_executor.entity_dim,
                                     args.nrelation)
    if args.geo == 'box':
        branch_pruner = BranchPruner(latent_space_executor.entity_dim * 2)
    elif args.geo == 'rotate':
        branch_pruner = BranchPruner(latent_space_executor.entity_dim)
    logging.info("Pruner created")

    kg = sampler_clib.create_kg(args.nentity, args.nrelation, args.kg_dtype)
    kg.load_triplets(os.path.join(args.data_path, "train_indexified.txt"),
                     True)
    logging.info("KG constructed")

    ent_in, ent_out, adj = construct_graph(args.data_path,
                                           ['train_indexified.txt'], True)
    adj = torch.LongTensor(adj).transpose(0, 1)
    ent_rel_mat = SparseTensor(row=adj[0], col=adj[1])
    logging.info("Adj matrix constructed")

    relation_query_structures = parse_structures(args.relation_tasks)
    branch_query_structures = parse_structures(args.branch_tasks)
    check_valid_branch_structures(branch_query_structures)
    branch_sampler = BranchSampler(
        kg,
        branch_query_structures,
        1,  # placeholder does not matter
        eval_tuple(args.online_sample_mode),
        [
            1. / len(branch_query_structures)
            for _ in range(len(branch_query_structures))
        ],
        sampler_type='naive',
        same_in_batch=False,
        share_negative=True,
        num_threads=args.cpu_num)
    relation_sampler = RelationSampler(
        kg,
        relation_query_structures,
        1,  # placeholder does not matter
        eval_tuple(args.online_sample_mode),
        [
            1. / len(relation_query_structures)
            for _ in range(len(relation_query_structures))
        ],
        sampler_type='naive',
        same_in_batch=False,
        share_negative=True,
        num_threads=args.cpu_num)
    branch_iterator = branch_sampler.batch_generator(args.batch_size)
    relation_iterator = relation_sampler.batch_generator(args.batch_size)
    logging.info("Train samplers constructed")

    if args.cuda:
        ent_rel_mat = ent_rel_mat.cuda()
        latent_space_executor = latent_space_executor.cuda()
        relation_pruner = relation_pruner.cuda()
        branch_pruner = branch_pruner.cuda()

    eval_relation_data, eval_branch_data = get_eval_data(args)
    eval_relation_dataloader = DataLoader(
        EvalRelationDataset(
            eval_relation_data,
            args.nentity,
            args.nrelation,
        ),
        batch_size=args.test_batch_size,
        num_workers=args.cpu_num,
        collate_fn=EvalRelationDataset.collate_fn)
    eval_branch_dataloader = DataLoader(
        EvalBranchDataset(
            eval_branch_data,
            args.nentity,
            args.nrelation,
        ),
        batch_size=args.test_batch_size,
        num_workers=args.cpu_num,
        collate_fn=EvalBranchDataset.collate_fn)
    logging.info("Eval dataloader constructed")

    step = 0
    if args.do_train:
        step = pretrain_pruners(args, writer, ent_rel_mat,
                                latent_space_executor, relation_pruner,
                                branch_pruner, relation_iterator,
                                branch_iterator, eval_relation_dataloader,
                                eval_branch_dataloader)

    logging.info('Evaluating on Test Dataset...')
    test_all_metrics = eval_branch_pruner(latent_space_executor, branch_pruner,
                                          args, eval_branch_dataloader,
                                          query_name_dict)
    log_and_write_metrics('Test average', step, test_all_metrics, writer)
    test_all_metrics = eval_relation_pruner(latent_space_executor,
                                            relation_pruner, args,
                                            eval_relation_dataloader,
                                            query_name_dict)
    log_and_write_metrics('Test average', step, test_all_metrics, writer)

    print('Training finished!!')
    logging.info("Training finished!!")