def retrain(state_dict, part=1, num_epochs=5):

    # Hyper Parameters
    param = {
        'batch_size': 4,
        'test_batch_size': 50,
        'num_epochs': num_epochs,
        'learning_rate': 0.001,
        'weight_decay': 5e-4,
    }

    num_cnn_layer = sum(
        [int(len(v.size()) == 4) for d, v in state_dict.items()])

    num_fc_layer = sum(
        [int(len(v.size()) == 2) for d, v in state_dict.items()])

    state_key = [k for k, v in state_dict.items()]

    cfg = []
    first = True
    for d, v in state_dict.items():
        #print(v.data.size())
        if len(v.data.size()) == 4 or len(v.data.size()) == 2:
            if first:
                first = False
                cfg.append(v.data.size()[1])
            cfg.append(v.data.size()[0])

    assert num_cnn_layer + num_fc_layer == len(cfg) - 1

    net = ConvNet(cfg, num_cnn_layer, part)

    masks = []

    for i, p in enumerate(net.parameters()):

        p.data = state_dict[state_key[i]]

        if len(p.data.size()) == 4:

            p_np = p.data.cpu().numpy()

            masks.append(np.ones(p_np.shape).astype('float32'))

            value_this_layer = np.abs(p_np).sum(axis=(2, 3))

            for j in range(len(value_this_layer)):

                for k in range(len(value_this_layer[0])):

                    if abs(value_this_layer[j][k]) < 1e-4:

                        masks[-1][j][k] = 0.

        elif len(p.data.size()) == 2:

            p_np = p.data.cpu().numpy()

            masks.append(np.ones(p_np.shape).astype('float32'))

            value_this_layer = np.abs(p_np)

            for j in range(len(value_this_layer)):

                for k in range(len(value_this_layer[0])):

                    if abs(value_this_layer[j][k]) < 1e-4:

                        masks[-1][j][k] = 0.

    net.set_masks(masks)

    ## Retraining
    loader_train, loader_test = load_dataset()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.RMSprop(net.parameters(),
                                    lr=param['learning_rate'],
                                    weight_decay=param['weight_decay'])
    #if num_epochs > 0:
    #    test(net, loader_test)

    train(net, criterion, optimizer, param, loader_train)

    for i, p in enumerate(net.parameters()):

        state_dict[state_key[i]] = p.data
        #print(p.data == state_dict[ state_key[i] ])

    #print("--- After retraining ---")
    #test(net, loader_test)

    #return net.state_dict()
    return state_dict
示例#2
0
loader_test = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=param['test_batch_size'],
                                          shuffle=True)

# Load the pretrained model
net = ConvNet()
net.load_state_dict(torch.load('models/convnet_pretrained.pkl'))
if torch.cuda.is_available():
    print('CUDA ensabled.')
    net.cuda()
print("--- Pretrained network loaded ---")
test(net, loader_test)

# prune the weights
masks = filter_prune(net, param['pruning_perc'])
net.set_masks(masks)
print("--- {}% parameters pruned ---".format(param['pruning_perc']))
test(net, loader_test)

# Retraining
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(net.parameters(),
                                lr=param['learning_rate'],
                                weight_decay=param['weight_decay'])

train(net, criterion, optimizer, param, loader_train)

# Check accuracy and nonzeros weights in each layer
print("--- After retraining ---")
test(net, loader_test)
prune_rate(net)