Ejemplo n.º 1
0
def _FindOrCreateDelegateBlobObject(
    builder, Fetch, x_blob_object, op_arg_parallel_attr
):
    if x_blob_object.op_arg_parallel_attr == op_arg_parallel_attr:
        return x_blob_object
    blob_cache = blob_cache_util.FindOrCreateBlobCache(x_blob_object)
    return blob_cache.GetCachedDelegateBlobObject(op_arg_parallel_attr, Fetch)
Ejemplo n.º 2
0
    def _NumpyMirroredList(self):
        physical_blob_objects = []

        def UnpackLogicalBlobToPhysicalBlobs(builder):
            nonlocal physical_blob_objects
            physical_blob_objects = builder.UnpackLogicalBlobToPhysicalBlobs(
                self.blob_object
            )

        def GetPhyBlobNumpy(i, phy_blob_object):
            name = "{}/{}".format(self.logical_blob_name, i)
            blob_register.SetObject4BlobName(name, phy_blob_object)
            return (
                eager_blob_util.EagerPhysicalBlob(name).numpy_list()
                if self.is_tensor_list
                else eager_blob_util.EagerPhysicalBlob(name).numpy()
            )

        def FetchBlobNumpyMirroredList(blob_object):
            vm_util.LogicalRun(UnpackLogicalBlobToPhysicalBlobs)
            return [
                GetPhyBlobNumpy(i, phy_blob_object)
                for i, phy_blob_object in enumerate(physical_blob_objects)
            ]

        blob_cache = blob_cache_util.FindOrCreateBlobCache(self.blob_object)
        return blob_cache.GetCachedNumpyMirroredList(FetchBlobNumpyMirroredList)
Ejemplo n.º 3
0
    def _Numpy(self):
        assert self.is_tensor_list is not True

        def FetchBlobNumpy(blob_object):
            consistent_blob_name = None

            def BoxingToSingleDevice(builder):
                parallel_conf = placement_pb.ParallelConf()
                parallel_conf.device_tag = blob_object.parallel_desc_symbol.device_tag
                parallel_conf.device_name.append("{}:{}".format(0, 0))
                tmp_parallel_desc_symbol = builder.GetParallelDescSymbol(parallel_conf)
                tmp_op_arg_parallel_attr = op_arg_util.OpArgParallelAttribute(
                    tmp_parallel_desc_symbol,
                    blob_object.op_arg_parallel_attr.sbp_parallel,
                    blob_object.op_arg_parallel_attr.opt_mirrored_parallel,
                )
                with oneflow.scope.placement(
                    self.parallel_conf.device_tag, list(self.parallel_conf.device_name)
                ):
                    tmp_blob_object = boxing_util.BoxingTo(
                        builder, blob_object, tmp_op_arg_parallel_attr
                    )
                nonlocal consistent_blob_name
                consistent_blob_name = "{}-consistent".format(self.logical_blob_name)
                if not blob_register.HasObject4BlobName(consistent_blob_name):
                    blob_register.SetObject4BlobName(
                        consistent_blob_name, tmp_blob_object
                    )

            vm_util.LogicalRun(BoxingToSingleDevice)
            return eager_blob_util.EagerPhysicalBlob(consistent_blob_name).numpy()

        blob_cache = blob_cache_util.FindOrCreateBlobCache(self.blob_object)
        return blob_cache.GetCachedNumpy(FetchBlobNumpy)
Ejemplo n.º 4
0
def _GetPhysicalBlobBodyCache(blob_object):
    blob_cache = blob_cache_util.FindOrCreateBlobCache(blob_object)
    return blob_cache.GetBodyCache(_FetchPhysicalBlobBody)
Ejemplo n.º 5
0
def _GetPhysicalBlobHeaderCache(blob_object):
    blob_cache = blob_cache_util.FindOrCreateBlobCache(blob_object)
    return blob_cache.GetHeaderCache(_FetchBlobHeader)