コード例 #1
0
def main(_):
  data_dir = FLAGS.data_dir
  label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)

  logging.info('Reading from dataset.')
  image_dir = os.path.join(data_dir, 'images')
  annotations_dir = os.path.join(data_dir, 'annotations')
  examples_path = os.path.join(annotations_dir, 'trainval.txt')
  examples_list = dataset_util.read_examples_list(examples_path)

  # Test images are not included in the downloaded data set, so we shall perform
  # our own split.
  random.seed(42)
  random.shuffle(examples_list)
  num_examples = len(examples_list)
  num_train = int(0.7 * num_examples)
  train_examples = examples_list[:num_train]
  val_examples = examples_list[num_train:]
  logging.info('%d training and %d validation examples.',
               len(train_examples), len(val_examples))

  train_output_path = os.path.join(FLAGS.output_dir, 'train.record')
  val_output_path = os.path.join(FLAGS.output_dir, 'val.record')
  create_tf_record(train_output_path, label_map_dict, annotations_dir,
                   image_dir, train_examples)
  create_tf_record(val_output_path, label_map_dict, annotations_dir,
                   image_dir, val_examples)
コード例 #2
0
def main(_):
  if FLAGS.set not in SETS:
    raise ValueError('set must be in : {}'.format(SETS))

  data_dir = FLAGS.data_dir
  label_map_path = os.path.join(data_dir, "labels.pbtxt")
  examples_path = os.path.join(data_dir, "files.txt")
  annotations_dir = os.path.join(data_dir, "annotations")

  label_map_dict = label_map_util.get_label_map_dict(label_map_path)
  examples_list = dataset_util.read_examples_list(examples_path)

  writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
  for idx, example in enumerate(examples_list):
    if idx % 100 == 0:
      logging.info('On image %d of %d', idx, len(examples_list))
    path = os.path.join(annotations_dir, example + '.xml')
    with tf.gfile.GFile(path, 'r') as fid:
      xml_str = fid.read()
    xml = etree.fromstring(xml_str)
    data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
    tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict, FLAGS.ignore_difficult_instances)
    writer.write(tf_example.SerializeToString())

  writer.close()
コード例 #3
0
ファイル: example_utils.py プロジェクト: fmiusov/camera-api
def get_class_names(label_map):
    """
    get the class names - as a dict of byte arrays
    input:  full absolute path o the label map protobuf
    output: dict of byte arrays {class id:  b'name'}
    """
    label_map = get_label_map_dict(label_map)
    label_map_reverse = {}
    for (k, v) in label_map.items():
        label_map_reverse[v] = k
    for (k, v) in label_map_reverse.items():
        label_map_reverse[k] = str.encode(v)
    return label_map_reverse
コード例 #4
0
def configure_tensorflow_model(model_config):
    log.info(f'tensorflow model config: {model_config}')
    framework = model_config['model_framework']
    model_path = os.path.join('model', model_config['model_path'])
    
    # get a frozen graph
    detection_graph = get_detection_graph(model_path)
    sess, tensor_dict, image_tensor = get_tf_session(detection_graph)

    label_map = model_config['label_map']
    label_dict = label_map_util.get_label_map_dict(label_map, 'id')

    # Model Input Dimensions
    # - tflite will give it to you, but not tensorflow frozen graph
    #   so I put it in the config - this is overwriting whatever tflite reported - beware
    model_input_dim = model_config['model_input_dim']
    model_image_dim = (model_config['model_input_dim'][1], model_config['model_input_dim'][2])
    # print ("Model Framework: {}   model input dim: {}   image dim: {}".format(framework, model_input_dim, model_image_dim))
    # print ("      Label Map: {}".format(label_map))
    log.info(f'Model Framework: {framework}   model input dim: {model_input_dim}   image dim: {model_image_dim}')
    return sess, tensor_dict, image_tensor, model_input_dim, label_map, label_dict
