Esempio n. 1
0
    def test_multiple_input_tensor_list(self):
        input_data = torch.randn(1, 300)
        other_input_data = torch.randn(1, 300).long()

        metrics = summary(MultipleInputNetDifferentDtypes(), [input_data, other_input_data])

        assert metrics.input_size == [torch.Size([1, 300]), torch.Size([1, 300])]
    def test_multiple_input_tensor_args() -> None:
        input_data = torch.randn(1, 300)
        other_input_data = torch.randn(1, 300).long()

        metrics = summary(MultipleInputNetDifferentDtypes(), input_data,
                          other_input_data)

        assert metrics.input_size == [torch.Size([1, 300])]
Esempio n. 3
0
    def test_dict_out(capsys: pytest.CaptureFixture[str]) -> None:
        # TODO: expand this test to handle intermediate dict layers.
        model = MultipleInputNetDifferentDtypes()
        input_data = torch.randn(1, 300)
        other_input_data = torch.randn(1, 300).long()

        summary(model, input_data={"x1": input_data, "x2": other_input_data})

        verify_output(capsys, "tests/test_output/dict_input.out")
Esempio n. 4
0
    def test_multiple_input_types(self):
        model = MultipleInputNetDifferentDtypes()
        input1 = (1, 300)
        input2 = (1, 300)
        if torch.cuda.is_available():
            dtypes = [torch.cuda.FloatTensor, torch.cuda.LongTensor]
        else:
            dtypes = [torch.FloatTensor, torch.LongTensor]

        results = summary(model, [input1, input2], dtypes=dtypes)

        assert results.total_params == 31120
        assert results.trainable_params == 31120
Esempio n. 5
0
    def test_multiple_input_types() -> None:
        model = MultipleInputNetDifferentDtypes()
        input_size = (1, 300)
        if torch.cuda.is_available():
            dtypes = [
                torch.cuda.FloatTensor,  # type: ignore[attr-defined]
                torch.cuda.LongTensor,  # type: ignore[attr-defined]
            ]
        else:
            dtypes = [torch.FloatTensor, torch.LongTensor]

        results = summary(model, input_size=[input_size, input_size], dtypes=dtypes)

        assert results.total_params == 31120
        assert results.trainable_params == 31120