Beispiel #1
0
def collate(batch):
    # Sort batch for packing
    node_lengths = [len(x['node_features']) for x in batch]
    sorted_indices = np.argsort(node_lengths)[::-1].copy()

    batch = [batch[i] for i in sorted_indices]

    graph = graph_utils.GraphInfo.merge(*[x['graph'] for x in batch])
    edge_label = torch.tensor(
        [x['target_edge_label'] for x in batch if x['target_edge_label'] != -1], dtype=torch.int64)
    node_features = torch.nn.utils.rnn.pack_sequence([x['node_features'] for x in batch])
    batch_offsets = graph_utils.offsets_from_counts(node_features.batch_sizes)

    node_features_graph_index = torch.cat([
        i + batch_offsets[:graph.node_counts[i]] for i in range(len(batch))
    ], dim=0)

    sparse_node_features = {}

    for k in batch[0]['sparse_node_features']:
        sparse_node_features[k] = graph_utils.SparseFeatureBatch.merge(
            [_reindex_sparse_batch(x['sparse_node_features'][k], batch_offsets) for x in batch], range(len(batch)))

    last_graph_node_index = batch_offsets[graph.node_counts - 1] + torch.arange(len(graph.node_counts),
                                                                                dtype=torch.int64)

    partner_index_index = []
    partner_index = []

    stop_partner_index_index = []

    for i, x in enumerate(batch):
        if x['partner_index'] == -1:
            stop_partner_index_index.append(i)
            continue

        partner_index_index.append(i)
        partner_index.append(x['partner_index'] + graph.node_offsets[i])

    partner_index = graph_utils.SparseFeatureBatch(
        torch.tensor(partner_index_index, dtype=torch.int64),
        torch.tensor(partner_index, dtype=torch.int64)
    )

    stop_partner_index_index = torch.tensor(stop_partner_index_index, dtype=torch.int64)

    return {
        'graph': graph,
        'edge_label': edge_label,
        'partner_index': partner_index,
        'stop_partner_index_index': stop_partner_index_index,
        'node_features': node_features,
        'node_features_graph_index': node_features_graph_index,
        'sparse_node_features': sparse_node_features,
        'last_graph_node_index': last_graph_node_index,
        'sorted_indices': torch.as_tensor(sorted_indices)
    }
Beispiel #2
0
def _reindex_sparse_batch(sparse_batch, pack_batch_offsets):
    return graph_utils.SparseFeatureBatch(
        pack_batch_offsets[sparse_batch.index], sparse_batch.value)