예제 #1
0
def ephemeral_torchglow_settings(
    fp16=False, backend=DEFAULT_BACKEND, fusion=False, blocklist=None
):
    old_fp16 = torch_glow.get_convert_to_fp16()
    old_clip = torch_glow.get_clip_fp16()
    old_convert_fused = torch_glow.get_convert_fused_to_fp16()
    old_backend = torch_glow.getGlowBackendName()
    old_blocklist = torch_glow.getFusionBlacklist()
    old_fusion = torch_glow.getFusionPassEnabled()
    try:
        if fusion:
            torch_glow.enableFusionPass()
        else:
            torch_glow.disableFusionPass()
        if fp16:
            torch_glow.enable_convert_to_fp16()
            torch_glow.enable_convert_fused_to_fp16()
            torch_glow.enable_clip_fp16()
        else:
            torch_glow.disable_convert_to_fp16()
            torch_glow.disable_convert_fused_to_fp16()
            torch_glow.disable_clip_fp16()
        if blocklist is None:
            torch_glow.clearFusionBlacklist()
        else:
            torch_glow.setFusionBlacklist(list(blocklist))
        torch_glow.setGlowBackend(backend)
        yield
    finally:
        torch_glow.enable_convert_to_fp16() if old_fp16 else torch_glow.disable_convert_to_fp16()
        torch_glow.enable_clip_fp16() if old_clip else torch_glow.disable_clip_fp16()
        torch_glow.enable_convert_fused_to_fp16() if old_convert_fused else torch_glow.disable_convert_fused_to_fp16()
        torch_glow.enableFusionPass() if old_fusion else torch_glow.disableFusionPass()
        torch_glow.setGlowBackend(old_backend)
        torch_glow.setFusionBlacklist(old_blocklist)
    def test_save_preprocessed_module(self):
        with torch.no_grad():
            x = torch.randn([1, 4, 4, 4], dtype=torch.float32)
            model = Bar()
            model.eval()
            model = torch.jit.trace(model, x)

            spec = torch_glow.CompilationSpec()
            spec.get_settings().set_glow_backend("Interpreter")

            compilation_group = torch_glow.CompilationGroup()
            spec.compilation_groups_append(compilation_group)

            compilation_group.input_sets_append(
                torch_glow.input_specs_from_tensors([x]))

            torch_glow.disableFusionPass()
            torch_glow.enable_convert_to_fp16()
            glow_mod = torch_glow.to_glow(model, spec)

            reloaded = utils.save_and_reload_model(glow_mod)

            wrappername = "__loweredModule__"
            attrname = "__processed_module"
            wp = getattr(reloaded._c, wrappername)
            pp = getattr(wp, attrname)
            pt_model = torch.jit._recursive.wrap_cpp_module(pp)
            graph = pt_model.graph_for(x)
            found = False
            for node in graph.nodes():
                if node.kind() == "quantized::conv2d":
                    found = True

            assert found
예제 #3
0
    def test_batchnorm_relu_basic(self):
        """
        Basic test of the PyTorch 3D batchnorm RELU Node on Glow.
        """

        class SimpleQuantizedBatchNormRelu(nn.Module):
            def __init__(self, w, b, m, v):
                super(SimpleQuantizedBatchNormRelu, self).__init__()
                self.bn = torch.nn.BatchNorm3d(4)
                self.relu = torch.nn.ReLU()
                self.bn.weight = torch.nn.Parameter(w)
                self.bn.bias = torch.nn.Parameter(b)
                self.bn.running_mean = m
                self.bn.running_var = v
                self.q = QuantStub()
                self.dq = DeQuantStub()

            def forward(self, x):
                qx = self.q(x)
                qy = self.bn(qx)
                qy_relu = self.relu(qy)
                y = self.dq(qy_relu)
                return y

        C = 4
        weight = torch.ones(C) + torch.rand(C) * 0.001
        bias = torch.rand(C) * 0.0001
        running_mean = torch.zeros(C)
        running_var = torch.ones(C)

        inputs = torch.randn((10, C, 2, 3, 4), requires_grad=False)
        model = SimpleQuantizedBatchNormRelu(weight, bias, running_mean, running_var)
        model.eval()
        model.qconfig = my_qconfig
        modules_to_fuse = [["bn", "relu"]]
        fuse_modules(model, modules_to_fuse, inplace=True)
        prepare(model, inplace=True)
        model.forward(inputs)
        convert(model, inplace=True)

        torch_glow.enable_convert_to_fp16()
        # Because of the difference of quantization between PyTorch & Glow
        # We set eps big enough.
        # Batchnorm introduced great accuracy issues, which could create up to
        # ~1e-2 difference in some rare cases. In order to prevent this test
        # to be flaky, atol is set to be 0.1.
        jitVsGlow(
            model,
            inputs,
            expected_fused_ops={"quantized::batch_norm3d_relu"},
            atol=1e-1,
        )