コード例 #5
0
def main(_):
    if FLAGS.set not in SETS:
        raise ValueError('set must be in : {}'.format(SETS))

    data_dir = FLAGS.data_dir

    label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)

    examples_path = os.path.join(data_dir, FLAGS.set + 'YTO.txt')

    annotations_dir = os.path.join(data_dir, FLAGS.annotations_dir)
    examples_list = dataset_util.read_examples_list(examples_path)

    random.seed(42)

    random.shuffle(examples_list)
    num_examples = len(examples_list)
    print("Total Records in Dataset: {}".format(num_examples))
    num_train = int(0.8 * num_examples)
    train_examples = examples_list[:num_train]
    test_examples = examples_list[num_train:]
    train_output_path = os.path.join(FLAGS.output_dir, 'train.record')
    test_output_path = os.path.join(FLAGS.output_dir, 'test.record')
    print("Training Set conversion")
    create_tf_record(
        train_output_path,
        FLAGS.num_shards,
        label_map_dict,
        annotations_dir,
        train_examples,
        image_dir=FLAGS.data_dir)

    print("Test Set conversion")
    create_tf_record(
        test_output_path,
        FLAGS.num_shards,
        label_map_dict,
        annotations_dir,
        test_examples,
        image_dir=FLAGS.data_dir)
コード例 #6
0
ファイル: voc_to_tfrecord.py プロジェクト: dicroce/deep_tag
def main():
    parser = argparse.ArgumentParser(
        description='Convert VOC annotations to tensorflow tfrecord files.')
    parser.add_argument('-a',
                        '--annotations_dir',
                        help='Directory containing VOC XML annotations.',
                        required=True)
    parser.add_argument('-i',
                        '--images_dir',
                        help='Directory contains JPEG images.',
                        required=True)
    parser.add_argument(
        '-l',
        '--label_map',
        help='PBTXT format mapping from text labels to integers.',
        required=True)
    args = vars(parser.parse_args())

    annotations_path = args["annotations_dir"]
    images_path = args["images_dir"]
    label_map_path = args["label_map"]

    annotations = os.listdir(annotations_path)
    random.shuffle(annotations)

    lm = get_label_map_dict(label_map_path)

    train_writer = tf.python_io.TFRecordWriter('train.record')
    test_writer = tf.python_io.TFRecordWriter('test.record')

    num_for_test = int(len(annotations) * 0.15)

    for i in range(len(annotations)):
        apath = annotations_path + '/' + annotations[i]
        parse_and_write_annotation(
            test_writer if i < num_for_test else train_writer, apath,
            images_path, lm)

    train_writer.close()
    test_writer.close()
コード例 #7
0
def configure_tflite_model(model_config):
    print (model_config)
    framework = model_config['model_framework']
    model_path = os.path.join('model', model_config['model_path'])
    #
    # S S D   M O D E L   F R A M E W O R K
    # TF Lite
    if framework == 'tflite':
        interpreter = tensorflow_util.get_tflite_interpreter(model_path)
        model_image_dim, model_input_dim, output_details = get_tflite_attributes(interpreter)

    label_map = model_config['label_map']
    label_dict = label_map_util.get_label_map_dict(label_map, 'id')

    # Model Input Dimensions
    # - tflite will give it to you, but not tensorflow frozen graph
    #   so I put it in the config - this is overwriting whatever tflite reported - beware
    model_input_dim = model_config['model_input_dim']
    model_image_dim = (model_config['model_input_dim'][1], model_config['model_input_dim'][2])
    print ("Model Framework: {}   model input dim: {}   image dim: {}".format(framework, model_input_dim, model_image_dim))
    print ("      Label Map: {}".format(label_map))
    return framework, interpreter, model_input_dim, output_details, label_map, label_dict
