kwargs['shape'] = (1,) v1 = next_creator(**kwargs) kwargs['initial_value'] = v2_value kwargs['shape'] = (1,) v2 = next_creator(**kwargs) return sharded_variable.ShardedVariable([v1, v2]) with variable_scope.variable_creator_scope(sharded_variable_creator): layer = Layer() self.assertLen(layer.trainable_weights, 2) self.assertEqual(layer.trainable_weights[0], [0]) self.assertEqual(layer.trainable_weights[1], [1]) self.assertLen(layer.non_trainable_weights, 2) self.assertEqual(layer.non_trainable_weights[0], [2]) self.assertEqual(layer.non_trainable_weights[1], [3]) self.assertAllEqual(layer.weights, layer.trainable_weights + layer.non_trainable_weights) self.assertAllEqual(layer.trainable_weights, layer.trainable_variables) self.assertAllEqual(layer.weights, layer.variables) checkpoint_deps = set(dep.ref for dep in layer._checkpoint_dependencies) self.assertEqual(checkpoint_deps, set([layer.w, layer.b])) if __name__ == '__main__': v2_compat.enable_v2_behavior() test.main()
parent_package_str=__name__, child_package_str=( 'tensorflow_estimator.python.estimator.api._v2.estimator')) if not hasattr(_current_module, 'estimator'): _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=( 'tensorflow_estimator.python.estimator.api.estimator')) _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=('tensorflow.python.keras.api._v2.keras')) # Enable TF2 behaviors from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top _compat.enable_v2_behavior() # Load all plugin libraries from site-packages/tensorflow-plugins if we are # running under pip. # TODO(gunan): Enable setting an environment variable to define arbitrary plugin # directories. # TODO(gunan): Find a better location for this code snippet. from tensorflow.python.framework import load_library as _ll from tensorflow.python.lib.io import file_io as _fi # Get sitepackages directories for the python installation. _site_packages_dirs = [] _site_packages_dirs += [_site.USER_SITE] _site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p] if 'getsitepackages' in dir(_site):
def setUp(self): super(TestDistributionStrategyDnnCorrectness, self).setUp() v2_compat.enable_v2_behavior() np.random.seed(_RANDOM_SEED) random_seed.set_random_seed(_RANDOM_SEED)
def setUp(self): super(GeneratorTest, self).setUp() v2_compat.enable_v2_behavior() config.set_soft_device_placement(False)
def __call__(self, resources, test_env, proc_func, args, kwargs, use_dill_for_args): """The wrapper function that actually gets run in child process(es).""" global _barrier self._resources = resources _barrier = self._resources.barrier proc_func = dill.loads(proc_func) if use_dill_for_args: args = dill.loads(args) kwargs = dill.loads(kwargs) if faulthandler is not None: faulthandler.enable() faulthandler.register(signal.SIGTERM, chain=True) # All logging should go to stderr to be streamed to the main process. logging.set_stderrthreshold(logging.DEBUG) # Assign sys.stdout and sys.stderr as duplicates of `streaming_pipe_w` so # print() and logging.*() write directly to `streaming_pipe_w`. # Unfortunately since we cannot prepend task_type and task_id information to # the streamed logs we will need a thread per subprocess to distinguish # where the piece of message is from. os.dup2(resources.streaming_pipe_w.fileno(), sys.stdout.fileno()) os.dup2(resources.streaming_pipe_w.fileno(), sys.stderr.fileno()) pid = os.getpid() logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid, test_env.task_type, test_env.task_id) # The thread will be dedicated to checking messages from the parent process. threading.Thread( # pylint: disable=unexpected-keyword-arg target=self._message_checking_func, args=(test_env.task_type, test_env.task_id), daemon=True).start() if test_env.v2_enabled: v2_compat.enable_v2_behavior() try: with self._runtime_mode(test_env.executing_eagerly): return_value = proc_func(*args, **kwargs) is_successful = True exc_info = None except Exception: # pylint: disable=broad-except # Capture all exceptions to be reported to parent process. return_value = None is_successful = False exc_info = sys.exc_info() # Re-raise the exception in addition to reporting it to the parent # process, so that even if `--test_timeout` flag is set and the # error doesn't make it to be shown in parent process before bazel's # timeout, the log would still show what happens in this subprocess, # instead of silently suppressing the error due to early bazel # timeout. Raising an error in the subprocess produces stack trace in # the log, but the program continues running. raise finally: info = _ProcessStatusInfo(task_type=test_env.task_type, is_successful=is_successful, exc_info=exc_info, return_value=return_value) self._resources.process_status_queue.put(info) self._close_streaming() # Exit with code 0 as it's considered successful exit at this point. sys.exit(0)
def setUp(self): super(AutoOutsideCompilationWithKerasTest, self).setUp() v2_compat.enable_v2_behavior() context.context().soft_device_placement = True self.summary_dir = self.get_temp_dir()
def _proc_func_wrapper(self, proc_func, task_type, task_id, per_process_cluster_spec, rpc_layer, pipe_w, *arg, **kwargs): """The wrapper function that actually gets run in child process(es).""" pid = os.getpid() logging.info('Subprocess with PID %d is now being started.', pid) self._get_subprocess_info_queue().put(_SubprocessInfo(pid=pid)) # Assign sys.stdout and sys.stderr as duplicates of `pipe_w` so print() and # logging.*() write directly to `pipe_w`. Unfortunately since we cannot # prepend task_type and task_id information to the streamed logs we will # need a thread per subprocess to distinguish where the piece of message is # from. os.dup2(pipe_w.fileno(), sys.stdout.fileno()) os.dup2(pipe_w.fileno(), sys.stderr.fileno()) # The thread will be dedicated to checking messages from the parent process. threading.Thread( # pylint: disable=unexpected-keyword-arg target=self._message_checking_func, args=(task_type, task_id), daemon=True).start() os.environ['GRPC_FAIL_FAST'] = str(self._grpc_fail_fast) tf_config_dict = { 'cluster': per_process_cluster_spec, 'task': { 'type': task_type, 'index': task_id, }, } if rpc_layer is not None: tf_config_dict['rpc_layer'] = rpc_layer os.environ['TF_CONFIG'] = json.dumps(tf_config_dict) if self._v2_enabled: v2_compat.enable_v2_behavior() return_value = None try: with self._runtime_mode(): return_value = proc_func(*arg, **kwargs) except Exception: # pylint: disable=broad-except # Capture all exceptions to be reported to parent process. self._finish_process( _ProcessStatusInfo(task_type=task_type, is_successful=False, exc_info=sys.exc_info()), return_value) # Re-raise the exception in addition to reporting it to the parent # process, so that even if `--test_timeout` flag is set and the # error doesn't make it to be shown in parent process before bazel's # timeout, the log would still show what happens in this subprocess, # instead of silently suppressing the error due to early bazel # timeout. Raising an error in the subprocess produces stack trace in # the log, but the program continues running. raise self._finish_process( _ProcessStatusInfo(task_type=task_type, is_successful=True, exc_info=None), return_value)
def setUp(self): super(GeneratorTest, self).setUp() v2_compat.enable_v2_behavior()
def _proc_func_wrapper(self, proc_func, task_type, task_id, per_process_cluster_spec, rpc_layer, *arg, **kwargs): """The wrapper function that actually gets run in child process(es).""" if self._capture_std_stream: # TODO(yuefengz): consider a lighter way of capturing std streams. stdout_collector = _LogCollector(sys.__stdout__) stderr_collector = _LogCollector(sys.__stderr__) sys.stdout = stdout_collector sys.stderr = stderr_collector else: stdout_collector = None stderr_collector = None # The thread will be dedicated to checking messages from parent process. threading.Thread( target=self._message_checking_func, args=(task_type, task_id, stdout_collector, stderr_collector)).start() os.environ['GRPC_FAIL_FAST'] = str(self._grpc_fail_fast) tf_config_dict = { 'cluster': per_process_cluster_spec, 'task': { 'type': task_type, 'index': task_id, }, } if rpc_layer is not None: tf_config_dict['rpc_layer'] = rpc_layer os.environ['TF_CONFIG'] = json.dumps(tf_config_dict) if self._v2_enabled: v2_compat.enable_v2_behavior() return_value = None if self._max_run_time is not None: # Register an sigalarm handler to exit the process when it reaches # `timeout` seconds. A program reaching `timeout` doesn't necessarily # indicate an issue. def handler(signum, frame): del signum, frame self._finish_process( _ProcessStatusInfo( task_type=task_type, is_successful=True, exc_info=None), None, stdout_collector, stderr_collector) # `os._exit(0)` is used to more reliably terminate a subprocess. os._exit(0) # pylint: disable=protected-access signal.signal(signal.SIGALRM, handler) signal.alarm(self._max_run_time) try: with self._runtime_mode(): return_value = proc_func(*arg, **kwargs) except Exception: # pylint: disable=broad-except # Capture all exceptions to be reported to parent process. self._finish_process( _ProcessStatusInfo( task_type=task_type, is_successful=False, exc_info=sys.exc_info()), return_value, stdout_collector, stderr_collector) # Re-raise the exception in addition to reporting it to the parent # process, so that even if `--test_timeout` flag is set and the # error doesn't make it to be shown in parent process before bazel's # timeout, the log would still show what happens in this subprocess, # instead of silently suppressing the error due to early bazel # timeout. Raising an error in the subprocess produces stack trace in # the log, but the program continues running. raise self._finish_process( _ProcessStatusInfo( task_type=task_type, is_successful=True, exc_info=None), return_value, stdout_collector, stderr_collector)