Ejemplo n.º 1
0
    def _test_output(self,
                     model,
                     x,
                     kind_in_graph=None,
                     kind_not_in_graph=None):
        modelName = model.__class__.__name__
        core.disable_jit_opt()
        core.disable_mix_bf16_fp32()

        model = model.to(device).eval()
        x = x.to(device)
        with torch.no_grad():
            result = model(x)

        script_model = torch.jit.script(model)
        script_model.eval()

        trace_model = torch.jit.trace(model, x)
        trace_model.eval()
        with torch.no_grad():
            sresult = script_model(x)
            tresult = trace_model(x)

        self.assertEqual(result, sresult)
        self.assertEqual(result, tresult)

        core.enable_jit_opt()
        script_fused_model = torch.jit.script(model)
        trace_fused_model = torch.jit.trace(model, x)
        with torch.no_grad():
            # conv relu fusion, conv sum fusion or conv sum relu fusion
            script_graph = script_fused_model.graph_for(x)
            fused_sresult = script_fused_model(x)

            trace_graph = trace_fused_model.graph_for(x)
            fused_tresult = trace_fused_model(x)

        self.assertEqual(result, fused_sresult)
        self.assertEqual(result, fused_tresult)

        # check if the fused node exists in the graph
        if kind_in_graph is not None:
            self.assertTrue(
                any(n.kind() == kind_in_graph for n in script_graph.nodes()))
            self.assertTrue(
                any(n.kind() == kind_in_graph for n in trace_graph.nodes()))

        # check if certain node does not exist in the graph
        if kind_not_in_graph is not None:
            self.assertTrue(
                all(n.kind() != kind_not_in_graph
                    for n in script_graph.nodes()))
            self.assertTrue(
                all(n.kind() != kind_not_in_graph
                    for n in trace_graph.nodes()))
Ejemplo n.º 2
0
    def _test_output_int8(self,
                          model,
                          x,
                          kind_in_graph=None,
                          kind_not_in_graph=None,
                          prec=None):
        modelName = model.__class__.__name__

        core.enable_auto_dnnl()
        core.enable_jit_opt()
        model = model.to(ipex.DEVICE).eval()
        x = x.to(ipex.DEVICE)
        x2 = x.clone()
        x3 = x.clone()

        script_fused_model = torch.jit.script(copy.deepcopy(model))
        trace_fused_model = torch.jit.trace(copy.deepcopy(model), x3)

        with torch.no_grad():
            # int8, native path
            int8_calibration(model, [x], "configure.json")
            int8_conf = ipex.AmpConf(torch.int8, "configure.json")
            with ipex.AutoMixPrecision(int8_conf):
                result = model(x)
            # int8, jit script path
            script_graph = script_fused_model.graph_for(x2)
            int8_calibration(script_fused_model, [x2], "configure.json")
            int8_conf = ipex.AmpConf(torch.int8, "configure.json")
            with ipex.AutoMixPrecision(int8_conf):
                fused_sresult = script_fused_model(x2)
            # int8, jit trace path
            trace_graph = trace_fused_model.graph_for(x3)
            int8_calibration(trace_fused_model, [x3], "configure.json")
            int8_conf = ipex.AmpConf(torch.int8, "configure.json")
            with ipex.AutoMixPrecision(int8_conf):
                fused_tresult = trace_fused_model(x3)
        os.remove('configure.json')
        self.assertEqual(fused_sresult, result, prec)
        self.assertEqual(fused_tresult, result, prec)

        # check if the fused node exists in the graph
        if kind_in_graph is not None:
            self.assertTrue(
                any(n.kind() == kind_in_graph for n in script_graph.nodes()))
            self.assertTrue(
                any(n.kind() == kind_in_graph for n in trace_graph.nodes()))
        # check if certain node does not exist in the graph
        if kind_not_in_graph is not None:
            self.assertTrue(
                all(n.kind() != kind_not_in_graph
                    for n in script_graph.nodes()))
            self.assertTrue(
                all(n.kind() != kind_not_in_graph
                    for n in trace_graph.nodes()))
Ejemplo n.º 3
0
    def _test_output_bf16(self,
                          model,
                          x,
                          kind_in_graph=None,
                          kind_not_in_graph=None,
                          prec=None):
        modelName = model.__class__.__name__

        core.enable_auto_dnnl()
        core.enable_jit_opt()
        core.enable_mix_bf16_fp32()

        model = model.to(ipex.DEVICE).eval()
        x = x.to(ipex.DEVICE)
        x2 = x.clone()
        x3 = x.clone()

        script_fused_model = torch.jit.script(copy.deepcopy(model))
        trace_fused_model = torch.jit.trace(copy.deepcopy(model), x3)

        with torch.no_grad():
            # bf16, native path
            result = model(x)
            # bf16, jit script path
            script_graph = script_fused_model.graph_for(x2)
            fused_sresult = script_fused_model(x2)
            # bf 16, jit trace path
            trace_graph = trace_fused_model.graph_for(x3)
            fused_tresult = trace_fused_model(x3)

        # disable mix_bf16_fp32 when the calculation is done
        # to avoid affecting other scripts
        core.disable_mix_bf16_fp32()

        self.assertEqual(fused_sresult, result, prec=prec)
        self.assertEqual(fused_tresult, result, prec=prec)

        # check if the fused node exists in the graph
        if kind_in_graph is not None:
            self.assertTrue(
                any(n.kind() == kind_in_graph for n in script_graph.nodes()))
            self.assertTrue(
                any(n.kind() == kind_in_graph for n in trace_graph.nodes()))

        # check if certain node does not exist in the graph
        if kind_not_in_graph is not None:
            self.assertTrue(
                all(n.kind() != kind_not_in_graph
                    for n in script_graph.nodes()))
            self.assertTrue(
                all(n.kind() != kind_not_in_graph
                    for n in trace_graph.nodes()))