Ejemplo n.º 1
0
def test_can_sparsify_embedding(algo):
    config = {"input_info": {"sample_size": [1, 10], "type": "long"}}
    sparsity_init = 0.5
    config['compression'] = {'algorithm': algo, 'sparsity_init': sparsity_init}
    nncf_config = NNCFConfig.from_dict(config)
    model = EmbeddingOnlyModel()
    model, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, nncf_config)

    # Should pass
    _ = compression_ctrl.statistics()
Ejemplo n.º 2
0
def nncf_config_with_default_init_args_(mocker):
    config = NNCFConfig.from_dict(CONFIG_WITH_ALL_INIT_TYPES)

    train_loader = DataLoader(
        OnesDatasetMock(INPUT_SAMPLE_SIZE[1:]),
        batch_size=1,
        num_workers=0,  # Workaround for PyTorch MultiprocessingDataLoader issues
        shuffle=False)
    mocker_criterion = mocker.stub()
    mocker_criterion.batch_size = 1

    config = register_default_init_args(config, train_loader, mocker_criterion)
    return config
Ejemplo n.º 3
0
def create_sample_config(args, parser) -> SampleConfig:
    sample_config = SampleConfig.from_json(args.config)
    sample_config.update_from_args(args, parser)

    file_path = Path(args.config).resolve()
    with safe_open(file_path) as f:
        loaded_json = json.load(f)

    if sample_config.get("target_device") is not None:
        target_device = sample_config.pop("target_device")
        loaded_json["target_device"] = target_device

    nncf_config = NNCFConfig.from_dict(loaded_json)
    sample_config.nncf_config = nncf_config
    return sample_config
Ejemplo n.º 4
0
 def add_range_init(config):
     for compression in config['compression']:
         if compression['algorithm'] == 'quantization':
             if 'initializer' not in compression:
                 compression['initializer'] = {}
             compression['initializer'].update(
                 {'range': {
                     'num_init_samples': 1
                 }})
             data_loader = create_ones_mock_dataloader(config)
             config = NNCFConfig.from_dict(config)
             config.register_extra_structs([
                 QuantizationRangeInitArgs(
                     wrap_dataloader_for_init(data_loader))
             ])
     return config