def test_bbox_batches_for_number_0(self): test_helper.get_test_metadata() # separated bboxes of 25.png(601): [[60, 11, 24, 50], [87, 9, 24, 50], [113, 7, 21, 50]], size of 1.png: 190 x 75 self._test_bbox_batches( 'number_0', [87 * 100 / 190, 9 * 100 / 75, 24 * 100 / 190, 50 * 100 / 75], test_helper.test_data_file_number_0)
def test_batches(self): data_file_path = test_helper.get_test_metadata() batch_size, size, B, H, W, C = 2, (416, 416), 5, 13, 13, 10 first_loss_batch = self._calculate_loss_feed_batches(0) second_loss_batch = self._calculate_loss_feed_batches(1) third_loss_batch = self._calculate_loss_feed_batches(2) with self.test_session() as sess: data_batches, origin_image_shape_batch, image_shape_batch, label_batch, label_bboxes_batch = \ yolo.batches(data_file_path, 5, batch_size, size, num_preprocess_threads=1, channels=3, is_training=False) loss_feed_batches = yolo.prepare_for_loss(B, batch_size, label_bboxes_batch, image_shape_batch, label_batch) self.assertEqual(data_batches.get_shape(), (2, 416, 416, 3)) self.assertEqual(loss_feed_batches['probs'].get_shape(), (2, H * W, B, C)) self.assertEqual(loss_feed_batches['confs'].get_shape(), (2, H * W, B)) self.assertEqual(loss_feed_batches['coord'].get_shape(), (2, H * W, B, 4)) self.assertEqual(loss_feed_batches['proid'].get_shape(), (2, H * W, B, C)) self.assertEqual(loss_feed_batches['areas'].get_shape().as_list(), [2, H * W, B]) self.assertEqual(loss_feed_batches['upleft'].get_shape(), (2, H * W, B, 2)) self.assertEqual(loss_feed_batches['botright'].get_shape(), (2, H * W, B, 2)) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) sess.run(tf.local_variables_initializer()) _, lfb = sess.run([data_batches, loss_feed_batches]) all_keys = [ 'probs', 'confs', 'coord', 'proid', 'areas', 'upleft', 'botright' ] for k in all_keys: print(second_loss_batch) self.assertAllClose(lfb[k][0], first_loss_batch[k]) self.assertAllClose(lfb[k][1], second_loss_batch[k]) _, lfb = sess.run([data_batches, loss_feed_batches]) for k in all_keys: self.assertAllClose(lfb[k][0], third_loss_batch[k]) coord.request_stop() coord.join(threads) sess.close()
def test_train_length_model(self): data_file_path = test_helper.get_test_metadata() config = CNNNSRModelConfig(data_file_path=data_file_path, batch_size=2) with self.test_session(): model = CNNLengthTrainModel(config) model.build() train_op = tf.contrib.layers.optimize_loss( loss=model.total_loss, global_step=model.global_step, learning_rate=0.1, optimizer=tf.train.MomentumOptimizer(0.5, momentum=0.5)) tf.contrib.slim.learning.train(train_op, None, number_of_steps=2)
def test_bbox_batches(self): batch_size, size = 2, (28, 28) with self.test_session() as sess: data_file_path = test_helper.get_test_metadata() data_batches, bbox_batches = \ inputs.bbox_batches(data_file_path, batch_size, size, num_preprocess_threads=1, channels=3) self.assertEqual(data_batches.get_shape(), (2, 28, 28, 3)) self.assertEqual(bbox_batches.get_shape(), (2, 4)) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # bbox of 1.png: 246, 77, 173, 223, size of 1.png: 741 x 350 _, bb = sess.run([data_batches, bbox_batches]) self.assertAllClose(bb[0], [246 / 741, 77 / 350, 173 / 741, 223 / 350]) coord.request_stop() coord.join(threads) sess.close()
def test_batches(self): numbers_labels = lambda numbers: np.concatenate([ one_hot(np.array(numbers) + 1, 11), np.array([one_hot(11, 11) for _ in range(5 - len(numbers))]) ]) max_number_length, expected_length_labels, expected_numbers_labels, expected_numbers_labels_1 = \ 5, one_hot(np.array([2, 2]), 5), numbers_labels([1, 9]), numbers_labels([2, 3]) data_file_path = test_helper.get_test_metadata() batch_size, size = 2, (28, 28) with self.test_session() as sess: data_batches, length_label_batches, numbers_label_batches = \ inputs.batches(data_file_path, max_number_length, batch_size, size, num_preprocess_threads=1, channels=3) self.assertEqual(data_batches.get_shape(), (2, 28, 28, 3)) self.assertEqual(length_label_batches.get_shape(), (2, max_number_length)) self.assertEqual(numbers_label_batches.get_shape(), (2, max_number_length, 11)) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) batches = [] for i in range(5): batches.append( sess.run([ data_batches, length_label_batches, numbers_label_batches ])) db, llb, nlb = batches[0] self.assertAllEqual(llb, expected_length_labels) self.assertNDArrayNear(nlb[0], expected_numbers_labels, 1e-5) self.assertNDArrayNear(nlb[1], expected_numbers_labels_1, 1e-5) coord.request_stop() coord.join(threads) sess.close()
def test_evaluation_correct_count(self): data_file_path = test_helper.get_test_metadata() config = CNNNSRModelConfig(data_file_path=data_file_path, batch_size=2) with self.test_session() as sess: model = CNNNSREvalModel(config) model.build() sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(10): print('batch %s correct count: %s' % (i, model.correct_count(sess))) coord.request_stop() coord.join(threads, stop_grace_period_secs=10)
def _test_bbox_batches(self, target_bbox, first_expected_bbox, data_file_path=None): data_file_path = test_helper.get_test_metadata( ) if data_file_path is None else data_file_path batch_size, size = 2, (28, 28) with self.test_session() as sess: data_batches, bbox_batches = \ inputs.bbox_batches(data_file_path, batch_size, size, 5, num_preprocess_threads=1, channels=3, target_bbox=target_bbox) self.assertEqual(data_batches.get_shape(), (2, 28, 28, 3)) self.assertEqual(bbox_batches.get_shape(), (2, 4)) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) _, bb = sess.run([data_batches, bbox_batches]) self.assertAllClose(bb[0], first_expected_bbox) coord.request_stop() coord.join(threads) sess.close()
def create_test_config(self): data_file_path = test_helper.get_test_metadata() config = YOLOModelConfig(data_file_path=data_file_path, net_type='yolo', batch_size=2) return config