コード例 #1
0
def make_model(args, config):
    # parse params with default values

    body_config = config["body"]
    fpn_config = config["fpn"]
    ir_config = config["global"]
    data_config = config["dataloader"]

    # get output dimensionality size
    net_modules = []

    # BN + activation
    norm_act_static, norm_act_dynamic = norm_act_from_config(body_config)

    # Create backbone
    log_debug("Creating backbone model %s", body_config.get("arch"))

    body_fn = models.__dict__[body_config.get("arch")]
    body_params = body_config.getstruct("body_params") if body_config.get(
        "body_params") else {}
    body = body_fn(norm_act=norm_act_static, config=body_config, **body_params)

    net_modules.append("body")

    if body_config.getboolean("pretrained"):
        arch = body_config.get("arch")

        # vgg with bn or without
        if body_config.get("arch").startswith("vgg"):
            if body_config["normalization_mode"] != 'off':
                arch = body_config.get("arch") + '_bn'

        # Download pre trained model
        log_debug("Downloading pre - trained model weights ")

        if body_config.get("source_url") == "cvut":
            if body_config.get("arch") not in model_urls_cvut:
                raise ValueError(
                    " body arch not found in cvut witch  source_url = pytorch")
            log_info("Downloading from m ", model_urls_cvut[arch])
            state_dict = load_state_dict_from_url(model_urls_cvut[arch],
                                                  progress=True)

        elif body_config.get("source_url") == "pytorch":
            if body_config.get("arch") not in model_urls:
                raise ValueError(" body arch not found in pytorch ")
            state_dict = load_state_dict_from_url(model_urls[arch],
                                                  progress=True)

        else:
            raise ValueError(" try source_url = cvut  or pytorch  ")

        # Convert model to unified format and save it
        converted_model = body.convert(state_dict)
        folder = args.directory + "/image_net"

        if not path.exists(folder):
            log_debug("Create path to save pretrained backbones: %s ", folder)
            makedirs(folder)

        body_path = folder + "/" + arch + ".pth"
        log_debug("Saving pretrained backbones in : %s ", body_path)
        torch.save(converted_model, body_path)

        # Load  converted weights to model
        body.load_state_dict(torch.load(body_path, map_location="cpu"))

        # Freeze modules in backbone
        for n, m in body.named_modules():
            for mod_id in range(1, body_config.getint("num_frozen") + 1):
                if ("mod%d" % mod_id) in n:
                    freeze_params(m)

    else:
        log_info("Initialize body to train from scratch")
        init_weights(body, body_config)

    # Feature pyramids
    if fpn_config.getboolean("fpn"):
        # Create FPN
        body_channels = body_config.getstruct("out_channels")

        fpn_inputs = fpn_config.getstruct("inputs")
        fpn_outputs = fpn_config.getstruct("outputs")

        fpn = FPN([body_channels[inp] for inp in fpn_inputs],
                  fpn_config.getint("out_channels"),
                  fpn_config.getint("extra_scales"), norm_act_static,
                  fpn_config.get("interpolation"))

        body = FPNBody(body, fpn, fpn_inputs)

        output_dim = fpn_config.getint("out_channels")
    else:
        output_dim = OUTPUT_DIM[body_config.get("arch")]

    # Create Image Retrieval
    global_loss = globalFeatureLoss(name=ir_config.get("loss"),
                                    sigma=ir_config.getfloat("loss_margin"))

    global_algo = globalFeatureAlgo(
        loss=global_loss,
        min_level=ir_config.getint("fpn_min_level"),
        fpn_levels=ir_config.getint("fpn_levels"))

    global_head = globalHead(pooling=ir_config.getstruct("pooling"),
                             normal=ir_config.getstruct("normal"),
                             dim=output_dim)
    # Data augmentation

    augment = RandomAugmentation(rgb_mean=data_config.getstruct("rgb_mean"),
                                 rgb_std=data_config.getstruct("rgb_std"))

    # Create a generic image retrieval network
    net = ImageRetrievalNet(body, global_algo, global_head, augment=augment)

    return net, net_modules, output_dim
