Exemple #1
0
def test_grad_with_multiple_inputs():
    if not pytorch_pfn_extras.requires("1.8.0"):
        pytest.skip('skip for PyTorch 1.7 or earlier')

    if pytorch_pfn_extras.requires('1.10.0') and sys.platform == 'win32':
        pytest.skip(
            'ONNX grad test does not work in windows CI for torch >= 1.10')

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv = nn.Conv2d(2, 6, 3)
            self.linear = nn.Linear(32, 20, bias=False)

        def forward(self, x):
            x0 = x * 0.5
            x1 = x * 2.0
            x0.requires_grad_(True)
            x1.requires_grad_(True)
            h = self.conv(torch.cat([x0, x1], dim=1))
            grad_x = grad(
                h,
                (x0, x1),
                retain_graph=True,
                create_graph=True,
            )
            h = self.linear(grad_x[0])
            return h

    model = Net()
    x = torch.ones((1, 1, 32, 32))
    output_dir = _helper(
        model,
        x,
        'grad',
        enable_onnx_checker=False,
        use_pfto=False,
    )

    actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx'))
    named_nodes = {n.name: n for n in actual_onnx.graph.node}
    assert 'Conv_5' in named_nodes
    assert 'Gradient_6' in named_nodes
    assert 'MatMul_8' in named_nodes

    assert list([v.name for v in actual_onnx.graph.output]) == [
        "v13_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_x_1_0"
    ]
    assert named_nodes["Concat_4"].input[0] == "Gradient_x_0_0"
    assert named_nodes["Concat_4"].input[1] == "Gradient_x_1_0"
    assert named_nodes["Conv_5"].output[0] == "Gradient_y_0"
Exemple #2
0
def test_no_as_output():
    if not pytorch_pfn_extras.requires("1.8.0"):
        pytest.skip('skip for PyTorch 1.7 or earlier')

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv = nn.Conv2d(1, 6, 3)
            self.linear = nn.Linear(30, 20, bias=False)

        def forward(self, x):
            h = self.conv(x)
            h = self.linear(h)
            return h

    model = Net()
    x = torch.ones((1, 1, 32, 32))
    output_dir = _helper(model, x, 'as_output')

    actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx'))
    named_nodes = {n.name: n for n in actual_onnx.graph.node}
    assert 'Conv_0' in named_nodes
    assert 'MatMul_2' in named_nodes

    assert len([v.name for v in actual_onnx.graph.output]) == 1
Exemple #3
0
    def optimize_onnx(self, graph: torch._C.Graph) -> torch._C.Graph:
        if pytorch_pfn_extras.requires("1.9.0"):
            self.run_jit_pass(torch._C._jit_pass_onnx_scalar_type_analysis, graph, self.onnx_lowprecision_cast, self.opset_version)
        else:
            self.run_jit_pass(torch._C._jit_pass_onnx_scalar_type_analysis, graph)

        if self.do_constant_folding and self.opset_version in torch.onnx.constant_folding_opset_versions:
            folded: Dict[str, torch.IValue] = torch._C._jit_pass_onnx_constant_fold(  # type: ignore[attr-defined]
                graph, self.vars, self.opset_version
            )
            # Replace input with constant nodes
            input_table: Dict[str, torch._C.Value] = {i.debugName(): i for i in graph.inputs()}
            for k, t in folded.items():
                c: torch._C.Value = graph.create("onnx::Constant", 1).output()
                assert isinstance(t, torch.Tensor)
                c.node().t_("value", cast(torch.Tensor, t))
                graph.prependNode(c.node())
                # TODO(twata): Determine folded nodes from original graph and document it
                self.node_doc_string[c.node()] = f"Constant folded node: {input_table[k]}"
                input_table[k].replaceAllUsesWith(c)
                c.copyMetadata(input_table[k])
                self.attrs[_unique_id(c)] = ONNXValueID(k)
                self.vars[k] = t
                del input_table[k]
            for _ in range(len(list(graph.inputs())) - len(input_table)):
                graph.eraseInput(len(input_table))
            torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)  # type: ignore[attr-defined]

        if self.onnx_peephole:
            self.run_jit_pass(torch._C._jit_pass_onnx_peephole, graph, self.opset_version, self.fixed_batch_size)

        return graph
