def collate(batch): # Sort batch for packing node_lengths = [len(x['node_features']) for x in batch] sorted_indices = np.argsort(node_lengths)[::-1].copy() batch = [batch[i] for i in sorted_indices] graph = graph_utils.GraphInfo.merge(*[x['graph'] for x in batch]) edge_label = torch.tensor( [x['target_edge_label'] for x in batch if x['target_edge_label'] != -1], dtype=torch.int64) node_features = torch.nn.utils.rnn.pack_sequence([x['node_features'] for x in batch]) batch_offsets = graph_utils.offsets_from_counts(node_features.batch_sizes) node_features_graph_index = torch.cat([ i + batch_offsets[:graph.node_counts[i]] for i in range(len(batch)) ], dim=0) sparse_node_features = {} for k in batch[0]['sparse_node_features']: sparse_node_features[k] = graph_utils.SparseFeatureBatch.merge( [_reindex_sparse_batch(x['sparse_node_features'][k], batch_offsets) for x in batch], range(len(batch))) last_graph_node_index = batch_offsets[graph.node_counts - 1] + torch.arange(len(graph.node_counts), dtype=torch.int64) partner_index_index = [] partner_index = [] stop_partner_index_index = [] for i, x in enumerate(batch): if x['partner_index'] == -1: stop_partner_index_index.append(i) continue partner_index_index.append(i) partner_index.append(x['partner_index'] + graph.node_offsets[i]) partner_index = graph_utils.SparseFeatureBatch( torch.tensor(partner_index_index, dtype=torch.int64), torch.tensor(partner_index, dtype=torch.int64) ) stop_partner_index_index = torch.tensor(stop_partner_index_index, dtype=torch.int64) return { 'graph': graph, 'edge_label': edge_label, 'partner_index': partner_index, 'stop_partner_index_index': stop_partner_index_index, 'node_features': node_features, 'node_features_graph_index': node_features_graph_index, 'sparse_node_features': sparse_node_features, 'last_graph_node_index': last_graph_node_index, 'sorted_indices': torch.as_tensor(sorted_indices) }
def _reindex_sparse_batch(sparse_batch, pack_batch_offsets): return graph_utils.SparseFeatureBatch( pack_batch_offsets[sparse_batch.index], sparse_batch.value)