Ejemplo n.º 1
0
    def __init__(self):
        self.input_size = cfg.input_image_h
        self.classes = utils.read_class_names(cfg.classes_file)
        self.num_classes = len(self.classes)
        self.score_threshold = cfg.score_threshold
        self.iou_threshold = cfg.nms_thresh
        self.moving_ave_decay = cfg.moving_ave_decay
        self.annotation_path = cfg.test_data_file
        self.weight_file = cfg.weight_file
        self.write_image = cfg.write_image
        self.write_image_path = cfg.write_image_path
        self.show_label = cfg.show_label

        self.input_data = tf.placeholder(
            shape=[1, cfg.input_image_h, cfg.input_image_w, 3],
            dtype=tf.float32)

        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

        self.sess.run(tf.global_variables_initializer())

        model = CenterNet(self.input_data, False)
        saver = tf.train.Saver()
        saver.restore(
            self.sess,
            './checkpoint/2021_02_24-centernet_test_loss=0.5797.ckpt-80')

        self.hm = model.pred_hm
        self.wh = model.pred_wh
        self.reg = model.pred_reg

        self.det = decode(self.hm, self.wh, self.reg, K=cfg.max_objs)
Ejemplo n.º 2
0
    def __init__(self):
        K.clear_session()

        K.set_learning_phase(0)
        self.classes = utils.read_class_names(cfg.RDT_Reader.CLASSES)
        self.num_classes = len(self.classes)
        self.time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                  time.localtime(time.time()))
        self.moving_ave_decay = cfg.TRAIN.MOVING_AVE_DECAY
        self.train_logdir = "./dataset/log/train"
        # self.testset             = data_loader.loadDataObjSSD('test')
        self.checkpoint_name = cfg.TEST.EVAL_MODEL_PATH + "/eval.ckpt"
        self.model_path = cfg.TEST.EVAL_MODEL_PATH + "/model/"
        self.eval_tflite = cfg.TEST.EVAL_MODEL_PATH + "/OD_180x320_newarch_resnet_data_qnt.lite"
        self.initial_weight = cfg.TEST.WEIGHT_FILE
        self.output_node_names = ["define_loss/reshapedOutput"]
        self.learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT
        self.learn_rate_end = cfg.TRAIN.LEARN_RATE_END
        self.first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS
        self.second_stage_epochs = cfg.TRAIN.SECOND_STAGE_EPOCHS
        self.warmup_periods = cfg.TRAIN.WARMUP_EPOCHS
        # self.trainset            = data_loader.loadData('train')
        self.testset = data_loader.loadDataObjSSDFromYoloFormat('test')
        # self.trainset             = data_loader.loadDataObjSSDFromYoloFormat('train')

        # self.steps_per_period    = len(self.trainset)
        self.quant_delay = cfg.TRAIN.QUANT_DELAY
        self.quantizedPb = cfg.TEST.QUANTIZED_WEIGHT_FILE
        self.resize_dim = tuple(cfg.TEST.INPUT_SIZE)
        self.number_blocks = cfg.TRAIN.NUMBER_BLOCKS

        self.model = ObjectDetection(True, self.initial_weight).model
        self.anch = cfg.TRAIN.ANCHOR_ASPECTRATIO
    def __init__(self, loadWeights, weightsFile):

        self.classes = utils.read_class_names(cfg.RDT_Reader.CLASSES)
        self.num_class = cfg.TRAIN.NUMBER_CLASSES
        self.loadWeights = loadWeights
        self.weightsFile = weightsFile
        self.resize_dim = tuple(cfg.TEST.INPUT_SIZE)
        self.number_anchors = len(cfg.TRAIN.ANCHOR_ASPECTRATIO[0])

        self.model = self.__build_network__imgClass()
Ejemplo n.º 4
0
 def __init__(self):
     self.classes = utils.read_class_names(cfg.RDT_Reader.CLASSES)
     self.num_classes = len(self.classes)
     self.learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT
     self.learn_rate_end = cfg.TRAIN.LEARN_RATE_END
     self.initial_weight = cfg.TRAIN.INITIAL_WEIGHT
     self.saveModelpath = cfg.TEST.WEIGHT_FILE
     self.time = time.strftime('%Y-%m-%d-%H-%M-%S',
                               time.localtime(time.time()))
     self.train_logdir = "./dataset/log/"
     self.trainset = data_loader.loadDataObjSSDFromYoloFormat('train')
     self.testset = data_loader.loadDataObjSSDFromYoloFormat('test')
     self.model = ObjectDetection(False, self.initial_weight).model
     self.number_blocks = cfg.TRAIN.NUMBER_BLOCKS
Ejemplo n.º 5
0
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
ckpt_path = './checkpoint/'
sess = tf.Session()

inputs = tf.placeholder(shape=[None, None, None, 3], dtype=tf.float32)
model = CenterNet(inputs, False, net_name=cfg.net)
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))

hm = model.pred_hm
wh = model.pred_wh
reg = model.pred_reg
det = decode(hm, wh, reg, K=cfg.max_objs)

class_names = read_class_names(cfg.classes_file)
img_names = os.listdir(
    '/home/pcl/tf_work/TF_CenterNet/VOC/test/VOCdevkit/VOC2007/JPEGImages')
for img_name in img_names:
    img_path = '/home/pcl/tf_work/TF_CenterNet/VOC/test/VOCdevkit/VOC2007/JPEGImages/' + img_name
    print(img_path)
    original_image = cv2.imread(img_path)
    original_image_size = original_image.shape[:2]
    image_data = image_preporcess(np.copy(original_image),
                                  [cfg.input_image_h, cfg.input_image_w])
    image_data = image_data[np.newaxis, ...]

    t0 = time.time()
    detections = sess.run(det, feed_dict={inputs: image_data})
    detections = post_process(detections, original_image_size,
                              [cfg.input_image_h, cfg.input_image_w],
Ejemplo n.º 6
0
# -*- coding:utf-8 -*-
from utils.utils import parse_anchors, read_class_names, get_color_table
"""
检测配置
"""

detect_object = 'img'  # 默认检测对象
input_image = './data/test_img/test8.jpg'  # 默认图片路径
input_video = './data/test_video/video_demo.mp4'  # 默认视频路径
output_image = './data/test_img/result/result8.jpg'  # 保存图片路径
output_video = './data/test_video/result/result.mp4'  # 保存视频路径
anchor_path = './data/yolo_anchors.txt'  # anchor 文件路径
anchors = parse_anchors(anchor_path)  # anchor内容
weight_path = './data/weights_yolo/yolo_face'  # weights路径

class_name_path = './data/face.names'  # 类别文件路径
classes = read_class_names(class_name_path)  # 类别文件list
num_class = len(classes)  # 类别数量

new_size = [416, 416]  # 图片改变后的大小
use_letterbox_resize = True  # 是否使用letterbox