Beispiel #1
0
def slim_speedup(masks_file, model_checkpoint):
    device = torch.device('cuda')
    model = VGG(depth=19)
    model.to(device)
    model.eval()

    dummy_input = torch.randn(64, 3, 32, 32)
    if use_mask:
        apply_compression_results(model, masks_file)
        dummy_input = dummy_input.to(device)
        start = time.time()
        for _ in range(32):
            out = model(dummy_input)
        #print(out.size(), out)
        print('mask elapsed time: ', time.time() - start)
        return
    else:
        #print("model before: ", model)
        m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
        m_speedup.speedup_model()
        #print("model after: ", model)
        dummy_input = dummy_input.to(device)
        start = time.time()
        for _ in range(32):
            out = model(dummy_input)
        #print(out.size(), out)
        print('speedup elapsed time: ', time.time() - start)
        return
Beispiel #2
0
def model_inference(config):
    masks_file = config['masks_file']
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    # device = torch.device(config['device'])
    if config['model_name'] == 'vgg16':
        model = VGG(depth=16)
    elif config['model_name'] == 'vgg19':
        model = VGG(depth=19)
    elif config['model_name'] == 'lenet':
        model = LeNet()

    model.to(device)
    model.eval()

    dummy_input = torch.randn(config['input_shape']).to(device)
    use_mask_out = use_speedup_out = None
    # must run use_mask before use_speedup because use_speedup modify the model
    if use_mask:
        apply_compression_results(model, masks_file, device)
        start = time.time()
        for _ in range(32):
            use_mask_out = model(dummy_input)
        print('elapsed time when use mask: ', time.time() - start)
    if use_speedup:
        m_speedup = ModelSpeedup(model, dummy_input, masks_file, device)
        m_speedup.speedup_model()
        start = time.time()
        for _ in range(32):
            use_speedup_out = model(dummy_input)
        print('elapsed time when use speedup: ', time.time() - start)
    if compare_results:
        if torch.allclose(use_mask_out, use_speedup_out, atol=1e-07):
            print('the outputs from use_mask and use_speedup are the same')
        else:
            raise RuntimeError(
                'the outputs from use_mask and use_speedup are different')