Пример #1
0
    def _GetMut2OperandBlobObjects(self,
                                   op_attribute,
                                   parallel_desc_sym,
                                   bn_in_op2blob_object={}):
        mut2_operand_blob_objects = []

        def GetOutBlobParallelDescSymbol(obn):
            parallel_signature = op_attribute.parallel_signature
            bn2symbol_id = parallel_signature.bn_in_op2parallel_desc_symbol_id
            if obn in bn2symbol_id:
                return symbol_storage.GetSymbol4Id(bn2symbol_id[obn])
            else:
                return parallel_desc_sym

        for obn in op_attribute.output_bns:
            obn2modifier = op_attribute.arg_modifier_signature.obn2output_blob_modifier
            if obn2modifier[obn].header_infered_before_compute:
                continue
            obn_sym = self.GetSymbol4String(obn)
            op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute(
                GetOutBlobParallelDescSymbol(obn), op_attribute, obn)
            op_arg_blob_attr = op_arg_util.GetOpArgBlobAttribute(
                op_attribute, obn)
            out_blob_object = self._NewBlobObject(op_arg_parallel_attr,
                                                  op_arg_blob_attr)
            bn_in_op2blob_object[obn] = out_blob_object
            mut2_operand_blob_objects.append((obn_sym, out_blob_object))
        return mut2_operand_blob_objects
Пример #2
0
def DistributeConcatOrAdd(op_attribute, parallel_conf, blob_register):
    op_parallel_desc_sym = symbol_storage.GetSymbol4Id(
        op_attribute.parallel_signature.op_parallel_desc_symbol_id)
    parallel_size = len(op_attribute.input_bns)
    op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute(
        op_parallel_desc_sym, op_attribute, "out")
    op_arg_blob_attr = op_arg_util.GetOpArgBlobAttribute(op_attribute, "out")
    parallel_sig = op_attribute.parallel_signature.bn_in_op2parallel_desc_symbol_id

    def GetInBlobObject(builder, i, bn_in_op2blob_object):
        ibn = "in_%s" % i
        origin_blob_object = bn_in_op2blob_object[ibn]
        in_op_parallel_desc_sym = symbol_storage.GetSymbol4Id(
            parallel_sig[ibn])
        in_op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute(
            in_op_parallel_desc_sym, op_attribute, ibn)
        return boxing_util.BoxingTo(builder, origin_blob_object,
                                    in_op_arg_parallel_attr)

    def BuildInstruction(builder):
        with blob_register.BnInOp2BlobObjectScope(
                op_attribute) as bn_in_op2blob_object:

            def GetPhysicalInBlob(i):
                return GetInBlobObject(builder, i, bn_in_op2blob_object)

            in_blob_objects = [
                GetPhysicalInBlob(i) for i in range(parallel_size)
            ]
            bn_in_op2blob_object[
                "out"] = builder.PackPhysicalBlobsToLogicalBlob(
                    in_blob_objects, op_arg_parallel_attr, op_arg_blob_attr)

    vm_util.LogicalRun(BuildInstruction)
Пример #3
0
    def MakeLazyRefBlobObject(self, interface_op_name):
        sess = session_ctx.GetDefaultSession()
        op_attribute = sess.OpAttribute4InterfaceOpName(interface_op_name)
        assert len(op_attribute.output_bns) == 1
        obn = op_attribute.output_bns[0]

        blob_parallel_desc_sym_id = op_attribute.parallel_signature.bn_in_op2parallel_desc_symbol_id[
            obn]
        blob_parallel_desc_sym = symbol_storage.GetSymbol4Id(
            blob_parallel_desc_sym_id)
        op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute(
            blob_parallel_desc_sym, op_attribute, obn)
        op_arg_blob_attr = op_arg_util.GetOpArgBlobAttribute(op_attribute, obn)

        blob_object = self._NewBlobObject(op_arg_parallel_attr,
                                          op_arg_blob_attr)
        self._LazyReference(blob_object, interface_op_name)
        return blob_object
Пример #4
0
    def _GetMut1OperandBlobObjects(
        self, op_attribute, parallel_desc_sym, bn_in_op2blob_object={}
    ):
        mut1_operand_blob_objects = []
        for ibn in op_attribute.input_bns:
            ibn2modifier = op_attribute.arg_modifier_signature.ibn2input_blob_modifier
            if not ibn2modifier[ibn].is_mutable:
                continue
            ibn_sym = self.GetSymbol4String(ibn)
            ref_blob_object = bn_in_op2blob_object[ibn]
            mut1_operand_blob_objects.append((ibn_sym, ref_blob_object))

        def GetOutBlobParallelDescSymbol(obn):
            parallel_signature = op_attribute.parallel_signature
            bn2symbol_id = parallel_signature.bn_in_op2parallel_desc_symbol_id
            if obn in bn2symbol_id:
                return symbol_storage.GetSymbol4Id(bn2symbol_id[obn])
            else:
                return parallel_desc_sym

        def OutputBns():
            obn2modifier = op_attribute.arg_modifier_signature.obn2output_blob_modifier
            for obn in op_attribute.output_bns:
                if obn2modifier[obn].header_infered_before_compute:
                    yield obn

            for tmp_bn in op_attribute.tmp_bns:
                yield tmp_bn

        for obn in OutputBns():
            obn_sym = self.GetSymbol4String(obn)
            op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute(
                GetOutBlobParallelDescSymbol(obn), op_attribute, obn
            )
            op_arg_blob_attr = op_arg_util.GetOpArgBlobAttribute(op_attribute, obn)
            out_blob_object = self._NewBlobObject(
                op_arg_parallel_attr, op_arg_blob_attr
            )
            lbi = op_attribute.arg_signature.bn_in_op2lbi[obn]
            bn_in_op2blob_object[obn] = out_blob_object
            mut1_operand_blob_objects.append((obn_sym, out_blob_object))
        return mut1_operand_blob_objects