def retrain(state_dict, part=1, num_epochs=5): # Hyper Parameters param = { 'batch_size': 4, 'test_batch_size': 50, 'num_epochs': num_epochs, 'learning_rate': 0.001, 'weight_decay': 5e-4, } num_cnn_layer = sum( [int(len(v.size()) == 4) for d, v in state_dict.items()]) num_fc_layer = sum( [int(len(v.size()) == 2) for d, v in state_dict.items()]) state_key = [k for k, v in state_dict.items()] cfg = [] first = True for d, v in state_dict.items(): #print(v.data.size()) if len(v.data.size()) == 4 or len(v.data.size()) == 2: if first: first = False cfg.append(v.data.size()[1]) cfg.append(v.data.size()[0]) assert num_cnn_layer + num_fc_layer == len(cfg) - 1 net = ConvNet(cfg, num_cnn_layer, part) masks = [] for i, p in enumerate(net.parameters()): p.data = state_dict[state_key[i]] if len(p.data.size()) == 4: p_np = p.data.cpu().numpy() masks.append(np.ones(p_np.shape).astype('float32')) value_this_layer = np.abs(p_np).sum(axis=(2, 3)) for j in range(len(value_this_layer)): for k in range(len(value_this_layer[0])): if abs(value_this_layer[j][k]) < 1e-4: masks[-1][j][k] = 0. elif len(p.data.size()) == 2: p_np = p.data.cpu().numpy() masks.append(np.ones(p_np.shape).astype('float32')) value_this_layer = np.abs(p_np) for j in range(len(value_this_layer)): for k in range(len(value_this_layer[0])): if abs(value_this_layer[j][k]) < 1e-4: masks[-1][j][k] = 0. net.set_masks(masks) ## Retraining loader_train, loader_test = load_dataset() criterion = nn.CrossEntropyLoss() optimizer = torch.optim.RMSprop(net.parameters(), lr=param['learning_rate'], weight_decay=param['weight_decay']) #if num_epochs > 0: # test(net, loader_test) train(net, criterion, optimizer, param, loader_train) for i, p in enumerate(net.parameters()): state_dict[state_key[i]] = p.data #print(p.data == state_dict[ state_key[i] ]) #print("--- After retraining ---") #test(net, loader_test) #return net.state_dict() return state_dict
loader_test = torch.utils.data.DataLoader(test_dataset, batch_size=param['test_batch_size'], shuffle=True) # Load the pretrained model net = ConvNet() net.load_state_dict(torch.load('models/convnet_pretrained.pkl')) if torch.cuda.is_available(): print('CUDA ensabled.') net.cuda() print("--- Pretrained network loaded ---") test(net, loader_test) # prune the weights masks = filter_prune(net, param['pruning_perc']) net.set_masks(masks) print("--- {}% parameters pruned ---".format(param['pruning_perc'])) test(net, loader_test) # Retraining criterion = nn.CrossEntropyLoss() optimizer = torch.optim.RMSprop(net.parameters(), lr=param['learning_rate'], weight_decay=param['weight_decay']) train(net, criterion, optimizer, param, loader_train) # Check accuracy and nonzeros weights in each layer print("--- After retraining ---") test(net, loader_test) prune_rate(net)