Ejemplo n.º 1
0
def setup_cfg(args):
    # load config from file and command-line arguments
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg
Ejemplo n.º 2
0
def setup(args):
    cfg = get_cfg()
    if args.config_file:
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg
Ejemplo n.º 3
0
def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)
    return cfg
Ejemplo n.º 4
0
def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(
        cfg, args
    )  # if you don't like any of the default setup, write your own setup code
    return cfg
Ejemplo n.º 5
0
def setup_cfg(args):
    # load config from file and command-line arguments
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    # Set score_threshold for builtin models
    cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
    cfg.MODEL.FCOS.INFERENCE_TH_TEST = args.confidence_threshold
    cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
    cfg.freeze()
    return cfg
Ejemplo n.º 6
0
def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)

    rank = comm.get_rank()
    setup_logger(cfg.OUTPUT_DIR, distributed_rank=rank, name="adet")

    return cfg
Ejemplo n.º 7
0
def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)

    ##############################
    # NOTE: pop some unwanterd configs in detectron2
    cfg.SOLVER.pop("STEPS", None)
    cfg.SOLVER.pop("MAX_ITER", None)
    # NOTE: get optimizer from string cfg dict
    if cfg.SOLVER.OPTIMIZER_CFG != "":
        optim_cfg = eval(cfg.SOLVER.OPTIMIZER_CFG)
        iprint("optimizer_cfg:", optim_cfg)
        cfg.SOLVER.OPTIMIZER_NAME = optim_cfg['type']
        cfg.SOLVER.BASE_LR = optim_cfg['lr']
        cfg.SOLVER.MOMENTUM = optim_cfg.get("momentum", 0.9)
        cfg.SOLVER.WEIGHT_DECAY = optim_cfg.get("weight_decay", 1e-4)
    if cfg.get("DEBUG", False):
        iprint("DEBUG")
        args.num_gpus = 1
        args.num_machines = 1
        cfg.DATALOADER.NUM_WORKERS = 0
        cfg.TRAIN.PRINT_FREQ = 1
    if cfg.TRAIN.get("VERBOSE", False):
        cfg.TRAIN.PRINT_FREQ = 1

    # register datasets
    dataset_names = cfg.DATASETS.TRAIN + cfg.DATASETS.TEST
    register_datasets(dataset_names)

    cfg.RESUME = args.resume
    ##########################################
    cfg.freeze()
    default_setup(cfg, args)

    setup_for_distributed(is_master=comm.is_main_process())

    rank = comm.get_rank()
    setup_my_logger(cfg.OUTPUT_DIR, distributed_rank=rank, name="adet")
    setup_my_logger(cfg.OUTPUT_DIR, distributed_rank=rank, name="core")

    return cfg
Ejemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser(
        description="Export model to the onnx format")
    parser.add_argument(
        "--config-file",
        default="configs/FCOS-Detection/R_50_1x.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument('--width', default=0, type=int)
    parser.add_argument('--height', default=0, type=int)
    parser.add_argument('--level', default=0, type=int)
    parser.add_argument(
        "--output",
        default="output/fcos.onnx",
        metavar="FILE",
        help="path to the output onnx file",
    )
    parser.add_argument(
        "--opts",
        help="Modify config options using the command-line 'KEY VALUE' pairs",
        default=[],
        nargs=argparse.REMAINDER,
    )

    cfg = get_cfg()
    args = parser.parse_args()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)

    # norm for ONNX: change FrozenBN back to BN
    cfg.MODEL.BACKBONE.FREEZE_AT = 0
    cfg.MODEL.RESNETS.NORM = "BN"
    # turn on the following configuration according to your own convenience
    #cfg.MODEL.FCOS.NORM = "BN"
    #cfg.MODEL.FCOS.NORM = "NaiveGN"

    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    logger = setup_logger(output=output_dir)
    logger.info(cfg)

    model = build_model(cfg)
    model.eval()
    model.to(cfg.MODEL.DEVICE)
    logger.info("Model:\n{}".format(model))

    checkpointer = DetectionCheckpointer(model)
    _ = checkpointer.load(cfg.MODEL.WEIGHTS)
    logger.info("load Model:\n{}".format(cfg.MODEL.WEIGHTS))

    # patch fcos_head
    # step 1. config
    fcos_head = model.proposal_generator.fcos_head
    norm = None if cfg.MODEL.FCOS.NORM == "none" else cfg.MODEL.FCOS.NORM
    head_configs = {
        "cls": (cfg.MODEL.FCOS.NUM_CLS_CONVS, cfg.MODEL.FCOS.USE_DEFORMABLE),
        "bbox": (cfg.MODEL.FCOS.NUM_BOX_CONVS, cfg.MODEL.FCOS.USE_DEFORMABLE),
        "share": (cfg.MODEL.FCOS.NUM_SHARE_CONVS, False)
    }

    # step 2. seperate module
    for l in range(fcos_head.num_levels):
        for head in head_configs:
            tower = []
            num_convs, use_deformable = head_configs[head]
            for i in range(num_convs):
                tower.append(
                    deepcopy(
                        getattr(fcos_head,
                                '{}_tower'.format(head))[i * 3 + 0]))
                if norm in ["GN", "NaiveGN"]:
                    tower.append(
                        deepcopy(
                            getattr(fcos_head,
                                    '{}_tower'.format(head))[i * 3 + 1]))
                elif norm in ["BN", "SyncBN"]:
                    tower.append(
                        deepcopy(
                            getattr(fcos_head,
                                    '{}_tower'.format(head))[i * 3 + 1][l]))
                tower.append(
                    deepcopy(
                        getattr(fcos_head,
                                '{}_tower'.format(head))[i * 3 + 2]))
            fcos_head.add_module('{}_tower{}'.format(head, l),
                                 torch.nn.Sequential(*tower))

    # step 3. override forward
    def fcos_head_forward(self, x):
        logits = []
        bbox_reg = []
        ctrness = []
        bbox_towers = []
        for l, feature in enumerate(x):
            feature = self.share_tower(feature)
            cls_tower = getattr(self, 'cls_tower{}'.format(l))(feature)
            bbox_tower = getattr(self, 'bbox_tower{}'.format(l))(feature)

            logits.append(self.cls_logits(cls_tower))
            ctrness.append(self.ctrness(bbox_tower))
            reg = self.bbox_pred(bbox_tower)
            if self.scales is not None:
                reg = self.scales[l](reg)
            # Note that we use relu, as in the improved FCOS, instead of exp.
            bbox_reg.append(F.relu(reg))

        return logits, bbox_reg, ctrness, bbox_towers

    fcos_head.forward = types.MethodType(fcos_head_forward, fcos_head)

    proposal_generator = FCOS(cfg)
    onnx_model = torch.nn.Sequential(
        OrderedDict([
            ('backbone', model.backbone),
            ('proposal_generator', proposal_generator),
            ('heads', model.proposal_generator.fcos_head),
        ]))

    height, width = 800, 1088
    if args.width > 0:
        width = args.width
    if args.height > 0:
        height = args.height
    input_names = ["input_image"]
    dummy_input = torch.zeros((1, 3, height, width)).to(cfg.MODEL.DEVICE)
    output_names = []
    for item in ["logits", "bbox_reg", "centerness"]:
        for l in range(len(cfg.MODEL.FCOS.FPN_STRIDES)):
            fpn_name = "P{}".format(3 + l)
            output_names.extend([fpn_name + item])

    logger.info("Load onnx model from {}.".format(args.output))
    sess = rt.InferenceSession(args.output)

    for in_blob in sess.get_inputs():
        if in_blob.name not in input_names:
            print("Input blob name not match that in the mode")
        else:
            print("Input {}, shape {} and type {}".format(
                in_blob.name, in_blob.shape, in_blob.type))
    for out_blob in sess.get_outputs():
        if out_blob.name not in output_names:
            print("Output blob name not match that in the mode")
        else:
            print("Output {}, shape {} and type {}".format(
                out_blob.name, out_blob.shape, out_blob.type))

    with torch.no_grad():
        torch_output = onnx_model(dummy_input)
        logits, bbox_reg, ctrness, bbox_towers = torch_output
        lists = logits + bbox_reg + ctrness + bbox_towers

    onnx_output = sess.run(None, {input_names[0]: dummy_input.cpu().numpy()})
    for i, out in enumerate(onnx_output):
        try:
            np.testing.assert_allclose(lists[i].cpu().detach().numpy(),
                                       out,
                                       rtol=1e-03,
                                       atol=2e-04)
        except AssertionError as e:
            print("ouput {} mismatch {}".format(output_names[i], e))
            continue
        print("ouput {} match".format(output_names[i]))
