def extract_features():
    if not args.list and len(args.files) == 0:
        logger.error("No input file specified.")
        return

    import tensorflow as tf
    import inspect
    import feature_extractor as feature_extractor

    # Add before any TF calls (https://github.com/tensorflow/tensorflow/issues/29931#issuecomment-504217770)
    # Initialize the keras global outside of any tf.functions
    temp = tf.zeros([4, 32, 32, 3])
    tf.keras.applications.vgg16.preprocess_input(temp)

    # Get all the available feature extractor names
    extractor_names = list([
        e[0] for e in inspect.getmembers(feature_extractor, inspect.isclass)
        if e[0] != "FeatureExtractorBase"
    ])

    module = __import__("feature_extractor")

    if args.list:
        print("%-30s | %-15s | %-4s | %-8s | %-5s" %
              ("NAME", "OUTPUT SHAPE", "RF", "IMG SIZE", "RF / IMG"))
        print("-" * 80)
        for e in list(map(lambda e: getattr(module, e), extractor_names)):
            factor = e.RECEPTIVE_FIELD["size"][0] / float(e.IMG_SIZE)
            print("%-30s | %-15s | %-4s | %-8s | %.3f %s" %
                  (e.__name__.replace("FeatureExtractor", ""), e.OUTPUT_SHAPE,
                   e.RECEPTIVE_FIELD["size"][0], e.IMG_SIZE, factor,
                   "!" if factor >= 2 else ""))
        return

    if args.extractor is None:
        args.extractor = extractor_names

    # args.extractor = filter(lambda f: "EfficientNet" in f, args.extractor)

    if isinstance(args.files, basestring):
        args.files = [args.files]

    patches = PatchArray(args.files)

    ## WZL:
    patches = patches.training_and_validation
    # For the benchmark subset:
    # patches = patches.training_and_validation[0:10]

    ## FieldSAFE:
    # p = patches[:, 0, 0]
    # f = p.round_numbers == 1
    # patches = patches[f]

    # vis = Visualize(patches)
    # vis.show()

    dataset = patches.to_dataset()
    dataset_3D = patches.to_temporal_dataset(16)
    total = patches.shape[0]

    # Add progress bar if multiple extractors
    if len(args.extractor) > 1:
        args.extractor = tqdm(args.extractor,
                              desc="Extractors",
                              file=sys.stderr)

    for extractor_name in args.extractor:
        try:
            bs = getattr(module, extractor_name).TEMPORAL_BATCH_SIZE
            # shape = getattr(module, extractor_name).OUTPUT_SHAPE
            # if np.prod(shape) > 300000:
            #     logger.warning("Skipping %s (output too big)" % extractor_name)
            #     continue

            logger.info("Instantiating %s" % extractor_name)
            extractor = getattr(module, extractor_name)()
            # Get an instance
            if bs > 1:
                extractor.extract_dataset(dataset_3D, total)
            else:
                extractor.extract_dataset(dataset, total)
        except KeyboardInterrupt:
            logger.info("Terminated by CTRL-C")
            return
        except:
            logger.error("%s: %s" % (extractor_name, traceback.format_exc()))
示例#2
0
    RECEPTIVE_FIELD = {'stride': (4.0, 4.0), 'size': (23, 23)}


# Only for tests
if __name__ == "__main__":
    from common import PatchArray
    extractor = FeatureExtractorC3D()
    # extractor.plot_model(extractor.model)
    patches = PatchArray()

    p = patches[:, 0, 0]

    f = np.zeros(p.shape, dtype=np.bool)
    f[:] = np.logical_and(
        p.directions == 1,  # CCW and
        np.logical_or(
            p.labels == 2,  #   Anomaly or
            np.logical_and(
                p.round_numbers >= 7,  #     Round between 2 and 5
                p.round_numbers <= 9)))

    # Let's make contiguous blocks of at least 10, so
    # we can do some meaningful temporal smoothing afterwards
    for i, b in enumerate(f):
        if b and i - 10 >= 0:
            f[i - 10:i] = True

    patches = patches[f]

    extractor.extract_dataset(patches.to_temporal_dataset(), patches.shape[0])