Beispiel #1
0
def make_model(graph, **kwargs):
    model = ModelProto()
    # Touch model.ir_version so it is stored as the version from which it is
    # generated.
    model.ir_version = IR_VERSION
    model.graph.CopyFrom(graph)

    for k, v in kwargs.items():
        setattr(model, k, v)
    return model
Beispiel #2
0
    def test_load(self):
        # Create a model proto.
        model = ModelProto()
        model.ir_version = IR_VERSION
        model_string = model.SerializeToString()

        # Test if input is string
        loaded_model = onnx.load_from_string(model_string)
        self.assertTrue(model == loaded_model)

        # Test if input has a read function
        f = io.BytesIO(model_string)
        loaded_model = onnx.load(f)
        self.assertTrue(model == loaded_model)

        # Test if input is a file name
        f = tempfile.NamedTemporaryFile(delete=False)
        f.write(model_string)
        f.close()
        loaded_model = onnx.load(f.name)
        self.assertTrue(model == loaded_model)
        os.remove(f.name)
Beispiel #3
0
 def test_version_exists(self):
     model = ModelProto()
     # When we create it, graph should not have a version string.
     self.assertFalse(model.HasField('ir_version'))
     # We should touch the version so it is annotated with the current
     # ir version of the running ONNX
     model.ir_version = IR_VERSION
     model_string = model.SerializeToString()
     model.ParseFromString(model_string)
     self.assertTrue(model.HasField('ir_version'))
     # Check if the version is correct.
     self.assertEqual(model.ir_version, IR_VERSION)