Esempio n. 1
0
def softmax_hinge_loss(logits, targets):
    """Computes hinge loss given predictions and labels.

  Args:
    logits: float array; Output of model in shape `[ ..., num_classes]`.
    targets: int array; Labels with shape  `[..., num_classes]`.

  Returns:
    Loss value.
  """
    probs = nn.softmax(logits, axis=-1)
    loss = jnp.sum(jnp.maximum(0, 1. - jnp.multiply(probs, targets)), axis=-1)

    return loss
Esempio n. 2
0
        def step_single_example(hidden_states, instruction_pointer,
                                node_embeddings, true_indexes, false_indexes,
                                exit_index):
            # Execution (e.g. apply RNN)
            # leaves(hidden_states).shape: num_nodes, hidden_size
            # instruction_pointer.shape: num_nodes,
            # node_embeddings.shape: num_nodes, statement_length, hidden_size
            hidden_state_contributions = execute(hidden_states,
                                                 node_embeddings)

            # leaves(hidden_state_contributions).shape: num_nodes, hidden_size

            # Use the exit node's hidden state as it's hidden state contribution
            # to avoid "executing" the exit node.
            def mask_h(h_contribution, h):
                return h_contribution.at[exit_index, :].set(h[exit_index, :])

            hidden_state_contributions = jax.tree_map(
                mask_h, hidden_state_contributions, hidden_states)

            # Branch decisions (e.g. Dense layer)
            branch_decision_logits = branch_decide(hidden_state_contributions)
            branch_decisions = nn.softmax(branch_decision_logits, axis=-1)

            # Update state
            instruction_pointer_new = update_instruction_pointer(
                instruction_pointer, branch_decisions, true_indexes,
                false_indexes)
            hidden_states_new = aggregate(hidden_state_contributions,
                                          instruction_pointer,
                                          branch_decisions, true_indexes,
                                          false_indexes)

            to_tag = {
                'branch_decisions': branch_decisions,
                'hidden_state_contributions': hidden_state_contributions,
                'hidden_states_before': hidden_states,
                'hidden_states': hidden_states_new,
                'instruction_pointer_before': instruction_pointer,
                'instruction_pointer': instruction_pointer_new,
                'true_indexes': true_indexes,
                'false_indexes': false_indexes,
            }
            return hidden_states_new, instruction_pointer_new, to_tag
Esempio n. 3
0
def logit_transformer(logits,
                      temp=1.0,
                      confidence_quantile_threshold=1.0,
                      self_supervised_label_transformation='soft',
                      logit_indices=None):
    """Transforms logits into labels used as targets in a loss functions.

  Args:
    logits: jnp float array; Prediction of a model.
    temp: float; Softmax temp.
    confidence_quantile_threshold: float; Training examples are weighted based
      on this.
    self_supervised_label_transformation: str; Type of labels to produce (soft
      or sharp).
    logit_indices: list(int); Usable Indices for logits (list of indices to
      use).

  Returns:

  """
    # Compute confidence for each prediction:
    confidence = jnp.amax(logits, axis=-1) - jnp.amin(logits, axis=-1)

    # Compute confidence threshold:
    alpha = jnp.quantile(confidence, confidence_quantile_threshold)
    # Only train on confident outputs:
    weights = jnp.float32(confidence >= alpha)

    if self_supervised_label_transformation == 'sharp':
        if logit_indices:
            logits = logits[Ellipsis, logit_indices]
        new_labels = jnp.argmax(logits, axis=-1)
    elif self_supervised_label_transformation == 'soft':
        new_labels = nn.softmax(logits / (temp or 1.0), axis=-1)
    else:
        new_labels = logits

    return new_labels, weights
