示例#1
0
    def _Loop(self):
        with self._cluster, tf.container(self._container_id), self._GetSession(
                disable_meta_optimizer=FLAGS.disable_meta_optimizer_in_executor
        ) as sess:
            config_proto = (self._tpu_embedding.config_proto
                            if self._tpu_embedding is not None else None)
            sess.reset(self._tf_master)
            for worker in self._cluster.all_worker_names:
                sess.run(
                    tf.tpu.initialize_system(embedding_config=config_proto,
                                             job=worker))

            # Initialize the variables first, if needed.
            compile_fns = []
            for program in self._programs:
                program.RestoreIfNeeded(sess)
                compile_fns += [program.Compile]

            # Run the compiles in parallel.
            threadpool = multiprocessing.dummy.Pool(len(compile_fns))
            futures = []
            tf.logging.info(
                f'Compiling {len(compile_fns)} programs in parallel.')
            for fn in compile_fns:
                futures += [threadpool.apply_async(fn, args=(sess, ))]
            for future in futures:
                future.wait()

            sess.run(self._initialize_tables)
            sess.run(self._initialize_local_vars)

            sess.run(self._load_ops)
            program_schedule = None
            while True:
                global_step = sess.run(py_utils.GetGlobalStep())
                if self._ShouldStop(sess, global_step):
                    tf.logging.info('Training finished.')
                    if not self._ml_perf_log:
                        self.save_only_checkpointer.Save(sess, global_step)
                        for program in self._programs:
                            program.SaveProgramState(sess, global_step)
                    if program_schedule:
                        tf.logging.info('Shutting down programs.')
                        program_schedule.Shutdown()
                    return

                if not self._ml_perf_log and self.save_only_checkpointer.ShouldSave(
                ):

                    def RunSave(sess, global_step):
                        # Run TPU embedding retrieve ops.
                        # NOTE: this is expensive, so only run it when we're checkpointing.
                        tf.logging.info('Retrieve params.')
                        sess.run(self._retrieve_ops)
                        tf.logging.info('Retrieve params done.')
                        # Save program state first, so it's recoverable after we restore
                        # from checkpoint.
                        for program in self._programs:
                            program.SaveProgramState(sess, global_step)
                        # Save the checkpoints.
                        self.save_only_checkpointer.Save(sess, global_step)

                    if self.save_only_checkpointer.async_checkpointing:
                        tf.logging.info(
                            'Save checkpoint asynchronously AT YOUR OWN RISK.')
                        threadpool = multiprocessing.dummy.Pool(1)
                        saver_future = threadpool.apply_async(
                            RunSave, args=(sess, global_step))
                    else:
                        RunSave(sess, global_step)

                # If a task is explicitly selected, only run the programs associated
                # with that task.
                if self._single_task_mode or self._model_task_name:
                    tf.logging.info('Single task mode: %s',
                                    self._model_task_name)
                    program_schedule = self._program_schedule_dict[
                        self._model_task_name]
                else:
                    # Otherwise, sample a task.
                    model_task = self.task_scheduler.Sample(global_step)
                    tf.logging.info('Sampled %s', model_task)
                    program_schedule = self._program_schedule_dict[model_task]

                done = program_schedule.Run(sess)
                if (not self._ml_perf_log
                        and self.save_only_checkpointer.async_checkpointing):
                    saver_future.wait()

                if done:
                    tf.logging.info('Program schedule told us to stop.\n'
                                    'Shutting down programs.')
                    program_schedule.Shutdown()
                    return
