def _FindOrCreateDelegateBlobObject( builder, Fetch, x_blob_object, op_arg_parallel_attr ): if x_blob_object.op_arg_parallel_attr == op_arg_parallel_attr: return x_blob_object blob_cache = blob_cache_util.FindOrCreateBlobCache(x_blob_object) return blob_cache.GetCachedDelegateBlobObject(op_arg_parallel_attr, Fetch)
def _NumpyMirroredList(self): physical_blob_objects = [] def UnpackLogicalBlobToPhysicalBlobs(builder): nonlocal physical_blob_objects physical_blob_objects = builder.UnpackLogicalBlobToPhysicalBlobs( self.blob_object ) def GetPhyBlobNumpy(i, phy_blob_object): name = "{}/{}".format(self.logical_blob_name, i) blob_register.SetObject4BlobName(name, phy_blob_object) return ( eager_blob_util.EagerPhysicalBlob(name).numpy_list() if self.is_tensor_list else eager_blob_util.EagerPhysicalBlob(name).numpy() ) def FetchBlobNumpyMirroredList(blob_object): vm_util.LogicalRun(UnpackLogicalBlobToPhysicalBlobs) return [ GetPhyBlobNumpy(i, phy_blob_object) for i, phy_blob_object in enumerate(physical_blob_objects) ] blob_cache = blob_cache_util.FindOrCreateBlobCache(self.blob_object) return blob_cache.GetCachedNumpyMirroredList(FetchBlobNumpyMirroredList)
def _Numpy(self): assert self.is_tensor_list is not True def FetchBlobNumpy(blob_object): consistent_blob_name = None def BoxingToSingleDevice(builder): parallel_conf = placement_pb.ParallelConf() parallel_conf.device_tag = blob_object.parallel_desc_symbol.device_tag parallel_conf.device_name.append("{}:{}".format(0, 0)) tmp_parallel_desc_symbol = builder.GetParallelDescSymbol(parallel_conf) tmp_op_arg_parallel_attr = op_arg_util.OpArgParallelAttribute( tmp_parallel_desc_symbol, blob_object.op_arg_parallel_attr.sbp_parallel, blob_object.op_arg_parallel_attr.opt_mirrored_parallel, ) with oneflow.scope.placement( self.parallel_conf.device_tag, list(self.parallel_conf.device_name) ): tmp_blob_object = boxing_util.BoxingTo( builder, blob_object, tmp_op_arg_parallel_attr ) nonlocal consistent_blob_name consistent_blob_name = "{}-consistent".format(self.logical_blob_name) if not blob_register.HasObject4BlobName(consistent_blob_name): blob_register.SetObject4BlobName( consistent_blob_name, tmp_blob_object ) vm_util.LogicalRun(BoxingToSingleDevice) return eager_blob_util.EagerPhysicalBlob(consistent_blob_name).numpy() blob_cache = blob_cache_util.FindOrCreateBlobCache(self.blob_object) return blob_cache.GetCachedNumpy(FetchBlobNumpy)
def _GetPhysicalBlobBodyCache(blob_object): blob_cache = blob_cache_util.FindOrCreateBlobCache(blob_object) return blob_cache.GetBodyCache(_FetchPhysicalBlobBody)
def _GetPhysicalBlobHeaderCache(blob_object): blob_cache = blob_cache_util.FindOrCreateBlobCache(blob_object) return blob_cache.GetHeaderCache(_FetchBlobHeader)