def test_sparsity_report_initial_weights(self): """ The convolutional neural network should be fully connected right after initialization. """ net = Net(NetNames.CONV, DatasetNames.CIFAR10, plan_conv=[8, 'M', 16, 'A'], plan_fc=[32, 16]) np.testing.assert_array_equal(np.ones(6, dtype=float), net.sparsity_report())
def test_sparsity_report_after_single_prune_lenet_300_100(self): """ Should prune each layer with the given pruning rate, except for the last layer (half fc pruning-rate). total_weights = (28*28*300) + (300*100) + (100*10) = 266200 sparsity = ((28*28*300)*0.9 + (300*100)*0.9 + (100*10)*0.95) / 266200 ~ 0.9002 """ net = Net(NetNames.LENET, DatasetNames.MNIST, plan_conv=[], plan_fc=[300, 100]) net.prune_net(prune_rate_conv=0.0, prune_rate_fc=0.1) np.testing.assert_array_equal(np.array([0.9002, 0.9, 0.9, 0.95]), net.sparsity_report())
def test_sparsity_report_after_single_prune_conv2(self): """ Should prune each layer with the given pruning rate, except for the last layer (half fc pruning-rate). total_weights = conv+fc = 38592+4262400 = 4300992 sparsity = (38592*0.9 + (16*16*64*256 + 256*256)*0.8 + (256*10)*0.9) / 4300992 ~ 0.8010 """ net = Net(NetNames.CONV, DatasetNames.MNIST, plan_conv=[64, 64, 'M'], plan_fc=[256, 256]) net.prune_net(prune_rate_conv=0.1, prune_rate_fc=0.2) np.testing.assert_almost_equal(np.array( [0.801, 0.9, 0.9, 0.8, 0.8, 0.9]), net.sparsity_report(), decimal=3)
def test_get_trained_instance(self): """ The pruned and trained network should return a trained copy of itself. """ net = Net(NetNames.CONV, DatasetNames.CIFAR10, plan_conv=[2, 'M'], plan_fc=[2]) net.conv[0].weight.add_(0.5) net.fc[0].weight.add_(0.5) net.prune_net(prune_rate_conv=0.0, prune_rate_fc=0.1, reset=False) new_net = net.get_new_instance(reset_weight=False) np.testing.assert_array_equal(net.sparsity_report(), new_net.sparsity_report()) self.assertEqual(NetNames.CONV, new_net.net_name) self.assertEqual(DatasetNames.CIFAR10, new_net.dataset_name) self.assertIs( torch.equal(new_net.conv[0].weight, net.conv[0].weight.mul(net.conv[0].weight_mask)), True) self.assertIs( torch.equal(new_net.fc[0].weight, net.fc[0].weight.mul(net.fc[0].weight_mask)), True)