def test_flatten_basic(self): obj = [3, ([5, 6], {"name": [7, 9], "name2": 3})] res, schema = flatten_to_tuple(obj) self.assertEqual(res, (3, 5, 6, 7, 9, 3)) new_obj = schema(res) self.assertEqual(new_obj, obj) _, new_schema = flatten_to_tuple(new_obj) self.assertEqual(schema, new_schema) # test __eq__
def test_flatten_instances_boxes(self): inst = Instances( torch.tensor([5, 8]), pred_masks=torch.tensor([3]), pred_boxes=Boxes(torch.ones((1, 4))) ) obj = [3, ([5, 6], inst)] res, schema = flatten_to_tuple(obj) self.assertEqual(res[:3], (3, 5, 6)) for r, expected in zip(res[3:], (inst.pred_boxes.tensor, inst.pred_masks, inst.image_size)): self.assertIs(r, expected) new_obj = schema(res) assert_instances_allclose(new_obj[1][1], inst, rtol=0.0, size_as_tensor=True)
def forward(self, *input_args): flattened_inputs, _ = flatten_to_tuple(input_args) flattened_outputs = self.traced_model(*flattened_inputs) return self.outputs_schema(flattened_outputs)
def forward(self, image): outputs = inference_func(self[0], image) flattened_outputs, schema = flatten_to_tuple(outputs) if not hasattr(self, "schema"): self.schema = schema return flattened_outputs