예제 #1
0
 def GetPhyBlobNumpy(i, phy_blob_object):
     name = "{}/{}".format(self.logical_blob_name, i)
     blob_register.SetObject4BlobName(name, phy_blob_object)
     return (
         eager_blob_util.EagerPhysicalBlob(name).numpy_list()
         if self.is_tensor_list
         else eager_blob_util.EagerPhysicalBlob(name).numpy()
     )
예제 #2
0
        def FetchBlobNumpy(blob_object):
            consistent_blob_name = None

            def BoxingToSingleDevice(builder):
                parallel_conf = placement_pb.ParallelConf()
                parallel_conf.device_tag = blob_object.parallel_desc_symbol.device_tag
                parallel_conf.device_name.append("{}:{}".format(0, 0))
                tmp_parallel_desc_symbol = builder.GetParallelDescSymbol(parallel_conf)
                tmp_op_arg_parallel_attr = op_arg_util.OpArgParallelAttribute(
                    tmp_parallel_desc_symbol,
                    blob_object.op_arg_parallel_attr.sbp_parallel,
                    blob_object.op_arg_parallel_attr.opt_mirrored_parallel,
                )
                with oneflow.scope.placement(
                    self.parallel_conf.device_tag, list(self.parallel_conf.device_name)
                ):
                    tmp_blob_object = boxing_util.BoxingTo(
                        builder, blob_object, tmp_op_arg_parallel_attr
                    )
                nonlocal consistent_blob_name
                consistent_blob_name = "{}-consistent".format(self.logical_blob_name)
                if not blob_register.HasObject4BlobName(consistent_blob_name):
                    blob_register.SetObject4BlobName(
                        consistent_blob_name, tmp_blob_object
                    )

            vm_util.LogicalRun(BoxingToSingleDevice)
            return eager_blob_util.EagerPhysicalBlob(consistent_blob_name).numpy()