コード例 #1
0
def generate_model_from_state_dict(state_dict, specs):
    """ Generate a model specified by 'specs' and load the given 'state_dict'. """
    net = Net(specs.net, specs.dataset, specs.plan_conv, specs.plan_fc)
    net.load_state_dict(state_dict)
    net.prune_net(
        0., 0., reset=False)  # apply pruned masks, but do not modify the masks
    return net
コード例 #2
0
 def test_sparsity_report_after_single_prune_lenet_300_100(self):
     """ Should prune each layer with the given pruning rate, except for the last layer (half fc pruning-rate).
     total_weights = (28*28*300) + (300*100) + (100*10) = 266200
     sparsity = ((28*28*300)*0.9 + (300*100)*0.9 + (100*10)*0.95) / 266200 ~ 0.9002 """
     net = Net(NetNames.LENET,
               DatasetNames.MNIST,
               plan_conv=[],
               plan_fc=[300, 100])
     net.prune_net(prune_rate_conv=0.0, prune_rate_fc=0.1)
     np.testing.assert_array_equal(np.array([0.9002, 0.9, 0.9, 0.95]),
                                   net.sparsity_report())
コード例 #3
0
 def generate_randomly_reinitialized_net(specs, state_dict):
     """ Build a net from 'state_dict' and randomly reinitialize its weights.
     The net has the same masks like the net specified by 'state_dict'. """
     assert isinstance(
         specs, ExperimentSpecs
     ), f"'specs' needs to be ExperimentSpecs, but is {type(specs)}."
     net = Net(specs.net, specs.dataset, specs.plan_conv, specs.plan_fc)
     net.load_state_dict(state_dict)
     net.apply(gaussian_glorot)
     net.store_initial_weights()
     net.prune_net(0.0, 0.0)
     return net
コード例 #4
0
 def test_sparsity_report_after_single_prune_conv2(self):
     """ Should prune each layer with the given pruning rate, except for the last layer (half fc pruning-rate).
     total_weights = conv+fc = 38592+4262400 = 4300992
     sparsity = (38592*0.9 + (16*16*64*256 + 256*256)*0.8 + (256*10)*0.9) / 4300992 ~ 0.8010 """
     net = Net(NetNames.CONV,
               DatasetNames.MNIST,
               plan_conv=[64, 64, 'M'],
               plan_fc=[256, 256])
     net.prune_net(prune_rate_conv=0.1, prune_rate_fc=0.2)
     np.testing.assert_almost_equal(np.array(
         [0.801, 0.9, 0.9, 0.8, 0.8, 0.9]),
                                    net.sparsity_report(),
                                    decimal=3)
コード例 #5
0
    def test_get_trained_instance(self):
        """ The pruned and trained network should return a trained copy of itself. """
        net = Net(NetNames.CONV,
                  DatasetNames.CIFAR10,
                  plan_conv=[2, 'M'],
                  plan_fc=[2])
        net.conv[0].weight.add_(0.5)
        net.fc[0].weight.add_(0.5)
        net.prune_net(prune_rate_conv=0.0, prune_rate_fc=0.1, reset=False)

        new_net = net.get_new_instance(reset_weight=False)

        np.testing.assert_array_equal(net.sparsity_report(),
                                      new_net.sparsity_report())
        self.assertEqual(NetNames.CONV, new_net.net_name)
        self.assertEqual(DatasetNames.CIFAR10, new_net.dataset_name)
        self.assertIs(
            torch.equal(new_net.conv[0].weight,
                        net.conv[0].weight.mul(net.conv[0].weight_mask)), True)
        self.assertIs(
            torch.equal(new_net.fc[0].weight,
                        net.fc[0].weight.mul(net.fc[0].weight_mask)), True)