Example #1
0
def _check_jit_scriptable(nn_module, args, unwrapper=None, eager_out=None):
    """Check that a nn.Module's results in TorchScript match eager and that it can be exported"""

    def get_export_import_copy(m):
        """Save and load a TorchScript model"""
        with TemporaryDirectory() as dir:
            path = os.path.join(dir, "script.pt")
            m.save(path)
            imported = torch.jit.load(path)
        return imported

    sm = torch.jit.script(nn_module)

    if eager_out is None:
        with torch.no_grad(), freeze_rng_state():
            eager_out = nn_module(*args)

    with torch.no_grad(), freeze_rng_state():
        script_out = sm(*args)
        if unwrapper:
            script_out = unwrapper(script_out)

    torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4)

    m_import = get_export_import_copy(sm)
    with torch.no_grad(), freeze_rng_state():
        imported_script_out = m_import(*args)
        if unwrapper:
            imported_script_out = unwrapper(imported_script_out)

    torch.testing.assert_close(script_out, imported_script_out, atol=3e-4, rtol=3e-4)
Example #2
0
    def test_noncontig(self, test_case, module, input):
        # check no scalars, can't make non-contig
        if isinstance(input, torch.Tensor) and input.dim() == 0:
            return
        if any(i.dim() == 0 for i in input if isinstance(i, torch.Tensor)):
            return

        test_case._zero_grad_parameters(module)
        test_case._zero_grad_input(input)
        with freeze_rng_state():
            output = test_case._forward(module, input)
            grad_output = output.new(output.shape).normal_()
            output = output.clone()
            d_input = deepcopy(
                test_case._backward(module, input, output, grad_output))
            d_param = deepcopy(test_case._get_parameters(module)[1])

        nc_input = self.noncontiguize(input)
        nc_grad_output = self.noncontiguize(grad_output)
        for contig_i, contig_g in product((True, False), repeat=2):
            i = input if contig_i else nc_input
            go = grad_output if contig_g else nc_grad_output
            test_case._zero_grad_parameters(module)
            test_case._zero_grad_input(i)
            with freeze_rng_state():
                out = test_case._forward(module, i)
                grad = test_case._backward(module, i, out, go)

                test_case.assertEqual(out, output)
                test_case.assertEqual(grad, d_input, 1e-4)
                test_case.assertEqual(
                    test_case._get_parameters(module)[1], d_param)
Example #3
0
    def assert_export_import_module(m, args):
        """Check that the results of a model are the same after saving and loading"""
        def get_export_import_copy(m):
            """Save and load a TorchScript model"""
            buffer = io.BytesIO()
            torch.jit.save(m, buffer)
            buffer.seek(0)
            imported = torch.jit.load(buffer)
            return imported

        m_import = get_export_import_copy(m)
        with freeze_rng_state():
            results = m(*args)
        with freeze_rng_state():
            results_from_imported = m_import(*args)
        tol = 3e-4
        try:
            torch.testing.assert_close(results,
                                       results_from_imported,
                                       atol=tol,
                                       rtol=tol)
        except ValueError:
            # custom check for the models that return named tuples:
            # we compare field by field while ignoring None as assert_close can't handle None
            for a, b in zip(results, results_from_imported):
                if a is not None:
                    torch.testing.assert_close(a, b, atol=tol, rtol=tol)
Example #4
0
def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
    """Check that a nn.Module's results in TorchScript match eager and that it can be exported"""
    def assert_export_import_module(m, args):
        """Check that the results of a model are the same after saving and loading"""
        def get_export_import_copy(m):
            """Save and load a TorchScript model"""
            buffer = io.BytesIO()
            torch.jit.save(m, buffer)
            buffer.seek(0)
            imported = torch.jit.load(buffer)
            return imported

        m_import = get_export_import_copy(m)
        with freeze_rng_state():
            results = m(*args)
        with freeze_rng_state():
            results_from_imported = m_import(*args)
        tol = 3e-4
        try:
            torch.testing.assert_close(results,
                                       results_from_imported,
                                       atol=tol,
                                       rtol=tol)
        except ValueError:
            # custom check for the models that return named tuples:
            # we compare field by field while ignoring None as assert_close can't handle None
            for a, b in zip(results, results_from_imported):
                if a is not None:
                    torch.testing.assert_close(a, b, atol=tol, rtol=tol)

    TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1'
    if not TEST_WITH_SLOW or skip:
        # TorchScript is not enabled, skip these tests
        msg = "The check_jit_scriptable test for {} was skipped. " \
              "This test checks if the module's results in TorchScript " \
              "match eager and that it can be exported. To run these " \
              "tests make sure you set the environment variable " \
              "PYTORCH_TEST_WITH_SLOW=1 and that the test is not " \
              "manually skipped.".format(nn_module.__class__.__name__)
        warnings.warn(msg, RuntimeWarning)
        return None

    sm = torch.jit.script(nn_module)

    with freeze_rng_state():
        eager_out = nn_module(*args)

    with freeze_rng_state():
        script_out = sm(*args)
        if unwrapper:
            script_out = unwrapper(script_out)

    torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4)
    assert_export_import_module(sm, args)
