Exemple #1
0
def is_trainable(ctx):
    assert in_global_mode(ctx)
    if oneflow._oneflow_internal.EagerExecutionEnabled():
        return session_ctx.GetDefaultSession().CurrentEagerGlobalFunctionDesc()
    else:
        job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName()
        return session_ctx.GetDefaultSession().GetFunctionDesc(job_name)
Exemple #2
0
 def Save(self, save_model_before_graph_complete: bool = True):
     self._check_input_output_name_conflict()
     for (_, graph_builder) in self.graph_builders_.items():
         if not graph_builder.finished:
             graph_builder.Finish()
     sess = session_ctx.GetDefaultSession()
     for (graph_name, graph_def) in self.proto.graphs.items():
         job = sess.Job(graph_name if save_model_before_graph_complete else
                        graph_name + "_after_complete")
         graph_def.op_list.extend(list(job.net.op))
     if not os.path.exists(self.saved_model_dir_):
         os.makedirs(self.saved_model_dir_)
     if self.version_ is None:
         raise ValueError("model version is not set")
     version_dir = os.path.join(self.saved_model_dir_, str(self.version_))
     if os.path.exists(version_dir):
         raise ValueError(
             'Directory of model "{}" version "{}" already exist.'.format(
                 self.saved_model_dir_, self.version_))
     os.makedirs(version_dir)
     self.proto.version = self.version_
     checkpoint_path = os.path.join(version_dir, self.checkpoint_dir_)
     flow.checkpoint.save(checkpoint_path)
     self.proto.checkpoint_dir = self.checkpoint_dir_
     saved_model_pb_path = os.path.join(version_dir,
                                        self.saved_model_pb_filename_)
     with open(saved_model_pb_path, "wb") as writer:
         writer.write(self.saved_model_proto_.SerializeToString())
     saved_model_pbtxt_path = os.path.join(version_dir,
                                           self.saved_model_pbtxt_filename_)
     with open(saved_model_pbtxt_path, "wt") as writer:
         writer.write(text_format.MessageToString(self.saved_model_proto_))
Exemple #3
0
def make_new_block_scope(prev_scope, block):
    assert prev_scope is not None
    assert block is not None

    attr_dict = dict()
    if block.config.stage_id is not None:
        attr_dict["pipeline_stage_id_hint"] = block.config.stage_id
    if block.config.activation_checkpointing is not None:
        attr_dict["checkpointing"] = block.config.activation_checkpointing

    name2default = session_context.GetDefaultSession().scope_attr_name2default_val

    def scope_proto_setter(scope_proto):
        # set attr
        for attr_name, py_value in attr_dict.items():
            assert attr_name in name2default
            attr_util.SetAttrValue(
                scope_proto.mutable_attr_name2attr_value()[attr_name],
                py_value,
                name2default[attr_name],
            )
        # append name prefix
        scope_proto.clear_scope_op_name_prefixes()
        scope_proto.add_scope_op_name_prefixes(block.name_prefix + block.name)

    new_scope = None

    def build_scope(builder):
        nonlocal new_scope
        new_scope = builder.BuildScopeByProtoSetter(prev_scope, scope_proto_setter)
        assert new_scope is not None

    oneflow._oneflow_internal.deprecated.PhysicalRun(build_scope)
    oneflow._oneflow_internal.eager.Sync()
    return new_scope
 def test_case2(self):
     print("test_case2")
     self.assertTrue(flow.env.is_multi_client())
     sess = session_ctx.GetDefaultSession()
     self.assertTrue(isinstance(sess, MultiClientSession))
     sess.TryInit()
     self.assertEqual(sess.status, sess.Status.INITED)
