Ejemplo n.º 1
0
  def parser(record):
    """Dataset parser function.

    Args:
      record: A single serialized Example proto tensor.

    Returns:
      A tuple of:
      * A dict of features ('patch' and 'weight')
      * A label tensor (int64 scalar).
    """
    feature_types = {
        'patch':
            tf.FixedLenFeature((patch_height, patch_width), tf.float32),
        'label':
            tf.FixedLenFeature((), tf.int64),
    }
    if FLAGS.use_included_label_weight:
      feature_types['label_weight'] = tf.FixedLenFeature((), tf.float32)
    features = tf.parse_single_example(record, feature_types)

    label = features['label']
    weight = label_weights.weights_from_labels(label)
    if FLAGS.use_included_label_weight:
      # Both operands must be the same type (float32).
      weight = tf.to_float(weight) * tf.to_float(features['label_weight'])
    patch = _augment(features['patch'])
    return {'patch': patch, WEIGHT_COLUMN_NAME: weight}, label
Ejemplo n.º 2
0
 def testWeightsFromLabels(self):
   g = musicscore_pb2.Glyph
   labels = tf.constant(
       [g.NONE, g.NONE, g.NOTEHEAD_FILLED, g.SHARP, g.FLAT, g.NATURAL])
   weights = 'NONE=0.1,NATURAL=2.0,SHARP=0.5,NOTEHEAD_FILLED=0.8'
   weights_tensor = label_weights.weights_from_labels(labels, weights)
   with self.test_session():
     self.assertAllEqual([0.1, 0.1, 0.8, 0.5, 1.0, 2.0], weights_tensor.eval())