Exemple #1
0
    def _StatelessCall(
        self,
        stream_tag,
        op_attribute,
        op_parallel_desc_sym=None,
        blob_parallel_desc_sym=None,
        bn_in_op2blob_object={},
        get_delegate_blob_object=None,
    ):
        assert callable(get_delegate_blob_object)
        if op_attribute.parallel_signature.HasField(
                "op_parallel_desc_symbol_id"):
            symbol_id = op_attribute.parallel_signature.op_parallel_desc_symbol_id
            op_parallel_desc_sym = symbol_storage.GetSymbol4Id(symbol_id)
        assert op_parallel_desc_sym is not None

        def DelegateBlobObject4Ibn(ibn):
            op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute(
                op_parallel_desc_sym, op_attribute, ibn)
            return get_delegate_blob_object(bn_in_op2blob_object[ibn],
                                            op_arg_parallel_attr)

        op_conf = op_attribute.op_conf
        assert op_conf.HasField("scope_symbol_id"), op_conf
        scope_symbol = symbol_storage.GetSymbol4Id(op_conf.scope_symbol_id)
        job_desc_sym = scope_symbol.job_desc_symbol
        op_conf_sym = self._GetOpConfSymbol(op_conf)
        op_node_signature_sym = self._GetOpNodeSignatureSymbol(op_attribute)
        opkernel_obj = self.GetSharedOpKernelObject4ParallelConfSymbol(
            op_parallel_desc_sym)
        const_input_operand_blob_objects = self._GetConstInputOperandBlobObjects(
            op_attribute, blob_object4ibn=DelegateBlobObject4Ibn)
        mutable_input_operand_blob_objects = self._GetMutableInputOperandBlobObjects(
            op_attribute, blob_object4ibn=DelegateBlobObject4Ibn)
        mut1_operand_blob_objects = self._GetMut1OperandBlobObjects(
            op_attribute,
            blob_parallel_desc_sym,
            bn_in_op2blob_object=bn_in_op2blob_object,
        )
        mut2_operand_blob_objects = self._GetMut2OperandBlobObjects(
            op_attribute,
            blob_parallel_desc_sym,
            bn_in_op2blob_object=bn_in_op2blob_object,
        )
        is_user_op = op_attribute.op_conf.HasField("user_conf")
        instruction_prefix = "User" if is_user_op else "System"
        self._StatelessCallOpKernel(
            "%s.%sStatelessCallOpKernel" % (stream_tag, instruction_prefix),
            op_parallel_desc_sym,
            job_desc_sym,
            op_conf_sym,
            op_node_signature_sym,
            opkernel_obj,
            const_input_operand_blob_objects,
            mutable_input_operand_blob_objects,
            mut1_operand_blob_objects,
            mut2_operand_blob_objects,
        )
Exemple #2
0
 def __init__(self, symbol_id, scope_proto, parent_scope_symbol=None):
     Symbol.__init__(self, symbol_id, scope_proto)
     self.parent_scope_symbol_ = parent_scope_symbol
     self.job_desc_symbol_ = symbol_storage.GetSymbol4Id(
         scope_proto.job_desc_symbol_id)
     self.device_parallel_desc_symbol_ = symbol_storage.GetSymbol4Id(
         scope_proto.device_parallel_desc_symbol_id)
     self.host_parallel_desc_symbol_ = symbol_storage.GetSymbol4Id(
         scope_proto.host_parallel_desc_symbol_id)
     self.auto_increment_id_ = 0
Exemple #3
0
 def NewOpKernelObject(self, op_conf):
     assert op_conf.HasField("scope_symbol_id")
     scope_symbol = symbol_storage.GetSymbol4Id(op_conf.scope_symbol_id)
     op_conf_sym = self._GetOpConfSymbol(op_conf)
     parallel_desc_sym_id = c_api_util.GetOpParallelSymbolId(op_conf)
     parallel_desc_symbol = symbol_storage.GetSymbol4Id(parallel_desc_sym_id)
     object_id = self._NewOpKernelObject(
         parallel_desc_symbol, scope_symbol.job_desc_symbol, op_conf_sym
     )
     return OpKernelObject(object_id, op_conf, self.release_object_)
Exemple #4
0
 def GetOutBlobParallelDescSymbol(obn):
     parallel_signature = op_attribute.parallel_signature
     bn2symbol_id = parallel_signature.bn_in_op2parallel_desc_symbol_id
     if obn in bn2symbol_id:
         return symbol_storage.GetSymbol4Id(bn2symbol_id[obn])
     else:
         return parallel_desc_sym
