Example #1
0
 def BoxingToSingleDevice(builder):
     parallel_conf = placement_cfg.ParallelConf()
     parallel_conf.set_device_tag(
         blob_object.parallel_desc_symbol.device_tag)
     parallel_conf.add_device_name("{}:{}".format(0, 0))
     tmp_parallel_desc_symbol = builder.GetParallelDescSymbol(
         parallel_conf)
     tmp_op_arg_parallel_attr = oneflow_api.OpArgParallelAttribute(
         tmp_parallel_desc_symbol,
         str(blob_object.op_arg_parallel_attr.sbp_parallel),
         str(blob_object.op_arg_parallel_attr.opt_mirrored_parallel
             ),
     )
     with oneflow.scope.placement(
             self.parallel_conf.device_tag(),
             list(self.parallel_conf.device_name()),
     ):
         tmp_blob_object = boxing_util.BoxingTo(
             builder, blob_object, tmp_op_arg_parallel_attr)
     nonlocal consistent_blob_name
     consistent_blob_name = "{}-consistent".format(
         self.logical_blob_name)
     if not blob_register.HasObject4BlobName(consistent_blob_name):
         blob_register.SetObject4BlobName(consistent_blob_name,
                                          tmp_blob_object)
Example #2
0
 def GetOpArgParallelAttr(builder, produced_blob_object,
                          consumer_op_arg_parallel_attr):
     return oneflow_api.OpArgParallelAttribute(
         get_parallel_desc_symbol(builder, produced_blob_object,
                                  consumer_op_arg_parallel_attr),
         str(
             get_sbp_parallel(builder, produced_blob_object,
                              consumer_op_arg_parallel_attr)),
         str(produced_blob_object.op_arg_parallel_attr.opt_mirrored_parallel
             ),
     )
Example #3
0
 def BuildAssignInstruction(builder):
     new_parallel_desc_symbol = boxing_util.TryReplaceDeviceTag(
         builder, var_blob_object.parallel_desc_symbol, "cpu")
     consumer_op_arg_parallel_attr = oneflow_api.OpArgParallelAttribute(
         new_parallel_desc_symbol,
         str(var_blob_object.op_arg_parallel_attr.sbp_parallel),
         str(var_blob_object.op_arg_parallel_attr.opt_mirrored_parallel),
     )
     tmp_blob_object = boxing_util.BoxingTo(builder, value_blob_object,
                                            consumer_op_arg_parallel_attr)
     boxing_util.Assign(builder, var_blob_object, tmp_blob_object)