Esempio n. 4
0
        def step_single_example(hidden_states, instruction_pointer,
                                node_embeddings, true_indexes, false_indexes,
                                exit_index):
            # Execution (e.g. apply RNN)
            # leaves(hidden_states).shape: num_nodes, hidden_size
            # instruction_pointer.shape: num_nodes,
            # node_embeddings.shape: num_nodes, statement_length, hidden_size
            if config.model.interpolant.apply_code_rnn:
                hidden_state_contributions = execute(hidden_states,
                                                     node_embeddings)
                # leaves(hidden_state_contributions).shape: num_nodes, hidden_size
            else:
                hidden_state_contributions = hidden_states

            if config.model.interpolant.apply_dense:
                parent_to_true_child = jax.tree_map(
                    dense_parent_to_true_child, hidden_state_contributions)
                parent_to_false_child = jax.tree_map(
                    dense_parent_to_false_child, hidden_state_contributions)
                true_child_to_parent = jax.tree_map(
                    dense_true_child_to_parent, hidden_state_contributions)
                false_child_to_parent = jax.tree_map(
                    dense_false_child_to_parent, hidden_state_contributions)
            else:
                parent_to_true_child = hidden_state_contributions
                parent_to_false_child = hidden_state_contributions
                true_child_to_parent = hidden_state_contributions
                false_child_to_parent = hidden_state_contributions

            # Use the exit node's hidden state as it's hidden state contribution
            # to avoid "executing" the exit node.
            def mask_h(h_contribution, h):
                return h_contribution.at[exit_index, :].set(h[exit_index, :])

            hidden_state_contributions = jax.tree_multimap(
                mask_h, hidden_state_contributions, hidden_states)

            # Branch decisions (e.g. Dense layer)
            branch_decision_logits = branch_decide(hidden_state_contributions)
            branch_decisions = nn.softmax(branch_decision_logits, axis=-1)

            # Update state
            if config.model.interpolant.use_ipa:
                instruction_pointer_new = update_instruction_pointer(
                    instruction_pointer, branch_decisions, true_indexes,
                    false_indexes)
                hidden_states_new = aggregate(hidden_state_contributions,
                                              instruction_pointer,
                                              branch_decisions, true_indexes,
                                              false_indexes)
            else:
                assert config.model.interpolant.use_parent_embeddings
                assert config.model.interpolant.use_child_embeddings
                instruction_pointer_new = instruction_pointer
                normalization = jnp.sqrt(
                    2 + (  # Each node has a true and false child.
                        # jnp.bincount(true_indexes, minlength=num_nodes)
                        jax.ops.segment_sum(jnp.ones_like(true_indexes),
                                            true_indexes,
                                            num_segments=num_nodes)
                        # + jnp.bincount(false_indexes, minlength=num_nodes)
                        + jax.ops.segment_sum(jnp.ones_like(false_indexes),
                                              false_indexes,
                                              num_segments=num_nodes)))

                # normalization.shape: num_nodes,
                def aggregate_parent_and_child_contributions(p1, p2, c3, c4):
                    return (jax.ops.segment_sum(
                        p1, true_indexes, num_segments=num_nodes) +
                            jax.ops.segment_sum(
                                p2, false_indexes, num_segments=num_nodes) +
                            c3[true_indexes] +
                            c4[false_indexes]) / normalization[:, None]

                hidden_states_new = jax.tree_multimap(
                    aggregate_parent_and_child_contributions,
                    parent_to_true_child,
                    parent_to_false_child,
                    true_child_to_parent,  # true_child_to_parent[child] -> parent
                    false_child_to_parent)
            if config.model.interpolant.apply_gru:

                def apply_gru(h2, h1):
                    output, _ = gru_cell(h2, h1)
                    return output

                hidden_states_new = (jax.tree_multimap(apply_gru,
                                                       hidden_states_new,
                                                       hidden_states))

            to_tag = {
                'branch_decisions': branch_decisions,
                'hidden_state_contributions': hidden_state_contributions,
                'hidden_states_before': hidden_states,
                'hidden_states': hidden_states_new,
                'instruction_pointer_before': instruction_pointer,
                'instruction_pointer': instruction_pointer_new,
                'true_indexes': true_indexes,
                'false_indexes': false_indexes,
            }
            return hidden_states_new, instruction_pointer_new, to_tag