def test_splice_signatures(self): sd1 = ShapeDtype((1, )) sd2 = ShapeDtype((2, )) sd3 = ShapeDtype((3, )) sd4 = ShapeDtype((4, )) sd5 = ShapeDtype((5, )) # Signatures can be ShapeDtype instances, tuples of 2+ ShapeDtype instances, # or empty tuples. sig1 = sd1 sig2 = (sd2, sd3, sd4) sig3 = () sig4 = sd5 spliced = shapes.splice_signatures(sig1, sig2, sig3, sig4) self.assertEqual(spliced, (sd1, sd2, sd3, sd4, sd5))
def _model_with_metrics(model, eval_task): """Returns a model+metrics layer built on an already initialized model. Args: model: Layer with initialized weights and state. eval_task: EvalTask instance. Returns: An initialized, combined model+metrics layer, preserving the initialization of `model`. """ # TODO(jonni): Redo this function as part of an initialization refactor? metrics_layer = tl.Branch(*eval_task.metrics) data_signature = shapes.signature(eval_task.sample_batch[:-1]) label_signature = shapes.signature(eval_task.sample_batch[-1]) metrics_input_signature = (shapes.splice_signatures( model.output_signature(data_signature), label_signature)) _, _ = metrics_layer.init(metrics_input_signature) model_with_metrics = tl.Serial(model, metrics_layer) model_with_metrics._rng = model.rng # pylint: disable=protected-access return model_with_metrics