def main():

    parser = argparse.ArgumentParser(
        description=
        'Module to run a TF animal detection model on lots of images')
    parser.add_argument('detector_file',
                        help='Path to .pb TensorFlow detector model file')
    parser.add_argument(
        'image_file',
        help=
        'Path to a single image file, a JSON file containing a list of paths to images, or a directory'
    )
    parser.add_argument(
        'output_file',
        help=
        'Path to output JSON results file, should end with a .json extension')
    parser.add_argument(
        '--recursive',
        action='store_true',
        help=
        'Recurse into directories, only meaningful if image_file points to a directory'
    )
    parser.add_argument(
        '--output_relative_filenames',
        action='store_true',
        help=
        'Output relative file names, only meaningful if image_file points to a directory'
    )
    parser.add_argument(
        '--threshold',
        type=float,
        default=TFDetector.DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD,
        help=
        "Confidence threshold between 0 and 1.0, don't include boxes below this confidence in the output file. Default is 0.1"
    )
    parser.add_argument(
        '--checkpoint_frequency',
        type=int,
        default=-1,
        help=
        'Write results to a temporary file every N images; default is -1, which disables this feature'
    )
    parser.add_argument(
        '--resume_from_checkpoint',
        help=
        'Path to a JSON checkpoint file to resume from, must be in same directory as output_file'
    )
    parser.add_argument(
        '--ncores',
        type=int,
        default=0,
        help=
        'Number of cores to use; only applies to CPU-based inference, does not support checkpointing when ncores > 1'
    )

    if len(sys.argv[1:]) == 0:
        parser.print_help()
        parser.exit()

    args = parser.parse_args()

    assert os.path.exists(
        args.detector_file), 'Specified detector_file does not exist'
    assert 0.0 < args.threshold <= 1.0, 'Confidence threshold needs to be between 0 and 1'  # Python chained comparison
    assert args.output_file.endswith(
        '.json'), 'output_file specified needs to end with .json'
    if args.checkpoint_frequency != -1:
        assert args.checkpoint_frequency > 0, 'Checkpoint_frequency needs to be > 0 or == -1'
    if args.output_relative_filenames:
        assert os.path.isdir(
            args.image_file
        ), 'image_file must be a directory when --output_relative_filenames is set'

    if os.path.exists(args.output_file):
        print('Warning: output_file {} already exists and will be overwritten'.
              format(args.output_file))

    # Load the checkpoint if available
    #
    # Relative file names are only output at the end; all file paths in the checkpoint are
    # still full paths.
    if args.resume_from_checkpoint:
        assert os.path.exists(
            args.resume_from_checkpoint
        ), 'File at resume_from_checkpoint specified does not exist'
        with open(args.resume_from_checkpoint) as f:
            saved = json.load(f)
        assert 'images' in saved, \
            'The file saved as checkpoint does not have the correct fields; cannot be restored'
        results = saved['images']
        print('Restored {} entries from the checkpoint'.format(len(results)))
    else:
        results = []

    # Find the images to score; images can be a directory, may need to recurse
    if os.path.isdir(args.image_file):
        image_file_names = ImagePathUtils.find_images(args.image_file,
                                                      args.recursive)
        print('{} image files found in the input directory'.format(
            len(image_file_names)))
    # A json list of image paths
    elif os.path.isfile(args.image_file) and args.image_file.endswith('.json'):
        with open(args.image_file) as f:
            image_file_names = json.load(f)
        print('{} image files found in the json list'.format(
            len(image_file_names)))
    # A single image file
    elif os.path.isfile(args.image_file) and ImagePathUtils.is_image_file(
            args.image_file):
        image_file_names = [args.image_file]
        print('A single image at {} is the input file'.format(args.image_file))
    else:
        raise ValueError(
            'image_file specified is not a directory, a json list, or an image file, '
            '(or does not have recognizable extensions).')

    assert len(
        image_file_names
    ) > 0, 'Specified image_file does not point to valid image files'
    assert os.path.exists(
        image_file_names[0]
    ), 'The first image to be scored does not exist at {}'.format(
        image_file_names[0])

    output_dir = os.path.dirname(args.output_file)

    assert os.path.exists(
        output_dir), 'Invalid output filename (folder does not exist)'
    assert not os.path.isdir(
        args.output_file), 'Specified output file is a directory'

    # Test that we can write to the output_file's dir if checkpointing requested
    if args.checkpoint_frequency != -1:
        checkpoint_path = os.path.join(
            output_dir, 'checkpoint_{}.json'.format(
                datetime.utcnow().strftime("%Y%m%d%H%M%S")))
        with open(checkpoint_path, 'w') as f:
            json.dump({'images': []}, f)
        print('The checkpoint file will be written to {}'.format(
            checkpoint_path))
    else:
        checkpoint_path = None

    start_time = time.time()

    results = load_and_run_detector_batch(
        model_file=args.detector_file,
        image_file_names=image_file_names,
        checkpoint_path=checkpoint_path,
        confidence_threshold=args.threshold,
        checkpoint_frequency=args.checkpoint_frequency,
        results=results,
        n_cores=args.ncores)

    elapsed = time.time() - start_time
    print('Finished inference in {}'.format(
        humanfriendly.format_timespan(elapsed)))

    relative_path_base = None
    if args.output_relative_filenames:
        relative_path_base = args.image_file
    write_results_to_file(results,
                          args.output_file,
                          relative_path_base=relative_path_base)

    if checkpoint_path:
        os.remove(checkpoint_path)
        print('Deleted checkpoint file')

    print('Done!')
