Exemple #1
0
                def metric_fn(per_example_loss, label_ids, logits,
                              is_real_example):
                    """Compute Pearson correlations for STS-B."""
                    # Display labels and predictions
                    concat1 = contrib_metrics.streaming_concat(logits)
                    concat2 = contrib_metrics.streaming_concat(label_ids)

                    # Compute Pearson correlation
                    pearson = contrib_metrics.streaming_pearson_correlation(
                        logits, label_ids, weights=is_real_example)

                    # Compute MSE
                    # mse = tf.metrics.mean(per_example_loss)
                    mse = tf.metrics.mean_squared_error(
                        label_ids, logits, weights=is_real_example)

                    loss = tf.metrics.mean(values=per_example_loss,
                                           weights=is_real_example)

                    return {
                        "pred": concat1,
                        "label_ids": concat2,
                        "pearson": pearson,
                        "MSE": mse,
                        "eval_loss": loss,
                    }
Exemple #2
0
def streaming_concat(name, value, axis=0):
    tf.summary.scalar(
        name,
        tf.reduce_mean(
            tfm.streaming_concat(value,
                                 axis=axis,
                                 name='stream/{}'.format(name))[1]))
Exemple #3
0
def next_production_rule_info_batch_text_summary(expression_strings,
                                                 partial_sequences,
                                                 partial_sequence_lengths,
                                                 next_production_rules,
                                                 unmasked_probabilities_batch,
                                                 masked_probabilities_batch,
                                                 grammar,
                                                 target_length=None):
    """Ceates text summary for a batch next production rule prediction.

  Args:
    expression_strings: String tensor with shape [batch_size].
    partial_sequences: Integer tensor with shape [batch_size, max_length].
    partial_sequence_lengths: Integer tensor with shape [batch_size].
    next_production_rules: Integer tensor with shape [batch_size]. The
        indice of the next production rules.
    unmasked_probabilities_batch: Float tensor with shape
        [batch_size, num_production_rules]. The probabilities from the model
        prediction without valid production rule mask.
    masked_probabilities_batch: Boolean tensor with shape
        [batch_size, num_production_rules]. The probabilities from the model
        prediction after applied valid production rule mask.
    grammar: arithmetic_grammar.Grammar object.
    target_length: Integer. Only examples with partial sequence length equal to
        target_length will be used. If None (the default), all examples in
        batch will be used.

  Returns:
    summary: String Tensor containing a Summary proto.
    update_op: Op that updates summary (and the underlying stream).
  """
    if target_length is not None:
        (expression_strings, partial_sequences, partial_sequence_lengths,
         next_production_rules, unmasked_probabilities_batch,
         masked_probabilities_batch) = mask_by_partial_sequence_length(
             tensors=(expression_strings, partial_sequences,
                      partial_sequence_lengths, next_production_rules,
                      unmasked_probabilities_batch,
                      masked_probabilities_batch),
             partial_sequence_lengths=partial_sequence_lengths,
             target_length=target_length)
        suffix = '/length_%d' % target_length
    else:
        suffix = ''

    info = tf.py_func(
        functools.partial(next_production_rule_info_batch, grammar=grammar), [
            expression_strings, partial_sequences, partial_sequence_lengths,
            next_production_rules, unmasked_probabilities_batch,
            masked_probabilities_batch
        ],
        tf.string,
        name='py_func-next_production_rule_info_batch_text_summary' + suffix)
    info.set_shape([expression_strings.shape[0]])
    value, update_op = contrib_metrics.streaming_concat(info)
    value = tf.random_shuffle(value)  # So we see different summaries.
    summary = tf.summary.text('next_production_rule_info' + suffix, value[:10])
    return summary, update_op