Пример #1
0
def get_bitops_fullp():
    model_ = mobilenet_v2(QuantOps)
    model_ = model_.to(device)
    if args.dataset in ["cifar100", "cifar10"]:
        input = torch.randn([1, 3, 32, 32]).cuda()
    else:
        input = torch.randn([1, 3, 224, 224]).cuda()
    model_.train()
    QuantOps.initialize(model_, train_loader, 32, weight=True)
    QuantOps.initialize(model_, train_loader, 32, act=True)
    model_.eval()
    return get_bitops(model_, 'list')
Пример #2
0
def get_bitops_total():
    model_ = mobilenet_v2(QuantOps)
    model_ = model_.to(device)
    if args.dataset in ["cifar100", "cifar10"]:
        input = torch.randn([1, 3, 32, 32]).cuda()
    else:
        input = torch.randn([1, 3, 224, 224]).cuda()
    model_.eval()
    QuantOps.initialize(model_, train_loader, 32, weight=True)
    QuantOps.initialize(model_, train_loader, 32, act=True)
    #bitops = calc_bitops(model_, full=True)
    _, bitops = model_(input)

    return bitops
Пример #3
0
    bitops_first = bitops_list[0]
    bitops_last = bitops_list[-1]

bitops_target = ((bitops_total - bitops_first - bitops_last) * (args.w_target_bit/32.) * (args.a_target_bit/32.) +\
              (bitops_first * (args.w_target_bit/32.)) +\
              (bitops_last * (args.a_target_bit/32.)))

logging.info(f'bitops_total : {int(bitops_total):d}')
logging.info(f'bitops_target: {int(bitops_target):d}')
#logging.info(f'bitops_wrong : {int(bitops_total * (args.w_target_bit/32.) * (args.a_target_bit/32.)):d}')
if type(bitops_target) != float:
    bitops_target = float(bitops_target)

# Make model
if args.model == "mobilenetv2":
    model = mobilenet_v2(QuantOps)
    if not os.path.isfile("./checkpoint/mobilenet_v2-b0353104.pth"):
        os.system(
            "wget -P ./checkpoint https://download.pytorch.org/models/mobilenet_v2-b0353104.pth"
        )
    model.load_state_dict(torch.load("./checkpoint/mobilenet_v2-b0353104.pth"),
                          False)
    print("pretrained weight is loaded.")
else:
    raise NotImplementedError
model = model.to(device)


# optimizer -> for further coding (from PROFIT)
def get_optimizer(params, train_weight, train_quant, train_bnbias,
                  train_w_theta, train_a_theta):