示例#2
0
def main():

    parser = argparse.ArgumentParser(
        description=
        'Module to run a TF animal detection model on lots of images')
    parser.add_argument('detector_file',
                        help='Path to .pb TensorFlow detector model file')
    parser.add_argument(
        'image_file',
        help=
        'Can be a single image file, a json file containing a list of paths to images, or a directory'
    )
    parser.add_argument(
        'output_file',
        help='Output results file, should end with a .json extension')
    parser.add_argument(
        '--recursive',
        action='store_true',
        help=
        'Recurse into directories, only meaningful if --image_file points to a directory'
    )
    parser.add_argument(
        '--output_relative_filenames',
        action='store_true',
        help=
        'Output relative file names, only meaningful if --image_file points to a directory'
    )
    parser.add_argument(
        '--threshold',
        type=float,
        default=TFDetector.DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD,
        help=
        "Confidence threshold between 0 and 1.0, don't include boxes below this confidence in the output file. Default is 0.1"
    )
    parser.add_argument(
        '--checkpoint_frequency',
        type=int,
        default=-1,
        help=
        'Write results to a temporary file every N images; default is -1, which disables this feature'
    )
    parser.add_argument(
        '--resume_from_checkpoint',
        help=
        'Initiate from the specified checkpoint, which is in the same directory as the output_file specified'
    )

    if len(sys.argv[1:]) == 0:
        parser.print_help()
        parser.exit()

    args = parser.parse_args()

    assert os.path.exists(
        args.detector_file), 'detector_file specified does not exist'
    assert 0.0 < args.threshold <= 1.0, 'Confidence threshold needs to be between 0 and 1'  # Python chained comparison
    assert args.output_file.endswith(
        '.json'), 'output_file specified needs to end with .json'
    if args.checkpoint_frequency != -1:
        assert args.checkpoint_frequency > 0, 'Checkpoint_frequency needs to be > 0 or == -1'
    if args.output_relative_filenames:
        assert os.path.isdir(
            args.image_file
        ), 'Since output_relative_filenames is flagged, image_file needs to be a directory'

    if os.path.exists(args.output_file):
        print('Warning: output_file {} already exists and will be overwritten'.
              format(args.output_file))

    # load the checkpoint if available
    # relative file names are only output at the end; all file paths in the checkpoint are still full paths
    if args.resume_from_checkpoint:
        assert os.path.exists(
            args.resume_from_checkpoint
        ), 'File at resume_from_checkpoint specified does not exist'
        with open(args.resume_from_checkpoint) as f:
            saved = json.load(f)
        assert 'images' in saved, \
            'The file saved as checkpoint does not have the correct fields; cannot be restored'
        results = saved['images']
        print('Restored {} entries from the checkpoint'.format(len(results)))
    else:
        results = []

    # Find the images to score; images can be a directory, may need to recurse
    if os.path.isdir(args.image_file):
        image_file_names = ImagePathUtils.find_images(args.image_file,
                                                      args.recursive)
        print('{} image files found in the input directory'.format(
            len(image_file_names)))
    # a json list of image paths
    elif os.path.isfile(args.image_file) and args.image_file.endswith('.json'):
        with open(args.image_file) as f:
            image_file_names = json.load(f)
        print('{} image files found in the json list'.format(
            len(image_file_names)))
    # a single image file
    elif os.path.isfile(args.image_file) and ImagePathUtils.is_image_file(
            args.image_file):
        image_file_names = [args.image_file]
        print('A single image at {} is the input file'.format(args.image_file))
    else:
        print(
            'image_file specified is not a directory, a json list or an image file (or does not have recognizable extensions), exiting.'
        )
        sys.exit(1)

    assert len(image_file_names
               ) > 0, 'image_file provided does not point to valid image files'
    assert os.path.exists(
        image_file_names[0]
    ), 'The first image to be scored does not exist at {}'.format(
        image_file_names[0])

    # test that we can write to the output_file's dir if checkpointing requested
    if args.checkpoint_frequency != -1:
        output_dir = os.path.dirname(args.output_file)
        checkpoint_path = os.path.join(
            output_dir, 'checkpoint_{}.json'.format(
                datetime.utcnow().strftime("%Y%m%d%H%M%S")))
        with open(checkpoint_path, 'w') as f:
            json.dump({'images': []}, f)
        print('The checkpoint file will be written to {}'.format(
            checkpoint_path))
    else:
        checkpoint_path = None

    results = load_and_run_detector_batch(
        model_file=args.detector_file,
        image_file_names=image_file_names,
        checkpoint_path=checkpoint_path,
        confidence_threshold=args.threshold,
        checkpoint_frequency=args.checkpoint_frequency,
        results=results)

    if args.output_relative_filenames:
        for r in results:
            r['file'] = os.path.relpath(r['file'], start=args.image_file)

    final_output = {
        'images': results,
        'detection_categories': TFDetector.DEFAULT_DETECTOR_LABEL_MAP,
        'info': {
            'detection_completion_time':
            datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S'),
            'format_version':
            '1.0'
        }
    }
    with open(args.output_file, 'w') as f:
        json.dump(final_output, f, indent=1)
    print('Output file saved at {}'.format(args.output_file))

    # finally delete the checkpoint file if used
    if checkpoint_path:
        os.remove(checkpoint_path)
        print('Deleted checkpoint file')
    print('Done!')