def do_test(model, bailout, print_diff):
    logging.basicConfig(filename='jit_' + model.replace('/', '-') +
                        str(int(time.time())) + '.log',
                        filemode='w',
                        level=logging.DEBUG)
    with enable_profiling_mode():
        logging.info("loading profiled %s", model)
        jm = torch.jit.load(model)
        #jm.eval()
        logging.info("running profiled %s", model)
        with freeze_rng_state():
            po = jm()
        logging.info("running profiled2 %s", model)
        with freeze_rng_state():
            po2 = jm()
        if not test_allclose(po, po2):
            logging.error("profiled and profiled2 outputs aren't equal")
            if (print_diff):
                logging.error("po : %s", str(po))
                logging.error("po2 : %s", str(po2))
        logging.info("running optimized %s", model)
        with freeze_rng_state():
            jo = jm()
        if not test_allclose(po, jo):
            logging.error("profiled and optimized outputs aren't equal")
            if (print_diff):
                logging.error("po : %s", str(po))
                logging.error("jo : %s", str(jo))
        plan = get_plan(jm)
        num_bailouts = plan.code.num_bailouts()
        logging.info("number of bailouts: %d", num_bailouts)
        if bailout:
            logging.info("triggering bailout %d ", bailout)
            plan.code.request_bailout(bailout)
            with freeze_rng_state():
                bo = jm()
            if not test_allclose(bo, jo):
                logging.error("bailout %d and optimized outputs aren't equal",
                              bailout)
                if (print_diff):
                    logging.error("bo : %s", str(bo))
                    logging.error("jo : %s", str(jo))
        else:
            for i in range(0, num_bailouts):
                logging.info("triggering bailout %d ", i)
                plan.code.request_bailout(i)
                with freeze_rng_state():
                    bo = jm()
                if not test_allclose(bo, jo):
                    logging.error(
                        "bailout %d and optimized outputs aren't equal", i)
                    if (print_diff):
                        logging.error("bo : %s", str(bo))
                        logging.error("jo : %s", str(jo))
Example #6
0
def _check_jit_scriptable(nn_module,
                          args,
                          unwrapper=None,
                          skip=False,
                          eager_out=None):
    """Check that a nn.Module's results in TorchScript match eager and that it can be exported"""
    def get_export_import_copy(m):
        """Save and load a TorchScript model"""
        with TemporaryDirectory() as dir:
            path = os.path.join(dir, "script.pt")
            m.save(path)
            imported = torch.jit.load(path)
        return imported

    TEST_WITH_SLOW = os.getenv("PYTORCH_TEST_WITH_SLOW", "0") == "1"
    if not TEST_WITH_SLOW or skip:
        # TorchScript is not enabled, skip these tests
        msg = (
            f"The check_jit_scriptable test for {nn_module.__class__.__name__} was skipped. "
            "This test checks if the module's results in TorchScript "
            "match eager and that it can be exported. To run these "
            "tests make sure you set the environment variable "
            "PYTORCH_TEST_WITH_SLOW=1 and that the test is not "
            "manually skipped.")
        warnings.warn(msg, RuntimeWarning)
        return None

    sm = torch.jit.script(nn_module)

    if eager_out is None:
        with torch.no_grad(), freeze_rng_state():
            if unwrapper:
                eager_out = nn_module(*args)

    with torch.no_grad(), freeze_rng_state():
        script_out = sm(*args)
        if unwrapper:
            script_out = unwrapper(script_out)

    torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4)

    m_import = get_export_import_copy(sm)
    with torch.no_grad(), freeze_rng_state():
        imported_script_out = m_import(*args)
        if unwrapper:
            imported_script_out = unwrapper(imported_script_out)

    torch.testing.assert_close(script_out,
                               imported_script_out,
                               atol=3e-4,
                               rtol=3e-4)
