def _train_test_split(self): def split(train_rate): edges = list(self.graph.edges()) train_size = int(len(edges) * train_rate) random.shuffle(edges) train_edges = edges[:train_size] test_edges = edges[train_size:] return source_targets(train_edges), source_targets(test_edges) args = self._args self._test_nodes = [] if self._hold_out: splits = split(args.tr_rate) self._train_sources, self._train_targets = splits[0] if args.output_dir != '': test_sources, test_targets = splits[1] self._test_nodes = set(test_sources) | set(test_targets) path = os.path.join(args.output_dir, f'test_graph_{int(args.tr_rate * 100)}.txt') gap_helper.log(f"Persisting test data to {path} and the number of test points is {len(test_sources)}") nx.write_edgelist(self._creator(list(zip(test_sources, test_targets))), path=path, data=False) else: gap_helper.log('No test data is persisted') self._train_sources, self._train_targets = source_targets(self.graph.edges()) self._train_nodes = set(self._train_sources) | set(self._train_targets)
def relable_nodes(graph): gap_helper.log('Node relabeling ...') nodes = sorted(graph.nodes()) node_ids = range(len(nodes)) node_id_map = dict(zip(nodes, node_ids)) id_node_map = dict(zip(node_ids, nodes)) return nx.relabel_nodes(graph, node_id_map), id_node_map
def _create_train_dev_indices(self): args = self._args self._dev_indices = [] self._train_indices = np.arange(self._train_sources.shape[0]) if self._use_dev: dev_size = int(len(self._train_sources) * args.dev_rate) gap_helper.log(f'Number of dev points: {dev_size}') self._dev_indices = np.arange(dev_size) self._train_indices = np.arange(dev_size, self._train_sources.shape[0])
def _build_batches(self, idx): gap_helper.log('Building in memory batches') batches = [] sources, targets, negatives = self._train_sources[idx], self._train_targets[idx], self._train_negatives[idx] size = idx.shape[0] for i in range(0, size, self._batch_size): batch = self._fetch_current_batch(start=i, size=size, sources=sources, targets=targets, negatives=negatives) batches.append(batch) return batches
def __init__(self, args): self._args = args self.data = Data(args) self.loss_fun = gap_model.RankingLoss self.model = None self.context_embedding = {} self.global_embedding = {} self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") gap_helper.log(f'Running GAP on a {self.device} machine')
def train(self): args = self._args self.model = gap_model.GAP(num_nodes=self.data.num_nodes, emb_dim=args.dim) self.model.to(self.device) optimizer = torch.optim.Adam(self.model.parameters(), lr=args.learning_rate) if isinstance(self.data.train_inputs, list): """ In Memory batches """ train_inputs = self.data.train_inputs dev_inputs = self.data.dev_inputs else: """ We create multiple copies of the training and dev batch iterators. Useful when the training input is large, > 100000 edges """ train_inputs = tee(self.data.train_inputs, args.epochs) dev_inputs = tee(self.data.dev_inputs, args.epochs) for epoch in range(args.epochs): train_batches = train_inputs if isinstance( train_inputs, list) else train_inputs[epoch] for batch in train_batches: self._infer(batch) criterion = self.loss_fun(self.model) optimizer.zero_grad() criterion.loss.backward() optimizer.step() if args.dev_rate > 0: val_loss, val_auc = self._validate(dev_inputs if isinstance( dev_inputs, list) else dev_inputs[epoch]) gap_helper.log( 'Epoch: {}/{} training loss: {:.5f} validation loss: {:.5f} validation AUC: {:.5f}' .format(epoch + 1, args.epochs, criterion.loss.data, val_loss, val_auc)) else: gap_helper.log("Epoch {}/{} training loss = {:.5f}".format( epoch + 1, args.epochs, criterion.loss.data))
def save_embeddings(self): args = self._args if args.output_dir != '': suffix = '' if args.tr_rate == 1 else f'_{str(int(args.tr_rate * 100))}' path = os.path.join(args.output_dir, f'gap_context{suffix}.emb') gap_helper.log(f'Saving context embedding to {path}') with open(path, 'w') as f: for node in self.context_embedding: for emb in self.context_embedding[node]: output = '{} {}\n'.format( node, ' '.join(str(val) for val in emb)) f.write(output) path = os.path.join(args.output_dir, f'gap_global{suffix}.emb') gap_helper.log(f'Saving aggregated global embedding to {path}') with open(path, 'w') as f: for node in self.global_embedding: output = '{} {}\n'.format( node, ' '.join( str(val) for val in self.global_embedding[node])) f.write(output)
def _negative_sample(self): def get_negative_node_to(u, v): while True: node = self.node_dist_table[random.randint(0, len(self.node_dist_table) - 1)] if node != u and node != v: return node gap_helper.log('Sampling negative nodes') degree = {node: int(1 + self.graph.degree(node) ** 0.75) for node in self.graph.nodes()} # node_dist_table is equivalent of the uni-gram distribution table in the word2vec implementation self.node_dist_table = [node for node, new_degree in degree.items() for _ in range(new_degree)] sources, targets = self._train_sources, self._train_targets src, trg, neg = [], [], [] for i in range(len(sources)): neg_node = get_negative_node_to(sources[i], targets[i]) src.append(sources[i]) trg.append(targets[i]) neg.append(neg_node) self._train_sources, self._train_targets, self._train_negatives = np.array(src), np.array(trg), np.array(neg)
def _read_graph(self): args = self._args self._reader = nx.read_adjlist if args.fmt == 'adjlist' else nx.read_edgelist self._creator = nx.DiGraph if args.directed else nx.Graph gap_helper.log(f'Reading graph from {args.input}') self.graph = self._reader(path=args.input, create_using=self._creator, nodetype=int) self.graph, self.id_to_node = relable_nodes(self.graph) self.num_nodes = self.graph.number_of_nodes() gap_helper.log(f'Number of nodes {self.num_nodes}') gap_helper.log(f'Number of edges {self.graph.number_of_edges()}')
def _build_batch_iterator(self, idx): gap_helper.log('Building batch iterator') sources, targets, negatives = self._train_sources[idx], self._train_targets[idx], self._train_negatives[idx] size = idx.shape[0] for i in range(0, size, self._batch_size): yield self._fetch_current_batch(start=i, size=size, sources=sources, targets=targets, negatives=negatives)