def init_model(cfg):
    model_cfg = edict()
    model_cfg.crop_size = (320, 480)
    model_cfg.input_normalization = {
        'mean': [.485, .456, .406],
        'std': [.229, .224, .225]
    }
    model_cfg.num_max_points = 10

    model_cfg.input_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(model_cfg.input_normalization['mean'],
                             model_cfg.input_normalization['std']),
    ])

    model = get_hrnet_model(width=32,
                            ocr_width=128,
                            max_interactive_points=model_cfg.num_max_points,
                            with_aux_output=True)

    model.to(cfg.device)
    model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0))
    model.feature_extractor.load_pretrained_weights(
        cfg.IMAGENET_PRETRAINED_MODELS.HRNETV2_W32)

    return model, model_cfg
def load_hrnet_is_model(state_dict,
                        device,
                        backbone='auto',
                        width=48,
                        ocr_width=256,
                        small=False,
                        cpu_dist_maps=False,
                        norm_radius=260):
    if backbone == 'auto':
        num_fe_weights = len(
            [x for x in state_dict.keys() if 'feature_extractor.' in x])
        small = num_fe_weights < 1800

        ocr_f_down = [
            v for k, v in state_dict.items()
            if 'object_context_block.f_down.1.0.bias' in k
        ]
        assert len(ocr_f_down) == 1
        ocr_width = ocr_f_down[0].shape[0]

        s2_conv1_w = [
            v for k, v in state_dict.items()
            if 'stage2.0.branches.0.0.conv1.weight' in k
        ]
        assert len(s2_conv1_w) == 1
        width = s2_conv1_w[0].shape[0]

    model = get_hrnet_model(width=width,
                            ocr_width=ocr_width,
                            small=small,
                            with_aux_output=False,
                            cpu_dist_maps=cpu_dist_maps,
                            norm_radius=norm_radius)

    model.load_state_dict(state_dict, strict=False)
    for param in model.parameters():
        param.requires_grad = False
    model.to(device)
    model.eval()

    return model