kwargs['shape'] = (1,)
      v1 = next_creator(**kwargs)

      kwargs['initial_value'] = v2_value
      kwargs['shape'] = (1,)
      v2 = next_creator(**kwargs)

      return sharded_variable.ShardedVariable([v1, v2])

    with variable_scope.variable_creator_scope(sharded_variable_creator):
      layer = Layer()

    self.assertLen(layer.trainable_weights, 2)
    self.assertEqual(layer.trainable_weights[0], [0])
    self.assertEqual(layer.trainable_weights[1], [1])
    self.assertLen(layer.non_trainable_weights, 2)
    self.assertEqual(layer.non_trainable_weights[0], [2])
    self.assertEqual(layer.non_trainable_weights[1], [3])
    self.assertAllEqual(layer.weights,
                        layer.trainable_weights + layer.non_trainable_weights)
    self.assertAllEqual(layer.trainable_weights, layer.trainable_variables)
    self.assertAllEqual(layer.weights, layer.variables)

    checkpoint_deps = set(dep.ref for dep in layer._checkpoint_dependencies)
    self.assertEqual(checkpoint_deps, set([layer.w, layer.b]))


if __name__ == '__main__':
  v2_compat.enable_v2_behavior()
  test.main()
    parent_package_str=__name__,
    child_package_str=(
        'tensorflow_estimator.python.estimator.api._v2.estimator'))

if not hasattr(_current_module, 'estimator'):
  _component_api_helper.package_hook(
      parent_package_str=__name__,
      child_package_str=(
          'tensorflow_estimator.python.estimator.api.estimator'))
_component_api_helper.package_hook(
    parent_package_str=__name__,
    child_package_str=('tensorflow.python.keras.api._v2.keras'))

# Enable TF2 behaviors
from tensorflow.python.compat import v2_compat as _compat  # pylint: disable=g-import-not-at-top
_compat.enable_v2_behavior()


# Load all plugin libraries from site-packages/tensorflow-plugins if we are
# running under pip.
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin
# directories.
# TODO(gunan): Find a better location for this code snippet.
from tensorflow.python.framework import load_library as _ll
from tensorflow.python.lib.io import file_io as _fi

# Get sitepackages directories for the python installation.
_site_packages_dirs = []
_site_packages_dirs += [_site.USER_SITE]
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
if 'getsitepackages' in dir(_site):
Esempio n. 3
0
 def setUp(self):
     super(TestDistributionStrategyDnnCorrectness, self).setUp()
     v2_compat.enable_v2_behavior()
     np.random.seed(_RANDOM_SEED)
     random_seed.set_random_seed(_RANDOM_SEED)
 def setUp(self):
     super(GeneratorTest, self).setUp()
     v2_compat.enable_v2_behavior()
     config.set_soft_device_placement(False)
    def __call__(self, resources, test_env, proc_func, args, kwargs,
                 use_dill_for_args):
        """The wrapper function that actually gets run in child process(es)."""

        global _barrier

        self._resources = resources
        _barrier = self._resources.barrier
        proc_func = dill.loads(proc_func)
        if use_dill_for_args:
            args = dill.loads(args)
            kwargs = dill.loads(kwargs)

        if faulthandler is not None:
            faulthandler.enable()
            faulthandler.register(signal.SIGTERM, chain=True)

        # All logging should go to stderr to be streamed to the main process.
        logging.set_stderrthreshold(logging.DEBUG)

        # Assign sys.stdout and sys.stderr as duplicates of `streaming_pipe_w` so
        # print() and logging.*() write directly to `streaming_pipe_w`.
        # Unfortunately since we cannot prepend task_type and task_id information to
        # the streamed logs we will need a thread per subprocess to distinguish
        # where the piece of message is from.
        os.dup2(resources.streaming_pipe_w.fileno(), sys.stdout.fileno())
        os.dup2(resources.streaming_pipe_w.fileno(), sys.stderr.fileno())

        pid = os.getpid()
        logging.info('Subprocess with PID %d (%s, %d) is now being started.',
                     pid, test_env.task_type, test_env.task_id)

        # The thread will be dedicated to checking messages from the parent process.
        threading.Thread(  # pylint: disable=unexpected-keyword-arg
            target=self._message_checking_func,
            args=(test_env.task_type, test_env.task_id),
            daemon=True).start()

        if test_env.v2_enabled:
            v2_compat.enable_v2_behavior()

        try:
            with self._runtime_mode(test_env.executing_eagerly):
                return_value = proc_func(*args, **kwargs)
                is_successful = True
                exc_info = None

        except Exception:  # pylint: disable=broad-except
            # Capture all exceptions to be reported to parent process.
            return_value = None
            is_successful = False
            exc_info = sys.exc_info()

            # Re-raise the exception in addition to reporting it to the parent
            # process, so that even if `--test_timeout` flag is set and the
            # error doesn't make it to be shown in parent process before bazel's
            # timeout, the log would still show what happens in this subprocess,
            # instead of silently suppressing the error due to early bazel
            # timeout. Raising an error in the subprocess produces stack trace in
            # the log, but the program continues running.
            raise

        finally:
            info = _ProcessStatusInfo(task_type=test_env.task_type,
                                      is_successful=is_successful,
                                      exc_info=exc_info,
                                      return_value=return_value)
            self._resources.process_status_queue.put(info)
            self._close_streaming()

        # Exit with code 0 as it's considered successful exit at this point.
        sys.exit(0)
