def testCanBeCalledMultipleTimes(self):
        batch_size = 20
        val_input_batch = [array_ops.zeros([2, 3, 4])]
        lbl_input_batch = array_ops.ones([], dtype=dtypes.int32)
        probs = np.array([0, 1, 0, 0, 0])
        batches = sampling_ops.stratified_sample(val_input_batch,
                                                 lbl_input_batch,
                                                 probs,
                                                 batch_size,
                                                 init_probs=probs)
        batches += sampling_ops.stratified_sample(val_input_batch,
                                                  lbl_input_batch,
                                                  probs,
                                                  batch_size,
                                                  init_probs=probs)
        summary_op = logging_ops.merge_summary(
            ops.get_collection(ops.GraphKeys.SUMMARIES))

        with self.cached_session() as sess:
            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(coord=coord)

            sess.run(batches + (summary_op, ))

            coord.request_stop()
            coord.join(threads)
 def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
     return sampling_ops.stratified_sample(val,
                                           lbls,
                                           probs,
                                           batch,
                                           init_probs=None,
                                           enqueue_many=enqueue_many)
    def testRejectionDataListInput(self):
        batch_size = 20
        val_input_batch = [
            array_ops.zeros([2, 3, 4]),
            array_ops.ones([2, 4]),
            array_ops.ones(2) * 3
        ]
        lbl_input_batch = array_ops.ones([], dtype=dtypes.int32)
        probs = np.array([0, 1, 0, 0, 0])
        val_list, lbls = sampling_ops.stratified_sample(
            val_input_batch,
            lbl_input_batch,
            probs,
            batch_size,
            init_probs=[0, 1, 0, 0, 0])

        # Check output shapes.
        self.assertTrue(isinstance(val_list, list))
        self.assertEqual(len(val_list), len(val_input_batch))
        self.assertTrue(isinstance(lbls, ops.Tensor))

        with self.cached_session() as sess:
            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(coord=coord)

            out = sess.run(val_list + [lbls])

            coord.request_stop()
            coord.join(threads)

        # Check output shapes.
        self.assertEqual(len(out), len(val_input_batch) + 1)
  def testRejectionDataListInput(self):
    batch_size = 20
    val_input_batch = [
        array_ops.zeros([2, 3, 4]), array_ops.ones([2, 4]), array_ops.ones(2) *
        3
    ]
    lbl_input_batch = array_ops.ones([], dtype=dtypes.int32)
    probs = np.array([0, 1, 0, 0, 0])
    val_list, lbls = sampling_ops.stratified_sample(
        val_input_batch,
        lbl_input_batch,
        probs,
        batch_size,
        init_probs=[0, 1, 0, 0, 0])

    # Check output shapes.
    self.assertTrue(isinstance(val_list, list))
    self.assertEqual(len(val_list), len(val_input_batch))
    self.assertTrue(isinstance(lbls, ops.Tensor))

    with self.test_session() as sess:
      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(coord=coord)

      out = sess.run(val_list + [lbls])

      coord.request_stop()
      coord.join(threads)

    # Check output shapes.
    self.assertEqual(len(out), len(val_input_batch) + 1)
 def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
   return sampling_ops.stratified_sample(
       val,
       lbls,
       probs,
       batch,
       init_probs=initial_p,
       enqueue_many=enqueue_many)
  def testCanBeCalledMultipleTimes(self):
    batch_size = 20
    val_input_batch = [array_ops.zeros([2, 3, 4])]
    lbl_input_batch = array_ops.ones([], dtype=dtypes.int32)
    probs = np.array([0, 1, 0, 0, 0])
    batches = sampling_ops.stratified_sample(
        val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs)
    batches += sampling_ops.stratified_sample(
        val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs)
    summary_op = logging_ops.merge_summary(
        ops.get_collection(ops.GraphKeys.SUMMARIES))

    with self.test_session() as sess:
      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(coord=coord)

      sess.run(batches + (summary_op,))

      coord.request_stop()
      coord.join(threads)
    def testRejectionBatchingBehavior(self):
        batch_size = 20
        input_batch_size = 11
        val_input_batch = [array_ops.zeros([input_batch_size, 2, 3, 4])]
        lbl_input_batch = control_flow_ops.cond(
            math_ops.greater(.5, random_ops.random_uniform([])),
            lambda: array_ops.ones([input_batch_size], dtype=dtypes.int32) * 1,
            lambda: array_ops.ones([input_batch_size], dtype=dtypes.int32) * 3)
        probs = np.array([0, .2, 0, .8, 0])
        data_batch, labels = sampling_ops.stratified_sample(
            val_input_batch,
            lbl_input_batch,
            probs,
            batch_size,
            init_probs=[0, .3, 0, .7, 0],
            enqueue_many=True)
        with self.cached_session() as sess:
            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(coord=coord)

            sess.run([data_batch, labels])

            coord.request_stop()
            coord.join(threads)
  def testRejectionBatchingBehavior(self):
    batch_size = 20
    input_batch_size = 11
    val_input_batch = [array_ops.zeros([input_batch_size, 2, 3, 4])]
    lbl_input_batch = control_flow_ops.cond(
        math_ops.greater(.5, random_ops.random_uniform([])),
        lambda: array_ops.ones([input_batch_size], dtype=dtypes.int32) * 1,
        lambda: array_ops.ones([input_batch_size], dtype=dtypes.int32) * 3)
    probs = np.array([0, .2, 0, .8, 0])
    data_batch, labels = sampling_ops.stratified_sample(
        val_input_batch,
        lbl_input_batch,
        probs,
        batch_size,
        init_probs=[0, .3, 0, .7, 0],
        enqueue_many=True)
    with self.test_session() as sess:
      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(coord=coord)

      sess.run([data_batch, labels])

      coord.request_stop()
      coord.join(threads)
    def testGraphBuildAssertionFailures(self):
        val = [array_ops.zeros([1, 3]), array_ops.ones([1, 5])]
        label = constant_op.constant([1],
                                     shape=[1])  # must have batch dimension
        probs = [.2] * 5
        init_probs = [.1, .3, .1, .3, .2]
        batch_size = 16

        # Label must have only batch dimension if enqueue_many is True.
        with self.assertRaises(ValueError):
            sampling_ops.stratified_sample(val,
                                           array_ops.zeros([]),
                                           probs,
                                           batch_size,
                                           init_probs,
                                           enqueue_many=True)
        with self.assertRaises(ValueError):
            sampling_ops.stratified_sample(val,
                                           array_ops.zeros([1, 1]),
                                           probs,
                                           batch_size,
                                           init_probs,
                                           enqueue_many=True)

        # Label must not be one-hot.
        with self.assertRaises(ValueError):
            sampling_ops.stratified_sample(
                val, constant_op.constant([0, 1, 0, 0, 0]), probs, batch_size,
                init_probs)

        # Data must be list, not singleton tensor.
        with self.assertRaises(TypeError):
            sampling_ops.stratified_sample(array_ops.zeros([1, 3]), label,
                                           probs, batch_size, init_probs)

        # Data must have batch dimension if enqueue_many is True.
        with self.assertRaises(ValueError):
            sampling_ops.stratified_sample(val,
                                           constant_op.constant(1),
                                           probs,
                                           batch_size,
                                           init_probs,
                                           enqueue_many=True)

        # Batch dimensions on data and labels should be equal.
        with self.assertRaises(ValueError):
            sampling_ops.stratified_sample([array_ops.zeros([2, 1])],
                                           label,
                                           probs,
                                           batch_size,
                                           init_probs,
                                           enqueue_many=True)

        # Probabilities must be numpy array, python list, or tensor.
        with self.assertRaises(ValueError):
            sampling_ops.stratified_sample(val, label, 1, batch_size,
                                           init_probs)

        # Probabilities shape must be fully defined.
        with self.assertRaises(ValueError):
            sampling_ops.stratified_sample(
                val, label, array_ops.placeholder(dtypes.float32,
                                                  shape=[None]), batch_size,
                init_probs)

        # In the rejection sampling case, make sure that probability lengths are
        # the same.
        with self.assertRaises(ValueError):
            sampling_ops.stratified_sample(val,
                                           label, [.1] * 10,
                                           batch_size,
                                           init_probs=[.2] * 5)

        # In the rejection sampling case, make sure that zero initial probability
        # classes also have zero target probability.
        with self.assertRaises(ValueError):
            sampling_ops.stratified_sample(val,
                                           label, [.2, .4, .4],
                                           batch_size,
                                           init_probs=[0, .5, .5])
