class RandomGraph(InMemoryDataset): def __init__(self, n): n_edges = m = n * 10 in_feats = 100 n_classes = 10 row = np.random.choice(n, m) col = np.random.choice(n, m) spm = spsp.coo_matrix((np.ones(len(row)), (row, col)), shape=(n, n)) features = torch.ones((n, in_feats)) labels = torch.LongTensor(np.random.choice(n_classes, n)) train_mask = np.ones(shape=(n)) val_mask = np.ones(shape=(n)) test_mask = np.ones(shape=(n)) if hasattr(torch, 'BoolTensor'): train_mask = torch.BoolTensor(train_mask) val_mask = torch.BoolTensor(val_mask) test_mask = torch.BoolTensor(test_mask) else: train_mask = torch.ByteTensor(train_mask) val_mask = torch.ByteTensor(val_mask) test_mask = torch.ByteTensor(test_mask) self.edge_attr = torch.FloatTensor(spm.data) indices = np.vstack((spm.row, spm.col)) indices = torch.LongTensor(indices) self.edge_index = indices self.x = features self.y = labels self.train_mask = train_mask self.val_mask = val_mask self.test_mask = test_mask self.in_feats = in_feats self.n_classes = n_classes self.adj = SparseTensor(row=self.edge_index[0], col=self.edge_index[1], value=self.edge_attr) def to(self, device): self.x = self.x.cuda() self.y = self.y.cuda() self.train_mask = self.train_mask.cuda() self.val_mask = self.val_mask.cuda() self.test_mask = self.test_mask.cuda() self.edge_index = self.edge_index.cuda() self.edge_attr = self.edge_attr.cuda() self.adj = self.adj.cuda() return self
def main(parser): parser = parse_additional_args(parser) args = parser.parse_args(None) set_global_seed(args.seed) set_train_mode(args) set_save_path(args) gpus = [int(i) for i in args.gpus.split(".")] assert len(gpus) == 1, "pruner pretraining only supports single GPU" print("Logging to", args.save_path) writer = set_logger(args) latent_space_executor = get_lse_model(args) logging.info("Latent space executor created") load_lse_checkpoint(args, latent_space_executor) logging.info("Latent space executor loaded") relation_pruner = RelationPruner(latent_space_executor.entity_dim, args.nrelation) if args.geo == 'box': branch_pruner = BranchPruner(latent_space_executor.entity_dim * 2) elif args.geo == 'rotate': branch_pruner = BranchPruner(latent_space_executor.entity_dim) logging.info("Pruner created") kg = sampler_clib.create_kg(args.nentity, args.nrelation, args.kg_dtype) kg.load_triplets(os.path.join(args.data_path, "train_indexified.txt"), True) logging.info("KG constructed") ent_in, ent_out, adj = construct_graph(args.data_path, ['train_indexified.txt'], True) adj = torch.LongTensor(adj).transpose(0, 1) ent_rel_mat = SparseTensor(row=adj[0], col=adj[1]) logging.info("Adj matrix constructed") relation_query_structures = parse_structures(args.relation_tasks) branch_query_structures = parse_structures(args.branch_tasks) check_valid_branch_structures(branch_query_structures) branch_sampler = BranchSampler( kg, branch_query_structures, 1, # placeholder does not matter eval_tuple(args.online_sample_mode), [ 1. / len(branch_query_structures) for _ in range(len(branch_query_structures)) ], sampler_type='naive', same_in_batch=False, share_negative=True, num_threads=args.cpu_num) relation_sampler = RelationSampler( kg, relation_query_structures, 1, # placeholder does not matter eval_tuple(args.online_sample_mode), [ 1. / len(relation_query_structures) for _ in range(len(relation_query_structures)) ], sampler_type='naive', same_in_batch=False, share_negative=True, num_threads=args.cpu_num) branch_iterator = branch_sampler.batch_generator(args.batch_size) relation_iterator = relation_sampler.batch_generator(args.batch_size) logging.info("Train samplers constructed") if args.cuda: ent_rel_mat = ent_rel_mat.cuda() latent_space_executor = latent_space_executor.cuda() relation_pruner = relation_pruner.cuda() branch_pruner = branch_pruner.cuda() eval_relation_data, eval_branch_data = get_eval_data(args) eval_relation_dataloader = DataLoader( EvalRelationDataset( eval_relation_data, args.nentity, args.nrelation, ), batch_size=args.test_batch_size, num_workers=args.cpu_num, collate_fn=EvalRelationDataset.collate_fn) eval_branch_dataloader = DataLoader( EvalBranchDataset( eval_branch_data, args.nentity, args.nrelation, ), batch_size=args.test_batch_size, num_workers=args.cpu_num, collate_fn=EvalBranchDataset.collate_fn) logging.info("Eval dataloader constructed") step = 0 if args.do_train: step = pretrain_pruners(args, writer, ent_rel_mat, latent_space_executor, relation_pruner, branch_pruner, relation_iterator, branch_iterator, eval_relation_dataloader, eval_branch_dataloader) logging.info('Evaluating on Test Dataset...') test_all_metrics = eval_branch_pruner(latent_space_executor, branch_pruner, args, eval_branch_dataloader, query_name_dict) log_and_write_metrics('Test average', step, test_all_metrics, writer) test_all_metrics = eval_relation_pruner(latent_space_executor, relation_pruner, args, eval_relation_dataloader, query_name_dict) log_and_write_metrics('Test average', step, test_all_metrics, writer) print('Training finished!!') logging.info("Training finished!!")