Exemple #5
0
def DistributeConcatOrAdd(op_attribute, parallel_conf, blob_register):
    op_parallel_desc_sym = symbol_storage.GetSymbol4Id(
        op_attribute.parallel_signature.op_parallel_desc_symbol_id)
    parallel_size = len(op_attribute.input_bns)
    op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute(
        op_parallel_desc_sym, op_attribute, "out")
    op_arg_blob_attr = op_arg_util.GetOpArgBlobAttribute(op_attribute, "out")
    parallel_sig = op_attribute.parallel_signature.bn_in_op2parallel_desc_symbol_id

    def GetInBlobObject(builder, i, bn_in_op2blob_object):
        ibn = "in_%s" % i
        origin_blob_object = bn_in_op2blob_object[ibn]
        in_op_parallel_desc_sym = symbol_storage.GetSymbol4Id(
            parallel_sig[ibn])
        in_op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute(
            in_op_parallel_desc_sym, op_attribute, ibn)
        return boxing_util.BoxingTo(builder, origin_blob_object,
                                    in_op_arg_parallel_attr)

    def BuildInstruction(builder):
        with blob_register.BnInOp2BlobObjectScope(
                op_attribute) as bn_in_op2blob_object:

            def GetPhysicalInBlob(i):
                return GetInBlobObject(builder, i, bn_in_op2blob_object)

            in_blob_objects = [
                GetPhysicalInBlob(i) for i in range(parallel_size)
            ]
            bn_in_op2blob_object[
                "out"] = builder.PackPhysicalBlobsToLogicalBlob(
                    in_blob_objects, op_arg_parallel_attr, op_arg_blob_attr)

    vm_util.LogicalRun(BuildInstruction)
Exemple #6
0
 def GetInBlobObject(builder, ibn, bn_in_op2blob_object):
     origin_blob_object = bn_in_op2blob_object[ibn]
     in_op_parallel_desc_sym = symbol_storage.GetSymbol4Id(
         parallel_sig[ibn])
     in_op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute(
         in_op_parallel_desc_sym, op_attribute, ibn)
     return boxing_util.BoxingTo(builder, origin_blob_object,
                                 in_op_arg_parallel_attr)
def AddScopeToStorage(scope_symbol_id, scope_proto_str):
    if symbol_storage.HasSymbol4SerializedScopeProto(scope_proto_str):
        return
    scope_proto = text_format.Parse(scope_proto_str, scope_pb.ScopeProto())
    parent_scope_symbol = symbol_storage.GetSymbol4Id(
        scope_proto.parent_scope_symbol_id
    )
    symbol = scope_symbol.ScopeSymbol(scope_symbol_id, scope_proto, parent_scope_symbol)
    symbol_storage.SetSymbol4Id(scope_symbol_id, symbol)
    symbol_storage.SetSymbol4SerializedScopeProto(scope_proto_str, symbol)
Exemple #8
0
def AddScopeToStorage(scope_symbol_id, scope_proto):
    scope_proto_str = str(scope_proto)
    if symbol_storage.HasSymbol4SerializedScopeProto(scope_proto_str):
        return
    parent_scope_symbol = symbol_storage.GetSymbol4Id(
        scope_proto.parent_scope_symbol_id())
    symbol = scope_symbol.ScopeSymbol(scope_symbol_id, scope_proto,
                                      parent_scope_symbol)
    symbol_storage.SetSymbol4Id(scope_symbol_id, symbol)
    symbol_storage.SetSymbol4SerializedScopeProto(scope_proto_str, symbol)
Exemple #9
0
    def MakeLazyRefBlobObject(self, interface_op_name):
        sess = session_ctx.GetDefaultSession()
        op_attribute = sess.OpAttribute4InterfaceOpName(interface_op_name)
        assert len(op_attribute.output_bns) == 1
        obn = op_attribute.output_bns[0]

        blob_parallel_desc_sym_id = op_attribute.parallel_signature.bn_in_op2parallel_desc_symbol_id[
            obn]
        blob_parallel_desc_sym = symbol_storage.GetSymbol4Id(
            blob_parallel_desc_sym_id)
        op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute(
            blob_parallel_desc_sym, op_attribute, obn)
        op_arg_blob_attr = op_arg_util.GetOpArgBlobAttribute(op_attribute, obn)

        blob_object = self._NewBlobObject(op_arg_parallel_attr,
                                          op_arg_blob_attr)
        self._LazyReference(blob_object, interface_op_name)
        return blob_object
Exemple #10
0
def _GetOpParallelSymbol(op_conf):
    assert op_conf.HasField("scope_symbol_id")
    symbol_id = c_api_util.GetOpParallelSymbolId(op_conf)
    return symbol_storage.GetSymbol4Id(symbol_id)
Exemple #11
0
def _GetScopeSymbol(op_conf):
    assert op_conf.HasField("scope_symbol_id")
    return symbol_storage.GetSymbol4Id(op_conf.scope_symbol_id)