Пример #1
0
def collate(batch):
    # Sort batch for packing
    node_lengths = [len(x['node_features']) for x in batch]
    sorted_indices = np.argsort(node_lengths)[::-1].copy()

    batch = [batch[i] for i in sorted_indices]

    graph = graph_utils.GraphInfo.merge(*[x['graph'] for x in batch])
    edge_label = torch.tensor(
        [x['target_edge_label'] for x in batch if x['target_edge_label'] != -1], dtype=torch.int64)
    node_features = torch.nn.utils.rnn.pack_sequence([x['node_features'] for x in batch])
    batch_offsets = graph_utils.offsets_from_counts(node_features.batch_sizes)

    node_features_graph_index = torch.cat([
        i + batch_offsets[:graph.node_counts[i]] for i in range(len(batch))
    ], dim=0)

    sparse_node_features = {}

    for k in batch[0]['sparse_node_features']:
        sparse_node_features[k] = graph_utils.SparseFeatureBatch.merge(
            [_reindex_sparse_batch(x['sparse_node_features'][k], batch_offsets) for x in batch], range(len(batch)))

    last_graph_node_index = batch_offsets[graph.node_counts - 1] + torch.arange(len(graph.node_counts),
                                                                                dtype=torch.int64)

    partner_index_index = []
    partner_index = []

    stop_partner_index_index = []

    for i, x in enumerate(batch):
        if x['partner_index'] == -1:
            stop_partner_index_index.append(i)
            continue

        partner_index_index.append(i)
        partner_index.append(x['partner_index'] + graph.node_offsets[i])

    partner_index = graph_utils.SparseFeatureBatch(
        torch.tensor(partner_index_index, dtype=torch.int64),
        torch.tensor(partner_index, dtype=torch.int64)
    )

    stop_partner_index_index = torch.tensor(stop_partner_index_index, dtype=torch.int64)

    return {
        'graph': graph,
        'edge_label': edge_label,
        'partner_index': partner_index,
        'stop_partner_index_index': stop_partner_index_index,
        'node_features': node_features,
        'node_features_graph_index': node_features_graph_index,
        'sparse_node_features': sparse_node_features,
        'last_graph_node_index': last_graph_node_index,
        'sorted_indices': torch.as_tensor(sorted_indices)
    }
Пример #2
0
    def forward(self, data):
        graph = data['graph']
        graph_counts = data['graph_counts']

        node_post_embedding, graph_embedding = self.model_core(data)

        with torch.autograd.profiler.record_function('edge_partner_readout'):
            partner_logits = self.edge_partner(node_post_embedding,
                                               graph_embedding, graph)

        total_constraint_counts = sum(graph_counts[t]
                                      for t in target.TargetType.edge_types())

        # Computation specific to constraint prediction
        with torch.autograd.profiler.record_function('edge_post_embedding'):
            current_node_post_embedding = node_post_embedding.index_select(
                0, graph.node_offsets[1:total_constraint_counts + 1] - 1)
            partner_node_post_embedding = node_post_embedding.index_select(
                0, data['edge_partner'])

            edge_post_embedding = self.edge_post_embedding(
                current_node_post_embedding, partner_node_post_embedding)

        with torch.autograd.profiler.record_function('edge_label_readout'):
            edge_label_logits = self.edge_label(
                edge_post_embedding, graph_embedding[:total_constraint_counts])

        graph_offsets = graph_model.offsets_from_counts(graph_counts)

        if not self.edge_feature_readout:
            edge_feature_logits = None
        else:
            with torch.autograd.profiler.record_function(
                    'edge_feature_readout'):
                edge_feature_logits = {
                    t: self.edge_feature_readout[t.name](
                        data['edge_numerical'][t],
                        torch.narrow(edge_post_embedding, 0, graph_offsets[t],
                                     graph_counts[t]),
                        torch.narrow(graph_embedding, 0, graph_offsets[t],
                                     graph_counts[t]))
                    for t in target.TargetType.numerical_edge_types()
                    if graph_counts[t] > 0
                }

        # Computation specific to entity type prediction
        total_entity_counts = sum(graph_counts[t]
                                  for t in target.TargetType.node_types())

        with torch.autograd.profiler.record_function('node_label_readout'):
            entity_logits = self.entity_label(graph_embedding[
                total_constraint_counts:total_constraint_counts +
                total_entity_counts])

        if not self.entity_feature_readout:
            node_feature_logits = None
        else:
            with torch.autograd.profiler.record_function(
                    'node_feature_readout'):
                node_feature_logits = {
                    t: self.entity_feature_readout[t.name](
                        data['node_numerical'][t],
                        torch.narrow(graph_embedding, 0, graph_offsets[t],
                                     graph_counts[t]))
                    for t in target.TargetType.numerical_node_types()
                    if graph_counts[t] > 0
                }

        result = {
            'graph_embedding': graph_embedding,
            'node_embedding': node_post_embedding,
            'partner_logits': partner_logits,
            'entity_logits': entity_logits,
            'edge_label_logits': edge_label_logits,
        }

        if edge_feature_logits is not None:
            result['edge_feature_logits'] = edge_feature_logits

        if node_feature_logits is not None:
            result['entity_feature_logits'] = node_feature_logits

        return result