Esempio n. 1
0
def setup(args):
    """
    Create configs and setup logger from arguments and the given config file.
    """
    cfg = get_shapenet_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    # register dataset
    data_dir, splits_file = register_shapenet(cfg.DATASETS.NAME)
    cfg.DATASETS.DATA_DIR = data_dir
    cfg.DATASETS.SPLITS_FILE = splits_file
##
    cfg.PRETRAINED_MODEL = args.trained_model_from_Pix_2_Vox
    cfg.PRETRAINED_MODEL2 = args.trained_model_from_Mesh_RCNN
##
    # if data was copied the data dir has changed
    if args.copy_data:
        cfg.DATASETS.DATA_DIR = args.data_dir
    cfg.freeze()

    colorful_logging = not args.no_color
    output_dir = cfg.OUTPUT_DIR
    if comm.is_main_process() and output_dir:
        os.makedirs(output_dir, exist_ok=True)
    comm.synchronize()

    logger = setup_logger(
        output_dir, color=colorful_logging, name="shapenet", distributed_rank=comm.get_rank()
    )
    logger.info(
        "Using {} GPUs per machine. Rank of current process: {}".format(
            args.num_gpus, comm.get_rank()
        )
    )
    logger.info(args)

    logger.info("Environment info:\n" + collect_env_info())
    logger.info(
        "Loaded config file {}:\n{}".format(args.config_file, open(args.config_file, "r").read())
    )
    logger.info("Running with full config:\n{}".format(cfg))
    if comm.is_main_process() and output_dir:
        path = os.path.join(output_dir, "config.yaml")
        with open(path, "w") as f:
            f.write(cfg.dump())
        logger.info("Full config saved to {}".format(os.path.abspath(path)))
    return cfg
    
    parser.add_argument("--output", help="A directory to save output visualizations")
    
    parser.add_argument("--checkpoint",help="A path to a checkpoint file")
    parser.add_argument("--saved_weights_dir", help="Path to saved weights to eval from")
    parser.add_argument("--split", default='val', help='train_eval or val split')
    return parser


if __name__ == "__main__":
    args = get_parser().parse_args()
    logger = setup_logger(name="demo")
    logger.info("Arguments: " + str(args))

    cfg = setup_cfg(args)
    data_dir, splits_file = register_shapenet(cfg.DATASETS.NAME)
    cfg.DATASETS.DATA_DIR = data_dir
    cfg.DATASETS.SPLITS_FILE = splits_file
    split = args.split

    data_loader = build_data_loader(cfg, 'MeshVoxMulti', split, multigpu=False, num_workers=8)
    print('Steps in dataloader: {}'.format(len(data_loader)))

    data_dir = args.data_dir
    saved_weights_dir = args.saved_weights_dir
    saved_weights = sorted(glob.glob(saved_weights_dir+'/*.pth'))

    #wandb.init(project='MeshRCNN', config=cfg, name='eval_'+split+'_'+saved_weights_dir)
    for checkpoint_lp_model in saved_weights:
        running_correct = 0
        running_total   = 0