import collections import contextlib import heapq import re import lingvo.compat as tf from lingvo.core import hyperparams from lingvo.core import nested_map from lingvo.core import thread_local_utils import numpy as np # Helper class to record the current infeed host we are working on. InfeedContext = collections.namedtuple( 'InfeedContext', ['infeed_host_index', 'num_infeed_hosts']) _INFEED_CONTEXT_STACK = thread_local_utils.ThreadLocalStack() @contextlib.contextmanager def InfeedContextScope(infeed_host_index, num_infeed_hosts): _INFEED_CONTEXT_STACK.stack.append( InfeedContext(infeed_host_index, num_infeed_hosts)) try: yield finally: _INFEED_CONTEXT_STACK.stack.pop() def GetInfeedContext(): return (_INFEED_CONTEXT_STACK.stack[-1] if _INFEED_CONTEXT_STACK.stack else InfeedContext(infeed_host_index=0, num_infeed_hosts=1))
# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utilities for scatter updates.""" import contextlib import lingvo.compat as tf from lingvo.core import py_utils from lingvo.core import thread_local_utils _global_inplace_update_stack = thread_local_utils.ThreadLocalStack() @contextlib.contextmanager def SetInplaceUpdate(inplace_update): _global_inplace_update_stack.stack.append(inplace_update) try: yield finally: _global_inplace_update_stack.stack.pop() def UseInplaceUpdate(): if not _global_inplace_update_stack.stack: # TODO(rpang): set the default value to False in a follow-up CL. return True
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Specification of a training cluster.""" import heapq import lingvo.compat as tf from lingvo.core import hyperparams from lingvo.core import nested_map from lingvo.core import thread_local_utils import numpy as np _CLUSTER_STACK = thread_local_utils.ThreadLocalStack() class _Cluster: """The whole training cluster from a single task's point of view.""" @classmethod def _JobSpec(cls, replicas): """Construct a job spec param with the given number of replicas.""" p = hyperparams.Params() # By default, we use /job:localhost so that most of tests can just # work out of the box. trainer.py will then set job names accordingly. p.Define('name', '/job:localhost', 'TensorFlow job spec, e.g., /job:trainer, /job:ps') p.Define('replicas', replicas, 'The number of tasks of a job.') p.Define( 'targets', '', 'The target network address(es) to which we can '