コード例 #1
0
def train(cfg, local_rank, distributed):
    model = LabelEncStep1Network(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler,
                                         output_dir, save_to_disk)
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model
コード例 #2
0
def run_test(cfg, model, distributed):
    if distributed:
        model["backbone"] = model["backbone"].module
        model["fcos"] = model["fcos"].module
        #if cfg.MODEL.ADV.USE_DIS_P7:
        #    model["dis_P7"] = model["dis_P7"].module
        #if cfg.MODEL.ADV.USE_DIS_P6:
        #    model["dis_P6"] = model["dis_P6"].module
        #if cfg.MODEL.ADV.USE_DIS_P5:
        #    model["dis_P5"] = model["dis_P5"].module
        #if cfg.MODEL.ADV.USE_DIS_P4:
        #    model["dis_P4"] = model["dis_P4"].module
        #if cfg.MODEL.ADV.USE_DIS_P3:
        #    model["dis_P3"] = model["dis_P3"].module
    torch.cuda.empty_cache()  # TODO check if it helps
    iou_types = ("bbox", )
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm", )
    if cfg.MODEL.KEYPOINT_ON:
        iou_types = iou_types + ("keypoints", )
    output_folders = [None] * len(cfg.DATASETS.TEST)
    dataset_names = cfg.DATASETS.TEST
    if cfg.OUTPUT_DIR:
        for idx, dataset_name in enumerate(dataset_names):
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference",
                                         dataset_name)
            mkdir(output_folder)
            output_folders[idx] = output_folder
    data_loaders_val = make_data_loader(cfg,
                                        is_train=False,
                                        is_distributed=distributed)
    for output_folder, dataset_name, data_loader_val in zip(
            output_folders, dataset_names, data_loaders_val):
        inference(
            model,
            data_loader_val,
            dataset_name=dataset_name,
            iou_types=iou_types,
            box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else
            cfg.MODEL.RPN_ONLY,
            device=cfg.MODEL.DEVICE,
            expected_results=cfg.TEST.EXPECTED_RESULTS,
            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
            output_folder=output_folder,
        )
        synchronize()
コード例 #3
0
def run_test(cfg, model, distributed):
    model_test = {}
    if distributed:
        model_test["backbone"] = model["backbone"].module
        model_test["fcos"] = model["fcos"].module
        #if cfg.MODEL.ADV.USE_DIS_P7:
        #    model["dis_P7"] = model["dis_P7"].module
        #if cfg.MODEL.ADV.USE_DIS_P6:
        #    model["dis_P6"] = model["dis_P6"].module
        #if cfg.MODEL.ADV.USE_DIS_P5:
        #    model["dis_P5"] = model["dis_P5"].module
        #if cfg.MODEL.ADV.USE_DIS_P4:
        #    model["dis_P4"] = model["dis_P4"].module
        #if cfg.MODEL.ADV.USE_DIS_P3:
        #    model["dis_P3"] = model["dis_P3"].module
    torch.cuda.empty_cache()  # TODO check if it helps
    iou_types = ("bbox", )
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm", )
    if cfg.MODEL.KEYPOINT_ON:
        iou_types = iou_types + ("keypoints", )
    dataset_name = cfg.DATASETS.TEST[0]
    if cfg.OUTPUT_DIR:
        output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
        mkdir(output_folder)
    data_loaders_val = make_data_loader(cfg,
                                        is_train=False,
                                        is_distributed=distributed)
    results = inference(
        model_test,
        data_loaders_val[0],
        dataset_name=dataset_name,
        iou_types=iou_types,
        box_only=False
        if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
        device=cfg.MODEL.DEVICE,
        expected_results=cfg.TEST.EXPECTED_RESULTS,
        expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
        output_folder=output_folder,
    )
    synchronize()
    results = all_gather(results)
    # import pdb; pdb.set_trace()
    return results
