# Run inference
        predicted_bbox, predicted_labels, predicted_scores = run_inference(detr, np.expand_dims(model_input, axis=0), config)

        frame = frame.astype(np.float32)
        frame = frame / 255
        frame = numpy_bbox_to_image(frame, predicted_bbox, labels=predicted_labels, scores=predicted_scores, class_name=COCO_CLASS_NAME)

        cv2.imshow('frame', frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    # When everything done, release the capture
    cap.release()
    cv2.destroyAllWindows()

if __name__ == "__main__":

    physical_devices = tf.config.list_physical_devices('GPU')
    if len(physical_devices) == 1:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)

    config = TrainingConfig()
    args = training_config_parser().parse_args()
    config.update_from_args(args)

    # Load the model with the new layers to finetune
    detr = get_detr_model(config, include_top=True, weights="detr")
    config.background_class = 91

    # RUn webcam inference
    run_webcam_inference(detr)
Exemple #2
0
        t_class = tf.squeeze(t_class, axis=-1)
        # Compute map
        cal_map(p_bbox, p_labels, p_scores,  np.zeros((138, 138, len(p_bbox))), np.array(t_bbox), np.array(t_class), np.zeros((138, 138, len(t_bbox))), ap_data, iou_thresholds)
        print(f"Computing map.....{it}", end="\r")
        it += 1
        #if it > 10:
        #    break

    # Compute the mAp over all thresholds
    calc_map(ap_data, iou_thresholds, class_names, print_result=True)

if __name__ == "__main__":

    physical_devices = tf.config.list_physical_devices('GPU')
    if len(physical_devices) == 1:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)

    config = TrainingConfig()
    args = training_config_parser().parse_args()
    config.update_from_args(args)

    # Load the model with the new layers to finetune
    detr = build_model(config)

    valid_dt, class_names = load_coco_dataset(config, 1, augmentation=None)

    # Run training
    eval_model(detr, config, class_names, valid_dt)