def test_pr_curves(self): old_event = event_pb2.Event() old_event.step = 123 old_event.wall_time = 456.75 pr_curve_pb = pr_curve_summary.pb( "foo", labels=np.array([True, False, True, False]), predictions=np.array([0.75, 0.25, 0.85, 0.15]), num_thresholds=10, display_name="bar", description="baz", ) old_event.summary.ParseFromString(pr_curve_pb.SerializeToString()) new_events = self._migrate_event(old_event) self.assertLen(new_events, 1) self.assertLen(new_events[0].summary.value, 1) value = new_events[0].summary.value[0] tensor = tensor_util.make_ndarray(value.tensor) self.assertEqual(tensor.shape, (6, 10)) np.testing.assert_array_equal( tensor, tensor_util.make_ndarray(pr_curve_pb.value[0].tensor) ) self.assertEqual( value.metadata.data_class, summary_pb2.DATA_CLASS_TENSOR ) self.assertEqual( value.metadata.plugin_data.plugin_name, pr_curve_metadata.PLUGIN_NAME, )
def compute_and_check_summary_pb(self, name, labels, predictions, num_thresholds, weights=None, display_name=None, description=None, feed_dict=None): """Use both `op` and `pb` to get a summary, asserting equality. Returns: a `Summary` protocol buffer """ labels_tensor = tf.constant(labels) predictions_tensor = tf.constant(predictions) weights_tensor = None if weights is None else tf.constant(weights) op = summary.op(name=name, labels=labels_tensor, predictions=predictions_tensor, num_thresholds=num_thresholds, weights=weights_tensor, display_name=display_name, description=description) pb = self.normalize_summary_pb( summary.pb(name=name, labels=labels, predictions=predictions, num_thresholds=num_thresholds, weights=weights, display_name=display_name, description=description)) pb_via_op = self.normalize_summary_pb( self.pb_via_op(op, feed_dict=feed_dict)) self.assertProtoEquals(pb, pb_via_op) return pb