Пример #1
0
def get_executor(use_gpu=True):
    torch_module = MobileNetV2(n_class=27)
    # print(torch_module)

    if not os.path.exists(
            "mobilenetv2_jester_online.pth.tar"):  # checkpoint not downloaded
        print('Downloading PyTorch checkpoint...')
        import urllib.request
        url = 'https://hanlab.mit.edu/projects/tsm/models/mobilenetv2_jester_online.pth.tar'
        urllib.request.urlretrieve(url, './mobilenetv2_jester_online.pth.tar')
    torch_module.load_state_dict(
        torch.load("mobilenetv2_jester_online.pth.tar"))
    torch_inputs = (torch.rand(1, 3, 224, 224), torch.zeros([1, 3, 56, 56]),
                    torch.zeros([1, 4, 28, 28]), torch.zeros([1, 4, 28, 28]),
                    torch.zeros([1, 8, 14, 14]), torch.zeros([1, 8, 14, 14]),
                    torch.zeros([1, 8, 14, 14]), torch.zeros([1, 12, 14, 14]),
                    torch.zeros([1, 12, 14, 14]), torch.zeros([1, 20, 7, 7]),
                    torch.zeros([1, 20, 7, 7]))
    #summary(torch_module, torch_inputs)
    if use_gpu:
        target = 'cuda -model=tx2'
    else:
        target = 'llvm -device=arm_cpu -target=aarch64-linux-gnu'  #-mcpu=cortex-a57
        # target = tvm.target.arm_cpu('-target=aarch64-linux-gnu')
    return torch2executor(torch_module, torch_inputs, target)
Пример #2
0
def get_executor():
    torch_module = MobileNetV2(n_class=4)
    torch_module.load_state_dict(torch.load("ckpt.best.pth.tar"), strict=False)
    # torch_module.load_state_dict(torch.load("jester.pth.tar"))
    torch_inputs = (torch.rand(1, 3, 224, 224),
                    torch.zeros([1, 3, 56, 56]),
                    torch.zeros([1, 4, 28, 28]),
                    torch.zeros([1, 4, 28, 28]),
                    torch.zeros([1, 8, 14, 14]),
                    torch.zeros([1, 8, 14, 14]),
                    torch.zeros([1, 8, 14, 14]),
                    torch.zeros([1, 12, 14, 14]),
                    torch.zeros([1, 12, 14, 14]),
                    torch.zeros([1, 20, 7, 7]),
                    torch.zeros([1, 20, 7, 7]))
    return torch2executor(torch_module, torch_inputs, target='cuda')
def get_executor():
    torch_module = MobileNetV2(n_class=27)
    if not os.path.exists(
            "mobilenetv2_jester_online.pth.tar"):  # checkpoint not downloaded
        print('Downloading PyTorch checkpoint...')
        import urllib.request
        url = 'https://hanlab.mit.edu/projects/tsm/models/mobilenetv2_jester_online.pth.tar'
        urllib.request.urlretrieve(url, './mobilenetv2_jester_online.pth.tar')
    torch_module.load_state_dict(
        torch.load("mobilenetv2_jester_online.pth.tar"))
    torch_inputs = (torch.rand(1, 3, 224, 224), torch.zeros([1, 3, 56, 56]),
                    torch.zeros([1, 4, 28, 28]), torch.zeros([1, 4, 28, 28]),
                    torch.zeros([1, 8, 14, 14]), torch.zeros([1, 8, 14, 14]),
                    torch.zeros([1, 8, 14, 14]), torch.zeros([1, 12, 14, 14]),
                    torch.zeros([1, 12, 14, 14]), torch.zeros([1, 20, 7, 7]),
                    torch.zeros([1, 20, 7, 7]))
    return torch2executor(torch_module, torch_inputs, target='cuda')
Пример #4
0
def get_executor(use_gpu=True):
    torch_module = MobileNetV2(n_class=27)
    if not os.path.exists(
            "mobilenetv2_jester_online.pth.tar"):  # checkpoint not downloaded
        print('Downloading PyTorch checkpoint...')
        import urllib.request
        url = 'https://file.lzhu.me/projects/tsm/models/mobilenetv2_jester_online.pth.tar'
        urllib.request.urlretrieve(url, './mobilenetv2_jester_online.pth.tar')
    torch_module.load_state_dict(
        torch.load("mobilenetv2_jester_online.pth.tar"))
    torch_inputs = (torch.rand(1, 3, 224, 224), torch.zeros([1, 3, 56, 56]),
                    torch.zeros([1, 4, 28, 28]), torch.zeros([1, 4, 28, 28]),
                    torch.zeros([1, 8, 14, 14]), torch.zeros([1, 8, 14, 14]),
                    torch.zeros([1, 8, 14, 14]), torch.zeros([1, 12, 14, 14]),
                    torch.zeros([1, 12, 14, 14]), torch.zeros([1, 20, 7, 7]),
                    torch.zeros([1, 20, 7, 7]))
    if use_gpu:
        target = 'cuda'
    else:
        target = 'llvm -mcpu=cortex-a72 -target=armv7l-linux-gnueabihf'
    return torch2executor(torch_module, torch_inputs, target)
