Ejemplo n.º 1
0
 def __init__(self, symbol_id, scope_proto, parent_scope_symbol=None):
     Symbol.__init__(self, symbol_id, scope_proto)
     self.parent_scope_symbol_ = parent_scope_symbol
     self.job_desc_symbol_ = oneflow_api.GetJobConfSymbol(
         scope_proto.job_desc_symbol_id())
     self.device_parallel_desc_symbol_ = oneflow_api.GetPlacementSymbol(
         scope_proto.device_parallel_desc_symbol_id())
     self.host_parallel_desc_symbol_ = oneflow_api.GetPlacementSymbol(
         scope_proto.host_parallel_desc_symbol_id())
     self.auto_increment_id_ = 0
Ejemplo n.º 2
0
 def GetOutBlobParallelDescSymbol(obn):
     parallel_signature = op_attribute.parallel_signature
     bn2symbol_id = parallel_signature.bn_in_op2parallel_desc_symbol_id
     if obn in bn2symbol_id:
         return oneflow_api.GetPlacementSymbol(bn2symbol_id[obn])
     else:
         return parallel_desc_sym
Ejemplo n.º 3
0
def DistributeConcatOrAdd(op_attribute, parallel_conf, blob_register):
    op_parallel_desc_sym = oneflow_api.GetPlacementSymbol(
        op_attribute.parallel_signature.op_parallel_desc_symbol_id)
    parallel_size = len(op_attribute.input_bns)
    op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute(
        op_parallel_desc_sym, str(op_attribute), "out")
    op_arg_blob_attr = oneflow_api.GetOpArgBlobAttribute(
        str(op_attribute), "out")
    parallel_sig = op_attribute.parallel_signature.bn_in_op2parallel_desc_symbol_id

    def GetInBlobObject(builder, i, bn_in_op2blob_object):
        ibn = "in_%s" % i
        origin_blob_object = bn_in_op2blob_object[ibn]
        in_op_parallel_desc_sym = oneflow_api.GetPlacementSymbol(
            parallel_sig[ibn])
        in_op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute(
            in_op_parallel_desc_sym, str(op_attribute), ibn)
        return boxing_util.BoxingTo(builder, origin_blob_object,
                                    in_op_arg_parallel_attr)

    def BuildInstruction(builder):
        with blob_register_util.BnInOp2BlobObjectScope(
                blob_register, op_attribute) as bn_in_op2blob_object:

            def GetPhysicalInBlob(i):
                return GetInBlobObject(builder, i, bn_in_op2blob_object)

            in_blob_objects = [
                GetPhysicalInBlob(i) for i in range(parallel_size)
            ]
            bn_in_op2blob_object[
                "out"] = builder.PackPhysicalBlobsToLogicalBlob(
                    in_blob_objects, op_arg_parallel_attr, op_arg_blob_attr)

    oneflow_api.deprecated.LogicalRun(BuildInstruction)
Ejemplo n.º 4
0
 def GetInBlobObject(builder, ibn, bn_in_op2blob_object):
     origin_blob_object = bn_in_op2blob_object[ibn]
     in_op_parallel_desc_sym = oneflow_api.GetPlacementSymbol(
         parallel_sig[ibn])
     in_op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute(
         in_op_parallel_desc_sym, str(op_attribute), ibn)
     return boxing_util.BoxingTo(builder, origin_blob_object,
                                 in_op_arg_parallel_attr)