Exemple #5
0
    def __getattr__(
        self, attr_name: str
    ) -> Callable[[Optional[Union[bool, int, float, str]]], None]:
        name2default = session_ctx.GetDefaultSession(
        ).function_flag_name2default_val
        assert attr_name in name2default
        flag_name2flag_value = self.function_desc.job_config_proto.flag_name2flag_value
        default_val = name2default[attr_name]

        def FunctionConfigSetter(
                attr_value: Optional[Union[bool, int, float,
                                           str]] = None) -> None:
            if default_val.HasField("at_bool"):
                if attr_value is None:
                    attr_value = True
                assert type(attr_value) is bool
                flag_name2flag_value[attr_name].at_bool = attr_value
            elif default_val.HasField("at_int64"):
                assert type(attr_value) is int
                flag_name2flag_value[attr_name].at_int64 = attr_value
            elif default_val.HasField("at_double"):
                assert type(attr_value) is float
                flag_name2flag_value[attr_name].at_double = attr_value
            elif default_val.HasField("at_string"):
                assert type(attr_value) is str
                flag_name2flag_value[attr_name].at_string = attr_value
            else:
                raise NotImplementedError(
                    "config_flag `%s' with type %s is not supported" %
                    (attr_name, type(attr_value)))

        return FunctionConfigSetter
Exemple #6
0
def api_scope_config(**kwargs):
    name2default = session_ctx.GetDefaultSession().scope_attr_name2default_val

    def SetScopeProto(scope_proto):
        for (attr_name, py_value) in kwargs.items():
            assert attr_name in name2default
            attr_util.SetAttrValue(
                scope_proto.mutable_attr_name2attr_value()[attr_name],
                py_value,
                name2default[attr_name],
            )

    sess = session_ctx.GetDefaultSession()
    scope = MakeScope(lambda old_scope, builder: builder.
                      BuildScopeByProtoSetter(old_scope, SetScopeProto))
    return ScopeContext(scope)
Exemple #7
0
 def test_feed_input_tensor(test_case):
     test_case.assertTrue(
         oneflow.framework.env_util.HasAllMultiClientEnvVars())
     x = flow.Tensor(1, 1, 10, 10)
     flow.nn.init.uniform_(x, a=-1.0, b=1.0)
     session = session_ctx.GetDefaultSession()
     test_case.assertTrue(isinstance(session, MultiClientSession))
     session.TryInit()
     with oneflow._oneflow_internal.lazy_mode.guard(True):
         oneflow._oneflow_internal.JobBuildAndInferCtx_Open(
             "cc_test_input_op_expr_job")
         job_conf = (oneflow._oneflow_internal.oneflow.core.job.job_conf.
                     JobConfigProto())
         job_conf.set_job_name("cc_test_input_op_expr_job")
         job_conf.mutable_predict_conf()
         c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)
         op_name = "cc_Input_0"
         input_conf = (oneflow._oneflow_internal.oneflow.core.operator.
                       op_conf.FeedInputOpConf())
         input_conf.set_in_0("EagerTensorInput")
         input_conf.set_out_0("out_0")
         input_op = oneflow._oneflow_internal.one.FeedInputOpExpr(
             op_name, input_conf, ["in_0"], ["out_0"])
         attrs = oneflow._oneflow_internal.MutableCfgAttrMap()
         out_tensor = _C.dispatch_feed_input(input_op, x)
         test_case.assertEqual(out_tensor.shape, (1, 1, 10, 10))
         test_case.assertTrue(out_tensor.is_lazy)
         test_case.assertTrue(out_tensor.is_local)
         oneflow._oneflow_internal.JobBuildAndInferCtx_Close()
 def test_feed_var_tensor(test_case):
     test_case.assertTrue(
         oneflow.framework.env_util.HasAllMultiClientEnvVars())
     x = flow.Tensor(1, 1, 10, 10)
     flow.nn.init.uniform_(x, a=-1.0, b=1.0)
     session = session_ctx.GetDefaultSession()
     test_case.assertTrue(isinstance(session, MultiClientSession))
     session.TryInit()
     with oneflow._oneflow_internal.lazy_mode.guard(True):
         oneflow._oneflow_internal.JobBuildAndInferCtx_Open(
             "cc_test_variable_op_expr_job")
         job_conf = oneflow.core.job.job_conf_pb2.JobConfigProto()
         job_conf.job_name = "cc_test_variable_op_expr_job"
         job_conf.predict_conf.SetInParent()
         c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)
         op_name = "cc_Variable_0"
         var_conf = oneflow.core.operator.op_conf_pb2.FeedVariableOpConf()
         var_conf.in_0 = "EagerTensorInput"
         var_conf.out_0 = "out_0"
         var_conf_str = text_format.MessageToString(var_conf)
         var_op = oneflow._oneflow_internal.one.FeedVariableOpExpr(
             op_name, var_conf_str, ["in_0"], ["out_0"])
         out_tensor = _C.dispatch_feed_variable(var_op, x, l2=0)
         test_case.assertEqual(out_tensor.shape, (1, 1, 10, 10))
         test_case.assertTrue(out_tensor.is_lazy)
         test_case.assertTrue(out_tensor.is_local)
         oneflow._oneflow_internal.JobBuildAndInferCtx_Close()
