def testCanBeCalledMultipleTimes(self): batch_size = 20 val_input_batch = [array_ops.zeros([2, 3, 4])] lbl_input_batch = array_ops.ones([], dtype=dtypes.int32) probs = np.array([0, 1, 0, 0, 0]) batches = sampling_ops.stratified_sample(val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs) batches += sampling_ops.stratified_sample(val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs) summary_op = logging_ops.merge_summary( ops.get_collection(ops.GraphKeys.SUMMARIES)) with self.cached_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) sess.run(batches + (summary_op, )) coord.request_stop() coord.join(threads)
def curried_sampler(val, lbls, probs, batch, enqueue_many=False): return sampling_ops.stratified_sample(val, lbls, probs, batch, init_probs=None, enqueue_many=enqueue_many)
def testRejectionDataListInput(self): batch_size = 20 val_input_batch = [ array_ops.zeros([2, 3, 4]), array_ops.ones([2, 4]), array_ops.ones(2) * 3 ] lbl_input_batch = array_ops.ones([], dtype=dtypes.int32) probs = np.array([0, 1, 0, 0, 0]) val_list, lbls = sampling_ops.stratified_sample( val_input_batch, lbl_input_batch, probs, batch_size, init_probs=[0, 1, 0, 0, 0]) # Check output shapes. self.assertTrue(isinstance(val_list, list)) self.assertEqual(len(val_list), len(val_input_batch)) self.assertTrue(isinstance(lbls, ops.Tensor)) with self.cached_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) out = sess.run(val_list + [lbls]) coord.request_stop() coord.join(threads) # Check output shapes. self.assertEqual(len(out), len(val_input_batch) + 1)
def testRejectionDataListInput(self): batch_size = 20 val_input_batch = [ array_ops.zeros([2, 3, 4]), array_ops.ones([2, 4]), array_ops.ones(2) * 3 ] lbl_input_batch = array_ops.ones([], dtype=dtypes.int32) probs = np.array([0, 1, 0, 0, 0]) val_list, lbls = sampling_ops.stratified_sample( val_input_batch, lbl_input_batch, probs, batch_size, init_probs=[0, 1, 0, 0, 0]) # Check output shapes. self.assertTrue(isinstance(val_list, list)) self.assertEqual(len(val_list), len(val_input_batch)) self.assertTrue(isinstance(lbls, ops.Tensor)) with self.test_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) out = sess.run(val_list + [lbls]) coord.request_stop() coord.join(threads) # Check output shapes. self.assertEqual(len(out), len(val_input_batch) + 1)
def curried_sampler(val, lbls, probs, batch, enqueue_many=False): return sampling_ops.stratified_sample( val, lbls, probs, batch, init_probs=initial_p, enqueue_many=enqueue_many)
def testCanBeCalledMultipleTimes(self): batch_size = 20 val_input_batch = [array_ops.zeros([2, 3, 4])] lbl_input_batch = array_ops.ones([], dtype=dtypes.int32) probs = np.array([0, 1, 0, 0, 0]) batches = sampling_ops.stratified_sample( val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs) batches += sampling_ops.stratified_sample( val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs) summary_op = logging_ops.merge_summary( ops.get_collection(ops.GraphKeys.SUMMARIES)) with self.test_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) sess.run(batches + (summary_op,)) coord.request_stop() coord.join(threads)
def testRejectionBatchingBehavior(self): batch_size = 20 input_batch_size = 11 val_input_batch = [array_ops.zeros([input_batch_size, 2, 3, 4])] lbl_input_batch = control_flow_ops.cond( math_ops.greater(.5, random_ops.random_uniform([])), lambda: array_ops.ones([input_batch_size], dtype=dtypes.int32) * 1, lambda: array_ops.ones([input_batch_size], dtype=dtypes.int32) * 3) probs = np.array([0, .2, 0, .8, 0]) data_batch, labels = sampling_ops.stratified_sample( val_input_batch, lbl_input_batch, probs, batch_size, init_probs=[0, .3, 0, .7, 0], enqueue_many=True) with self.cached_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) sess.run([data_batch, labels]) coord.request_stop() coord.join(threads)
def testRejectionBatchingBehavior(self): batch_size = 20 input_batch_size = 11 val_input_batch = [array_ops.zeros([input_batch_size, 2, 3, 4])] lbl_input_batch = control_flow_ops.cond( math_ops.greater(.5, random_ops.random_uniform([])), lambda: array_ops.ones([input_batch_size], dtype=dtypes.int32) * 1, lambda: array_ops.ones([input_batch_size], dtype=dtypes.int32) * 3) probs = np.array([0, .2, 0, .8, 0]) data_batch, labels = sampling_ops.stratified_sample( val_input_batch, lbl_input_batch, probs, batch_size, init_probs=[0, .3, 0, .7, 0], enqueue_many=True) with self.test_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) sess.run([data_batch, labels]) coord.request_stop() coord.join(threads)
def testGraphBuildAssertionFailures(self): val = [array_ops.zeros([1, 3]), array_ops.ones([1, 5])] label = constant_op.constant([1], shape=[1]) # must have batch dimension probs = [.2] * 5 init_probs = [.1, .3, .1, .3, .2] batch_size = 16 # Label must have only batch dimension if enqueue_many is True. with self.assertRaises(ValueError): sampling_ops.stratified_sample(val, array_ops.zeros([]), probs, batch_size, init_probs, enqueue_many=True) with self.assertRaises(ValueError): sampling_ops.stratified_sample(val, array_ops.zeros([1, 1]), probs, batch_size, init_probs, enqueue_many=True) # Label must not be one-hot. with self.assertRaises(ValueError): sampling_ops.stratified_sample( val, constant_op.constant([0, 1, 0, 0, 0]), probs, batch_size, init_probs) # Data must be list, not singleton tensor. with self.assertRaises(TypeError): sampling_ops.stratified_sample(array_ops.zeros([1, 3]), label, probs, batch_size, init_probs) # Data must have batch dimension if enqueue_many is True. with self.assertRaises(ValueError): sampling_ops.stratified_sample(val, constant_op.constant(1), probs, batch_size, init_probs, enqueue_many=True) # Batch dimensions on data and labels should be equal. with self.assertRaises(ValueError): sampling_ops.stratified_sample([array_ops.zeros([2, 1])], label, probs, batch_size, init_probs, enqueue_many=True) # Probabilities must be numpy array, python list, or tensor. with self.assertRaises(ValueError): sampling_ops.stratified_sample(val, label, 1, batch_size, init_probs) # Probabilities shape must be fully defined. with self.assertRaises(ValueError): sampling_ops.stratified_sample( val, label, array_ops.placeholder(dtypes.float32, shape=[None]), batch_size, init_probs) # In the rejection sampling case, make sure that probability lengths are # the same. with self.assertRaises(ValueError): sampling_ops.stratified_sample(val, label, [.1] * 10, batch_size, init_probs=[.2] * 5) # In the rejection sampling case, make sure that zero initial probability # classes also have zero target probability. with self.assertRaises(ValueError): sampling_ops.stratified_sample(val, label, [.2, .4, .4], batch_size, init_probs=[0, .5, .5])
def testGraphBuildAssertionFailures(self): val = [array_ops.zeros([1, 3]), array_ops.ones([1, 5])] label = constant_op.constant([1], shape=[1]) # must have batch dimension probs = [.2] * 5 init_probs = [.1, .3, .1, .3, .2] batch_size = 16 # Label must have only batch dimension if enqueue_many is True. with self.assertRaises(ValueError): sampling_ops.stratified_sample( val, array_ops.zeros([]), probs, batch_size, init_probs, enqueue_many=True) with self.assertRaises(ValueError): sampling_ops.stratified_sample( val, array_ops.zeros([1, 1]), probs, batch_size, init_probs, enqueue_many=True) # Label must not be one-hot. with self.assertRaises(ValueError): sampling_ops.stratified_sample(val, constant_op.constant([0, 1, 0, 0, 0]), probs, batch_size, init_probs) # Data must be list, not singleton tensor. with self.assertRaises(TypeError): sampling_ops.stratified_sample( array_ops.zeros([1, 3]), label, probs, batch_size, init_probs) # Data must have batch dimension if enqueue_many is True. with self.assertRaises(ValueError): sampling_ops.stratified_sample( val, constant_op.constant(1), probs, batch_size, init_probs, enqueue_many=True) # Batch dimensions on data and labels should be equal. with self.assertRaises(ValueError): sampling_ops.stratified_sample( [array_ops.zeros([2, 1])], label, probs, batch_size, init_probs, enqueue_many=True) # Probabilities must be numpy array, python list, or tensor. with self.assertRaises(ValueError): sampling_ops.stratified_sample(val, label, 1, batch_size, init_probs) # Probabilities shape must be fully defined. with self.assertRaises(ValueError): sampling_ops.stratified_sample( val, label, array_ops.placeholder( dtypes.float32, shape=[None]), batch_size, init_probs) # In the rejection sampling case, make sure that probability lengths are # the same. with self.assertRaises(ValueError): sampling_ops.stratified_sample( val, label, [.1] * 10, batch_size, init_probs=[.2] * 5) # In the rejection sampling case, make sure that zero initial probability # classes also have zero target probability. with self.assertRaises(ValueError): sampling_ops.stratified_sample( val, label, [.2, .4, .4], batch_size, init_probs=[0, .5, .5])