def input_fn(input_patches): """Defines the estimator input function. Args: input_patches: The input patches TFRecords pattern. Returns: A callable. Each invocation returns a tuple containing: * A dict with a single key 'patch', and the patch tensor as a value. * A scalar tensor with the patch label, as an integer. """ patch_height, patch_width = read_patch_dimensions() dataset = tf.data.TFRecordDataset( file_io.get_matching_files(input_patches)) def parser(record): """Dataset parser function. Args: record: A single serialized Example proto tensor. Returns: A tuple of: * A dict of features ('patch' and 'weight') * A label tensor (int64 scalar). """ feature_types = { 'patch': tf.FixedLenFeature((patch_height, patch_width), tf.float32), 'label': tf.FixedLenFeature((), tf.int64), } if FLAGS.use_included_label_weight: feature_types['label_weight'] = tf.FixedLenFeature((), tf.float32) features = tf.parse_single_example(record, feature_types) label = features['label'] weight = label_weights.weights_from_labels(label) if FLAGS.use_included_label_weight: # Both operands must be the same type (float32). weight = tf.to_float(weight) * tf.to_float( features['label_weight']) patch = _augment(features['patch']) return {'patch': patch, WEIGHT_COLUMN_NAME: weight}, label return batches.get_batched_tensor(dataset.map(parser))
def testBatching(self): all_as = np.random.rand(1000, 2, 3) all_bs = np.random.randint(0, 100, [1000], np.int32) all_labels = np.random.randint(0, 5, [1000], np.int32) random_dataset = tf.data.Dataset.from_tensor_slices(({ 'a': tf.constant(all_as), 'b': tf.constant(all_bs) }, tf.constant(all_labels))) flags.FLAGS.dataset_shuffle_buffer_size = 0 batch_tensors = batches.get_batched_tensor(random_dataset) with self.test_session() as sess: batch = sess.run(batch_tensors) # First batch. self.assertEqual(len(batch), 2) self.assertEqual(sorted(batch[0].keys()), ['a', 'b']) batch_size = flags.FLAGS.dataset_batch_size self.assertAllEqual(batch[0]['a'], all_as[:batch_size]) self.assertAllEqual(batch[0]['b'], all_bs[:batch_size]) self.assertAllEqual(batch[1], all_labels[:batch_size]) batch = sess.run(batch_tensors) # Second batch. self.assertEqual(len(batch), 2) self.assertEqual(sorted(batch[0].keys()), ['a', 'b']) batch_size = flags.FLAGS.dataset_batch_size self.assertAllEqual(batch[0]['a'], all_as[batch_size:batch_size * 2]) self.assertAllEqual(batch[0]['b'], all_bs[batch_size:batch_size * 2]) self.assertAllEqual(batch[1], all_labels[batch_size:batch_size * 2])