示例#1
0
    def test_recursive(self):
        results = summary(RecursiveNet(), (64, 28, 28))
        second_layer = results.summary_list[1]

        assert len(results.summary_list) == 6, "Should find 6 layers"
        assert (second_layer.num_params_to_str() == "(recursive)"
                ), "should not count the second layer again"
        assert results.total_params == 36928
        assert results.trainable_params == 36928
        assert results.total_mult_adds == 173408256
示例#2
0
    def test_recursive() -> None:
        results = summary(RecursiveNet(), input_size=(1, 64, 28, 28))
        second_layer = results.summary_list[1]

        assert len(results.summary_list) == 6, "Should find 6 layers"
        assert (
            second_layer.num_params_to_str(reached_max_depth=False) == "(recursive)"
        ), "should not count the second layer again"
        assert results.total_params == 36928
        assert results.trainable_params == 36928
        assert results.total_mult_adds == 173408256
示例#3
0
 def test_model_with_args(self):
     summary(RecursiveNet(), (64, 28, 28), "args1", args2="args2")
示例#4
0
 def test_model_with_args() -> None:
     summary(
         RecursiveNet(), input_size=(1, 64, 28, 28), args1="args1", args2="args2"
     )