def _GetMut2OperandBlobObjects( self, op_attribute, parallel_desc_sym, bn_in_op2blob_object={} ): mut2_operand_blob_objects = [] def GetOutBlobParallelDescSymbol(obn): parallel_signature = op_attribute.parallel_signature bn2symbol_id = parallel_signature.bn_in_op2parallel_desc_symbol_id if obn in bn2symbol_id: return oneflow_api.GetPlacementSymbol(bn2symbol_id[obn]) else: return parallel_desc_sym for obn in op_attribute.output_bns: obn2modifier = op_attribute.arg_modifier_signature.obn2output_blob_modifier if obn2modifier[obn].header_infered_before_compute: continue obn_sym = self.GetSymbol4String(obn) op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute( GetOutBlobParallelDescSymbol(obn), str(op_attribute), obn ) op_arg_blob_attr = oneflow_api.GetOpArgBlobAttribute(str(op_attribute), obn) out_blob_object = self.NewBlobObject(op_arg_parallel_attr, op_arg_blob_attr) bn_in_op2blob_object[obn] = out_blob_object mut2_operand_blob_objects.append((obn_sym, out_blob_object)) return mut2_operand_blob_objects
def MakeLazyRefBlobObject(self, interface_op_name): sess = session_ctx.GetDefaultSession() op_attribute = sess.OpAttribute4InterfaceOpName(interface_op_name) assert len(op_attribute.output_bns) == 1 obn = op_attribute.output_bns[0] parallel_conf = sess.ParallelConf4LazyInterfaceOpName(interface_op_name) if not isinstance( parallel_conf, oneflow_api.oneflow.core.job.placement.ParallelConf ): parallel_conf_cfg = placement_cfg.ParallelConf() parallel_conf_cfg.set_device_tag(parallel_conf.device_tag) for device_name in parallel_conf.device_name: parallel_conf_cfg.add_device_name(device_name) parallel_conf = parallel_conf_cfg blob_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf) op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute( blob_parallel_desc_sym, str(op_attribute), obn ) op_arg_blob_attr = oneflow_api.GetOpArgBlobAttribute(str(op_attribute), obn) blob_object = self.NewBlobObject(op_arg_parallel_attr, op_arg_blob_attr) self.LazyReference(blob_object, interface_op_name) return blob_object
def DistributeConcatOrAdd(op_attribute, parallel_conf, blob_register): op_parallel_desc_sym = oneflow_api.GetPlacementSymbol( op_attribute.parallel_signature.op_parallel_desc_symbol_id) parallel_size = len(op_attribute.input_bns) op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute( op_parallel_desc_sym, str(op_attribute), "out") op_arg_blob_attr = oneflow_api.GetOpArgBlobAttribute( str(op_attribute), "out") parallel_sig = op_attribute.parallel_signature.bn_in_op2parallel_desc_symbol_id def GetInBlobObject(builder, i, bn_in_op2blob_object): ibn = "in_%s" % i origin_blob_object = bn_in_op2blob_object[ibn] in_op_parallel_desc_sym = oneflow_api.GetPlacementSymbol( parallel_sig[ibn]) in_op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute( in_op_parallel_desc_sym, str(op_attribute), ibn) return boxing_util.BoxingTo(builder, origin_blob_object, in_op_arg_parallel_attr) def BuildInstruction(builder): with blob_register_util.BnInOp2BlobObjectScope( blob_register, op_attribute) as bn_in_op2blob_object: def GetPhysicalInBlob(i): return GetInBlobObject(builder, i, bn_in_op2blob_object) in_blob_objects = [ GetPhysicalInBlob(i) for i in range(parallel_size) ] bn_in_op2blob_object[ "out"] = builder.PackPhysicalBlobsToLogicalBlob( in_blob_objects, op_arg_parallel_attr, op_arg_blob_attr) oneflow_api.deprecated.LogicalRun(BuildInstruction)
def GetInBlobObject(builder, ibn, bn_in_op2blob_object): origin_blob_object = bn_in_op2blob_object[ibn] in_op_parallel_desc_sym = oneflow_api.GetPlacementSymbol( parallel_sig[ibn]) in_op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute( in_op_parallel_desc_sym, str(op_attribute), ibn) return boxing_util.BoxingTo(builder, origin_blob_object, in_op_arg_parallel_attr)
def BuildInstruction(builder): with blob_register_util.BnInOp2BlobObjectScope( blob_register, op_attribute) as bn_in_op2blob_object: in_blob_object = bn_in_op2blob_object["in"] parallel_desc_symbol = in_blob_object.parallel_desc_symbol op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute( parallel_desc_symbol, str(op_attribute), "out") out_blob_object = builder.MakeReferenceBlobObject( in_blob_object, op_arg_parallel_attr) bn_in_op2blob_object["out"] = out_blob_object
def DelegateBlobObject4Ibn(ibn): op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute( op_parallel_desc_sym, str(op_attribute), ibn ) return get_delegate_blob_object(bn_in_op2blob_object[ibn], op_arg_parallel_attr)