Esempio n. 1
0
    def _MaybeConstructSharedModel(self, train_cfg):
        """Construct a single shared copy of the model if this is a MultiTaskModel.

    If the share_model_object parameter is set, for MultiTaskModels,
    we create a MultiTaskSubModel for each task, but construct the model only
    once.

    Args:
      train_cfg: The params for a SingleTaskModel or MultiTaskModel.

    Returns:
      A MultiTaskModel, if train_cfg is a MultiTaskModel params object.
    """
        if not issubclass(train_cfg.cls, base_model.MultiTaskModel):
            return None

        if not train_cfg.share_model_object:
            return None

        with self._cluster, tf.container(
                self._container_id), contextlib.ExitStack() as stack:
            if not py_utils.IsEagerMode():
                stack.enter_context(self._graph.as_default())
                stack.enter_context(tf.device(self._cluster.GetPlacer()))
            with py_utils.VariableStore(), py_utils.VariableRenameScope(
                    self._variable_renaming_rules):
                py_utils.GetOrCreateGlobalStepVar()
                shared_model = train_cfg.Instantiate()

        return shared_model
Esempio n. 2
0
 def setUp(self):
     super().setUp()
     with contextlib.ExitStack() as stack:
         stack.enter_context(py_utils.VariableStore())
         self.addCleanup(stack.pop_all().close)
     # Ensure the global_step variable is created in the default graph.
     py_utils.GetOrCreateGlobalStepVar()
     cluster = cluster_factory.SetRequireSequentialInputOrder(True)
     cluster.params.in_unit_test = True
     cluster.__enter__()
Esempio n. 3
0
    def Wrapper(self, *args, **kwargs):
        """Decorator wrapper fn."""
        stack = _LAYER_STACK.stack

        with contextlib.ExitStack() as context_stack:
            if not stack:
                context_stack.enter_context(py_utils.VariableStore())

            if stack and stack[-1] is self:
                # Short circuit if called multiple times (eg. super() chain).
                func(self, *args, **kwargs)
                return

            # Push back self (the current layer) to the stack.
            stack_size = len(stack)
            stack.append(self)
            try:
                # Calls the layer's real __init__ method.
                # pylint: disable=protected-access
                with contextlib.ExitStack() as context_stack2:
                    if args and IsLayerParams(args[0]):
                        context_stack2.enter_context(
                            self._SelfVariableScope(args[0],
                                                    enter_name_scope=False))
                    func(self, *args, **kwargs)
                    self._CreateLayerVariables()
                self._disable_create_child = True
                self._VerifyChildren()
                self._VerifyVarsAndTheta()
                # pylint: enable=protected-access
                if len(stack) > 1:
                    # Records the fact stack[-2] just created a sub-layer self.
                    stack[-2]._AutoAddChild(self)  # pylint: disable=protected-access
            finally:
                # Pop out self (the current layer).
                assert stack[-1] is self
                stack.pop()
                assert len(stack) == stack_size