コード例 #4
0
ファイル: grow_res50_combo.py プロジェクト: xianpf/Grow-Model
def run_test(cfg, model, distributed):
    if distributed:
        model = model.module
    torch.cuda.empty_cache()  # TODO check if it helps
    iou_types = ("bbox", )
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm", )
    if cfg.MODEL.KEYPOINT_ON:
        iou_types = iou_types + ("keypoints", )
    output_folders = [None] * len(cfg.DATASETS.TEST)
    dataset_names = cfg.DATASETS.TEST
    if cfg.OUTPUT_DIR:
        for idx, dataset_name in enumerate(dataset_names):
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference",
                                         dataset_name)
            mkdir(output_folder)
            output_folders[idx] = output_folder
    data_loaders_val = make_data_loader(cfg,
                                        is_train=False,
                                        is_distributed=distributed)
    for output_folder, dataset_name, data_loader_val in zip(
            output_folders, dataset_names, data_loaders_val):
        inference_result = inference(
            model,
            data_loader_val,
            dataset_name=dataset_name,
            iou_types=iou_types,
            box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else
            cfg.MODEL.RPN_ONLY,
            device=cfg.MODEL.DEVICE,
            expected_results=cfg.TEST.EXPECTED_RESULTS,
            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
            output_folder=output_folder,
        )
        synchronize()
        # import pdb; pdb.set_trace()
        summaryStrs = get_neat_inference_result(inference_result[2][0])
        # print('\n'.join(summaryStrs))
        with open(output_folder + '/summaryStrs.txt', 'w') as f_summaryStrs:
            f_summaryStrs.write('\n'.join(summaryStrs))
コード例 #5
0
def main():
    parser = argparse.ArgumentParser(description="Test onnx models of FCOS")
    parser.add_argument(
        "--config-file",
        default="/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--onnx-model",
        default="fcos_imprv_R_50_FPN_1x.onnx",
        metavar="FILE",
        help="path to the onnx model",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)

    # The onnx model can only be used with DATALOADER.NUM_WORKERS = 0
    cfg.DATALOADER.NUM_WORKERS = 0

    cfg.freeze()

    save_dir = ""
    logger = setup_logger("fcos_core", save_dir, get_rank())
    logger.info(cfg)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    model = ONNX_FCOS(args.onnx_model, cfg)
    model.to(cfg.MODEL.DEVICE)

    iou_types = ("bbox",)
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm",)
    if cfg.MODEL.KEYPOINT_ON:
        iou_types = iou_types + ("keypoints",)
    output_folders = [None] * len(cfg.DATASETS.TEST)
    dataset_names = cfg.DATASETS.TEST
    if cfg.OUTPUT_DIR:
        for idx, dataset_name in enumerate(dataset_names):
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
            mkdir(output_folder)
            output_folders[idx] = output_folder
    data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=False)
    for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val):
        inference(
            model,
            data_loader_val,
            dataset_name=dataset_name,
            iou_types=iou_types,
            box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
            device=cfg.MODEL.DEVICE,
            expected_results=cfg.TEST.EXPECTED_RESULTS,
            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
            output_folder=output_folder,
        )
        synchronize()