Ejemplo n.º 9
0
from detectron2.modeling import GeneralizedRCNNWithTTA
from detectron2.utils.logger import setup_logger

from adet.data.dataset_mapper import DatasetMapperWithBasis
from adet.data.dataset_mapper_depth import DatasetMapperWithDepth
from adet.config import get_cfg
from adet.checkpoint import AdetCheckpointer
from adet.evaluation import TextEvaluator
import matplotlib.pyplot as plt
import numpy as np
from detectron2.engine import DefaultPredictor
import time

from detectron2.data import (
    build_detection_test_loader, )
cfg = get_cfg()
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(
    r"/home/zqzhou/disk/Code/python/AdelaiDet/configs/BoxData/CondInst/CondInst_R_50_3x_sem.yaml"
)
cfg.MODEL.PROPOSAL_GENERATOR.NAME = 'FCOSM'
cfg.MODEL.META_ARCHITECTURE = 'CondInstM'
cfg.MODEL.FCOSM.USE_BOX_NMS = False
# cfg.MODEL.FCOSM.USE_SINGLE_NMS = False
cfg.MODEL.WEIGHTS = "/home/zqzhou/disk/Code/python/AdelaiDet/output/BoxData/CondInst/condinst_MS_R_50_3x_sem/old_0.0050/model_final.pth"
predictor_old = DefaultPredictor(cfg)

