Exemplo n.º 1
0
import collections
import contextlib
import heapq
import re
import lingvo.compat as tf
from lingvo.core import hyperparams
from lingvo.core import nested_map
from lingvo.core import thread_local_utils
import numpy as np

# Helper class to record the current infeed host we are working on.
InfeedContext = collections.namedtuple(
    'InfeedContext', ['infeed_host_index', 'num_infeed_hosts'])

_INFEED_CONTEXT_STACK = thread_local_utils.ThreadLocalStack()


@contextlib.contextmanager
def InfeedContextScope(infeed_host_index, num_infeed_hosts):
    _INFEED_CONTEXT_STACK.stack.append(
        InfeedContext(infeed_host_index, num_infeed_hosts))
    try:
        yield
    finally:
        _INFEED_CONTEXT_STACK.stack.pop()


def GetInfeedContext():
    return (_INFEED_CONTEXT_STACK.stack[-1] if _INFEED_CONTEXT_STACK.stack else
            InfeedContext(infeed_host_index=0, num_infeed_hosts=1))
Exemplo n.º 2
0
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for scatter updates."""

import contextlib
import lingvo.compat as tf
from lingvo.core import py_utils
from lingvo.core import thread_local_utils

_global_inplace_update_stack = thread_local_utils.ThreadLocalStack()


@contextlib.contextmanager
def SetInplaceUpdate(inplace_update):
    _global_inplace_update_stack.stack.append(inplace_update)
    try:
        yield
    finally:
        _global_inplace_update_stack.stack.pop()


def UseInplaceUpdate():
    if not _global_inplace_update_stack.stack:
        # TODO(rpang): set the default value to False in a follow-up CL.
        return True
Exemplo n.º 3
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Specification of a training cluster."""

import heapq
import lingvo.compat as tf
from lingvo.core import hyperparams
from lingvo.core import nested_map
from lingvo.core import thread_local_utils
import numpy as np

_CLUSTER_STACK = thread_local_utils.ThreadLocalStack()


class _Cluster:
    """The whole training cluster from a single task's point of view."""
    @classmethod
    def _JobSpec(cls, replicas):
        """Construct a job spec param with the given number of replicas."""
        p = hyperparams.Params()
        # By default, we use /job:localhost so that most of tests can just
        # work out of the box. trainer.py will then set job names accordingly.
        p.Define('name', '/job:localhost',
                 'TensorFlow job spec, e.g., /job:trainer, /job:ps')
        p.Define('replicas', replicas, 'The number of tasks of a job.')
        p.Define(
            'targets', '', 'The target network address(es) to which we can '