def GetAllVariables() -> Dict[str, oneflow_api.EagerConsistentBlob]: """ Get all variables of all jobs as a dict. """ oneflow.sync_default_session() sess = session_ctx.GetDefaultSession() interface_ops = sess.interface_ops variables = {} for op in interface_ops: op_attr = sess.OpAttribute4InterfaceOpName(op) if op_attr.op_conf.WhichOneof("op_type") != "variable_conf": continue variables[op] = interface_op_read_and_write.GetEagerInterfaceBlob(op) return variables
def LoadVariables( value_dict: Dict[str, ValueContainer], ignore_mismatch: bool = True, ): """ Load value in `value_dict` into oneflow variables. For example, if `value_dict` is {'x', np.ones(x_shape)}, the value of variable "x" will all ones. If `ignore_mismatch` is False, an exception will be raised when there is a name in `value_dict` not belonging to any variable. """ oneflow.sync_default_session() all_vars = GetAllVariables() for name, value in value_dict.items(): if name in all_vars: var_blob = interface_op_read_and_write.GetEagerInterfaceBlob(name) _FeedValueToVariable(var_blob, value) else: if not ignore_mismatch: raise RuntimeError('"{}" is not a variable name'.format(name)) oneflow_api.eager.Sync()