def parser(record): """Dataset parser function. Args: record: A single serialized Example proto tensor. Returns: A tuple of: * A dict of features ('patch' and 'weight') * A label tensor (int64 scalar). """ feature_types = { 'patch': tf.FixedLenFeature((patch_height, patch_width), tf.float32), 'label': tf.FixedLenFeature((), tf.int64), } if FLAGS.use_included_label_weight: feature_types['label_weight'] = tf.FixedLenFeature((), tf.float32) features = tf.parse_single_example(record, feature_types) label = features['label'] weight = label_weights.weights_from_labels(label) if FLAGS.use_included_label_weight: # Both operands must be the same type (float32). weight = tf.to_float(weight) * tf.to_float(features['label_weight']) patch = _augment(features['patch']) return {'patch': patch, WEIGHT_COLUMN_NAME: weight}, label
def testWeightsFromLabels(self): g = musicscore_pb2.Glyph labels = tf.constant( [g.NONE, g.NONE, g.NOTEHEAD_FILLED, g.SHARP, g.FLAT, g.NATURAL]) weights = 'NONE=0.1,NATURAL=2.0,SHARP=0.5,NOTEHEAD_FILLED=0.8' weights_tensor = label_weights.weights_from_labels(labels, weights) with self.test_session(): self.assertAllEqual([0.1, 0.1, 0.8, 0.5, 1.0, 2.0], weights_tensor.eval())