示例#1
0
def global_function_or_identity(*args, **kwargs):
    if rt_mode.CurrentMode() == rt_mode.NORMAL_MODE:
        return flow.global_function(*args, **kwargs)
    else:
        assert rt_mode.CurrentMode() == rt_mode.GLOBAL_MODE
        identity_decorator = lambda func: func
        return identity_decorator
def sync_default_session_if_normal():
    # TODO merge with same function in framework/check_point_v2.py
    if rt_mode.CurrentMode() == rt_mode.NORMAL_MODE:
        flow.sync_default_session()
    else:
        # do nothing
        pass
示例#3
0
def sync_default_session_if_normal():
    # TODO merge with same function in experimental/interface_op_read_and_write.py
    if rt_mode.CurrentMode() == rt_mode.NORMAL_MODE:
        oneflow.sync_default_session()
    else:
        # do nothing
        pass
示例#4
0
def GenerateVariableOpConf(
        name,
        shape,
        dtype=None,
        initializer=None,
        regularizer=None,
        trainable=None,
        model_name=None,
        random_seed=None,
        distribute=oneflow_api.distribute.broadcast(),
):
    op_conf = op_conf_util.OperatorConf()
    op_conf.name = name
    op_conf.variable_conf.shape.dim.extend(shape)

    assert dtype is not None
    op_conf.variable_conf.data_type = oneflow_api.deprecated.GetProtoDtype4OfDtype(
        dtype)

    if rt_mode.CurrentMode() == rt_mode.NORMAL_MODE:
        root_path = None
    else:
        root_path = (compile_context.GetCurJobConfigProto().
                     default_initialize_with_snapshot_path())
        dir_path = os.path.join(root_path, name)
        file_path = os.path.join(dir_path, "out")
    if root_path and os.path.isfile(file_path):
        op_conf.variable_conf.initialize_with_snapshot.path = dir_path
        op_conf.variable_conf.initialize_with_snapshot.key = "out"
    else:
        if root_path:
            print("{} not found, will be initialized".format(file_path))
        if initializer is not None:
            op_conf.variable_conf.initializer.CopyFrom(initializer)

    if regularizer is not None:
        op_conf.variable_conf.regularizer.CopyFrom(regularizer)

    if trainable is not None:
        op_conf.variable_conf.trainable = trainable

    if model_name is not None:
        op_conf.variable_conf.model_name = model_name

    if type(distribute) is oneflow_api.distribute.SplitDistribute:
        op_conf.variable_conf.split_axis.value = distribute.axis
    else:
        op_conf.variable_conf.split_axis.ClearField("value")

    if random_seed is not None:
        op_conf.variable_conf.random_seed = random_seed

    op_conf.variable_conf.out = "out"
    return op_conf
示例#5
0
def in_device_mode(ctx):
    return rt_mode.CurrentMode() == rt_mode.DEVICE_MODE
示例#6
0
def in_global_mode(ctx):
    return rt_mode.CurrentMode() == rt_mode.GLOBAL_MODE
示例#7
0
def in_normal_mode(ctx):
    return rt_mode.CurrentMode() == rt_mode.NORMAL_MODE