def _WatcherHandler(handler_uuid, of_blob_ptr): uuid2handler = session_ctx.GetDefaultSession().uuid2watch_handler assert handler_uuid in uuid2handler blob_watched, handler = uuid2handler[handler_uuid] assert callable(handler) ndarray = ofblob.OfBlob(of_blob_ptr).CopyToNdarray() local_blob = local_blob_util.LocalBlob(ndarray, blob_watched.is_dynamic) handler(oft_util.TransformWatchedBlob(local_blob, handler))
def result(self): if self.local_mirrored_blob_ is not None: return self.local_mirrored_blob_ local_blob_list = [x.result.numpy() for x in self.sub_pullers_] local_numpy = local_blob_list[0] # TODO(chengcheng): check list length = 1 in single client. fix after multi-client if len(local_blob_list) > 1: print("WARNING: return tensor list will concat as axis = 0.") local_numpy = np.concatenate(local_blob_list, axis=0) self.local_mirrored_blob_ = local_blob_util.LocalBlob( local_numpy, self.mirrored_blob_.is_dynamic) return self.local_mirrored_blob_
def HandlerParallelIdAndLocalBlob(parallel_id, local_blob): assert parallel_id not in parallel_id2consistent_local_blob parallel_id2consistent_local_blob[parallel_id] = local_blob if len(parallel_id2consistent_local_blob) != len_sub_remote_blobs: return local_blob_list = [ parallel_id2consistent_local_blob[parallel_id] for i in range(len_sub_remote_blobs) ] local_numpy = local_blob_list[0].numpy() if len(local_blob_list) > 1: print("WARNING: watch return tensor list will concat as axis = 0.") local_numpy_list = [x.numpy() for x in local_blob_list] local_numpy = np.concatenate(local_numpy_list, axis=0) local_blob = local_blob_util.LocalBlob(local_numpy, blob_watched.is_dynamic) handler(oft_util.TransformWatchedBlob(local_blob, handler))
def PullCallback(of_blob): self.result_ = local_blob_util.LocalBlob( of_blob.CopyToNdarray(), self.consistent_blob_.is_dynamic) pull_cb()