Beispiel #1
0
def BuildInitialScope(
    instruction_builder,
    session_id,
    job_conf,
    device_tag,
    machine_device_ids,
    is_mirrored,
):
    scope_proto = scope_pb.ScopeProto()
    scope_proto.session_id = session_id
    job_conf_sym = instruction_builder.GetJobConfSymbol(job_conf)
    scope_proto.job_desc_symbol_id = job_conf_sym.symbol_id
    parallel_conf = MakeParallelConf(device_tag, machine_device_ids)
    device_parallel_desc_sym = instruction_builder.GetParallelDescSymbol(
        parallel_conf)
    scope_proto.device_parallel_desc_symbol_id = device_parallel_desc_sym.symbol_id
    parallel_conf = MakeParallelConf("cpu", machine_device_ids)
    host_parallel_desc_sym = instruction_builder.GetParallelDescSymbol(
        parallel_conf)
    scope_proto.host_parallel_desc_symbol_id = host_parallel_desc_sym.symbol_id
    if is_mirrored:
        scope_proto.opt_mirrored_parallel_conf.mirrored_parallel.SetInParent()
    else:
        scope_proto.opt_mirrored_parallel_conf.ClearField("mirrored_parallel")
    return instruction_builder.GetScopeSymbol(scope_proto, None)
def AddScopeToStorage(scope_symbol_id, scope_proto_str):
    if symbol_storage.HasSymbol4SerializedScopeProto(scope_proto_str):
        return
    scope_proto = text_format.Parse(scope_proto_str, scope_pb.ScopeProto())
    parent_scope_symbol = symbol_storage.GetSymbol4Id(
        scope_proto.parent_scope_symbol_id
    )
    symbol = scope_symbol.ScopeSymbol(scope_symbol_id, scope_proto, parent_scope_symbol)
    symbol_storage.SetSymbol4Id(scope_symbol_id, symbol)
    symbol_storage.SetSymbol4SerializedScopeProto(scope_proto_str, symbol)
    def scope_proto_str_setter(serialized_scope_proto: str):
        scope_proto = text_format.Parse(serialized_scope_proto,
                                        scope_pb2_util.ScopeProto())
        # set attr
        for attr_name, py_value in attr_dict.items():
            assert attr_name in name2default
            attr_util.SetProtoAttrValue(
                scope_proto.attr_name2attr_value[attr_name],
                py_value,
                name2default[attr_name],
            )
        # append name prefix
        scope_proto.ClearField("scope_op_name_prefixes")
        scope_proto.scope_op_name_prefixes.append(block.name_prefix +
                                                  block.name)
        # set module name
        if isinstance(block, oneflow.nn.graph.block.ModuleBlock):
            scope_proto.module_name = block.name_prefix + block.name

        return str(text_format.MessageToString(scope_proto))
Beispiel #4
0
 def _CloneScopeProto(self):
     scope_proto = scope_pb.ScopeProto()
     scope_proto.CopyFrom(self.data)
     return scope_proto
Beispiel #5
0
def scope_to_proto(scope):
    return text_format.Parse(scope._proto_str, scope_pb2_util.ScopeProto())
Beispiel #6
0
 def _CloneScopeProto(self):
     scope_proto = scope_pb.ScopeProto()
     scope_proto.CopyFrom(self.data)
     scope_proto.ClearField("symbol_id")
     return scope_proto