cfg = get_cfg()
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(
    r"/home/zqzhou/disk/Code/python/AdelaiDet/configs/BoxData/CondInst/CondInst_R_50_3x_sem.yaml"
Ejemplo n.º 10
0
def main():
    parser = argparse.ArgumentParser(description="Export model to the onnx format")
    parser.add_argument(
        "--config-file",
        default="configs/FCOS-Detection/R_50_1x.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument('--width', default=0, type=int)
    parser.add_argument('--height', default=0, type=int)
    parser.add_argument('--level', default=0, type=int)
    parser.add_argument(
        "--output",
        default="output/fcos.onnx",
        metavar="FILE",
        help="path to the output onnx file",
    )
    parser.add_argument(
        "--opts",
        help="Modify config options using the command-line 'KEY VALUE' pairs",
        default=[],
        nargs=argparse.REMAINDER,
    )

    cfg = get_cfg()
    args = parser.parse_args()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)

    # norm for ONNX: change FrozenBN back to BN
    cfg.MODEL.BACKBONE.FREEZE_AT = 0
    cfg.MODEL.RESNETS.NORM = "BN"
    cfg.MODEL.BASIS_MODULE.NORM = "BN"

    # turn on the following configuration according to your own convenience
    #cfg.MODEL.FCOS.NORM = "BN"
    #cfg.MODEL.FCOS.NORM = "NaiveGN"

    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    logger = setup_logger(output=output_dir)
    logger.info(cfg)

    model = build_model(cfg)
    model.eval()
    model.to(cfg.MODEL.DEVICE)
    logger.info("Model:\n{}".format(model))

    checkpointer = DetectionCheckpointer(model)
    _ = checkpointer.load(cfg.MODEL.WEIGHTS)
    logger.info("load Model:\n{}".format(cfg.MODEL.WEIGHTS))

    height, width = 800, 1088
    if args.width > 0:
        width = args.width
    if args.height > 0:
        height = args.height
    input_names = ["input_image"]
    dummy_input = torch.zeros((1, 3, height, width)).to(cfg.MODEL.DEVICE)
    output_names = []
    if isinstance(model, BlendMask):
        patch_blendmask(cfg, model, output_names)

    if isinstance(model, ProposalNetwork):
        patch_ProposalNetwork(cfg, model, output_names)

    if hasattr(model, 'proposal_generator'):
        if isinstance(model.proposal_generator, FCOS):
            patch_fcos(cfg, model.proposal_generator)
            patch_fcos_head(cfg, model.proposal_generator.fcos_head)

    torch.onnx.export(
        model,
        dummy_input,
        args.output,
        verbose=True,
        input_names=input_names,
        output_names=output_names,
        keep_initializers_as_inputs=True
    )

    logger.info("Done. The onnx model is saved into {}.".format(args.output))
Ejemplo n.º 11
0
def main():
    parser = argparse.ArgumentParser(
        description="Export model to the onnx format")
    parser.add_argument(
        "--config-file",
        default="configs/FCOS-Detection/R_50_1x.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--output",
        default="output/fcos.onnx",
        metavar="FILE",
        help="path to the output onnx file",
    )
    parser.add_argument(
        "--opts",
        help="Modify config options using the command-line 'KEY VALUE' pairs",
        default=[],
        nargs=argparse.REMAINDER,
    )

    cfg = get_cfg()
    args = parser.parse_args()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)

    # norm for ONNX: change FrozenBN back to BN
    cfg.MODEL.BACKBONE.FREEZE_AT = 0
    cfg.MODEL.RESNETS.NORM = "BN"
    # turn on the following configuration according to your own convenience
    #cfg.MODEL.FCOS.NORM = "BN"
    #cfg.MODEL.FCOS.NORM = "NaiveGN"

    # The onnx model can only be used with DATALOADER.NUM_WORKERS = 0
    cfg.DATALOADER.NUM_WORKERS = 0

    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    logger = setup_logger(output=output_dir)
    logger.info(cfg)

    model = build_model(cfg)
    model.eval()
    model.to(cfg.MODEL.DEVICE)
    logger.info("Model:\n{}".format(model))

    checkpointer = DetectionCheckpointer(model)
    _ = checkpointer.load(cfg.MODEL.WEIGHTS)

    proposal_generator = FCOS(cfg)
    onnx_model = torch.nn.Sequential(
        OrderedDict([
            ('backbone', model.backbone),
            ('proposal_generator', proposal_generator),
            ('heads', model.proposal_generator.fcos_head),
        ]))

    height, width = 512, 640
    input_names = ["input_image"]
    dummy_input = torch.zeros((1, 3, height, width)).to(cfg.MODEL.DEVICE)
    output_names = []
    for l in range(len(cfg.MODEL.FCOS.FPN_STRIDES)):
        fpn_name = "P{}/".format(3 + l)
        output_names.extend([
            fpn_name + "logits", fpn_name + "bbox_reg", fpn_name + "centerness"
        ])

    torch.onnx.export(onnx_model,
                      dummy_input,
                      args.output,
                      verbose=True,
                      input_names=input_names,
                      output_names=output_names,
                      keep_initializers_as_inputs=True)

    logger.info("Done. The onnx model is saved into {}.".format(args.output))
