def get_basic_sparsity_config(model_size=4, input_sample_size=(1, 1, 4, 4), sparsity_init=0.02, sparsity_target=0.5, sparsity_steps=2, sparsity_training_steps=3): config = Config() config.update({ "model": "basic_sparse_conv", "model_size": model_size, "input_info": { "sample_size": input_sample_size, }, "compression": { "algorithm": "rb_sparsity", "params": { "schedule": "polynomial", "sparsity_init": sparsity_init, "sparsity_target": sparsity_target, "sparsity_steps": sparsity_steps, "sparsity_training_steps": sparsity_training_steps }, "layers": { "conv": { "sparsify": True }, } } }) return config
def get_empty_config(model_size=4, input_sample_size=(1, 1, 4, 4)): config = Config() config.update({ "model": "basic_sparse_conv", "model_size": model_size, "input_sample_size": input_sample_size, }) return config
def get_basic_magnitude_sparsity_config(input_sample_size=(1, 1, 4, 4)): config = Config() config.update({ "model": "basic_sparse_conv", "input_sample_size": input_sample_size, "compression": { "algorithm": "magnitude_sparsity", "params": {} } }) return config
def get_basic_sparsity_plus_quantization_config(input_sample_size=(1, 1, 4, 4)): config = Config() config.update({ "input_sample_size": input_sample_size, "compression": [{ "algorithm": "rb_sparsity", }, { "algorithm": "quantization", }] }) return config
def get_squeezenet_quantization_config(model_size=32): config = Config() config.update({ "model": "squeezenet1_1_custom", "model_size": model_size, "input_sample_size": (3, 3, model_size, model_size), "compression": { "algorithm": "quantization", "initializer": { "num_init_steps": 0 } } }) return config
def get_basic_quantization_config(model_size=4): config = Config() config.update({ "model": "basic_quant_conv", "model_size": model_size, "input_sample_size": (1, 1, model_size, model_size), "compression": { "algorithm": "quantization", "initializer": { "num_init_steps": 0 }, "params": {} } }) return config
def get_basic_pruning_config(input_sample_size=(1, 1, 4, 4)): config = Config() config.update({ "model": "pruning_conv_model", "input_info": { "sample_size": input_sample_size, }, "compression": { "params": { } } }) return config
def test_get_default_weight_decay(algo, ref_weight_decay): config = Config() config.update({"compression": {"algorithm": algo}}) assert ref_weight_decay == get_default_weight_decay(config)