示例#2
0
  def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir,
               tf_master, **kwargs):
    """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
          if train_cfg is a SingleTaskModelParams, we expect only one entry.
      model_task_name: An override for multi-task models, currently unused.
      logdir:  String path to the log directory to output to.
      tf_master: String path to the master job, e.g. 'local'.
      **kwargs: keyword args to pass through to BaseRunner.
    """
    super().__init__(train_cfg, model_task_name, logdir, tf_master, **kwargs)

    self._cluster_def = self._cluster.worker_cluster_def

    # There is a single Executor task
    assert self._cluster.num_replicas == 1
    data_parallelism = self._cluster.num_splits_per_client

    assert data_parallelism
    num_devices_per_split = self._cluster.num_devices_per_split
    tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                    data_parallelism, num_devices_per_split)

    self.task_scheduler = None
    self._checkpoint_dir = os.path.join(logdir, 'train')

    self._variable_renaming_rules = []

    self._ml_perf = None

    # If this is a multi-task model, grab the params for the TaskScheduler.
    if issubclass(train_cfg.cls, base_model.SingleTaskModel):
      tf.logging.info('single_task_model')
      assert len(ps_params_dict) == 1
      self._model_task_name = list(ps_params_dict.keys())[0]
      self._single_task_mode = True
    elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
      tf.logging.info('multi_task_model')

      if issubclass(train_cfg.cls, multitask_model.RegExSharedVariableModel):
        self._variable_renaming_rules = train_cfg.variable_renaming_rules

      if train_cfg.task_schedule is None:
        task_schedule_params = task_scheduler.ConstantScheduler.Params()
        task_schedule_params.task_probs = sorted(
            list(train_cfg.task_probs.IterParams()))
      else:
        task_schedule_params = train_cfg.task_schedule
      self.task_scheduler = task_schedule_params.Instantiate()
      self._single_task_mode = False
    else:
      tf.logging.fatal(
          'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
          train_cfg.cls)

    tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

    self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir,
                     'trainer_params.txt')
    self._program_schedule_dict = {}
    self._programs = []

    for task_string, program_schedule_params in ps_params_dict.items():
      program_schedule_params.logdir = logdir
      program_schedule_params.num_splits_per_client = data_parallelism
      program_schedule_params.task_name = task_string
      ps = program_schedule_params.Instantiate()
      self._program_schedule_dict[task_string] = ps
      tf.logging.info('program_schedule_params: %s',
                      program_schedule_params.ToText())
      self._programs += ps.Programs()
      if program_schedule_params.ml_perf.benchmark_name is not None:
        self._ml_perf = program_schedule_params.ml_perf

    tf.logging.info('num_programs: %d', len(self._programs))

    if self._ml_perf is not None:
      self._ml_perf_log = True
      mlp_log.mlperf_print(key='benchmark', value=self._ml_perf.benchmark_name)
    else:
      self._ml_perf_log = False

    # BaseRunner legacy
    self.enqueue_ops = None

    @py_utils.RetryOnTransientTfError()
    def _WaitTillInit():
      """Wait until the model is ready."""
      try:
        with self._graph.as_default(), self._GetSession(
            cluster_def=self._cluster_def,
            disable_meta_optimizer=FLAGS.disable_meta_optimizer_in_executor
        ) as sess:
          topology = sess.run(
              tf.tpu.initialize_system(embedding_config=None, job=None))
          device_assignment = device_assignment_lib.device_assignment(
              topology,
              computation_shape=py_utils.ComputationShape(
                  num_devices_per_split),
              num_replicas=data_parallelism)
          py_utils.SetTpuDeviceAssignment(device_assignment)
          tf.logging.info('device_assignment.core_assignment: %s',
                          str(device_assignment.core_assignment))
          tf.logging.info(
              'device_assignment.topology.device_coordinates: %s',
              str(device_assignment.topology.device_coordinates))
      except py_utils.transient_tf_errors as e:
        tf.logging.info('TPU initialization failed: %s', e)
        raise

    if self._ml_perf_log:
      mlp_log.mlperf_print(key='init_start', value=None)
    _WaitTillInit()

    with self._graph.as_default(), tf.container(self._container_id):
      with self._cluster, tf.device(
          self._cluster.job_spec.name if not FLAGS.cluster_placer_in_executor
          else self._cluster.GetPlacer()):
        with py_utils.VariableRenameScope(self._variable_renaming_rules):
          _ = py_utils.GetOrCreateGlobalStepVar()
          for program in self._programs:
            program.BuildTpuSubgraph()
        for program in self._programs:
          program.SetStatusMessageFn(self._SetStatusMessage)
          program.CreateCheckpointer()
        self._initialize_tables = tf.tables_initializer()
        self._initialize_local_vars = tf.local_variables_initializer()

        self.save_only_checkpointer = checkpointer.Checkpointer(
            self._checkpoint_dir,
            model=None,
            train_params=train_cfg.train,
            save_only=True)
