コード例 #1
0
def inference(args, config, network):
    # model_path
    misc_utils.ensure_dir('outputs')
    saveDir = os.path.join('../model', args.model_dir, config.model_dir)
    model_file = os.path.join(saveDir,
                              'dump-{}.pth'.format(args.resume_weights))
    assert os.path.exists(model_file)
    # build network
    net = network()
    net.eval()
    check_point = torch.load(model_file, map_location=torch.device('cpu'))
    net.load_state_dict(check_point['state_dict'])
    # get data
    image, resized_img, im_info = get_data(args.img_path,
                                           config.eval_image_short_size,
                                           config.eval_image_max_size)
    pred_boxes = net(resized_img, im_info).numpy()
    pred_boxes = post_process(pred_boxes, config, im_info[0, 2])
    pred_tags = pred_boxes[:, 5].astype(np.int32).flatten()
    pred_tags_name = np.array(config.class_names)[pred_tags]
    # inplace draw
    image = visual_utils.draw_boxes(image,
                                    pred_boxes[:, :4],
                                    scores=pred_boxes[:, 4],
                                    tags=pred_tags_name,
                                    line_thick=1,
                                    line_color='white')
    name = args.img_path.split('/')[-1].split('.')[-2]
    fpath = 'outputs/{}.png'.format(name)
    cv2.imwrite(fpath, image)
コード例 #2
0
ファイル: test.py プロジェクト: lu1kaifeng/CrowdDet
def eval_all(args, config, network):
    # model_path
    saveDir = os.path.join('../model', args.model_dir, config.model_dir)
    evalDir = os.path.join('../model', args.model_dir, config.eval_dir)
    misc_utils.ensure_dir(evalDir)
    model_file = os.path.join(saveDir,
                              'dump-{}.pth'.format(args.resume_weights))
    assert os.path.exists(model_file)
    # get devices
    str_devices = args.devices
    devices = misc_utils.device_parser(str_devices)
    # load data
    crowdhuman = CrowdHuman(config, if_train=False)
    #crowdhuman.records = crowdhuman.records[:10]
    # multiprocessing
    num_devs = len(devices)
    len_dataset = len(crowdhuman)
    num_image = math.ceil(len_dataset / num_devs)
    result_queue = Queue(500)
    procs = []
    all_results = []
    for i in range(num_devs):
        start = i * num_image
        end = min(start + num_image, len_dataset)
        proc = Process(target=inference,
                       args=(config, network, model_file, devices[i],
                             crowdhuman, start, end, result_queue))
        proc.start()
        procs.append(proc)
    pbar = tqdm(total=len_dataset, ncols=50)
    for i in range(len_dataset):
        t = result_queue.get()
        all_results.append(t)
        pbar.update(1)
    pbar.close()
    for p in procs:
        p.join()
    fpath = os.path.join(evalDir, 'dump-{}.json'.format(args.resume_weights))
    misc_utils.save_json_lines(all_results, fpath)
    # evaluation
    eval_path = os.path.join(evalDir,
                             'eval-{}.json'.format(args.resume_weights))
    eval_fid = open(eval_path, 'w')
    res_line, JI = compute_JI.evaluation_all(fpath, 'box')
    for line in res_line:
        eval_fid.write(line + '\n')
    AP, MR = compute_APMR.compute_APMR(fpath, config.eval_source, 'box')
    line = 'AP:{:.4f}, MR:{:.4f}, JI:{:.4f}.'.format(AP, MR, JI)
    print(line)
    eval_fid.write(line + '\n')
    eval_fid.close()