예제 #4
0
파일: utils.py 프로젝트: wwwyiwenchen/glow
def scriptVsGlow(
    f,
    atol,
    rtol,
    *inputs,
    expected_fused_ops=None,
    accept_all_ops=False,
    black_list=None,
    use_fp16=False,
    backend_name=None,
):
    if black_list is None:
        black_list = []
    with torch.no_grad():

        torch_res = f(*inputs)

        torch_glow.enableFusionPass()
        torch_glow.setFusionBlacklist(black_list)

        if use_fp16:
            torch_glow.enable_convert_to_fp16()
            torch_glow.enable_convert_fused_to_fp16()
            torch_glow.enable_clip_fp16()
        else:
            torch_glow.disable_convert_to_fp16()
            torch_glow.disable_convert_fused_to_fp16()
            torch_glow.disable_clip_fp16()

        if backend_name:
            torch_glow.setGlowBackend(backend_name)
        else:
            torch_glow.setGlowBackend("Interpreter")

        glow_trace = torch.jit.script(f)
        glow_res = glow_trace(*inputs)

        glow_graph = glow_trace.graph_for(*inputs)
        print("glow_graph,", glow_graph)

        # need to explicitly clear settings to avoid carry-over static settings
        torch_glow.disableFusionPass()
        torch_glow.disable_convert_to_fp16()
        torch_glow.disable_convert_fused_to_fp16()
        torch_glow.disable_clip_fp16()
        torch_glow.setGlowBackend("Interpreter")

    checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops)
    checkResult(torch_res, glow_res, atol, rtol)
예제 #5
0
    def test_batchnorm_with_weights(self):
        """
        Test of the PyTorch 2D batchnorm Node with weights and biases on Glow.
        """

        class SimpleQuantizedBatchNorm(nn.Module):
            def __init__(
                self, C, weight, bias, running_mean, running_var, scale, zero_point
            ):
                super(SimpleQuantizedBatchNorm, self).__init__()
                self.qconfig = my_qconfig
                self.batchnorm = nn.quantized.BatchNorm3d(C)
                self.batchnorm.scale = scale
                self.batchnorm.zero_point = zero_point
                self.batchnorm.weight = torch.nn.Parameter(weight)
                self.batchnorm.bias = torch.nn.Parameter(bias)
                self.batchnorm.running_mean = running_mean
                self.batchnorm.running_var = running_var
                self.relu = torch.nn.ReLU()
                self.dq = torch.nn.quantized.DeQuantize()

            def forward(self, x):
                return self.dq(self.relu(self.batchnorm(x)))

        C = 7
        in_scale = out_scale = 0.0047
        in_zero_point = out_zero_point = -7
        weight = torch.ones(C) + torch.rand(C) * 0.001
        bias = torch.rand(C) * 0.0001
        running_mean = torch.zeros(C)
        running_var = torch.ones(C)

        inputs = torch.randn(6, C, 4, 33, 42)
        inputs = torch.quantize_per_tensor(
            inputs, scale=in_scale, zero_point=in_zero_point, dtype=torch.qint8
        )
        model = SimpleQuantizedBatchNorm(
            C, weight, bias, running_mean, running_var, out_scale, out_zero_point
        )
        model.eval()

        torch_glow.enable_convert_to_fp16()
        jitVsGlow(model, inputs, expected_fused_ops={"quantized::batch_norm3d"})
