def test_multi_output_with_unuse_model(self): """ Test multi-output model with Tuple Tensor as intermediate output and with one of tuple tensor not used """ class MultiOutputWithUnuseModel(torch.nn.Module): """ Model with Tuple of Tensors as output with one output tensor unused """ def __init__(self): super(MultiOutputWithUnuseModel, self).__init__() self.layer = test_models.TupleOutputModel() self.conv1 = torch.nn.Conv2d(2, 4, kernel_size=3, padding=1) self.conv2 = torch.nn.Conv2d(6, 4, kernel_size=3, padding=1) def forward(self, *inputs): x, _, z = self.layer(inputs[0]) x1 = self.conv1(x) z1 = self.conv2(z) return torch.cat([x1, z1], 1) inp_data = torch.rand(1, 3, 8, 8) model = MultiOutputWithUnuseModel() conn_graph = ConnectedGraph(model, (inp_data, )) self.assertEqual(6, len(conn_graph.ordered_ops)) self.assertEqual( 5, len([ op for op in conn_graph.get_all_ops().keys() if 'convolution' in op ])) self.assertEqual( 0, len([ op for op in conn_graph.get_all_ops().keys() if 'Tuple' in op ])) self.assertEqual('cat', conn_graph.ordered_ops[-1].type) product_names = conn_graph.get_all_products().keys() self.assertEqual( 0, len([product for product in product_names if 'Tuple' in product])) expected_products = [ # layer #1 to conv1,conv2 'convolution_0_to_convolution_3', 'convolution_2_to_convolution_4', # conv1,conv2 to cat 'convolution_3_to_cat_5', 'convolution_4_to_cat_5' ] products = conn_graph.get_all_products() for product_name in product_names: if product_name in expected_products: product = products[product_name] self.assertEqual(product.shape, product.producer.output_shape) expected_products.remove(product_name) self.assertEqual(0, len(expected_products))
def test_multi_output_model(self): """ Test multi-output model with Tuple Tensor as intermediate output. """ model = test_models.MultiOutputModel() inp_data = torch.rand(1, 3, 8, 8) conn_graph = ConnectedGraph(model, (inp_data, )) self.assertEqual(7, len(conn_graph.ordered_ops)) self.assertEqual( 6, len([ op for op in conn_graph.get_all_ops().keys() if 'convolution' in op ])) self.assertEqual( 0, len([ op for op in conn_graph.get_all_ops().keys() if 'Tuple' in op ])) self.assertEqual( 0, len([ product for product in conn_graph.get_all_products().keys() if 'Tuple' in product ])) self.assertEqual('cat', conn_graph.ordered_ops[-1].type)
def test_multi_output_with_shuffled_layers(self): """ Test a multiple layer multi-output model with intermediate Tuple Tensors shuffled """ class MultiOutputShuffledModel(torch.nn.Module): """ Model with Tuple of Tensors as output shuffled between layers """ def __init__(self): super(MultiOutputShuffledModel, self).__init__() self.layer1 = test_models.ConfigurableTupleOutputModel( channels=(1, 2, 3)) self.layer2 = test_models.ConfigurableTupleOutputModel( channels=(2, 3, 1)) self.layer3 = test_models.ConfigurableTupleOutputModel( channels=(3, 1, 2)) def forward(self, *inputs): x1, x2, x3 = self.layer1(inputs[0], inputs[1], inputs[2]) y2, y3, y1 = self.layer2(x2, x3, x1) z3, z1, z2 = self.layer3(y3, y1, y2) return torch.cat([z1, z2, z3, x1], 1) model = MultiOutputShuffledModel() inp_tensor_list = create_rand_tensors_given_shapes([(1, 1, 8, 8), (1, 2, 8, 8), (1, 3, 8, 8)]) conn_graph = ConnectedGraph(model, inp_tensor_list) self.assertEqual(10, len(conn_graph.ordered_ops)) self.assertEqual( 9, len([ op for op in conn_graph.get_all_ops().keys() if 'convolution' in op ])) self.assertEqual( 0, len([ op for op in conn_graph.get_all_ops().keys() if 'Tuple' in op ])) self.assertEqual('cat', conn_graph.ordered_ops[-1].type) product_names = conn_graph.get_all_products().keys() self.assertEqual( 0, len([product for product in product_names if 'Tuple' in product])) expected_products = [ # TODO fix order of products # layer #1 to layer #2 'convolution_0__to__Split_0', 'convolution_1_to_convolution_3', 'convolution_2_to_convolution_4', # layer #2 to layer #3 'convolution_3_to_convolution_8', 'convolution_4_to_convolution_6', 'convolution_5_to_convolution_7', # layer #3, layer#1.conv1 to cat 'convolution_6_to_cat_9', 'convolution_7_to_cat_9', 'convolution_8_to_cat_9' ] products = conn_graph.get_all_products() for product_name in product_names: if product_name in expected_products: product = products[product_name] self.assertEqual(product.shape, product.producer.output_shape) expected_products.remove(product_name) self.assertEqual(0, len(expected_products)) split_product = conn_graph.get_all_products( )['Split_0__to__multiple_ops'] self.assertTrue(conn_graph.get_all_ops()['convolution_5'] in split_product.consumers) self.assertTrue( conn_graph.get_all_ops()['cat_9'] in split_product.consumers)