예제 #1
0
    def test_fully_customized(self):
        """Evaluate if ResNet of different sizes initializes and runs"""

        custom_sparse_params = dict(
            stem=LayerParams(),
            filters64=[  # 3 blocks
                dict(
                    conv1x1_1=SparseWeightsLayerParams(
                        percent_on=0.3,
                        boost_strength=1.2,
                        boost_strength_factor=1.0,
                        local=False,
                        weight_sparsity=0.3,
                    ),
                    conv3x3_2=LayerParams(),
                    conv1x1_3=LayerParams(),
                    shortcut=LayerParams(),
                ),
                dict(
                    conv1x1_1=LayerParams(),
                    conv3x3_2=LayerParams(),
                    conv1x1_3=LayerParams(),
                    shortcut=LayerParams(),
                ),
                dict(
                    conv1x1_1=LayerParams(),
                    conv3x3_2=LayerParams(),
                    conv1x1_3=LayerParams(),
                    shortcut=LayerParams(),
                ),
            ],
            filters128=[  # 4 blocks
                dict(
                    conv1x1_1=LayerParams(),
                    conv3x3_2=LayerParams(),
                    conv1x1_3=LayerParams(),
                    shortcut=LayerParams(),
                ),
                dict(
                    conv1x1_1=LayerParams(),
                    conv3x3_2=LayerParams(),
                    conv1x1_3=LayerParams(),
                    shortcut=LayerParams(),
                ),
                dict(
                    conv1x1_1=LayerParams(),
                    conv3x3_2=LayerParams(),
                    conv1x1_3=LayerParams(),
                    shortcut=LayerParams(),
                ),
                dict(
                    conv1x1_1=LayerParams(),
                    conv3x3_2=LayerParams(),
                    conv1x1_3=LayerParams(),
                    shortcut=LayerParams(),
                ),
            ],
            filters256=[  # 6 blocks
                dict(
                    conv1x1_1=LayerParams(),
                    conv3x3_2=LayerParams(),
                    conv1x1_3=LayerParams(),
                    shortcut=LayerParams(),
                ),
                dict(
                    conv1x1_1=LayerParams(),
                    conv3x3_2=LayerParams(),
                    conv1x1_3=LayerParams(),
                    shortcut=LayerParams(),
                ),
                dict(
                    conv1x1_1=LayerParams(),
                    conv3x3_2=LayerParams(),
                    conv1x1_3=LayerParams(),
                    shortcut=LayerParams(),
                ),
                dict(
                    conv1x1_1=LayerParams(),
                    conv3x3_2=LayerParams(),
                    conv1x1_3=LayerParams(),
                    shortcut=LayerParams(),
                ),
                dict(
                    conv1x1_1=LayerParams(),
                    conv3x3_2=LayerParams(),
                    conv1x1_3=LayerParams(),
                    shortcut=LayerParams(),
                ),
                dict(
                    conv1x1_1=LayerParams(),
                    conv3x3_2=LayerParams(),
                    conv1x1_3=LayerParams(),
                    shortcut=LayerParams(),
                ),
            ],
            filters512=[  # 3 blocks
                dict(
                    conv1x1_1=LayerParams(),
                    conv3x3_2=LayerParams(),
                    conv1x1_3=LayerParams(),
                    shortcut=LayerParams(),
                ),
                dict(
                    conv1x1_1=LayerParams(),
                    conv3x3_2=LayerParams(),
                    conv1x1_3=LayerParams(),
                    shortcut=LayerParams(),
                ),
                dict(
                    conv1x1_1=LayerParams(),
                    conv3x3_2=LayerParams(),
                    conv1x1_3=LayerParams(),
                    shortcut=LayerParams(),
                ),
            ],
            linear=LayerParams(),
        )

        net = ResNet(config=dict(
            depth=50, num_classes=10, sparse_params=custom_sparse_params))
        net(Variable(torch.randn(2, 3, 32, 32)))

        self.assertIsInstance(net, ResNet, "Loads ResNet50 fully customized")
예제 #2
0
    def test_custom_per_group(self):
        """Evaluate ResNets customized per group"""

        custom_sparse_params = dict(
            stem=SparseWeightsLayerParams(),
            filters64=dict(
                conv1x1_1=SparseWeightsLayerParams(
                    percent_on=0.3,
                    boost_strength=1.2,
                    boost_strength_factor=1.0,
                    local=False,
                    weight_sparsity=0.3,
                ),
                conv3x3_2=SparseWeightsLayerParams(
                    percent_on=0.1,
                    boost_strength=1.2,
                    boost_strength_factor=1.0,
                    local=True,
                    weight_sparsity=0.1,
                ),
                conv1x1_3=SparseWeightsLayerParams(weight_sparsity=0.1),
                shortcut=SparseWeightsLayerParams(percent_on=0.4,
                                                  weight_sparsity=0.4),
            ),
            filters128=dict(
                conv1x1_1=LayerParams(),
                conv3x3_2=LayerParams(),
                conv1x1_3=LayerParams(),
                shortcut=LayerParams(),
            ),
            filters256=dict(
                conv1x1_1=LayerParams(),
                conv3x3_2=LayerParams(),
                conv1x1_3=LayerParams(),
                shortcut=LayerParams(),
            ),
            filters512=dict(
                conv1x1_1=LayerParams(),
                conv3x3_2=LayerParams(),
                conv1x1_3=LayerParams(),
                shortcut=LayerParams(),
            ),
            linear=SparseWeightsLayerParams(weight_sparsity=0.5),
        )

        net = ResNet(config=dict(
            depth=50, num_classes=10, sparse_params=custom_sparse_params))
        net(Variable(torch.randn(2, 3, 32, 32)))

        self.assertIsInstance(net, ResNet,
                              "Loads ResNet50 customized per group")