Пример #5
0
def get_executor():
    model = MobileNetV2(n_class=27)
    mobilenetv2_jester = torch.load('mobilenetv2_jester.pth.tar')['state_dict']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in mobilenetv2_jester.items():
        name = k[7:]
        if 'new_fc' in name:
            name = name.replace('new_fc', 'classifier')
        else:
            if 'net' in name:
                name = name.replace('net.', '')
            name = name[11:]
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    inputs = (torch.rand(1, 3, 224, 224), torch.zeros([1, 3, 56, 56]),
              torch.zeros([1, 4, 28, 28]), torch.zeros([1, 4, 28, 28]),
              torch.zeros([1, 8, 14, 14]), torch.zeros([1, 8, 14, 14]),
              torch.zeros([1, 8, 14, 14]), torch.zeros([1, 12, 14, 14]),
              torch.zeros([1, 12, 14,
                           14]), torch.zeros([1, 20, 7,
                                              7]), torch.zeros([1, 20, 7, 7]))
    return torch2executor(model, inputs, target='cuda')
Пример #6
0
import torch

dummy_input = torch.randn(1, 3, 224, 224)
state_dict = torch.load(r'./mobilenetv2_jester_online.pth.tar')

from mobilenet_v2_tsm import MobileNetV2
torch_module = MobileNetV2(n_class=27)
torch_module.load_state_dict(torch.load("mobilenetv2_jester_online.pth.tar"))
torch_module.eval()
torch_module.load_state_dict(state_dict)
shift_buffer = [torch.zeros([1, 3, 56, 56]),
                torch.zeros([1, 4, 28, 28]),
                torch.zeros([1, 4, 28, 28]),
                torch.zeros([1, 8, 14, 14]),
                torch.zeros([1, 8, 14, 14]),
                torch.zeros([1, 8, 14, 14]),
                torch.zeros([1, 12, 14, 14]),
                torch.zeros([1, 12, 14, 14]),
                torch.zeros([1, 20, 7, 7]),
                torch.zeros([1, 20, 7, 7])]