Exemple #9
0
 def HasAttr(self, attr_name):
     if attr_name == "flag_name2flag_value":
         return False
     name2default = session_ctx.GetDefaultSession().function_flag_name2default_val
     if attr_name in self.job_config_proto.flag_name2flag_value:
         return True
     return self.job_config_proto.HasField(attr_name)
Exemple #10
0
def _set_attr_to_resource(attr_name, attr_value):
    sess = session_ctx.GetDefaultSession()
    if sess.status_ == sess.Status.INITED:
        reso_config = resource_util.Resource()
        setattr(reso_config, attr_name, attr_value)
        sess.update_resource_eagerly(reso_config)
    else:
        setattr(sess.config_proto.resource, attr_name, attr_value)
Exemple #11
0
 def HasAttr(self, attr_name):
     if attr_name == "flag_name2flag_value":
         return False
     name2default = session_ctx.GetDefaultSession(
     ).function_flag_name2default_val
     if attr_name in self.job_config_proto.flag_name2flag_value():
         return True
     return getattr(self.job_config_proto, "has_" + attr_name)()
Exemple #12
0
    def __init__(self):
        """
        Initializes internal Graph states. It MUST be called in ``__init__`` method of subclass.

        For example:

        .. code-block:: python

            >>> import oneflow as flow
            >>> class SubclassGraph(flow.nn.Graph):
            ...     def __init__(self):
            ...         super().__init__() # MUST be called
            ...         # Then define the graph attributes
            ...     def build(self):
            ...         pass

        """
        self._generate_name()
        self.config = GraphConfig()
        self._blocks = OrderedDict()
        self._opts = []
        self._verbose = False
        self._grad_scaler = None
        self._variables_conf = OrderedDict()
        self._additional_variable_tobe_loaded = OrderedDict()
        self._is_compiled = False
        # Default is local view
        self._is_global_view = False
        # forward graph job proto
        self._forward_job_proto = None
        # forward, backward and optimized graph job proto
        self._full_job_proto = None
        self._args_repr = []
        self._outs_repr = []
        self._debug = False
        self._debug_min_s_level = 2
        self._debug_max_v_level = 0
        self._debug_max_py_stack_depth = 2
        self._outputs_buffer_size = 2
        self._cur_index_of_ouputs_buffer = 0

        self._session = session_ctx.GetDefaultSession()
        assert type(self._session) is MultiClientSession
        self._session.TryInit()
        self._c_nn_graph = oneflow._oneflow_internal.nn.graph.CNNGraph(
            self._name, self._session._session_ctx)
Exemple #13
0
            def build(self, x):
                test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), True)
                import oneflow.framework.session_context as session_ctx
                from oneflow.framework.multi_client_session import MultiClientSession

                session = session_ctx.GetDefaultSession()
                test_case.assertEqual(type(session), MultiClientSession)
                import oneflow.framework.scope_util as scope_util

                scope = scope_util.current_scope()
                scope_proto = graph_build_util.scope_to_proto(scope)
                test_case.assertEqual(session.id, scope_proto.session_id)
                test_case.assertEqual(
                    oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(),
                    self.name,
                )
                return x
