예제 #1
0
def create_model(config: SampleConfig, resuming_checkpoint: dict = None):
    input_info_list = create_input_infos(config.nncf_config)
    image_size = input_info_list[0].shape[-1]
    ssd_net = build_ssd(config.model, config.ssd_params, image_size,
                        config.num_classes, config)
    weights = config.get('weights')
    if weights:
        sd = torch.load(weights,
                        map_location='cpu',
                        pickle_module=restricted_pickle_module)
        sd = sd["state_dict"]
        load_state(ssd_net, sd)

    ssd_net.to(config.device)
    model_state_dict, compression_state = extract_model_and_compression_states(
        resuming_checkpoint)
    compression_ctrl, compressed_model = create_compressed_model(
        ssd_net, config.nncf_config, compression_state)
    if model_state_dict is not None:
        load_state(compressed_model, model_state_dict, is_resume=True)
    compressed_model, _ = prepare_model_for_execution(compressed_model, config)

    compressed_model.train()
    return compression_ctrl, compressed_model
예제 #2
0
def is_pretrained_model_requested(config: SampleConfig) -> bool:
    return config.get('pretrained', True) if config.get('weights') is None else False