Beispiel #1
0
def GetAllVariables() -> Dict[str, oneflow_api.EagerConsistentBlob]:
    """
    Get all variables of all jobs as a dict.
    """
    oneflow.sync_default_session()

    sess = session_ctx.GetDefaultSession()
    interface_ops = sess.interface_ops
    variables = {}
    for op in interface_ops:
        op_attr = sess.OpAttribute4InterfaceOpName(op)
        if op_attr.op_conf.WhichOneof("op_type") != "variable_conf":
            continue
        variables[op] = interface_op_read_and_write.GetEagerInterfaceBlob(op)
    return variables
Beispiel #2
0
def LoadVariables(
    value_dict: Dict[str, ValueContainer],
    ignore_mismatch: bool = True,
):
    """
    Load value in `value_dict` into oneflow variables.
    For example, if `value_dict` is {'x', np.ones(x_shape)},
    the value of variable "x" will all ones.
    If `ignore_mismatch` is False, an exception will be raised when
    there is a name in `value_dict` not belonging to any variable.
    """
    oneflow.sync_default_session()

    all_vars = GetAllVariables()
    for name, value in value_dict.items():
        if name in all_vars:
            var_blob = interface_op_read_and_write.GetEagerInterfaceBlob(name)
            _FeedValueToVariable(var_blob, value)
        else:
            if not ignore_mismatch:
                raise RuntimeError('"{}" is not a variable name'.format(name))
    oneflow_api.eager.Sync()