コード例 #1
0
ファイル: watcher.py プロジェクト: zhenlin-work/oneflow
def _WatcherHandler(handler_uuid, of_blob_ptr):
    uuid2handler = session_ctx.GetDefaultSession().uuid2watch_handler
    assert handler_uuid in uuid2handler
    blob_watched, handler = uuid2handler[handler_uuid]
    assert callable(handler)
    ndarray = ofblob.OfBlob(of_blob_ptr).CopyToNdarray()
    local_blob = local_blob_util.LocalBlob(ndarray, blob_watched.is_dynamic)
    handler(oft_util.TransformWatchedBlob(local_blob, handler))
コード例 #2
0
 def result(self):
     if self.local_mirrored_blob_ is not None:
         return self.local_mirrored_blob_
     local_blob_list = [x.result.numpy() for x in self.sub_pullers_]
     local_numpy = local_blob_list[0]
     # TODO(chengcheng): check list length = 1 in single client. fix after multi-client
     if len(local_blob_list) > 1:
         print("WARNING: return tensor list will concat as axis = 0.")
         local_numpy = np.concatenate(local_blob_list, axis=0)
     self.local_mirrored_blob_ = local_blob_util.LocalBlob(
         local_numpy, self.mirrored_blob_.is_dynamic)
     return self.local_mirrored_blob_
コード例 #3
0
 def HandlerParallelIdAndLocalBlob(parallel_id, local_blob):
     assert parallel_id not in parallel_id2consistent_local_blob
     parallel_id2consistent_local_blob[parallel_id] = local_blob
     if len(parallel_id2consistent_local_blob) != len_sub_remote_blobs:
         return
     local_blob_list = [
         parallel_id2consistent_local_blob[parallel_id]
         for i in range(len_sub_remote_blobs)
     ]
     local_numpy = local_blob_list[0].numpy()
     if len(local_blob_list) > 1:
         print("WARNING: watch return tensor list will concat as axis = 0.")
         local_numpy_list = [x.numpy() for x in local_blob_list]
         local_numpy = np.concatenate(local_numpy_list, axis=0)
     local_blob = local_blob_util.LocalBlob(local_numpy,
                                            blob_watched.is_dynamic)
     handler(oft_util.TransformWatchedBlob(local_blob, handler))
コード例 #4
0
 def PullCallback(of_blob):
     self.result_ = local_blob_util.LocalBlob(
         of_blob.CopyToNdarray(), self.consistent_blob_.is_dynamic)
     pull_cb()