示例#10
0
  def testGraphBuildAssertionFailures(self):
    val = [array_ops.zeros([1, 3]), array_ops.ones([1, 5])]
    label = constant_op.constant([1], shape=[1])  # must have batch dimension
    probs = [.2] * 5
    init_probs = [.1, .3, .1, .3, .2]
    batch_size = 16

    # Label must have only batch dimension if enqueue_many is True.
    with self.assertRaises(ValueError):
      sampling_ops.stratified_sample(
          val,
          array_ops.zeros([]),
          probs,
          batch_size,
          init_probs,
          enqueue_many=True)
    with self.assertRaises(ValueError):
      sampling_ops.stratified_sample(
          val,
          array_ops.zeros([1, 1]),
          probs,
          batch_size,
          init_probs,
          enqueue_many=True)

    # Label must not be one-hot.
    with self.assertRaises(ValueError):
      sampling_ops.stratified_sample(val,
                                     constant_op.constant([0, 1, 0, 0, 0]),
                                     probs, batch_size, init_probs)

    # Data must be list, not singleton tensor.
    with self.assertRaises(TypeError):
      sampling_ops.stratified_sample(
          array_ops.zeros([1, 3]), label, probs, batch_size, init_probs)

    # Data must have batch dimension if enqueue_many is True.
    with self.assertRaises(ValueError):
      sampling_ops.stratified_sample(
          val,
          constant_op.constant(1),
          probs,
          batch_size,
          init_probs,
          enqueue_many=True)

    # Batch dimensions on data and labels should be equal.
    with self.assertRaises(ValueError):
      sampling_ops.stratified_sample(
          [array_ops.zeros([2, 1])],
          label,
          probs,
          batch_size,
          init_probs,
          enqueue_many=True)

    # Probabilities must be numpy array, python list, or tensor.
    with self.assertRaises(ValueError):
      sampling_ops.stratified_sample(val, label, 1, batch_size, init_probs)

    # Probabilities shape must be fully defined.
    with self.assertRaises(ValueError):
      sampling_ops.stratified_sample(
          val,
          label,
          array_ops.placeholder(
              dtypes.float32, shape=[None]),
          batch_size,
          init_probs)

    # In the rejection sampling case, make sure that probability lengths are
    # the same.
    with self.assertRaises(ValueError):
      sampling_ops.stratified_sample(
          val, label, [.1] * 10, batch_size, init_probs=[.2] * 5)

    # In the rejection sampling case, make sure that zero initial probability
    # classes also have zero target probability.
    with self.assertRaises(ValueError):
      sampling_ops.stratified_sample(
          val, label, [.2, .4, .4], batch_size, init_probs=[0, .5, .5])