Example #1
0
    def test_torchscripting_using_trace(self):
        """
        Test that the save_torchscript function works as expected with trace
        """
        config = get_test_task_config()
        torchscript_folder = self.base_dir + "/torchscript_end_test/"

        # create a torchscript hook using trace
        torchscript_hook = TorchscriptHook(torchscript_folder)
        self.execute_hook(config, torchscript_folder, torchscript_hook)
Example #2
0
    def test_torchscripting_using_script(self):
        """
        Test that the save_torchscript function works as expected with script
        """
        config = get_test_task_config()
        # Setting wrapper_cls to None to make ResNet model torchscriptable
        ResNet.wrapper_cls = None
        torchscript_folder = self.base_dir + "/torchscript_end_test/"

        # create a torchscript hook using script
        torchscript_hook = TorchscriptHook(torchscript_folder, use_trace=False)
        self.execute_hook(config, torchscript_folder, torchscript_hook)
Example #3
0
    def test_torchscripting(self):
        """
        Test that the save_torchscript function works as expected.
        """
        config = get_test_task_config()
        task = build_task(config)
        task.prepare()

        torchscript_folder = self.base_dir + "/torchscript_end_test/"

        # create a torchscript hook
        torchscript_hook = TorchscriptHook(torchscript_folder)

        # create checkpoint dir, verify on_start hook runs
        os.mkdir(torchscript_folder)
        torchscript_hook.on_start(task)

        task.train = True
        # call the on end function
        torchscript_hook.on_end(task)

        # load torchscript file
        torchscript_file_name = (
            f"{torchscript_hook.torchscript_folder}/{TORCHSCRIPT_FILE}")
        torchscript = torch.jit.load(torchscript_file_name)
        # compare model load from checkpoint vs torchscript
        with torch.no_grad():
            batchsize = 1
            model = task.model
            input_data = torch.randn((batchsize, ) + model.input_shape,
                                     dtype=torch.float)
            if torch.cuda.is_available():
                input_data = input_data.cuda()
            checkpoint_out = model(input_data)
            torchscript_out = torchscript(input_data)
            self.assertTrue(torch.allclose(checkpoint_out, torchscript_out))