def sample_edge_target(self, graph, generator=None): """Samples edge targets. This function samples potential edges between the last node and a given target. The sampled target may potentially be one past the end of the nodes, indicating that no new edges should be created. Parameters ---------- graph : GraphInfo the graph (or batch of graphs) for which to obtain node labels. generator : torch.Generator, optional Optional PRNG to use for sampling Returns ------- torch.Tensor An integer tensor containing the sampled edge target for each graph in the batch. """ data = {'graph': graph} node_embedding, graph_embedding = self.model_core(data) partner_logits = self.edge_partner(node_embedding, graph_embedding, graph) partner_target_samples = sg_functional.segmented_multinomial_extended( partner_logits, graph_utils.scopes_from_offsets(graph.node_offsets), generator=generator) return partner_target_samples
def segment_stop_accuracy(partner_logits, segment_offsets, target_idx=None): """Computes the accuracy for stop prediction for partner logits. Parameters ---------- partner_logits : torch.Tensor The un-normalized logits for the segments. segment_offsets : torch.Tensor A tensor of shape `[num_segments + 1]` indicating the segment offsets. target_idx : torch.Tensor If not None, a tensor of shape `[num_segments]` representing the target offset to predict in each segment. Otherwise, this is assumed to be the last (implicit) entry in each segment. Returns ------- torch.Tensor A boolean tensor of shape `[num_segments]` representing the accuracy at each segment. """ scopes = scopes_from_offsets(segment_offsets) max_logit_in_segment, max_logit_indices = sg_functional.segment_argmax( partner_logits.detach(), scopes) if target_idx is None: return max_logit_in_segment < 0 else: return (max_logit_in_segment > 0) & (max_logit_indices + segment_offsets[:-1] == target_idx)
def segment_stop_loss(partner_logits, segment_offsets, target_idx=None): """Computes the loss for stop prediction for partner logits. This function effectively corresponds to a cross-entropy softmax operation on each segment, where the cross-entropy is augmented with one last constant logit. If the target predicted is omitted, we assume that it is the last (implicit) entry. Parameters ---------- partner_logits : torch.Tensor The un-normalized logits for the segments. segment_offsets : torch.Tensor A tensor of shape `[num_segments + 1]` indicating the segment offsets. target_idx : torch.Tensor, optional If not None, a tensor of shape `[num_segments]` representing the target offset to predict in each segment. Otherwise, this is assumed to be the last (implicit) entry in each segment. Returns ------- torch.Tensor A tensor of shape `[num_segments]` representing the cross-entropy loss at each segment. """ scopes = scopes_from_offsets(segment_offsets) log_weight_other = sg_functional.segment_logsumexp(partner_logits, scopes) total_weight = torch.nn.functional.softplus(log_weight_other) if target_idx is None: return total_weight else: return total_weight - partner_logits.index_select(0, target_idx)
def forward(self, node_embedding, graph): scopes = graph_model.scopes_from_offsets(graph.node_offsets) transformed_embedding = self.node_gating_net( node_embedding) * self.node_to_graph_net(node_embedding) graph_embedding = sg_nn.functional.segment_avg_pool1d( transformed_embedding, scopes) * graph.node_counts.unsqueeze(-1) return graph_embedding
def segment_stop_accuracy(partner_logits, segment_offsets, target_idx, stop_partner_index_index): """Computes the accuracy for stop prediction for partner logits.""" scopes = scopes_from_offsets(segment_offsets) max_logit_in_segment, max_logit_indices = sg_nn.functional.segment_argmax( partner_logits.detach(), scopes) prediction_stop_correct = (max_logit_in_segment < 0).index_select(0, stop_partner_index_index) prediction_partner_correct = ( (max_logit_in_segment > 0).index_select(0, target_idx.indices) & ((max_logit_indices + segment_offsets[:-1]).index_select(0, target_idx.indices) == target_idx.values)) return prediction_stop_correct.float().mean(), prediction_partner_correct.float().mean()
def segment_stop_loss(partner_logits, segment_offsets, partner_index, stop_partner_index_index, reduction='sum'): scopes = scopes_from_offsets(segment_offsets) log_weight_other = sg_nn.functional.segment_logsumexp(partner_logits, scopes) total_weight = torch.nn.functional.softplus(log_weight_other) stop_loss = total_weight.index_select(0, stop_partner_index_index) partner_loss = (total_weight.index_select(0, partner_index.index) - partner_logits.index_select(0, partner_index.value)) if reduction == 'sum': return stop_loss.sum(), partner_loss.sum() elif reduction == 'mean': return stop_loss.mean(), partner_loss.mean() elif reduction == 'none': return stop_loss, partner_loss else: raise ValueError('Reduction must be one of sum, mean or none.')