コード例 #6
0
def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler,
                                         output_dir, save_to_disk)
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    # import matplotlib.pyplot as plt
    # import numpy as np
    #
    # def imshow(img):
    #     #img = img / 2 + 0.5  # unnormalize
    #     img = img + 115
    #     img = img[[2, 1, 0]]
    #     npimg = img.numpy().astype(np.int)
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # import torchvision
    # dataiter = iter(data_loader)
    # images, target, _ = dataiter.next()  #chwangteg target and pixel is hundreds
    #
    # imshow(torchvision.utils.make_grid(images.tensors))

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model
コード例 #7
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Inference")
    parser.add_argument(
        "--config-file",
        default=
        "/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    save_dir = ""
    logger = setup_logger("fcos_core", save_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(cfg)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    model = build_detection_model(cfg)
    model.to(cfg.MODEL.DEVICE)

    output_dir = cfg.OUTPUT_DIR
    checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)
    _ = checkpointer.load(cfg.MODEL.WEIGHT)

    iou_types = ("bbox", ) + ("segm", )
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm", )
    if cfg.MODEL.KEYPOINT_ON:
        iou_types = iou_types + ("keypoints", )
    output_folders = [None] * len(cfg.DATASETS.TEST)
    dataset_names = cfg.DATASETS.TEST
    if cfg.OUTPUT_DIR:
        for idx, dataset_name in enumerate(dataset_names):
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference",
                                         dataset_name)
            mkdir(output_folder)
            output_folders[idx] = output_folder
    data_loaders_val = make_data_loader(cfg,
                                        is_train=False,
                                        is_distributed=distributed)
    for output_folder, dataset_name, data_loader_val in zip(
            output_folders, dataset_names, data_loaders_val):
        inference(
            model,
            data_loader_val,
            dataset_name=dataset_name,
            iou_types=iou_types,
            box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.SIPMASK_ON
            or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
            device=cfg.MODEL.DEVICE,
            expected_results=cfg.TEST.EXPECTED_RESULTS,
            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
            output_folder=output_folder,
        )
        synchronize()
