예제 #1
0
def get_model(model, dataset, classify=True):
    """
    VGG Models
    """
    if model == 'vgg11':
        model = vgg.vgg11_bn(dataset=dataset, classify=classify)
    if model == 'vgg13':
        model = vgg.vgg13_bn(dataset=dataset, classify=classify)
    if model == 'vgg16':
        model = vgg.vgg16_bn(dataset=dataset, classify=classify)
    if model == 'vgg19':
        model = vgg.vgg19_bn(dataset=dataset, classify=classify)
    """
    CyVGG Models
    """
    if model == 'cyvgg11':
        model = cyvgg.cyvgg11_bn(dataset=dataset, classify=classify)
    if model == 'cyvgg13':
        model = cyvgg.cyvgg13_bn(dataset=dataset, classify=classify)
    if model == 'cyvgg16':
        model = cyvgg.cyvgg16_bn(dataset=dataset, classify=classify)
    if model == 'cyvgg19':
        model = cyvgg.cyvgg19_bn(dataset=dataset, classify=classify)
    """
    Resnet Models   
    """
    if model == 'resnet20':
        model = resnet.resnet20(dataset=dataset)
    if model == 'resnet32':
        model = resnet.resnet32(dataset=dataset)
    if model == 'resnet44':
        model = resnet.resnet44(dataset=dataset)
    if model == 'resnet56':
        model = resnet.resnet56(dataset=dataset)
    """
    CyResnet Models
    """
    if model == 'cyresnet20':
        model = cyresnet.cyresnet20(dataset=dataset)
    if model == 'cyresnet32':
        model = cyresnet.cyresnet32(dataset=dataset)
    if model == 'cyresnet44':
        model = cyresnet.cyresnet44(dataset=dataset)
    if model == 'cyresnet56':
        model = cyresnet.cyresnet56(dataset=dataset)

    return model
예제 #2
0
from tools.model_trainer import ModelTrainer
from tools.common_tools import *
from config.config import cfg
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BASE_DIR = os.path.dirname(os.path.abspath(__file__))

if __name__ == '__main__':

    test_dir = os.path.join(BASE_DIR, "..", "data", "cifar10_test")
    path_checkpoint = os.path.join(BASE_DIR, "..", "..", "results/03-02_20-22/checkpoint_best.pkl")  # resnet-image
    path_checkpoint = os.path.join(BASE_DIR, "..", "..", "results/03-06_16-54/checkpoint_best.pkl")  # vgg-cifar
    valid_data = CifarDataset(data_dir=test_dir, transform=cfg.transforms_valid)
    valid_loader = DataLoader(dataset=valid_data, batch_size=cfg.valid_bs, num_workers=cfg.workers)
    log_dir = "../../results"

    model = resnet56()
    model = VGG("VGG16")


    check_p = torch.load(path_checkpoint, map_location="cpu", encoding='iso-8859-1')
    pretrain_dict = check_p["model_state_dict"]
    print("best acc: {} in epoch:{}".format(check_p["best_acc"], check_p["epoch"]))
    state_dict_cpu = state_dict_to_cpu(pretrain_dict)
    model.load_state_dict(state_dict_cpu)

    # resnet --> ghost-resnet
    model = replace_conv(model, GhostModule, arc="vgg16", pretrain=False)
    # model = replace_conv(model, GhostModule, pretrain=True)

    Isparallel = False
    if Isparallel and torch.cuda.is_available():
예제 #3
0
    train_data = CifarDataset(data_dir=train_dir,
                              transform=cfg.transforms_train)
    valid_data = CifarDataset(data_dir=test_dir,
                              transform=cfg.transforms_valid)

    # 构建DataLoder
    train_loader = DataLoader(dataset=train_data,
                              batch_size=cfg.train_bs,
                              shuffle=True,
                              num_workers=cfg.workers)
    valid_loader = DataLoader(dataset=valid_data,
                              batch_size=cfg.valid_bs,
                              num_workers=cfg.workers)

    # ------------------------------------ step 2/5 : 定义网络------------------------------------
    teacher = resnet56()
    # check_p = torch.load(path_checkpoint, map_location="cpu")
    check_p = torch.load(path_checkpoint,
                         map_location="cpu",
                         encoding='iso-8859-1')
    pretrain_dict = check_p["model_state_dict"]
    state_dict_cpu = state_dict_to_cpu(pretrain_dict)
    teacher.load_state_dict(state_dict_cpu)

    student = resnet56()

    t_map, s_map = [], []

    def t_fmap_hook(m, i, o):
        t_map.append(o)
예제 #4
0
"""
import os
import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(BASE_DIR, '..'))
from torchstat import stat
from models.ghost_net import GhostModule
from models.resnet import resnet56
from models.vgg import VGG
from tools.common_tools import replace_conv

if __name__ == '__main__':

    img_shape = (3, 32, 32)

    resnet56 = resnet56()
    stat(resnet56, img_shape)  # https://github.com/Swall0w/torchstat
    print("↑↑↑↑ is resnet56")
    print("\n" * 10)

    ghost_resnet56 = replace_conv(resnet56, GhostModule, arc="resnet56")
    stat(ghost_resnet56, img_shape)
    print("↑↑↑↑ is ghost_resnet56")

    vgg = 0
    # vgg = 1
    if vgg:
        vgg16 = VGG("VGG16")
        stat(vgg16, img_shape)
        print("↑↑↑↑ is vgg16")
        print("\n" * 10)