def _GetOpNodeSignatureSymbol(self, op_attribute): new_op_node_signature = op_attribute_pb.OpNodeSignature() new_op_node_signature.sbp_signature.CopyFrom(op_attribute.sbp_signature) new_op_node_signature.mirrored_signature.CopyFrom( op_attribute.mirrored_signature ) new_op_node_signature.logical_blob_desc_signature.CopyFrom( op_attribute.logical_blob_desc_signature ) new_op_node_signature.batch_axis_signature.CopyFrom( op_attribute.batch_axis_signature ) new_op_node_signature.parallel_signature.CopyFrom( op_attribute.parallel_signature ) serialized_op_node_signature = new_op_node_signature.SerializeToString() if symbol_storage.HasSymbol4SerializedOpNodeSignature( serialized_op_node_signature ): return symbol_storage.GetSymbol4SerializedOpNodeSignature( serialized_op_node_signature ) symbol_id = self._NewSymbolId4OpNodeSignature(new_op_node_signature) symbol = symbol_util.Symbol(symbol_id, new_op_node_signature) symbol_storage.SetSymbol4Id(symbol_id, symbol) symbol_storage.SetSymbol4SerializedOpNodeSignature( serialized_op_node_signature, symbol ) return symbol
def GetJobConfSymbol(self, job_conf): if symbol_storage.HasSymbol4JobConf(job_conf): return symbol_storage.GetSymbol4JobConf(job_conf) symbol_id = self._NewSymbolId4JobConf(job_conf) symbol = symbol_util.Symbol(symbol_id, job_conf) symbol_storage.SetSymbol4Id(symbol_id, symbol) symbol_storage.SetSymbol4JobConf(job_conf, symbol) return symbol
def GetSymbol4String(self, string): if symbol_storage.HasSymbol4String(string): return symbol_storage.GetSymbol4String(string) symbol_id = self._NewSymbolId4String(string) symbol = symbol_util.Symbol(symbol_id, string) symbol_storage.SetSymbol4Id(symbol_id, symbol) symbol_storage.SetSymbol4String(string, symbol) return symbol
def _GetOpConfSymbol(self, op_conf): serialized_op_conf = op_conf.SerializeToString() if symbol_storage.HasSymbol4SerializedOpConf(serialized_op_conf): return symbol_storage.GetSymbol4SerializedOpConf(serialized_op_conf) symbol_id = self._NewSymbolId4OpConf(op_conf) symbol = symbol_util.Symbol(symbol_id, op_conf) symbol_storage.SetSymbol4Id(symbol_id, symbol) symbol_storage.SetSymbol4SerializedOpConf(serialized_op_conf, symbol) return symbol
def GetScopeSymbol(self, scope_proto, parent_scope_symbol=None): symbol_id = self._NewSymbolId4Scope(scope_proto) serialized_scope_proto = scope_proto.SerializeToString() if symbol_storage.HasSymbol4SerializedScopeProto(serialized_scope_proto): return symbol_storage.GetSymbol4SerializedScopeProto(serialized_scope_proto) symbol = scope_util.ScopeSymbol(symbol_id, scope_proto, parent_scope_symbol) symbol_storage.SetSymbol4Id(symbol_id, symbol) symbol_storage.SetSymbol4SerializedScopeProto(serialized_scope_proto, symbol) return symbol
def AddScopeToStorage(scope_symbol_id, scope_proto_str): if symbol_storage.HasSymbol4SerializedScopeProto(scope_proto_str): return scope_proto = text_format.Parse(scope_proto_str, scope_pb.ScopeProto()) parent_scope_symbol = symbol_storage.GetSymbol4Id( scope_proto.parent_scope_symbol_id ) symbol = scope_symbol.ScopeSymbol(scope_symbol_id, scope_proto, parent_scope_symbol) symbol_storage.SetSymbol4Id(scope_symbol_id, symbol) symbol_storage.SetSymbol4SerializedScopeProto(scope_proto_str, symbol)
def AddScopeToStorage(scope_symbol_id, scope_proto): scope_proto_str = str(scope_proto) if symbol_storage.HasSymbol4SerializedScopeProto(scope_proto_str): return parent_scope_symbol = symbol_storage.GetSymbol4Id( scope_proto.parent_scope_symbol_id()) symbol = scope_symbol.ScopeSymbol(scope_symbol_id, scope_proto, parent_scope_symbol) symbol_storage.SetSymbol4Id(scope_symbol_id, symbol) symbol_storage.SetSymbol4SerializedScopeProto(scope_proto_str, symbol)
def GetParallelDescSymbol(self, parallel_conf): device_tag = parallel_conf.device_tag serialized_parallel_conf = parallel_conf.SerializeToString() if symbol_storage.HasSymbol4SerializedParallelConf( serialized_parallel_conf): return symbol_storage.GetSymbol4SerializedParallelConf( serialized_parallel_conf) symbol_id = self._NewSymbolId4ParallelConf(parallel_conf) symbol = symbol_util.ParallelDescSymbol(symbol_id, parallel_conf, device_tag) symbol_storage.SetSymbol4Id(symbol_id, symbol) symbol_storage.SetSymbol4SerializedParallelConf( serialized_parallel_conf, symbol) return symbol