def __init__(self, symbol_id, scope_proto, parent_scope_symbol=None): Symbol.__init__(self, symbol_id, scope_proto) self.parent_scope_symbol_ = parent_scope_symbol self.job_desc_symbol_ = oneflow_api.GetJobConfSymbol( scope_proto.job_desc_symbol_id()) self.device_parallel_desc_symbol_ = oneflow_api.GetPlacementSymbol( scope_proto.device_parallel_desc_symbol_id()) self.host_parallel_desc_symbol_ = oneflow_api.GetPlacementSymbol( scope_proto.host_parallel_desc_symbol_id()) self.auto_increment_id_ = 0
def GetOutBlobParallelDescSymbol(obn): parallel_signature = op_attribute.parallel_signature bn2symbol_id = parallel_signature.bn_in_op2parallel_desc_symbol_id if obn in bn2symbol_id: return oneflow_api.GetPlacementSymbol(bn2symbol_id[obn]) else: return parallel_desc_sym
def DistributeConcatOrAdd(op_attribute, parallel_conf, blob_register): op_parallel_desc_sym = oneflow_api.GetPlacementSymbol( op_attribute.parallel_signature.op_parallel_desc_symbol_id) parallel_size = len(op_attribute.input_bns) op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute( op_parallel_desc_sym, str(op_attribute), "out") op_arg_blob_attr = oneflow_api.GetOpArgBlobAttribute( str(op_attribute), "out") parallel_sig = op_attribute.parallel_signature.bn_in_op2parallel_desc_symbol_id def GetInBlobObject(builder, i, bn_in_op2blob_object): ibn = "in_%s" % i origin_blob_object = bn_in_op2blob_object[ibn] in_op_parallel_desc_sym = oneflow_api.GetPlacementSymbol( parallel_sig[ibn]) in_op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute( in_op_parallel_desc_sym, str(op_attribute), ibn) return boxing_util.BoxingTo(builder, origin_blob_object, in_op_arg_parallel_attr) def BuildInstruction(builder): with blob_register_util.BnInOp2BlobObjectScope( blob_register, op_attribute) as bn_in_op2blob_object: def GetPhysicalInBlob(i): return GetInBlobObject(builder, i, bn_in_op2blob_object) in_blob_objects = [ GetPhysicalInBlob(i) for i in range(parallel_size) ] bn_in_op2blob_object[ "out"] = builder.PackPhysicalBlobsToLogicalBlob( in_blob_objects, op_arg_parallel_attr, op_arg_blob_attr) oneflow_api.deprecated.LogicalRun(BuildInstruction)
def GetInBlobObject(builder, ibn, bn_in_op2blob_object): origin_blob_object = bn_in_op2blob_object[ibn] in_op_parallel_desc_sym = oneflow_api.GetPlacementSymbol( parallel_sig[ibn]) in_op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute( in_op_parallel_desc_sym, str(op_attribute), ibn) return boxing_util.BoxingTo(builder, origin_blob_object, in_op_arg_parallel_attr)
def _StatelessCall( self, stream_tag, op_attribute, op_parallel_desc_sym=None, blob_parallel_desc_sym=None, bn_in_op2blob_object={}, get_delegate_blob_object=None, ): assert callable(get_delegate_blob_object) if op_attribute.parallel_signature.HasField("op_parallel_desc_symbol_id"): symbol_id = op_attribute.parallel_signature.op_parallel_desc_symbol_id op_parallel_desc_sym = oneflow_api.GetPlacementSymbol(symbol_id) assert op_parallel_desc_sym is not None def DelegateBlobObject4Ibn(ibn): op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute( op_parallel_desc_sym, str(op_attribute), ibn ) return get_delegate_blob_object(bn_in_op2blob_object[ibn], op_arg_parallel_attr) op_conf = op_attribute.op_conf assert op_conf.HasField("scope_symbol_id"), op_conf scope_symbol = oneflow_api.GetScopeSymbol(op_conf.scope_symbol_id) job_desc_sym = scope_symbol.job_desc_symbol op_conf_sym = self._GetOpConfSymbol(op_conf) op_node_signature_sym = self._GetOpNodeSignatureSymbol(op_attribute) opkernel_obj = self.GetSharedOpKernelObject4ParallelConfSymbol(op_parallel_desc_sym) assert opkernel_obj.parallel_desc_symbol == op_parallel_desc_sym, ( str(opkernel_obj.parallel_desc_symbol.parallel_conf), str(op_parallel_desc_sym.parallel_conf), ) const_input_operand_blob_objects = self._GetConstInputOperandBlobObjects( op_attribute, blob_object4ibn=DelegateBlobObject4Ibn ) mutable_input_operand_blob_objects = self._GetMutableInputOperandBlobObjects( op_attribute, blob_object4ibn=DelegateBlobObject4Ibn ) mut1_operand_blob_objects = self._GetMut1OperandBlobObjects( op_attribute, blob_parallel_desc_sym, bn_in_op2blob_object=bn_in_op2blob_object, ) mut2_operand_blob_objects = self._GetMut2OperandBlobObjects( op_attribute, blob_parallel_desc_sym, bn_in_op2blob_object=bn_in_op2blob_object, ) is_user_op = op_attribute.op_conf.HasField("user_conf") instruction_prefix = "User" if is_user_op else "System" self._StatelessCallOpKernel( "%s.%sStatelessCallOpKernel" % (stream_tag, instruction_prefix), op_parallel_desc_sym, job_desc_sym, op_conf_sym, op_node_signature_sym, opkernel_obj, const_input_operand_blob_objects, mutable_input_operand_blob_objects, mut1_operand_blob_objects, mut2_operand_blob_objects, )
def NewOpKernelObject(self, op_conf): assert op_conf.HasField("scope_symbol_id") scope_symbol = oneflow_api.GetScopeSymbol(op_conf.scope_symbol_id) op_conf_sym = self._GetOpConfSymbol(op_conf) parallel_desc_sym_id = c_api_util.GetOpParallelSymbolId(op_conf) parallel_desc_symbol = oneflow_api.GetPlacementSymbol(parallel_desc_sym_id) object_id = self._NewOpKernelObject( parallel_desc_symbol, scope_symbol.job_desc_symbol, op_conf_sym ) return OpKernelObject(object_id, op_conf, self.object_releaser())
def NewOpKernelObject(self, op_conf): assert op_conf.HasField("scope_symbol_id") scope_symbol = oneflow_api.GetScopeSymbol(op_conf.scope_symbol_id) cfg_op_conf = oneflow_api.deprecated.MakeOpConfByString(str(op_conf)) op_conf_sym = self.GetOpConfSymbol(cfg_op_conf) parallel_desc_sym_id = c_api_util.GetOpParallelSymbolId(op_conf) parallel_desc_symbol = oneflow_api.GetPlacementSymbol(parallel_desc_sym_id) object_id = self._NewOpKernelObject(parallel_desc_symbol, scope_symbol.job_desc_symbol, op_conf_sym) return oneflow_api.OpKernelObject(object_id, cfg_op_conf, self.object_releaser())
def _GetOpParallelSymbol(op_conf): assert op_conf.HasField("scope_symbol_id") symbol_id = c_api_util.GetOpParallelSymbolId(op_conf) return oneflow_api.GetPlacementSymbol(symbol_id)