Exemple #4
0
def test_grad_no_export():
    if not pytorch_pfn_extras.requires("1.8.0"):
        pytest.skip('skip for PyTorch 1.7 or earlier')

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv = nn.Conv2d(1, 6, 3)
            self.linear = nn.Linear(32, 20, bias=False)

        def forward(self, x):
            x = x * 0.5
            x.requires_grad_(True)
            h = self.conv(x)
            grad_x = grad(
                h,
                (x, ),
                retain_graph=True,
                create_graph=True,
            )[0]
            h = self.linear(grad_x)
            return h

    model = Net()
    x = torch.ones((1, 1, 32, 32))
    y = model(x)
    assert y.shape == (1, 1, 32, 20)
Exemple #5
0
def test_annotate():
    if not pytorch_pfn_extras.requires("1.8.0"):
        pytest.skip('skip for PyTorch 1.7 or earlier')

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv = nn.Conv2d(1, 6, 3)
            self.conv2 = nn.Conv2d(6, 12, 3)
            self.linear = nn.Linear(28, 10, bias=False)
            self.linear2 = nn.Linear(10, 5, bias=False)

        def forward(self, x):
            with annotate(aaa='a', bbb=['b', 'c']):
                h = self.conv(x)
            h = self.conv2(h)
            with annotate(zzz=99, yyy=[9, 9]):
                h = self.linear(h)
                h = self.linear2(h)
            return h

    model = Net()
    x = torch.ones((1, 1, 32, 32))
    output_dir = _helper(model, x, 'annotate', use_pfto=False)

    actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx'))
    named_nodes = {n.name: n for n in actual_onnx.graph.node}
    assert 'Conv_0' in named_nodes
    assert 'Conv_1' in named_nodes

    assert 'MatMul_3' in named_nodes
    assert 'MatMul_5' in named_nodes

    node_conv_0_attrs = [a.name for a in named_nodes['Conv_0'].attribute]
    assert 'aaa' in node_conv_0_attrs
    assert 'bbb' in node_conv_0_attrs
    assert 'zzz' not in node_conv_0_attrs
    assert 'yyy' not in node_conv_0_attrs
    node_conv_1_attrs = [a.name for a in named_nodes['Conv_1'].attribute]
    assert 'aaa' not in node_conv_1_attrs
    assert 'bbb' not in node_conv_1_attrs
    assert 'zzz' not in node_conv_1_attrs
    assert 'yyy' not in node_conv_1_attrs
    node_matmul_2_attrs = [a.name for a in named_nodes['MatMul_3'].attribute]
    assert 'aaa' not in node_matmul_2_attrs
    assert 'bbb' not in node_matmul_2_attrs
    assert 'zzz' in node_matmul_2_attrs
    assert 'yyy' in node_matmul_2_attrs
    node_matmul_5_attrs = [a.name for a in named_nodes['MatMul_5'].attribute]
    assert 'aaa' not in node_matmul_5_attrs
    assert 'bbb' not in node_matmul_5_attrs
    assert 'zzz' in node_matmul_5_attrs
    assert 'yyy' in node_matmul_5_attrs
Exemple #6
0
def _export_util(
    model: torch.nn.Module,
    args: Sequence[Any],
    f: IO,
    **kwargs: Any,
) -> Any:
    """Wrap operator type to export

    Copied from torch.onnx.utils.export, to get output values.
    """
    aten = kwargs.get('aten', False)
    export_raw_ir = kwargs.get('export_raw_ir', False)
    operator_export_type = kwargs.get('operator_export_type', None)

    if aten or export_raw_ir:
        assert operator_export_type is None
        assert aten ^ export_raw_ir
        # Note: OperatorExportTypes.RAW unavailable in PyTorch 1.10+
        operator_export_type = OperatorExportTypes.ONNX_ATEN if\
            aten else OperatorExportTypes.RAW  # type: ignore
    elif operator_export_type is None:
        if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
            operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
        else:
            operator_export_type = OperatorExportTypes.ONNX

    old_model_to_graph = torch.onnx.utils._model_to_graph
    # TODO(ecastill) _model_to_graph shouldn't be direclty overriden
    # This is a temporal workaround until a fix is introduced in PyTorch.
    try:
        torch.onnx.utils._model_to_graph = _model_to_graph_with_value_names
        if pytorch_pfn_extras.requires('1.10.0'):
            checker_error = getattr(torch.onnx, "CheckerError", None)
            if checker_error is None:
                checker_error = torch.onnx.utils.ONNXCheckerError  # type: ignore[attr-defined]
            try:
                enable_onnx_checker = kwargs.pop('enable_onnx_checker', None)
                return torch_export(  # type: ignore[no-untyped-call]
                    model, args, f, **kwargs)
            except checker_error:
                if enable_onnx_checker:
                    raise
        else:
            kwargs['_retain_param_name'] = True
            return torch_export(  # type: ignore[no-untyped-call]
                model, args, f, **kwargs)
    finally:
        torch.onnx.utils._model_to_graph = old_model_to_graph