Ejemplo n.º 12
0
def main():
    parser = argparse.ArgumentParser(description="Export model to the onnx format")
    parser.add_argument(
        "--config-file",
        default="configs/FCOS-Detection/R_50_1x.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument('--width', default=0, type=int)
    parser.add_argument('--height', default=0, type=int)
    parser.add_argument('--level', default=0, type=int)
    parser.add_argument(
        "--output",
        default="output/fcos.onnx",
        metavar="FILE",
        help="path to the output onnx file",
    )
    parser.add_argument(
        "--opts",
        help="Modify config options using the command-line 'KEY VALUE' pairs",
        default=[],
        nargs=argparse.REMAINDER,
    )

    cfg = get_cfg()
    args = parser.parse_args()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)

    # norm for ONNX: change FrozenBN back to BN
    cfg.MODEL.BACKBONE.FREEZE_AT = 0
    cfg.MODEL.RESNETS.NORM = "BN"
    cfg.MODEL.BASIS_MODULE.NORM = "BN"

    # turn on the following configuration according to your own convenience
    #cfg.MODEL.FCOS.NORM = "BN"
    #cfg.MODEL.FCOS.NORM = "NaiveGN"

    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    logger = setup_logger(output=output_dir)
    logger.info(cfg)

    model = build_model(cfg)
    model.eval()
    model.to(cfg.MODEL.DEVICE)
    logger.info("Model:\n{}".format(model))

    checkpointer = DetectionCheckpointer(model)
    _ = checkpointer.load(cfg.MODEL.WEIGHTS)
    logger.info("load Model:\n{}".format(cfg.MODEL.WEIGHTS))

    height, width = 800, 1088
    if args.width > 0:
        width = args.width
    if args.height > 0:
        height = args.height
    input_names = ["input_image"]
    dummy_input = torch.zeros((1, 3, height, width)).to(cfg.MODEL.DEVICE)
    output_names = []
    if isinstance(model, BlendMask):
        patch_blendmask(cfg, model, output_names)

    if isinstance(model, ProposalNetwork):
        patch_ProposalNetwork(cfg, model, output_names)

    if hasattr(model, 'proposal_generator'):
        if isinstance(model.proposal_generator, FCOS):
            patch_fcos(cfg, model.proposal_generator)
            patch_fcos_head(cfg, model.proposal_generator.fcos_head)


    logger.info("Load onnx model from {}.".format(args.output))
    sess = rt.InferenceSession(args.output)
    
    # check input and output
    for in_blob in sess.get_inputs():
        if in_blob.name not in input_names:
            print("Input blob name not match that in the mode")
        else:
            print("Input {}, shape {} and type {}".format(in_blob.name, in_blob.shape, in_blob.type))
    for out_blob in sess.get_outputs():
        if out_blob.name not in output_names:
            print("Output blob name not match that in the mode")
        else:
            print("Output {}, shape {} and type {}".format(out_blob.name, out_blob.shape, out_blob.type))

    # run pytorch model
    with torch.no_grad():
        torch_output = model(dummy_input)
        torch_output = to_list(torch_output)

    # run onnx by onnxruntime
    onnx_output = sess.run(None, {input_names[0]: dummy_input.cpu().numpy()})

    # run onnx by tensorrt
    logger.info("Load onnx model from {}.".format(args.output))
    load_model = onnx.load(args.output)
    onnx.checker.check_model(load_model)
    onnx_model = backend.prepare(load_model)
    tensorrt_output = onnx_model.run(dummy_input.data.numpy())

    # compare the result
    for i, out in enumerate(onnx_output):
        try:
            np.testing.assert_allclose(torch_output[i].cpu().detach().numpy(), out, rtol=1e-03, atol=2e-04)
        except AssertionError as e:
            print("ouput {} mismatch {}".format(output_names[i], e))
            continue
        print("ouput {} match\n".format(output_names[i]))

    # compare the result
    for i, out in enumerate(tensorrt_output):
        try:
            np.testing.assert_allclose(torch_output[i].cpu().detach().numpy(), out, rtol=1e-03, atol=2e-04)
        except AssertionError as e:
            print("ouput {} mismatch {}".format(output_names[i], e))
            continue
        print("ouput {} match\n".format(output_names[i]))
