示例#1
0
                          opset_version=self.opset_version,
                          training=torch.onnx.TrainingMode.TRAINING)
        ort_sess = onnxruntime.InferenceSession(f.getvalue())
        ort_inputs = {ort_sess.get_inputs()[0].name: x.cpu().numpy()}
        ort_outs = ort_sess.run(None, ort_inputs)
        assert x != ort_outs[0]


# opset 10 tests
TestUtilityFuns_opset10 = type(
    str("TestUtilityFuns_opset10"), (TestCase, ),
    dict(TestUtilityFuns.__dict__, opset_version=10))

# opset 11 tests
TestUtilityFuns_opset11 = type(
    str("TestUtilityFuns_opset11"), (TestCase, ),
    dict(TestUtilityFuns.__dict__, opset_version=11))

# opset 12 tests
TestUtilityFuns_opset12 = type(
    str("TestUtilityFuns_opset12"), (TestCase, ),
    dict(TestUtilityFuns.__dict__, opset_version=12))

# opset 12tests
TestUtilityFuns_opset12 = type(
    str("TestUtilityFuns_opset12"), (TestCase, ),
    dict(TestUtilityFuns.__dict__, opset_version=12))

if __name__ == '__main__':
    run_tests()
示例#2
0
            assert isinstance(x, torch._C.Value)
            assert isinstance(y[0], torch._C.Value)
            assert isinstance(y[1], torch._C.Value)
            return g.op('Sum', x, y[0], y[1]), (
                g.op('Neg', x), g.op('Neg', y[0]))

        @torch.onnx.symbolic_override_first_arg_based(symb)
        def foo(x, y):
            return x + y[0] + y[1], (-x, -y[0])

        class BigModule(torch.nn.Module):
            def forward(self, x, y):
                return foo(x, y)

        inp = (Variable(torch.FloatTensor([1])),
               (Variable(torch.FloatTensor([2])),
                Variable(torch.FloatTensor([3]))))
        BigModule()(*inp)
        self.assertONNX(BigModule(), inp)


if __name__ == '__main__':
    onnx_test_flag = '--onnx-test'
    _onnx_test = onnx_test_flag in common.UNITTEST_ARGS
    if onnx_test_flag in common.UNITTEST_ARGS:
        common.UNITTEST_ARGS.remove(onnx_test_flag)
    if _onnx_test:
        for d in glob.glob(os.path.join(test_onnx_common.pytorch_operator_dir, "test_operator_*")):
            shutil.rmtree(d)
    run_tests()