コード例 #1
0
    def test_multi_output_with_unuse_model(self):
        """ Test multi-output model with Tuple Tensor as intermediate output and with one of tuple tensor not used """
        class MultiOutputWithUnuseModel(torch.nn.Module):
            """
            Model with Tuple of Tensors as output with one output tensor unused
            """
            def __init__(self):
                super(MultiOutputWithUnuseModel, self).__init__()
                self.layer = test_models.TupleOutputModel()
                self.conv1 = torch.nn.Conv2d(2, 4, kernel_size=3, padding=1)
                self.conv2 = torch.nn.Conv2d(6, 4, kernel_size=3, padding=1)

            def forward(self, *inputs):
                x, _, z = self.layer(inputs[0])
                x1 = self.conv1(x)
                z1 = self.conv2(z)
                return torch.cat([x1, z1], 1)

        inp_data = torch.rand(1, 3, 8, 8)
        model = MultiOutputWithUnuseModel()
        conn_graph = ConnectedGraph(model, (inp_data, ))
        self.assertEqual(6, len(conn_graph.ordered_ops))
        self.assertEqual(
            5,
            len([
                op for op in conn_graph.get_all_ops().keys()
                if 'convolution' in op
            ]))
        self.assertEqual(
            0,
            len([
                op for op in conn_graph.get_all_ops().keys() if 'Tuple' in op
            ]))
        self.assertEqual('cat', conn_graph.ordered_ops[-1].type)

        product_names = conn_graph.get_all_products().keys()
        self.assertEqual(
            0,
            len([product for product in product_names if 'Tuple' in product]))

        expected_products = [
            # layer #1 to conv1,conv2
            'convolution_0_to_convolution_3',
            'convolution_2_to_convolution_4',

            # conv1,conv2 to cat
            'convolution_3_to_cat_5',
            'convolution_4_to_cat_5'
        ]

        products = conn_graph.get_all_products()
        for product_name in product_names:
            if product_name in expected_products:
                product = products[product_name]
                self.assertEqual(product.shape, product.producer.output_shape)
                expected_products.remove(product_name)
        self.assertEqual(0, len(expected_products))
コード例 #2
0
 def test_multi_output_model(self):
     """ Test multi-output model with Tuple Tensor as intermediate  output. """
     model = test_models.MultiOutputModel()
     inp_data = torch.rand(1, 3, 8, 8)
     conn_graph = ConnectedGraph(model, (inp_data, ))
     self.assertEqual(7, len(conn_graph.ordered_ops))
     self.assertEqual(
         6,
         len([
             op for op in conn_graph.get_all_ops().keys()
             if 'convolution' in op
         ]))
     self.assertEqual(
         0,
         len([
             op for op in conn_graph.get_all_ops().keys() if 'Tuple' in op
         ]))
     self.assertEqual(
         0,
         len([
             product for product in conn_graph.get_all_products().keys()
             if 'Tuple' in product
         ]))
     self.assertEqual('cat', conn_graph.ordered_ops[-1].type)
コード例 #3
0
    def test_multi_output_with_shuffled_layers(self):
        """ Test a multiple layer multi-output model with intermediate Tuple Tensors shuffled """
        class MultiOutputShuffledModel(torch.nn.Module):
            """
            Model with Tuple of Tensors as output shuffled between layers
            """
            def __init__(self):
                super(MultiOutputShuffledModel, self).__init__()
                self.layer1 = test_models.ConfigurableTupleOutputModel(
                    channels=(1, 2, 3))
                self.layer2 = test_models.ConfigurableTupleOutputModel(
                    channels=(2, 3, 1))
                self.layer3 = test_models.ConfigurableTupleOutputModel(
                    channels=(3, 1, 2))

            def forward(self, *inputs):
                x1, x2, x3 = self.layer1(inputs[0], inputs[1], inputs[2])
                y2, y3, y1 = self.layer2(x2, x3, x1)
                z3, z1, z2 = self.layer3(y3, y1, y2)
                return torch.cat([z1, z2, z3, x1], 1)

        model = MultiOutputShuffledModel()
        inp_tensor_list = create_rand_tensors_given_shapes([(1, 1, 8, 8),
                                                            (1, 2, 8, 8),
                                                            (1, 3, 8, 8)])
        conn_graph = ConnectedGraph(model, inp_tensor_list)
        self.assertEqual(10, len(conn_graph.ordered_ops))
        self.assertEqual(
            9,
            len([
                op for op in conn_graph.get_all_ops().keys()
                if 'convolution' in op
            ]))
        self.assertEqual(
            0,
            len([
                op for op in conn_graph.get_all_ops().keys() if 'Tuple' in op
            ]))
        self.assertEqual('cat', conn_graph.ordered_ops[-1].type)

        product_names = conn_graph.get_all_products().keys()
        self.assertEqual(
            0,
            len([product for product in product_names if 'Tuple' in product]))

        expected_products = [
            # TODO fix order of products

            # layer #1 to layer #2
            'convolution_0__to__Split_0',
            'convolution_1_to_convolution_3',
            'convolution_2_to_convolution_4',

            # layer #2 to layer #3
            'convolution_3_to_convolution_8',
            'convolution_4_to_convolution_6',
            'convolution_5_to_convolution_7',

            # layer #3, layer#1.conv1 to cat
            'convolution_6_to_cat_9',
            'convolution_7_to_cat_9',
            'convolution_8_to_cat_9'
        ]

        products = conn_graph.get_all_products()
        for product_name in product_names:
            if product_name in expected_products:
                product = products[product_name]
                self.assertEqual(product.shape, product.producer.output_shape)
                expected_products.remove(product_name)
        self.assertEqual(0, len(expected_products))
        split_product = conn_graph.get_all_products(
        )['Split_0__to__multiple_ops']
        self.assertTrue(conn_graph.get_all_ops()['convolution_5'] in
                        split_product.consumers)
        self.assertTrue(
            conn_graph.get_all_ops()['cat_9'] in split_product.consumers)