コード例 #1
0
 def GetOpArgParallelAttr(builder, produced_blob_object,
                          consumer_op_arg_parallel_attr):
     return op_arg_util.OpArgParallelAttribute(
         get_parallel_desc_symbol(builder, produced_blob_object,
                                  consumer_op_arg_parallel_attr),
         get_sbp_parallel(builder, produced_blob_object,
                          consumer_op_arg_parallel_attr),
         produced_blob_object.op_arg_parallel_attr.opt_mirrored_parallel,
     )
コード例 #2
0
 def BuildAssignInstruction(builder):
     new_parallel_desc_symbol = boxing_util.TryReplaceDeviceTag(
         builder, var_blob_object.parallel_desc_symbol, "cpu")
     consumer_op_arg_parallel_attr = op_arg_util.OpArgParallelAttribute(
         new_parallel_desc_symbol,
         var_blob_object.op_arg_parallel_attr.sbp_parallel,
         var_blob_object.op_arg_parallel_attr.opt_mirrored_parallel,
     )
     tmp_blob_object = boxing_util.BoxingTo(builder, value_blob_object,
                                            consumer_op_arg_parallel_attr)
     boxing_util.Assign(builder, var_blob_object, tmp_blob_object)
コード例 #3
0
ファイル: remote_blob.py プロジェクト: Sodu-Qinming/Oneflow
 def BoxingToSingleDevice(builder):
     parallel_conf = placement_pb.ParallelConf()
     parallel_conf.device_tag = blob_object.parallel_desc_symbol.device_tag
     parallel_conf.device_name.append("{}:{}".format(0, 0))
     tmp_parallel_desc_symbol = builder.GetParallelDescSymbol(parallel_conf)
     tmp_op_arg_parallel_attr = op_arg_util.OpArgParallelAttribute(
         tmp_parallel_desc_symbol,
         blob_object.op_arg_parallel_attr.sbp_parallel,
         blob_object.op_arg_parallel_attr.opt_mirrored_parallel,
     )
     with oneflow.scope.placement(
         self.parallel_conf.device_tag, list(self.parallel_conf.device_name)
     ):
         tmp_blob_object = boxing_util.BoxingTo(
             builder, blob_object, tmp_op_arg_parallel_attr
         )
     nonlocal consistent_blob_name
     consistent_blob_name = "{}-consistent".format(self.logical_blob_name)
     if not blob_register.HasObject4BlobName(consistent_blob_name):
         blob_register.SetObject4BlobName(
             consistent_blob_name, tmp_blob_object
         )