def BuildInitialScope( instruction_builder, session_id, job_conf, device_tag, machine_device_ids, is_mirrored, ): scope_proto = scope_pb.ScopeProto() scope_proto.session_id = session_id job_conf_sym = instruction_builder.GetJobConfSymbol(job_conf) scope_proto.job_desc_symbol_id = job_conf_sym.symbol_id parallel_conf = MakeParallelConf(device_tag, machine_device_ids) device_parallel_desc_sym = instruction_builder.GetParallelDescSymbol( parallel_conf) scope_proto.device_parallel_desc_symbol_id = device_parallel_desc_sym.symbol_id parallel_conf = MakeParallelConf("cpu", machine_device_ids) host_parallel_desc_sym = instruction_builder.GetParallelDescSymbol( parallel_conf) scope_proto.host_parallel_desc_symbol_id = host_parallel_desc_sym.symbol_id if is_mirrored: scope_proto.opt_mirrored_parallel_conf.mirrored_parallel.SetInParent() else: scope_proto.opt_mirrored_parallel_conf.ClearField("mirrored_parallel") return instruction_builder.GetScopeSymbol(scope_proto, None)
def AddScopeToStorage(scope_symbol_id, scope_proto_str): if symbol_storage.HasSymbol4SerializedScopeProto(scope_proto_str): return scope_proto = text_format.Parse(scope_proto_str, scope_pb.ScopeProto()) parent_scope_symbol = symbol_storage.GetSymbol4Id( scope_proto.parent_scope_symbol_id ) symbol = scope_symbol.ScopeSymbol(scope_symbol_id, scope_proto, parent_scope_symbol) symbol_storage.SetSymbol4Id(scope_symbol_id, symbol) symbol_storage.SetSymbol4SerializedScopeProto(scope_proto_str, symbol)
def scope_proto_str_setter(serialized_scope_proto: str): scope_proto = text_format.Parse(serialized_scope_proto, scope_pb2_util.ScopeProto()) # set attr for attr_name, py_value in attr_dict.items(): assert attr_name in name2default attr_util.SetProtoAttrValue( scope_proto.attr_name2attr_value[attr_name], py_value, name2default[attr_name], ) # append name prefix scope_proto.ClearField("scope_op_name_prefixes") scope_proto.scope_op_name_prefixes.append(block.name_prefix + block.name) # set module name if isinstance(block, oneflow.nn.graph.block.ModuleBlock): scope_proto.module_name = block.name_prefix + block.name return str(text_format.MessageToString(scope_proto))
def _CloneScopeProto(self): scope_proto = scope_pb.ScopeProto() scope_proto.CopyFrom(self.data) return scope_proto
def scope_to_proto(scope): return text_format.Parse(scope._proto_str, scope_pb2_util.ScopeProto())
def _CloneScopeProto(self): scope_proto = scope_pb.ScopeProto() scope_proto.CopyFrom(self.data) scope_proto.ClearField("symbol_id") return scope_proto