Exemple #7
0
def test_export_testcase_return_output():
    model = nn.Sequential(nn.Linear(5, 10, bias=False))
    x = torch.zeros((2, 5))

    output_dir = _get_output_dir('export_filename')

    if pytorch_pfn_extras.requires("1.6.0"):
        with pytest.warns(UserWarning):
            (out, ) = export_testcase(model, x, output_dir, return_output=True)
    else:
        (out, ) = export_testcase(model, x, output_dir, return_output=True)

    assert os.path.isfile(os.path.join(output_dir, 'model.onnx'))
    expected_out = torch.zeros((2, 10))  # check only shape size
    np.testing.assert_allclose(out.detach().cpu().numpy(),
                               expected_out.detach().cpu().numpy())
Exemple #8
0
def test_export_testcase_with_unused_input(keep_initializers_as_inputs):
    if not pytorch_pfn_extras.requires("1.7.0"):
        pytest.skip('skip for PyTorch 1.6 or earlier')

    model = NetWithUnusedInput().to('cpu')
    x = torch.zeros((1, 1, 28, 28))
    unused = torch.zeros((1, ))

    # Without input_names
    output_dir = _helper(
        model,
        args=(x, unused),
        d='net_with_unused_input_without_input_names',
        opset_version=11,
        strip_doc_string=False,
        keep_initializers_as_inputs=keep_initializers_as_inputs)
    assert os.path.isdir(output_dir)
    test_data_set_dir = os.path.join(output_dir, 'test_data_set_0')
    assert os.path.exists(os.path.join(test_data_set_dir, 'input_0.pb'))
    assert not os.path.exists(os.path.join(test_data_set_dir, 'input_1.pb'))

    xmodel = onnx.load_model(os.path.join(output_dir, 'model.onnx'))
    assert xmodel.graph.input[0].name == 'input_0'
    assert len(xmodel.graph.input) == 1 or \
        xmodel.graph.input[1].name != 'input_1'

    # With input_names
    output_dir = _helper(
        model,
        args=(x, unused),
        d='net_with_unused_input_with_input_names',
        opset_version=11,
        strip_doc_string=False,
        keep_initializers_as_inputs=keep_initializers_as_inputs,
        input_names=['x', 'unused'])
    assert os.path.isdir(output_dir)
    test_data_set_dir = os.path.join(output_dir, 'test_data_set_0')
    assert os.path.exists(os.path.join(test_data_set_dir, 'input_0.pb'))
    assert not os.path.exists(os.path.join(test_data_set_dir, 'input_1.pb'))

    xmodel = onnx.load_model(os.path.join(output_dir, 'model.onnx'))
    assert xmodel.graph.input[0].name == 'x'
    assert len(xmodel.graph.input) == 1 or \
        xmodel.graph.input[1].name != 'unused'
Exemple #9
0
def test_as_output_no_export():
    if not pytorch_pfn_extras.requires("1.8.0"):
        pytest.skip('skip for PyTorch 1.7 or earlier')

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv = nn.Conv2d(1, 6, 3)
            self.linear = nn.Linear(30, 20, bias=False)

        def forward(self, x):
            h = self.conv(x)
            h = as_output("h", h)
            h = self.linear(h)
            return h

    model = Net()
    x = torch.ones((1, 1, 32, 32))
    y = model(x)
    assert y.shape == (1, 6, 30, 20)
