Exemple #1
0
def _GetMut2OperandBlobObjects(
    self, op_attribute, parallel_desc_sym, bn_in_op2blob_object={}
):
    mut2_operand_blob_objects = []

    def GetOutBlobParallelDescSymbol(obn):
        parallel_signature = op_attribute.parallel_signature
        bn2symbol_id = parallel_signature.bn_in_op2parallel_desc_symbol_id
        if obn in bn2symbol_id:
            return oneflow_api.GetPlacementSymbol(bn2symbol_id[obn])
        else:
            return parallel_desc_sym

    for obn in op_attribute.output_bns:
        obn2modifier = op_attribute.arg_modifier_signature.obn2output_blob_modifier
        if obn2modifier[obn].header_infered_before_compute:
            continue
        obn_sym = self.GetSymbol4String(obn)
        op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute(
            GetOutBlobParallelDescSymbol(obn), str(op_attribute), obn
        )
        op_arg_blob_attr = oneflow_api.GetOpArgBlobAttribute(str(op_attribute), obn)
        out_blob_object = self.NewBlobObject(op_arg_parallel_attr, op_arg_blob_attr)
        bn_in_op2blob_object[obn] = out_blob_object
        mut2_operand_blob_objects.append((obn_sym, out_blob_object))
    return mut2_operand_blob_objects
Exemple #2
0
def MakeLazyRefBlobObject(self, interface_op_name):
    sess = session_ctx.GetDefaultSession()
    op_attribute = sess.OpAttribute4InterfaceOpName(interface_op_name)
    assert len(op_attribute.output_bns) == 1
    obn = op_attribute.output_bns[0]

    parallel_conf = sess.ParallelConf4LazyInterfaceOpName(interface_op_name)
    if not isinstance(
        parallel_conf, oneflow_api.oneflow.core.job.placement.ParallelConf
    ):
        parallel_conf_cfg = placement_cfg.ParallelConf()
        parallel_conf_cfg.set_device_tag(parallel_conf.device_tag)
        for device_name in parallel_conf.device_name:
            parallel_conf_cfg.add_device_name(device_name)
        parallel_conf = parallel_conf_cfg
    blob_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf)

    op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute(
        blob_parallel_desc_sym, str(op_attribute), obn
    )
    op_arg_blob_attr = oneflow_api.GetOpArgBlobAttribute(str(op_attribute), obn)

    blob_object = self.NewBlobObject(op_arg_parallel_attr, op_arg_blob_attr)
    self.LazyReference(blob_object, interface_op_name)
    return blob_object
Exemple #3
0
def DistributeConcatOrAdd(op_attribute, parallel_conf, blob_register):
    op_parallel_desc_sym = oneflow_api.GetPlacementSymbol(
        op_attribute.parallel_signature.op_parallel_desc_symbol_id)
    parallel_size = len(op_attribute.input_bns)
    op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute(
        op_parallel_desc_sym, str(op_attribute), "out")
    op_arg_blob_attr = oneflow_api.GetOpArgBlobAttribute(
        str(op_attribute), "out")
    parallel_sig = op_attribute.parallel_signature.bn_in_op2parallel_desc_symbol_id

    def GetInBlobObject(builder, i, bn_in_op2blob_object):
        ibn = "in_%s" % i
        origin_blob_object = bn_in_op2blob_object[ibn]
        in_op_parallel_desc_sym = oneflow_api.GetPlacementSymbol(
            parallel_sig[ibn])
        in_op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute(
            in_op_parallel_desc_sym, str(op_attribute), ibn)
        return boxing_util.BoxingTo(builder, origin_blob_object,
                                    in_op_arg_parallel_attr)

    def BuildInstruction(builder):
        with blob_register_util.BnInOp2BlobObjectScope(
                blob_register, op_attribute) as bn_in_op2blob_object:

            def GetPhysicalInBlob(i):
                return GetInBlobObject(builder, i, bn_in_op2blob_object)

            in_blob_objects = [
                GetPhysicalInBlob(i) for i in range(parallel_size)
            ]
            bn_in_op2blob_object[
                "out"] = builder.PackPhysicalBlobsToLogicalBlob(
                    in_blob_objects, op_arg_parallel_attr, op_arg_blob_attr)

    oneflow_api.deprecated.LogicalRun(BuildInstruction)
Exemple #4
0
 def GetInBlobObject(builder, ibn, bn_in_op2blob_object):
     origin_blob_object = bn_in_op2blob_object[ibn]
     in_op_parallel_desc_sym = oneflow_api.GetPlacementSymbol(
         parallel_sig[ibn])
     in_op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute(
         in_op_parallel_desc_sym, str(op_attribute), ibn)
     return boxing_util.BoxingTo(builder, origin_blob_object,
                                 in_op_arg_parallel_attr)
Exemple #5
0
 def BuildInstruction(builder):
     with blob_register_util.BnInOp2BlobObjectScope(
             blob_register, op_attribute) as bn_in_op2blob_object:
         in_blob_object = bn_in_op2blob_object["in"]
         parallel_desc_symbol = in_blob_object.parallel_desc_symbol
         op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute(
             parallel_desc_symbol, str(op_attribute), "out")
         out_blob_object = builder.MakeReferenceBlobObject(
             in_blob_object, op_arg_parallel_attr)
         bn_in_op2blob_object["out"] = out_blob_object
Exemple #6
0
 def DelegateBlobObject4Ibn(ibn):
     op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute(
         op_parallel_desc_sym, str(op_attribute), ibn
     )
     return get_delegate_blob_object(bn_in_op2blob_object[ibn], op_arg_parallel_attr)