コード例 #8
0
def main(year, data_dir, annotations_dir, set, ignore_difficult_instances):
  if set not in SETS:
    raise ValueError('set must be in : {}'.format(SETS))
  if year not in YEARS:
    raise ValueError('year must be in : {}'.format(YEARS))

  data_dir = data_dir
  years = ['VOC2007', 'VOC2012']
  if year != 'merged':
    years = [year]

  writer = tf.python_io.TFRecordWriter(output_path)

  label_map_dict = label_map_util.get_label_map_dict(label_map_path)

  for year in years:
    logging.info('Reading from PASCAL %s dataset.', year)
    examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main', set + '.txt')
    # annotations_dir = os.path.join(data_dir, year, annotations_dir)
    annotations_dir = annotations_dir
    examples_list = dataset_util.read_examples_list(examples_path)
    for idx, example in enumerate(examples_list):
      if idx % 100 == 0:
        logging.info('On image %d of %d', idx, len(examples_list))
      path = os.path.join(annotations_dir, example + '.xml')
      with tf.gfile.GFile(path, 'r') as fid:
        xml_str = fid.read()
      # xml = etree.fromstring(xml_str)
      xml = etree.fromstring(xml_str.encode('utf-8'))
      data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']

      tf_example = dict_to_tf_example(year, data, data_dir,
                                      label_map_dict, example,
                                      ignore_difficult_instances)
      writer.write(tf_example.SerializeToString())

  writer.close()