# 因为需要偏移缓存区,所以有dummy_input, *shift_buffer两个输入,需要把它们组合成元组变成一个输入,
# 另外,还需加上opset_version=10,该方案来自:https://github.com/pytorch/fairseq/issues/1669#issuecomment-798972533
torch.onnx.export(torch_module, (dummy_input, *shift_buffer), "mobilenet_v2_tsm.onnx", opset_version=10, verbose=True)
Пример #7
0
def get_executor(use_gpu=True):
    # torch_module = MobileNetV2(n_class=27)
    # if not os.path.exists("mobilenetv2_jester_online.pth.tar"):  # checkpoint not downloaded
    #     print('Downloading PyTorch checkpoint...')
    #     import urllib.request
    #     url = 'https://file.lzhu.me/projects/tsm/models/mobilenetv2_jester_online.pth.tar'
    #     urllib.request.urlretrieve(url, './mobilenetv2_jester_online.pth.tar')

    torch_module = MobileNetV2(n_class=2)
    # torch_module.load_state_dict(torch.load("ckpt.best.pth.tar"))
    checkpoint = torch.load("ckpt.best.pth.tar")
    checkpoint = checkpoint['state_dict']

    # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
    base_dict = {
        '.'.join(k.split('.')[1:]): v
        for k, v in list(checkpoint.items())
    }
    replace_dict = {
        'base_model.classifier.weight': 'new_fc.weight',
        'base_model.classifier.bias': 'new_fc.bias',
    }
    # print(base_dict.keys())
    for k in base_dict.keys():
        k_new = k.replace('.net', '').replace('base_model.', '')
        base_dict[k_new] = base_dict.pop(k)
        # print(k, type(k))
        # print('base_model.features.2.conv.4.running_var'.startswith('base_model.'))
        # print('base_model.features.2.conv.4.running_var'.replace('base_model.', ''))
        # # exit()
        # if k.startswith('base_model.'):
        #     k_new = k.replace('base_model.', '')
        #     print(k_new)
        #     base_dict[k_new] = base_dict.pop(k)
    # for k, v in replace_dict.items():
    #     if k in base_dict:
    #         base_dict[v] = base_dict.pop(k)
    # print("\n\n\n")
    # print(base_dict.keys())
    for k in base_dict.keys():
        k_new = k.replace('.net', '').replace('base_model.', '')
        base_dict[k_new] = base_dict.pop(k)
    # print("\n\n\n")
    # print(base_dict.keys())
    base_dict['classifier.weight'] = base_dict.pop('new_fc.weight')
    base_dict['classifier.bias'] = base_dict.pop('new_fc.bias')
    # exit()
    # for k in base_dict.keys():
    #     print(k, type(k))
    #     print('base_model.features.2.conv.4.running_var'.startswith('base_model.'))
    #     print('base_model.features.2.conv.4.running_var'.replace('base_model.', ''))
    #     # exit()
    #     if k.startswith('base_model.'):
    #         k_new = k.replace('base_model.', '')
    #         print(k_new)
    #         base_dict[k_new] = base_dict.pop(k)
    # print("\n\n\n")
    # print(base_dict.keys())
    torch_module.load_state_dict(base_dict)
    # torch_module = torch.nn.DataParallel(torch_module.cuda())
    # torch_module = torch_module.cuda()
    # torch_module.eval()

    # torch_module = TSN(2, 1, 'RGB',
    #                    base_model='mobilenetv2',
    #                    consensus_type='avg',
    #                    img_feature_dim=256,
    #                    pretrain='imagenet',
    #                    # is_shift=False, shift_div=8, shift_place='blockres',
    #                    is_shift=True, shift_div=8, shift_place='blockres',
    #                    # non_local='_nl' in './checkpoint/TSM_HockeyFights_RGB_mobilenetv2_shift8_blockres_avg_segment8_e100/ckpt.best.pth.tar',
    #                    non_local='_nl' in pt_path,
    #                    )
    # checkpoint = torch.load(
    #     # './checkpoint/TSM_HockeyFights_RGB_mobilenetv2_shift8_blockres_avg_segment8_e100/ckpt.best.pth.tar')
    #     pt_path)
    # checkpoint = checkpoint['state_dict']
    #
    # # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
    # base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
    # replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
    #                 'base_model.classifier.bias': 'new_fc.bias',
    #                 }
    # for k, v in replace_dict.items():
    #     if k in base_dict:
    #         base_dict[v] = base_dict.pop(k)
    # torch_module.load_state_dict(base_dict)

    torch_inputs = (torch.rand(1, 3, 224, 224), torch.zeros([1, 3, 56, 56]),
                    torch.zeros([1, 4, 28, 28]), torch.zeros([1, 4, 28, 28]),
                    torch.zeros([1, 8, 14, 14]), torch.zeros([1, 8, 14, 14]),
                    torch.zeros([1, 8, 14, 14]), torch.zeros([1, 12, 14, 14]),
                    torch.zeros([1, 12, 14, 14]), torch.zeros([1, 20, 7, 7]),
                    torch.zeros([1, 20, 7, 7]))
    # torch_inputs = (torch.rand(1, 3, 224, 224).cuda(),
    #                 torch.zeros([1, 3, 56, 56]).cuda(),
    #                 torch.zeros([1, 4, 28, 28]).cuda(),
    #                 torch.zeros([1, 4, 28, 28]).cuda(),
    #                 torch.zeros([1, 8, 14, 14]).cuda(),
    #                 torch.zeros([1, 8, 14, 14]).cuda(),
    #                 torch.zeros([1, 8, 14, 14]).cuda(),
    #                 torch.zeros([1, 12, 14, 14]).cuda(),
    #                 torch.zeros([1, 12, 14, 14]).cuda(),
    #                 torch.zeros([1, 20, 7, 7]).cuda(),
    #                 torch.zeros([1, 20, 7, 7]).cuda())
    if use_gpu:
        target = 'cuda'
    else:
        target = 'llvm -mcpu=cortex-a72 -target=armv7l-linux-gnueabihf'
    # target = 'llvm -mcpu=cortex-a72 -target=armv7l-linux-gnueabihf'
    return torch2executor(torch_module, torch_inputs, target)