Exemple #1
0
def get_basic_sparsity_config(model_size=4,
                              input_sample_size=(1, 1, 4, 4),
                              sparsity_init=0.02,
                              sparsity_target=0.5,
                              sparsity_steps=2,
                              sparsity_training_steps=3):
    config = Config()
    config.update({
        "model": "basic_sparse_conv",
        "model_size": model_size,
        "input_info": {
            "sample_size": input_sample_size,
        },
        "compression": {
            "algorithm": "rb_sparsity",
            "params": {
                "schedule": "polynomial",
                "sparsity_init": sparsity_init,
                "sparsity_target": sparsity_target,
                "sparsity_steps": sparsity_steps,
                "sparsity_training_steps": sparsity_training_steps
            },
            "layers": {
                "conv": {
                    "sparsify": True
                },
            }
        }
    })
    return config
def get_empty_config(model_size=4, input_sample_size=(1, 1, 4, 4)):
    config = Config()
    config.update({
        "model": "basic_sparse_conv",
        "model_size": model_size,
        "input_sample_size": input_sample_size,
    })
    return config
Exemple #3
0
def get_basic_magnitude_sparsity_config(input_sample_size=(1, 1, 4, 4)):
    config = Config()
    config.update({
        "model": "basic_sparse_conv",
        "input_sample_size": input_sample_size,
        "compression":
            {
                "algorithm": "magnitude_sparsity",
                "params": {}
            }
    })
    return config
Exemple #4
0
def get_basic_sparsity_plus_quantization_config(input_sample_size=(1, 1, 4,
                                                                   4)):
    config = Config()
    config.update({
        "input_sample_size":
        input_sample_size,
        "compression": [{
            "algorithm": "rb_sparsity",
        }, {
            "algorithm": "quantization",
        }]
    })
    return config
Exemple #5
0
def get_squeezenet_quantization_config(model_size=32):
    config = Config()
    config.update({
        "model": "squeezenet1_1_custom",
        "model_size": model_size,
        "input_sample_size": (3, 3, model_size, model_size),
        "compression": {
            "algorithm": "quantization",
            "initializer": {
                "num_init_steps": 0
            }
        }
    })
    return config
Exemple #6
0
def get_basic_quantization_config(model_size=4):
    config = Config()
    config.update({
        "model": "basic_quant_conv",
        "model_size": model_size,
        "input_sample_size": (1, 1, model_size, model_size),
        "compression": {
            "algorithm": "quantization",
            "initializer": {
                "num_init_steps": 0
            },
            "params": {}
        }
    })
    return config
Exemple #7
0
def get_basic_pruning_config(input_sample_size=(1, 1, 4, 4)):
    config = Config()
    config.update({
        "model": "pruning_conv_model",
        "input_info":
            {
                "sample_size": input_sample_size,
            },
        "compression":
            {
                "params": {
                }
            }
    })
    return config
Exemple #8
0
def test_get_default_weight_decay(algo, ref_weight_decay):
    config = Config()
    config.update({"compression": {"algorithm": algo}})
    assert ref_weight_decay == get_default_weight_decay(config)