def _export(
        model: torch.nn.Module,
        args: Sequence[Any],
        strip_large_tensor_data: bool = False,
        large_tensor_threshold: int = LARGE_TENSOR_DATA_THRESHOLD,
        use_pfto: bool = False,
        **kwargs: Any,
) -> Tuple[onnx.ModelProto, Any]:
    model.zero_grad()
    bytesio = io.BytesIO()
    opset_ver = kwargs.get('opset_version', None)
    if opset_ver is None:
        opset_ver = _default_onnx_opset_version
        kwargs['opset_version'] = opset_ver
    if use_pfto or not pytorch_pfn_extras.requires('1.10.0'):
        strip_doc_string = kwargs.get('strip_doc_string', True)
        kwargs['strip_doc_string'] = False
    else:
        strip_doc_string = kwargs.pop('strip_doc_string', True)
        kwargs['verbose'] = True
    with init_annotate(model, opset_ver) as ann, \
            as_output.trace(model) as (model, outputs), \
            grad.init_grad_state():
        if use_pfto:
            outs = pfto_export(
                model, args, bytesio, **kwargs)
        else:
            outs = _export_util(
                model, args, bytesio, **kwargs)
        onnx_graph = onnx.load(io.BytesIO(bytesio.getvalue()))
        onnx_graph = ann.set_annotate(onnx_graph)
        onnx_graph = ann.reorg_anchor(onnx_graph)
        outputs.add_outputs_to_model(onnx_graph)
        if strip_doc_string:
            for node in onnx_graph.graph.node:
                node.doc_string = b''

    if strip_large_tensor_data:
        _strip_large_initializer_raw_data(onnx_graph, large_tensor_threshold)

    return onnx_graph, outs
class TestDistributedDataParallel:
    def test_save_load(self):
        module = MyModule()
        with_ddp = DistributedDataParallel(module)
        assert module.state_dict().keys() == with_ddp.state_dict().keys()
        module.load_state_dict(with_ddp.state_dict())
        assert np.array_equal(module.state_dict()["param0"],
                              with_ddp.state_dict()["param0"])
        assert np.array_equal(module.state_dict()["param1"],
                              with_ddp.state_dict()["param1"])
        assert np.array_equal(module.state_dict()["buffer"],
                              with_ddp.state_dict()["buffer"])

    @pytest.mark.parametrize('device_type', _device_types())
    def test_sync_init_params(self, device_type):
        module0 = MyModule()
        module0.param0.data = torch.tensor([1.])
        r0, r1 = _launch(inputs=[torch.tensor([1.]),
                                 torch.tensor([2.])],
                         modules=[module0, MyModule()],
                         device_type=device_type)
        assert r0[0].item() == 1
        assert r1[0].item() == 2
        assert r0[1]["param0"].item() == 1.0
        assert r1[1]["param0"].item() == 1.0

    @pytest.mark.parametrize('device_type', _device_types())
    def test_all_reduce(self, device_type):
        r0, r1 = _launch(inputs=[torch.tensor([1.]),
                                 torch.tensor([2.])],
                         device_type=device_type)
        assert r0[0].item() == -1
        assert r1[0].item() == -2
        assert r0[2]["module.param0"].item() == 1.5
        assert r1[2]["module.param0"].item() == 1.5
        assert r0[2]["module.param1"] is None
        assert r1[2]["module.param1"] is None

    @pytest.mark.parametrize('device_type', _device_types())
    def test_specific_reduce(self, device_type):
        r0, r1 = _launch(inputs=[torch.tensor([1.]),
                                 torch.tensor([2.])],
                         args={"reduce_function": Collectives._to_zero},
                         device_type=device_type)
        assert r0[2]["module.param0"].item() == 0.0
        assert r1[2]["module.param0"].item() == 0.0

    @pytest.mark.parametrize('device_type', _device_types())
    def test_nosync_buffer(self, device_type):
        r0, r1 = _launch(inputs=[torch.tensor([1.]),
                                 torch.tensor([2.])],
                         args={"broadcast_buffers": False},
                         device_type=device_type)
        assert r0[0].item() == -1
        assert r1[0].item() == -2
        assert r0[1]["buffer"].item() == 1
        assert r1[1]["buffer"].item() == 2

    @pytest.mark.parametrize('device_type', _device_types())
    def test_sync_buffer(self, device_type):
        r0, r1 = _launch(inputs=[torch.tensor([1.]),
                                 torch.tensor([2.])],
                         args={"broadcast_buffers": True},
                         device_type=device_type)
        assert r0[0].item() == -1
        assert r1[0].item() == -2
        assert r0[1]["buffer"].item() == 1
        assert r1[1]["buffer"].item() == 1

    @pytest.mark.parametrize('device_type', _device_types())
    def test_specific_broadcast(self, device_type):
        r0, r1 = _launch(inputs=[torch.tensor([1.]),
                                 torch.tensor([2.])],
                         args={
                             "broadcast_function": Collectives._to_zero,
                             "broadcast_buffers": True
                         },
                         device_type=device_type)
        assert r0[1]["buffer"].item() == 0.0
        assert r1[1]["buffer"].item() == 0.0

    @pytest.mark.parametrize('device_type', _device_types())
    def test_define_by_run(self, device_type):
        r0, r1 = _launch(inputs=[torch.tensor([1.]),
                                 torch.tensor([-1])],
                         device_type=device_type)
        assert r0[0].item() == -1
        assert r1[0].item() == 1
        assert r0[2]["module.param0"].item() == 0.5
        assert r1[2]["module.param0"].item() == 0.5
        assert r0[2]["module.param1"].item() == 0.5
        assert r1[2]["module.param1"].item() == 0.5

    @pytest.mark.parametrize('device_type', _device_types())
    def test_no_sync(self, device_type):
        r0, r1 = _launch(inputs=[torch.tensor([1.]),
                                 torch.tensor([2.])],
                         step=Steps._step_with_no_sync,
                         device_type=device_type)
        assert r0[0].item() == -1
        assert r1[0].item() == -2
        assert r0[2]["module.param0"].item() == 1
        assert r1[2]["module.param0"].item() == 2
        assert r0[2]["module.param1"] is None
        assert r1[2]["module.param1"] is None

    @pytest.mark.parametrize('device_type', _device_types())
    def test_hook(self, device_type):
        r0, r1 = _launch(inputs=[torch.tensor([1.]),
                                 torch.tensor([2.])],
                         step=Steps._step_with_hook,
                         device_type=device_type)
        assert r0[0].item() == -1
        assert r1[0].item() == -2
        assert r0[2]["module.param0"].item() == 0
        assert r1[2]["module.param0"].item() == 0
        assert r0[2]["module.param1"].item() == 0
        assert r1[2]["module.param1"].item() == 0

    @pytest.mark.parametrize('device_type', _device_types())
    @pytest.mark.skipif(
        not pytorch_pfn_extras.requires("1.6.0"),
        reason="Variable._execution_engine.queue_callback does not work "
        "with checkpointing when torch < 1.6.0")
    def test_checkpoint(self, device_type):
        r0, r1 = _launch(
            inputs=[torch.tensor([[1.]]),
                    torch.tensor([[2.]])],
            modules=[MyModuleWithCheckpoint(),
                     MyModuleWithCheckpoint()],
            step=Steps._step_with_hook,
            device_type=device_type)
        grad0 = r0[2]
        grad1 = r1[2]
        for key in grad0.keys():
            assert np.array_equal(grad0[key].cpu().numpy(),
                                  grad1[key].cpu().numpy())