示例#3
0
    def _Loop(self):
        with self._cluster, tf.container(self._container_id), self._GetSession(
                disable_meta_optimizer=FLAGS.disable_meta_optimizer_in_executor
        ) as sess:
            config_proto = (self._tpu_embedding.config_proto
                            if self._tpu_embedding is not None else None)
            sess.reset(self._tf_master)
            for worker in self._cluster.all_worker_names:
                sess.run(
                    tf.tpu.initialize_system(embedding_config=config_proto,
                                             job=worker))

            # Initialize the variables first, if needed.
            compile_fns = []
            for program in self._programs:
                program.RestoreIfNeeded(sess)
                compile_fns += [program.Compile]

            # Run the compiles in parallel.
            threadpool = multiprocessing.dummy.Pool(len(compile_fns))
            futures = []
            tf.logging.info(
                f'Compiling {len(compile_fns)} programs in parallel.')
            for fn in compile_fns:
                futures += [threadpool.apply_async(fn, args=(sess, ))]
            for future in futures:
                future.wait()

            sess.run(self._initialize_tables)
            sess.run(self._initialize_local_vars)

            sess.run(self._load_ops)
            program_schedule = None
            while True:
                global_step = sess.run(py_utils.GetGlobalStep())
                if self._ShouldStop(sess, global_step):
                    tf.logging.info('Training finished.')
                    if not self._ml_perf_log:
                        self.save_only_checkpointer.Save(sess, global_step)
                    if program_schedule:
                        tf.logging.info('Shutting down programs.')
                        program_schedule.Shutdown()
                    return

                # If a task is explicitly selected, only run the programs associated
                # with that task.
                if self._single_task_mode or self._model_task_name:
                    tf.logging.info('Single task mode: %s',
                                    self._model_task_name)
                    program_schedule = self._program_schedule_dict[
                        self._model_task_name]
                else:
                    # Otherwise, sample a task.
                    model_task = self.task_scheduler.Sample(global_step)
                    tf.logging.info('Sampled %s', model_task)
                    program_schedule = self._program_schedule_dict[model_task]

                done = program_schedule.Run(sess)
                if done:
                    tf.logging.info('Program schedule told us to stop.\n'
                                    'Shutting down programs.')
                    program_schedule.Shutdown()
                    return

                # global_step local variable above is a result of sess.run, not a
                # tf variable, so when we do save_only_checkpointer.Save(...) here
                # py_utils.GetGlobalStep() is ahead of it by
                #   (train_executions_per_eval * train_steps_per_loop)
                # steps ahead already, due to program_schedule.Run(sess).
                #
                if not self._ml_perf_log:
                    tf.logging.info('Retrieve params.')
                    sess.run(self._retrieve_ops)
                    tf.logging.info('Retrieve params done.')
                    self.save_only_checkpointer.MaybeSave(
                        sess, py_utils.GetGlobalStep())
