コード例 #1
0
ファイル: glyph_patches.py プロジェクト: tensorflow/moonlight
def input_fn(input_patches):
    """Defines the estimator input function.

  Args:
    input_patches: The input patches TFRecords pattern.

  Returns:
    A callable. Each invocation returns a tuple containing:
    * A dict with a single key 'patch', and the patch tensor as a value.
    * A scalar tensor with the patch label, as an integer.
  """
    patch_height, patch_width = read_patch_dimensions()
    dataset = tf.data.TFRecordDataset(
        file_io.get_matching_files(input_patches))

    def parser(record):
        """Dataset parser function.

    Args:
      record: A single serialized Example proto tensor.

    Returns:
      A tuple of:
      * A dict of features ('patch' and 'weight')
      * A label tensor (int64 scalar).
    """
        feature_types = {
            'patch': tf.FixedLenFeature((patch_height, patch_width),
                                        tf.float32),
            'label': tf.FixedLenFeature((), tf.int64),
        }
        if FLAGS.use_included_label_weight:
            feature_types['label_weight'] = tf.FixedLenFeature((), tf.float32)
        features = tf.parse_single_example(record, feature_types)

        label = features['label']
        weight = label_weights.weights_from_labels(label)
        if FLAGS.use_included_label_weight:
            # Both operands must be the same type (float32).
            weight = tf.to_float(weight) * tf.to_float(
                features['label_weight'])
        patch = _augment(features['patch'])
        return {'patch': patch, WEIGHT_COLUMN_NAME: weight}, label

    return batches.get_batched_tensor(dataset.map(parser))
コード例 #2
0
ファイル: batches_test.py プロジェクト: zrohdes/moonlight
    def testBatching(self):
        all_as = np.random.rand(1000, 2, 3)
        all_bs = np.random.randint(0, 100, [1000], np.int32)
        all_labels = np.random.randint(0, 5, [1000], np.int32)
        random_dataset = tf.data.Dataset.from_tensor_slices(({
            'a':
            tf.constant(all_as),
            'b':
            tf.constant(all_bs)
        }, tf.constant(all_labels)))

        flags.FLAGS.dataset_shuffle_buffer_size = 0
        batch_tensors = batches.get_batched_tensor(random_dataset)
        with self.test_session() as sess:
            batch = sess.run(batch_tensors)

            # First batch.
            self.assertEqual(len(batch), 2)
            self.assertEqual(sorted(batch[0].keys()), ['a', 'b'])
            batch_size = flags.FLAGS.dataset_batch_size
            self.assertAllEqual(batch[0]['a'], all_as[:batch_size])
            self.assertAllEqual(batch[0]['b'], all_bs[:batch_size])
            self.assertAllEqual(batch[1], all_labels[:batch_size])

            batch = sess.run(batch_tensors)

            # Second batch.
            self.assertEqual(len(batch), 2)
            self.assertEqual(sorted(batch[0].keys()), ['a', 'b'])
            batch_size = flags.FLAGS.dataset_batch_size
            self.assertAllEqual(batch[0]['a'],
                                all_as[batch_size:batch_size * 2])
            self.assertAllEqual(batch[0]['b'],
                                all_bs[batch_size:batch_size * 2])
            self.assertAllEqual(batch[1],
                                all_labels[batch_size:batch_size * 2])