Exemple #1
0
    def test_splice_signatures(self):
        sd1 = ShapeDtype((1, ))
        sd2 = ShapeDtype((2, ))
        sd3 = ShapeDtype((3, ))
        sd4 = ShapeDtype((4, ))
        sd5 = ShapeDtype((5, ))

        # Signatures can be ShapeDtype instances, tuples of 2+ ShapeDtype instances,
        # or empty tuples.
        sig1 = sd1
        sig2 = (sd2, sd3, sd4)
        sig3 = ()
        sig4 = sd5
        spliced = shapes.splice_signatures(sig1, sig2, sig3, sig4)
        self.assertEqual(spliced, (sd1, sd2, sd3, sd4, sd5))
Exemple #2
0
def _model_with_metrics(model, eval_task):
    """Returns a model+metrics layer built on an already initialized model.

  Args:
    model: Layer with initialized weights and state.
    eval_task: EvalTask instance.

  Returns:
    An initialized, combined model+metrics layer, preserving the initialization
    of `model`.
  """
    # TODO(jonni): Redo this function as part of an initialization refactor?
    metrics_layer = tl.Branch(*eval_task.metrics)
    data_signature = shapes.signature(eval_task.sample_batch[:-1])
    label_signature = shapes.signature(eval_task.sample_batch[-1])
    metrics_input_signature = (shapes.splice_signatures(
        model.output_signature(data_signature), label_signature))
    _, _ = metrics_layer.init(metrics_input_signature)

    model_with_metrics = tl.Serial(model, metrics_layer)
    model_with_metrics._rng = model.rng  # pylint: disable=protected-access
    return model_with_metrics