Example #7
0
def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
    """Check that a nn.Module's results in TorchScript match eager and that it can be exported"""
    def assert_export_import_module(m, args):
        """Check that the results of a model are the same after saving and loading"""
        def get_export_import_copy(m):
            """Save and load a TorchScript model"""
            buffer = io.BytesIO()
            torch.jit.save(m, buffer)
            buffer.seek(0)
            imported = torch.jit.load(buffer)
            return imported

        m_import = get_export_import_copy(m)
        with freeze_rng_state():
            results = m(*args)
        with freeze_rng_state():
            results_from_imported = m_import(*args)
        tol = 3e-4
        torch.testing.assert_close(results,
                                   results_from_imported,
                                   atol=tol,
                                   rtol=tol)

    TEST_WITH_SLOW = os.getenv("PYTORCH_TEST_WITH_SLOW", "0") == "1"
    if not TEST_WITH_SLOW or skip:
        # TorchScript is not enabled, skip these tests
        msg = (
            f"The check_jit_scriptable test for {nn_module.__class__.__name__} was skipped. "
            "This test checks if the module's results in TorchScript "
            "match eager and that it can be exported. To run these "
            "tests make sure you set the environment variable "
            "PYTORCH_TEST_WITH_SLOW=1 and that the test is not "
            "manually skipped.")
        warnings.warn(msg, RuntimeWarning)
        return None

    sm = torch.jit.script(nn_module)

    with freeze_rng_state():
        eager_out = nn_module(*args)

    with freeze_rng_state():
        script_out = sm(*args)
        if unwrapper:
            script_out = unwrapper(script_out)

    torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4)
    assert_export_import_module(sm, args)
Example #8
0
    def checkModule(self, nn_module, args):
        """
        Check that a nn.Module's results in Script mode match eager and that it
        can be exported
        """
        sm = torch.jit.script(nn_module)

        with freeze_rng_state():
            eager_out = nn_module(*args)

        with freeze_rng_state():
            script_out = sm(*args)

        self.assertEqual(eager_out, script_out)
        self.assertExportImportModule(sm, args)

        return sm
Example #9
0
    def assert_export_import_module(m, args):
        """Check that the results of a model are the same after saving and loading"""
        def get_export_import_copy(m):
            """Save and load a TorchScript model"""
            with TemporaryDirectory() as dir:
                path = os.path.join(dir, "script.pt")
                m.save(path)
                imported = torch.jit.load(path)
            return imported

        m_import = get_export_import_copy(m)
        with torch.no_grad(), freeze_rng_state():
            results = m(*args)
        with torch.no_grad(), freeze_rng_state():
            results_from_imported = m_import(*args)
        tol = 3e-4
        torch.testing.assert_close(results,
                                   results_from_imported,
                                   atol=tol,
                                   rtol=tol)
Example #10
0
    def assert_export_import_module(m, args):
        """Check that the results of a model are the same after saving and loading"""
        def get_export_import_copy(m):
            """Save and load a TorchScript model"""
            buffer = io.BytesIO()
            torch.jit.save(m, buffer)
            buffer.seek(0)
            imported = torch.jit.load(buffer)
            return imported

        m_import = get_export_import_copy(m)
        with freeze_rng_state():
            results = m(*args)
        with freeze_rng_state():
            results_from_imported = m_import(*args)
        tol = 3e-4
        torch.testing.assert_close(results,
                                   results_from_imported,
                                   atol=tol,
                                   rtol=tol)
def do_legacy_test(model, print_diff):
    torch._C._jit_set_profiling_executor(False)
    torch._C._jit_set_profiling_mode(False)
    logging.basicConfig(filename='jit_' + model.replace('/', '-') +
                        str(int(time.time())) + '.log',
                        filemode='w',
                        level=logging.DEBUG)
    logging.info("loading %s", model)
    jm = torch.jit.load(model)
    logging.info("evaling %s", model)
    jm.eval()
    logging.info("running legacy %s", model)
    with freeze_rng_state():
        po = jm()
    logging.info("running legacy %s", model)
    with freeze_rng_state():
        po2 = jm()
    if not test_allclose(po, po2):
        logging.error("legacy and legacy2 outputs aren't equal")
        if (print_diff):
            logging.error("po : %s", str(po))
            logging.error("po2 : %s", str(po2))
Example #12
0
 def runAndSaveRNG(self, func, inputs, kwargs=None):
     kwargs = kwargs if kwargs else {}
     with freeze_rng_state():
         results = func(*inputs, **kwargs)
     return results