def get(self, ids): """ Get batched samples. Args: ids (numpy.array): A 1d numpy array of whose negative dst nodes will be sampled. Return: A `Nodes` object, shape=[ids.shape, `expand_factor`]. """ if not isinstance(ids, np.ndarray): raise ValueError("ids must be a numpy array, got {}." .format(type(ids))) ids = ids.flatten() req = self._make_req(ids) res = pywrap.new_nbr_res() raise_exception_on_not_ok_status(self._client.sample_neighbor(req, res)) nbrs = pywrap.get_nbr_res_nbr_ids(res) neg_nbrs = self._graph.get_nodes(self._dst_type, nbrs, shape=(ids.shape[0], self._expand_factor)) pywrap.del_nbr_res(res) pywrap.del_nbr_req(req) return neg_nbrs
def get(self, ids): # pylint: disable=unused-argument """ Get batched samples. Args: ids: A 1d numpy array, the input ids whose neighbors will be returned, type=np.int64. Return: A `Layers` object. """ if len(self._meta_path) != len(self._expand_factor): raise ValueError("The meta_path must have the same number" "of elements as num_at_each_hop") src_ids = ids current_batch_size = ids.size layers = Layers() for i in xrange(len(self._meta_path)): req = self._make_req(i, src_ids) res = pywrap.new_nbr_res() status = self._client.sample_neighbor(req, res) if status.ok(): nbr_ids = pywrap.get_nbr_res_nbr_ids(res) edge_ids = pywrap.get_nbr_res_edge_ids(res) pywrap.del_nbr_res(res) pywrap.del_nbr_req(req) raise_exception_on_not_ok_status(status) dst_type = self._dst_types[i] layer_nodes = self._graph.get_nodes(dst_type, nbr_ids, shape=(current_batch_size, self._expand_factor[i])) ids = src_ids.repeat(self._expand_factor[i]).flatten() nbr_ids_flat = nbr_ids.flatten() layer_edges = \ self._graph.get_edges(self._meta_path[i], ids, nbr_ids_flat, shape=(current_batch_size, self._expand_factor[i])) layer_edges.edge_ids = edge_ids layers.append_layer(Layer(layer_nodes, layer_edges)) current_batch_size = nbr_ids_flat.size src_ids = nbr_ids return layers
def get(self, ids): # pylint: disable=unused-argument if len(self._meta_path) != len(self._expand_factor): raise ValueError("The meta_path must have the same number" "of elements as num_at_each_hop") ids = ids.flatten() src_ids = ids current_batch_size = ids.size layers = Layers() for i in xrange(len(self._meta_path)): # req, res & call method. req = self._make_req(i, src_ids) res = pywrap.new_nbr_res() status = self._client.sample_neighbor(req, res) if status.ok(): src_degrees = pywrap.get_nbr_res_degrees(res) dense_shape = (current_batch_size, max(src_degrees)) nbr_ids = pywrap.get_nbr_res_nbr_ids(res) edge_ids = pywrap.get_nbr_res_edge_ids(res) pywrap.del_nbr_res(res) pywrap.del_nbr_req(req) raise_exception_on_not_ok_status(status) dst_type = self._dst_types[i] layer_nodes = self._graph.get_nodes(dst_type, nbr_ids, offsets=src_degrees, shape=dense_shape) ids = np.concatenate([src_ids[idx].repeat(d) for \ idx, d in enumerate(src_degrees)]) nbr_ids_flat = nbr_ids.flatten() layer_edges = \ self._graph.get_edges(self._meta_path[i], ids, nbr_ids_flat, offsets=src_degrees, shape=dense_shape) layer_edges.edge_ids = edge_ids layers.append_layer(Layer(layer_nodes, layer_edges)) current_batch_size = nbr_ids_flat.size src_ids = nbr_ids return layers