コード例 #2
0
def make_model(args, config):

    # parse params with default values
    body_config = config["body"]
    fpn_config = config["fpn"]
    local_config = config["local"]
    transformer_config = config["transformer"]

    data_config = config["dataloader"]

    # BN + activation
    norm_act_static, norm_act_dynamic = norm_act_from_config(body_config)

    # Create backbone
    log_debug("Creating backbone model %s", body_config.get("arch"))

    body_fn = models.__dict__[body_config.get("arch")]
    body_params = body_config.getstruct("body_params") if body_config.get(
        "body_params") else {}
    body = body_fn(norm_act=norm_act_static, config=body_config, **body_params)

    if body_config.getboolean("pretrained"):
        arch = body_config.get("arch")

        # vgg with bn or without
        if body_config.get("arch").startswith("vgg"):
            if body_config["normalization_mode"] != 'off':
                arch = body_config.get("arch") + '_bn'

        # Download pre trained model
        log_debug("Downloading pre - trained model weights ")

        if body_config.get("source_url") == "cvut":
            if body_config.get("arch") not in model_urls_cvut:
                raise ValueError(
                    " body arch not found in cvut witch  source_url = pytorch")
            log_info("Downloading from m ", model_urls_cvut[arch])
            state_dict = load_state_dict_from_url(model_urls_cvut[arch],
                                                  progress=True)

        elif body_config.get("source_url") == "pytorch":
            if body_config.get("arch") not in model_urls:
                raise ValueError(" body arch not found in pytorch ")
            state_dict = load_state_dict_from_url(model_urls[arch],
                                                  progress=True)

        else:
            raise ValueError(" try source_url = cvut  or pytorch  ")

        # Convert model to unified format and save it
        converted_model = body.convert(state_dict)
        folder = args.directory + "/image_net"

        if not path.exists(folder):
            log_debug("Create path to save pretrained backbones: %s ", folder)
            makedirs(folder)

        body_path = folder + "/" + arch + ".pth"
        log_debug("Saving pretrained backbones in : %s ", body_path)
        torch.save(converted_model, body_path)

        # Load  converted weights to model
        body.load_state_dict(torch.load(body_path, map_location="cpu"))

        # Freeze modules in backbone
        for n, m in body.named_modules():
            for mod_id in body_config.getstruct("num_frozen"):
                if ("mod%d" % mod_id) in n:
                    freeze_params(m)
        # delete mods
        for mod_id in body_config.getstruct("num_frozen"):
            key = "mod" + str(mod_id)
            delattr(body, key)

    else:
        log_info("Initialize body to train from scratch")
        init_weights(body, body_config)

    # Feature pyramids
    if fpn_config.getboolean("fpn"):

        # Create FPN
        body_channels = body_config.getstruct("out_channels")

        fpn_inputs = fpn_config.getstruct("inputs")

        fpn = FPN([body_channels[inp] for inp in fpn_inputs],
                  fpn_config.getint("out_channels"),
                  fpn_config.getint("extra_scales"), norm_act_static,
                  fpn_config.get("interpolation"))

        body = FPNBody(body, fpn, fpn_inputs)

        output_dim = fpn_config.getint("out_channels")
    else:
        output_dim = OUTPUT_DIM[body_config.get("arch")]

    # Create Local features

    local_gen = DenseFeatureGenerator(
        max_predictions=local_config.getint("max_kpt"),
        nms_threshold=None,
        score_threshold=None)

    local_loss = DenseFeatureLoss(
        name=local_config.get("loss"),
        gamma=local_config.getfloat("gamma"),
        epipolar_margin=local_config.getstruct("epipolar_margin"),
        cyclic_margin=local_config.getfloat("cyclic_margin"))

    local_pos_encoding = PositionEncodingSine(embedding_size=output_dim)

    local_matching_coarse = CoarseMatching(
        thr=local_config.getfloat("thr"),
        border_rm=local_config.getint("border_rm"),
        temperature=local_config.getfloat("temperature"))

    local_matching_fine = FineMatching()

    local_algo = DenseFeatureAlgo(
        generator=local_gen,
        loss=local_loss,
        coarse_matching=local_matching_coarse,
        fine_matching=local_matching_fine,
        pos_encoding=local_pos_encoding,
        min_level=local_config.getint("fpn_min_level"),
        fpn_levels=local_config.getint("fpn_levels"),
        win_size=local_config.getint("win_size"))

    local_head_coarse = DenseHead(
        dim=output_dim,
        embedding_size=transformer_config.getint("embedding_size"),
        num_head=transformer_config.getint("num_head"),
        layer_names=transformer_config.getstruct("layer_names"),
        layer_num=transformer_config.getint("layer_num_coarse"),
        attention=transformer_config.get("attention"))

    local_head_fine = DenseHead(
        dim=output_dim,
        embedding_size=transformer_config.getint("embedding_size"),
        num_head=transformer_config.getint("num_head"),
        layer_names=transformer_config.getstruct("layer_names"),
        layer_num=transformer_config.getint("layer_num_fine"),
        attention=transformer_config.get("attention"))

    if local_config.getint("embedding_size"):
        output_dim = local_config.getint("embedding_size")

    # Create a generic Local features network
    net = DenseNet(body, local_algo, local_head_coarse, local_head_fine)

    return net, output_dim
コード例 #3
0
def make_model(args, config):
    body_config = config["body"]
    fpn_config = config["fpn"]
    local_config = config["local"]
    data_config = config["dataloader"]

    # BN + activation
    norm_act_static, norm_act_dynamic = norm_act_from_config(body_config)

    # Create backbone
    log_debug("Creating backbone model %s", body_config.get("arch"))

    body_fn = models.__dict__[body_config.get("arch")]
    body_params = body_config.getstruct("body_params") if body_config.get(
        "body_params") else {}
    body = body_fn(norm_act=norm_act_static, config=body_config, **body_params)

    # Feature pyramids
    if fpn_config.getboolean("fpn"):

        # Create FPN
        body_channels = body_config.getstruct("out_channels")

        fpn_inputs = fpn_config.getstruct("inputs")

        fpn = FPN([body_channels[inp] for inp in fpn_inputs],
                  fpn_config.getint("out_channels"),
                  fpn_config.getint("extra_scales"), norm_act_static,
                  fpn_config.get("interpolation"))

        body = FPNBody(body, fpn, fpn_inputs)

        output_dim = fpn_config.getint("out_channels")
    else:
        output_dim = OUTPUT_DIM[body_config.get("arch")]

    # Create Local features

    local_gen = localFeatureGenerator(
        max_predictions=local_config.getint("max_kpt"),
        nms_threshold=None,
        score_threshold=None)

    local_loss = localFeatureLoss(
        name=local_config.get("loss"),
        gamma=local_config.getfloat("gamma"),
        epipolar_margin=local_config.getstruct("epipolar_margin"),
        cyclic_margin=local_config.getfloat("cyclic_margin"))

    local_algo = localFeatureAlgo(
        keypoint_generator=local_gen,
        loss=local_loss,
        min_level=local_config.getint("fpn_min_level"),
        fpn_levels=local_config.getint("fpn_levels"))

    local_head = localHead(
        dim=output_dim, embedding_size=local_config.getint("embedding_size"))

    if local_config.getint("embedding_size"):
        output_dim = local_config.getint("embedding_size")

    # Create a generic Local features network
    net = localNet(body, local_algo, local_head)

    return net, output_dim