def get_dataset(tfrecord_path, label_map='label_map.pbtxt'):
    """
    Opens a tf record file and create tf dataset
    args:
      - tfrecord_path [str]: path to a tf record file
      - label_map [str]: path the label_map file
    returns:
      - dataset [tf.Dataset]: tensorflow dataset
    """
    input_config = input_reader_pb2.InputReader()
    input_config.label_map_path = label_map
    input_config.tf_record_input_reader.input_path[:] = [tfrecord_path]
    
    dataset = build_dataset(input_config)
    return dataset
Beispiel #2
0
def main(labelmap_path, model_path, tf_record_path, config_path, output_path):
    """
    Use a model and a tf record file and create a mp4 video
    args:
    - labelmap_path [str]: path to labelmap file
    - model_path [str]: path to exported model 
    - tf_record_path [str]: path to tf record file to visualize
    - config_path [str]: path to config file
    - output_path [str]: path to mp4 file

    Save the results as mp4 file
    """
    # load label map
    category_index = create_category_index_from_labelmap(labelmap_path,
                                                         use_display_name=True)

    # Load saved model and build the detection function
    logger.info(f'Loading model from {model_path}')
    detect_fn = tf.saved_model.load_v2(model_path)

    # open config file
    logger.info(f'Loading config from {config_path}')
    configs = get_configs_from_pipeline_file(config_path)
    eval_config = configs['eval_config']
    eval_input_config = configs['eval_input_config']
    model_config = configs['model']

    # update the eval config file
    eval_input_config.tf_record_input_reader.input_path[:] = [tf_record_path]
    dataset = build_dataset(eval_input_config)

    # build dataset
    dataset = build_dataset(eval_input_config)

    # here we infer on the entire dataset
    images = []
    logger.info(f'Inference on {tf_record_path}')
    for idx, batch in enumerate(dataset):
        if idx % 50:
            logger.info(f'Step: {idx}')
        # add new axis and feed into model
        input_tensor = batch['image']
        image_np = input_tensor.numpy().astype(np.uint8)
        input_tensor = input_tensor[tf.newaxis, ...]

        detections = detect_fn(input_tensor)

        # tensor -> numpy arr, remove one dimensions
        num_detections = int(detections.pop('num_detections'))
        print(f'num_detections={num_detections}')
        detections = {
            key: value[0, ...].numpy()
            for key, value in detections.items()
        }
        detections['num_detections'] = num_detections

        # detection_classes should be ints.
        detections['detection_classes'] = detections[
            'detection_classes'].astype(np.int64)

        image_np_with_detections = image_np.copy()
        viz_utils.visualize_boxes_and_labels_on_image_array(
            image_np_with_detections,
            detections['detection_boxes'],
            detections['detection_classes'],
            detections['detection_scores'],
            category_index,
            use_normalized_coordinates=True,
            max_boxes_to_draw=200,
            min_score_thresh=.30,
            agnostic_mode=False)
        images.append(image_np_with_detections)

    # now we can create the animation
    f = plt.figure()
    f.subplots_adjust(left=0,
                      bottom=0,
                      right=1,
                      top=1,
                      wspace=None,
                      hspace=None)
    ax = plt.subplot(111)
    ax.axis('off')
    im_obj = ax.imshow(images[0])

    def animate(idx):
        image = images[idx]
        im_obj.set_data(image)

    #anim = animation.FuncAnimation(f, animate, frames=198)
    anim = animation.FuncAnimation(f, animate, frames=100)
    anim.save(output_path, fps=5, dpi=300)