Beispiel #1
0
    def sample_edge_target(self, graph, generator=None):
        """Samples edge targets.

        This function samples potential edges between the last node and a given target.
        The sampled target may potentially be one past the end of the nodes, indicating
        that no new edges should be created.

        Parameters
        ----------
        graph : GraphInfo
            the graph (or batch of graphs) for which to obtain node labels.
        generator : torch.Generator, optional
            Optional PRNG to use for sampling

        Returns
        -------
        torch.Tensor
            An integer tensor containing the sampled edge target for each graph in the batch.
        """
        data = {'graph': graph}
        node_embedding, graph_embedding = self.model_core(data)
        partner_logits = self.edge_partner(node_embedding, graph_embedding,
                                           graph)

        partner_target_samples = sg_functional.segmented_multinomial_extended(
            partner_logits,
            graph_utils.scopes_from_offsets(graph.node_offsets),
            generator=generator)

        return partner_target_samples
Beispiel #2
0
def segment_stop_accuracy(partner_logits, segment_offsets, target_idx=None):
    """Computes the accuracy for stop prediction for partner logits.

    Parameters
    ----------
    partner_logits : torch.Tensor
        The un-normalized logits for the segments.
    segment_offsets : torch.Tensor
        A tensor of shape `[num_segments + 1]` indicating the segment offsets.
    target_idx : torch.Tensor
        If not None, a tensor of shape `[num_segments]` representing the target offset to predict in each segment.
        Otherwise, this is assumed to be the last (implicit) entry in each segment.

    Returns
    -------
    torch.Tensor
        A boolean tensor of shape `[num_segments]` representing the accuracy at each segment.
    """
    scopes = scopes_from_offsets(segment_offsets)

    max_logit_in_segment, max_logit_indices = sg_functional.segment_argmax(
        partner_logits.detach(), scopes)

    if target_idx is None:
        return max_logit_in_segment < 0
    else:
        return (max_logit_in_segment > 0) & (max_logit_indices + segment_offsets[:-1] == target_idx)
Beispiel #3
0
def segment_stop_loss(partner_logits, segment_offsets, target_idx=None):
    """Computes the loss for stop prediction for partner logits.

    This function effectively corresponds to a cross-entropy softmax
    operation on each segment, where the cross-entropy is augmented
    with one last constant logit. If the target predicted is omitted,
    we assume that it is the last (implicit) entry.

    Parameters
    ----------
    partner_logits : torch.Tensor
        The un-normalized logits for the segments.
    segment_offsets : torch.Tensor
        A tensor of shape `[num_segments + 1]` indicating the segment offsets.
    target_idx : torch.Tensor, optional
        If not None, a tensor of shape `[num_segments]` representing the target offset to predict in each segment.
        Otherwise, this is assumed to be the last (implicit) entry in each segment.

    Returns
    -------
    torch.Tensor
        A tensor of shape `[num_segments]` representing the cross-entropy loss at each segment.
    """
    scopes = scopes_from_offsets(segment_offsets)

    log_weight_other = sg_functional.segment_logsumexp(partner_logits, scopes)
    total_weight = torch.nn.functional.softplus(log_weight_other)

    if target_idx is None:
        return total_weight
    else:
        return total_weight - partner_logits.index_select(0, target_idx)
Beispiel #4
0
    def forward(self, node_embedding, graph):
        scopes = graph_model.scopes_from_offsets(graph.node_offsets)

        transformed_embedding = self.node_gating_net(
            node_embedding) * self.node_to_graph_net(node_embedding)

        graph_embedding = sg_nn.functional.segment_avg_pool1d(
            transformed_embedding, scopes) * graph.node_counts.unsqueeze(-1)

        return graph_embedding
Beispiel #5
0
def segment_stop_accuracy(partner_logits, segment_offsets, target_idx, stop_partner_index_index):
    """Computes the accuracy for stop prediction for partner logits."""
    scopes = scopes_from_offsets(segment_offsets)

    max_logit_in_segment, max_logit_indices = sg_nn.functional.segment_argmax(
        partner_logits.detach(), scopes)

    prediction_stop_correct = (max_logit_in_segment < 0).index_select(0, stop_partner_index_index)

    prediction_partner_correct = (
            (max_logit_in_segment > 0).index_select(0, target_idx.indices) &
            ((max_logit_indices + segment_offsets[:-1]).index_select(0, target_idx.indices) == target_idx.values))

    return prediction_stop_correct.float().mean(), prediction_partner_correct.float().mean()
Beispiel #6
0
def segment_stop_loss(partner_logits, segment_offsets, partner_index, stop_partner_index_index, reduction='sum'):
    scopes = scopes_from_offsets(segment_offsets)

    log_weight_other = sg_nn.functional.segment_logsumexp(partner_logits, scopes)
    total_weight = torch.nn.functional.softplus(log_weight_other)

    stop_loss = total_weight.index_select(0, stop_partner_index_index)
    partner_loss = (total_weight.index_select(0, partner_index.index)
                    - partner_logits.index_select(0, partner_index.value))

    if reduction == 'sum':
        return stop_loss.sum(), partner_loss.sum()
    elif reduction == 'mean':
        return stop_loss.mean(), partner_loss.mean()
    elif reduction == 'none':
        return stop_loss, partner_loss
    else:
        raise ValueError('Reduction must be one of sum, mean or none.')