def main(): os.chdir(os.path.dirname(__file__)) args = get_arguments() constr_weight = get_constraint(args.weight_bits, 'weight') constr_activation = get_constraint(args.activation_bits, 'activation') if args.dataset == 'cifar10': network = resnet20 dataloader = dataloader_cifar else: if args.network == 'resnet18': network = resnet18 elif args.network == 'resnet50': network = resnet50 else: print('Not Support Network Type: %s' % args.network) return dataloader = dataloader_imagenet train_loader = dataloader(args.data_root, split='train', batch_size=args.batch_size) test_loader = dataloader(args.data_root, split='test', batch_size=args.batch_size) net = network(quan_first_last=args.quan_first_last, constr_activation=constr_activation, preactivation=args.preactivation) model_path = os.path.join(args.model_root, args.model_name + '.pth') name_weights_old = torch.load(model_path) name_weights_new = net.state_dict() name_weights_new.update(name_weights_old) net.load_state_dict(name_weights_new) add_lsqmodule(net, constr_weight) print(net) net = net.cuda() net = nn.DataParallel(net, device_ids=range(cuda.device_count())) quan_activation = isinstance(constr_activation, np.ndarray) postfix = '_w' if not quan_activation else '_a' new_model_name = args.prefix + args.model_name + '_lsq' + postfix cache_root = os.path.join('.', 'cache') train_loger = LogHelper(new_model_name, cache_root, quan_activation, args.resume) optimizer, lr_scheduler = get_optimizer(net=net, optimizer=args.optimizer, lr_base=args.learning_rate, weight_decay=args.weight_decay, lr_scheduler=args.lr_scheduler, total_epoch=args.total_epoch, quan_activation=quan_activation) trainer = Trainer(net=net, train_loader=train_loader, test_loader=test_loader, optimizer=optimizer, lr_scheduler=lr_scheduler, model_name=new_model_name, train_loger=train_loger) trainer(total_epoch=args.total_epoch, save_check_point=True, resume=args.resume)
def main(): args = get_arguments() constr_activation = get_constraint(args.activation_bits, 'activation') net = mixnet_s(quan_first=True, quan_last=True, constr_activation=constr_activation, preactivation=False, bw_act=args.activation_bits) test_loader = dataloader_imagenet(args.data_root, split='test', batch_size=args.batch_size) add_lsqmodule(net, bit_width=args.weight_bits) name_weights_old = torch.load(args.model_path) name_weights_new = net.state_dict() name_weights_new.update(name_weights_old) load_checkpoint(net, name_weights_new) criterion = torch.nn.CrossEntropyLoss() score = get_micronet_score(net, args.weight_bits, args.activation_bits) # Calculate accuracy net = net.cuda() quan_perf_epoch = eval_performance(net, test_loader, criterion) accuracy = quan_perf_epoch[1] print("Accuracy:", accuracy) print("Score:", score)
def main(): os.chdir(os.path.dirname(__file__)) args = get_arguments() constr_weight = get_constraint(args.weight_bits, 'weight') constr_activation = get_constraint(args.activation_bits, 'activation') if args.dataset == 'cifar10': network = resnet20 dataloader = dataloader_cifar10 elif args.dataset == 'cifar100': t_net = WRN40_6() state = torch.load( "/prj/neo_lv/user/ybhalgat/LSQ-KD-0911/cifar100_pretrained/wrn40_6.pth" ) t_net.load_state_dict(state) network = WRN40_4 dataloader = dataloader_cifar100 else: if args.network == 'resnet18': network = resnet18 elif args.network == 'resnet50': network = resnet50 elif args.network == 'efficientnet-b0': t_net = EfficientNet.from_pretrained("efficientnet-b1") network = efficientnet_b0 elif args.network == "mixnet_s": t_net = MixNet(net_type=args.teacher) t_net.load_state_dict( torch.load("../imagenet_pretrained/" + args.teacher + ".pth")) network = mixnet_s else: print('Not Support Network Type: %s' % args.network) return dataloader = dataloader_imagenet train_loader = dataloader(args.data_root, split='train', batch_size=args.batch_size) test_loader = dataloader(args.data_root, split='test', batch_size=args.batch_size) net = network(quan_first=args.quan_first, quan_last=args.quan_last, constr_activation=constr_activation, preactivation=args.preactivation, bw_act=args.activation_bits) # net.load_state_dict(name_weights_new, strict=False) if args.cem: ##### CEM vector for 1.5x_W7A7_CEM prefinetuning 72% # cem_input = [7, 7, 7, 7, 7, 6, 7, 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, # 7, 7, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, # 7, 7, 7, 6, 7, 7, 7, 7, 7, 5, 7, 7, 7, 6, 4, 7, 7, 6, # 6, 6, 7, 7, 7, 7, 5, 7, 7, 7, 6, 4, 7, 7, 5, 5, 4, 7, # 7, 6, 5, 5, 7, 5, 7, 5, 5, 3] ##### CEM vector for 1.5x_W7A7_CEM prefinetuning 70% cem_input = [ 7, 7, 7, 7, 7, 5, 7, 7, 7, 7, 7, 7, 7, 7, 6, 6, 7, 7, 6, 7, 6, 7, 7, 7, 7, 6, 7, 7, 7, 7, 6, 7, 7, 7, 7, 4, 7, 6, 7, 5, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 5, 7, 7, 7, 5, 6, 7, 7, 7, 6, 4, 7, 7, 6, 5, 4, 7, 6, 5, 5, 4, 7, 7, 6, 5, 4, 7, 7, 6, 5, 5, 3 ] strategy_path = "/prj/neo_lv/user/ybhalgat/LSQ-implementation/lsq_quantizer/cem_strategy_relaxed.txt" with open(strategy_path) as fp: strategy = fp.readlines() strategy = [x.strip().split(",") for x in strategy] ##### CEM vector for W6A6_CEM # cem_input = [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, # 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, # 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, # 1, 0, 1, 1, 0, 1, 1, 1, 1, 1] strat = {} act_strat = {} for idx, width in enumerate(cem_input): weight_layer_name = strategy[idx][1] act_layer_name = strategy[idx][0] for name, module in net.named_modules(): if name.startswith('module'): name = name[7:] # remove `module.` if name == weight_layer_name: strat[name] = int(cem_input[idx]) if name == act_layer_name: act_strat[name] = int(cem_input[idx]) add_lsqmodule(net, bit_width=args.weight_bits, strategy=strat) for name, module in net.named_modules(): if name in act_strat: if "efficientnet" in args.network: if "_in_act_quant" in name or "first_act" in name or "_head_act_quant0" in name or "_head_act_quant1" in name: temp_constr_act = get_constraint( act_strat[name], 'weight') #symmetric else: temp_constr_act = get_constraint( act_strat[name], 'activation') #asymmetric elif "mixnet" in args.network: if "last_act" in name or "out_act_quant" in name or "first_act" in name: temp_constr_act = get_constraint( act_strat[name], 'weight') #symmetric else: temp_constr_act = get_constraint( act_strat[name], 'activation') #asymmetric module.constraint = temp_constr_act elif args.manual: if args.network == "wrn40_4": strategy = { "block3.layer.0.conv2": 3, "block3.layer.2.conv1": 3, "block3.layer.3.conv1": 3, "block3.layer.4.conv1": 3, "block3.layer.2.conv2": 3, "block3.layer.1.conv2": 3, "block3.layer.3.conv2": 3, "block3.layer.1.conv1": 3, "block3.layer.5.conv1": 2, "block1.layer.1.conv2": 1 } act_strategy = { "block3.layer.0.relu2": 3, "block3.layer.2.relu1": 3, "block3.layer.3.relu1": 3, "block3.layer.4.relu1": 3, "block3.layer.2.relu2": 3, "block3.layer.1.relu2": 3, "block3.layer.3.relu2": 3, "block3.layer.1.relu1": 3, "block3.layer.5.relu1": 2, "block1.layer.1.relu2": 1 } elif args.network == 'efficientnet-b0': strategy = { "_fc": 3, "_conv_head": 5, "_blocks.15._project_conv": 5, "_blocks.15._expand_conv": 4, "_blocks.14._expand_conv": 4, "_blocks.13._expand_conv": 4, "_blocks.12._expand_conv": 4, "_blocks.13._project_conv": 4, "_blocks.14._project_conv": 4, "_blocks.12._project_conv": 5, "_blocks.9._expand_conv": 4, "_blocks.10._expand_conv": 4 } act_strategy = { "_head_act_quant1": 3, "_head_act_quant0": 5, "_blocks.15._pre_proj_activation": 5, "_blocks.15._in_act_quant": 4, "_blocks.14._in_act_quant": 4, "_blocks.13._in_act_quant": 4, "_blocks.12._in_act_quant": 4, "_blocks.13._pre_proj_activation": 4, "_blocks.14._pre_proj_activation": 4, "_blocks.12._pre_proj_activation": 5, "_blocks.9._in_act_quant": 4, "_blocks.10._in_act_quant": 4 } #strategy = {"_fc": 3, # "_conv_head": 4, # "_blocks.15._project_conv": 4, # "_blocks.14._project_conv": 4, # "_blocks.13._project_conv": 3, # "_blocks.13._expand_conv": 4, # "_blocks.12._project_conv": 4, # "_blocks.12._expand_conv": 5, # "_blocks.14._expand_conv": 4, # "_blocks.15._expand_conv": 4, # "_blocks.9._project_conv": 4} # #"_blocks.10._project_conv": 4, # #"_blocks.9._expand_conv": 4, # #"_blocks.10._expand_conv": 4, # #"_blocks.7._expand_conv": 4, # #"_blocks.11._expand_conv": 4} #act_strategy = {"_head_act_quant1": 3, # "_head_act_quant0": 4, # "_blocks.15._pre_proj_activation": 4, # "_blocks.14._pre_proj_activation": 4, # "_blocks.13._pre_proj_activation": 3, # "_blocks.13._in_act_quant": 4, # "_blocks.12._pre_proj_activation": 4, # "_blocks.12._in_act_quant": 5, # "_blocks.14._in_act_quant": 4, # "_blocks.15._in_act_quant": 4, # "_blocks.9._pre_proj_activation": 4} # #"_blocks.10._pre_proj_activation": 4, # #"_blocks.9._in_act_quant": 4, # #"_blocks.10._in_act_quant": 4, # #"_blocks.7._in_act_quant": 4, # #"_blocks.11._in_act_quant": 4} add_lsqmodule(net, bit_width=args.weight_bits, strategy=strategy) for name, module in net.named_modules(): if name in act_strategy: if "_in_act_quant" in name or "first_act" in name or "_head_act_quant0" in name or "_head_act_quant1" in name: temp_constr_act = get_constraint(act_strategy[name], 'weight') #symmetric else: temp_constr_act = get_constraint(act_strategy[name], 'activation') #asymmetric module.constraint = temp_constr_act elif args.haq: if args.network == 'resnet50': strategy = [ 6, 6, 5, 5, 5, 5, 4, 5, 5, 4, 5, 5, 5, 5, 5, 5, 3, 5, 4, 3, 5, 4, 3, 4, 4, 4, 2, 5, 4, 3, 3, 5, 3, 2, 5, 3, 2, 4, 3, 2, 5, 3, 2, 5, 3, 4, 2, 5, 2, 3, 4, 2, 3, 4 ] elif args.network == 'efficientnet-b0': strategy = [ 7, 8, 8, 8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 7, 6, 7, 6, 7, 6, 5, 6, 5, 6, 4, 5, 6, 5, 6, 4, 4, 5, 4, 5, 2, 3, 4, 3, 4, 2, 3, 4, 4, 7, 5, 2, 4, 2, 5, 5, 2, 4, 2, 5, 5, 2, 4, 2, 5, 5, 2, 4, 3, 3, 2 ] add_lsqmodule(net, strategy=strategy) else: add_lsqmodule(net, bit_width=args.weight_bits) model_path = os.path.join(args.model_root, args.model_name + '.pth.tar') if not os.path.exists(model_path): model_path = model_path[:-4] name_weights_old = torch.load(model_path) name_weights_new = net.state_dict() name_weights_new.update(name_weights_old) load_checkpoint(net, name_weights_new, strict=False) print(net) net = net.cuda() net = nn.DataParallel(net) t_net = t_net.cuda() t_net = nn.DataParallel(t_net) if args.pruned: start_LSQ(net) quan_activation = isinstance(constr_activation, np.ndarray) postfix = '_w' if not quan_activation else '_a' new_model_name = args.prefix + args.model_name + '_lsq' + postfix cache_root = os.path.join('.', 'cache') train_loger = LogHelper(new_model_name, cache_root, quan_activation, args.resume) optimizer, lr_scheduler, optimizer_t, lr_scheduler_t = get_optimizer( s_net=net, t_net=t_net, optimizer=args.optimizer, lr_base=args.learning_rate, weight_decay=args.weight_decay, lr_scheduler=args.lr_scheduler, total_epoch=args.total_epoch, quan_activation=quan_activation, act_lr_factor=args.act_lr_factor, weight_lr_factor=args.weight_lr_factor) trainer = Trainer(net=net, t_net=t_net, train_loader=train_loader, test_loader=test_loader, optimizer=optimizer, optimizer_t=optimizer_t, lr_scheduler=lr_scheduler, lr_scheduler_t=lr_scheduler_t, model_name=new_model_name, train_loger=train_loger, pruned=args.pruned) trainer(total_epoch=args.total_epoch, save_check_point=True, resume=args.resume)
def cem_eval( cem_input, data_root="/nvme/users/tijmen/imagenet/", batch_size=350, strategy_path="/prj/neo_lv/user/ybhalgat/LSQ-implementation/lsq_quantizer/cem_strategy_relaxed.txt", model_path="/prj/neo_lv/scratch/jinwonl/finetune/lsq_quantizer/models/1.5x_W8A8_CEM_re1/efficientnet-b0.pth", activation_bits=7, weight_bits=7): constr_activation = get_constraint(activation_bits, 'activation') network = efficientnet_b0 dataloader = dataloader_imagenet test_loader = dataloader(data_root, split='test', batch_size=batch_size) with open(strategy_path) as fp: strategy = fp.readlines() strategy = [x.strip().split(",") for x in strategy] net = network(quan_first=True, quan_last=True, constr_activation=constr_activation, preactivation=False, bw_act=activation_bits) strat = {} act_strat = {} for idx, flag in enumerate(cem_input): weight_layer_name = strategy[idx][1] act_layer_name = strategy[idx][0] for name, module in net.named_modules(): if name.startswith('module'): name = name[7:] # remove `module.` if name == weight_layer_name: strat[name] = int(cem_input[idx]) if name == act_layer_name: act_strat[name] = int(cem_input[idx]) add_lsqmodule(net, bit_width=weight_bits, strategy=strat) for name, module in net.named_modules(): if name in act_strat: if "_in_act_quant" in name or "first_act" in name or "_head_act_quant0" in name or "_head_act_quant1" in name: temp_constr_act = get_constraint(act_strat[name], 'weight') #symmetric else: temp_constr_act = get_constraint(act_strat[name], 'activation') #asymmetric module.constraint = temp_constr_act net.load_state_dict(torch.load(model_path)) criterion = nn.CrossEntropyLoss() # Calculate score flops_model = add_flops_counting_methods(net) flops_model.eval().start_flops_count() input_res = (3, 224, 224) batch = torch.ones(()).new_empty( (1, *input_res), dtype=next(flops_model.parameters()).dtype, device=next(flops_model.parameters()).device) _ = flops_model(batch) flops_count = flops_model.compute_average_flops_cost( bw_weight=activation_bits, bw_act=weight_bits, strategy=(strat, strat)) params_count = get_model_parameters_number(flops_model, bw_weight=weight_bits, w_strategy=strat) flops_model.stop_flops_count() score = params_count / 6900000.0 + flops_count / 1170000000.0 # Calculate accuracy net = net.cuda() net = nn.DataParallel(net, device_ids=range(cuda.device_count())) quan_perf_epoch = eval_performance(net, test_loader, criterion) accuracy = quan_perf_epoch[1] return accuracy, score
def main(): os.chdir(os.path.dirname(__file__)) args = get_arguments() constr_weight = get_constraint(args.weight_bits, 'weight') constr_activation = get_constraint(args.activation_bits, 'activation') if args.dataset == 'cifar10': network = resnet20 dataloader = dataloader_cifar10 elif args.dataset == 'cifar100': t_net = ResNet(depth=56, num_classes=100) state = torch.load("/prj/neo_lv/user/ybhalgat/LSQ-KD/cifar100_pretrained/resnet56.pth.tar") t_net.load_state_dict(state) network = resnet20 dataloader = dataloader_cifar100 else: if args.network == 'resnet18': network = resnet18 elif args.network == 'resnet50': network = resnet50 elif args.network == 'efficientnet-b0': t_net = EfficientNet.from_pretrained("efficientnet-b3") network = efficientnet_b0 else: print('Not Support Network Type: %s' % args.network) return dataloader = dataloader_imagenet train_loader = dataloader(args.data_root, split='train', batch_size=args.batch_size) test_loader = dataloader(args.data_root, split='test', batch_size=args.batch_size) net = network(quan_first=args.quan_first, quan_last=args.quan_last, constr_activation=constr_activation, preactivation=args.preactivation, bw_act=args.activation_bits) model_path = os.path.join(args.model_root, args.model_name + '.pth.tar') if not os.path.exists(model_path): model_path = model_path[:-4] name_weights_old = torch.load(model_path) name_weights_new = net.state_dict() name_weights_new.update(name_weights_old) load_checkpoint(net, name_weights_new) # net.load_state_dict(name_weights_new, strict=False) if not args.haq: add_lsqmodule(net, bit_width=args.weight_bits) else: if args.network == 'resnet50': strategy = [6, 6, 5, 5, 5, 5, 4, 5, 5, 4, 5, 5, 5, 5, 5, 5, 3, 5, 4, 3, 5, 4, 3, 4, 4, 4, 2, 5, 4, 3, 3, 5, 3, 2, 5, 3, 2, 4, 3, 2, 5, 3, 2, 5, 3, 4, 2, 5, 2, 3, 4, 2, 3, 4] elif args.network == 'efficientnet-b0': strategy = [7, 8, 8, 8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 7, 6, 7, 6, 7, 6, 5, 6, 5, 6, 4, 5, 6, 5, 6, 4, 4, 5, 4, 5, 2, 3, 4, 3, 4, 2, 3, 4, 4, 7, 5, 2, 4, 2, 5, 5, 2, 4, 2, 5, 5, 2, 4, 2, 5, 5, 2, 4, 3, 3, 2] add_lsqmodule(net, strategy=strategy) print(net) net = net.cuda() net = nn.DataParallel(net, device_ids=range(cuda.device_count())) t_net = t_net.cuda() t_net = nn.DataParallel(t_net, device_ids=range(cuda.device_count())) quan_activation = isinstance(constr_activation, np.ndarray) postfix = '_w' if not quan_activation else '_a' new_model_name = args.prefix + args.model_name + '_lsq' + postfix cache_root = os.path.join('.', 'cache') train_loger = LogHelper(new_model_name, cache_root, quan_activation, args.resume) optimizer, lr_scheduler, optimizer_t = get_optimizer(s_net=net, t_net=t_net, optimizer=args.optimizer, lr_base=args.learning_rate, weight_decay=args.weight_decay, lr_scheduler=args.lr_scheduler, total_epoch=args.total_epoch, quan_activation=quan_activation, act_lr_factor=args.act_lr_factor, weight_lr_factor=args.weight_lr_factor) trainer = Trainer(net=net, t_net=t_net, train_loader=train_loader, test_loader=test_loader, optimizer=optimizer, optimizer_t=optimizer_t, lr_scheduler=lr_scheduler, model_name=new_model_name, train_loger=train_loger) trainer(total_epoch=args.total_epoch, save_check_point=True, resume=args.resume)
def main(): args = get_arguments() constr_activation = get_constraint(args.activation_bits, 'activation') net = WRN40_4(quan_first=False, quan_last=False, constr_activation=constr_activation, preactivation=False, bw_act=args.activation_bits) test_loader = dataloader_cifar100(args.data_root, split='test', batch_size=args.batch_size) add_lsqmodule(net, bit_width=args.weight_bits) if args.cem: strategy = { "block3.layer.0.conv2": 3, "block3.layer.2.conv1": 3, "block3.layer.3.conv1": 3, "block3.layer.4.conv1": 3, "block3.layer.2.conv2": 3, "block3.layer.1.conv2": 3, "block3.layer.3.conv2": 3, "block3.layer.1.conv1": 3, "block3.layer.5.conv1": 2, "block1.layer.1.conv2": 1 } act_strategy = { "block3.layer.0.relu2": 3, "block3.layer.2.relu1": 3, "block3.layer.3.relu1": 3, "block3.layer.4.relu1": 3, "block3.layer.2.relu2": 3, "block3.layer.1.relu2": 3, "block3.layer.3.relu2": 3, "block3.layer.1.relu1": 3, "block3.layer.5.relu1": 2, "block1.layer.1.relu2": 1 } add_lsqmodule(net, bit_width=args.weight_bits, strategy=strategy) for name, module in net.named_modules(): if name in act_strategy: if "_in_act_quant" in name or "first_act" in name or "_head_act_quant0" in name or "_head_act_quant1" in name: temp_constr_act = get_constraint(act_strategy[name], 'weight') #symmetric else: temp_constr_act = get_constraint(act_strategy[name], 'activation') #asymmetric module.constraint = temp_constr_act name_weights_old = torch.load(args.model_path) name_weights_new = net.state_dict() name_weights_new.update(name_weights_old) load_checkpoint(net, name_weights_new) criterion = torch.nn.CrossEntropyLoss() score = get_micronet_score(net, args.weight_bits, args.activation_bits, weight_strategy=strategy, activation_strategy=act_strategy, input_res=(3, 32, 32), baseline_params=36500000, baseline_MAC=10490000000) # Calculate accuracy net = net.cuda() quan_perf_epoch = eval_performance(net, test_loader, criterion) accuracy = quan_perf_epoch[1] print("Accuracy:", accuracy) print("Score:", score)
def main(): args = get_arguments() constr_activation = get_constraint(args.activation_bits, 'activation') net = efficientnet_b0(quan_first=True, quan_last=True, constr_activation=constr_activation, preactivation=False, bw_act=args.activation_bits) test_loader = dataloader_imagenet(args.data_root, split='test', batch_size=args.batch_size) add_lsqmodule(net, bit_width=args.weight_bits) if args.cem: ##### CEM vector for 1.5x_W7A7_CEM cem_input = [ 7, 7, 7, 7, 7, 5, 7, 7, 7, 7, 7, 7, 7, 7, 6, 6, 7, 7, 6, 7, 6, 7, 7, 7, 7, 6, 7, 7, 7, 7, 6, 7, 7, 7, 7, 4, 7, 6, 7, 5, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 5, 7, 7, 7, 5, 6, 7, 7, 7, 6, 4, 7, 7, 6, 5, 4, 7, 6, 5, 5, 4, 7, 7, 6, 5, 4, 7, 7, 6, 5, 5, 3 ] strategy_path = "lsq_quantizer/cem_strategy_relaxed.txt" with open(strategy_path) as fp: strategy = fp.readlines() strategy = [x.strip().split(",") for x in strategy] strat = {} act_strat = {} for idx, width in enumerate(cem_input): weight_layer_name = strategy[idx][1] act_layer_name = strategy[idx][0] for name, module in net.named_modules(): if name.startswith('module'): name = name[7:] # remove `module.` if name == weight_layer_name: strat[name] = int(cem_input[idx]) if name == act_layer_name: act_strat[name] = int(cem_input[idx]) add_lsqmodule(net, bit_width=args.weight_bits, strategy=strat) for name, module in net.named_modules(): if name in act_strat: if "_in_act_quant" in name or "first_act" in name or "_head_act_quant0" in name or "_head_act_quant1" in name: temp_constr_act = get_constraint(act_strat[name], 'weight') #symmetric else: temp_constr_act = get_constraint(act_strat[name], 'activation') #asymmetric module.constraint = temp_constr_act name_weights_old = torch.load(args.model_path) name_weights_new = net.state_dict() name_weights_new.update(name_weights_old) load_checkpoint(net, name_weights_new) score = get_micronet_score(net, args.weight_bits, args.activation_bits, weight_strategy=strat, activation_strategy=act_strat) criterion = torch.nn.CrossEntropyLoss() # Calculate accuracy net = net.cuda() quan_perf_epoch = eval_performance(net, test_loader, criterion) accuracy = quan_perf_epoch[1] print("Accuracy:", accuracy) print("Score:", score)