Пример #1
0
def test_input_info_specification_from_config(mocker, input_info_test_struct):
    stub_fn = mocker.stub()
    mock_model = MockModel(stub_fn)
    config = get_basic_quantization_config("symmetric")
    input_info_config_entry = input_info_test_struct[0]
    target_argument_info = input_info_test_struct[
        1]  # type: List[ModelInputInfo]
    config["input_info"] = input_info_config_entry

    _, _ = create_compressed_model_and_algo_for_test(mock_model, config)
    forward_call_args = stub_fn.call_args[0]
    forward_call_kwargs = stub_fn.call_args[1]

    ref_args_info = list(
        filter(lambda x: x.keyword is None, target_argument_info))
    ref_kw_vs_arg_info = {
        x.keyword: x
        for x in target_argument_info if x.keyword is not None
    }

    def check_arg(arg: torch.Tensor, ref_arg_info: ModelInputInfo):
        assert list(arg.shape) == ref_arg_info.shape
        assert arg.dtype == ref_arg_info.type

    assert len(forward_call_args) == len(ref_args_info)
    assert len(forward_call_kwargs) == len(ref_kw_vs_arg_info)
    assert set(forward_call_kwargs.keys()) == set(ref_kw_vs_arg_info.keys())

    for idx, arg in enumerate(forward_call_args):
        check_arg(arg, ref_args_info[idx])

    for keyword, arg in forward_call_kwargs.items():
        check_arg(arg, ref_kw_vs_arg_info[keyword])
Пример #2
0
def test_context_independence(model_name, model_builder, input_size, _case_config):

    config = get_basic_quantization_config(_case_config.quant_type, input_sample_sizes=input_size[0])
    compressed_models = [create_compressed_model_and_algo_for_test(model_builder[0](), config)[0],
                         create_compressed_model_and_algo_for_test(model_builder[1](), config)[0]]

    for i, compressed_model in enumerate(compressed_models):
        check_model_graph(compressed_model, model_name[i], _case_config.graph_dir)