示例#4
0
    def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir,
                 tf_master, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
        if train_cfg is a SingleTaskModelParams, we expect only one entry.
      model_task_name: An override for multi-task models, currently unused.
      logdir:  String path to the log directory to output to.
      tf_master: String path to the master job, e.g. 'local'.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        super().__init__(train_cfg, model_task_name, logdir, tf_master,
                         **kwargs)

        data_parallelism = self._cluster.num_splits_per_client

        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        self.task_scheduler = None
        self._checkpoint_dir = os.path.join(logdir, 'train')

        self._variable_renaming_rules = []

        self._ml_perf = None

        # If this is a multi-task model, grab the params for the TaskScheduler.
        if issubclass(train_cfg.cls, base_model.SingleTaskModel):
            tf.logging.info('single_task_model')
            assert len(ps_params_dict) == 1
            self._model_task_name = list(ps_params_dict.keys())[0]
            self._single_task_mode = True
        elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
            tf.logging.info('multi_task_model')

            if issubclass(train_cfg.cls,
                          multitask_model.RegExSharedVariableModel):
                self._variable_renaming_rules = train_cfg.variable_renaming_rules

            if train_cfg.task_schedule is None:
                task_schedule_params = task_scheduler.ConstantScheduler.Params(
                )
                task_schedule_params.task_probs = sorted(
                    list(train_cfg.task_probs.IterParams()))
            else:
                task_schedule_params = train_cfg.task_schedule
            self.task_scheduler = task_schedule_params.Instantiate()
            self._single_task_mode = False
        else:
            tf.logging.fatal(
                'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
                train_cfg.cls)

        tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

        self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir,
                         'trainer_params.txt')
        if self._ml_perf is not None:
            self._ml_perf_log = True
            mlp_log.mlperf_print(key='benchmark',
                                 value=self._ml_perf.benchmark_name)
        else:
            self._ml_perf_log = False

        # BaseRunner legacy
        self.enqueue_ops = None

        train_cfg = self.params

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit(job=None):
            """Wait until the model is ready."""
            try:
                # tpu.initialize_system() is called with None as embedding_config, as
                # embedding_config is not available yet. Later in _Loop, it is called
                # with the correct embedding_config. Since it cannot be called twice in
                # the same graph with different embedding_config, we use a dummy_graph
                # here.
                dummy_graph = tf.Graph()
                with dummy_graph.as_default():
                    tpu_initialize_system_op = tf.tpu.initialize_system(
                        embedding_config=None, job=job)

                with self._GetSession(graph=dummy_graph) as sess:
                    topology = sess.run(tpu_initialize_system_op)

                if train_cfg.train.tpu_device_order_mode is None:
                    device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=py_utils.ComputationShape(
                            num_devices_per_split, topology),
                        num_replicas=data_parallelism)
                else:
                    device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=py_utils.ComputationShape(
                            num_devices_per_split, topology),
                        num_replicas=data_parallelism,
                        device_order_mode=train_cfg.train.tpu_device_order_mode
                    )
                py_utils.SetTpuDeviceAssignment(device_assignment, job)
                tf.logging.info('device_assignment.core_assignment: %s',
                                str(device_assignment.core_assignment))
                tf.logging.info(
                    'device_assignment.topology.device_coordinates: %s',
                    str(device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        if self._ml_perf_log:
            mlp_log.mlperf_print(key='init_start', value=None)
        if len(self._cluster.all_worker_names) > 1:
            for worker in self._cluster.all_worker_names:
                _WaitTillInit(worker)
        else:
            _WaitTillInit(None)

        shared_model = self._MaybeConstructSharedModel(train_cfg)

        self._program_schedule_dict = {}
        self._programs = []

        for task_string, program_schedule_params in ps_params_dict.items():
            program_schedule_params.logdir = logdir
            program_schedule_params.num_splits_per_client = data_parallelism
            program_schedule_params.task_name = task_string
            # If the model was created above, we'll inject it here as a shared_model.
            ps = program_schedule_params.Instantiate(shared_model=shared_model,
                                                     tf_master=self._tf_master)
            self._program_schedule_dict[task_string] = ps
            tf.logging.info('program_schedule_params: %s',
                            program_schedule_params.ToText())
            self._programs += ps.Programs()
            if program_schedule_params.ml_perf.benchmark_name is not None:
                self._ml_perf = program_schedule_params.ml_perf

        tf.logging.info('num_programs: %d', len(self._programs))

        with self._graph.as_default(), tf.container(self._container_id):
            with self._cluster, tf.device(self._cluster.GetPlacer()):
                with py_utils.VariableRenameScope(
                        self._variable_renaming_rules):
                    _ = py_utils.GetOrCreateGlobalStepVar()
                    for program in self._programs:
                        program.BuildTpuSubgraph()
                        py_utils.ClearTpuSummaryTensors()

                self._initialize_tables = tf.tables_initializer()
                self._initialize_local_vars = tf.local_variables_initializer()
                self._initialize_global_vars = tf.global_variables_initializer(
                )

                for program in self._programs:
                    program.SetStatusMessageFn(self._SetStatusMessage)
                    program.CreateCheckpointer(
                        init_op=self._initialize_global_vars)

                self.save_only_checkpointer = checkpointer.Checkpointer(
                    self._checkpoint_dir,
                    model=None,
                    init_op=self._initialize_global_vars,
                    train_params=train_cfg.train,
                    save_only=True)

            tpu_embedding_collection = (
                tpu_embedding_layers.TpuEmbeddingCollection.Get())
            self._load_ops = tpu_embedding_collection.load_ops
            self._retrieve_ops = tpu_embedding_collection.retrieve_ops
            self._tpu_embedding = tpu_embedding_collection.tpu_embedding
            tf.io.write_graph(self._graph.as_graph_def(), self._checkpoint_dir,
                              'train.pbtxt')
示例#5
0
    def _LoopEnqueue(self, op, session_override=None):
        """Runs the enqueue op in a loop."""
        p = self.params
        sess = session_override or self._GetSession()

        with tf.container(self._container_id), sess:
            if self._initialize_tables is not None:
                sess.run(self._initialize_tables)
            for task in self._model.tasks:
                task.input.Initialize(sess)
            gsteps = py_utils.GetGlobalStep()
            local_enqueue_steps = 0

            # Global enqueue steps measures how many global steps have data enqueued
            # for already. We use this to terminate; note that the enqueue op may
            # hang in session.run if we do not terminate with this check.
            global_enqueue_steps = None

            tf.logging.info(
                'params.train.max_steps: %d, enqueue_max_steps: %d',
                p.train.max_steps, p.train.enqueue_max_steps)
            while True:
                if self._dequeue_thread_complete:
                    tf.logging.info(
                        'LoopEnqueue done since consuming thread is done.')
                    return

                global_step = sess.run(gsteps)
                if global_enqueue_steps is None:
                    global_enqueue_steps = global_step
                if local_enqueue_steps % 1000 == 0:
                    tf.logging.info(
                        'Current global_enqueue_steps: %d, '
                        'local_enqueue_steps: %d, global_step: %d',
                        global_enqueue_steps, local_enqueue_steps, global_step)

                if py_utils.use_tpu():
                    global_steps_with_available_data = int(
                        global_enqueue_steps // p.train.tpu_steps_per_loop *
                        p.train.tpu_steps_per_loop)
                else:
                    global_steps_with_available_data = global_enqueue_steps

                if (self._ShouldStop(sess, global_steps_with_available_data)
                        or self._ShouldStop(sess, global_step)):
                    tf.logging.info('Done. ShouldStop is True.')
                    tf.logging.info('Enqueue loop sleeping')
                    time.sleep(15)
                    continue
                if (p.train.enqueue_max_steps > 0
                        and local_enqueue_steps >= p.train.enqueue_max_steps):
                    tf.logging.info('Done. train.enqueue_max_steps reached.')
                    tf.logging.info('Enqueue loop sleeping')
                    time.sleep(15)
                    continue
                local_enqueue_steps += 1

                # There are tpu_infeed_parallelism parallel threads enqueuing.
                # We account for all of them when updating global_enqueue_steps.
                global_enqueue_steps += p.input.tpu_infeed_parallelism

                # Input data stats generated during training are collected and logged in
                # in input generators. The merged summary op for input data stats merges
                # all the scalar summaries for the stats logged from the input
                # generators. If merged scalar summaries for input data stats are
                # available write them to the training directory along with processing
                # the TPU infeed op.
                if self._merged_input_data_summary_op is not None:
                    summary_str, _ = sess.run(
                        [self._merged_input_data_summary_op, op])
                    self._WriteInputDataStatSummaries(summary_str,
                                                      global_enqueue_steps)
                else:
                    sess.run([op])
示例#6
0
  def _Loop(self):
    # Evaler/Controller jobs may find that the trial is infeasible and report
    # done earlier. This is an important check since the trainer may retry
    # indefinitely without it.
    if self._trial.ShouldStop():
      tf.logging.info('Training skipped (trial requested to stop).')
      return
    with tf.container(
        self._container_id), self._cluster, self._GetSession() as sess:
      # This initializes local tables
      sess.run(self._initialize_tables)
      # This initializes local variables.
      sess.run(self._initialize_local_vars)
      global_step = self._WaitUntilInit(sess, self._start_up_delay_steps)

      status_interval_steps = 100
      next_status_step = 1
      eval_metrics = None
      while True:
        if (self._trial.ShouldStopAndMaybeReport(global_step, eval_metrics) or
            self._ShouldStop(sess, global_step)):
          tf.logging.info('Training finished.')
          if self._early_stop:
            time.sleep(300)  # controller hangs if it doesn't finish first
          self._DequeueThreadComplete()
          return

        # If a task is explicitly specified, only train that task.
        if self._model_task_name:
          task = self._model.GetTask(self._model_task_name)
        else:
          # Note: This is a slightly stale global_step value from the previous
          # sess.run() call.
          # For multi-task models, `self._model.task_schedule.cur_probs` will
          # be updated.
          task = self._model.SampleTask(global_step)
          if self._task_probs_summary_writers:
            for index, prob in enumerate(self._model.task_schedule.cur_probs):
              self._SummarizeValue(global_step, 'task_probability', prob,
                                   self._task_probs_summary_writers[index])
            try:
              for index, task in enumerate(self._model.tasks):
                self._SummarizeValue(global_step, 'task_weight',
                                     sess.run(task.vars.task_weight),
                                     self._task_probs_summary_writers[index])
            except AttributeError:
              pass

        (_, eval_metrics, per_example_tensors) = sess.run([
            task.train_op,
            task.eval_metrics,
            task.per_example_tensors,
        ])
        # Explicitly fetch global_step after running train_op.
        # TODO(b/151181934): Investigate this behavior further.
        task_global_step = sess.run(task.global_step)
        task.ProcessFPropResults(sess, task_global_step, eval_metrics,
                                 per_example_tensors)

        global_step = sess.run(self._model.global_step)
        step_rate, example_rate, total_examples = (
            self._step_rate_tracker.ComputeStepRate(
                global_step, eval_metrics['num_samples_in_batch'][0]))
        self._SummarizeValue(global_step, 'global_step/sec', step_rate)
        self._SummarizeValue(global_step, 'examples/sec', example_rate)
        self._SummarizeValue(global_step, 'total_samples', total_examples)

        msg = 'step:%6d, steps/sec: %0.2f, examples/sec: %0.2f' % (
            global_step, step_rate, example_rate)
        for key, (val, _) in sorted(eval_metrics.items()):
          msg += ' %s:%.8g' % (key, val)
          self._SummarizeValue(global_step, key, val)
        if global_step >= next_status_step:
          self._SetStatusMessage(msg)
          self._ExportMetrics(
              # Metrics expects python int, but global_step is numpy.int64.
              global_step=int(global_step),
              step_rate=step_rate,
              example_rate=example_rate)
          next_status_step = global_step + status_interval_steps
        else:
          tf.logging.info(msg)
        self._model.ProcessFPropResults(sess, global_step, eval_metrics,
                                        per_example_tensors)
示例#7
0
    def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir,
                 tf_master, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
          if train_cfg is a SingleTaskModelParams, we expect only one entry.
      model_task_name: An override for multi-task models, currently unused.
      logdir:  String path to the log directory to output to.
      tf_master: String path to the master job, e.g. 'local'.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        super(ExecutorTpu, self).__init__(train_cfg, model_task_name, logdir,
                                          tf_master, **kwargs)
        self._cluster_def = self._cluster.worker_cluster_def

        # There is a single Executor task
        assert self._cluster.num_replicas == 1
        data_parallelism = self._cluster.num_splits_per_client

        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        self.task_scheduler = None

        # If this is a multi-task model, grab the params for the TaskScheduler.
        if issubclass(train_cfg.cls, base_model.SingleTaskModel):
            tf.logging.info('single_task_model')
            assert len(ps_params_dict) == 1
            self._model_task_name = ps_params_dict.keys()[0]
            self._single_task_mode = True
        elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
            tf.logging.info('multi_task_model')
            if train_cfg.task_schedule is None:
                task_schedule_params = task_scheduler.ConstantScheduler.Params(
                )
                task_schedule_params.task_probs = sorted(
                    list(train_cfg.task_probs.IterParams()))
            else:
                task_schedule_params = train_cfg.task_schedule
            self.task_scheduler = task_schedule_params.Instantiate()
            self._single_task_mode = False
        else:
            tf.logging.fatal(
                'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
                train_cfg.cls)

        tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

        self._program_schedule_dict = {}
        self._programs = []

        for task_string, program_schedule_params in ps_params_dict.items():
            program_schedule_params.logdir = logdir
            program_schedule_params.num_splits_per_client = data_parallelism
            program_schedule_params.task_name = task_string
            ps = program_schedule_params.Instantiate()
            self._program_schedule_dict[task_string] = ps
            tf.logging.info('program_schedule_params: %s',
                            program_schedule_params.ToText())
            self._programs += ps.Programs()

        tf.logging.info('num_programs: %d', len(self._programs))

        # BaseRunner legacy
        self.enqueue_ops = None

        def ComputationShape(split_size):
            """Decides the computation shape based on the split_size."""
            computation_shape = None
            if split_size == 1:
                computation_shape = [1, 1, 1]
            elif split_size == 2:
                computation_shape = [1, 1, 2]
            elif split_size == 4:
                computation_shape = [1, 2, 2]
            elif split_size == 8:
                computation_shape = [2, 2, 2]
            elif split_size == 16:
                computation_shape = [4, 2, 2]
            else:
                assert False, (
                    'Model parallelism with %d devices is currently not'
                    ' supported.' % split_size)
            assert computation_shape is not None
            return computation_shape

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit():
            """Wait until the model is ready."""
            try:
                with self._GetSession(cluster_def=self._cluster_def) as sess:
                    topology = sess.run(
                        tf.tpu.initialize_system(embedding_config=None,
                                                 job=None))
                    device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=ComputationShape(
                            num_devices_per_split),
                        num_replicas=data_parallelism)
                    py_utils.SetTpuDeviceAssignment(device_assignment)
                    tf.logging.info('device_assignment.core_assignment: %s',
                                    str(device_assignment.core_assignment))
                    tf.logging.info(
                        'device_assignment.topology.device_coordinates: %s',
                        str(device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        _WaitTillInit()

        with self._graph.as_default(), tf.container(self._container_id):
            with self._cluster, tf.device(self._cluster.job_spec.name):
                for program in self._programs:
                    program.BuildTpuSubgraph()
                self.initialize_tables = tf.tables_initializer()
                self._initialize_local_vars = tf.local_variables_initializer()

                self._checkpoint_dir = os.path.join(logdir, 'train')
                self.save_only_checkpointer = checkpointer.Checkpointer(
                    self._checkpoint_dir,
                    model=None,
                    train_params=train_cfg.train,
                    save_only=True)
示例#8
0
    def _LoopEnqueue(self, op, session_override=None):
        """Runs the enqueue op in a loop."""
        p = self.params
        sess = session_override or self._GetSession()

        with tf.container(self._container_id), sess:
            if self._initialize_tables is not None:
                sess.run(self._initialize_tables)
            for task in self._model.tasks:
                task.input.Initialize(sess)
            gsteps = py_utils.GetGlobalStep()
            local_enqueue_steps = 0

            # Global enqueue steps measures how many global steps have data enqueued
            # for already. We use this to terminate; note that the enqueue op may
            # hang in session.run if we do not terminate with this check.
            global_enqueue_steps = None

            tf.logging.info(
                'params.train.max_steps: %d, enqueue_max_steps: %d',
                p.train.max_steps, p.train.enqueue_max_steps)
            while True:
                if self._dequeue_thread_complete:
                    tf.logging.info(
                        'LoopEnqueue done since consuming thread is done.')
                    return

                global_step = sess.run(gsteps)
                if global_enqueue_steps is None:
                    global_enqueue_steps = global_step
                if local_enqueue_steps % 1000 == 0:
                    tf.logging.info(
                        'Current global_enqueue_steps: %d, '
                        'local_enqueue_steps: %d, global_step: %d',
                        global_enqueue_steps, local_enqueue_steps, global_step)

                if py_utils.use_tpu():
                    global_steps_with_available_data = int(
                        global_enqueue_steps // p.train.tpu_steps_per_loop *
                        p.train.tpu_steps_per_loop)
                else:
                    global_steps_with_available_data = global_enqueue_steps

                if (self._ShouldStop(sess, global_steps_with_available_data)
                        or self._ShouldStop(sess, global_step)):
                    tf.logging.info('Done. ShouldStop is True.')
                    tf.logging.info('Enqueue loop sleeping')
                    time.sleep(15)
                    continue
                if (p.train.enqueue_max_steps > 0
                        and local_enqueue_steps >= p.train.enqueue_max_steps):
                    tf.logging.info('Done. train.enqueue_max_steps reached.')
                    tf.logging.info('Enqueue loop sleeping')
                    time.sleep(15)
                    continue
                local_enqueue_steps += 1

                # There are tpu_infeed_parallelism parallel threads enqueuing.
                # We account for all of them when updating global_enqueue_steps.
                global_enqueue_steps += p.input.tpu_infeed_parallelism

                sess.run([op])
示例#9
0
    def __init__(self, task_dict, program_schedule_params, model_task_name,
                 logdir, tf_master, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      task_dict: A dict of dataset_name -> task params.
      program_schedule_params: A ProgramSchedule params.
      model_task_name: An override for multi-task models, currently unused.
      logdir:  String path to the log directory to output to.
      tf_master: String path to the master job, e.g. 'local'.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        # TODO(blee): fix this.
        train_params = task_dict['Train']
        super(ExecutorTpu, self).__init__(train_params, model_task_name,
                                          logdir, tf_master, **kwargs)
        self._cluster_def = self._cluster.worker_cluster_def

        # There is a single Executor task
        assert self._cluster.num_replicas == 1
        data_parallelism = self._cluster.num_splits_per_client

        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        # Update run-time params
        program_schedule_params.task_dict = task_dict
        program_schedule_params.logdir = logdir
        program_schedule_params.num_splits_per_client = data_parallelism

        self._programs = []
        self._program_schedule = program_schedule_params.Instantiate()

        tf.logging.info('program_schedule_params: %s',
                        program_schedule_params.ToText())

        self._programs += self._program_schedule.Programs()

        # BaseRunner legacy
        self.enqueue_ops = None

        def ComputationShape(split_size):
            """Decides the computation shape based on the split_size."""
            computation_shape = None
            if split_size == 1:
                computation_shape = [1, 1, 1]
            elif split_size == 2:
                computation_shape = [1, 1, 2]
            elif split_size == 4:
                computation_shape = [1, 2, 2]
            elif split_size == 8:
                computation_shape = [2, 2, 2]
            elif split_size == 16:
                computation_shape = [4, 2, 2]
            else:
                assert False, (
                    'Model parallelism with %d devices is currently not'
                    ' supported.' % split_size)
            assert computation_shape is not None
            return computation_shape

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit():
            """Wait until the model is ready."""
            try:
                with self._GetSession(cluster_def=self._cluster_def) as sess:
                    topology = sess.run(
                        tf.tpu.initialize_system(embedding_config=None,
                                                 job=None))
                    device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=ComputationShape(
                            num_devices_per_split),
                        num_replicas=data_parallelism)
                    py_utils.SetTpuDeviceAssignment(device_assignment)
                    tf.logging.info('device_assignment.core_assignment: %s',
                                    str(device_assignment.core_assignment))
                    tf.logging.info(
                        'device_assignment.topology.device_coordinates: %s',
                        str(device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        _WaitTillInit()

        with self._graph.as_default(), tf.container(self._container_id):
            with self._cluster, tf.device(self._cluster.job_spec.name):
                for program in self._programs:
                    program.BuildTpuSubgraph()
                self.initialize_tables = tf.tables_initializer()
                self._initialize_local_vars = tf.local_variables_initializer()