Example #1
0
    def testDeviceListMultiReplicaNoSyncSgd(self):
        p = cluster_factory.Cluster.Params()
        p.mode = 'async'
        p.job = 'trainer'
        p.task = 1
        p.worker.replicas = 2
        p.worker.gpus_per_replica = 2
        c = cluster_factory.Cluster(p)
        gpu_devices = c.available_devices
        expected_gpu_devices = [[
            cluster.MakeDeviceString(job_name='/job:localhost',
                                     task_id=1,
                                     device_name='GPU',
                                     device_id=0),
            cluster.MakeDeviceString(job_name='/job:localhost',
                                     task_id=1,
                                     device_name='GPU',
                                     device_id=1),
        ]]
        self.assertAllEqual(gpu_devices, expected_gpu_devices)

        # Compute the total number of worker devices for a multi
        # replica setup.
        self.assertEqual(4, c.total_worker_devices)

        # Even when the job is different, we still look at the worker
        # information.
        p.job = 'controller'
        p.task = 0
        c = cluster_factory.Cluster(p)
        self.assertEqual(4, c.total_worker_devices)
Example #2
0
    def __init__(self, sess, batch_size):
        self.sess = sess
        self.batch_size = batch_size

        tf.set_random_seed(1234)
        params = model_registry.GetParams('asr.librispeech.Librispeech960Wpm',
                                          'Test')
        params.random_seed = 1234
        params.is_eval = True
        params.cluster.worker.gpus_per_replica = 1
        cluster = cluster_factory.Cluster(params.cluster)
        with cluster, tf.device(cluster.GetPlacer()):
            model = params.cls(params)

            # placeholders
            self.input_tf = tf.placeholder(tf.float32,
                                           shape=[batch_size, None],
                                           name='qq_input')
            self.tgt_tf = tf.placeholder(tf.string)
            self.sample_rate_tf = tf.placeholder(tf.int32,
                                                 name='qq_sample_rate')
            self.maxlen = tf.placeholder(np.int32)

            # generate the inputs that are needed for the lingvo model
            self.features = create_features(self.input_tf, self.sample_rate_tf)
            self.inputs = create_inputs(model, self.features, self.tgt_tf,
                                        self.batch_size)

            task = model.GetTask()
            metrics = task.FPropDefaultTheta(self.inputs)
            self.decoded = task.Decode(self.inputs)
Example #3
0
 def testDefaultParams(self):
   p = cluster_factory.Cluster.Params()
   c = cluster_factory.Cluster(p)
   self.assertFalse(c.add_summary)
   g = tf.Graph()
   vs = []
   with g.as_default():
     with tf.device(c.GetPlacer()):
       for i in range(10):
         vs.append(tf.get_variable('x%d' % i, (10, 10, 10)))
       sum_all = tf.add_n(vs)
   for v in vs:
     self.assertEqual(
         v.device,
         c._MakeDeviceString(
             job_name='/job:localhost',
             task_id=0,
             device_name='CPU',
             device_id=0))
   self.assertEqual(
       sum_all.device,
       c._MakeDeviceString(
           job_name='/job:localhost',
           task_id=0,
           device_name='CPU',
           device_id=0))
Example #4
0
 def testNoPS(self):
     p = cluster_factory.Cluster.Params()
     p.worker.name = '/job:trainer'
     p.worker.replicas = 1
     p.ps.name = '/job:trainer'
     p.ps.replicas = 1
     c = cluster_factory.Cluster(p)
     g = tf.Graph()
     vs = []
     with g.as_default():
         with tf.device(c.GetPlacer()):
             for i in range(10):
                 vs.append(tf.get_variable('x%d' % i, (10, 10, 10)))
             sum_all = tf.add_n(vs)
     for v in vs:
         self.assertEqual(
             v.device,
             cluster.MakeDeviceString(job_name='/job:trainer',
                                      task_id=0,
                                      device_name='CPU',
                                      device_id=0))
     self.assertEqual(
         sum_all.device,
         cluster.MakeDeviceString(job_name='/job:trainer',
                                  task_id=0,
                                  device_name='CPU',
                                  device_id=0))
Example #5
0
    def __init__(self,
                 params,
                 model_task_name,
                 logdir,
                 tf_master,
                 trial=base_trial.NoOpTrial()):
        """Construct a new BaseRunner.

    Args:
      params:  Params object containing model configuration.
      model_task_name:  String name of the task this runner should execute for
        multitask models only.  See flag for details.
      logdir:  String path to the log directory to output to.
      tf_master:  String path to the master job, e.g. 'local'.
      trial:   An optional hyperparameter trial. Used by Vizier studies.
    """
        p = params.Copy()
        # Set in subclasses.
        self._job_name = ''

        self._params = trial.OverrideModelParams(p)
        tf.logging.info('=' * 60)
        for line in self.params.ToText().split('\n'):
            tf.logging.info('%s', line)
        tf.logging.info('=' * 60)

        self._logdir = logdir
        self._tf_master = tf_master
        self._model_task_name = model_task_name
        self._trial = trial
        # If the runner is conducting a Vizier trial, scope all the variables
        # (e.g., global_step) by the trial id so that we do not share states across
        # trials.
        self._container_id = self._trial.Name()
        self._should_report_metrics = False

        # To early terminate a runner, we set max_steps here and that will trigger
        # appropriate ShouldStop behavior in the threads. This is used by Vizier
        # to early stop a trial.
        self._max_steps = None

        self.params.cluster.logdir = logdir
        self._cluster = cluster_factory.Cluster(self.params.cluster)
        self._train_dir = os.path.join(self._logdir, 'train')
        tf.io.gfile.makedirs(self._train_dir)
        self._graph = tf.Graph()
        self._summary_writer = None
        self._initialize_tables = None
        self._dequeue_thread_complete = False

        early_stop.MetricHistory.SetLogdirInMetricHistories(p, logdir)
        self._early_stop = None
        if p.train.early_stop and p.train.early_stop.window:
            self._early_stop = early_stop.EarlyStop(p.train.early_stop)
            with self._graph.as_default():
                self._early_stop.FProp(None)

        self._init_input_ops = []

        self._SetStatusMessage('Starting ...')
