def GetInBlobObject(builder, i, bn_in_op2blob_object): ibn = "in_%s" % i origin_blob_object = bn_in_op2blob_object[ibn] in_op_parallel_desc_sym = symbol_storage.GetSymbol4Id( parallel_sig[ibn]) in_op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute( in_op_parallel_desc_sym, op_attribute, ibn) return boxing_util.BoxingTo(builder, origin_blob_object, in_op_arg_parallel_attr)
def BuildInstruction(builder): with blob_register.BnInOp2BlobObjectScope( op_attribute) as bn_in_op2blob_object: in_blob_object = bn_in_op2blob_object["in"] parallel_desc_symbol = in_blob_object.parallel_desc_symbol op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute( parallel_desc_symbol, op_attribute, "out") out_blob_object = builder.MakeReferenceBlobObject( in_blob_object, op_arg_parallel_attr) bn_in_op2blob_object["out"] = out_blob_object
def MakeLazyRefBlobObject(self, interface_op_name): sess = session_ctx.GetDefaultSession() op_attribute = sess.OpAttribute4InterfaceOpName(interface_op_name) assert len(op_attribute.output_bns) == 1 obn = op_attribute.output_bns[0] blob_parallel_desc_sym_id = op_attribute.parallel_signature.bn_in_op2parallel_desc_symbol_id[ obn] blob_parallel_desc_sym = symbol_storage.GetSymbol4Id( blob_parallel_desc_sym_id) op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute( blob_parallel_desc_sym, op_attribute, obn) op_arg_blob_attr = op_arg_util.GetOpArgBlobAttribute(op_attribute, obn) blob_object = self._NewBlobObject(op_arg_parallel_attr, op_arg_blob_attr) self._LazyReference(blob_object, interface_op_name) return blob_object
def _GetMut1OperandBlobObjects( self, op_attribute, parallel_desc_sym, bn_in_op2blob_object={} ): mut1_operand_blob_objects = [] for ibn in op_attribute.input_bns: ibn2modifier = op_attribute.arg_modifier_signature.ibn2input_blob_modifier if not ibn2modifier[ibn].is_mutable: continue ibn_sym = self.GetSymbol4String(ibn) ref_blob_object = bn_in_op2blob_object[ibn] mut1_operand_blob_objects.append((ibn_sym, ref_blob_object)) def GetOutBlobParallelDescSymbol(obn): parallel_signature = op_attribute.parallel_signature bn2symbol_id = parallel_signature.bn_in_op2parallel_desc_symbol_id if obn in bn2symbol_id: return symbol_storage.GetSymbol4Id(bn2symbol_id[obn]) else: return parallel_desc_sym def OutputBns(): obn2modifier = op_attribute.arg_modifier_signature.obn2output_blob_modifier for obn in op_attribute.output_bns: if obn2modifier[obn].header_infered_before_compute: yield obn for tmp_bn in op_attribute.tmp_bns: yield tmp_bn for obn in OutputBns(): obn_sym = self.GetSymbol4String(obn) op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute( GetOutBlobParallelDescSymbol(obn), op_attribute, obn ) op_arg_blob_attr = op_arg_util.GetOpArgBlobAttribute(op_attribute, obn) out_blob_object = self._NewBlobObject( op_arg_parallel_attr, op_arg_blob_attr ) lbi = op_attribute.arg_signature.bn_in_op2lbi[obn] bn_in_op2blob_object[obn] = out_blob_object mut1_operand_blob_objects.append((obn_sym, out_blob_object)) return mut1_operand_blob_objects
def DelegateBlobObject4Ibn(ibn): op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute( op_parallel_desc_sym, op_attribute, ibn) return get_delegate_blob_object(bn_in_op2blob_object[ibn], op_arg_parallel_attr)