Esempio n. 6
0
 def setUp(self):
     super(AutoOutsideCompilationWithKerasTest, self).setUp()
     v2_compat.enable_v2_behavior()
     context.context().soft_device_placement = True
     self.summary_dir = self.get_temp_dir()
    def _proc_func_wrapper(self, proc_func, task_type, task_id,
                           per_process_cluster_spec, rpc_layer, pipe_w, *arg,
                           **kwargs):
        """The wrapper function that actually gets run in child process(es)."""

        pid = os.getpid()
        logging.info('Subprocess with PID %d is now being started.', pid)
        self._get_subprocess_info_queue().put(_SubprocessInfo(pid=pid))

        # Assign sys.stdout and sys.stderr as duplicates of `pipe_w` so print() and
        # logging.*() write directly to `pipe_w`. Unfortunately since we cannot
        # prepend task_type and task_id information to the streamed logs we will
        # need a thread per subprocess to distinguish where the piece of message is
        # from.
        os.dup2(pipe_w.fileno(), sys.stdout.fileno())
        os.dup2(pipe_w.fileno(), sys.stderr.fileno())

        # The thread will be dedicated to checking messages from the parent process.
        threading.Thread(  # pylint: disable=unexpected-keyword-arg
            target=self._message_checking_func,
            args=(task_type, task_id),
            daemon=True).start()

        os.environ['GRPC_FAIL_FAST'] = str(self._grpc_fail_fast)
        tf_config_dict = {
            'cluster': per_process_cluster_spec,
            'task': {
                'type': task_type,
                'index': task_id,
            },
        }
        if rpc_layer is not None:
            tf_config_dict['rpc_layer'] = rpc_layer
        os.environ['TF_CONFIG'] = json.dumps(tf_config_dict)

        if self._v2_enabled:
            v2_compat.enable_v2_behavior()

        return_value = None

        try:
            with self._runtime_mode():
                return_value = proc_func(*arg, **kwargs)

        except Exception:  # pylint: disable=broad-except
            # Capture all exceptions to be reported to parent process.
            self._finish_process(
                _ProcessStatusInfo(task_type=task_type,
                                   is_successful=False,
                                   exc_info=sys.exc_info()), return_value)

            # Re-raise the exception in addition to reporting it to the parent
            # process, so that even if `--test_timeout` flag is set and the
            # error doesn't make it to be shown in parent process before bazel's
            # timeout, the log would still show what happens in this subprocess,
            # instead of silently suppressing the error due to early bazel
            # timeout. Raising an error in the subprocess produces stack trace in
            # the log, but the program continues running.
            raise

        self._finish_process(
            _ProcessStatusInfo(task_type=task_type,
                               is_successful=True,
                               exc_info=None), return_value)
Esempio n. 8
0
 def setUp(self):
     super(GeneratorTest, self).setUp()
     v2_compat.enable_v2_behavior()
  def _proc_func_wrapper(self, proc_func, task_type, task_id,
                         per_process_cluster_spec, rpc_layer, *arg, **kwargs):
    """The wrapper function that actually gets run in child process(es)."""

    if self._capture_std_stream:
      # TODO(yuefengz): consider a lighter way of capturing std streams.
      stdout_collector = _LogCollector(sys.__stdout__)
      stderr_collector = _LogCollector(sys.__stderr__)
      sys.stdout = stdout_collector
      sys.stderr = stderr_collector
    else:
      stdout_collector = None
      stderr_collector = None

    # The thread will be dedicated to checking messages from parent process.
    threading.Thread(
        target=self._message_checking_func,
        args=(task_type, task_id, stdout_collector, stderr_collector)).start()

    os.environ['GRPC_FAIL_FAST'] = str(self._grpc_fail_fast)
    tf_config_dict = {
        'cluster': per_process_cluster_spec,
        'task': {
            'type': task_type,
            'index': task_id,
        },
    }
    if rpc_layer is not None:
      tf_config_dict['rpc_layer'] = rpc_layer
    os.environ['TF_CONFIG'] = json.dumps(tf_config_dict)

    if self._v2_enabled:
      v2_compat.enable_v2_behavior()

    return_value = None

    if self._max_run_time is not None:
      # Register an sigalarm handler to exit the process when it reaches
      # `timeout` seconds. A program reaching `timeout` doesn't necessarily
      # indicate an issue.
      def handler(signum, frame):
        del signum, frame
        self._finish_process(
            _ProcessStatusInfo(
                task_type=task_type, is_successful=True, exc_info=None), None,
            stdout_collector, stderr_collector)
        # `os._exit(0)` is used to more reliably terminate a subprocess.
        os._exit(0)  # pylint: disable=protected-access

      signal.signal(signal.SIGALRM, handler)
      signal.alarm(self._max_run_time)

    try:
      with self._runtime_mode():
        return_value = proc_func(*arg, **kwargs)
    except Exception:  # pylint: disable=broad-except
      # Capture all exceptions to be reported to parent process.
      self._finish_process(
          _ProcessStatusInfo(
              task_type=task_type, is_successful=False,
              exc_info=sys.exc_info()), return_value, stdout_collector,
          stderr_collector)

      # Re-raise the exception in addition to reporting it to the parent
      # process, so that even if `--test_timeout` flag is set and the
      # error doesn't make it to be shown in parent process before bazel's
      # timeout, the log would still show what happens in this subprocess,
      # instead of silently suppressing the error due to early bazel
      # timeout. Raising an error in the subprocess produces stack trace in
      # the log, but the program continues running.
      raise

    self._finish_process(
        _ProcessStatusInfo(
            task_type=task_type, is_successful=True, exc_info=None),
        return_value, stdout_collector, stderr_collector)