def testBatchDimensionNotRequired(self):
        classes = 5
        # Probs must be a tensor, since we pass it directly to _verify_input.
        probs = tf.constant([1.0 / classes] * classes)

        # Make sure that these vals/labels pairs don't throw any runtime exceptions.
        legal_input_pairs = [
            (np.zeros([2, 3]), [x % classes for x in range(2)]),  # batch dim 2
            (np.zeros([4,
                       15]), [x % classes for x in range(4)]),  # batch dim 4
            (np.zeros([10,
                       1]), [x % classes for x in range(10)]),  # batch dim 10
        ]

        # Set up graph with placeholders.
        vals_ph = tf.placeholder(tf.float32)  # completely undefined shape
        labels_ph = tf.placeholder(tf.int32)  # completely undefined shape
        val_tf, labels_tf, _ = sampling_ops._verify_input(  # pylint: disable=protected-access
            [vals_ph], labels_ph, [probs])

        # Run graph to make sure there are no shape-related runtime errors.
        for vals, labels in legal_input_pairs:
            with self.test_session() as sess:
                sess.run([val_tf, labels_tf],
                         feed_dict={
                             vals_ph: vals,
                             labels_ph: labels
                         })
Example #2
0
    def testRuntimeAssertionFailures(self):
        valid_probs = [0.2] * 5
        valid_labels = [1, 2, 3]
        vals = [tf.zeros([3, 1])]

        illegal_labels = [
            [0, -1, 1],  # classes must be nonnegative
            [5, 1, 1],  # classes must be less than number of classes
            [2, 3],  # data and label batch size must be the same
        ]

        illegal_probs = [
            [0.1] * 5,  # probabilities must sum to one
            [-0.5, 0.5, 0.5, 0.4, 0.1],  # probabilities must be non-negative
        ]

        # Set up graph with illegal label vector.
        label_ph = tf.placeholder(tf.int32, shape=[None])
        probs_ph = tf.placeholder(tf.float32, shape=[5])  # shape must be defined
        val_tf, lbl_tf, prob_tf = sampling_ops._verify_input(  # pylint: disable=protected-access
            vals, label_ph, [probs_ph]
        )

        for illegal_label in illegal_labels:
            # Run session that should fail.
            with self.test_session() as sess:
                with self.assertRaises(tf.errors.InvalidArgumentError):
                    sess.run([val_tf, lbl_tf], feed_dict={label_ph: illegal_label, probs_ph: valid_probs})

        for illegal_prob in illegal_probs:
            # Run session that should fail.
            with self.test_session() as sess:
                with self.assertRaises(tf.errors.InvalidArgumentError):
                    sess.run([prob_tf], feed_dict={label_ph: valid_labels, probs_ph: illegal_prob})
  def testBatchDimensionNotRequired(self):
    classes = 5
    # Probs must be a tensor, since we pass it directly to _verify_input.
    probs = constant_op.constant([1.0 / classes] * classes)

    # Make sure that these vals/labels pairs don't throw any runtime exceptions.
    legal_input_pairs = [
        (np.zeros([2, 3]), [x % classes for x in range(2)]),  # batch dim 2
        (np.zeros([4, 15]), [x % classes for x in range(4)]),  # batch dim 4
        (np.zeros([10, 1]), [x % classes for x in range(10)]),  # batch dim 10
    ]

    # Set up graph with placeholders.
    vals_ph = array_ops.placeholder(
        dtypes.float32)  # completely undefined shape
    labels_ph = array_ops.placeholder(
        dtypes.int32)  # completely undefined shape
    val_tf, labels_tf, _ = sampling_ops._verify_input(  # pylint: disable=protected-access
        [vals_ph], labels_ph, [probs])

    # Run graph to make sure there are no shape-related runtime errors.
    for vals, labels in legal_input_pairs:
      with self.test_session() as sess:
        sess.run([val_tf, labels_tf],
                 feed_dict={vals_ph: vals,
                            labels_ph: labels})
    def testRuntimeAssertionFailures(self):
        valid_probs = [.2] * 5
        valid_labels = [1, 2, 3]
        vals = [tf.zeros([3, 1])]

        illegal_labels = [
            [0, -1, 1],  # classes must be nonnegative
            [5, 1, 1],  # classes must be less than number of classes
            [2, 3],  # data and label batch size must be the same
        ]

        illegal_probs = [
            [.1] * 5,  # probabilities must sum to one
            [-.5, .5, .5, .4, .1],  # probabilities must be non-negative
        ]

        # Set up graph with illegal label vector.
        label_ph = tf.placeholder(tf.int32, shape=[None])
        probs_ph = tf.placeholder(tf.float32,
                                  shape=[5])  # shape must be defined
        val_tf, lbl_tf, prob_tf = sampling_ops._verify_input(  # pylint: disable=protected-access
            vals, label_ph, [probs_ph])

        for illegal_label in illegal_labels:
            # Run session that should fail.
            with self.test_session() as sess:
                with self.assertRaises(tf.errors.InvalidArgumentError):
                    sess.run([val_tf, lbl_tf],
                             feed_dict={
                                 label_ph: illegal_label,
                                 probs_ph: valid_probs
                             })

        for illegal_prob in illegal_probs:
            # Run session that should fail.
            with self.test_session() as sess:
                with self.assertRaises(tf.errors.InvalidArgumentError):
                    sess.run([prob_tf],
                             feed_dict={
                                 label_ph: valid_labels,
                                 probs_ph: illegal_prob
                             })
  def testMultiThreadedEstimateDataDistribution(self):
    num_classes = 10

    # Set up graph.
    random_seed.set_random_seed(1234)
    label = math_ops.cast(
        math_ops.round(random_ops.random_uniform([1]) * num_classes),
        dtypes_lib.int32)

    prob_estimate = sampling_ops._estimate_data_distribution(  # pylint: disable=protected-access
        label, num_classes)
    # Check that prob_estimate is well-behaved in a multithreaded context.
    _, _, [prob_estimate] = sampling_ops._verify_input(  # pylint: disable=protected-access
        [], label, [prob_estimate])

    # Use queues to run multiple threads over the graph, each of which
    # fetches `prob_estimate`.
    queue = data_flow_ops.FIFOQueue(
        capacity=25,
        dtypes=[prob_estimate.dtype],
        shapes=[prob_estimate.get_shape()])
    enqueue_op = queue.enqueue([prob_estimate])
    queue_runner_impl.add_queue_runner(
        queue_runner_impl.QueueRunner(queue, [enqueue_op] * 25))
    out_tensor = queue.dequeue()

    # Run the multi-threaded session.
    with self.cached_session() as sess:
      # Need to initialize variables that keep running total of classes seen.
      variables.global_variables_initializer().run()

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(coord=coord)

      for _ in range(25):
        sess.run([out_tensor])

      coord.request_stop()
      coord.join(threads)
