Ejemplo n.º 1
0
        model_abc = CLabcv2(ranks_up)
        model_abc.set_weights(model_base.raw_weights(ranks_up))
        model_abc.load_state_dict(model_base.state_dict(), strict=False)
        model = model_abc
    else:
        ranks_up = model_base.get_ranks()
        model = CLabcv2(ranks_up)
        model.load_state_dict(torch.load(model_path))

    modelu, modelz = CLabcv2(ranks_up), CLabcv2(ranks_up)
elif args.lr_special:
    if not args.lr_initialize:
        model_base.load_state_dict(torch.load(model_path))
        weights_list, ranks_up = model_base.svd_global_lowrank_weights(k=args.prune_ratio)
        print(ranks_up)
        model = CLlr(ranks_up)
        # print([weight.shape for weight in weights_list])
        # print(weights_list)
        model.set_weights(weights_list)
        model.load_state_dict(model_base.state_dict(), strict=False)
        # descripter = "{}_proj_{}_nnz_{}_quant_{}_bits_{}_".format(defend_name, prune_name, args.prune_ratio, quantize_name, args.quantize_bits) 
        with open(os.path.join(args.savedir, args.model_name[0:-4] + ".npy"), "rb") as filestream:
            pickle.dump(ranks_up, filestream, protocol=pickle.HIGHEST_PROTOCOL)
    else:
        with open(os.path.join(args.loaddir, args.model_name[0:-4] + ".npy"), "rb") as filestream:
            ranks_up = pickle.load(filestream)
        model = CLlr(ranks_up)
        model.load_state_dict(torch.load(model_path))
    modelu, modelz = CLlr(ranks_up), CLlr(ranks_up)

else:
        model = model_abc
    else:
        ranks_up = model_base.get_ranks()
        sparse_model = CLabcv2(ranks_up)
        quant_model = CLabcv2(ranks_up)
        sparse_model.load_state_dict(torch.load(sparse_model_path))
        quant_model.load_state_dict(torch.load(quant_model_path))

    modelu, modelz = CLabcv2(ranks_up), CLabcv2(ranks_up)
elif args.lr_special:
    if not args.lr_initialize:
        model_base.load_state_dict(torch.load(model_path))
        weights_list, ranks_up = model_base.svd_global_lowrank_weights(
            k=args.prune_ratio)
        print(ranks_up)
        model = CLlr(ranks_up)
        # print([weight.shape for weight in weights_list])
        # print(weights_list)
        model.set_weights(weights_list)
        model.load_state_dict(model_base.state_dict(), strict=False)
        # descripter = "{}_proj_{}_nnz_{}_quant_{}_bits_{}_".format(defend_name, prune_name, args.prune_ratio, quantize_name, args.quantize_bits)
        with open(os.path.join(args.savedir, args.model_name[0:-4] + ".npy"),
                  "rb") as filestream:
            pickle.dump(ranks_up, filestream, protocol=pickle.HIGHEST_PROTOCOL)
    else:
        with open(os.path.join(args.loaddir, args.model_name[0:-4] + ".npy"),
                  "rb") as filestream:
            ranks_up = pickle.load(filestream)
        sparse_model = CLlr(ranks_up)
        quant_model = CLlr(ranks_up)
        sparse_model.load_state_dict(torch.load(sparse_model_path))
Ejemplo n.º 3
0
        model_abc.set_weights(model_base.raw_weights(ranks_up))
        model_abc.load_state_dict(model_base.state_dict(), strict=False)
        model = model_abc
    else:
        ranks_up = model_base.get_ranks()
        model = CLabcv2(ranks_up)
        model.load_state_dict(torch.load(model_path))

    modelu, modelz = CLabcv2(ranks_up), CLabcv2(ranks_up)
elif args.lr_special:
    if not args.lr_initialize:
        model_base.load_state_dict(torch.load(model_path))
        weights_list, ranks_up = model_base.svd_global_lowrank_weights(
            k=args.prune_ratio)
        print(ranks_up)
        model = CLlr(ranks_up)
        # print([weight.shape for weight in weights_list])
        # print(weights_list)
        model.set_weights(weights_list)
        model.load_state_dict(model_base.state_dict(), strict=False)
        # descripter = "{}_proj_{}_nnz_{}_quant_{}_bits_{}_".format(defend_name, prune_name, args.prune_ratio, quantize_name, args.quantize_bits)
        with open(os.path.join(args.savedir, "param_" + descripter + "lr.npy"),
                  "wb") as filestream:
            pickle.dump(ranks_up, filestream, protocol=pickle.HIGHEST_PROTOCOL)
    else:
        # with open(os.path.join(args.loaddir, "param_" + descripter + "lr.npy"), "rb") as filestream:
        with open(os.path.join(args.loaddir, args.model_name[0:-4] + ".npy"),
                  "rb") as filestream:
            ranks_up = pickle.load(filestream)
        model = CLlr(ranks_up)
        model.load_state_dict(torch.load(model_path))
Ejemplo n.º 4
0
    os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu

    # model:
    model_base = CLdense()
    weight_name = ["weight"] if not args.abc_special else ["weightA", "weightB", "weightC"]
    weight_name = ["weightA", "weightB"] if args.lr_special else weight_name

    model_path = os.path.join(args.loaddir, args.prefix_name + args.model_name)

    if args.abc_special:
        ranks_up = model_base.get_ranks()
        model = CLabcv2(ranks_up)
    elif args.lr_special:        
        with open(os.path.join(args.loaddir, args.model_name[0:-4] + ".npy"), "rb") as filestream:
            ranks_up = pickle.load(filestream)
        model = CLlr(ranks_up)
    else:
        model = model_base
    model.load_state_dict(torch.load(model_path))

    model.cuda()

    #
    if args.prune_algo == "l0proj":
        prune_algo = l0proj
    elif args.prune_algo is None:
        prune_algo = None
    elif args.prune_algo == "baseline":
        prune_algo = l0proj
    elif args.prune_algo == "model_size_prune":
        prune_algo = pt.prune_admm_ms