def make_new_block_scope(prev_scope, block):
    assert prev_scope is not None
    assert block is not None
    attr_dict = dict()
    if block.config.stage_id is not None:
        attr_dict["pipeline_stage_id_hint"] = block.config.stage_id
    if block.config.activation_checkpointing is not None:
        attr_dict["checkpointing"] = block.config.activation_checkpointing

    name2default = session_context.GetDefaultSession(
    ).scope_attr_name2default_val

    def scope_proto_str_setter(serialized_scope_proto: str):
        scope_proto = text_format.Parse(serialized_scope_proto,
                                        scope_pb2_util.ScopeProto())
        # set attr
        for attr_name, py_value in attr_dict.items():
            assert attr_name in name2default
            attr_util.SetProtoAttrValue(
                scope_proto.attr_name2attr_value[attr_name],
                py_value,
                name2default[attr_name],
            )
        # append name prefix
        scope_proto.ClearField("scope_op_name_prefixes")
        scope_proto.scope_op_name_prefixes.append(block.name_prefix +
                                                  block.name)
        # set module name
        if isinstance(block, oneflow.nn.graph.block.ModuleBlock):
            scope_proto.module_name = block.name_prefix + block.name

        return str(text_format.MessageToString(scope_proto))

    new_scope = None

    def build_scope(builder):
        nonlocal new_scope
        new_scope = builder.BuildScopeByProtoStrSetter(prev_scope,
                                                       scope_proto_str_setter)
        assert new_scope is not None

    oneflow._oneflow_internal.deprecated.PhysicalRun(build_scope)
    oneflow._oneflow_internal.eager.Sync()
    return new_scope
Exemple #15
0
 def __getattr__(self, attr_name):
     assert attr_name != "flag_name2flag_value"
     flag_name2flag_value = self.job_config_proto.flag_name2flag_value
     name2default = session_ctx.GetDefaultSession().function_flag_name2default_val
     if attr_name not in name2default:
         assert self.job_config_proto.HasField(attr_name)
         return getattr(self.job_config_proto, attr_name)
     attr_value = name2default[attr_name]
     if attr_name in flag_name2flag_value:
         attr_value = flag_name2flag_value[attr_name]
     if attr_value.HasField("at_bool"):
         return attr_value.at_bool
     elif attr_value.HasField("at_int64"):
         return attr_value.at_int64
     elif attr_value.HasField("at_double"):
         return attr_value.at_double
     elif attr_value.HasField("at_string"):
         return attr_value.at_string
     else:
         raise NotImplementedError()
Exemple #16
0
def nccl_fusion_reduce_scatter(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.resource.collective_boxing_conf.nccl_fusion_reduce_scatter = val
Exemple #17
0
def load_library(val):
    assert type(val) is str
    sess = session_ctx.GetDefaultSession()
    sess.config_proto.load_lib_path.append(val)
Exemple #18
0
def enable_tensor_float_32_compute(val=True):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.resource.enable_tensor_float_32_compute = val
    if not val:
        os.environ["ONEFLOW_EP_CUDA_ENABLE_TF32_EXECUTION"] = "0"
Exemple #19
0
def enable_mem_chain_merge(val=True):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.resource.enable_mem_chain_merge = val
Exemple #20
0
def enable_fusion(val=True):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.resource.collective_boxing_conf.enable_fusion = val
Exemple #21
0
def num_callback_threads(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is int
    sess.config_proto.resource.collective_boxing_conf.num_callback_threads = val
Exemple #22
0
def enable_model_io_v2(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.resource.enable_model_io_v2 = val
Exemple #23
0
def nccl_fusion_broadcast(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.resource.collective_boxing_conf.nccl_fusion_broadcast = val
Exemple #24
0
def api_legacy_model_io_enabled():
    sess = session_ctx.GetDefaultSession()
    return sess.config_proto.resource.enable_legacy_model_io
Exemple #25
0
def enable_cudnn_fused_normalization_add_relu(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.resource.cudnn_conf.enable_cudnn_fused_normalization_add_relu = (
        val)
Exemple #26
0
def nccl_enable_mixed_fusion(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.resource.collective_boxing_conf.nccl_enable_mixed_fusion = val
Exemple #27
0
def nccl_fusion_max_ops(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is int
    sess.config_proto.resource.collective_boxing_conf.nccl_fusion_max_ops = val
Exemple #28
0
def compute_thread_pool_size(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is int
    sess.config_proto.resource.compute_thread_pool_size = val
Exemple #29
0
def machine_num(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is int
    sess.config_proto.resource.machine_num = val
Exemple #30
0
def reserved_device_mem_mbyte(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is int
    sess.config_proto.resource.reserved_device_mem_mbyte = val