Ejemplo n.º 1
0
    def test_save_and_load(self):
        placement_arg = {
            "placement": flow.placement("cuda", ranks=[0]),
            "sbp": flow.sbp.broadcast,
        }
        graph = InferGraph(placement_arg)
        image_placeholder = flow.empty(
            (1, 3, 224, 224),
            dtype=flow.float32,
            placement=flow.placement("cpu", ranks=[0]),
            sbp=flow.sbp.broadcast,
        )
        graph._compile(image_placeholder)
        saved_path = os.path.join("saved_model", graph.name)
        if not os.path.exists(saved_path):
            os.makedirs(saved_path)
        flow.save(graph, saved_path)

        saved_ir_path = os.path.join(saved_path, "model.mlir")
        serialized_job = oneflow._oneflow_internal.nn.graph.LoadSerializedJobFromIR(
            saved_ir_path)
        job = job_pb.Job()
        job.ParseFromString(serialized_job)

        op_list = []
        op_list_ = []

        for op in job.net.op:
            op_list.append(op)

        for op in graph._forward_job_proto.net.op:
            op_list_.append(op)

        def sort_by_op_name(op):
            return op.name

        op_list.sort(key=sort_by_op_name)
        op_list_.sort(key=sort_by_op_name)

        for (op, op_) in zip(op_list, op_list_):
            # TODO: convert loc in MLIR
            op_.ClearField("loc")
            self.assertTrue(op == op_, {"op": op, "op_": op_})
    def test_save_and_load(self):
        placement_arg = {
            "placement": flow.placement("cuda", ranks=[0]),
            "sbp": flow.sbp.broadcast,
        }
        graph = InferGraph(placement_arg)
        image_placeholder = flow.empty(
            (1, 3, 224, 224),
            dtype=flow.float32,
            placement=flow.placement("cpu", ranks=[0]),
            sbp=flow.sbp.broadcast,
        )
        graph._compile(image_placeholder)
        saved_path = os.path.join("saved_model", graph.name)
        if not os.path.exists(saved_path):
            os.makedirs(saved_path)
        flow.save(graph, saved_path)

        saved_ir_path = os.path.join(saved_path, "model.mlir")
        serialized_job = oneflow._oneflow_internal.nn.graph.LoadSerializedJobFromIR(
            saved_ir_path)
        job = job_pb.Job()
        job.ParseFromString(serialized_job)
Ejemplo n.º 3
0
def GetCurrentJob():
    serialized_job = oneflow._oneflow_internal.GetSerializedCurrentJob()
    ret = job_pb.Job()
    ret.ParseFromString(serialized_job)
    return ret
Ejemplo n.º 4
0
def GetCurrentJob():
    serialized_job = oneflow._oneflow_internal.GetSerializedCurrentJob()
    return text_format.Parse(serialized_job, job_pb.Job())