示例#1
0
    def testRestoreState(self):
        train_dir = os.path.join(self.get_temp_dir(), 'testSaveRestore')
        os.mkdir(train_dir)
        p = base_model.SingleTaskModel.Params(LinearModel.Params())
        p.input = base_input_generator.BaseInputGenerator.Params()

        with self.session(graph=tf.Graph()) as sess:
            model = p.Instantiate()
            self.evaluate(tf.global_variables_initializer())
            saver = checkpointer.Checkpointer(train_dir, model)
            saver.Save(sess, 10)

        with self.session(graph=tf.Graph()) as sess:
            model = p.Instantiate()
            # Create a new saver.
            saver = checkpointer.Checkpointer(train_dir, model)
            saver.RestoreIfNeeded(sess)

            # Save at 20
            saver.Save(sess, 20)

            # Check the checkpoint state saved, there should be two paths.

            ckpt_state = tf.train.get_checkpoint_state(train_dir)
            self.assertEqual(2, len(ckpt_state.all_model_checkpoint_paths))
示例#2
0
    def testSaveRestore(self, use_custom_saver):
        FLAGS.use_custom_saver = use_custom_saver
        train_dir = os.path.join(self.get_temp_dir(), 'testSaveRestore')
        os.mkdir(train_dir)
        p = base_model.SingleTaskModel.Params(LinearModel.Params())
        p.input = base_input_generator.BaseInputGenerator.Params()

        final_global_step = 10
        expected_w = [0.38615, 2.975221, -0.852826]
        initial_b = 1.418741
        final_b = 1234

        with self.session(graph=tf.Graph()) as sess:
            model = p.Instantiate()
            self.evaluate(tf.global_variables_initializer())
            w, b = self.evaluate(
                [model.GetTask().vars.w,
                 model.GetTask().vars.b])
            self.assertAllClose(expected_w, w)
            self.assertAlmostEqual(initial_b, b, places=5)

            saver = checkpointer.Checkpointer(train_dir, model)
            self.evaluate(
                tf.assign(py_utils.GetOrCreateGlobalStepVar(),
                          final_global_step))
            self.evaluate(tf.assign(model.GetTask().vars.b, final_b))
            saver.Save(sess, model.global_step)

            w, b = self.evaluate(
                [model.GetTask().vars.w,
                 model.GetTask().vars.b])
            self.assertAllClose(expected_w, w)
            self.assertEqual(final_b, b)

        self.assertTrue(
            os.path.isfile(
                os.path.join(train_dir,
                             'ckpt-%08d.index' % final_global_step)))

        with self.session(graph=tf.Graph()) as sess:
            model = p.Instantiate()
            saver = checkpointer.Checkpointer(train_dir, model)
            saver.RestoreIfNeeded(sess)

            w, b, global_step = self.evaluate([
                model.GetTask().vars.w,
                model.GetTask().vars.b, model.global_step
            ])
            self.assertAllClose(expected_w, w)
            self.assertEqual(final_b, b)
            self.assertEqual(final_global_step, global_step)

            # Restore from checkpoint will always work, even though vars are already
            # initialized.
            saver.Restore(sess)
示例#3
0
    def testRestoreWithoutCheckpointInitializesVars(self):
        train_dir = os.path.join(
            self.get_temp_dir(), 'testRestoreWithoutCheckpointInitializesVars')
        os.mkdir(train_dir)
        p = base_model.SingleTaskModel.Params(LinearModel.Params())
        p.input = base_input_generator.BaseInputGenerator.Params()

        with self.session(graph=tf.Graph()) as sess:
            model = p.Instantiate()
            saver = checkpointer.Checkpointer(train_dir, model)

            with self.assertRaises(tf.errors.FailedPreconditionError):
                self.evaluate([model.GetTask().vars.w, model.GetTask().vars.b])

            saver.RestoreIfNeeded(sess)
            w, b, global_step = self.evaluate([
                model.GetTask().vars.w,
                model.GetTask().vars.b, model.global_step
            ])
            self.assertAllClose([0.38615, 2.975221, -0.852826], w)
            self.assertAlmostEqual(1.418741, b, places=5)
            self.assertEqual(0, global_step)

        self.assertFalse(
            os.path.isfile(os.path.join(train_dir, 'ckpt-00000000.index')))