Example #6
0
 def testDeviceListMultiReplicaSyncSgd(self):
     p = cluster_factory.Cluster.Params()
     p.mode = 'sync'
     p.job = 'trainer_client'
     p.worker.name = '/job:localhost'
     p.worker.replicas = 2
     p.worker.gpus_per_replica = 2
     c = cluster_factory.Cluster(p)
     gpu_devices = c.available_devices
     expected_gpu_devices = [
         [
             cluster.MakeDeviceString(job_name='/job:localhost',
                                      task_id=0,
                                      device_name='GPU',
                                      device_id=0),
             cluster.MakeDeviceString(job_name='/job:localhost',
                                      task_id=0,
                                      device_name='GPU',
                                      device_id=1),
         ],
         [
             cluster.MakeDeviceString(job_name='/job:localhost',
                                      task_id=1,
                                      device_name='GPU',
                                      device_id=0),
             cluster.MakeDeviceString(job_name='/job:localhost',
                                      task_id=1,
                                      device_name='GPU',
                                      device_id=1),
         ]
     ]
     self.assertAllEqual(gpu_devices, expected_gpu_devices)
Example #7
0
    def GetParamsForDataset(self, job_name, dataset_name):
        """Returns params for job `job_name` on the dataset `dataset_name`."""
        # Get the current cluster and update its params from flags.
        cluster = cluster_factory.Current()
        self.UpdateClusterParamsFromFlags(cluster.params, job_name)
        with cluster_factory.Cluster(cluster.params):
            try:
                cfg = self.model_registry.GetParams(self._model_name,
                                                    dataset_name)
            except base_model_params.DatasetError as e:
                dataset_name_retry = dataset_name.title()
                tf.logging.warning(
                    'Exception configuring dataset %s, retrying as %s: %s',
                    dataset_name, dataset_name_retry, e)
                cfg = self.model_registry.GetParams(self._model_name,
                                                    dataset_name_retry)
                tf.logging.warning('Succeeded after retrying as %s.' %
                                   dataset_name_retry)
        cfg.cluster = cluster.params

        # Updates a few params based on flags.
        if FLAGS.enqueue_max_steps is not None:
            cfg.train.enqueue_max_steps = FLAGS.enqueue_max_steps
        if FLAGS.saver_max_to_keep is not None:
            cfg.train.save_max_to_keep = FLAGS.saver_max_to_keep
        if FLAGS.saver_keep_checkpoint_every_n_hours is not None:
            cfg.train.save_keep_checkpoint_every_n_hours = FLAGS.saver_keep_checkpoint_every_n_hours
        return cfg
Example #8
0
 def testPSWithGPUs(self):
     p = cluster_factory.Cluster.Params()
     p.worker.name = '/job:trainer'
     p.worker.replicas = 1
     p.ps.name = '/job:ps'
     p.ps.replicas = 4
     p.ps.gpus_per_replica = 2
     c = cluster_factory.Cluster(p)
     g = tf.Graph()
     vs = []
     with g.as_default():
         with tf.device(c.GetPlacer()):
             for i in range(10):
                 vs.append(tf.get_variable('x%d' % i, (10, 10, 10)))
             sum_all = tf.add_n(vs)
     for i, v in enumerate(vs):
         self.assertEqual(
             v.device,
             cluster.MakeDeviceString(job_name='/job:ps',
                                      task_id=(i / 2) % 4,
                                      device_name='GPU',
                                      device_id=i % 2))
     self.assertEqual(
         sum_all.device,
         cluster.MakeDeviceString(job_name='/job:trainer',
                                  task_id=0,
                                  device_name='CPU',
                                  device_id=0))
Example #9
0
 def testDefaultParamsWithDynamicShape(self):
     p = cluster_factory.Cluster.Params()
     c = cluster_factory.Cluster(p)
     g = tf.Graph()
     vs = []
     with g.as_default():
         with tf.device(c.GetPlacer()):
             for i in range(10):
                 dyn_shape = tf.constant([2], dtype=tf.int32)
                 dyn_shape = tf.placeholder_with_default(dyn_shape,
                                                         shape=[None])
                 v = tf.get_variable('x%d_wb/var' % i,
                                     initializer=tf.random.uniform(
                                         dyn_shape, dtype=tf.float64),
                                     validate_shape=False)
                 vs.append(v)
             sum_all = tf.add_n(vs)
     for v in vs:
         self.assertEqual(
             v.device,
             cluster.MakeDeviceString(job_name='/job:localhost',
                                      task_id=0,
                                      device_name='CPU',
                                      device_id=0))
     self.assertEqual(
         sum_all.device,
         cluster.MakeDeviceString(job_name='/job:localhost',
                                  task_id=0,
                                  device_name='CPU',
                                  device_id=0))
Example #10
0
 def testInputTargets(self):
   p = cluster_factory.Cluster.Params()
   p.input.name = '/job:input'
   p.input.replicas = 2
   p.input.targets = '10.100.1.1:10001,10.100.1.2:10002'
   c = cluster_factory.Cluster(p)
   self.assertEqual(c.input_targets, ['10.100.1.1:10001', '10.100.1.2:10002'])
Example #11
0
def GetExecutorParams(model_name, cluster_params, model_registry):
  """Get the params needed to instantiate the Executor.

  Args:
    model_name: A model name regsitered in the ModelRegistry.
    cluster_params: A cluster hyperparams object.
    model_registry: A ModelRegistry object.

  Returns:
    A tuple (dict, Params):

    - ps_params_dict: High-level task name -> ProgramScheduleParams
    - train_cfg: A SingleTaskModelParams or MultiTaskModelParams.
  """

  ps_params_dict = {}
  with cluster_factory.Cluster(cluster_params):
    ps_cfg = model_registry.GetProgramSchedule(model_name)
    train_cfg = model_registry.GetParams(model_name, 'Train')
    train_cfg.cluster = cluster_params

    if issubclass(train_cfg.cls, base_model.MultiTaskModel):
      multi_task_train_cfg = train_cfg
      # Create SingleTaskModelParams from a MultiTaskModelParams.
      for k, _ in multi_task_train_cfg.task_params.IterParams():
        single_task_params = base_model.SingleTaskModel.Params()
        single_task_params.cluster = multi_task_train_cfg.cluster
        single_task_params.input = multi_task_train_cfg.input.Get(k)
        single_task_params.task = multi_task_train_cfg.task_params.Get(k)
        single_task_params.train = single_task_params.task.train
        if k not in ps_cfg.program_schedule_dict:
          tf.logging.fatal(
              'Could not find %s in ps_cfg.program_schedule_dict: %s', k,
              ps_cfg)
        program_schedule_params = ps_cfg.program_schedule_dict[k]

        program_schedule_params.task_dict = {'Train': single_task_params}

        for eval_dataset_name in program_schedule_params.dataset_names:
          multi_task_eval_cfg = model_registry.GetParams(
              model_name, eval_dataset_name)
          eval_task_params = base_model.SingleTaskModel.Params()
          eval_task_params.cluster = single_task_params.cluster
          eval_task_params.input = multi_task_eval_cfg.input.Get(k)
          eval_task_params.task = multi_task_eval_cfg.task_params.Get(k)
          program_schedule_params.task_dict[
              eval_dataset_name] = eval_task_params
        ps_params_dict[k] = program_schedule_params
    else:
      program_schedule_params = ps_cfg
      program_schedule_params.task_dict = {'Train': train_cfg}
      for eval_dataset_name in program_schedule_params.dataset_names:
        task_eval_params = model_registry.GetParams(model_name,
                                                    eval_dataset_name)
        task_eval_params.cluster = train_cfg.cluster
        program_schedule_params.task_dict[eval_dataset_name] = task_eval_params
      ps_params_dict[''] = program_schedule_params

  return ps_params_dict, train_cfg