コード例 #8
0
def train(cfg, local_rank, distributed, labelenc_fpath):
    model = LabelEncStep2Network(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    # Load LabelEncodingFunction
    # Initialize FPN and Head from Step1 weights
    if not checkpointer.has_checkpoint():
        labelenc_weights = torch.load(labelenc_fpath, map_location=torch.device('cpu'))
        # load LabelEncodingFunction
        model.module.label_encoding_function.load_state_dict(
                labelenc_weights['label_encoding_function'], strict=True)
        # Initialize Head
        model.module.rpn.load_state_dict(
                labelenc_weights['rpn'], strict=True)
        if model.module.roi_heads:
            model.module.roi_heads.load_state_dict(
                labelenc_weights['roi_heads'], strict=True)
        # Initialize FPN
        fpn_weight = model.module.label_encoding_function.fpn.state_dict()
        model.module.backbone.fpn.load_state_dict(fpn_weight, strict=True)


    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model
コード例 #9
0
ファイル: run_res_tidy.py プロジェクト: xianpf/Grow-Model
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Training")
    parser.add_argument(
        "--run-dir",
        default="run/fcos_imprv_R_50_FPN_1x/Baseline_lr1en4_191209",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    args = parser.parse_args()

    # import pdb; pdb.set_trace()
    target_dir = args.run_dir
    dir_files = sorted(glob.glob(target_dir + '/*'))
    assert (
        target_dir + '/new_config.yml'
    ) in dir_files, "Error! No cfg file found! check if the dir is right."
    cfg_file = target_dir + '/new_config.yml' if (
        target_dir + '/new_config.yml') in dir_files else None
    model_files = [
        f for f in dir_files if f.endswith('00.pth') and 'model_' in f
    ]
    tidyed_before = (target_dir + '/run_res_tidy') in dir_files
    if tidyed_before:
        import pdb
        pdb.set_trace()
        pass
    else:
        os.makedirs(target_dir + '/run_res_tidy')

    cfg.merge_from_file(cfg_file)
    cfg.freeze()

    logger = setup_logger("fcos_core",
                          target_dir + '/run_res_tidy',
                          0,
                          filename="test_log.txt")
    logger.info(cfg)

    # test_str = ''

    model = build_detection_model(cfg)
    model.to(cfg.MODEL.DEVICE)
    checkpointer = DetectronCheckpointer(cfg,
                                         model,
                                         save_dir=target_dir +
                                         '/run_res_tidy/')

    iou_types = ("bbox", )
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm", )
    if cfg.MODEL.KEYPOINT_ON:
        iou_types = iou_types + ("keypoints", )
    # output_folders = [None] * len(cfg.DATASETS.TEST)
    dataset_names = cfg.DATASETS.TEST
    # if cfg.OUTPUT_DIR:
    #     for idx, dataset_name in enumerate(dataset_names):
    #         output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
    #         mkdir(output_folder)
    #         output_folders[idx] = output_folder
    data_loaders_val = make_data_loader(cfg,
                                        is_train=False,
                                        is_distributed=False)
    dataset_name = dataset_names[0]
    data_loader_val = data_loaders_val[0]

    for i, model_f in enumerate(model_files):
        # import pdb; pdb.set_trace()
        _ = checkpointer.load(model_f)
        output_folder = target_dir + '/run_res_tidy/' + dataset_name + '_' + (
            model_f.split('/')[-1][:-4])
        os.makedirs(output_folder)
        logger.info('Processing {}/{}: {}'.format(i, len(model_f),
                                                  output_folder))
        # print('Processing {}/{}: {}'.format(i, len(model_f), output_folder))
        inference_result = inference(
            model,
            data_loader_val,
            dataset_name=dataset_name,
            iou_types=iou_types,
            box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else
            cfg.MODEL.RPN_ONLY,
            device=cfg.MODEL.DEVICE,
            expected_results=cfg.TEST.EXPECTED_RESULTS,
            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
            output_folder=output_folder,
        )
        summaryStrs = get_neat_inference_result(inference_result[2][0])
        # test_str += '\n'+ output_folder.split('/')[-1]+   \
        #     '\n'.join(summaryStrs)
        logger.info(output_folder.split('/')[-1])
        logger.info('\n'.join(summaryStrs))
コード例 #10
0
ファイル: search_net.py プロジェクト: zyg11/research-fad
def train(cfg, local_rank, distributed, device_ids, use_tensorboard=False):

    # ------------------------------- more configs
    half_data = [0, 0]  # do not split

    first_order = cfg.SOLVER.SEARCH.FIRST_ORDER
    alpha_lr = cfg.SOLVER.SEARCH.BASE_LR_ALPHA
    alpha_weight_decay = 1e-3

    device_ids = [int(x) for x in device_ids]

    if cfg.MODEL.FAD.CLSTOWER or cfg.MODEL.FAD.BOXTOWER:
        n_cells = cfg.MODEL.FAD.NUM_CELLS_CLS
        if cfg.MODEL.FAD.CLSTOWER and cfg.MODEL.FAD.BOXTOWER:
            n_nodes = cfg.MODEL.FAD.NUM_NODES_CLS
            n_module = 2
        elif cfg.MODEL.FAD.CLSTOWER:
            n_nodes = cfg.MODEL.FAD.NUM_NODES_CLS
            n_module = 1
        else:
            n_nodes = cfg.MODEL.FAD.NUM_NODES_BOX
            n_module = 1
    else:
        pdb.set_trace()

    # build model
    model = SearchRCNNController(n_cells,
                                 n_nodes=n_nodes,
                                 device_ids=device_ids,
                                 cfg_det=cfg,
                                 n_module=n_module)

    device = torch.device(cfg.MODEL.DEVICE)
    model = model.to(device)
    torch.cuda.set_device(0)
    distributed = False

    if first_order: print('Using 1st order approximationfor the search')

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    # ---------------------- optimize alpha
    arch = Architect(model, cfg.SOLVER.MOMENTUM, cfg.SOLVER.WEIGHT_DECAY)
    alpha_optim = torch.optim.Adam(model.alphas(),
                                   alpha_lr,
                                   betas=(0.5, 0.999),
                                   weight_decay=alpha_weight_decay)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    # ------ tensorboard
    tb_info = {"tb_logger": None}
    if use_tensorboard:
        tb_logger = get_tensorboard_writer(output_dir)
        tb_info['tb_logger'] = tb_logger
        tb_info['prefix'] = cfg.TENSORBOARD.PREFIX

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler,
                                         output_dir, save_to_disk)

    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(cfg,
                                   is_train=True,
                                   is_distributed=distributed,
                                   start_iter=arguments["iteration"],
                                   half=half_data[0])

    val_loader = make_data_loader(cfg,
                                  is_train=True,
                                  is_distributed=distributed,
                                  start_iter=arguments["iteration"],
                                  half=half_data[1])

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        arch,
        data_loader,
        val_loader,
        optimizer,
        alpha_optim,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
        cfg,
        tb_info=tb_info,
        first_order=first_order,
    )

    return model