示例#4
0
    def testRestore(self):
        train_dir = os.path.join(self.get_temp_dir(), 'testRestore')
        os.mkdir(train_dir)
        p = base_model.SingleTaskModel.Params(LinearModel.Params())
        p.input = base_input_generator.BaseInputGenerator.Params()

        with self.session(graph=tf.Graph()) as sess:
            model = p.Instantiate()
            saver = checkpointer.Checkpointer(train_dir, model)

            with self.assertRaises(tf.errors.FailedPreconditionError):
                self.evaluate([model.GetTask().vars.w, model.GetTask().vars.b])

            saver.RestoreIfNeeded(sess)
            w, b, global_step = self.evaluate([
                model.GetTask().vars.w,
                model.GetTask().vars.b, model.global_step
            ])
            self.assertAllClose([0.38615, 2.975221, -0.852826], w)
            self.assertAlmostEqual(1.418741, b, places=5)
            self.assertEqual(0, global_step)

            with self.assertRaises(AssertionError):
                # When initializing from scratch, variables are expected to not already
                # be initialized.
                saver.Restore(sess)

            # Unless force_reinitialize is used.
            saver.Restore(sess, force_reinitialize=True)
  def testEagerMultiLearnerCheckpointCompatibility(self):
    self.assertTrue(tf.executing_eagerly())
    cfg = model_registry.GetParams('test.LinearModelParams', 'Train')
    mdl = cfg.Instantiate()
    with py_utils.GradientTape(persistent=True):
      mdl.ConstructFPropBPropGraph()

    eager_v1_logdir = os.path.join(self.get_temp_dir(), 'eager_v1')
    eager_v2_logdir = os.path.join(self.get_temp_dir(), 'eager_v2')
    checkpointer.EagerCheckpointerV1(eager_v1_logdir, mdl).Save(gsteps=0)
    checkpointer.EagerCheckpointerV2(eager_v2_logdir, mdl).Save(gsteps=0)
    eager_v1_keys = _GetCheckpointKeys(
        os.path.join(eager_v1_logdir, 'ckpt_V1', 'ckpt-00000000'))
    eager_v2_keys = _GetCheckpointKeys(
        os.path.join(eager_v2_logdir, 'ckpt_V2', 'ckpt-0'))
    # Expecting two more variables in V2 checkpoints:
    # _CHECKPOINTABLE_OBJECT_GRAPH
    # save_counter
    self.assertEqual(len(eager_v1_keys) + 2, len(eager_v2_keys))  # pylint:disable=g-generic-assert

    py_utils.SetEagerMode(False)
    self.assertFalse(tf.executing_eagerly())
    graph_logdir = os.path.join(self.get_temp_dir(), 'graph')
    os.mkdir(graph_logdir)
    with self.session(graph=tf.Graph()) as sess:
      mdl = cfg.Instantiate()
      for lrn in mdl.GetTask().learners:
        lrn.optimizer.params.clear_variable_scope = False
      mdl.ConstructFPropBPropGraph()
      sess.run(tf.global_variables_initializer())
      checkpointer.Checkpointer(graph_logdir, mdl).Save(sess)
    graph_keys = _GetCheckpointKeys(os.path.join(graph_logdir, 'ckpt'))
    self.assertEqual(eager_v1_keys, graph_keys)
示例#6
0
    def testInitRulesDirectory(self):
        train_dir = os.path.join(self.get_temp_dir(), 'testInitRulesDirectory')
        os.mkdir(train_dir)
        p = base_model.SingleTaskModel.Params(LinearModel.Params())
        p.input = base_input_generator.BaseInputGenerator.Params()
        b1 = 1234
        g1 = 10
        b2 = 12345
        g2 = 100

        with self.session(graph=tf.Graph()) as sess:
            model = p.Instantiate()
            self.evaluate(tf.global_variables_initializer())
            saver = checkpointer.Checkpointer(train_dir, model)
            self.evaluate(tf.assign(py_utils.GetOrCreateGlobalStepVar(), g1))
            self.evaluate(tf.assign(model.GetTask().vars.b, b1))
            saver.Save(sess, model.global_step)
            self.evaluate(tf.assign(py_utils.GetOrCreateGlobalStepVar(), g2))
            self.evaluate(tf.assign(model.GetTask().vars.b, b2))
            saver.Save(sess, model.global_step)

        train_dir_2 = os.path.join(self.get_temp_dir(),
                                   'testInitRulesDirectory_2')

        # Set init_checkpoint_rules to only restore b from a specific ckpt
        # the first one, not the latest one.
        rules = [('(.*)', '%s')]
        spec_dir = os.path.join(train_dir, 'ckpt-00000010')
        p.train.init_from_checkpoint_rules = {spec_dir: (rules, [])}
        with self.session(graph=tf.Graph()) as sess:
            model = p.Instantiate()
            saver = checkpointer.Checkpointer(train_dir_2, model)
            saver.RestoreIfNeeded(sess)
            new_b = self.evaluate(model.GetTask().vars.b)
            self.assertEqual(b1, new_b)

        # Set init_checkpoint_rules to restore all from the latest checkpoint
        # by specifying just the original train directory, not a specific
        # checkpoint.
        rules = [('(.*)', '%s')]
        p.train.init_from_checkpoint_rules = {train_dir: (rules, [])}
        with self.session(graph=tf.Graph()) as sess:
            model = p.Instantiate()
            saver = checkpointer.Checkpointer(train_dir_2, model)
            saver.RestoreIfNeeded(sess)
            new_b = self.evaluate(model.GetTask().vars.b)
            self.assertEqual(b2, new_b)
示例#7
0
 def _CreateCheckpointer(self, train_dir, model, init_op=None):
     """Wrapper method for override purposes."""
     if py_utils.IsEagerMode():
         if FLAGS.write_v2_checkpoints:
             return checkpointer.EagerCheckpointerV2(
                 train_dir, model, init_op)
         return checkpointer.EagerCheckpointerV1(train_dir, model, init_op)
     return checkpointer.Checkpointer(train_dir, model, init_op)
