def sample_blocks(self, g, seed_nodes, exclude_eids=None) : # print('seed is ', seed_nodes) blocks = [] exclude_eids = ( _tensor_or_dict_to_numpy(exclude_eids) if exclude_eids is not None else None) # seed node 是起始点 # 处理后seed node 是本次采样用到的全部点 tuples, labels, subsampling_ws, seed_nodes = self.random_walk_sampler.sampler(g, seed_nodes) for block_id in reversed(range(self.num_layers)): frontier = self.sample_frontier(block_id, g, seed_nodes) # Removing edges from the frontier for link prediction training falls # into the category of frontier postprocessing if exclude_eids is not None: parent_eids = frontier.edata[EID] parent_eids_np = _tensor_or_dict_to_numpy(parent_eids) located_eids = _locate_eids_to_exclude(parent_eids_np, exclude_eids) if not isinstance(located_eids, Mapping): # (BarclayII) If frontier already has a EID field and located_eids is empty, # the returned graph will keep EID intact. Otherwise, EID will change # to the mapping from the new graph to the old frontier. # So we need to test if located_eids is empty, and do the remapping ourselves. if len(located_eids) > 0: frontier = transform.remove_edges(frontier, located_eids) frontier.edata[EID] = F.gather_row(parent_eids, frontier.edata[EID]) else: # (BarclayII) remove_edges only accepts removing one type of edges, # so I need to keep track of the edge IDs left one by one. new_eids = parent_eids.copy() for k, v in located_eids.items(): if len(v) > 0: frontier = transform.remove_edges(frontier, v, etype=k) new_eids[k] = F.gather_row(parent_eids[k], frontier.edges[k].data[EID]) frontier.edata[EID] = new_eids block = transform.to_block(frontier, seed_nodes) if self.return_eids: assign_block_eids(block, frontier) seed_nodes = {ntype: block.srcnodes[ntype].data[NID] for ntype in block.srctypes} # Pre-generate CSR format so that it can be used in training directly block.create_formats_() blocks.insert(0, block) # id inverse mapping the tuples id from blocks nid = blocks[-1].ndata[NID]['_N'] dicts = dict() for indice, id in enumerate(nid): id = id.item() dicts[id] = indice li = [] for items in tuples : src = dicts[items[0].item()] dst = dicts[items[1].item()] li.append([src, dst]) tuples = torch.as_tensor(li, dtype=torch.long) return tuples, labels, subsampling_ws, blocks
def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map): if isinstance(eids, Mapping): eids = {g.to_canonical_etype(k): v for k, v in eids.items()} exclude_eids = { k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0) for k, v in eids.items() } else: exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0) return exclude_eids
def _pull_handler(self, name, ID): """Default handler for PULL operation. On default, _pull_handler perform gather_row() operation for the tensor. Parameters ---------- name : str data name ID : tensor (mx.ndarray or torch.tensor) a vector storing the IDs that has been re-mapped to local id. Return ------ tensor a tensor with the same row size of ID """ new_tensor = F.gather_row(self._data_store[name], ID) return new_tensor