def _MaybeConstructSharedModel(self, train_cfg): """Construct a single shared copy of the model if this is a MultiTaskModel. If the share_model_object parameter is set, for MultiTaskModels, we create a MultiTaskSubModel for each task, but construct the model only once. Args: train_cfg: The params for a SingleTaskModel or MultiTaskModel. Returns: A MultiTaskModel, if train_cfg is a MultiTaskModel params object. """ if not issubclass(train_cfg.cls, base_model.MultiTaskModel): return None if not train_cfg.share_model_object: return None with self._cluster, tf.container( self._container_id), contextlib.ExitStack() as stack: if not py_utils.IsEagerMode(): stack.enter_context(self._graph.as_default()) stack.enter_context(tf.device(self._cluster.GetPlacer())) with py_utils.VariableStore(), py_utils.VariableRenameScope( self._variable_renaming_rules): py_utils.GetOrCreateGlobalStepVar() shared_model = train_cfg.Instantiate() return shared_model
def setUp(self): super().setUp() with contextlib.ExitStack() as stack: stack.enter_context(py_utils.VariableStore()) self.addCleanup(stack.pop_all().close) # Ensure the global_step variable is created in the default graph. py_utils.GetOrCreateGlobalStepVar() cluster = cluster_factory.SetRequireSequentialInputOrder(True) cluster.params.in_unit_test = True cluster.__enter__()
def Wrapper(self, *args, **kwargs): """Decorator wrapper fn.""" stack = _LAYER_STACK.stack with contextlib.ExitStack() as context_stack: if not stack: context_stack.enter_context(py_utils.VariableStore()) if stack and stack[-1] is self: # Short circuit if called multiple times (eg. super() chain). func(self, *args, **kwargs) return # Push back self (the current layer) to the stack. stack_size = len(stack) stack.append(self) try: # Calls the layer's real __init__ method. # pylint: disable=protected-access with contextlib.ExitStack() as context_stack2: if args and IsLayerParams(args[0]): context_stack2.enter_context( self._SelfVariableScope(args[0], enter_name_scope=False)) func(self, *args, **kwargs) self._CreateLayerVariables() self._disable_create_child = True self._VerifyChildren() self._VerifyVarsAndTheta() # pylint: enable=protected-access if len(stack) > 1: # Records the fact stack[-2] just created a sub-layer self. stack[-2]._AutoAddChild(self) # pylint: disable=protected-access finally: # Pop out self (the current layer). assert stack[-1] is self stack.pop() assert len(stack) == stack_size
def __init__(self, train_cfg, ps_params_dict, *args, **kwargs): """Construct an ExecutorTpu BaseRunner. Args: train_cfg: SingleTaskModelParams or MultiTaskModelParams ps_params_dict: A dict of top-level task name -> ProgramSchedule params, if train_cfg is a SingleTaskModelParams, we expect only one entry. *args: List args to pass through to BaseRunner. **kwargs: keyword args to pass through to BaseRunner. """ if py_utils.IsEagerMode(): assert tf.executing_eagerly() tf.logging.info(f'FLAGS.tf_master: {FLAGS.tf_master}') # Connect to the TPU runtime. resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tf_master, job_name=FLAGS.worker_job[len('/job:'):]) tf.config.experimental_connect_to_cluster(resolver) super().__init__(train_cfg, *args, **kwargs) data_parallelism = self._cluster.num_splits_per_client assert data_parallelism num_devices_per_split = self._cluster.num_devices_per_split tf.logging.info('data_parallelism: %d, num_devices_per_split: %d', data_parallelism, num_devices_per_split) self.task_scheduler = None self._checkpoint_dir = os.path.join(self._logdir, 'train') self._variable_renaming_rules = [] self._ml_perf = None # If this is a multi-task model, grab the params for the TaskScheduler. if issubclass(train_cfg.cls, base_model.SingleTaskModel): tf.logging.info('single_task_model') assert len(ps_params_dict) == 1 self._model_task_name = list(ps_params_dict.keys())[0] self._single_task_mode = True elif issubclass(train_cfg.cls, base_model.MultiTaskModel): tf.logging.info('multi_task_model') if issubclass(train_cfg.cls, multitask_model.RegExSharedVariableModel): self._variable_renaming_rules = train_cfg.variable_renaming_rules if train_cfg.task_schedule is None: task_schedule_params = task_scheduler.ConstantScheduler.Params( ) task_schedule_params.task_probs = sorted( list(train_cfg.task_probs.IterParams())) else: task_schedule_params = train_cfg.task_schedule self.task_scheduler = task_schedule_params.Instantiate() self._single_task_mode = False else: tf.logging.fatal( 'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel', train_cfg.cls) tf.logging.info('train_cfg.cls: %s', train_cfg.cls) self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir, 'trainer_params.txt') self._WriteToLog( text_format.MessageToString(train_cfg.ToProto(), as_utf8=True), self._checkpoint_dir, 'trainer_params.pbtxt') if self._ml_perf is not None: self._ml_perf_log = True mlp_log.mlperf_print(key='benchmark', value=self._ml_perf.benchmark_name) else: self._ml_perf_log = False train_cfg = self.params @py_utils.RetryOnTransientTfError() def _WaitTillInit(job=None): """Wait until the model is ready.""" try: if py_utils.IsEagerMode(): topology = tf.tpu.experimental.initialize_tpu_system( resolver) else: # tpu.initialize_system() is called with None as embedding_config, as # embedding_config is not available yet. Later in _Loop, it is called # with the correct embedding_config. Since it cannot be called twice # in the same graph with different embedding_config, we use a # dummy_graph here. dummy_graph = tf.Graph() with dummy_graph.as_default(): tpu_initialize_system_op = tf.tpu.initialize_system( embedding_config=None, job=job) with self._GetSession(graph=dummy_graph) as sess: topology = sess.run(tpu_initialize_system_op) if train_cfg.train.tpu_computation_shape is None: computation_shape = py_utils.ComputationShape( num_devices_per_split, topology) else: computation_shape = train_cfg.train.tpu_computation_shape assert num_devices_per_split == np.prod(computation_shape) if train_cfg.train.tpu_device_order_mode is None: self.device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=computation_shape, num_replicas=data_parallelism) else: self.device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=computation_shape, num_replicas=data_parallelism, device_order_mode=train_cfg.train.tpu_device_order_mode ) py_utils.SetTpuDeviceAssignment(self.device_assignment, job) tf.logging.info('device_assignment.core_assignment: %s', str(self.device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(self.device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise if self._ml_perf_log: mlp_log.mlperf_print(key='init_start', value=None) if len(self._cluster.all_worker_names) > 1: for worker in self._cluster.all_worker_names: _WaitTillInit(worker) else: _WaitTillInit(None) shared_model = self._MaybeConstructSharedModel(train_cfg) self._program_schedule_dict = {} self._programs = [] self._ckpt_programs = [] self._checkpoint_to_load = None with self._cluster: # Create the ExponentialMovingAverage singleton shared by all programs, if # applicable. ema = py_utils.CreateEMAForModel(train_cfg, self._global_step_var) for task_string, program_schedule_params in ps_params_dict.items(): program_schedule_params.logdir = self._logdir program_schedule_params.num_splits_per_client = data_parallelism program_schedule_params.task_name = task_string # If the model was created above, we'll inject it here as a # shared_model. ps = program_schedule_params.Instantiate( shared_model=shared_model, trial=self._trial, ema=ema, tf_master=self._tf_master) self._program_schedule_dict[task_string] = ps tf.logging.info('program_schedule_params: %s', program_schedule_params.ToText()) self._programs += ps.Programs() if ps.train_program: self._ckpt_programs.append(ps.train_program) else: self._ckpt_programs += ps.Programs() if program_schedule_params.ml_perf.benchmark_name is not None: self._ml_perf = program_schedule_params.ml_perf if ('checkpoint_to_load' in program_schedule_params and program_schedule_params.checkpoint_to_load): if (self._checkpoint_to_load and (self._checkpoint_to_load != program_schedule_params.checkpoint_to_load)): raise ValueError( f'Multiple values found for checkpoint_to_load: ' f'{self._checkpoint_to_load}, ' f'{program_schedule_params.checkpoint_to_load}.') self._checkpoint_to_load = program_schedule_params.checkpoint_to_load tf.logging.info('num_programs: %d', len(self._programs)) # When running in a vizier trainer, the executor reports infeasiable runs # in case of errors. The programs report metrics and normal completions. for program in self._programs: if program._should_report_metrics: self._should_report_metrics = True with self._cluster, tf.container( self._container_id), contextlib.ExitStack() as stack: if not py_utils.IsEagerMode(): stack.enter_context(self._graph.as_default()) if FLAGS.use_tpu_mirrored_vars: resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tf_master, job_name=FLAGS.worker_job[len('/job:'):]) self._tpu_strategy = tf.distribute.experimental.TPUStrategy( resolver, device_assignment=self.device_assignment) stack.enter_context(self._tpu_strategy.scope()) stack.enter_context( tpu_strategy._TPUReplicaContext(self._tpu_strategy)) else: stack.enter_context(tf.device(self._cluster.GetPlacer())) if FLAGS.pdb_on_exception: stack.enter_context(pdb_wrapper.catch_post_mortem()) with py_utils.VariableStore(), py_utils.VariableRenameScope( self._variable_renaming_rules): # `BuildTpuSubgraph` has to be called before checkpoint restore, so that # the optimizer slot variables are guaranteed to be initialized before # they get loaded. Otherwise, the optimizers' slot variables will not # be properly loaded when V1 checkpoint is used. for program in self._programs: program.BuildTpuSubgraph() py_utils.ClearTpuSummaryTensors() if not py_utils.IsEagerMode(): self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self._initialize_global_vars = tf.global_variables_initializer( ) checkpointer_models = [ program.GetModel() for program in self._ckpt_programs ] if py_utils.IsEagerMode(): if FLAGS.use_v2_checkpoints_in_eager: self._checkpointer = checkpointer.EagerCheckpointerV2( self._checkpoint_dir, models=checkpointer_models, init_op=None, train_params=train_cfg.train, save_only=False) else: self._checkpointer = checkpointer.EagerCheckpointerV1( self._checkpoint_dir, models=checkpointer_models, init_op=None, train_params=train_cfg.train, save_only=False) else: self._checkpointer = checkpointer.Checkpointer( self._checkpoint_dir, models=checkpointer_models, init_op=self._initialize_global_vars, train_params=train_cfg.train, save_only=False) for program in self._programs: program.SetStatusMessageFn(self._SetStatusMessage) tpu_embedding_collection = ( tpu_embedding_layers.TpuEmbeddingCollection.Get()) self._load_ops = tpu_embedding_collection.load_ops self._retrieve_ops = tpu_embedding_collection.retrieve_ops self._tpu_embedding = tpu_embedding_collection.tpu_embedding