예제 #1
0
    def test_lstm_out(capsys):
        summary(
            LSTMNet(),
            (100, ),
            dtypes=[torch.long],
            branching=False,
            verbose=2,
            col_width=20,
            col_names=("kernel_size", "output_size", "num_params",
                       "mult_adds"),
        )

        verify_output(capsys, "unit_test/test_output/lstm.out")
예제 #2
0
    def test_lstm_out(self, capsys):
        summary(
            LSTMNet(),
            (100, ),
            dtypes=[torch.long],
            branching=False,
            verbose=2,
            col_width=20,
            col_names=("kernel_size", "output_size", "num_params",
                       "mult_adds"),
        )

        if sys.version_info < (3, 7):
            try:
                verify_output(capsys, "unit_test/test_output/lstm.out")
            except AssertionError:
                warnings.warn(
                    "LSTM verbose output is not determininstic because dictionaries are not "
                    "necessarily ordered in versions before Python 3.7.")
        else:
            verify_output(capsys, "unit_test/test_output/lstm.out")
예제 #3
0
    def test_lstm_out(capsys: pytest.CaptureFixture[str]) -> None:
        summary(
            LSTMNet(),
            input_size=(1, 100),
            dtypes=[torch.long],
            verbose=2,
            col_width=20,
            col_names=("kernel_size", "output_size", "num_params",
                       "mult_adds"),
        )

        if sys.version_info < (3, 7):
            try:
                verify_output(capsys, "unit_test/test_output/lstm.out")
            except AssertionError:
                warnings.warn(
                    "LSTM verbose output is not determininstic because dictionaries "
                    "are not necessarily ordered in versions before Python 3.7."
                )
        else:
            verify_output(capsys, "unit_test/test_output/lstm.out")
예제 #4
0
    def test_lstm(self):
        results = summary(LSTMNet(), (100,), dtypes=[torch.long])

        assert len(results.summary_list) == 3, "Should find 3 layers"
예제 #5
0
    def test_lstm_custom_batch_size() -> None:
        # batch_size intentionally omitted.
        results = summary(LSTMNet(), (100,), dtypes=[torch.long], batch_dim=1)

        assert len(results.summary_list) == 3, "Should find 3 layers"
예제 #6
0
    def test_lstm() -> None:
        results = summary(LSTMNet(), input_size=(100, 1), dtypes=[torch.long])

        assert len(results.summary_list) == 3, "Should find 3 layers"