コード例 #3
0
def eval_all(args):
    # ground truth file
    gt_path = '/data/CrowdHuman/annotation_val.odgt'
    assert os.path.exists(gt_path), "Wrong ground truth path!"
    misc_utils.ensure_dir('outputs')
    # output file
    eval_path = os.path.join('outputs', 'result_eval.md')
    eval_fid = open(eval_path, 'a')
    eval_fid.write((args.json_file + '\n'))
    # eval JI
    res_line, JI = compute_JI.evaluation_all(args.json_file, 'box')
    for line in res_line:
        eval_fid.write(line + '\n')
    # eval AP, MR
    AP, MR = compute_APMR.compute_APMR(args.json_file, gt_path, 'box')
    line = 'AP:{:.4f}, MR:{:.4f}, JI:{:.4f}.'.format(AP, MR, JI)
    print(line)
    eval_fid.write(line + '\n\n')
    eval_fid.close()
コード例 #4
0
def multi_train(params, config, network):
    # check gpus
    if not torch.cuda.is_available():
        print('No GPU exists!')
        return
    else:
        num_gpus = torch.cuda.device_count()
    torch.set_default_tensor_type('torch.FloatTensor')
    # setting training config
    train_config = Train_config()
    train_config.world_size = num_gpus
    train_config.total_epoch = config.max_epoch
    train_config.iter_per_epoch = \
            config.nr_images_epoch // (num_gpus * config.train_batch_per_gpu)
    train_config.mini_batch_size = config.train_batch_per_gpu
    train_config.warm_iter = config.warm_iter
    train_config.learning_rate = \
            config.base_lr * config.train_batch_per_gpu * num_gpus
    train_config.momentum = config.momentum
    train_config.weight_decay = config.weight_decay
    train_config.lr_decay = config.lr_decay
    train_config.model_dir = os.path.join('../model/', params.model_dir,
                                          config.model_dir)
    line = 'network.lr.{}.train.{}'.format(train_config.learning_rate,
                                           train_config.total_epoch)
    train_config.log_path = os.path.join('../model/', params.model_dir,
                                         config.output_dir, line + '.log')
    train_config.resume_weights = params.resume_weights
    train_config.init_weights = config.init_weights
    train_config.log_dump_interval = config.log_dump_interval
    misc_utils.ensure_dir(train_config.model_dir)
    # print the training config
    line = 'Num of GPUs:{}, learning rate:{:.5f}, mini batch size:{}, \
            \ntrain_epoch:{}, iter_per_epoch:{}, decay_epoch:{}'.format(
        num_gpus, train_config.learning_rate, train_config.mini_batch_size,
        train_config.total_epoch, train_config.iter_per_epoch,
        train_config.lr_decay)
    print(line)
    print("Init multi-processing training...")
    torch.multiprocessing.spawn(train_worker,
                                nprocs=num_gpus,
                                args=(train_config, network, config))
コード例 #5
0
ファイル: visulize_json.py プロジェクト: zzq96/CrowdDet
def eval_all(args):
    # json file
    assert os.path.exists(args.json_file), "Wrong json path!"
    misc_utils.ensure_dir('outputs')
    records = misc_utils.load_json_lines(args.json_file)[:args.number]
    for record in records:
        dtboxes = misc_utils.load_bboxes(
                record, key_name='dtboxes', key_box='box', key_score='score', key_tag='tag')
        gtboxes = misc_utils.load_bboxes(record, 'gtboxes', 'box')
        dtboxes = misc_utils.xywh_to_xyxy(dtboxes)
        gtboxes = misc_utils.xywh_to_xyxy(gtboxes)
        keep = dtboxes[:, -2] > args.visual_thresh
        dtboxes = dtboxes[keep]
        len_dt = len(dtboxes)
        len_gt = len(gtboxes)
        line = "{}: dt:{}, gt:{}.".format(record['ID'], len_dt, len_gt)
        print(line)
        img_path = img_root + record['ID'] + '.png'
        img = misc_utils.load_img(img_path)
        visual_utils.draw_boxes(img, dtboxes, line_thick=1, line_color='blue')
        visual_utils.draw_boxes(img, gtboxes, line_thick=1, line_color='white')
        fpath = 'outputs/{}.png'.format(record['ID'])
        cv2.imwrite(fpath, img)