コード例 #9
0
ファイル: example_utils.py プロジェクト: fmiusov/camera-api
def voc_to_tfrecord_file(image_dir,
                         annotation_dir,
                         label_map_file,
                         tfrecord_dir,
                         training_split_tuple,
                         include_classes="all",
                         exclude_truncated=False,
                         exclude_difficult=False):
    # this uses only TensorFlow libraries
    # - no P Ferrari classes

    label_map = get_label_map_dict(label_map_file, 'value')
    # label_map_dict = invert_dict(origin_label_map_dict)    # we need the id, not the name as the key

    train_list, val_list, test_list = gen_imageset_list(
        annotation_dir, training_split_tuple)

    print(label_map)

    # iterate through each image_id (file name w/o extension) in the image list
    # this list will give you the variables needed to iterate through train/val/test
    imageset_list = [(train_list, 'train'), (val_list, 'val'),
                     (test_list, 'test')]
    j = 0
    for (image_list, imageset_name) in imageset_list:

        # you can create/open the tfrecord writer
        output_path = os.path.join(tfrecord_dir, imageset_name,
                                   imageset_name + ".tfrecord")
        tf_writer = tf.python_io.TFRecordWriter(output_path)
        print(" -- images", len(image_list), " writing to:", output_path)

        image_count = 0  # simple image cuonter
        class_dict = {}  # dict to keep class count

        # loop through each image in the image list

        for image_id in image_list:
            if image_id.startswith('.'):
                continue
            # get annotation information
            annotation_path = os.path.join(annotation_dir, image_id + '.xml')
            with open(annotation_path) as f:
                soup = BeautifulSoup(f, 'xml')

                folder = soup.folder.text
                filename = soup.filename.text
                # size = soup.size.text
                sizeWidth = float(
                    soup.size.width.text)  # you need everything as floats
                sizeHeight = float(soup.size.height.text)
                sizeDepth = float(soup.size.depth.text)

                boxes = []  # We'll store all boxes for this image here
                objects = soup.find_all(
                    'object')  # Get a list of all objects in this image

                # Parse the data for each object
                for obj in objects:
                    class_name = obj.find('name').text
                    try:
                        class_id = label_map[class_name]
                        class_dict = incr_class_count(class_dict, class_id)
                    except:
                        print("!!! label map error:", image_id, class_name,
                              " skipped")
                        continue
                    # Check if this class is supposed to be included in the dataset
                    if (not include_classes == 'all') and (not class_id
                                                           in include_classes):
                        continue
                    pose = obj.pose.text
                    truncated = int(obj.truncated.text)
                    if exclude_truncated and (truncated == 1): continue
                    difficult = int(obj.difficult.text)
                    if exclude_difficult and (difficult == 1): continue
                    # print (image_id, image_count, "xmin:", obj.bndbox.xmin.text)
                    xmin = int(
                        obj.bndbox.xmin.text.split('.')[0]
                    )  # encountered a few decimals - that will throw an error
                    ymin = int(obj.bndbox.ymin.text.split('.')[0])
                    xmax = int(obj.bndbox.xmax.text.split('.')[0])
                    ymax = int(obj.bndbox.ymax.text.split('.')[0])
                    item_dict = {
                        'class_name': class_name,
                        'class_id': class_id,
                        'pose': pose,
                        'truncated': truncated,
                        'difficult': difficult,
                        'xmin': xmin,
                        'ymin': ymin,
                        'xmax': xmax,
                        'ymax': ymax
                    }
                    boxes.append(item_dict)

            # get the encoded image
            img_path = os.path.join(image_dir, image_id + ".jpg")
            with tf.io.gfile.GFile(img_path, 'rb') as fid:
                encoded_jpg = fid.read()

            # now you have everything necessary to create a tf.example
            # tf.Example proto
            # print ("    ", filename)
            # print ("     ", class_name, class_id)
            # print ("     ", sizeHeight, sizeWidth, sizeDepth)
            # print ("     ", len(boxes))

            xmins = []
            xmaxs = []
            ymins = []
            ymaxs = []
            class_names = []
            class_ids = []
            obj_areas = []
            obj_is_crowds = []
            obj_difficults = []
            obj_group_ofs = []
            obj_weights = []

            # loop through each bbox to make a list of each
            for box in boxes:
                # print ("       ", box)
                # create lists of bbox dimensions
                xmins.append(box['xmin'] / sizeWidth)
                xmaxs.append(box['xmax'] / sizeWidth)

                ymins.append(box['ymin'] / sizeHeight)
                ymaxs.append(box['ymax'] / sizeHeight)

                class_names.append(str.encode(box['class_name']))
                class_ids.append(int(box['class_id']))

                obj_areas.append(OBJ_AREA)
                obj_is_crowds.append(OBJ_IS_CROWD)
                obj_difficults.append(OBJ_DIFFICULT)
                obj_group_ofs.append(OBJ_GROUP_OF)
                obj_weights.append(OBJ_WEIGHT)

            # use the commonly defined feature dictionary
            feature = feature_obj_detect.copy()
            # thus you have a common structure for writing & reading
            # these image features

            # per image attributes
            feature['image/encoded'] = bytes_feature(encoded_jpg)
            feature['image/format'] = bytes_feature(IMG_FORMAT)
            feature['image/filename'] = bytes_feature(str.encode(filename))
            feature['image/key/sha256'] = bytes_feature(IMG_SHA256)
            feature['image/source_id'] = bytes_feature(str.encode(image_id))

            feature['image/height'] = int64_feature(int(sizeHeight))
            feature['image/width'] = int64_feature(int(sizeWidth))

            feature['image/class/text'] = bytes_list_feature(IMG_CLASS_NAMES)
            feature['image/class/label'] = int64_list_feature(IMG_CLASS_IDS)

            # per image/object attributes
            feature['image/object/bbox/xmin'] = float_list_feature(xmins)
            feature['image/object/bbox/xmax'] = float_list_feature(xmaxs)
            feature['image/object/bbox/ymin'] = float_list_feature(ymins)
            feature['image/object/bbox/ymax'] = float_list_feature(ymaxs)
            feature['image/object/class/text'] = bytes_list_feature(
                class_names)
            feature['image/object/class/label'] = int64_list_feature(class_ids)

            # these are all taken from default values
            feature['image/object/area'] = float_list_feature(obj_areas)
            feature['image/object/is_crowd'] = int64_list_feature(
                obj_is_crowds)
            feature['image/object/difficult'] = int64_list_feature(
                obj_difficults)
            feature['image/object/group_of'] = int64_list_feature(
                obj_group_ofs)
            feature['image/object/weight'] = float_list_feature(obj_weights)

            features = tf.train.Features(feature=feature)

            tf_example = tf.train.Example(features=features)
            # write to the tfrecords writer
            tf_writer.write(tf_example.SerializeToString())
            image_count = image_count + 1

        # end of loop
        # TODO - shard on larger sets
        #        https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/using_your_own_dataset.md
        tf_writer.close()  # close the writer
        print('     image count:', image_count, "  class_count:", class_dict)
    return 1
