示例#1
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        help="Do not test the final model",
        action="store_true",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = num_gpus > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://"
        )
        synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir:
        mkdir(output_dir)

    logger = setup_logger("fcos_core", output_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info("Loaded configuration file {}".format(args.config_file))
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    model = train(cfg, args.local_rank, args.distributed)

    if not args.skip_test:
        run_test(cfg, model, args.distributed)
示例#2
0
 def prepare_detector(self):
     # Load model
     cfg.merge_from_file(self.config)
     cfg.merge_from_list(list())
     cfg.MODEL.WEIGHT = self.weights
     cfg.freeze()
     self.detector = COCODemo(
             cfg, 
             confidence_thresholds_for_classes=self.THRESHOLDS_FOR_CLASSES, 
             min_image_size=800)
示例#3
0
)
parser.add_argument(
    "opts",
    help="Modify model config options using the command-line",
    default=None,
    nargs=argparse.REMAINDER,
)

args = parser.parse_args()

# load config from file and command-line arguments
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.MODEL.WEIGHT = args.weights

cfg.freeze()

# The following per-class thresholds are computed by maximizing
# per-class f-measure in their precision-recall curve.
# Please see compute_thresholds_for_classes() in coco_eval.py for details.
thresholds_for_classes = [
    0.4923645853996277, 0.4928510785102844, 0.5040897727012634,
    0.4912887513637543, 0.5016880631446838, 0.5278812646865845,
    0.5351834893226624, 0.5003424882888794, 0.4955945909023285,
    0.43564629554748535, 0.6089804172515869, 0.666087806224823,
    0.5932040214538574, 0.48406165838241577, 0.4062422513961792,
    0.5571075081825256, 0.5671307444572449, 0.5268378257751465,
    0.5112953186035156, 0.4647842049598694, 0.5324517488479614,
    0.5795850157737732, 0.5152440071105957, 0.5280804634094238,
    0.4791383445262909, 0.5261335372924805, 0.4906163215637207,
    0.523737907409668, 0.47027698159217834, 0.5103300213813782,
示例#4
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Webcam Demo")
    parser.add_argument(
        "--config-file",
        default="configs/fcos/fcos_R_50_FPN_1x.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--weights",
        default="FCOS_R_50_FPN_1x.pth",
        metavar="FILE",
        help="path to the trained model",
    )
    parser.add_argument(
        "--images-dir",
        default="demo/images",
        metavar="DIR",
        help="path to demo images directory",
    )
    parser.add_argument(
        "--min-image-size",
        type=int,
        default=800,
        help="Smallest size of the image to feed to the model. "
        "Model was trained with 800, which gives best results",
    )
    parser.add_argument(
        "opts",
        help="Modify model config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    # load config from file and command-line arguments
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.MODEL.WEIGHT = args.weights

    cfg.freeze()

    # The following per-class thresholds are computed by maximizing
    # per-class f-measure in their precision-recall curve.
    # Please see compute_thresholds_for_classes() in coco_eval.py for details.
    thresholds_for_classes = [
        0.49211737513542175, 0.49340692162513733, 0.510103702545166,
        0.4707475006580353, 0.5197340250015259, 0.5007652044296265,
        0.5611110329627991, 0.4639902412891388, 0.4778415560722351,
        0.43332818150520325, 0.6180170178413391, 0.5248752236366272,
        0.5437473654747009, 0.5153843760490417, 0.4194680452346802,
        0.5640717148780823, 0.5087228417396545, 0.5021755695343018,
        0.5307778716087341, 0.4920770823955536, 0.5202335119247437,
        0.5715234279632568, 0.5089765191078186, 0.5422378778457642,
        0.45138806104660034, 0.49631351232528687, 0.4388565421104431,
        0.47193753719329834, 0.47037890553474426, 0.4791252017021179,
        0.45699411630630493, 0.48658522963523865, 0.4580649137496948,
        0.4603237509727478, 0.5243804454803467, 0.5235602855682373,
        0.48501554131507874, 0.5173789858818054, 0.4978085160255432,
        0.4626562297344208, 0.48144686222076416, 0.4889853894710541,
        0.4749937951564789, 0.42273756861686707, 0.47836390137672424,
        0.48752328753471375, 0.44069987535476685, 0.4241463541984558,
        0.5228247046470642, 0.4834112524986267, 0.4538525640964508,
        0.4730372428894043, 0.471712201833725, 0.5180512070655823,
        0.4671719968318939, 0.46602892875671387, 0.47536996006965637,
        0.487352192401886, 0.4771934747695923, 0.45533207058906555,
        0.43941256403923035, 0.5910647511482239, 0.554875910282135,
        0.49752360582351685, 0.6263655424118042, 0.4964958727359772,
        0.5542593002319336, 0.5049241185188293, 0.5306999087333679,
        0.5279538035392761, 0.5708096623420715, 0.524990975856781,
        0.5187852382659912, 0.41242220997810364, 0.5409807562828064,
        0.48504579067230225, 0.47305455803871155, 0.4814004898071289,
        0.42680642008781433, 0.4143834114074707
    ]

    demo_im_names = os.listdir(args.images_dir)

    # prepare object that handles inference plus adds predictions on top of image
    coco_demo = COCODemo(
        cfg,
        confidence_thresholds_for_classes=thresholds_for_classes,
        min_image_size=args.min_image_size)

    for im_name in demo_im_names:
        img = cv2.imread(os.path.join(args.images_dir, im_name))
        if img is None:
            continue
        start_time = time.time()
        composite = coco_demo.run_on_opencv_image(img)
        print("{}\tinference time: {:.2f}s".format(im_name,
                                                   time.time() - start_time))
        cv2.imshow(im_name, composite)
    print("Press any keys to exit ...")
    cv2.waitKey()
    cv2.destroyAllWindows()
示例#5
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Webcam Demo")
    parser.add_argument(
        "--config-file",
        default="../configs/caffe2/e2e_mask_rcnn_R_50_FPN_1x_caffe2.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--confidence-threshold",
        type=float,
        default=0.7,
        help="Minimum score for the prediction to be shown",
    )
    parser.add_argument(
        "--min-image-size",
        type=int,
        default=224,
        help="Smallest size of the image to feed to the model. "
        "Model was trained with 800, which gives best results",
    )
    parser.add_argument(
        "--show-mask-heatmaps",
        dest="show_mask_heatmaps",
        help="Show a heatmap probability for the top masks-per-dim masks",
        action="store_true",
    )
    parser.add_argument(
        "--masks-per-dim",
        type=int,
        default=2,
        help="Number of heatmaps per dimension to show",
    )
    parser.add_argument(
        "opts",
        help="Modify model config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    # load config from file and command-line arguments
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    # prepare object that handles inference plus adds predictions on top of image
    coco_demo = COCODemo(
        cfg,
        confidence_threshold=args.confidence_threshold,
        show_mask_heatmaps=args.show_mask_heatmaps,
        masks_per_dim=args.masks_per_dim,
        min_image_size=args.min_image_size,
    )

    cam = cv2.VideoCapture(0)
    while True:
        start_time = time.time()
        ret_val, img = cam.read()
        composite = coco_demo.run_on_opencv_image(img)
        print("Time: {:.2f} s / img".format(time.time() - start_time))
        cv2.imshow("COCO detections", composite)
        if cv2.waitKey(1) == 27:
            break  # esc to quit
    cv2.destroyAllWindows()
示例#6
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Webcam Demo")
    parser.add_argument(
        "--config-file",
        default="configs/fcos/fcos_syncbn_bs32_c128_MNV2_FPN_1x.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--weights",
        default="FCOS_syncbn_bs32_c128_MNV2_FPN_1x.pth",
        metavar="FILE",
        help="path to the trained model",
    )
    parser.add_argument(
        "--images-dir",
        default="demo/images",
        metavar="DIR",
        help="path to demo images directory",
    )
    parser.add_argument(
        "--min-image-size",
        type=int,
        default=800,
        help="Smallest size of the image to feed to the model. "
        "Model was trained with 800, which gives best results",
    )
    parser.add_argument(
        "opts",
        help="Modify model config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    # load config from file and command-line arguments
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.MODEL.WEIGHT = args.weights

    cfg.freeze()

    # The following per-class thresholds are computed by maximizing
    # per-class f-measure in their precision-recall curve.
    # Please see compute_thresholds_for_classes() in coco_eval.py for details.
    thresholds_for_classes = [
        0.4923645853996277, 0.4928510785102844, 0.5040897727012634,
        0.4912887513637543, 0.5016880631446838, 0.5278812646865845,
        0.5351834893226624, 0.5003424882888794, 0.4955945909023285,
        0.43564629554748535, 0.6089804172515869, 0.666087806224823,
        0.5932040214538574, 0.48406165838241577, 0.4062422513961792,
        0.5571075081825256, 0.5671307444572449, 0.5268378257751465,
        0.5112953186035156, 0.4647842049598694, 0.5324517488479614,
        0.5795850157737732, 0.5152440071105957, 0.5280804634094238,
        0.4791383445262909, 0.5261335372924805, 0.4906163215637207,
        0.523737907409668, 0.47027698159217834, 0.5103300213813782,
        0.4645252823829651, 0.5384289026260376, 0.47796186804771423,
        0.4403403103351593, 0.5101461410522461, 0.5535093545913696,
        0.48472103476524353, 0.5006796717643738, 0.5485560894012451,
        0.4863888621330261, 0.5061569809913635, 0.5235867500305176,
        0.4745445251464844, 0.4652363359928131, 0.4162440598011017,
        0.5252017974853516, 0.42710989713668823, 0.4550687372684479,
        0.4943239390850067, 0.4810051918029785, 0.47629663348197937,
        0.46629616618156433, 0.4662836790084839, 0.4854755401611328,
        0.4156557023525238, 0.4763634502887726, 0.4724511504173279,
        0.4915047585964203, 0.5006274580955505, 0.5124194622039795,
        0.47004589438438416, 0.5374764204025269, 0.5876904129981995,
        0.49395060539245605, 0.5102297067642212, 0.46571290493011475,
        0.5164387822151184, 0.540651798248291, 0.5323763489723206,
        0.5048757195472717, 0.5302401781082153, 0.48333442211151123,
        0.5109739303588867, 0.4077408015727997, 0.5764586925506592,
        0.5109297037124634, 0.4685552418231964, 0.5148998498916626,
        0.4224434792995453, 0.4998510777950287
    ]

    demo_im_names = os.listdir(args.images_dir)

    # prepare object that handles inference plus adds predictions on top of image
    coco_demo = COCODemo(
        cfg,
        confidence_thresholds_for_classes=thresholds_for_classes,
        min_image_size=args.min_image_size)
    i = 0
    for im_name in demo_im_names:
        img = cv2.imread(os.path.join(args.images_dir, im_name))
        if img is None:
            continue
        start_time = time.time()
        composite = coco_demo.run_on_opencv_image(img)
        i = i + 1
        print(i)
        print("{}\tinference time: {:.2f}s".format(im_name,
                                                   time.time() - start_time))
示例#7
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Training")
    parser.add_argument(
        "--run-dir",
        default="run/fcos_imprv_R_50_FPN_1x/Baseline_lr1en4_191209",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    args = parser.parse_args()

    # import pdb; pdb.set_trace()
    target_dir = args.run_dir
    dir_files = sorted(glob.glob(target_dir + '/*'))
    assert (
        target_dir + '/new_config.yml'
    ) in dir_files, "Error! No cfg file found! check if the dir is right."
    cfg_file = target_dir + '/new_config.yml' if (
        target_dir + '/new_config.yml') in dir_files else None
    model_files = [
        f for f in dir_files if f.endswith('00.pth') and 'model_' in f
    ]
    tidyed_before = (target_dir + '/run_res_tidy') in dir_files
    if tidyed_before:
        import pdb
        pdb.set_trace()
        pass
    else:
        os.makedirs(target_dir + '/run_res_tidy')

    cfg.merge_from_file(cfg_file)
    cfg.freeze()

    logger = setup_logger("fcos_core",
                          target_dir + '/run_res_tidy',
                          0,
                          filename="test_log.txt")
    logger.info(cfg)

    # test_str = ''

    model = build_detection_model(cfg)
    model.to(cfg.MODEL.DEVICE)
    checkpointer = DetectronCheckpointer(cfg,
                                         model,
                                         save_dir=target_dir +
                                         '/run_res_tidy/')

    iou_types = ("bbox", )
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm", )
    if cfg.MODEL.KEYPOINT_ON:
        iou_types = iou_types + ("keypoints", )
    # output_folders = [None] * len(cfg.DATASETS.TEST)
    dataset_names = cfg.DATASETS.TEST
    # if cfg.OUTPUT_DIR:
    #     for idx, dataset_name in enumerate(dataset_names):
    #         output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
    #         mkdir(output_folder)
    #         output_folders[idx] = output_folder
    data_loaders_val = make_data_loader(cfg,
                                        is_train=False,
                                        is_distributed=False)
    dataset_name = dataset_names[0]
    data_loader_val = data_loaders_val[0]

    for i, model_f in enumerate(model_files):
        # import pdb; pdb.set_trace()
        _ = checkpointer.load(model_f)
        output_folder = target_dir + '/run_res_tidy/' + dataset_name + '_' + (
            model_f.split('/')[-1][:-4])
        os.makedirs(output_folder)
        logger.info('Processing {}/{}: {}'.format(i, len(model_f),
                                                  output_folder))
        # print('Processing {}/{}: {}'.format(i, len(model_f), output_folder))
        inference_result = inference(
            model,
            data_loader_val,
            dataset_name=dataset_name,
            iou_types=iou_types,
            box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else
            cfg.MODEL.RPN_ONLY,
            device=cfg.MODEL.DEVICE,
            expected_results=cfg.TEST.EXPECTED_RESULTS,
            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
            output_folder=output_folder,
        )
        summaryStrs = get_neat_inference_result(inference_result[2][0])
        # test_str += '\n'+ output_folder.split('/')[-1]+   \
        #     '\n'.join(summaryStrs)
        logger.info(output_folder.split('/')[-1])
        logger.info('\n'.join(summaryStrs))
示例#8
0
def main():
    # 这个就是解析命令行参数,如上面的--config-file configs/fcos/fcos_imprv_R_50_FPN_1x.yaml
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    # 这个参数是torch.distributed.launch传递过来的,我们设置位置参数来接受
    # local_rank代表当前程序进程使用的GPU标号
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        help="Do not test the final model",
        action="store_true",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    # 判断机器上GPU的数量,大于1时自动使用分布式训练
    # WORLD_SIZE 由torch.distributed.launch.py产生
    # 具体数值为 nproc_per_node*node(node就是主机数)
    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = num_gpus > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()
    # 参数默认是在fcos_core/config/defaults.py中,其余由config_file,opts覆盖
    cfg.merge_from_file(args.config_file)  # 从yaml文件中读取参数
    cfg.merge_from_list(args.opts)  # 也可以从命令行参数重写
    cfg.freeze()  # 冻住参数,为了防止之后被不小心更改,cfg被传入train()
    # 可以在这里打印cfg看看,我以fcos_R_50_FPN_1x.yaml为例

    output_dir = cfg.OUTPUT_DIR  # 创建输出文件夹,存放一些日志信息
    if output_dir:
        mkdir(output_dir)

    # 写入日志文件,包括GPU数量,系统环境,配置文件参数等
    logger = setup_logger("fcos_core", output_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info("Loaded configuration file {}".format(args.config_file))
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    # 这句话是下一个入口,关注train()方法,里面第一步就是构建模型
    model = train(cfg, args.local_rank, args.distributed)

    if not args.skip_test:
        run_test(cfg, model, args.distributed)
示例#9
0
def main():
    # 解析命令行参数,例如--config-file
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file", #配置文件
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    #此参数是通过torch.distributed.launch传递过来的,我们设置位置参数来接受
    # local_rank代表当前程序进程使用的GPU标号
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        help="Do not test the final model",
        action="store_true",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER, #所有剩余的命令行参数都被收集到一个列表中
    )

    args = parser.parse_args()
    #判断机器上gpu的数量,大于1时自动使用分布式训练
    #world_size是由torch.distributed.launch.py产生
    # 具体数值为 nproc_per_node*node(node就是主机数)
    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 #判断当前系统环境变量中是否有"WORLD_SIZE" 如果没有num_gpus=1
    args.distributed = num_gpus > 1 #False

    if args.distributed: #False
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group\
        (
            backend="nccl", init_method="env://"
        )
        synchronize()
    #yacs的具体用法 可以参考印象笔记
    #参数默认在fcos_core/config_defaults.py中 其余参数由config_file opts覆盖
    cfg.merge_from_file(args.config_file) #从yaml文件中读取参数 即configs/fcos/fcos_R_50_FPN_1x.yaml
    cfg.merge_from_list(args.opts) #也可以从命令行进行参数重写
    cfg.freeze() #冻结参数 防止不小心被更改 cfg被传入train()

    output_dir = cfg.OUTPUT_DIR #输出模型路径 存放一些日志信息
    if output_dir:
        mkdir(output_dir) #创建对应的输出路径

    #写入日志文件 包括gpu数量,系统环境,配置文件参数等
    logger = setup_logger("fcos_core", output_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info("Loaded configuration file {}".format(args.config_file))
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    model = train(cfg, args.local_rank, args.distributed) #local_rank=0 distributed=False

    if not args.skip_test:
        run_test(cfg, model, args.distributed)
示例#10
0
def main():
    parser = argparse.ArgumentParser(
        description="Export model to the onnx format")
    parser.add_argument(
        "--config-file",
        default="configs/fcos/fcos_imprv_R_50_FPN_1x.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--output",
        default="fcos.onnx",
        metavar="FILE",
        help="path to the output onnx file",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    assert cfg.MODEL.FCOS_ON, "This script is only tested for the detector FCOS."

    save_dir = ""
    logger = setup_logger("fcos_core", save_dir, get_rank())
    logger.info(cfg)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    model = build_detection_model(cfg)
    model.to(cfg.MODEL.DEVICE)

    output_dir = cfg.OUTPUT_DIR
    checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)
    _ = checkpointer.load(cfg.MODEL.WEIGHT)

    onnx_model = torch.nn.Sequential(
        OrderedDict([
            ('backbone', model.backbone),
            ('heads', model.rpn.head),
        ]))

    input_names = ["input_image"]
    dummy_input = torch.zeros((1, 3, 800, 1216)).to(cfg.MODEL.DEVICE)
    output_names = []
    for l in range(len(cfg.MODEL.FCOS.FPN_STRIDES)):
        fpn_name = "P{}/".format(3 + l)
        output_names.extend([
            fpn_name + "logits", fpn_name + "bbox_reg", fpn_name + "centerness"
        ])

    torch.onnx.export(onnx_model,
                      dummy_input,
                      args.output,
                      verbose=True,
                      input_names=input_names,
                      output_names=output_names,
                      keep_initializers_as_inputs=True)

    logger.info("Done. The onnx model is saved into {}.".format(args.output))
示例#11
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("--device_ids", type=list, default=[0])
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        help="Do not test the final model",
        action="store_true",
    )
    parser.add_argument(
        "--use-tensorboard",
        dest="use_tensorboard",
        help="Use tensorboardX logger (Requires tensorboardX installed)",
        action="store_true",
        default=False)

    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    # set devices_ids according to num gpus
    num_gpus = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
    args.device_ids = list(map(str, range(num_gpus)))

    # do not use torch.distributed
    args.distributed = False

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir:
        mkdir(output_dir)

    logger = setup_logger("fad_core", output_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info("Loaded configuration file {}".format(args.config_file))
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    model = train(cfg,
                  args.local_rank,
                  args.distributed,
                  args.device_ids,
                  use_tensorboard=args.use_tensorboard)

    if not args.skip_test:
        run_test(cfg, model, args.distributed)
示例#12
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Webcam Demo")
    parser.add_argument(
        "--config-file",
        default="configs/embed_mask/embed_mask_R50_1x.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--weights",
        default="models/embed_mask_R50_1x.pth",
        metavar="FILE",
        help="path to the trained model",
    )
    parser.add_argument(
        "--images-dir",
        default="demo/images",
        metavar="DIR",
        help="path to demo images directory",
    )
    parser.add_argument(
        "--out-dir",
        default="demo/output",
        metavar="DIR",
        help="path to demo images directory",
    )
    parser.add_argument(
        "--min-image-size",
        type=int,
        default=800,
        help="Smallest size of the image to feed to the model. "
            "Model was trained with 800, which gives best results",
    )
    parser.add_argument(
        "opts",
        help="Modify model config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    # load config from file and command-line arguments
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.MODEL.WEIGHT = args.weights

    cfg.freeze()

    # The following per-class thresholds are computed by maximizing
    # per-class f-measure in their precision-recall curve.
    # Please see compute_thresholds_for_classes() in coco_eval.py for details.
    thresholds_for_classes = [
        0.24445024132728577, 0.2556260824203491, 0.2336651235818863, 0.26643890142440796, 0.22829005122184753,
        0.27605465054512024, 0.29680299758911133, 0.24539557099342346, 0.22566702961921692, 0.21125544607639313,
        0.3632965385913849, 0.42116600275039673, 0.29700127243995667, 0.2278410643339157, 0.2317150980234146,
        0.30244436860084534, 0.32276564836502075, 0.25707629323005676, 0.24852260947227478, 0.24491029977798462,
        0.2518414556980133, 0.35320255160331726, 0.2866332232952118, 0.2207552194595337, 0.2568267285823822,
        0.24461865425109863, 0.20570527017116547, 0.2656995356082916, 0.21232444047927856, 0.2799481451511383,
        0.18180416524410248, 0.2654014825820923, 0.262266606092453, 0.19924932718276978, 0.22213412821292877,
        0.3075449764728546, 0.2290934920310974, 0.2963321805000305, 0.23535756766796112, 0.2430417388677597,
        0.22808006405830383, 0.2716907560825348, 0.21096138656139374, 0.18565504252910614, 0.17213594913482666,
        0.2755044996738434, 0.22538238763809204, 0.22792285680770874, 0.24877801537513733, 0.23092558979988098,
        0.23993775248527527, 0.21917308866977692, 0.2535002529621124, 0.30203622579574585, 0.19476301968097687,
        0.24782243371009827, 0.22699865698814392, 0.25022363662719727, 0.23006463050842285, 0.22317998111248016,
        0.20648975670337677, 0.28253015875816345, 0.35304051637649536, 0.2882220447063446, 0.2875506281852722,
        0.21613512933254242, 0.308322936296463, 0.29409125447273254, 0.3021804690361023, 0.273112416267395,
        0.23458659648895264, 0.2998719811439514, 0.2715963125228882, 0.1898047924041748, 0.32565683126449585,
        0.25560101866722107, 0.265905499458313, 0.3087238669395447, 0.2053961306810379, 0.20331673324108124
    ]

    demo_im_names = os.listdir(args.images_dir)

    # prepare object that handles inference plus adds predictions on top of image
    coco_demo = COCODemo(
        cfg,
        confidence_thresholds_for_classes=thresholds_for_classes,
        min_image_size=args.min_image_size
    )

    for im_name in demo_im_names:
        img = cv2.imread(os.path.join(args.images_dir, im_name))
        if img is None:
            continue
        start_time = time.time()
        composite = coco_demo.run_on_opencv_image(img)
        print("{}\tinference time: {:.2f}s".format(im_name, time.time() - start_time))
        cv2.imwrite(os.path.join(args.out_dir, im_name), composite)
    print("Press any keys to exit ...")
示例#13
0
def run_accCal(model_path,
               test_base_path,
               save_base_path,
               labels_dict,
               config_file,
               input_size=640,
               confidence_thresholds=(0.3, )):
    save_res_path = os.path.join(save_base_path, 'all')
    if os.path.exists(save_res_path):
        shutil.rmtree(save_res_path)
    os.mkdir(save_res_path)

    save_recall_path = os.path.join(save_base_path, 'recall')
    if os.path.exists(save_recall_path):
        shutil.rmtree(save_recall_path)
    os.mkdir(save_recall_path)

    save_ero_path = os.path.join(save_base_path, 'ero')
    if os.path.exists(save_ero_path):
        shutil.rmtree(save_ero_path)
    os.mkdir(save_ero_path)

    save_ori_path = os.path.join(save_base_path, 'ori')
    if os.path.exists(save_ori_path):
        shutil.rmtree(save_ori_path)
    os.mkdir(save_ori_path)

    test_img_path = os.path.join(test_base_path, 'VOC2007/JPEGImages')
    test_ano_path = os.path.join(test_base_path, 'VOC2007/Annotations')
    img_list = glob.glob(test_img_path + '/*.jpg')

    cfg.merge_from_file(config_file)
    cfg.MODEL.WEIGHT = model_path
    cfg.TEST.IMS_PER_BATCH = 1  # only test single image
    cfg.freeze()
    dbg_cfg = cfg

    model = build_detection_model(cfg)
    model.to(cfg.MODEL.DEVICE)
    checkpointer = DetectronCheckpointer(cfg, model, save_dir=cfg.OUTPUT_DIR)
    checkpointer.load(cfg.MODEL.WEIGHT)
    model.eval()

    normalize_transform = T.Normalize(
        mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
    )
    transform = T.Compose(
        [
            T.ToPILImage(),
            T.Resize(input_size),
            T.ToTensor(),
            T.Lambda(lambda x: x * 255),
            normalize_transform,
        ]
    )

    sad_accuracy = [0] * len(confidence_thresholds)
    sad_precision = [0] * len(confidence_thresholds)
    sad_recall = [0] * len(confidence_thresholds)
    spend_time = []
    for idx, img_name in enumerate(img_list):
        progress(int(idx/len(img_list) * 100))
        base_img_name = os.path.split(img_name)[-1]
        frame = cv2.imread(img_name)
        ori_frame = copy.deepcopy(frame)

        h, w = frame.shape[:2]
        image = transform(frame)
        image_list = to_image_list(image, cfg.DATALOADER.SIZE_DIVISIBILITY)
        image_list = image_list.to(cfg.MODEL.DEVICE)

        start_time = time.time()
        with torch.no_grad():
            predictions = model(image_list)
        prediction = predictions[0].to("cpu")
        end_time = time.time()
        spend_time.append(end_time - start_time)

        prediction = prediction.resize((w, h)).convert("xyxy")
        # scores = prediction.get_field("scores")
        # keep = torch.nonzero(scores > confidence_threshold).squeeze(1)
        # prediction = prediction[keep]
        scores = prediction.get_field("scores")
        _, idx = scores.sort(0, descending=True)
        prediction = prediction[idx]
        scores = prediction.get_field("scores").numpy()
        labels = prediction.get_field("labels").numpy()
        bboxes = prediction.bbox.numpy().astype(np.int32)
        bboxes_area = (bboxes[:, 2] - bboxes[:, 0]) * (bboxes[:, 3] - bboxes[:, 1])

        for ii, confidence_threshold in enumerate(confidence_thresholds):
            _keep = np.where((scores > confidence_threshold) & (bboxes_area > 0), True, False)
            _scores = scores[_keep].tolist()
            _labels = labels[_keep].tolist()
            _bboxes = bboxes[_keep].tolist()
            _labels, _bboxes, _scores = soft_nms(_labels, _bboxes, _scores, confidence_threshold)

            if ii == 0:
                for i, b in enumerate(_bboxes):
                    # save all
                    frame = cv2.rectangle(frame,
                                          (b[0], b[1]), (b[2], b[3]),
                                          (100, 220, 200), 2)
                    frame = cv2.putText(frame,
                                        str(_labels[i]) + '-' + str(int(_scores[i] * 100)),
                                        (b[0], b[1]), 1, 1,
                                        (0, 0, 255), 1)
                # cv2.imwrite(os.path.join(save_res_path, base_img_name), frame)

            boxes_list_tmp = copy.deepcopy(_bboxes)
            classes_list_tmp = copy.deepcopy(_labels)
            score_list_tmp = copy.deepcopy(_scores)

            fg_cnt = 0
            recall_flag = False
            xml_name = base_img_name[:-4] + '.xml'
            anno_path = os.path.join(test_ano_path, xml_name)
            tree = ET.parse(anno_path)
            root = tree.getroot()
            rc_box = []
            for siz in root.findall('size'):
                width_ = siz.find('width').text
                height_ = siz.find('height').text
            if not int(width_) or not int(height_):
                width_ = w
                height_ = h
            for obj in root.findall('object'):
                name = obj.find('name').text
                # class_tmp = get_cls(name, labels_dict)
                for bndbox in obj.findall('bndbox'):
                    xmin = bndbox.find('xmin').text
                    ymin = bndbox.find('ymin').text
                    xmax = bndbox.find('xmax').text
                    ymax = bndbox.find('ymax').text
                    tmp_bbox = [int(int(xmin) * w / int(width_)),
                                int(int(ymin) * h / int(height_)),
                                int(int(xmax) * w / int(width_)),
                                int(int(ymax) * h / int(height_))]
                map_flag = False
                for bbox_idx in range(len(boxes_list_tmp)):
                    min_area, box_s, min_flag, iou_score = \
                        get_iou(tmp_bbox, boxes_list_tmp[bbox_idx])
                    if iou_score > 0.3:
                        map_flag = True
                        del classes_list_tmp[bbox_idx]
                        del boxes_list_tmp[bbox_idx]
                        del score_list_tmp[bbox_idx]
                        break
                # 如果没找到匹配,属于漏检,算到召回率/检出率中
                if not map_flag:
                    recall_flag = True
                    rc_box.append(tmp_bbox)
                fg_cnt = fg_cnt + 1

            if recall_flag:
                sad_recall[ii] += 1
                if ii == 0:
                    for box_idx in range(len(rc_box)):
                        x1, y1, x2, y2 = rc_box[box_idx]
                        rca_frame = cv2.rectangle(frame,
                                                  (int(x1), int(y1)), (int(x2), int(y2)),
                                                  (255, 0, 0), 4)
                    cv2.imwrite(os.path.join(save_recall_path, base_img_name), rca_frame)
                    shutil.copy(img_name, os.path.join(save_ori_path, base_img_name))
                    shutil.copy(anno_path, os.path.join(save_ori_path, xml_name))
                # print("sad_recall: " + str(sad_recall))

            # 如果有多出来的,属于误检,ground_truth中没有这个框,算到准确率中
            if len(classes_list_tmp) > 0:
                sad_precision[ii] += 1
                if ii == 0:
                    for box_idx in range(len(boxes_list_tmp)):
                        x1, y1, x2, y2 = boxes_list_tmp[box_idx]
                        ero_frame = cv2.rectangle(frame,
                                                  (int(x1), int(y1)), (int(x2), int(y2)),
                                                  (0, 0, 255), 4)
                        err_rect_name = base_img_name[:-4] + '_' + str(box_idx) + '.jpg'
                        cv2.imwrite(os.path.join(save_ero_path, err_rect_name),
                                    ori_frame[y1: y2, x1: x2, :])
                    cv2.imwrite(os.path.join(save_ero_path, base_img_name), ero_frame)
                    shutil.copy(img_name, os.path.join(save_ori_path, base_img_name))
                    shutil.copy(anno_path, os.path.join(save_ori_path, xml_name))

            if not recall_flag and len(classes_list_tmp) == 0:
                sad_accuracy[ii] += 1

            # print("cur sad: " + str(sad))
            # print("fg_cnt: " + str(fg_cnt))
            # print("pred_cnt: " + str(len(classes_list_tmp)))

    # 单图所有框都检测正确才正确率,少一个框算漏检,多一个框算误检,不看mAP
    print('\nfps is : ', 1 / np.average(spend_time))
    for ii, confidence_threshold in enumerate(confidence_thresholds):
        print("confidence th is : {}".format(confidence_threshold))
        accuracy = float(sad_accuracy[ii] / len(img_list))
        print("accuracy is : {}".format(accuracy))
        precision = 1 - float(sad_precision[ii] / len(img_list))
        print("precision is : {}".format(precision))
        recall = 1 - float(sad_recall[ii] / len(img_list))
        print("recall is : {}\n".format(recall))
示例#14
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Inference")
    parser.add_argument(
        "--config-file",
        default=
        "/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    save_dir = ""
    logger = setup_logger("fcos_core", save_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(cfg)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    model = build_detection_model(cfg)
    model.to(cfg.MODEL.DEVICE)

    output_dir = cfg.OUTPUT_DIR
    checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)
    _ = checkpointer.load(cfg.MODEL.WEIGHT)

    iou_types = ("bbox", ) + ("segm", )
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm", )
    if cfg.MODEL.KEYPOINT_ON:
        iou_types = iou_types + ("keypoints", )
    output_folders = [None] * len(cfg.DATASETS.TEST)
    dataset_names = cfg.DATASETS.TEST
    if cfg.OUTPUT_DIR:
        for idx, dataset_name in enumerate(dataset_names):
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference",
                                         dataset_name)
            mkdir(output_folder)
            output_folders[idx] = output_folder
    data_loaders_val = make_data_loader(cfg,
                                        is_train=False,
                                        is_distributed=distributed)
    for output_folder, dataset_name, data_loader_val in zip(
            output_folders, dataset_names, data_loaders_val):
        inference(
            model,
            data_loader_val,
            dataset_name=dataset_name,
            iou_types=iou_types,
            box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.SIPMASK_ON
            or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
            device=cfg.MODEL.DEVICE,
            expected_results=cfg.TEST.EXPECTED_RESULTS,
            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
            output_folder=output_folder,
        )
        synchronize()
示例#15
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Webcam Demo")
    parser.add_argument(
        "--config-file",
        default="configs/embed_mask/embed_mask_R50_1x.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--weights",
        default="models/embed_mask_R50_1x.pth",
        metavar="FILE",
        help="path to the trained model",
    )
    parser.add_argument(
        "--images-dir",
        default="demo/images",
        metavar="DIR",
        help="path to demo images directory",
    )
    parser.add_argument(
        "--out-dir",
        default="demo/output",
        metavar="DIR",
        help="path to demo images directory",
    )
    parser.add_argument(
        "--min-image-size",
        type=int,
        default=800,
        help="Smallest size of the image to feed to the model. "
        "Model was trained with 800, which gives best results",
    )
    parser.add_argument(
        "opts",
        help="Modify model config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    # load config from file and command-line arguments
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.MODEL.WEIGHT = args.weights

    cfg.freeze()

    # The following per-class thresholds are computed by maximizing
    # per-class f-measure in their precision-recall curve.
    # Please see compute_thresholds_for_classes() in coco_eval.py for details.
    thresholds_for_classes = [
        0.24721044301986694, 0.2316334992647171, 0.23782534897327423,
        0.2447730302810669, 0.26833730936050415, 0.2909756898880005,
        0.22202278673648834, 0.23603129386901855, 0.19448654353618622,
        0.2009030282497406, 0.2205723077058792, 0.4426179826259613,
        0.2812938094139099, 0.23200270533561707, 0.22222928702831268,
        0.34396135807037354, 0.29865574836730957, 0.2620207965373993,
        0.23538640141487122, 0.21343813836574554, 0.23408174514770508,
        0.3619556427001953, 0.25181055068969727, 0.2753196656703949,
        0.20989173650741577, 0.256824254989624, 0.24953776597976685,
        0.2482326775789261, 0.23516853153705597, 0.3231242001056671,
        0.1875445693731308, 0.22903329133987427, 0.220603808760643,
        0.1938045769929886, 0.2102973908185959, 0.30885136127471924,
        0.21589471399784088, 0.2611836791038513, 0.27154257893562317,
        0.2536311149597168, 0.21989859640598297, 0.2741137146949768,
        0.24886088073253632, 0.20183633267879486, 0.17529579997062683,
        0.2467200607061386, 0.2103690654039383, 0.23187917470932007,
        0.28766655921936035, 0.21596665680408478, 0.24378667771816254,
        0.2806374728679657, 0.23764009773731232, 0.2884339392185211,
        0.19776469469070435, 0.29654744267463684, 0.23793953657150269,
        0.2753768265247345, 0.24718035757541656, 0.2166261523962021,
        0.22458019852638245, 0.36707887053489685, 0.29586368799209595,
        0.24396133422851562, 0.3916597068309784, 0.2478819191455841,
        0.3140171468257904, 0.23574240505695343, 0.30935078859329224,
        0.2633970379829407, 0.22616524994373322, 0.22482863068580627,
        0.25680482387542725, 0.184458926320076, 0.31002628803253174,
        0.2936173677444458, 0.2688758671283722, 0.2438362091779709,
        0.17232654988765717, 0.1869594156742096
    ]

    demo_im_names = os.listdir(args.images_dir)

    # prepare object that handles inference plus adds predictions on top of image
    coco_demo = COCODemo(
        cfg,
        confidence_thresholds_for_classes=thresholds_for_classes,
        min_image_size=args.min_image_size)

    for im_name in demo_im_names:
        img = cv2.imread(os.path.join(args.images_dir, im_name))
        if img is None:
            continue
        start_time = time.time()
        composite = coco_demo.run_on_opencv_image(img)
        print("{}\tinference time: {:.2f}s".format(im_name,
                                                   time.time() - start_time))
        cv2.imwrite(os.path.join(args.out_dir, im_name), composite)
    print("Press any keys to exit ...")
示例#16
0
def main():
    parser = argparse.ArgumentParser(description="Test onnx models of FCOS")
    parser.add_argument(
        "--config-file",
        default="/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--onnx-model",
        default="fcos_imprv_R_50_FPN_1x.onnx",
        metavar="FILE",
        help="path to the onnx model",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)

    # The onnx model can only be used with DATALOADER.NUM_WORKERS = 0
    cfg.DATALOADER.NUM_WORKERS = 0

    cfg.freeze()

    save_dir = ""
    logger = setup_logger("fcos_core", save_dir, get_rank())
    logger.info(cfg)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    model = ONNX_FCOS(args.onnx_model, cfg)
    model.to(cfg.MODEL.DEVICE)

    iou_types = ("bbox",)
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm",)
    if cfg.MODEL.KEYPOINT_ON:
        iou_types = iou_types + ("keypoints",)
    output_folders = [None] * len(cfg.DATASETS.TEST)
    dataset_names = cfg.DATASETS.TEST
    if cfg.OUTPUT_DIR:
        for idx, dataset_name in enumerate(dataset_names):
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
            mkdir(output_folder)
            output_folders[idx] = output_folder
    data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=False)
    for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val):
        inference(
            model,
            data_loader_val,
            dataset_name=dataset_name,
            iou_types=iou_types,
            box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
            device=cfg.MODEL.DEVICE,
            expected_results=cfg.TEST.EXPECTED_RESULTS,
            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
            output_folder=output_folder,
        )
        synchronize()
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        help="Do not test the final model",
        action="store_true",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = num_gpus > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)

    # add distance loss warmup iters
    cfg.SOLVER.MAX_ITER += cfg.MODEL.LABELENC.DISTANCE_LOSS_WARMUP_ITERS
    cfg.SOLVER.STEPS = tuple([
        i + cfg.MODEL.LABELENC.DISTANCE_LOSS_WARMUP_ITERS
        for i in cfg.SOLVER.STEPS
    ])

    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir:
        mkdir(output_dir)

    logger = setup_logger("fcos_core", output_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info("Loaded configuration file {}".format(args.config_file))
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    model = train(cfg, args.local_rank, args.distributed)

    if not args.skip_test:
        run_test(cfg, model, args.distributed)

    if args.distributed:
        model = model.module
    if not args.distributed or dist.get_rank() == 0:
        label_encoding_function = model.label_encoding_function.state_dict()
        rpn = model.rpn.state_dict()
        saved_weights = {
            'label_encoding_function': label_encoding_function,
            'rpn': rpn
        }
        if model.roi_heads:
            roi_heads = model.roi_heads.state_dict()
            saved_weights.update({'roi_heads': roi_heads})
        torch.save(saved_weights,
                   os.path.join(cfg.OUTPUT_DIR, "label_encoding_function.pth"))
        logger.info("Successfully save label encoding function weights to " + \
                os.path.join(cfg.OUTPUT_DIR, "label_encoding_function.pth"))
    synchronize()