예제 #1
0
def weight_map(weights, mapped_float, mapped_binary, error_rate, indicies, device):
    shape = weights.shape
    weights_flat = weights.view(-1)
    if weights_flat.numel() > 16:
        weight_binary = mapped_binary
    else:
        return weights
    # Creating masks for all weights in one layer
    mask0_binary, mask1_binary = SAsimulate3.create_mask(shape, error_rate=error_rate)
    # reporter = MemReporter()
    # reporter.report()
    mask0_binary, mask1_binary = (
        mask0_binary.view(int(mask0_binary.numel() / 32 / 16), 16, 32),
        mask1_binary.view(int(mask1_binary.numel() / 32 / 16), 16, 32),
    )
    new_weight_binary = torch.empty(
        [*mapped_binary.shape], device=device, dtype=torch.int8
    )
    for i in range(16):
        new_weight_binary[:, :, i, :] = SAsimulate3.make_SA(
            mapped_binary[:, :, i, :], mask0_binary, mask1_binary
        )
    half_shape = int(new_weight_binary.shape[0] / 2)
    new_weight = torch.empty(half_shape * 2, 16, 16, device=device)
    new_weight[0:half_shape, ...] = bit2float(
        new_weight_binary[0:half_shape, ...], num_e_bits=8, num_m_bits=23, bias=127.0
    )
    new_weight[half_shape : new_weight.shape[0], ...] = bit2float(
        new_weight_binary[half_shape : new_weight.shape[0], ...],
        num_e_bits=8,
        num_m_bits=23,
        bias=127.0,
    )
    binary_index = 0
    weight_index = 0

    dev_map = abs(mapped_float - new_weight) # Calculate deviation
    dev_sum_map = torch.sum(dev_map, dim=1)
    min_dev, best_map = torch.min(dev_sum_map, dim=1) # calculate best mapping
    best_map3d = best_map.unsqueeze(1).repeat(1, 16).unsqueeze(1)
    best_map_16 = torch.gather(new_weight, dim=1, index=best_map3d).squeeze(1) 
    idx_map = torch.index_select(indicies, dim=0, index=best_map)
    weight_remap = torch.gather(best_map_16, dim=1, index=idx_map)
    new_weights = weight_remap.view(shape)

    return new_weights
예제 #2
0
def make_SA(weights, mask, mask1):
    assert weights.shape == weights.view(-1).shape
    assert mask.shape == mask.view(-1).shape
    assert mask1.shape == mask1.view(-1).shape
    weights = weights.view(-1)
    ## Inject errors
    output = ((weights + mask) > 0.).float()  # inject stuck at 0
    output = ((output - mask1) > 0.).float()  # inject stuck at 1
    output = output.view(int(output.numel() / 32), 32)
    float_tensor = bit2float(output, num_e_bits=8, num_m_bits=23, bias=127.)
    return float_tensor
예제 #3
0
def make_SA2(weights, mask, mask1):
    assert weights.shape == weights.view(-1).shape
    assert mask.shape == mask.view(-1).shape
    assert mask1.shape == mask1.view(-1).shape
    conv_binary = float2bit(weights, num_e_bits=8, num_m_bits=23, bias=127.)
    shape = conv_binary.shape
    conv_binary = conv_binary.view(-1)
    ## Inject errors
    output = ((conv_binary + mask) > 0.).float()  # inject stuck at 0
    output = ((output - mask1) > 0.).float()  # inject stuck at 1
    output = output.view(shape)
    float_tensor = bit2float(output, num_e_bits=8, num_m_bits=23, bias=127.)
    return float_tensor
예제 #4
0
def ECC_method(state_dict, total_param, error_total, device):
    device = device
    with torch.no_grad():
        for name, param in state_dict.items():
            if "weight" not in name:
                continue
            else:
                shape = param.data.shape
                error_layer = (param.numel() / total_param) * error_total
                param_binary = float2bit(
                    param.data, num_e_bits=8, num_m_bits=23, bias=127.0
                )
                mask, mask1 = SAsimulate2.create_mask(param_binary, error_layer)
                output = SAsimulate2.make_SA_ECC(param.data.view(-1), mask, mask1)
                correct_binary = ECC(output, param_binary)
                float_tensor = bit2float(correct_binary, num_e_bits=8, num_m_bits=23, bias=127.0)
                param.data = float_tensor.view(shape)
    return state_dict