Example #12
0
 def InspectModel(self):
   """Prints out model analysis for the model."""
   p = self.GetParamsForDataset('controller', 'Train')
   p.cluster.mode = 'sync'
   c = cluster_factory.Cluster(p.cluster)
   with tf.Graph().as_default(), c, tf.device(c.GetPlacer()):
     analysis, _ = _ModelAnalysis(p.cls(p))
   print(analysis)
Example #13
0
 def testInputDevice(self):
   p = cluster_factory.Cluster.Params()
   p.mode = 'sync'
   p.job = 'trainer_client'
   p.input.name = '/job:input'
   p.input.replicas = 1
   c = cluster_factory.Cluster(p)
   input_device = c.input_device
   expected_device = c._MakeDeviceString(
       job_name='/job:input', task_id=0, device_name='CPU', device_id=0)
   self.assertEqual(input_device, expected_device)
Example #14
0
  def __init__(self, decoder_type, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.params.cluster.do_eval = True
    self._cluster = cluster_factory.Cluster(self.params.cluster)

    self._decoder_dir = os.path.join(self._logdir, f'decoder_{decoder_type}')
    if self._model_task_name:
      self._decoder_dir += '_' + self._model_task_name
    tf.io.gfile.makedirs(self._decoder_dir)
    self._summary_writer = tf.compat.v2.summary.create_file_writer(
        self._decoder_dir)
Example #15
0
    def __init__(self, decoder_type, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._job_name = 'decoder_' + decoder_type
        self.params.cluster.do_eval = True
        self._cluster = cluster_factory.Cluster(self.params.cluster)
        self._decoder_dir = GetDecoderDir(self._logdir, self._job_name,
                                          self._model_task_name)
        tf.io.gfile.makedirs(self._decoder_dir)

        self._decode_path = None
        # Multitask params doesn't have 'task'.
        if 'task' in self.params:
            self._decode_path = checkpointer.GetSpecificCheckpoint(
                self.params.task.eval.load_checkpoint_from)

        self._should_report_metrics = self._job_name.startswith(
            self._cluster.reporting_job)

        with self._graph.as_default(), tf.container(self._container_id):
            self._summary_writer = self._CreateSummaryWriter(self._decoder_dir)
            self._CreateTF2SummaryWriter(self._decoder_dir)
            with self._cluster, tf.device(
                    self._cluster.GetPlacer()), self._TF2SummaryContext():
                self._model = self.params.Instantiate()
                self._params = self._model.params
                self._task = self._model.GetTask(self._model_task_name)
                # Note, different graphs are being constructed for different model
                # tasks, which may result in different node names being chosen.
                # Obviously, variable names has to be stay the same between train and
                # decode.
                cluster = self._cluster
                with tf.device(cluster.input_device):
                    input_batch = (
                        self._task.input_generator.GetPreprocessedInputBatch())

                self._dec_output = self._task.Decode(input_batch)
                self._summary_op = tf.summary.merge_all()
                self.checkpointer = self._CreateCheckpointer(
                    self._train_dir, self._model)
            self._CreateTF2SummaryOps()
            self._initialize_tables = tf.tables_initializer()
            self._initialize_local_vars = tf.local_variables_initializer()
            # No queues are allowed for decoder models.
            self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS)
            assert not self.enqueue_ops

        # Saves the graph def.
        self._WriteToLog(self.params.ToText(), self._decoder_dir, 'params.txt')
        if self.params.cluster.task == 0:
            tf.io.write_graph(self._graph.as_graph_def(), self._decoder_dir,
                              '%s.pbtxt' % self._job_name)
Example #16
0
 def testWorkerDeviceInModelSplitSync(self):
   p = cluster_factory.Cluster.Params()
   p.mode = 'sync'
   p.job = 'trainer_client'
   p.worker.name = '/job:trainer'
   p.worker.replicas = 4
   p.worker.gpus_per_replica = 4
   p.worker.devices_per_split = 2
   c = cluster_factory.Cluster(p)
   with py_utils.ModelSplit(1):
     d = c.WorkerDeviceInModelSplit(1)
   expected_device = c._MakeDeviceString(
       job_name='/job:trainer', task_id=0, device_name='GPU', device_id=3)
   self.assertEqual(expected_device, d)
    def __init__(self, model_name, split, run_preprocessors):
        self._model_name = model_name
        self._split = split
        self._run_preprocessors = run_preprocessors
        self._sess = None

        # Create a cluster configuration assuming evaluation; the input pipelines
        # need to know the cluster job type to set up the outputs correctly.
        cluster = cluster_factory.Current()
        cluster.params.job = 'evaler'
        cluster.params.mode = 'sync'
        cluster.params.task = 0
        cluster.params.evaler.replicas = 1
        self._cluster = cluster_factory.Cluster(cluster.params)
Example #18
0
 def testWorkerDeviceInModelSplit(self):
   p = cluster_factory.Cluster.Params()
   p.mode = 'async'
   p.job = 'trainer'
   p.task = 3
   p.worker.name = '/job:trainer'
   p.worker.replicas = 4
   p.worker.gpus_per_replica = 4
   p.worker.devices_per_split = 2
   with cluster_factory.Cluster(p):
     with cluster_factory.SetModelSplit(1) as c:
       d = c.WorkerDeviceInModelSplit(1)
       expected_device = c._MakeDeviceString(
           job_name='/job:trainer', task_id=3, device_name='GPU', device_id=3)
   self.assertEqual(expected_device, d)
Example #19
0
    def __init__(self,
                 params,
                 model_task_name,
                 logdir,
                 tf_master,
                 trial=base_trial.NoOpTrial()):
        """Construct a new BaseRunner.

    Args:
      params:  Params object containing model configuration.
      model_task_name:  String name of the task this runner should execute
        for multitask models only.  See flag for details.
      logdir:  String path to the log directory to output to.
      tf_master:  String path to the master job, e.g. 'local'.
      trial:   An optional hyperparameter trial. Used by Vizier studies.
    """
        p = params.Copy()
        p.add_summary = FLAGS.add_summary

        self._params = trial.OverrideModelParams(p)
        tf.logging.info('=' * 60)
        for line in self.params.ToText().split('\n'):
            tf.logging.info('%s', line)
        tf.logging.info('=' * 60)

        self._logdir = logdir
        self._tf_master = tf_master
        self._model_task_name = model_task_name
        self._trial = trial
        # If the runner is conducting a Vizier trial, scope all the variables
        # (e.g., global_step) by the trial id so that we do not share states across
        # trials.
        self._container_id = self._trial.Name()

        self._cluster = cluster_factory.Cluster(self.params.cluster)
        self._train_dir = os.path.join(self._logdir, 'train')
        self._graph = tf.Graph()
        self._summary_writer = None
        self.initialize_tables = None

        early_stop.MetricHistory.SetLogdirInMetricHistories(p, logdir)
        self._early_stop = None
        if p.train.early_stop and p.train.early_stop.window:
            self._early_stop = early_stop.EarlyStop(p.train.early_stop)
            with self._graph.as_default():
                self._early_stop.FProp(None)

        self._SetStatusMessage('Starting ...')
Example #20
0
 def testDeviceListOneReplicaGpu(self):
     p = cluster_factory.Cluster.Params()
     p.mode = 'async'
     p.job = 'trainer'
     p.worker.gpus_per_replica = 2
     c = cluster_factory.Cluster(p)
     gpu_devices = c.available_devices
     expected_gpu_devices = [[
         c._MakeDeviceString(job_name='/job:localhost',
                             task_id=0,
                             device_name='GPU',
                             device_id=0),
         c._MakeDeviceString(job_name='/job:localhost',
                             task_id=0,
                             device_name='GPU',
                             device_id=1),
     ]]
     self.assertAllEqual(gpu_devices, expected_gpu_devices)
Example #21
0
 def testPSRandomSize(self):
   p = cluster_factory.Cluster.Params()
   p.worker.name = '/job:trainer'
   p.ps.name = '/job:ps'
   p.ps.replicas = 10
   c = cluster_factory.Cluster(p)
   g = tf.Graph()
   vs = []
   np.random.seed(301)
   with g.as_default():
     with tf.device(c.GetPlacer()):
       # Creates 200 variables with different sizes.
       for i in range(200):
         if i % 13:
           size = np.random.randint(10000)
         elif i % 7:
           size = np.random.randint(100)
         else:
           size = np.random.randint(10)
         vs.append(tf.get_variable('x%d' % i, shape=(size)))
       sum_all = tf.add_n([tf.reduce_sum(x) for x in vs])
   # Computes the total size of variables placed on each device.
   total_size = {}  # device name -> size
   for v in vs:
     size = tf.TensorShape(v.op.get_attr('shape')).num_elements()
     if v.device in total_size:
       total_size[v.device] += size
     else:
       total_size[v.device] = size
   for (device, allocated) in zip(
       sorted(total_size),
       [91701, 91361, 90346, 88738, 87240, 89265, 91944, 92472, 88051, 95053]):
     self.assertEqual(total_size[device], allocated)
   self.assertEqual(
       sum_all.device,
       cluster.MakeDeviceString(
           job_name='/job:trainer',
           replica_id=0,
           task_id=0,
           device_name='CPU',
           device_id=0))
Example #22
0
 def testDeviceListOneRepliaCpu(self):
   p = cluster_factory.Cluster.Params()
   p.mode = 'async'
   p.job = 'trainer'
   p.worker.cpus_per_replica = 2
   c = cluster_factory.Cluster(p)
   cpu_devices = c.available_devices
   expected_cpu_devices = [[
       cluster.MakeDeviceString(
           job_name='/job:localhost',
           replica_id=0,
           task_id=0,
           device_name='CPU',
           device_id=0),
       cluster.MakeDeviceString(
           job_name='/job:localhost',
           replica_id=0,
           task_id=0,
           device_name='CPU',
           device_id=1),
   ]]
   print(expected_cpu_devices)
   self.assertAllEqual(cpu_devices, expected_cpu_devices)
Example #23
0
    def InspectModel(self):
        """Prints out model analysis for the model."""
        FLAGS.mode = 'sync'
        p = self.GetParamsForDataset('controller', 'Train')
        c = cluster_factory.Cluster(p.cluster)
        model_part_regex = FLAGS.inspect_model_part_regex
        part_pattern = None
        if model_part_regex:
            part_pattern = {}
            for pat_str in model_part_regex:
                first_colon = pat_str.find(':')
                if first_colon < 0:
                    msg = f'Cannot understand --inspect_model_part_regex={pat_str}.'
                    raise ValueError(msg)
                name = pat_str[:first_colon]
                pattern = pat_str[first_colon + 1:]
                part_pattern[name] = pattern

        with tf.Graph().as_default(), c, tf.device(c.GetPlacer()):
            analysis, _ = summary_utils.ModelAnalysis(
                p.Instantiate(),
                topn=FLAGS.inspect_model_topn,
                part_pattern=part_pattern)
        print(analysis)
Example #24
0
def main(argv):
    data = np.loadtxt(FLAGS.input, dtype=str, delimiter=",")
    # calculate the number of loops to run the test
    num = len(data[0])
    batch_size = FLAGS.batch_size
    num_loops = num / batch_size
    assert num % batch_size == 0

    with tf.device("/gpu:0"):
        tf.set_random_seed(1234)
        tfconf = tf.ConfigProto(allow_soft_placement=True)
        with tf.Session(config=tfconf) as sess:
            params = model_registry.GetParams(
                "asr.librispeech.Librispeech960Wpm", "Test")
            params.cluster.worker.gpus_per_replica = 1
            cluster = cluster_factory.Cluster(params.cluster)
            with cluster, tf.device(cluster.GetPlacer()):
                params.vn.global_vn = False
                params.random_seed = 1234
                params.is_eval = True
                model = params.cls(params)
                task = model.GetTask()
                saver = tf.train.Saver()
                saver.restore(sess, FLAGS.checkpoint)

                # define the placeholders
                input_tf = tf.placeholder(tf.float32, shape=[batch_size, None])
                tgt_tf = tf.placeholder(tf.string)
                sample_rate_tf = tf.placeholder(tf.int32)
                mask_tf = tf.placeholder(tf.float32,
                                         shape=[batch_size, None, 80])
                rir_tf = tf.placeholder(tf.float32)
                lengths = tf.placeholder(
                    np.int32,
                    shape=[
                        batch_size,
                    ],
                )
                maxlen = tf.placeholder(np.int32)
                mask = tf.placeholder(dtype=np.float32,
                                      shape=[batch_size, None])

                # generate the features and inputs
                new_input = (create_speech_rir(input_tf, rir_tf, lengths,
                                               maxlen, batch_size) * mask)
                features = create_features(new_input, sample_rate_tf, mask_tf)
                shape = tf.shape(features)
                inputs = create_inputs(model, features, tgt_tf, batch_size,
                                       mask_tf)

                # loss
                metrics = task.FPropDefaultTheta(inputs)
                loss = tf.get_collection("per_loss")[0]

                # prediction
                decoded_outputs = task.Decode(inputs)
                dec_metrics_dict = task.CreateDecoderMetrics()

                success_rates = []
                for num_room in range(FLAGS.num_test_rooms):
                    correct = 0
                    rir = Readrir(num_room)

                    for l in range(num_loops):
                        data_sub = data[:, l * batch_size:(l + 1) * batch_size]
                        (
                            audios_np,
                            sample_rate,
                            tgt_np,
                            mask_freq,
                            lengths_np,
                            max_len,
                            masks,
                        ) = Read_input(data_sub, batch_size)

                        feed_dict = {
                            input_tf: audios_np,
                            sample_rate_tf: sample_rate,
                            tgt_tf: tgt_np,
                            mask_tf: mask_freq,
                            rir_tf: rir,
                            lengths: lengths_np,
                            maxlen: max_len,
                            mask: masks,
                        }

                        losses = sess.run(loss, feed_dict)
                        predictions = sess.run(decoded_outputs, feed_dict)

                        task.PostProcessDecodeOut(predictions,
                                                  dec_metrics_dict)
                        wer_value = dec_metrics_dict["wer"].value * 100.0

                        for i in range(batch_size):
                            print("example: {}, loss_ce: {}".format(
                                l * batch_size + i, losses[i]))
                            print("pred:{}".format(
                                predictions["topk_decoded"][i, 0]))
                            print("targ:{}".format(tgt_np[i].lower()))
                            print("true: {}".format(data_sub[1, i].lower()))

                            if predictions["topk_decoded"][
                                    i, 0] == tgt_np[i].lower():
                                correct += 1

                        print("--------------------------------")
                        print("Now, the WER is: {0:.2f}%".format(wer_value))

                    print("num of examples succeed for room {}: {}".format(
                        num_room, correct))
                    success_rate = correct / float(num) * 100
                    print("success rate for room {}: {}%".format(
                        num_room, success_rate))

                    success_rates.append(success_rate)
                success_ave = float(sum(success_rates)) / len(success_rates)
                print("success rate overall: {}%".format(success_ave))
Example #25
0
def GetExecutorParams(model_name, cluster_params, model_registry):
    """Get the params needed to instantiate the Executor.

  Args:
    model_name: A model name registered in the ModelRegistry.
    cluster_params: A cluster hyperparams object.
    model_registry: A ModelRegistry object.

  Returns:
    A tuple (dict, Params):

    - ps_params_dict: High-level task name -> ProgramScheduleParams
    - train_cfg: A SingleTaskModelParams or MultiTaskModelParams.

  Raises:
    ValueError if the model params is invalid.
  """

    ps_params_dict = {}
    with cluster_factory.Cluster(cluster_params):
        ps_cfg = model_registry.GetProgramSchedule(model_name)
        train_cfg = model_registry.GetParams(model_name, 'Train')
        train_cfg.cluster = cluster_params

        # Remove misleading train params
        train_cfg = UnsetUnusedTrainParams(train_cfg)

        if issubclass(train_cfg.cls, base_model.MultiTaskModel):
            multi_task_train_cfg = train_cfg

            if multi_task_train_cfg.train.ema_decay > 0:
                # Check that all subtasks use the same ema settings.
                for task_name, task_params in (
                        multi_task_train_cfg.task_params.IterParams()):
                    for field in ['ema_decay', 'ema_decay_moving_vars']:
                        if (task_params.train.Get(field) !=
                                multi_task_train_cfg.train.Get(field)):
                            raise ValueError(
                                'Params did not match for field %s in task %s'
                                % (field, task_name))

            for k, _ in multi_task_train_cfg.task_params.IterParams():
                if multi_task_train_cfg.share_model_object:
                    # Create MultiTaskSubModel params from a MultiTaskModelParams.
                    train_task_params = base_model.MultiTaskSubModel.Params()
                    train_task_params.task_name = k
                    train_task_params.input = multi_task_train_cfg.input.Get(
                        k).Copy()
                else:
                    train_task_params = base_model.SingleTaskModel.Params()
                    train_task_params.task = multi_task_train_cfg.task_params.Get(
                        k)
                    train_task_params.input = multi_task_train_cfg.input.Get(k)
                train_task_params.name = k + '_executor_train_task'
                train_task_params.cluster = multi_task_train_cfg.cluster
                train_task_params.train = multi_task_train_cfg.task_params.Get(
                    k).train

                if k not in ps_cfg.program_schedule_dict:
                    tf.logging.fatal(
                        'Could not find %s in ps_cfg.program_schedule_dict: %s',
                        k, ps_cfg)
                program_schedule_params = ps_cfg.program_schedule_dict[k]

                program_schedule_params.task_dict = {
                    'Train': train_task_params
                }

                for eval_dataset_name in program_schedule_params.dataset_names:
                    multi_task_eval_cfg = model_registry.GetParams(
                        model_name, eval_dataset_name)
                    if multi_task_train_cfg.share_model_object:
                        eval_task_params = base_model.MultiTaskSubModel.Params(
                        )
                        eval_task_params.task_name = k
                        eval_task_params.input = multi_task_eval_cfg.input.Get(
                            k).Copy()
                    else:
                        eval_task_params = base_model.SingleTaskModel.Params()
                        eval_task_params.task = multi_task_eval_cfg.task_params.Get(
                            k)
                        eval_task_params.input = multi_task_eval_cfg.input.Get(
                            k)
                    eval_task_params.name = (k + '_' + eval_dataset_name +
                                             '_executor_eval_task')
                    eval_task_params.cluster = multi_task_eval_cfg.cluster
                    eval_task_params = UnsetUnusedTrainParams(eval_task_params)

                    program_schedule_params.task_dict[
                        eval_dataset_name] = eval_task_params
                ps_params_dict[k] = program_schedule_params
        else:
            program_schedule_params = ps_cfg
            program_schedule_params.task_dict = {'Train': train_cfg}
            for eval_dataset_name in program_schedule_params.dataset_names:
                task_eval_params = model_registry.GetParams(
                    model_name, eval_dataset_name)
                task_eval_params = UnsetUnusedTrainParams(task_eval_params)
                program_schedule_params.task_dict[
                    eval_dataset_name] = task_eval_params

            ps_params_dict[''] = program_schedule_params

    return ps_params_dict, train_cfg
    def __init__(self,
                 sess,
                 batch_size=1,
                 lr_step1=100,
                 lr_step2=0.1,
                 num_iter_step1=1000,
                 num_iter_step2=4000,
                 th=None,
                 psd_max_ori=None):

        self.sess = sess
        self.num_iter_step1 = num_iter_step1
        self.num_iter_step2 = num_iter_step2
        self.batch_size = batch_size
        self.lr_step1 = lr_step1
        #self.lr_step2 = lr_step2

        tf.set_random_seed(1234)
        params = model_registry.GetParams('asr.librispeech.Librispeech960Wpm',
                                          'Test')
        params.random_seed = 1234
        params.is_eval = True
        params.cluster.worker.gpus_per_replica = 1
        cluster = cluster_factory.Cluster(params.cluster)
        with cluster, tf.device(cluster.GetPlacer()):
            model = params.cls(params)
            self.delta_large = tf.Variable(np.zeros((batch_size, 223200),
                                                    dtype=np.float32),
                                           name='qq_delta')

            # placeholders
            self.input_tf = tf.placeholder(tf.float32,
                                           shape=[batch_size, None],
                                           name='qq_input')
            self.tgt_tf = tf.placeholder(tf.string)
            self.sample_rate_tf = tf.placeholder(tf.int32,
                                                 name='qq_sample_rate')
            self.th = tf.placeholder(tf.float32,
                                     shape=[batch_size, None, None],
                                     name='qq_th')
            self.psd_max_ori = tf.placeholder(tf.float32,
                                              shape=[batch_size],
                                              name='qq_psd')
            self.mask = tf.placeholder(dtype=np.float32,
                                       shape=[batch_size, None],
                                       name='qq_mask')
            self.mask_freq = tf.placeholder(dtype=np.float32,
                                            shape=[batch_size, None, 80])
            #noise = tf.random_normal(self.new_input.shape, stddev=2)
            self.noise = tf.placeholder(np.float32,
                                        shape=[batch_size, None],
                                        name="qq_noise")
            self.maxlen = tf.placeholder(np.int32)
            self.lr_step2 = tf.placeholder(np.float32)

            # variable
            self.rescale = tf.Variable(np.ones((batch_size, 1),
                                               dtype=np.float32),
                                       name='qq_rescale')
            self.alpha = tf.Variable(np.ones(
                (batch_size), dtype=np.float32) * 0.05,
                                     name='qq_alpha')

            # extract the delta
            self.delta = tf.slice(tf.identity(self.delta_large), [0, 0],
                                  [batch_size, self.maxlen])
            self.apply_delta = tf.clip_by_value(self.delta, -2000,
                                                2000) * self.rescale
            self.new_input = self.apply_delta * self.mask + self.input_tf
            #pass_in = tf.clip_by_value(self.new_input, -2**15, 2**15-1)
            self.pass_in = tf.clip_by_value(self.new_input + self.noise,
                                            -2**15, 2**15 - 1)

            # generate the inputs that are needed for the lingvo model
            self.features = create_features(self.pass_in, self.sample_rate_tf,
                                            self.mask_freq)
            self.inputs = create_inputs(model, self.features, self.tgt_tf,
                                        self.batch_size, self.mask_freq)

            task = model.GetTask()
            metrics = task.FPropDefaultTheta(self.inputs)
            # self.celoss with the shape (batch_size)
            self.celoss = tf.get_collection("per_loss")[0]
            self.decoded = task.Decode(self.inputs)

        # compute the loss for masking threshold
        self.loss_th_list = []
        self.transform = Transform()
        for i in range(self.batch_size):
            logits_delta = self.transform((self.apply_delta[i, :]),
                                          (self.psd_max_ori)[i])
            loss_th = tf.reduce_mean(tf.nn.relu(logits_delta - (self.th)[i]))
            loss_th = tf.expand_dims(loss_th, dim=0)
            self.loss_th_list.append(loss_th)
        self.loss_th = tf.concat(self.loss_th_list, axis=0)

        self.optimizer1 = tf.train.AdamOptimizer(self.lr_step1)
        self.optimizer2 = tf.train.AdamOptimizer(self.lr_step2)

        grad1, var1 = self.optimizer1.compute_gradients(
            self.celoss, [self.delta_large])[0]
        grad21, var21 = self.optimizer2.compute_gradients(
            self.celoss, [self.delta_large])[0]
        grad22, var22 = self.optimizer2.compute_gradients(
            self.alpha * self.loss_th, [self.delta_large])[0]

        self.train1 = self.optimizer1.apply_gradients([(tf.sign(grad1), var1)])
        self.train21 = self.optimizer2.apply_gradients([(grad21, var21)])
        self.train22 = self.optimizer2.apply_gradients([(grad22, var22)])
        self.train2 = tf.group(self.train21, self.train22)
Example #27
0
def main(argv):
    data = np.loadtxt(FLAGS.input, dtype=str, delimiter=",")
    # calculate the number of loops to run the test
    num = len(data[0])
    batch_size = FLAGS.batch_size
    num_loops = num / batch_size
    assert num % batch_size == 0

    with tf.device("/gpu:0"):
        tf.set_random_seed(1234)
        tfconf = tf.ConfigProto(allow_soft_placement=True)
        with tf.Session(config=tfconf) as sess:
            params = model_registry.GetParams(
                'asr.librispeech.Librispeech960Wpm', 'Test')
            params.cluster.worker.gpus_per_replica = 1
            cluster = cluster_factory.Cluster(params.cluster)
            with cluster, tf.device(cluster.GetPlacer()):
                params.vn.global_vn = False
                params.random_seed = 1234
                params.is_eval = True
                model = params.cls(params)
                task = model.GetTask()
                saver = tf.train.Saver()
                saver.restore(sess, FLAGS.checkpoint)

                # define the placeholders
                input_tf = tf.placeholder(tf.float32, shape=[batch_size, None])
                tgt_tf = tf.placeholder(tf.string)
                sample_rate_tf = tf.placeholder(tf.int32)
                mask_tf = tf.placeholder(tf.float32,
                                         shape=[batch_size, None, 80])

                # generate the features and inputs
                features = create_features(input_tf, sample_rate_tf, mask_tf)
                shape = tf.shape(features)
                inputs = create_inputs(model, features, tgt_tf, batch_size,
                                       mask_tf)

                # loss
                metrics = task.FPropDefaultTheta(inputs)
                loss = tf.get_collection("per_loss")[0]

                # prediction
                decoded_outputs = task.Decode(inputs)
                dec_metrics_dict = task.CreateDecoderMetrics()

                correct = 0
                for l in range(num_loops):
                    data_sub = data[:, l * batch_size:(l + 1) * batch_size]
                    audios_np, sample_rate, tgt_np, mask_freq = Read_input(
                        data_sub, batch_size)
                    feed_dict = {
                        input_tf: audios_np,
                        sample_rate_tf: sample_rate,
                        tgt_tf: tgt_np,
                        mask_tf: mask_freq
                    }

                    losses = sess.run(loss, feed_dict)
                    predictions = sess.run(decoded_outputs, feed_dict)

                    task.PostProcessDecodeOut(predictions, dec_metrics_dict)
                    wer_value = dec_metrics_dict['wer'].value * 100.

                    for i in range(batch_size):
                        print("pred:{}".format(predictions['topk_decoded'][i,
                                                                           0]))
                        print("targ:{}".format(tgt_np[i].lower()))
                        print("true: {}".format(data_sub[1, i].lower()))

                        if predictions['topk_decoded'][i,
                                                       0] == tgt_np[i].lower():
                            correct += 1
                            print("------------------------------")
                            print("example {} succeeds".format(i))

                    print("Now, the WER is: {0:.2f}%".format(wer_value))
                print("num of examples succeed: {}".format(correct))
                print("success rate: {}%".format(correct / float(num) * 100))
Example #28
0
    def __init__(
        self,
        sess,
        batch_size=1,
        lr_stage1=100,
        lr_stage2=0.1,
        num_iter_stage1=1000,
        num_iter_stage2=4000,
        th=None,
        psd_max_ori=None,
    ):

        self.sess = sess
        self.num_iter_stage1 = num_iter_stage1
        self.num_iter_stage2 = num_iter_stage2
        self.batch_size = batch_size
        self.lr_stage1 = lr_stage1
        self.lr_stage2 = lr_stage2

        tf.set_random_seed(1234)
        params = model_registry.GetParams("asr.librispeech.Librispeech960Wpm",
                                          "Test")
        params.random_seed = 1234
        params.is_eval = True
        params.cluster.worker.gpus_per_replica = 1
        cluster = cluster_factory.Cluster(params.cluster)
        with cluster, tf.device(cluster.GetPlacer()):
            model = params.cls(params)
            self.delta_large = tf.Variable(
                np.zeros((batch_size, FLAGS.max_length_dataset),
                         dtype=np.float32),
                name="qq_delta",
            )

            # placeholders
            self.input_tf = tf.placeholder(tf.float32,
                                           shape=[batch_size, None],
                                           name="qq_input")
            self.tgt_tf = tf.placeholder(tf.string)
            self.rir = tf.placeholder(tf.float32)

            self.sample_rate_tf = tf.placeholder(tf.int32,
                                                 name="qq_sample_rate")
            self.mask = tf.placeholder(dtype=np.float32,
                                       shape=[batch_size, None],
                                       name="qq_mask")
            self.mask_freq = tf.placeholder(dtype=np.float32,
                                            shape=[batch_size, None, 80])
            self.noise = tf.placeholder(np.float32,
                                        shape=[batch_size, None],
                                        name="qq_noise")
            self.maxlen = tf.placeholder(np.int32)
            self.lr = tf.placeholder(np.float32)
            self.lengths = tf.placeholder(
                np.int32,
                shape=[
                    batch_size,
                ],
            )

            # variable
            self.rescale = tf.Variable(
                np.ones(
                    (batch_size, 1), dtype=np.float32) * FLAGS.initial_bound,
                name="qq_rescale",
            )

            # extract the delta
            self.delta = tf.slice(tf.identity(self.delta_large), [0, 0],
                                  [batch_size, self.maxlen])
            self.apply_delta = tf.clip_by_value(self.delta, -self.rescale,
                                                self.rescale)
            self.before_rir = tf.clip_by_value(
                self.apply_delta * self.mask + self.input_tf, -(2**15),
                2**15 - 1)
            self.new_input = (create_speech_rir(
                self.before_rir,
                self.rir,
                self.lengths,
                self.maxlen,
                self.batch_size,
            ) * self.mask)
            self.pass_in = tf.clip_by_value(self.new_input + self.noise,
                                            -(2**15), 2**15 - 1)

            # generate the inputs that are needed for the lingvo model
            self.features = create_features(self.pass_in, self.sample_rate_tf,
                                            self.mask_freq)
            self.inputs = create_inputs(model, self.features, self.tgt_tf,
                                        self.batch_size, self.mask_freq)

            task = model.GetTask()
            metrics = task.FPropDefaultTheta(self.inputs)

            # self.celoss with the shape (batch_size)
            self.celoss = tf.get_collection("per_loss")[0]
            self.decoded = task.Decode(self.inputs)

        self.optimizer1 = tf.train.AdamOptimizer(self.lr)
        grad1, var1 = self.optimizer1.compute_gradients(
            self.celoss, [self.delta_large])[0]
        self.train1 = self.optimizer1.apply_gradients([(tf.sign(grad1), var1)])
Example #29
0
    def Task(self):
        metadata = waymo_metadata.WaymoMetadata()
        num_classes = len(metadata.ClassNames())
        p = starnet.ModelV2.Params(
            num_classes,
            num_anchor_bboxes_offsets=self.NUM_ANCHOR_BBOX_OFFSETS,
            num_anchor_bboxes_rotations=self.NUM_ANCHOR_BBOX_ROTATIONS,
            num_anchor_bboxes_dimensions=self.NUM_ANCHOR_BBOX_DIMENSIONS,
            num_laser_features=3)

        # Update the Point Cloud Featurizer architecture
        starnet_builder = starnet.Builder()
        starnet_builder.linear_params_init = (
            py_utils.WeightInit.KaimingUniformFanInRelu())

        gin_layers = [[
            self.GIN_HIDDEN_DIMS * 2, self.GIN_HIDDEN_DIMS * 4,
            self.GIN_HIDDEN_DIMS
        ]] * self.NUM_GIN_LAYERS  # pyformat: disable

        p.cell_featurizer = starnet_builder.GINFeaturizerV2(
            'feat',
            num_laser_features=3,
            fc_dims=self.GIN_HIDDEN_DIMS,
            mlp_dims=gin_layers,
            fc_use_bn=False)
        p.cell_feature_dims = self.GIN_HIDDEN_DIMS * (self.NUM_GIN_LAYERS + 1)

        p.output_decoder = waymo_decoder.WaymoOpenDatasetDecoder.Params()
        p.max_nms_boxes = 512
        p.use_oriented_per_class_nms = True

        # Note: Sub-classes need to set nms_iou_threshold and nms_score_threshold
        # appropriately.
        p.nms_iou_threshold = [0.0] * num_classes

        # TODO(jngiam): 1.1 for untrained classes is needed to avoid an issue
        # with boxutils error.
        p.nms_score_threshold = [1.1] * num_classes

        p.name = 'starnet'
        tp = p.train
        tp.optimizer = optimizer.Adam.Params()
        tp.clip_gradient_norm_to_value = 5

        ep = p.eval

        # Train set uses a smaller decoding set, so we can
        # safely eval over the entire input.
        ep.samples_per_summary = 0

        # To be tuned.
        p.train.l2_regularizer_weight = 1e-8

        cluster = cluster_factory.Current()
        train_cluster_p = cluster.params.Copy()
        train_cluster_p.job = 'trainer_client'
        train_cluster_p.mode = 'sync'

        # When running a decoding only job, there are no trainer workers, so we set
        # worker replicas to 1 as a dummy value.
        if train_cluster_p.worker.replicas <= 0:
            train_cluster_p.worker.replicas = 1

        # Set learning rate and schedule.
        with cluster_factory.Cluster(train_cluster_p):
            train_input_p = self.Train()

        # Adapted from V1 tuning.
        tp.ema_decay = 0.99
        # TODO(b/148537111): consider setting this to True.
        tp.ema_decay_moving_vars = False
        tp.learning_rate = 0.001
        lr_util.SetExponentialLR(train_p=tp,
                                 train_input_p=train_input_p,
                                 exp_start_epoch=5,
                                 total_epoch=75)

        p.dimension_loss_weight = .3
        p.location_loss_weight = 3.
        p.loss_weight_classification = 1.
        p.loss_weight_localization = 3.
        p.rotation_loss_weight = 0.3

        return p
Example #30
0
def GetExecutorParams(model_name, cluster_params, model_registry):
    """Get the params needed to instantiate the Executor.

  Args:
    model_name: A model name registered in the ModelRegistry.
    cluster_params: A cluster hyperparams object.
    model_registry: A ModelRegistry object.

  Returns:
    A tuple (dict, Params):

    - ps_params_dict: High-level task name -> ProgramScheduleParams
    - train_cfg: A SingleTaskModelParams or MultiTaskModelParams.

  Raises:
    ValueError if the model params is invalid.
  """

    ps_params_dict = {}
    with cluster_factory.Cluster(cluster_params):
        ps_cfg = model_registry.GetProgramSchedule(model_name)
        train_cfg = model_registry.GetParams(model_name, 'Train')
        train_cfg.cluster = cluster_params

        # Remove misleading train params
        train_cfg = UnsetUnusedTrainParams(train_cfg)

        if issubclass(train_cfg.cls, base_model.MultiTaskModel):
            multi_task_train_cfg = train_cfg

            for k, _ in multi_task_train_cfg.task_params.IterParams():
                if multi_task_train_cfg.share_model_object:
                    # Create MultiTaskSubModel params from a MultiTaskModelParams.
                    train_task_params = base_model.MultiTaskSubModel.Params()
                    train_task_params.task_name = k
                    train_task_params.input = multi_task_train_cfg.input.Get(
                        k).Copy()
                else:
                    task = multi_task_train_cfg.task_params.Get(k)
                    train_task_params = base_model.SingleTaskModel.Params(task)
                    train_task_params.input = multi_task_train_cfg.input.Get(k)
                train_task_params.name = k + '_executor_train_task'
                train_task_params.cluster = multi_task_train_cfg.cluster
                train_task_params.train = multi_task_train_cfg.task_params.Get(
                    k).train

                if k not in ps_cfg.program_schedule_dict:
                    tf.logging.fatal(
                        'Could not find %s in ps_cfg.program_schedule_dict: %s',
                        k, ps_cfg)
                # Add Copy in case a user is sharing the same ProgramSchedule params
                # instance across different tasks.
                program_schedule_params = ps_cfg.program_schedule_dict[k].Copy(
                )

                program_schedule_params.task_dict = {
                    'Train': train_task_params
                }

                for eval_dataset_name in program_schedule_params.dataset_names:
                    multi_task_eval_cfg = model_registry.GetParams(
                        model_name, eval_dataset_name)
                    multi_task_eval_cfg.cluster = cluster_params
                    if multi_task_train_cfg.share_model_object:
                        eval_task_params = base_model.MultiTaskSubModel.Params(
                        )
                        eval_task_params.task_name = k
                        eval_task_params.input = multi_task_eval_cfg.input.Get(
                            k).Copy()
                    else:
                        task = multi_task_eval_cfg.task_params.Get(k)
                        eval_task_params = base_model.SingleTaskModel.Params(
                            task)
                        eval_task_params.input = multi_task_eval_cfg.input.Get(
                            k)
                    eval_task_params.name = (k + '_' + eval_dataset_name +
                                             '_executor_eval_task')
                    eval_task_params.cluster = multi_task_eval_cfg.cluster
                    eval_task_params = UnsetUnusedTrainParams(eval_task_params)

                    program_schedule_params.task_dict[
                        eval_dataset_name] = eval_task_params
                ps_params_dict[k] = program_schedule_params
        else:
            program_schedule_params = ps_cfg
            program_schedule_params.task_dict = {'Train': train_cfg}
            for eval_dataset_name in program_schedule_params.dataset_names:
                task_eval_params = model_registry.GetParams(
                    model_name, eval_dataset_name)
                task_eval_params.cluster = cluster_params
                task_eval_params = UnsetUnusedTrainParams(task_eval_params)
                program_schedule_params.task_dict[
                    eval_dataset_name] = task_eval_params

            ps_params_dict[''] = program_schedule_params

    return ps_params_dict, train_cfg