Ejemplo n.º 13
0
def main():
    parser = argparse.ArgumentParser(
        description="Export model to the onnx format")
    parser.add_argument(
        "--config-file",
        default="configs/FCOS-Detection/R_50_1x.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--output",
        default="output/fcos.onnx",
        metavar="FILE",
        help="path to the output onnx file",
    )
    parser.add_argument(
        "--opts",
        help="Modify config options using the command-line 'KEY VALUE' pairs",
        default=[],
        nargs=argparse.REMAINDER,
    )

    cfg = get_cfg()
    args = parser.parse_args()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)

    # norm for ONNX: change FrozenBN back to BN
    cfg.MODEL.BACKBONE.FREEZE_AT = 0
    cfg.MODEL.RESNETS.NORM = "BN"
    # turn on the following configuration according to your own convenience
    #cfg.MODEL.FCOS.NORM = "BN"
    #cfg.MODEL.FCOS.NORM = "NaiveGN"

    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    logger = setup_logger(output=output_dir)
    logger.info(cfg)

    model = build_model(cfg)
    model.eval()
    model.to(cfg.MODEL.DEVICE)
    logger.info("Model:\n{}".format(model))

    checkpointer = DetectionCheckpointer(model)
    _ = checkpointer.load(cfg.MODEL.WEIGHTS)
    logger.info("load Model:\n{}".format(cfg.MODEL.WEIGHTS))

    # patch fcos_head
    # step 1. config
    fcos_head = model.proposal_generator.fcos_head
    norm = None if cfg.MODEL.FCOS.NORM == "none" else cfg.MODEL.FCOS.NORM
    head_configs = {
        "cls": (cfg.MODEL.FCOS.NUM_CLS_CONVS, cfg.MODEL.FCOS.USE_DEFORMABLE),
        "bbox": (cfg.MODEL.FCOS.NUM_BOX_CONVS, cfg.MODEL.FCOS.USE_DEFORMABLE),
        "share": (cfg.MODEL.FCOS.NUM_SHARE_CONVS, False)
    }

    # step 2. seperate module
    for l in range(fcos_head.num_levels):
        for head in head_configs:
            tower = []
            num_convs, use_deformable = head_configs[head]
            for i in range(num_convs):
                tower.append(
                    deepcopy(
                        getattr(fcos_head,
                                '{}_tower'.format(head))[i * 3 + 0]))
                if norm in ["GN", "NaiveGN"]:
                    tower.append(
                        deepcopy(
                            getattr(fcos_head,
                                    '{}_tower'.format(head))[i * 3 + 1]))
                elif norm in ["BN", "SyncBN"]:
                    tower.append(
                        deepcopy(
                            getattr(fcos_head,
                                    '{}_tower'.format(head))[i * 3 + 1][l]))
                tower.append(
                    deepcopy(
                        getattr(fcos_head,
                                '{}_tower'.format(head))[i * 3 + 2]))
            fcos_head.add_module('{}_tower{}'.format(head, l),
                                 torch.nn.Sequential(*tower))

    # step 3. override forward
    def fcos_head_forward(self, x):
        logits = []
        bbox_reg = []
        ctrness = []
        bbox_towers = []
        for l, feature in enumerate(x):
            feature = self.share_tower(feature)
            cls_tower = getattr(self, 'cls_tower{}'.format(l))(feature)
            bbox_tower = getattr(self, 'bbox_tower{}'.format(l))(feature)

            logits.append(self.cls_logits(cls_tower))
            ctrness.append(self.ctrness(bbox_tower))
            reg = self.bbox_pred(bbox_tower)
            if self.scales is not None:
                reg = self.scales[l](reg)
            # Note that we use relu, as in the improved FCOS, instead of exp.
            bbox_reg.append(F.relu(reg))

        return logits, bbox_reg, ctrness, bbox_towers

    fcos_head.forward = types.MethodType(fcos_head_forward, fcos_head)

    proposal_generator = FCOS(cfg)
    onnx_model = torch.nn.Sequential(
        OrderedDict([
            ('backbone', model.backbone),
            ('proposal_generator', proposal_generator),
            ('heads', model.proposal_generator.fcos_head),
        ]))

    height, width = 800, 1088
    input_names = ["input_image"]
    dummy_input = torch.zeros((1, 3, height, width)).to(cfg.MODEL.DEVICE)
    output_names = []
    for item in ["logits", "bbox_reg", "centerness"]:
        for l in range(len(cfg.MODEL.FCOS.FPN_STRIDES)):
            fpn_name = "P{}".format(3 + l)
            output_names.extend([fpn_name + item])

    torch.onnx.export(onnx_model,
                      dummy_input,
                      args.output,
                      verbose=True,
                      input_names=input_names,
                      output_names=output_names,
                      keep_initializers_as_inputs=True)

    logger.info("Done. The onnx model is saved into {}.".format(args.output))