Exemple #12
0
import os

import pytest
import torch

import pytorch_pfn_extras as ppe

_profiler_available = (os.name != 'nt' or ppe.requires("1.9"))


@pytest.mark.skipif(not _profiler_available,
                    reason="profiler is not available")
@pytest.mark.parametrize('device', ['cpu', 'cuda'])
def test_record(device):
    if not torch.cuda.is_available() and device == 'cuda':
        pytest.skip()
    model = torch.nn.Linear(30, 40)
    model.to(device)
    x = torch.arange(30, dtype=torch.float32).to(device)

    with torch.profiler.profile() as prof:
        with ppe.profiler.record('my_tag_1'):
            model(x)

    keys = [event.key for event in prof.key_averages()]
    assert 'my_tag_1' in keys
    assert 'aten::linear' in keys


@pytest.mark.skipif(not _profiler_available,
                    reason="profiler is not available")
Exemple #13
0
def test_export_testcase_strip_large_tensor_data():
    if not pytorch_pfn_extras.requires("1.6.0"):
        pytest.skip('skip for PyTorch 1.5 or earlier')

    model = Net().to('cpu')
    x = torch.zeros((1, 1, 28, 28))

    output_dir = _helper(model,
                         x,
                         'mnist_stripped_tensor_data',
                         output_grad=True,
                         strip_large_tensor_data=True,
                         metadata=True)

    assert os.path.isdir(output_dir)
    assert os.path.isfile(os.path.join(output_dir, 'meta.json'))
    assert os.path.isfile(os.path.join(output_dir, 'model.onnx'))
    test_data_set_dir = os.path.join(output_dir, 'test_data_set_0')
    assert os.path.isfile(os.path.join(test_data_set_dir, 'input_0.pb'))
    assert os.path.isfile(os.path.join(test_data_set_dir, 'output_0.pb'))

    for i in range(8):
        assert os.path.isfile(
            os.path.join(test_data_set_dir, 'gradient_{}.pb'.format(i)))
    assert not os.path.isfile(os.path.join(test_data_set_dir, 'gradient_8.pb'))

    with open(os.path.join(output_dir, 'meta.json')) as metaf:
        metaj = json.load(metaf)
        assert metaj['strip_large_tensor_data']

    def is_stripped_with_check(tensor):
        if is_large_tensor(tensor, LARGE_TENSOR_DATA_THRESHOLD):
            assert tensor.data_location == onnx.TensorProto.EXTERNAL
            assert tensor.external_data[0].key == 'location'
            meta = json.loads(tensor.external_data[0].value)
            assert meta['type'] == 'stripped'
            assert type(meta['average']) == float
            assert type(meta['variance']) == float
            return True
        assert len(tensor.external_data) == 0
        return False

    onnx_model = onnx.load(os.path.join(output_dir, 'model.onnx'),
                           load_external_data=False)

    check_stripped = [
        is_stripped_with_check(init) for init in onnx_model.graph.initializer
    ]
    # this testcase tests strip, so output mode is no stripped, test is failed
    assert any(check_stripped)

    for pb_filepath in ('input_0.pb', 'output_0.pb'):
        with open(os.path.join(test_data_set_dir, pb_filepath), 'rb') as f:
            tensor = onnx.TensorProto()
            tensor.ParseFromString(f.read())
            is_stripped_with_check(tensor)

    # check re-load stripped onnx
    _strip_large_tensor_tool_impl(os.path.join(output_dir, 'model.onnx'),
                                  os.path.join(output_dir, 'model_re.onnx'),
                                  LARGE_TENSOR_DATA_THRESHOLD)
    assert os.path.isfile(os.path.join(output_dir, 'model_re.onnx'))
    # loading check
    onnx.load(os.path.join(output_dir, 'model_re.onnx'),
              load_external_data=False)

    # unstrip test
    unstrip_output_dir = _get_output_dir('mnist_unstripped_tensor_data')
    unstrip(output_dir, out_path=unstrip_output_dir)

    def is_unstripped_with_check(tensor):
        if is_large_tensor(tensor, LARGE_TENSOR_DATA_THRESHOLD):
            assert tensor.data_location == onnx.TensorProto.DEFAULT
            return True
        return False

    onnx_paths = Path(unstrip_output_dir).glob('*.onnx')
    check_unstripped = []
    for onnx_path in onnx_paths:
        onnx_model = onnx.load(onnx_path, load_external_data=False)
        check_unstripped.extend([
            is_unstripped_with_check(init)
            for init in onnx_model.graph.initializer
        ])
    assert len(check_unstripped) > 0
    assert any(check_unstripped)

    pb_paths = Path(unstrip_output_dir).glob('**/*.pb')
    checked = False
    for pb_path in pb_paths:
        tensor = onnx.TensorProto()
        with open(pb_path, 'rb') as f:
            tensor.ParseFromString(f.read())
        checked = is_unstripped_with_check(tensor) or checked
    assert checked, 'more than one data is unstripped'

    with open(os.path.join(unstrip_output_dir, 'meta.json')) as metaf:
        metaj = json.load(metaf)
        assert not metaj['strip_large_tensor_data']
Exemple #14
0
    trainer = engine.create_trainer(model_with_loss, {
        "0": optimizer0,
        "1": optimizer1
    },
                                    20,
                                    device=device,
                                    evaluator=evaluator,
                                    extensions=extensions,
                                    out_dir=path,
                                    logic=ppe.handler.CodeBlockLogic())
    trainer.run(data, data)


@pytest.mark.skipif(
    os.name == 'nt' and not ppe.requires("1.9"),
    reason='torch.profiler.profile is not supported.',
)
def test_trainer_profile():
    device = 'cpu'
    model = MyModel()
    model_with_loss = MyModelWithLossDictOutput(model)
    ppe.to(model_with_loss, device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    data = torch.utils.data.DataLoader([{
        'x': torch.rand(20, ),
        't': torch.rand(10, )
    } for i in range(10)])
    extensions = _make_extensions()

    evaluator = engine.create_evaluator(model_with_loss, device=device)