def test_yolo_infer(self): from nsrec.nets import simple_yolo self._create_yolo_test_graph() inferrable = Inferrable(test_helper.test_graph_file, None, YOLOToExportModel.INPUT_NODE_NAME, YOLOToExportModel.OUTPUT_NODE_NAME) input_data = inputs.read_img(os.path.join(test_helper.train_data_dir_path, '1.png')) input_data = inputs.normalize_img(input_data, [simple_yolo.image_width, simple_yolo.image_height]) pbs = inferrable.infer(np.array([input_data])) print(pbs.shape) print(pbs)
def test_infer(self): from nsrec.nets import lenet_v2 self._create_test_graph() inferrable = Inferrable(test_helper.test_graph_file, CNNBBoxToExportModel.INITIALIZER_NODE_NAME, CNNBBoxToExportModel.INPUT_NODE_NAME, CNNBBoxToExportModel.OUTPUT_NODE_NAME) input_data = inputs.read_img(os.path.join(test_helper.train_data_dir_path, '1.png')) input_data = inputs.normalize_img(input_data, [lenet_v2.image_width, lenet_v2.image_height]) pbs = inferrable.infer(np.array([input_data])) print(pbs)
def test_img_data_generator(new_size, crop_bbox=False): for i in range(10): filename = '%s.png' % (i + 1) img_idx = metadata['filenames'].index(filename) bbox, label = metadata['bboxes'][img_idx], metadata['labels'][img_idx] input_data = inputs.read_img(os.path.join(test_helper.train_data_dir_path, filename)) width, height = input_data.shape[1], input_data.shape[0] if crop_bbox: input_data = inputs.read_img(os.path.join(test_helper.train_data_dir_path, filename), bbox) input_data = inputs.normalize_img(input_data, [new_size[0], new_size[1]]) yield (input_data, (width, height), bbox, label)
def infer(self, sess, data): def to_coordinate_bboxes(label, image_shape): def to_coordinate_bbox(bbox): w, h = image_shape[1], image_shape[0] return list( map(int, [bbox[0] * w, bbox[1] * h, bbox[2] * w, bbox[3] * h])) return [to_coordinate_bbox(l['bbox']) for l in label] input_data = [ inputs.normalize_img(image, self.config.size) for image in data ] net_out = sess.run(self.net_out, feed_dict={self.inputs: input_data}) labels = [] for net_out_i in net_out: labels.append( extract_label(net_out_i, self.max_number_length, self.config.num_classes, self.config.threshold)) join_label = lambda label: ''.join( map(lambda l: str(l['label']), label)) return [(join_label(labels[i]), to_coordinate_bboxes(labels[i], data[i].shape)) for i in range(len(labels))]