Exemple #1
0
    def _test_output_bf16(self, model, x):
        modelName = model.__class__.__name__

        core.enable_auto_dnnl()
        core.enable_jit()
        core.disable_mix_bf16_fp32()

        model = model.to('dpcpp').eval()
        x = x.to('dpcpp')
        x2 = x.clone()

        fused_model = torch.jit.script(copy.deepcopy(model))

        # bn folding, removing it after solve some issue, using mix_preci? to check
        core.disable_auto_dnnl()
        fused_model = wrap_cpp_module(
            torch._C._jit_pass_fold_convbn(fused_model._c))
        core.enable_auto_dnnl()

        core.enable_mix_bf16_fp32()
        # prepack convolution weight, weight will be a bf16 tensor
        fused_model = wrap_cpp_module(
            core._jit_prepack_conv_weight(fused_model._c))
        with torch.no_grad():
            # bf16, native path
            result = model(x)
            # bf16, jit path
            fresult = fused_model(x2)

        #print(result)
        #print(fresult)

        self.assertEqual(fresult, result)
Exemple #2
0
    def _test_output(self,
                     model,
                     x,
                     kind_in_graph=None,
                     kind_not_in_graph=None):
        modelName = model.__class__.__name__
        core.disable_jit_opt()
        core.disable_mix_bf16_fp32()

        model = model.to(device).eval()
        x = x.to(device)
        with torch.no_grad():
            result = model(x)

        script_model = torch.jit.script(model)
        script_model.eval()

        trace_model = torch.jit.trace(model, x)
        trace_model.eval()
        with torch.no_grad():
            sresult = script_model(x)
            tresult = trace_model(x)

        self.assertEqual(result, sresult)
        self.assertEqual(result, tresult)

        core.enable_jit_opt()
        script_fused_model = torch.jit.script(model)
        trace_fused_model = torch.jit.trace(model, x)
        with torch.no_grad():
            # conv relu fusion, conv sum fusion or conv sum relu fusion
            script_graph = script_fused_model.graph_for(x)
            fused_sresult = script_fused_model(x)

            trace_graph = trace_fused_model.graph_for(x)
            fused_tresult = trace_fused_model(x)

        self.assertEqual(result, fused_sresult)
        self.assertEqual(result, fused_tresult)

        # check if the fused node exists in the graph
        if kind_in_graph is not None:
            self.assertTrue(
                any(n.kind() == kind_in_graph for n in script_graph.nodes()))
            self.assertTrue(
                any(n.kind() == kind_in_graph for n in trace_graph.nodes()))

        # check if certain node does not exist in the graph
        if kind_not_in_graph is not None:
            self.assertTrue(
                all(n.kind() != kind_not_in_graph
                    for n in script_graph.nodes()))
            self.assertTrue(
                all(n.kind() != kind_not_in_graph
                    for n in trace_graph.nodes()))
Exemple #3
0
    def _test_output_bf16(self,
                          model,
                          x,
                          kind_in_graph=None,
                          kind_not_in_graph=None,
                          prec=None):
        modelName = model.__class__.__name__

        core.enable_auto_dnnl()
        core.enable_jit_opt()
        core.enable_mix_bf16_fp32()

        model = model.to(ipex.DEVICE).eval()
        x = x.to(ipex.DEVICE)
        x2 = x.clone()
        x3 = x.clone()

        script_fused_model = torch.jit.script(copy.deepcopy(model))
        trace_fused_model = torch.jit.trace(copy.deepcopy(model), x3)

        with torch.no_grad():
            # bf16, native path
            result = model(x)
            # bf16, jit script path
            script_graph = script_fused_model.graph_for(x2)
            fused_sresult = script_fused_model(x2)
            # bf 16, jit trace path
            trace_graph = trace_fused_model.graph_for(x3)
            fused_tresult = trace_fused_model(x3)

        # disable mix_bf16_fp32 when the calculation is done
        # to avoid affecting other scripts
        core.disable_mix_bf16_fp32()

        self.assertEqual(fused_sresult, result, prec=prec)
        self.assertEqual(fused_tresult, result, prec=prec)

        # check if the fused node exists in the graph
        if kind_in_graph is not None:
            self.assertTrue(
                any(n.kind() == kind_in_graph for n in script_graph.nodes()))
            self.assertTrue(
                any(n.kind() == kind_in_graph for n in trace_graph.nodes()))

        # check if certain node does not exist in the graph
        if kind_not_in_graph is not None:
            self.assertTrue(
                all(n.kind() != kind_not_in_graph
                    for n in script_graph.nodes()))
            self.assertTrue(
                all(n.kind() != kind_not_in_graph
                    for n in trace_graph.nodes()))