def test_body_matches_heads(self):
        stub_32 = Identity2D((32, 32), 'stub_32')
        stub_64 = Identity2D((32, 64), 'stub_64')

        source_nets = {'source': stub_32}
        body = stub_32
        heads = {'2': stub_64}  # should error
        output_space = {'output': (32, 64)}

        with self.assertRaises(AssertionError):
            ModularNetwork(source_nets, body, heads, output_space,
                           dummy_gpu_preprocessor)
    def test_source_nets_match_body(self):
        stub_32 = Identity2D((32, 32), "stub_32")
        stub_64 = Identity2D((32, 64), "stub_64")

        source_nets = {"source": stub_32}
        body = stub_64  # should error
        heads = {"2": stub_64}
        output_space = {"output": (32, 64)}

        with self.assertRaises(AssertionError):
            ModularNetwork(
                source_nets, body, heads, output_space, dummy_gpu_preprocessor
            )
    def test_heads_match_out_shapes(self):
        stub_2d = Identity2D((32, 32), 'stub_2d')

        source_nets = {'source': stub_2d}
        body = stub_2d
        heads = {'2': stub_2d}
        output_space = {'output': (32, 64)}  # should error
        with self.assertRaises(AssertionError):
            ModularNetwork(source_nets, body, heads, output_space,
                           dummy_gpu_preprocessor)
    def test_output_has_a_head(self):
        stub_2d = Identity2D((32, 32), "stub_2d")

        source_nets = {"source": stub_2d}
        body = stub_2d
        heads = {"2": stub_2d}
        output_space = {"output": (32, 32, 32)}  # should error
        with self.assertRaises(AssertionError):
            ModularNetwork(
                source_nets, body, heads, output_space, dummy_gpu_preprocessor
            )
    def test_heads_not_higher_dim_than_body(self):
        stub_1d = Identity1D((32, ), 'stub_1d')
        stub_2d = Identity2D((32, 32), 'stub_2d')

        source_nets = {'source': stub_1d}
        body = stub_1d
        heads = {'2': stub_2d}
        output_space = {'output': (32, 32)}

        with self.assertRaises(AssertionError):
            ModularNetwork(source_nets, body, heads, output_space,
                           dummy_gpu_preprocessor)
class TestModularNetwork(unittest.TestCase):
    # Example of valid structure
    source_nets = {
        "source_1d": Identity1D((16,), "source_1d"),
        "source_2d": Identity2D((16, 8 * 8), "source_2d"),
        "source_3d": Identity3D((16, 8, 8), "source_3d"),
        "source_4d": Identity4D((16, 8, 8, 8), "source_4d"),
    }
    body = Identity3D((176, 8, 8), "body")
    heads = {
        "1": Identity1D((11264,), "head1d"),
        "2": Identity2D((176, 64), "head2d"),
        "3": Identity3D((176, 8, 8), "head3d"),
    }
    output_space = {
        "output_1d": (16,),
        "output_2d": (16, 8 * 8),
        "output_3d": (16, 8, 8),
    }

    def test_heads_not_higher_dim_than_body(self):
        stub_1d = Identity1D((32,), "stub_1d")
        stub_2d = Identity2D((32, 32), "stub_2d")

        source_nets = {"source": stub_1d}
        body = stub_1d
        heads = {"2": stub_2d}
        output_space = {"output": (32, 32)}

        with self.assertRaises(AssertionError):
            ModularNetwork(
                source_nets, body, heads, output_space, dummy_gpu_preprocessor
            )

    def test_source_nets_match_body(self):
        stub_32 = Identity2D((32, 32), "stub_32")
        stub_64 = Identity2D((32, 64), "stub_64")

        source_nets = {"source": stub_32}
        body = stub_64  # should error
        heads = {"2": stub_64}
        output_space = {"output": (32, 64)}

        with self.assertRaises(AssertionError):
            ModularNetwork(
                source_nets, body, heads, output_space, dummy_gpu_preprocessor
            )

    def test_body_matches_heads(self):
        stub_32 = Identity2D((32, 32), "stub_32")
        stub_64 = Identity2D((32, 64), "stub_64")

        source_nets = {"source": stub_32}
        body = stub_32
        heads = {"2": stub_64}  # should error
        output_space = {"output": (32, 64)}

        with self.assertRaises(AssertionError):
            ModularNetwork(
                source_nets, body, heads, output_space, dummy_gpu_preprocessor
            )

    def test_output_has_a_head(self):
        stub_2d = Identity2D((32, 32), "stub_2d")

        source_nets = {"source": stub_2d}
        body = stub_2d
        heads = {"2": stub_2d}
        output_space = {"output": (32, 32, 32)}  # should error
        with self.assertRaises(AssertionError):
            ModularNetwork(
                source_nets, body, heads, output_space, dummy_gpu_preprocessor
            )

    def test_heads_match_out_shapes(self):
        stub_2d = Identity2D((32, 32), "stub_2d")

        source_nets = {"source": stub_2d}
        body = stub_2d
        heads = {"2": stub_2d}
        output_space = {"output": (32, 64)}  # should error
        with self.assertRaises(AssertionError):
            ModularNetwork(
                source_nets, body, heads, output_space, dummy_gpu_preprocessor
            )

    def test_valid_structure(self):
        try:
            ModularNetwork(
                self.source_nets,
                self.body,
                self.heads,
                self.output_space,
                dummy_gpu_preprocessor,
            )
        except:
            self.fail("Unexpected exception")

    def test_forward(self):
        import torch

        BATCH = 32
        obs = {
            "source_1d": torch.zeros((BATCH, 16,)),
            "source_2d": torch.zeros((BATCH, 16, 8 * 8)),
            "source_3d": torch.zeros((BATCH, 16, 8, 8)),
            "source_4d": torch.zeros((BATCH, 16, 8, 8, 8)),
        }
        try:
            net = ModularNetwork(
                self.source_nets,
                self.body,
                self.heads,
                self.output_space,
                dummy_gpu_preprocessor,
            )
            outputs, _ = net.forward(obs, {})
        except:
            self.fail("Unexpected exception")