예제 #5
0
def method0(state_dict, total_param, error_total, device):
    device = device
    with torch.no_grad():
        for name, param in state_dict.items():
            if "weight" not in name:
                continue
            else:
                shape = param.data.shape
                print(name, shape)
                error_layer = (param.numel() / total_param) * error_total
                param_binary = float2bit(
                    param, num_e_bits=8, num_m_bits=23, bias=127.0
                ) > 0.
                mask, mask1 = SAsimulate3.create_mask_bool(shape, error_layer)
                output = SAsimulate3.make_SA_bool(param_binary.view(-1), mask, mask1)
                output = bit2float(output.view(param_binary.shape).type(torch.int8))
                param.data = output.view(shape)
                torch.cuda.empty_cache()
    return state_dict
예제 #6
0
def weight_map2(weights, mapped_float, mapped_binary, error_rate, indicies, device):
    shape = weights.shape
    weights_flat = weights.view(-1)
    if weights_flat.numel() > 16:
        weight_binary = mapped_binary
    else:
        return weights

    # Creating masks for all weights in one layer
    mask0_binary, mask1_binary = SAsimulate3.create_mask_bool(shape, error_rate=error_rate)
    # reporter = MemReporter()
    # reporter.report()
    mask0_binary, mask1_binary = (
        mask0_binary.view(int(mask0_binary.numel() / 32 / 16), 16, 32),
        mask1_binary.view(int(mask1_binary.numel() / 32 / 16), 16, 32),
    )

    flip_mapped = ~mapped_binary
    mapped_binary = torch.cat((mapped_binary, flip_mapped), dim=1)
    new_weight_binary = torch.empty(
        [*mapped_binary.shape], device=device, dtype=torch.bool
    )

    for i in range(32):
        new_weight_binary[:, i, :, :] = SAsimulate3.make_SA_bool(
                mapped_binary[:, i, :, :], mask0_binary, mask1_binary
        )
    new_weight_binary[:, 16:32, :, :] = ~new_weight_binary[:, 16:32, :, :]

    new_weight = torch.empty(new_weight_binary.shape[0], 32, 16, device=device)
    for idx in range(32):
        new_binary = new_weight_binary[:, idx, ...]
        new_weight[:, idx, :] = bit2float(new_binary.type(torch.int8))

    # half_shape = int(new_weight_binary.shape[0] / 4)
    # new_weight = torch.empty(half_shape * 4, 32, 16, device=device)
    # def part_weight_binary(new_weight_binary):
    #     for idx in range(0, new_weight_binary.shape[0], half_shape):
    #         yield new_weight_binary[idx:idx+half_shape, ...].type(torch.int8)

    # part_binary = part_weight_binary(new_weight_binary)
    # for idx, binary in zip(range(0, new_weight_binary.shape[0], half_shape), part_binary):
    #     new_weight[idx:idx+half_shape] = bit2float(binary)
    #     torch.cuda.empty_cache()

    # new_weight[0:half_shape, ...] = bit2float(
    #     next(part_binary), num_e_bits=8, num_m_bits=23, bias=127.0
    # )
    # new_weight[half_shape : new_weight.shape[0], ...] = bit2float(
    #     next(part_binary),
    #     num_e_bits=8,
    #     num_m_bits=23,
    #     bias=127.0,
    # )

    binary_index = 0
    weight_index = 0
    mapped_float = torch.cat((mapped_float, mapped_float), dim=1)

    dev_map = abs(mapped_float - new_weight) # Calculate deviation
    dev_sum_map = torch.sum(dev_map, dim=2)
    min_dev, best_map = torch.min(dev_sum_map, dim=1) # Calculate best mapping
    best_map3d = best_map.unsqueeze(1).repeat(1, 16).unsqueeze(1)
    best_map_16 = torch.gather(new_weight, dim=1, index=best_map3d).squeeze(1) # Choose best case in 32 cases
    idx_map = torch.index_select(indicies, dim=0, index=best_map)
    weight_remap = torch.gather(best_map_16, dim=1, index=idx_map)  # remap best mapping
    new_weights = weight_remap.view(shape)

    return new_weights