Ejemplo n.º 5
0
def _StatelessCall(
    self,
    stream_tag,
    op_attribute,
    op_parallel_desc_sym=None,
    blob_parallel_desc_sym=None,
    bn_in_op2blob_object={},
    get_delegate_blob_object=None,
):
    assert callable(get_delegate_blob_object)
    if op_attribute.parallel_signature.HasField("op_parallel_desc_symbol_id"):
        symbol_id = op_attribute.parallel_signature.op_parallel_desc_symbol_id
        op_parallel_desc_sym = oneflow_api.GetPlacementSymbol(symbol_id)
    assert op_parallel_desc_sym is not None

    def DelegateBlobObject4Ibn(ibn):
        op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute(
            op_parallel_desc_sym, str(op_attribute), ibn
        )
        return get_delegate_blob_object(bn_in_op2blob_object[ibn], op_arg_parallel_attr)

    op_conf = op_attribute.op_conf
    assert op_conf.HasField("scope_symbol_id"), op_conf
    scope_symbol = oneflow_api.GetScopeSymbol(op_conf.scope_symbol_id)
    job_desc_sym = scope_symbol.job_desc_symbol
    op_conf_sym = self._GetOpConfSymbol(op_conf)
    op_node_signature_sym = self._GetOpNodeSignatureSymbol(op_attribute)
    opkernel_obj = self.GetSharedOpKernelObject4ParallelConfSymbol(op_parallel_desc_sym)
    assert opkernel_obj.parallel_desc_symbol == op_parallel_desc_sym, (
        str(opkernel_obj.parallel_desc_symbol.parallel_conf),
        str(op_parallel_desc_sym.parallel_conf),
    )
    const_input_operand_blob_objects = self._GetConstInputOperandBlobObjects(
        op_attribute, blob_object4ibn=DelegateBlobObject4Ibn
    )
    mutable_input_operand_blob_objects = self._GetMutableInputOperandBlobObjects(
        op_attribute, blob_object4ibn=DelegateBlobObject4Ibn
    )
    mut1_operand_blob_objects = self._GetMut1OperandBlobObjects(
        op_attribute, blob_parallel_desc_sym, bn_in_op2blob_object=bn_in_op2blob_object,
    )
    mut2_operand_blob_objects = self._GetMut2OperandBlobObjects(
        op_attribute, blob_parallel_desc_sym, bn_in_op2blob_object=bn_in_op2blob_object,
    )
    is_user_op = op_attribute.op_conf.HasField("user_conf")
    instruction_prefix = "User" if is_user_op else "System"
    self._StatelessCallOpKernel(
        "%s.%sStatelessCallOpKernel" % (stream_tag, instruction_prefix),
        op_parallel_desc_sym,
        job_desc_sym,
        op_conf_sym,
        op_node_signature_sym,
        opkernel_obj,
        const_input_operand_blob_objects,
        mutable_input_operand_blob_objects,
        mut1_operand_blob_objects,
        mut2_operand_blob_objects,
    )
Ejemplo n.º 6
0
def NewOpKernelObject(self, op_conf):
    assert op_conf.HasField("scope_symbol_id")
    scope_symbol = oneflow_api.GetScopeSymbol(op_conf.scope_symbol_id)
    op_conf_sym = self._GetOpConfSymbol(op_conf)
    parallel_desc_sym_id = c_api_util.GetOpParallelSymbolId(op_conf)
    parallel_desc_symbol = oneflow_api.GetPlacementSymbol(parallel_desc_sym_id)
    object_id = self._NewOpKernelObject(
        parallel_desc_symbol, scope_symbol.job_desc_symbol, op_conf_sym
    )
    return OpKernelObject(object_id, op_conf, self.object_releaser())
Ejemplo n.º 7
0
def NewOpKernelObject(self, op_conf):
    assert op_conf.HasField("scope_symbol_id")
    scope_symbol = oneflow_api.GetScopeSymbol(op_conf.scope_symbol_id)
    cfg_op_conf = oneflow_api.deprecated.MakeOpConfByString(str(op_conf))
    op_conf_sym = self.GetOpConfSymbol(cfg_op_conf)
    parallel_desc_sym_id = c_api_util.GetOpParallelSymbolId(op_conf)
    parallel_desc_symbol = oneflow_api.GetPlacementSymbol(parallel_desc_sym_id)
    object_id = self._NewOpKernelObject(parallel_desc_symbol,
                                        scope_symbol.job_desc_symbol,
                                        op_conf_sym)
    return oneflow_api.OpKernelObject(object_id, cfg_op_conf,
                                      self.object_releaser())
Ejemplo n.º 8
0
def _GetOpParallelSymbol(op_conf):
    assert op_conf.HasField("scope_symbol_id")
    symbol_id = c_api_util.GetOpParallelSymbolId(op_conf)
    return oneflow_api.GetPlacementSymbol(symbol_id)