Ejemplo n.º 1
0
        def check_out(out):
            self.assertEqual(len(out), 1)

            def subsample_tensor(tensor):
                num_elems = tensor.numel()
                num_samples = 20
                if num_elems <= num_samples:
                    return tensor

                flat_tensor = tensor.flatten()
                ith_index = num_elems // num_samples
                return flat_tensor[ith_index - 1::ith_index]

            def compute_mean_std(tensor):
                # can't compute mean of integral tensor
                tensor = tensor.to(torch.double)
                mean = torch.mean(tensor)
                std = torch.std(tensor)
                return {"mean": mean, "std": std}

            if name == "maskrcnn_resnet50_fpn":
                # maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now
                # compare results with mean and std
                test_value = map_nested_tensor_object(
                    out, tensor_map_fn=compute_mean_std)
                # mean values are small, use large prec
                self.assertExpected(test_value,
                                    prec=.01,
                                    strip_suffix="_" + dev)
            else:
                self.assertExpected(map_nested_tensor_object(
                    out, tensor_map_fn=subsample_tensor),
                                    prec=0.01,
                                    strip_suffix="_" + dev)
    def _test_detection_model(self, name):
        set_rng_seed(0)
        model = models.detection.__dict__[name](num_classes=50,
                                                pretrained_backbone=False)
        model.eval()
        input_shape = (3, 300, 300)
        x = torch.rand(input_shape)
        model_input = [x]
        out = model(model_input)
        self.assertIs(model_input[0], x)
        self.assertEqual(len(out), 1)

        def subsample_tensor(tensor):
            num_elems = tensor.numel()
            num_samples = 20
            if num_elems <= num_samples:
                return tensor

            flat_tensor = tensor.flatten()
            ith_index = num_elems // num_samples
            return flat_tensor[ith_index - 1::ith_index]

        def compute_mean_std(tensor):
            # can't compute mean of integral tensor
            tensor = tensor.to(torch.double)
            mean = torch.mean(tensor)
            std = torch.std(tensor)
            return {"mean": mean, "std": std}

        # maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now
        # compare results with mean and std
        if name == "maskrcnn_resnet50_fpn":
            test_value = map_nested_tensor_object(
                out, tensor_map_fn=compute_mean_std)
            # mean values are small, use large prec
            self.assertExpected(test_value, prec=.01)
        else:
            self.assertExpected(map_nested_tensor_object(
                out, tensor_map_fn=subsample_tensor),
                                prec=0.01)

        scripted_model = torch.jit.script(model)
        scripted_model.eval()
        scripted_out = scripted_model(model_input)[1]
        self.assertEqual(scripted_out[0]["boxes"], out[0]["boxes"])
        self.assertEqual(scripted_out[0]["scores"], out[0]["scores"])
        # labels currently float in script: need to investigate (though same result)
        self.assertEqual(scripted_out[0]["labels"].to(dtype=torch.long),
                         out[0]["labels"])
        self.assertTrue("boxes" in out[0])
        self.assertTrue("scores" in out[0])
        self.assertTrue("labels" in out[0])
        # don't check script because we are compiling it here:
        # TODO: refactor tests
        # self.check_script(model, name)
        self.checkModule(model, name, ([x], ))
Ejemplo n.º 3
0
        def check_out(out):
            self.assertEqual(len(out), 1)

            def compact(tensor):
                size = tensor.size()
                elements_per_sample = functools.reduce(operator.mul, size[1:],
                                                       1)
                if elements_per_sample > 30:
                    return compute_mean_std(tensor)
                else:
                    return subsample_tensor(tensor)

            def subsample_tensor(tensor):
                num_elems = tensor.size(0)
                num_samples = 20
                if num_elems <= num_samples:
                    return tensor

                ith_index = num_elems // num_samples
                return tensor[ith_index - 1::ith_index]

            def compute_mean_std(tensor):
                # can't compute mean of integral tensor
                tensor = tensor.to(torch.double)
                mean = torch.mean(tensor)
                std = torch.std(tensor)
                return {"mean": mean, "std": std}

            output = map_nested_tensor_object(out, tensor_map_fn=compact)
            prec = 0.01
            strip_suffix = f"_{dev}"
            try:
                # We first try to assert the entire output if possible. This is not
                # only the best way to assert results but also handles the cases
                # where we need to create a new expected result.
                self.assertExpected(output,
                                    prec=prec,
                                    strip_suffix=strip_suffix)
            except AssertionError:
                # Unfortunately detection models are flaky due to the unstable sort
                # in NMS. If matching across all outputs fails, use the same approach
                # as in NMSTester.test_nms_cuda to see if this is caused by duplicate
                # scores.
                expected_file = self._get_expected_file(
                    strip_suffix=strip_suffix)
                expected = torch.load(expected_file)
                self.assertEqual(output[0]["scores"],
                                 expected[0]["scores"],
                                 prec=prec)

                # Note: Fmassa proposed turning off NMS by adapting the threshold
                # and then using the Hungarian algorithm as in DETR to find the
                # best match between output and expected boxes and eliminate some
                # of the flakiness. Worth exploring.
                return False  # Partial validation performed

            return True  # Full validation performed
Ejemplo n.º 4
0
    def _test_detection_model(self, name):
        set_rng_seed(0)
        model = models.detection.__dict__[name](num_classes=50,
                                                pretrained_backbone=False)
        self.check_script(model, name)
        model.eval()
        input_shape = (3, 300, 300)
        x = torch.rand(input_shape)
        model_input = [x]
        out = model(model_input)
        self.assertIs(model_input[0], x)
        self.assertEqual(len(out), 1)

        def subsample_tensor(tensor):
            num_elems = tensor.numel()
            num_samples = 20
            if num_elems <= num_samples:
                return tensor

            flat_tensor = tensor.flatten()
            ith_index = num_elems // num_samples
            return flat_tensor[ith_index - 1::ith_index]

        def compute_mean_std(tensor):
            # can't compute mean of integral tensor
            tensor = tensor.to(torch.double)
            mean = torch.mean(tensor)
            std = torch.std(tensor)
            return {"mean": mean, "std": std}

        # maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now
        # compare results with mean and std
        if name == "maskrcnn_resnet50_fpn":
            test_value = map_nested_tensor_object(
                out, tensor_map_fn=compute_mean_std)
            # mean values are small, use large rtol
            self.assertExpected(test_value, rtol=.01, atol=.01)
        else:
            self.assertExpected(
                map_nested_tensor_object(out, tensor_map_fn=subsample_tensor))

        self.assertTrue("boxes" in out[0])
        self.assertTrue("scores" in out[0])
        self.assertTrue("labels" in out[0])