Exemple #1
0
 def GetInBlobObject(builder, i, bn_in_op2blob_object):
     ibn = "in_%s" % i
     origin_blob_object = bn_in_op2blob_object[ibn]
     in_op_parallel_desc_sym = symbol_storage.GetSymbol4Id(
         parallel_sig[ibn])
     in_op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute(
         in_op_parallel_desc_sym, op_attribute, ibn)
     return boxing_util.BoxingTo(builder, origin_blob_object,
                                 in_op_arg_parallel_attr)
Exemple #2
0
 def BuildInstruction(builder):
     with blob_register.BnInOp2BlobObjectScope(
             op_attribute) as bn_in_op2blob_object:
         in_blob_object = bn_in_op2blob_object["in"]
         parallel_desc_symbol = in_blob_object.parallel_desc_symbol
         op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute(
             parallel_desc_symbol, op_attribute, "out")
         out_blob_object = builder.MakeReferenceBlobObject(
             in_blob_object, op_arg_parallel_attr)
         bn_in_op2blob_object["out"] = out_blob_object
Exemple #3
0
    def MakeLazyRefBlobObject(self, interface_op_name):
        sess = session_ctx.GetDefaultSession()
        op_attribute = sess.OpAttribute4InterfaceOpName(interface_op_name)
        assert len(op_attribute.output_bns) == 1
        obn = op_attribute.output_bns[0]

        blob_parallel_desc_sym_id = op_attribute.parallel_signature.bn_in_op2parallel_desc_symbol_id[
            obn]
        blob_parallel_desc_sym = symbol_storage.GetSymbol4Id(
            blob_parallel_desc_sym_id)
        op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute(
            blob_parallel_desc_sym, op_attribute, obn)
        op_arg_blob_attr = op_arg_util.GetOpArgBlobAttribute(op_attribute, obn)

        blob_object = self._NewBlobObject(op_arg_parallel_attr,
                                          op_arg_blob_attr)
        self._LazyReference(blob_object, interface_op_name)
        return blob_object
Exemple #4
0
    def _GetMut1OperandBlobObjects(
        self, op_attribute, parallel_desc_sym, bn_in_op2blob_object={}
    ):
        mut1_operand_blob_objects = []
        for ibn in op_attribute.input_bns:
            ibn2modifier = op_attribute.arg_modifier_signature.ibn2input_blob_modifier
            if not ibn2modifier[ibn].is_mutable:
                continue
            ibn_sym = self.GetSymbol4String(ibn)
            ref_blob_object = bn_in_op2blob_object[ibn]
            mut1_operand_blob_objects.append((ibn_sym, ref_blob_object))

        def GetOutBlobParallelDescSymbol(obn):
            parallel_signature = op_attribute.parallel_signature
            bn2symbol_id = parallel_signature.bn_in_op2parallel_desc_symbol_id
            if obn in bn2symbol_id:
                return symbol_storage.GetSymbol4Id(bn2symbol_id[obn])
            else:
                return parallel_desc_sym

        def OutputBns():
            obn2modifier = op_attribute.arg_modifier_signature.obn2output_blob_modifier
            for obn in op_attribute.output_bns:
                if obn2modifier[obn].header_infered_before_compute:
                    yield obn

            for tmp_bn in op_attribute.tmp_bns:
                yield tmp_bn

        for obn in OutputBns():
            obn_sym = self.GetSymbol4String(obn)
            op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute(
                GetOutBlobParallelDescSymbol(obn), op_attribute, obn
            )
            op_arg_blob_attr = op_arg_util.GetOpArgBlobAttribute(op_attribute, obn)
            out_blob_object = self._NewBlobObject(
                op_arg_parallel_attr, op_arg_blob_attr
            )
            lbi = op_attribute.arg_signature.bn_in_op2lbi[obn]
            bn_in_op2blob_object[obn] = out_blob_object
            mut1_operand_blob_objects.append((obn_sym, out_blob_object))
        return mut1_operand_blob_objects
Exemple #5
0
 def DelegateBlobObject4Ibn(ibn):
     op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute(
         op_parallel_desc_sym, op_attribute, ibn)
     return get_delegate_blob_object(bn_in_op2blob_object[ibn],
                                     op_arg_parallel_attr)