def get_bitops_fullp(): model_ = mobilenet_v2(QuantOps) model_ = model_.to(device) if args.dataset in ["cifar100", "cifar10"]: input = torch.randn([1, 3, 32, 32]).cuda() else: input = torch.randn([1, 3, 224, 224]).cuda() model_.train() QuantOps.initialize(model_, train_loader, 32, weight=True) QuantOps.initialize(model_, train_loader, 32, act=True) model_.eval() return get_bitops(model_, 'list')
def get_bitops_total(): model_ = mobilenet_v2(QuantOps) model_ = model_.to(device) if args.dataset in ["cifar100", "cifar10"]: input = torch.randn([1, 3, 32, 32]).cuda() else: input = torch.randn([1, 3, 224, 224]).cuda() model_.eval() QuantOps.initialize(model_, train_loader, 32, weight=True) QuantOps.initialize(model_, train_loader, 32, act=True) #bitops = calc_bitops(model_, full=True) _, bitops = model_(input) return bitops
bitops_first = bitops_list[0] bitops_last = bitops_list[-1] bitops_target = ((bitops_total - bitops_first - bitops_last) * (args.w_target_bit/32.) * (args.a_target_bit/32.) +\ (bitops_first * (args.w_target_bit/32.)) +\ (bitops_last * (args.a_target_bit/32.))) logging.info(f'bitops_total : {int(bitops_total):d}') logging.info(f'bitops_target: {int(bitops_target):d}') #logging.info(f'bitops_wrong : {int(bitops_total * (args.w_target_bit/32.) * (args.a_target_bit/32.)):d}') if type(bitops_target) != float: bitops_target = float(bitops_target) # Make model if args.model == "mobilenetv2": model = mobilenet_v2(QuantOps) if not os.path.isfile("./checkpoint/mobilenet_v2-b0353104.pth"): os.system( "wget -P ./checkpoint https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" ) model.load_state_dict(torch.load("./checkpoint/mobilenet_v2-b0353104.pth"), False) print("pretrained weight is loaded.") else: raise NotImplementedError model = model.to(device) # optimizer -> for further coding (from PROFIT) def get_optimizer(params, train_weight, train_quant, train_bnbias, train_w_theta, train_a_theta):