Example #6
0
    def testMultiThreadedEstimateDataDistribution(self):
        num_classes = 10

        # Set up graph.
        random_seed.set_random_seed(1234)
        label = math_ops.cast(
            math_ops.round(random_ops.random_uniform([1]) * num_classes),
            dtypes_lib.int32)

        prob_estimate = sampling_ops._estimate_data_distribution(  # pylint: disable=protected-access
            label, num_classes)
        # Check that prob_estimate is well-behaved in a multithreaded context.
        _, _, [prob_estimate] = sampling_ops._verify_input(  # pylint: disable=protected-access
            [], label, [prob_estimate])

        # Use queues to run multiple threads over the graph, each of which
        # fetches `prob_estimate`.
        queue = data_flow_ops.FIFOQueue(capacity=25,
                                        dtypes=[prob_estimate.dtype],
                                        shapes=[prob_estimate.get_shape()])
        enqueue_op = queue.enqueue([prob_estimate])
        queue_runner_impl.add_queue_runner(
            queue_runner_impl.QueueRunner(queue, [enqueue_op] * 25))
        out_tensor = queue.dequeue()

        # Run the multi-threaded session.
        with self.test_session() as sess:
            # Need to initialize variables that keep running total of classes seen.
            variables.global_variables_initializer().run()

            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(coord=coord)

            for _ in range(25):
                sess.run([out_tensor])

            coord.request_stop()
            coord.join(threads)