def testBatchDimensionNotRequired(self): classes = 5 # Probs must be a tensor, since we pass it directly to _verify_input. probs = tf.constant([1.0 / classes] * classes) # Make sure that these vals/labels pairs don't throw any runtime exceptions. legal_input_pairs = [ (np.zeros([2, 3]), [x % classes for x in range(2)]), # batch dim 2 (np.zeros([4, 15]), [x % classes for x in range(4)]), # batch dim 4 (np.zeros([10, 1]), [x % classes for x in range(10)]), # batch dim 10 ] # Set up graph with placeholders. vals_ph = tf.placeholder(tf.float32) # completely undefined shape labels_ph = tf.placeholder(tf.int32) # completely undefined shape val_tf, labels_tf, _ = sampling_ops._verify_input( # pylint: disable=protected-access [vals_ph], labels_ph, [probs]) # Run graph to make sure there are no shape-related runtime errors. for vals, labels in legal_input_pairs: with self.test_session() as sess: sess.run([val_tf, labels_tf], feed_dict={ vals_ph: vals, labels_ph: labels })
def testRuntimeAssertionFailures(self): valid_probs = [0.2] * 5 valid_labels = [1, 2, 3] vals = [tf.zeros([3, 1])] illegal_labels = [ [0, -1, 1], # classes must be nonnegative [5, 1, 1], # classes must be less than number of classes [2, 3], # data and label batch size must be the same ] illegal_probs = [ [0.1] * 5, # probabilities must sum to one [-0.5, 0.5, 0.5, 0.4, 0.1], # probabilities must be non-negative ] # Set up graph with illegal label vector. label_ph = tf.placeholder(tf.int32, shape=[None]) probs_ph = tf.placeholder(tf.float32, shape=[5]) # shape must be defined val_tf, lbl_tf, prob_tf = sampling_ops._verify_input( # pylint: disable=protected-access vals, label_ph, [probs_ph] ) for illegal_label in illegal_labels: # Run session that should fail. with self.test_session() as sess: with self.assertRaises(tf.errors.InvalidArgumentError): sess.run([val_tf, lbl_tf], feed_dict={label_ph: illegal_label, probs_ph: valid_probs}) for illegal_prob in illegal_probs: # Run session that should fail. with self.test_session() as sess: with self.assertRaises(tf.errors.InvalidArgumentError): sess.run([prob_tf], feed_dict={label_ph: valid_labels, probs_ph: illegal_prob})
def testBatchDimensionNotRequired(self): classes = 5 # Probs must be a tensor, since we pass it directly to _verify_input. probs = constant_op.constant([1.0 / classes] * classes) # Make sure that these vals/labels pairs don't throw any runtime exceptions. legal_input_pairs = [ (np.zeros([2, 3]), [x % classes for x in range(2)]), # batch dim 2 (np.zeros([4, 15]), [x % classes for x in range(4)]), # batch dim 4 (np.zeros([10, 1]), [x % classes for x in range(10)]), # batch dim 10 ] # Set up graph with placeholders. vals_ph = array_ops.placeholder( dtypes.float32) # completely undefined shape labels_ph = array_ops.placeholder( dtypes.int32) # completely undefined shape val_tf, labels_tf, _ = sampling_ops._verify_input( # pylint: disable=protected-access [vals_ph], labels_ph, [probs]) # Run graph to make sure there are no shape-related runtime errors. for vals, labels in legal_input_pairs: with self.test_session() as sess: sess.run([val_tf, labels_tf], feed_dict={vals_ph: vals, labels_ph: labels})
def testRuntimeAssertionFailures(self): valid_probs = [.2] * 5 valid_labels = [1, 2, 3] vals = [tf.zeros([3, 1])] illegal_labels = [ [0, -1, 1], # classes must be nonnegative [5, 1, 1], # classes must be less than number of classes [2, 3], # data and label batch size must be the same ] illegal_probs = [ [.1] * 5, # probabilities must sum to one [-.5, .5, .5, .4, .1], # probabilities must be non-negative ] # Set up graph with illegal label vector. label_ph = tf.placeholder(tf.int32, shape=[None]) probs_ph = tf.placeholder(tf.float32, shape=[5]) # shape must be defined val_tf, lbl_tf, prob_tf = sampling_ops._verify_input( # pylint: disable=protected-access vals, label_ph, [probs_ph]) for illegal_label in illegal_labels: # Run session that should fail. with self.test_session() as sess: with self.assertRaises(tf.errors.InvalidArgumentError): sess.run([val_tf, lbl_tf], feed_dict={ label_ph: illegal_label, probs_ph: valid_probs }) for illegal_prob in illegal_probs: # Run session that should fail. with self.test_session() as sess: with self.assertRaises(tf.errors.InvalidArgumentError): sess.run([prob_tf], feed_dict={ label_ph: valid_labels, probs_ph: illegal_prob })
def testMultiThreadedEstimateDataDistribution(self): num_classes = 10 # Set up graph. random_seed.set_random_seed(1234) label = math_ops.cast( math_ops.round(random_ops.random_uniform([1]) * num_classes), dtypes_lib.int32) prob_estimate = sampling_ops._estimate_data_distribution( # pylint: disable=protected-access label, num_classes) # Check that prob_estimate is well-behaved in a multithreaded context. _, _, [prob_estimate] = sampling_ops._verify_input( # pylint: disable=protected-access [], label, [prob_estimate]) # Use queues to run multiple threads over the graph, each of which # fetches `prob_estimate`. queue = data_flow_ops.FIFOQueue( capacity=25, dtypes=[prob_estimate.dtype], shapes=[prob_estimate.get_shape()]) enqueue_op = queue.enqueue([prob_estimate]) queue_runner_impl.add_queue_runner( queue_runner_impl.QueueRunner(queue, [enqueue_op] * 25)) out_tensor = queue.dequeue() # Run the multi-threaded session. with self.cached_session() as sess: # Need to initialize variables that keep running total of classes seen. variables.global_variables_initializer().run() coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) for _ in range(25): sess.run([out_tensor]) coord.request_stop() coord.join(threads)
def testMultiThreadedEstimateDataDistribution(self): num_classes = 10 # Set up graph. random_seed.set_random_seed(1234) label = math_ops.cast( math_ops.round(random_ops.random_uniform([1]) * num_classes), dtypes_lib.int32) prob_estimate = sampling_ops._estimate_data_distribution( # pylint: disable=protected-access label, num_classes) # Check that prob_estimate is well-behaved in a multithreaded context. _, _, [prob_estimate] = sampling_ops._verify_input( # pylint: disable=protected-access [], label, [prob_estimate]) # Use queues to run multiple threads over the graph, each of which # fetches `prob_estimate`. queue = data_flow_ops.FIFOQueue(capacity=25, dtypes=[prob_estimate.dtype], shapes=[prob_estimate.get_shape()]) enqueue_op = queue.enqueue([prob_estimate]) queue_runner_impl.add_queue_runner( queue_runner_impl.QueueRunner(queue, [enqueue_op] * 25)) out_tensor = queue.dequeue() # Run the multi-threaded session. with self.test_session() as sess: # Need to initialize variables that keep running total of classes seen. variables.global_variables_initializer().run() coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) for _ in range(25): sess.run([out_tensor]) coord.request_stop() coord.join(threads)