Esempio n. 4
0
    def __init__(self, train_cfg, ps_params_dict, *args, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
        if train_cfg is a SingleTaskModelParams, we expect only one entry.
      *args: List args to pass through to BaseRunner.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        if py_utils.IsEagerMode():
            assert tf.executing_eagerly()
            tf.logging.info(f'FLAGS.tf_master: {FLAGS.tf_master}')

            # Connect to the TPU runtime.
            resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
                FLAGS.tf_master, job_name=FLAGS.worker_job[len('/job:'):])
            tf.config.experimental_connect_to_cluster(resolver)

        super().__init__(train_cfg, *args, **kwargs)

        data_parallelism = self._cluster.num_splits_per_client
        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        self.task_scheduler = None
        self._checkpoint_dir = os.path.join(self._logdir, 'train')

        self._variable_renaming_rules = []

        self._ml_perf = None

        # If this is a multi-task model, grab the params for the TaskScheduler.
        if issubclass(train_cfg.cls, base_model.SingleTaskModel):
            tf.logging.info('single_task_model')
            assert len(ps_params_dict) == 1
            self._model_task_name = list(ps_params_dict.keys())[0]
            self._single_task_mode = True
        elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
            tf.logging.info('multi_task_model')

            if issubclass(train_cfg.cls,
                          multitask_model.RegExSharedVariableModel):
                self._variable_renaming_rules = train_cfg.variable_renaming_rules

            if train_cfg.task_schedule is None:
                task_schedule_params = task_scheduler.ConstantScheduler.Params(
                )
                task_schedule_params.task_probs = sorted(
                    list(train_cfg.task_probs.IterParams()))
            else:
                task_schedule_params = train_cfg.task_schedule
            self.task_scheduler = task_schedule_params.Instantiate()
            self._single_task_mode = False
        else:
            tf.logging.fatal(
                'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
                train_cfg.cls)

        tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

        self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir,
                         'trainer_params.txt')
        self._WriteToLog(
            text_format.MessageToString(train_cfg.ToProto(), as_utf8=True),
            self._checkpoint_dir, 'trainer_params.pbtxt')
        if self._ml_perf is not None:
            self._ml_perf_log = True
            mlp_log.mlperf_print(key='benchmark',
                                 value=self._ml_perf.benchmark_name)
        else:
            self._ml_perf_log = False

        train_cfg = self.params

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit(job=None):
            """Wait until the model is ready."""
            try:
                if py_utils.IsEagerMode():
                    topology = tf.tpu.experimental.initialize_tpu_system(
                        resolver)
                else:
                    # tpu.initialize_system() is called with None as embedding_config, as
                    # embedding_config is not available yet. Later in _Loop, it is called
                    # with the correct embedding_config. Since it cannot be called twice
                    # in the same graph with different embedding_config, we use a
                    # dummy_graph here.
                    dummy_graph = tf.Graph()
                    with dummy_graph.as_default():
                        tpu_initialize_system_op = tf.tpu.initialize_system(
                            embedding_config=None, job=job)

                    with self._GetSession(graph=dummy_graph) as sess:
                        topology = sess.run(tpu_initialize_system_op)

                if train_cfg.train.tpu_computation_shape is None:
                    computation_shape = py_utils.ComputationShape(
                        num_devices_per_split, topology)
                else:
                    computation_shape = train_cfg.train.tpu_computation_shape
                    assert num_devices_per_split == np.prod(computation_shape)

                if train_cfg.train.tpu_device_order_mode is None:
                    self.device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=computation_shape,
                        num_replicas=data_parallelism)
                else:
                    self.device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=computation_shape,
                        num_replicas=data_parallelism,
                        device_order_mode=train_cfg.train.tpu_device_order_mode
                    )
                py_utils.SetTpuDeviceAssignment(self.device_assignment, job)
                tf.logging.info('device_assignment.core_assignment: %s',
                                str(self.device_assignment.core_assignment))
                tf.logging.info(
                    'device_assignment.topology.device_coordinates: %s',
                    str(self.device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        if self._ml_perf_log:
            mlp_log.mlperf_print(key='init_start', value=None)
        if len(self._cluster.all_worker_names) > 1:
            for worker in self._cluster.all_worker_names:
                _WaitTillInit(worker)
        else:
            _WaitTillInit(None)

        shared_model = self._MaybeConstructSharedModel(train_cfg)

        self._program_schedule_dict = {}
        self._programs = []
        self._ckpt_programs = []

        self._checkpoint_to_load = None
        with self._cluster:
            # Create the ExponentialMovingAverage singleton shared by all programs, if
            # applicable.
            ema = py_utils.CreateEMAForModel(train_cfg, self._global_step_var)
            for task_string, program_schedule_params in ps_params_dict.items():
                program_schedule_params.logdir = self._logdir
                program_schedule_params.num_splits_per_client = data_parallelism
                program_schedule_params.task_name = task_string
                # If the model was created above, we'll inject it here as a
                # shared_model.
                ps = program_schedule_params.Instantiate(
                    shared_model=shared_model,
                    trial=self._trial,
                    ema=ema,
                    tf_master=self._tf_master)
                self._program_schedule_dict[task_string] = ps
                tf.logging.info('program_schedule_params: %s',
                                program_schedule_params.ToText())
                self._programs += ps.Programs()
                if ps.train_program:
                    self._ckpt_programs.append(ps.train_program)
                else:
                    self._ckpt_programs += ps.Programs()
                if program_schedule_params.ml_perf.benchmark_name is not None:
                    self._ml_perf = program_schedule_params.ml_perf
                if ('checkpoint_to_load' in program_schedule_params
                        and program_schedule_params.checkpoint_to_load):
                    if (self._checkpoint_to_load
                            and (self._checkpoint_to_load !=
                                 program_schedule_params.checkpoint_to_load)):
                        raise ValueError(
                            f'Multiple values found for checkpoint_to_load: '
                            f'{self._checkpoint_to_load}, '
                            f'{program_schedule_params.checkpoint_to_load}.')
                    self._checkpoint_to_load = program_schedule_params.checkpoint_to_load

        tf.logging.info('num_programs: %d', len(self._programs))

        # When running in a vizier trainer, the executor reports infeasiable runs
        # in case of errors. The programs report metrics and normal completions.
        for program in self._programs:
            if program._should_report_metrics:
                self._should_report_metrics = True

        with self._cluster, tf.container(
                self._container_id), contextlib.ExitStack() as stack:
            if not py_utils.IsEagerMode():
                stack.enter_context(self._graph.as_default())

                if FLAGS.use_tpu_mirrored_vars:
                    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
                        FLAGS.tf_master,
                        job_name=FLAGS.worker_job[len('/job:'):])
                    self._tpu_strategy = tf.distribute.experimental.TPUStrategy(
                        resolver, device_assignment=self.device_assignment)
                    stack.enter_context(self._tpu_strategy.scope())
                    stack.enter_context(
                        tpu_strategy._TPUReplicaContext(self._tpu_strategy))
                else:
                    stack.enter_context(tf.device(self._cluster.GetPlacer()))

            if FLAGS.pdb_on_exception:
                stack.enter_context(pdb_wrapper.catch_post_mortem())
            with py_utils.VariableStore(), py_utils.VariableRenameScope(
                    self._variable_renaming_rules):
                # `BuildTpuSubgraph` has to be called before checkpoint restore, so that
                # the optimizer slot variables are guaranteed to be initialized before
                # they get loaded. Otherwise, the optimizers' slot variables will not
                # be properly loaded when V1 checkpoint is used.
                for program in self._programs:
                    program.BuildTpuSubgraph()
                    py_utils.ClearTpuSummaryTensors()

            if not py_utils.IsEagerMode():
                self._initialize_tables = tf.tables_initializer()
                self._initialize_local_vars = tf.local_variables_initializer()
                self._initialize_global_vars = tf.global_variables_initializer(
                )

            checkpointer_models = [
                program.GetModel() for program in self._ckpt_programs
            ]

            if py_utils.IsEagerMode():
                if FLAGS.use_v2_checkpoints_in_eager:
                    self._checkpointer = checkpointer.EagerCheckpointerV2(
                        self._checkpoint_dir,
                        models=checkpointer_models,
                        init_op=None,
                        train_params=train_cfg.train,
                        save_only=False)
                else:
                    self._checkpointer = checkpointer.EagerCheckpointerV1(
                        self._checkpoint_dir,
                        models=checkpointer_models,
                        init_op=None,
                        train_params=train_cfg.train,
                        save_only=False)
            else:
                self._checkpointer = checkpointer.Checkpointer(
                    self._checkpoint_dir,
                    models=checkpointer_models,
                    init_op=self._initialize_global_vars,
                    train_params=train_cfg.train,
                    save_only=False)

            for program in self._programs:
                program.SetStatusMessageFn(self._SetStatusMessage)

            tpu_embedding_collection = (
                tpu_embedding_layers.TpuEmbeddingCollection.Get())
            self._load_ops = tpu_embedding_collection.load_ops
            self._retrieve_ops = tpu_embedding_collection.retrieve_ops
            self._tpu_embedding = tpu_embedding_collection.tpu_embedding