Exemplo n.º 1
0
def main(args: Namespace, cfg: AttrDict):
    setup_logging(__name__, output_dir=get_checkpoint_folder(cfg))

    # Extract the features if the feature extract is enabled
    if cfg.CLUSTERFIT.FEATURES.EXTRACT:

        # We cannot have automatic extraction with more than 1 node or otherwise
        # we would have to run this script on several nodes and thus have several
        # parallel clustering of the features. The automatic extraction is only
        # there as a shortcut when running on a single node
        assert (cfg.DISTRIBUTED.NUM_NODES == 1
                ), "Automatic extraction can only work with 1 node"

        # Make sure to dump the features at the desired path
        cfg.CHECKPOINT.DIR = cfg.CLUSTERFIT.FEATURES.PATH
        cfg.CHECKPOINT.APPEND_DISTR_RUN_ID = False

        # Run the extraction of features
        set_env_vars(local_rank=0, node_id=0, cfg=cfg)
        logging.info("Setting seed....")
        set_seeds(cfg, args.node_id)
        launch_distributed(
            cfg,
            args.node_id,
            engine_name="extract_features",
            hook_generator=default_hook_generator,
        )

    # Else setup the path manager (done in set_env_vars) in
    # case of feature extraction above
    else:
        setup_path_manager()

    cluster_features(cfg)
    shutdown_logging()
Exemplo n.º 2
0
        logging.info("Beginning extract features for query set.")

        launch_distributed(
            config,
            args.node_id,
            engine_name="extract_features",
            hook_generator=default_hook_generator,
        )

    # print the config
    print_cfg(config)

    instance_retrieval_test(args, config)
    logging.info(f"Performance time breakdow:\n{PERF_STATS.report_str()}")

    # close the logging streams including the filehandlers
    shutdown_logging()


def hydra_main(overrides: List[Any]):
    cfg = compose_hydra_configuration(overrides)
    args, config = convert_to_attrdict(cfg)
    main(args, config)


if __name__ == "__main__":
    overrides = sys.argv[1:]

    setup_path_manager()
    hydra_main(overrides=overrides)