def _PadLabels3d(logits, labels): """Pads or slices 3-d labels to match logits. Covers the case of 2-d softmax output, when labels is [batch, height, width] and logits is [batch, height, width, onehot] Args: logits: 4-d Pre-softmax fully-connected output. labels: 3-d, but not necessarily matching in size. Returns: labels: Resized by padding or clipping to match logits. """ logits_shape = shapes.tensor_shape(logits) labels_shape = shapes.tensor_shape(labels) labels = tf.reshape(labels, [-1, labels_shape[2]]) labels = _PadLabels2d(logits_shape[2], labels) labels = tf.reshape(labels, [labels_shape[0], -1]) labels = _PadLabels2d(logits_shape[1] * logits_shape[2], labels) return tf.reshape(labels, [labels_shape[0], logits_shape[1], logits_shape[2]])