コード例 #11
0
def train(cfg, local_rank, distributed, iter_clear, ignore_head):
    model = build_detection_model(cfg)
    # model, conversion_count = convert_to_shift_dbg(
    #         model,
    #         cfg.DEEPSHIFT_DEPTH,
    #         cfg.DEEPSHIFT_TYPE,
    #         convert_weights=True,
    #         use_kernel=cfg.DEEPSHIFT_USEKERNEL,
    #         rounding=cfg.DEEPSHIFT_ROUNDING,
    #         shift_range=cfg.DEEPSHIFT_RANGE)

    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    output_dir = cfg.OUTPUT_DIR
    save_to_disk = get_rank() == 0
    if iter_clear:
        load_opt = False
        load_sch = False
    else:
        load_opt = True
        load_sch = True
    if ignore_head:
        load_body = True
        load_fpn = True
        load_head = False
    else:
        load_body = True
        load_fpn = True
        load_head = True
    # 预加载模型或者是通常的模型,或者是deepshift模型
    if cfg.MODEL.WEIGHT:
        checkpointer = DetectronCheckpointer(
            cfg, model, None, None, output_dir, save_to_disk
        )

        extra_checkpoint_data = checkpointer.load(
            cfg.MODEL.WEIGHT, load_opt=False, load_sch=False,
            load_body=load_body, load_fpn=load_fpn, load_head=load_head)
        
        model, conversion_count = convert_to_shift(
            model,
            cfg.DEEPSHIFT_DEPTH,
            cfg.DEEPSHIFT_TYPE,
            convert_weights=True,
            use_kernel=cfg.DEEPSHIFT_USEKERNEL,
            rounding=cfg.DEEPSHIFT_ROUNDING,
            shift_range=cfg.DEEPSHIFT_RANGE)
        
        optimizer = make_optimizer(cfg, model)
        scheduler = make_lr_scheduler(cfg, optimizer)

        checkpointer = DetectronCheckpointer(
            cfg, model, optimizer, scheduler, output_dir, save_to_disk
        )
    else:
        model, conversion_count = convert_to_shift(
            model,
            cfg.DEEPSHIFT_DEPTH,
            cfg.DEEPSHIFT_TYPE,
            convert_weights=True,
            use_kernel=cfg.DEEPSHIFT_USEKERNEL,
            rounding=cfg.DEEPSHIFT_ROUNDING,
            shift_range=cfg.DEEPSHIFT_RANGE)
        
        optimizer = make_optimizer(cfg, model)
        scheduler = make_lr_scheduler(cfg, optimizer)

        checkpointer = DetectronCheckpointer(
            cfg, model, optimizer, scheduler, output_dir, save_to_disk
        )

        extra_checkpoint_data = checkpointer.load(
            cfg.MODEL.WEIGHT, load_opt=False, load_sch=False,
            load_body=load_body, load_fpn=load_fpn, load_head=load_head)
    
    conv2d_layers_count = count_layer_type(model, torch.nn.Conv2d)
    linear_layers_count = count_layer_type(model, torch.nn.Linear)
    print("###### conversion_count: {}, not convert conv2d layer: {}, linear layer: {}".format(
        conversion_count, conv2d_layers_count, linear_layers_count))

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    arguments.update(extra_checkpoint_data)

    if iter_clear:
        arguments["iteration"] = 0

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    model = round_shift_weights(model)
    torch.save({"model": model.state_dict()}, os.path.join(output_dir, "model_final_round.pth"))

    return model