示例#8
0
    def testSaveOnly(self):
        train_dir = os.path.join(self.get_temp_dir(), 'testSaveOnly')
        os.mkdir(train_dir)
        p = base_model.SingleTaskModel.Params(LinearModel.Params())
        p.input = base_input_generator.BaseInputGenerator.Params()

        with self.session(graph=tf.Graph()) as sess:
            model = p.Instantiate()
            self.evaluate(tf.global_variables_initializer())
            saver = checkpointer.Checkpointer(train_dir, model, save_only=True)
            saver.Save(sess, model.global_step)
            with self.assertRaises(AssertionError):
                saver.RestoreIfNeeded(sess)

        self.assertTrue(
            os.path.isfile(os.path.join(train_dir, 'ckpt-00000000.index')))
示例#9
0
    def BuildTpuSubgraph(self):
        tf.logging.info('EvalProgram BuildTpuSubGraph')
        with py_utils.OpportunisticVariableReuseScope(True):
            self._eval_metrics = metrics.TpuEvalMetrics()
            data_parallelism = self.data_parallelism

            def TpuEvalStep(*args):
                """Eval a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          Per-step eval metrics.
        """
                self._model = self._task_params.Instantiate()
                self._model.ConstructFPropGraph()
                per_step_eval_metrics = self._eval_metrics.SetMetrics(
                    self._model.GetTask().eval_metrics, args)
                return per_step_eval_metrics

            @tpu_function.on_device_training_loop
            def TpuEval():
                loop_result = tpu_training_loop.repeat(
                    self._steps_per_loop,
                    TpuEvalStep,
                    inputs=self._eval_metrics.initial_values,
                    name='eval_loop')
                # Final metrics are the avg across self._steps_per_loop steps.
                return self._eval_metrics.FinalizeMetrics(loop_result)

            batch_parallel_res = tf.tpu.batch_parallel(
                TpuEval,
                num_shards=data_parallelism,
                device_assignment=py_utils.GetTpuDeviceAssignment())
            # Get metric result from a single replica; they are all same here.
            self.tpu_ops = [[t[0] for t in batch_parallel_res]]
            self._checkpointer = checkpointer.Checkpointer(
                self._checkpoint_dir, self._model)

            return self.tpu_ops
示例#10
0
文件: program.py 项目: k1eira/lingvo
  def BuildTpuSubgraph(self):
    py_utils.ResetStepSeed()

    def _DecodeFn():
      with py_utils.OpportunisticVariableReuseScope(True):
        self._model = self._task_params.Instantiate()
        self._model_task = self._model.GetTask()
        input_batch = self._model_task.GetInputBatch()
        metrics_dict = self._model_task.Decode(input_batch)
        self.metrics_nm = py_utils.NestedMap(metrics_dict)
        return self.metrics_nm.Flatten()

    batch_parallel_res = tf.tpu.batch_parallel(
        _DecodeFn,
        num_shards=self.data_parallelism,
        device_assignment=py_utils.GetTpuDeviceAssignment())

    self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir,
                                                   self._model)

    self.metrics = py_utils.NestedMap(self.metrics_nm)
    self.metrics = self.metrics.Pack(batch_parallel_res)
    return None
示例#11
0
 def CreateCheckpointer(self):
     self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir,
                                                    self._model)
