def _join_graph_index(graph_list, mode="node"): is_tensor = graph_list[0].is_tensor() if mode == "node": counts = [g.num_nodes for g in graph_list] elif mode == "edge": counts = [g.num_edges for g in graph_list] else: raise ValueError( "mode must be in ['node', 'edge']. But received model=%s" % mode) if is_tensor: counts = paddle.concat(counts) return op.get_index_from_counts(counts)
def segment_padding(data, segment_ids): """ Segment padding operator. This operator padding the input elements which with the same index in 'segment_ids' to a common length , and reshape its into [uniq_segment_id, max_padding, dim]. Args: data (tensor): a tensor, available data type float32, float64. segment_ids (tensor): a 1-d tensor, which have the same size with the first dimension of input data. available data type is int32, int64. Returns: output (Tensor): the padding result with shape [uniq_segment_id, max_padding, dim]. seq_len (Tensor): the numbers of elements grouped same segment_ids index: The index of elements for gather_nd or scatter_nd operation Examples: .. code-block:: python import paddle import pgl data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') segment_ids = paddle.to_tensor([0, 0, 1], dtype='int64') output, seq_len, index = pgl.math.segment_padding(data, segment_ids) """ idx_a = segment_ids idx_b = paddle.arange(paddle.shape(segment_ids)[0]) temp_idx = paddle.ones_like(segment_ids, dtype='float32') segment_len = segment_sum(temp_idx, segment_ids).astype('int32') max_padding = paddle.max(segment_len) segment_shift = get_index_from_counts(segment_len)[:-1] segment_shift = paddle.gather(segment_shift, segment_ids) idx_b = idx_b - segment_shift index = paddle.stack([idx_a, idx_b], axis=1) shape = [paddle.shape(segment_len)[0], max_padding, data.shape[-1]] output = paddle.scatter_nd(index, data, shape) return output, segment_len, index
def forward(self, bond_types_batch, type_count_batch, bond_feat): """ Input example: bond_types_batch: [0,0,2,0,1,2] + [0,0,2,0,1,2] + [2] type_count_batch: [[3, 3, 0], [1, 1, 0], [2, 2, 1]] # [num_type, batch_size] """ bond_feat = self.fc_1( paddle.reshape(bond_feat, [-1, self.num_angle * self.bond_dim])) inter_mat_list = [] for type_i in range(self.num_type): type_i_index = paddle.masked_select(paddle.arange(len(bond_feat)), bond_types_batch == type_i) if paddle.sum(type_count_batch[type_i]) == 0: inter_mat_list.append( paddle.to_tensor(np.array([0.] * len(type_count_batch[type_i])), dtype='float32')) continue bond_feat_type_i = paddle.gather(bond_feat, type_i_index) graph_bond_index = op.get_index_from_counts( type_count_batch[type_i]) # graph_bond_id = generate_segment_id_from_index(graph_bond_index) graph_bond_id = generate_segment_id(graph_bond_index) graph_feat_type_i = math.segment_pool(bond_feat_type_i, graph_bond_id, pool_type='sum') mat_flat_type_i = self.fc_2(graph_feat_type_i).squeeze(1) # print(graph_bond_id) # print(graph_bond_id.shape, graph_feat_type_i.shape, mat_flat_type_i.shape) my_pad = nn.Pad1D(padding=[ 0, len(type_count_batch[type_i]) - len(mat_flat_type_i) ], value=-1e9) mat_flat_type_i = my_pad(mat_flat_type_i) inter_mat_list.append(mat_flat_type_i) inter_mat_batch = paddle.stack(inter_mat_list, axis=1) # [batch_size, num_type] inter_mat_mask = paddle.ones_like(inter_mat_batch) * -1e9 inter_mat_batch = paddle.where( type_count_batch.transpose([1, 0]) > 0, inter_mat_batch, inter_mat_mask) inter_mat_batch = self.softmax(inter_mat_batch) return inter_mat_batch
def from_edges(cls, u, v, num_nodes): self = cls() self._is_tensor = check_is_tensor(u, v, num_nodes) if self._is_tensor: self._degree = paddle.zeros(shape=[num_nodes], dtype="int64") self._degree = scatter(x=self._degree, overwrite=False, index=u, updates=paddle.ones_like(u, dtype="int64")) self._sorted_eid = paddle.argsort(u) self._sorted_u = paddle.gather(u, self._sorted_eid) self._sorted_v = paddle.gather(v, self._sorted_eid) self._indptr = op.get_index_from_counts(self._degree) else: self._degree, self._sorted_v, self._sorted_u, \ self._sorted_eid, self._indptr = graph_kernel.build_index(u, v, num_nodes) return self