コード例 #10
0
ファイル: test_tf.py プロジェクト: fmiusov/camera-api
# add some paths that will pull in other software
# -- don't add to the path over and over
cwd = os.getcwd()
models = os.path.abspath(os.path.join(cwd, '..', 'models/research/'))
slim = os.path.abspath(os.path.join(cwd, '..', 'models/research/slim'))
sys.path.append(models)
sys.path.append(slim)

import tensorflow_util
import label_map_util  # this came from tensorflow/models

# if you made it this far, you python path is good - and you successfully inported

# test the label map util
d = label_map_util.get_label_map_dict('model/mscoco_label_map.pbtxt', 'id')
print(d)

# test getting a model
interpreter = tensorflow_util.get_tflite_interpreter(
    'model/output_tflite_graph.tflite')
model_image_dim, model_input_dim, output_details = tensorflow_util.get_tflite_attributes(
    interpreter)
print("Model Image Dim:", model_image_dim)
print("Model Image Dim:", model_input_dim)

# load an impage & preprocess from a file
image = tensorflow_util.load_image_into_numpy_array(
    'jpeg_images/111-1122_IMG.JPG')
cv2.namedWindow('raw_image', cv2.WINDOW_NORMAL)
cv2.imshow('raw_image', image)
コード例 #11
0
def main():
    # args
    camera_number = int(sys.argv[1])  # 0 based

    # get the app config - including passwords
    config = gen_util.read_app_config('app_config.json')

    # set some flags based on the config
    run_inferences = config["run_inferences"]
    save_inference = config["save_inference"]
    annotation_dir = config["annotation_dir"]
    snapshot_dir = config["snapshot_dir"]

    # set up tflite model
    global label_dict
    label_dict = label_map_util.get_label_map_dict(config['label_map'], 'id')

    global interpreter
    interpreter = tensorflow_util.get_tflite_interpreter(
        'model/output_tflite_graph.tflite')

    global model_image_dim, model_input_dim, output_details
    model_image_dim, model_input_dim, output_details = tensorflow_util.get_tflite_attributes(
        interpreter)

    # define your paths here - just once (not in the loop)
    global image_path, annotation_path
    image_path = os.path.abspath(os.path.join(cwd, snapshot_dir))
    annotation_path = os.path.abspath(os.path.join(cwd, annotation_dir))

    # Set up Camera
    # TODO - should be a list
    #   - but it's just one camera now

    # for name, capture, flip in camera_list:
    camera_config = camera_util.get_camera_config(config, camera_number)
    camera_name = camera_config['name']
    url = camera_util.get_reolink_url(
        'http', camera_config['ip'])  # pass the url base - not just the ip
    print("Camera Config:", camera_config)

    # based on the config, config all camera regions
    # - includes building the bbox stacks
    regions, bbox_stack_list, bbox_push_list = camera_util.config_camera_regions(
        camera_config)

    snapshot_count = 0
    while True:

        start_time = time.time()
        base_name = "{}_{}".format(str(int(start_time)), camera_number)
        # frame returned as a numpy array ready for cv2
        # not resized
        angle = camera_config['rotation_angle']
        frame = camera_util.get_reolink_snapshot(url,
                                                 camera_config['username'],
                                                 camera_config['password'])

        if frame is not None:
            frame = imutils.rotate(frame, angle)  # rotate frame
            orig_image_dim = (frame.shape[0], frame.shape[1]
                              )  #  dim = (height, width),
            orig_image = frame.copy(
            )  # preserve the original - full resolution
            # corner is top left

            print(
                '\n-- {} snap captured: {}'.format(snapshot_count,
                                                   frame.shape),
                '{0:.2f} seconds'.format(time.time() - start_time))

            # True == run it through the model
            if run_inferences:
                inference_start_time = time.time()
                # loop through 0:n sub-regions of the frame
                # last one is the full resolution
                for i, region in enumerate(regions):
                    crop_start_time = time.time()

                    inference_image, detected_objects, bbox_array = run_inference(
                        orig_image, base_name, region, i, bbox_stack_list,
                        bbox_push_list, True)
                    print(
                        '     crop {}'.format(i),
                        ' inference: {0:.2f} seconds'.format(time.time() -
                                                             crop_start_time))
                    # enlarged_inference = cv2.resize(inference_image, (1440, 1440), interpolation = cv2.INTER_AREA)
                    window_name = "{} crop {}".format(camera_name, i)
                    cv2.imshow(window_name,
                               inference_image)  # show the inferance

                print('   TOTAL inference: {0:.2f} seconds'.format(
                    time.time() - inference_start_time))

            else:
                cv2.imshow(camera_name, frame)
            snapshot_count = snapshot_count + 1
        else:
            print("-- no frame returned -- ")

        # time.sleep(3)

        # Use key 'q' to close window
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cv2.destroyAllWindows()
コード例 #12
0
def main():
    # get the app config - including passwords
    config = gen_util.read_app_config('app_config.json')

    # set some flags based on the config
    run_inference = config["run_inference"]
    save_inference = config["save_inference"]
    annotation_dir = config["annotation_dir"]
    snapshot_dir = config["snapshot_dir"]

    # set up camerass
    camera_list = camera_util.configure_cameras(config)

    # set up tflite model
    label_dict = label_map_util.get_label_map_dict(config['label_map'], 'id')
    interpreter = tensorflow_util.get_tflite_interpreter(
        'model/output_tflite_graph.tflite')
    model_image_dim, model_input_dim, output_details = tensorflow_util.get_tflite_attributes(
        interpreter)

    # define your paths here - just once (not in the loop)
    image_path = os.path.abspath(os.path.join(cwd, snapshot_dir))
    annotation_path = os.path.abspath(os.path.join(cwd, annotation_dir))

    run_with_camera_number = 0  # 0 based

    snapshot_count = 0
    while True:

        # for name, capture, flip in camera_list:
        name, capture, flip = camera_list[
            run_with_camera_number]  # running with 1 camera only
        start_time = time.time()
        print(name, snapshot_count)

        ret, frame = capture.read()  #  frame.shape (height, width, depth)

        if frame is not None:
            orig_image_dim = (frame.shape[0], frame.shape[1]
                              )  #  dim = (height, width),
            orig_image = frame.copy()
            snapshot_count = snapshot_count + 1

            print('captured:', frame.shape, time.time() - start_time)

            if flip == "vert":
                frame = cv2.flip(frame, 0)

            # True == run it through the model
            if run_inference:
                # pre-process the frame -> a compatible numpy array for the model
                preprocessed_image = tensorflow_util.preprocess_image(
                    frame, interpreter, model_image_dim, model_input_dim)
                bbox_array, class_id_array, prob_array = tensorflow_util.send_image_to_model(
                    preprocessed_image, interpreter)
                print('inference:', frame.shape, time.time() - start_time)

                inference_image, orig_image_dim, detected_objects = display.inference_to_image(
                    frame, bbox_array, class_id_array, prob_array,
                    model_input_dim, label_dict, PROBABILITY_THRESHOLD)

                # testing the format
                # convert detected_objexts to XML
                # detected_objects = list [ (class_id, class_name, probability, xmin, ymin, xmax, ymax)]
                if len(detected_objects) > 0:
                    print(detected_objects)
                    if save_inference:
                        image_base_name = str(int(start_time))
                        image_name = os.path.join(image_path,
                                                  image_base_name + '.jpg')
                        annotation_name = os.path.join(
                            annotation_path, image_base_name + '.xml')
                        print("saving:", image_name, frame.shape,
                              annotation_name)
                        # original image - h: 480  w: 640
                        cv2.imwrite(image_name, orig_image)
                        # this function generates & saves the XML annotation
                        annotation_xml = annotation.inference_to_xml(
                            name, image_name, orig_image_dim, detected_objects,
                            annotation_dir)

                # enlarged_inference = cv2.resize(inference_image, (1440, 1440), interpolation = cv2.INTER_AREA)
                cv2.imshow(name, inference_image)  # show the inferance
                # cv2.imshow(name, orig_image)     # show the raw image from the camera
            else:
                cv2.imshow(name, frame)
        else:
            print("-- no frame returned -- ")

        # time.sleep(3)

        # Use key 'q' to close window
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()
コード例 #13
0
def convert_kitti_to_tfrecords(data_dir, output_path, classes_to_use,
                               label_map_path, validation_set_size):
    """Convert the KITTI detection dataset to TFRecords.

  Args:
    data_dir: The full path to the unzipped folder containing the unzipped data
      from data_object_image_2 and data_object_label_2.zip.
      Folder structure is assumed to be: data_dir/training/label_2 (annotations)
      and data_dir/data_object_image_2/training/image_2 (images).
    output_path: The path to which TFRecord files will be written. The TFRecord
      with the training set will be located at: <output_path>_train.tfrecord
      And the TFRecord with the validation set will be located at:
      <output_path>_val.tfrecord
    classes_to_use: List of strings naming the classes for which data should be
      converted. Use the same names as presented in the KIITI README file.
      Adding dontcare class will remove all other bounding boxes that overlap
      with areas marked as dontcare regions.
    label_map_path: Path to label map proto
    validation_set_size: How many images should be left as the validation set.
      (Ffirst `validation_set_size` examples are selected to be in the
      validation set).
  """
    _label_map_dict = label_map_util.get_label_map_dict(label_map_path)
    label_map_dict = {
        x: (_label_map_dict[x] - 1)
        for x in _label_map_dict.keys()
    }
    print(label_map_dict)
    train_count = 0
    val_count = 0

    annotation_dir = os.path.join(data_dir, 'training', 'label_2')

    image_dir = os.path.join(data_dir, 'training', 'image_2')

    train_writer = tf.python_io.TFRecordWriter(
        os.path.join(output_path, 'kitti_train.record'))
    val_writer = tf.python_io.TFRecordWriter(
        os.path.join(output_path, 'kitti_val.record'))

    images = sorted(tf.gfile.ListDirectory(image_dir))
    validation_set_size = int(len(images) / 2)
    validation_images_idx = np.random.choice(range(len(images)),
                                             size=validation_set_size,
                                             replace=False)
    for img_idx, img_name in enumerate(images):
        img_num = int(img_name.split('.')[0])
        is_validation_img = (img_idx in validation_images_idx)
        img_anno = read_annotation_file(
            os.path.join(annotation_dir,
                         str(img_num).zfill(6) + '.txt'))

        image_path = os.path.join(image_dir, img_name)

        # Filter all bounding boxes of this frame that are of a legal class, and
        # don't overlap with a dontcare region.
        # TODO(talremez) filter out targets that are truncated or heavily occluded.
        annotation_for_image = filter_annotations(img_anno, classes_to_use)
        if (len(annotation_for_image['2d_bbox_left']) > 0):
            example = prepare_example(image_path, annotation_for_image,
                                      label_map_dict)
            if is_validation_img:
                val_writer.write(example.SerializeToString())
                val_count += 1
            else:
                train_writer.write(example.SerializeToString())
                train_count += 1

    train_writer.close()
    val_writer.close()