if args.defend_algo is not None: model_path = os.path.join(args.loaddir, args.defend_algo + "_densepretrain.pth") else: model_path = os.path.join(args.loaddir, '_densepretrain.pth') else: sparse_model_path = os.path.join( args.loaddir, "sparse_" + descripter + args.model_name) quant_model_path = os.path.join( args.loaddir, "quant_" + descripter + args.model_name) if args.abc_special: if not args.abc_initialize: model_base.load_state_dict(torch.load(model_path)) ranks_up = model_base.get_ranks() model_abc = CLabcv2(ranks_up) model_abc.set_weights(model_base.raw_weights(ranks_up)) model_abc.load_state_dict(model_base.state_dict(), strict=False) model = model_abc else: ranks_up = model_base.get_ranks() sparse_model = CLabcv2(ranks_up) quant_model = CLabcv2(ranks_up) sparse_model.load_state_dict(torch.load(sparse_model_path)) quant_model.load_state_dict(torch.load(quant_model_path)) modelu, modelz = CLabcv2(ranks_up), CLabcv2(ranks_up) elif args.lr_special: if not args.lr_initialize: model_base.load_state_dict(torch.load(model_path)) weights_list, ranks_up = model_base.svd_global_lowrank_weights(
pass else: if args.model_name is None: if args.defend_algo is not None: model_path = os.path.join(args.loaddir, args.defend_algo + "_densepretrain.pth") else: model_path = os.path.join(args.loaddir, '_densepretrain.pth') else: model_path = os.path.join(args.loaddir, args.model_name) if args.abc_special: if not args.abc_initialize: model_base.load_state_dict(torch.load(model_path)) ranks_up = model_base.get_ranks() model_abc = CLabcv2(ranks_up) model_abc.set_weights(model_base.raw_weights(ranks_up)) model_abc.load_state_dict(model_base.state_dict(), strict=False) model = model_abc else: ranks_up = model_base.get_ranks() model = CLabcv2(ranks_up) model.load_state_dict(torch.load(model_path)) modelu, modelz = CLabcv2(ranks_up), CLabcv2(ranks_up) elif args.lr_special: if not args.lr_initialize: model_base.load_state_dict(torch.load(model_path)) weights_list, ranks_up = model_base.svd_global_lowrank_weights( k=args.prune_ratio) print(ranks_up)
parser.add_argument("-e", "--exp_logger", default=None, help="exp results stored to") args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu # model: model_base = CLdense() weight_name = ["weight"] if not args.abc_special else ["weightA", "weightB", "weightC"] weight_name = ["weightA", "weightB"] if args.lr_special else weight_name model_path = os.path.join(args.loaddir, args.prefix_name + args.model_name) if args.abc_special: ranks_up = model_base.get_ranks() model = CLabcv2(ranks_up) elif args.lr_special: with open(os.path.join(args.loaddir, args.model_name[0:-4] + ".npy"), "rb") as filestream: ranks_up = pickle.load(filestream) model = CLlr(ranks_up) else: model = model_base model.load_state_dict(torch.load(model_path)) model.cuda() # if args.prune_algo == "l0proj": prune_algo = l0proj elif args.prune_algo is None: prune_algo = None