Ejemplo n.º 1
0
def round_shift_weights(model, shift_base=2, clone=False):
    if (clone):
        model = copy.deepcopy(model)

    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            # recurse
            model._modules[name] = round_shift_weights(model=module,
                                                       shift_base=shift_base)

        if type(module) == deepshift.modules.LinearShift or type(
                module) == deepshift.modules.Conv2dShift:
            module.shift.data = module.shift.round()
            module.sign.data = module.sign.round().sign()

            if (module.bias is not None):
                module.bias.data = utils.round_to_fixed(module.bias,
                                                        fraction_bits=16,
                                                        integer_bits=16)
        elif type(module) == deepshift.modules_q.LinearShiftQ or type(
                module) == deepshift.modules_q.Conv2dShiftQ:
            module.weight.data = utils.round_power_of_2(
                module.weight, shift_base)

            if (module.bias is not None):
                module.bias.data = utils.round_to_fixed(module.bias,
                                                        fraction_bits=16,
                                                        integer_bits=16)

    return model
Ejemplo n.º 2
0
def round_shift_weights(model, clone=False):

    # for name, param in model.named_parameters():
    #     if param.is_leaf:
    #         print(name)
    model.eval()
    if(clone):
        model = copy.deepcopy(model)

    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            # recurse
            model._modules[name] = round_shift_weights(model=module)

        if type(module) == deepshift.modules.LinearShift or type(module) == deepshift.modules.Conv2dShift:
            module.shift.data = module.shift.round()
            module.sign.data = module.sign.round().sign()

            if (module.bias is not None):
                module.bias.data = utils.round_to_fixed(module.bias, fraction=16, integer=16)
        elif type(module) == deepshift.modules_q.LinearShiftQ or type(module) == deepshift.modules_q.Conv2dShiftQ:
            module.weight.data = utils.round_power_of_2(module.weight)

            if (module.bias is not None):
                module.bias.data = utils.round_to_fixed(module.bias, fraction=16, integer=16)

    return model
Ejemplo n.º 3
0
def round_shift_weights(model, clone=False, weight_bits=5, act_integer_bits=16, act_fraction_bits=16):
    if(clone):
        model = copy.deepcopy(model)

    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            # recurse
            model._modules[name] = round_shift_weights(model=module, weight_bits=weight_bits, act_integer_bits=act_integer_bits, act_fraction_bits=act_fraction_bits)

        if type(module) == deepshift.modules.LinearShift or type(module) == deepshift.modules.Conv2dShift:
            module.shift.data = module.shift.round()
            module.sign.data = module.sign.round().sign()

            if (module.bias is not None):
                module.bias.data = utils.round_to_fixed(module.bias, integer_bits=act_integer_bits, fraction_bits=act_fraction_bits)
        elif type(module) == deepshift.modules_q.LinearShiftQ or type(module) == deepshift.modules_q.Conv2dShiftQ:
            module.weight.data = utils.clampabs(module.weight.data, 2**module.shift_range[0], 2**module.shift_range[1]) 
            module.weight.data = utils.round_power_of_2(module.weight)

            if (module.bias is not None):
                module.bias.data = utils.round_to_fixed(module.bias, integer_bits=act_integer_bits, fraction_bits=act_fraction_bits)

    return model
Ejemplo n.º 4
0
 def forward(ctx, input, z):
     return utils.round_power_of_2(input, z[0], z[1])
Ejemplo n.º 5
0
 def forward(ctx, input, stochastic=False):
     return utils.round_power_of_2(input, stochastic)
Ejemplo n.º 6
0
 def forward(ctx, input):
     return utils.round_power_of_2(input)