예제 #6
0
    def test_serialization(self):
        with torch.no_grad():
            x = torch.randn([1, 4, 4, 4], dtype=torch.float32)
            y = torch.randn([1, 4, 4, 4], dtype=torch.float32)
            model = Bar()
            model = torch.jit.trace(model, (x, y))

            spec = torch_glow.CompilationSpec()
            spec_settings = spec.get_settings()
            spec_settings.set_glow_backend("NNPI")
            # Enabled the serialize in this spec
            spec_settings.set_enable_serialize(True)

            compilation_group = torch_glow.CompilationGroup()
            compilation_group_settings = compilation_group.get_settings()
            compilation_group_settings.set_replication_count(1)
            compilation_group_settings.backend_specific_opts_insert(
                "NNPI_IceCores", "1")

            compilation_group.input_sets_append(
                torch_glow.input_specs_from_tensors([x, y]))

            spec.compilation_groups_append(compilation_group)
            torch_glow.disableFusionPass()
            torch_glow.enable_convert_to_fp16()

            # Enable global serialize
            # then compile(serialize) the model and save it
            torch_glow.enable_dump_serialized_model()
            glow_mod = torch_glow.to_glow(model, spec)
            res1 = glow_mod(x, y)
            torch.jit.save(glow_mod, "/tmp/serialize_to_glow.pt")

            # Enable global deserialize and disable serialize
            # and load(deserialize) the model to loaded_glow_mod
            torch_glow.enable_deserialize()
            torch_glow.disable_dump_serialized_model()
            loaded_glow_mod = torch.jit.load("/tmp/serialize_to_glow.pt")
            res2 = loaded_glow_mod(x, y)
            assert torch.allclose(res1, res2, 1e-5, 1e-5)
예제 #7
0
    def test_batchnorm_basic(self):
        """
        Basic test of the PyTorch 3D batchnorm Node on Glow.
        """

        class SimpleQuantizedBatchNorm(nn.Module):
            def __init__(self, C, running_mean, running_var, scale, zero_point):
                super(SimpleQuantizedBatchNorm, self).__init__()
                self.qconfig = my_qconfig
                self.batchnorm = nn.quantized.BatchNorm3d(C)
                self.batchnorm.scale = scale
                self.batchnorm.zero_point = zero_point
                self.batchnorm.running_mean = running_mean
                self.batchnorm.running_var = running_var
                self.relu = torch.nn.ReLU()
                self.dq = torch.nn.quantized.DeQuantize()

            def forward(self, x):
                return self.dq(self.relu(self.batchnorm(x)))

        C = 4
        in_scale = out_scale = 0.004
        in_zero_point = out_zero_point = 4
        running_mean = torch.zeros(C)
        running_var = torch.ones(C)

        inputs = torch.randn((5, C, 6, 32, 73), requires_grad=False)
        inputs = torch.quantize_per_tensor(
            inputs, scale=in_scale, zero_point=in_zero_point, dtype=torch.qint8
        )
        model = SimpleQuantizedBatchNorm(
            C, running_mean, running_var, out_scale, out_zero_point
        )
        model.eval()

        torch_glow.enable_convert_to_fp16()
        jitVsGlow(model, inputs, expected_fused_ops={"quantized::batch_norm3d"})
예제 #8
0
파일: utils.py 프로젝트: wwwyiwenchen/glow
def traceVsGlow(
    f_torch,
    f_glow,
    check_trace,
    atol,
    rtol,
    *inputs,
    expected_fused_ops=None,
    accept_all_ops=False,
    black_list=None,
    use_fp16=False,
    backend_name=None,
):
    if black_list is None:
        black_list = []
    with torch.no_grad():
        torch_glow.disableFusionPass()

        torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace)
        torch_res = torch_trace(*inputs)

        torch_glow.enableFusionPass()
        torch_glow.setFusionBlacklist(black_list)

        if use_fp16:
            torch_glow.enable_convert_to_fp16()
            torch_glow.enable_convert_fused_to_fp16()
            torch_glow.enable_clip_fp16()
        else:
            torch_glow.disable_convert_to_fp16()
            torch_glow.disable_convert_fused_to_fp16()
            torch_glow.disable_clip_fp16()

        if backend_name:
            torch_glow.setGlowBackend(backend_name)
        else:
            torch_glow.setGlowBackend("Interpreter")

        glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace)
        glow_res = glow_trace(*inputs)

        # check that there are no Glow nodes in the torch graph
        torch_graph = torch_trace.graph_for(*inputs)
        print("torch_graph,", torch_graph)

        num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME))
        assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format(
            num_glow_nodes)

        glow_graph = glow_trace.graph_for(*inputs)
        print("glow_graph,", glow_graph)

        # need to explicitly clear settings to avoid carry-over static settings
        torch_glow.disableFusionPass()
        torch_glow.disable_convert_to_fp16()
        torch_glow.disable_convert_fused_to_fp16()
        torch_glow.disable_clip_fp16()
        torch_glow.setGlowBackend("Interpreter")

    checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops)
    checkResult(torch_res, glow_res, atol, rtol)