示例#12
0
    def __init__(self, train_cfg, ps_params_dict, *args, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
        if train_cfg is a SingleTaskModelParams, we expect only one entry.
      *args: List args to pass through to BaseRunner.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        if py_utils.IsEagerMode():
            assert tf.executing_eagerly()
            tf.logging.info(f'FLAGS.tf_master: {FLAGS.tf_master}')

            # Connect to the TPU runtime.
            resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
                FLAGS.tf_master, job_name=FLAGS.worker_job[len('/job:'):])
            tf.config.experimental_connect_to_cluster(resolver)

        super().__init__(train_cfg, *args, **kwargs)

        data_parallelism = self._cluster.num_splits_per_client
        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        self.task_scheduler = None
        self._checkpoint_dir = os.path.join(self._logdir, 'train')

        self._variable_renaming_rules = []

        self._ml_perf = None

        # If this is a multi-task model, grab the params for the TaskScheduler.
        if issubclass(train_cfg.cls, base_model.SingleTaskModel):
            tf.logging.info('single_task_model')
            assert len(ps_params_dict) == 1
            self._model_task_name = list(ps_params_dict.keys())[0]
            self._single_task_mode = True
        elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
            tf.logging.info('multi_task_model')

            if issubclass(train_cfg.cls,
                          multitask_model.RegExSharedVariableModel):
                self._variable_renaming_rules = train_cfg.variable_renaming_rules

            if train_cfg.task_schedule is None:
                task_schedule_params = task_scheduler.ConstantScheduler.Params(
                )
                task_schedule_params.task_probs = sorted(
                    list(train_cfg.task_probs.IterParams()))
            else:
                task_schedule_params = train_cfg.task_schedule
            self.task_scheduler = task_schedule_params.Instantiate()
            self._single_task_mode = False
        else:
            tf.logging.fatal(
                'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
                train_cfg.cls)

        tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

        self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir,
                         'trainer_params.txt')
        self._WriteToLog(
            text_format.MessageToString(train_cfg.ToProto(), as_utf8=True),
            self._checkpoint_dir, 'trainer_params.pbtxt')
        if self._ml_perf is not None:
            self._ml_perf_log = True
            mlp_log.mlperf_print(key='benchmark',
                                 value=self._ml_perf.benchmark_name)
        else:
            self._ml_perf_log = False

        train_cfg = self.params

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit(job=None):
            """Wait until the model is ready."""
            try:
                if py_utils.IsEagerMode():
                    topology = tf.tpu.experimental.initialize_tpu_system(
                        resolver)
                else:
                    # tpu.initialize_system() is called with None as embedding_config, as
                    # embedding_config is not available yet. Later in _Loop, it is called
                    # with the correct embedding_config. Since it cannot be called twice
                    # in the same graph with different embedding_config, we use a
                    # dummy_graph here.
                    dummy_graph = tf.Graph()
                    with dummy_graph.as_default():
                        tpu_initialize_system_op = tf.tpu.initialize_system(
                            embedding_config=None, job=job)

                    with self._GetSession(graph=dummy_graph) as sess:
                        topology = sess.run(tpu_initialize_system_op)

                if train_cfg.train.tpu_computation_shape is None:
                    computation_shape = py_utils.ComputationShape(
                        num_devices_per_split, topology)
                else:
                    computation_shape = train_cfg.train.tpu_computation_shape
                    assert num_devices_per_split == np.prod(computation_shape)

                if train_cfg.train.tpu_device_order_mode is None:
                    self.device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=computation_shape,
                        num_replicas=data_parallelism)
                else:
                    self.device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=computation_shape,
                        num_replicas=data_parallelism,
                        device_order_mode=train_cfg.train.tpu_device_order_mode
                    )
                py_utils.SetTpuDeviceAssignment(self.device_assignment, job)
                tf.logging.info('device_assignment.core_assignment: %s',
                                str(self.device_assignment.core_assignment))
                tf.logging.info(
                    'device_assignment.topology.device_coordinates: %s',
                    str(self.device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        if self._ml_perf_log:
            mlp_log.mlperf_print(key='init_start', value=None)
        if len(self._cluster.all_worker_names) > 1:
            for worker in self._cluster.all_worker_names:
                _WaitTillInit(worker)
        else:
            _WaitTillInit(None)

        shared_model = self._MaybeConstructSharedModel(train_cfg)

        self._program_schedule_dict = {}
        self._programs = []
        self._ckpt_programs = []

        self._checkpoint_to_load = None
        with self._cluster:
            # Create the ExponentialMovingAverage singleton shared by all programs, if
            # applicable.
            ema = py_utils.CreateEMAForModel(train_cfg, self._global_step_var)
            for task_string, program_schedule_params in ps_params_dict.items():
                program_schedule_params.logdir = self._logdir
                program_schedule_params.num_splits_per_client = data_parallelism
                program_schedule_params.task_name = task_string
                # If the model was created above, we'll inject it here as a
                # shared_model.
                ps = program_schedule_params.Instantiate(
                    shared_model=shared_model,
                    trial=self._trial,
                    ema=ema,
                    tf_master=self._tf_master)
                self._program_schedule_dict[task_string] = ps
                tf.logging.info('program_schedule_params: %s',
                                program_schedule_params.ToText())
                self._programs += ps.Programs()
                if ps.train_program:
                    self._ckpt_programs.append(ps.train_program)
                else:
                    self._ckpt_programs += ps.Programs()
                if program_schedule_params.ml_perf.benchmark_name is not None:
                    self._ml_perf = program_schedule_params.ml_perf
                if ('checkpoint_to_load' in program_schedule_params
                        and program_schedule_params.checkpoint_to_load):
                    if (self._checkpoint_to_load
                            and (self._checkpoint_to_load !=
                                 program_schedule_params.checkpoint_to_load)):
                        raise ValueError(
                            f'Multiple values found for checkpoint_to_load: '
                            f'{self._checkpoint_to_load}, '
                            f'{program_schedule_params.checkpoint_to_load}.')
                    self._checkpoint_to_load = program_schedule_params.checkpoint_to_load

        tf.logging.info('num_programs: %d', len(self._programs))

        # When running in a vizier trainer, the executor reports infeasiable runs
        # in case of errors. The programs report metrics and normal completions.
        for program in self._programs:
            if program._should_report_metrics:
                self._should_report_metrics = True

        with self._cluster, tf.container(
                self._container_id), contextlib.ExitStack() as stack:
            if not py_utils.IsEagerMode():
                stack.enter_context(self._graph.as_default())

                if FLAGS.use_tpu_mirrored_vars:
                    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
                        FLAGS.tf_master,
                        job_name=FLAGS.worker_job[len('/job:'):])
                    self._tpu_strategy = tf.distribute.experimental.TPUStrategy(
                        resolver, device_assignment=self.device_assignment)
                    stack.enter_context(self._tpu_strategy.scope())
                    stack.enter_context(
                        tpu_strategy._TPUReplicaContext(self._tpu_strategy))
                else:
                    stack.enter_context(tf.device(self._cluster.GetPlacer()))

            if FLAGS.pdb_on_exception:
                stack.enter_context(pdb_wrapper.catch_post_mortem())
            with py_utils.VariableStore(), py_utils.VariableRenameScope(
                    self._variable_renaming_rules):
                # `BuildTpuSubgraph` has to be called before checkpoint restore, so that
                # the optimizer slot variables are guaranteed to be initialized before
                # they get loaded. Otherwise, the optimizers' slot variables will not
                # be properly loaded when V1 checkpoint is used.
                for program in self._programs:
                    program.BuildTpuSubgraph()
                    py_utils.ClearTpuSummaryTensors()

            if not py_utils.IsEagerMode():
                self._initialize_tables = tf.tables_initializer()
                self._initialize_local_vars = tf.local_variables_initializer()
                self._initialize_global_vars = tf.global_variables_initializer(
                )

            checkpointer_models = [
                program.GetModel() for program in self._ckpt_programs
            ]

            if py_utils.IsEagerMode():
                if FLAGS.use_v2_checkpoints_in_eager:
                    self._checkpointer = checkpointer.EagerCheckpointerV2(
                        self._checkpoint_dir,
                        models=checkpointer_models,
                        init_op=None,
                        train_params=train_cfg.train,
                        save_only=False)
                else:
                    self._checkpointer = checkpointer.EagerCheckpointerV1(
                        self._checkpoint_dir,
                        models=checkpointer_models,
                        init_op=None,
                        train_params=train_cfg.train,
                        save_only=False)
            else:
                self._checkpointer = checkpointer.Checkpointer(
                    self._checkpoint_dir,
                    models=checkpointer_models,
                    init_op=self._initialize_global_vars,
                    train_params=train_cfg.train,
                    save_only=False)

            for program in self._programs:
                program.SetStatusMessageFn(self._SetStatusMessage)

            tpu_embedding_collection = (
                tpu_embedding_layers.TpuEmbeddingCollection.Get())
            self._load_ops = tpu_embedding_collection.load_ops
            self._retrieve_ops = tpu_embedding_collection.retrieve_ops
            self._tpu_embedding = tpu_embedding_collection.tpu_embedding
示例#13
0
    def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir,
                 tf_master, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
          if train_cfg is a SingleTaskModelParams, we expect only one entry.
      model_task_name: An override for multi-task models, currently unused.
      logdir:  String path to the log directory to output to.
      tf_master: String path to the master job, e.g. 'local'.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        super(ExecutorTpu, self).__init__(train_cfg, model_task_name, logdir,
                                          tf_master, **kwargs)
        self._cluster_def = self._cluster.worker_cluster_def

        # There is a single Executor task
        assert self._cluster.num_replicas == 1
        data_parallelism = self._cluster.num_splits_per_client

        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        self.task_scheduler = None
        self._checkpoint_dir = os.path.join(logdir, 'train')

        self._variable_renaming_rules = []

        # If this is a multi-task model, grab the params for the TaskScheduler.
        if issubclass(train_cfg.cls, base_model.SingleTaskModel):
            tf.logging.info('single_task_model')
            assert len(ps_params_dict) == 1
            self._model_task_name = list(ps_params_dict.keys())[0]
            self._single_task_mode = True
        elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
            tf.logging.info('multi_task_model')

            if issubclass(train_cfg.cls,
                          multitask_model.RegExSharedVariableModel):
                self._variable_renaming_rules = train_cfg.variable_renaming_rules

            if train_cfg.task_schedule is None:
                task_schedule_params = task_scheduler.ConstantScheduler.Params(
                )
                task_schedule_params.task_probs = sorted(
                    list(train_cfg.task_probs.IterParams()))
            else:
                task_schedule_params = train_cfg.task_schedule
            self.task_scheduler = task_schedule_params.Instantiate()
            self._single_task_mode = False
        else:
            tf.logging.fatal(
                'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
                train_cfg.cls)

        tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

        self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir,
                         'executor_params.txt')
        self._program_schedule_dict = {}
        self._programs = []

        for task_string, program_schedule_params in ps_params_dict.items():
            program_schedule_params.logdir = logdir
            program_schedule_params.num_splits_per_client = data_parallelism
            program_schedule_params.task_name = task_string
            ps = program_schedule_params.Instantiate()
            self._program_schedule_dict[task_string] = ps
            tf.logging.info('program_schedule_params: %s',
                            program_schedule_params.ToText())
            self._programs += ps.Programs()

        tf.logging.info('num_programs: %d', len(self._programs))

        # BaseRunner legacy
        self.enqueue_ops = None

        def ComputationShape(split_size):
            """Decides the computation shape based on the split_size."""
            computation_shape = None
            if split_size == 1:
                computation_shape = [1, 1, 1]
            elif split_size == 2:
                computation_shape = [1, 1, 2]
            elif split_size == 4:
                computation_shape = [1, 2, 2]
            elif split_size == 8:
                computation_shape = [2, 2, 2]
            elif split_size == 16:
                computation_shape = [4, 2, 2]
            elif split_size == 32:
                computation_shape = [4, 4, 2]
            elif split_size == 64:
                computation_shape = [4, 8, 2]
            elif split_size == 128:
                computation_shape = [8, 8, 2]
            elif split_size == 256:
                computation_shape = [8, 16, 2]
            elif split_size == 512:
                computation_shape = [16, 16, 2]
            else:
                assert False, (
                    'Model parallelism with %d devices is currently not'
                    ' supported.' % split_size)
            assert computation_shape is not None
            return computation_shape

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit():
            """Wait until the model is ready."""
            try:
                with self._graph.as_default(), self._GetSession(
                        cluster_def=self._cluster_def) as sess:
                    topology = sess.run(
                        tf.tpu.initialize_system(embedding_config=None,
                                                 job=None))
                    device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=ComputationShape(
                            num_devices_per_split),
                        num_replicas=data_parallelism)
                    py_utils.SetTpuDeviceAssignment(device_assignment)
                    tf.logging.info('device_assignment.core_assignment: %s',
                                    str(device_assignment.core_assignment))
                    tf.logging.info(
                        'device_assignment.topology.device_coordinates: %s',
                        str(device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        _WaitTillInit()

        with self._graph.as_default(), tf.container(self._container_id):
            with self._cluster, tf.device(
                    self._cluster.job_spec.name if not FLAGS.
                    cluster_placer_in_executor else self._cluster.GetPlacer()):
                with py_utils.VariableRenameScope(
                        self._variable_renaming_rules):
                    for program in self._programs:
                        program.BuildTpuSubgraph()
                for program in self._programs:
                    program.CreateCheckpointer()
                self._initialize_tables = tf.tables_initializer()
                self._initialize_local_vars = tf.local_variables_initializer()

                self.save_only_checkpointer = checkpointer.Checkpointer(
                    self._checkpoint_dir,
                    model=None,
                    train_params=train_cfg.train,
                    save_only=True)
示例#14
0
文件: program.py 项目: k1eira/lingvo
 def _CreateCheckpointer(self, train_dir, model):
   return checkpointer.Checkpointer(train_dir, model)
示例#15
0
    def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir,
                 tf_master, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
        if train_cfg is a SingleTaskModelParams, we expect only one entry.
      model_task_name: An override for multi-task models, currently unused.
      logdir:  String path to the log directory to output to.
      tf_master: String path to the master job, e.g. 'local'.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        super().__init__(train_cfg, model_task_name, logdir, tf_master,
                         **kwargs)

        data_parallelism = self._cluster.num_splits_per_client

        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        self.task_scheduler = None
        self._checkpoint_dir = os.path.join(logdir, 'train')

        self._variable_renaming_rules = []

        self._ml_perf = None

        # If this is a multi-task model, grab the params for the TaskScheduler.
        if issubclass(train_cfg.cls, base_model.SingleTaskModel):
            tf.logging.info('single_task_model')
            assert len(ps_params_dict) == 1
            self._model_task_name = list(ps_params_dict.keys())[0]
            self._single_task_mode = True
        elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
            tf.logging.info('multi_task_model')

            if issubclass(train_cfg.cls,
                          multitask_model.RegExSharedVariableModel):
                self._variable_renaming_rules = train_cfg.variable_renaming_rules

            if train_cfg.task_schedule is None:
                task_schedule_params = task_scheduler.ConstantScheduler.Params(
                )
                task_schedule_params.task_probs = sorted(
                    list(train_cfg.task_probs.IterParams()))
            else:
                task_schedule_params = train_cfg.task_schedule
            self.task_scheduler = task_schedule_params.Instantiate()
            self._single_task_mode = False
        else:
            tf.logging.fatal(
                'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
                train_cfg.cls)

        tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

        self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir,
                         'trainer_params.txt')
        if self._ml_perf is not None:
            self._ml_perf_log = True
            mlp_log.mlperf_print(key='benchmark',
                                 value=self._ml_perf.benchmark_name)
        else:
            self._ml_perf_log = False

        # BaseRunner legacy
        self.enqueue_ops = None

        train_cfg = self.params

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit(job=None):
            """Wait until the model is ready."""
            try:
                # tpu.initialize_system() is called with None as embedding_config, as
                # embedding_config is not available yet. Later in _Loop, it is called
                # with the correct embedding_config. Since it cannot be called twice in
                # the same graph with different embedding_config, we use a dummy_graph
                # here.
                dummy_graph = tf.Graph()
                with dummy_graph.as_default():
                    tpu_initialize_system_op = tf.tpu.initialize_system(
                        embedding_config=None, job=job)

                with self._GetSession(graph=dummy_graph) as sess:
                    topology = sess.run(tpu_initialize_system_op)

                if train_cfg.train.tpu_device_order_mode is None:
                    device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=py_utils.ComputationShape(
                            num_devices_per_split, topology),
                        num_replicas=data_parallelism)
                else:
                    device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=py_utils.ComputationShape(
                            num_devices_per_split, topology),
                        num_replicas=data_parallelism,
                        device_order_mode=train_cfg.train.tpu_device_order_mode
                    )
                py_utils.SetTpuDeviceAssignment(device_assignment, job)
                tf.logging.info('device_assignment.core_assignment: %s',
                                str(device_assignment.core_assignment))
                tf.logging.info(
                    'device_assignment.topology.device_coordinates: %s',
                    str(device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        if self._ml_perf_log:
            mlp_log.mlperf_print(key='init_start', value=None)
        if len(self._cluster.all_worker_names) > 1:
            for worker in self._cluster.all_worker_names:
                _WaitTillInit(worker)
        else:
            _WaitTillInit(None)

        shared_model = self._MaybeConstructSharedModel(train_cfg)

        self._program_schedule_dict = {}
        self._programs = []

        for task_string, program_schedule_params in ps_params_dict.items():
            program_schedule_params.logdir = logdir
            program_schedule_params.num_splits_per_client = data_parallelism
            program_schedule_params.task_name = task_string
            # If the model was created above, we'll inject it here as a shared_model.
            ps = program_schedule_params.Instantiate(shared_model=shared_model,
                                                     tf_master=self._tf_master)
            self._program_schedule_dict[task_string] = ps
            tf.logging.info('program_schedule_params: %s',
                            program_schedule_params.ToText())
            self._programs += ps.Programs()
            if program_schedule_params.ml_perf.benchmark_name is not None:
                self._ml_perf = program_schedule_params.ml_perf

        tf.logging.info('num_programs: %d', len(self._programs))

        with self._graph.as_default(), tf.container(self._container_id):
            with self._cluster, tf.device(self._cluster.GetPlacer()):
                with py_utils.VariableRenameScope(
                        self._variable_renaming_rules):
                    _ = py_utils.GetOrCreateGlobalStepVar()
                    for program in self._programs:
                        program.BuildTpuSubgraph()
                        py_utils.ClearTpuSummaryTensors()

                self._initialize_tables = tf.tables_initializer()
                self._initialize_local_vars = tf.local_variables_initializer()
                self._initialize_global_vars = tf.global_variables_initializer(
                )

                for program in self._programs:
                    program.SetStatusMessageFn(self._SetStatusMessage)
                    program.CreateCheckpointer(
                        init_op=self._initialize_global_vars)

                self.save_only_checkpointer = checkpointer.Checkpointer(
                    self._checkpoint_dir,
                    model=None,
                    init_op=self._initialize_global_vars,
                    train_params=train_cfg.train,
                    save_only=True)

            self._load_ops = tf.get_collection(py_utils.TPU_EMBEDDING_LOAD_OPS)
            self._retrieve_ops = tf.get_collection(
                py_utils.TPU_EMBEDDING_RETRIEVE_OPS)
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            self._tpu_embedding = (tpu_embedding_collection[0]
                                   if tpu_embedding_collection else None)
            tf.io.write_graph(self._graph.as_graph_def(), self._checkpoint_dir,
                              'train.pbtxt')
示例#16
0
 def _CreateCheckpointer(self, train_dir, model):
   """Wrapper method for override purposes."""
   return checkpointer.Checkpointer(train_dir, model)
示例#17
0
 def CreateCheckpointer(self, init_op=None):
   self._checkpointer = checkpointer.Checkpointer(
       self._checkpoint_dir, self._model, init_op=init_op)
示例#18
0
    def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir,
                 tf_master, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
          if train_cfg is a SingleTaskModelParams, we expect only one entry.
      model_task_name: An override for multi-task models, currently unused.
      logdir:  String path to the log directory to output to.
      tf_master: String path to the master job, e.g. 'local'.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        super().__init__(train_cfg, model_task_name, logdir, tf_master,
                         **kwargs)

        self._cluster_def = self._cluster.worker_cluster_def

        # There is a single Executor task
        assert self._cluster.num_replicas == 1
        data_parallelism = self._cluster.num_splits_per_client

        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        self.task_scheduler = None
        self._checkpoint_dir = os.path.join(logdir, 'train')

        self._variable_renaming_rules = []

        self._ml_perf = None

        # If this is a multi-task model, grab the params for the TaskScheduler.
        if issubclass(train_cfg.cls, base_model.SingleTaskModel):
            tf.logging.info('single_task_model')
            assert len(ps_params_dict) == 1
            self._model_task_name = list(ps_params_dict.keys())[0]
            self._single_task_mode = True
        elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
            tf.logging.info('multi_task_model')

            if issubclass(train_cfg.cls,
                          multitask_model.RegExSharedVariableModel):
                self._variable_renaming_rules = train_cfg.variable_renaming_rules

            if train_cfg.task_schedule is None:
                task_schedule_params = task_scheduler.ConstantScheduler.Params(
                )
                task_schedule_params.task_probs = sorted(
                    list(train_cfg.task_probs.IterParams()))
            else:
                task_schedule_params = train_cfg.task_schedule
            self.task_scheduler = task_schedule_params.Instantiate()
            self._single_task_mode = False
        else:
            tf.logging.fatal(
                'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
                train_cfg.cls)

        tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

        self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir,
                         'trainer_params.txt')
        if self._ml_perf is not None:
            self._ml_perf_log = True
            mlp_log.mlperf_print(key='benchmark',
                                 value=self._ml_perf.benchmark_name)
        else:
            self._ml_perf_log = False

        # BaseRunner legacy
        self.enqueue_ops = None

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit():
            """Wait until the model is ready."""
            try:
                with self._graph.as_default(), self._GetSession(
                        cluster_def=self._cluster_def,
                        disable_meta_optimizer=FLAGS.
                        disable_meta_optimizer_in_executor) as sess:
                    topology = sess.run(
                        tf.tpu.initialize_system(embedding_config=None,
                                                 job=None))
                    device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=py_utils.ComputationShape(
                            num_devices_per_split, topology),
                        num_replicas=data_parallelism)
                    py_utils.SetTpuDeviceAssignment(device_assignment)
                    tf.logging.info('device_assignment.core_assignment: %s',
                                    str(device_assignment.core_assignment))
                    tf.logging.info(
                        'device_assignment.topology.device_coordinates: %s',
                        str(device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        if self._ml_perf_log:
            mlp_log.mlperf_print(key='init_start', value=None)
        _WaitTillInit()

        train_cfg = self.params
        shared_model = self._MaybeConstructSharedModel(train_cfg)

        self._program_schedule_dict = {}
        self._programs = []

        for task_string, program_schedule_params in ps_params_dict.items():
            program_schedule_params.logdir = logdir
            program_schedule_params.num_splits_per_client = data_parallelism
            program_schedule_params.task_name = task_string
            # If the model was created above, we'll inject it here as a shared_model.
            ps = program_schedule_params.Instantiate(shared_model=shared_model)
            self._program_schedule_dict[task_string] = ps
            tf.logging.info('program_schedule_params: %s',
                            program_schedule_params.ToText())
            self._programs += ps.Programs()
            if program_schedule_params.ml_perf.benchmark_name is not None:
                self._ml_perf = program_schedule_params.ml_perf

        tf.logging.info('num_programs: %d', len(self._programs))

        with self._graph.as_default(), tf.container(self._container_id):
            with self._cluster, tf.device(
                    self._cluster.job_spec.name if not FLAGS.
                    cluster_placer_in_executor else self._cluster.GetPlacer()):
                with py_utils.VariableRenameScope(
                        self._variable_renaming_rules):
                    _ = py_utils.GetOrCreateGlobalStepVar()
                    for program in self._programs:
                        program.BuildTpuSubgraph()
                        py_utils.ClearTpuSummaryTensors()
                for program in self._programs:
                    program.SetStatusMessageFn(self._SetStatusMessage)
                    program.CreateCheckpointer()
                self._initialize_tables = tf.tables_initializer()
                self._initialize_local_vars = tf.local_variables_initializer()

                self.save_only_checkpointer = checkpointer.Checkpointer(
                    self._checkpoint_dir,
                    model=None,
                    train_params=train_cfg.train,
                    save_only=True)
示例#19
0
  def testBatchNormLayer(self):
    p = base_model.SingleTaskModel.Params()
    p.task = self.TestParams(layers.BatchNormLayer.Params().Set(dim=1))
    p.task.train.ema_decay = 0.9
    p.task.train.ema_decay_moving_vars = True
    model = p.Instantiate()
    self.assertIsNotNone(model.ema)
    task = model._task
    task._train_op = tf.no_op()
    task.ApplyExponentialMovingAverage(model.ema)

    layer = task.encoder
    self.assertLen(layer.vars, 4)
    for var in layer.vars.Flatten():
      self.assertIsNotNone(model.ema.average(var), msg=var.name)
    beta = layer.vars.beta
    mean = layer.vars.moving_mean

    global_step = 100
    beta_1 = np.asarray([.2])
    mean_1 = np.asarray([.03])
    beta_1_ema = beta_1 * .1
    mean_1_ema = mean_1 * .1
    with self.session() as sess:
      # Test EMA values.
      sess.run(tf.global_variables_initializer())
      sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), global_step))
      sess.run(tf.assign(beta, beta_1))
      sess.run(tf.assign(mean, mean_1))
      sess.run(task._post_train_ops)

      self.assertAllClose([beta_1, beta_1_ema, mean_1, mean_1_ema],
                          sess.run([
                              beta,
                              model.ema.average(beta), mean,
                              model.ema.average(mean)
                          ]))

      # Test checkpointer.
      train_dir = os.path.join(self.get_temp_dir(), 'testSaveRestore')
      os.mkdir(train_dir)
      saver = checkpointer.Checkpointer(train_dir, model)
      saver.Save(sess, model.global_step)

      self.assertTrue(
          os.path.isfile(
              os.path.join(train_dir, 'ckpt-%08d.index' % global_step)))

    # Restore from ckpt in training mode.
    with self.session(graph=tf.Graph()) as sess:
      model = p.Instantiate()
      self.assertIsNotNone(model.ema)
      task = model._task
      task._train_op = tf.no_op()
      task.ApplyExponentialMovingAverage(model.ema)
      layer = task.encoder
      for var in layer.vars.Flatten():
        self.assertIsNotNone(model.ema.average(var), msg=var.name)
      beta = layer.vars.beta
      mean = layer.vars.moving_mean

      saver = checkpointer.Checkpointer(train_dir, model)
      saver.RestoreIfNeeded(sess)

      self.assertAllClose([beta_1, beta_1_ema, mean_1, mean_1_ema],
                          sess.run([
                              beta,
                              model.ema.average(beta), mean,
                              model.ema.average(mean)
                          ]))

    # Restore from ckpt in eval mode.
    with self.session(graph=tf.Graph()) as sess, self.SetEval(True):
      model = p.Instantiate()
      self.assertIsNotNone(model.ema)
      task = model._task
      # task._train_op = tf.no_op()
      # task.ApplyExponentialMovingAverage(model.ema)
      layer = task.encoder
      # for var in layer.vars.Flatten():
      #   self.assertIsNotNone(model.ema.average(var), msg=var.name)
      beta = layer.vars.beta
      mean = layer.vars.moving_mean

      saver = checkpointer.Checkpointer(train_dir, model)
      saver.RestoreIfNeeded(sess)

      # Both beta and mean should use the EMA value.
      self.assertAllClose([beta_1_ema, mean_1_ema], sess.run([beta, mean]))