Beispiel #1
0
    def GetParamsForDataset(self, job_name, dataset_name):
        """Returns params for job `job_name` on the dataset `dataset_name`."""
        # Get the current cluster and update its params from flags.
        cluster = cluster_factory.Current()
        self.UpdateClusterParamsFromFlags(cluster.params, job_name)
        with cluster_factory.Cluster(cluster.params):
            try:
                cfg = self.model_registry.GetParams(self._model_name,
                                                    dataset_name)
            except base_model_params.DatasetError as e:
                dataset_name_retry = dataset_name.title()
                tf.logging.warning(
                    'Exception configuring dataset %s, retrying as %s: %s',
                    dataset_name, dataset_name_retry, e)
                cfg = self.model_registry.GetParams(self._model_name,
                                                    dataset_name_retry)
                tf.logging.warning('Succeeded after retrying as %s.' %
                                   dataset_name_retry)
        cfg.cluster = cluster.params

        # Updates a few params based on flags.
        if FLAGS.enqueue_max_steps is not None:
            cfg.train.enqueue_max_steps = FLAGS.enqueue_max_steps
        if FLAGS.saver_max_to_keep is not None:
            cfg.train.save_max_to_keep = FLAGS.saver_max_to_keep
        if FLAGS.saver_keep_checkpoint_every_n_hours is not None:
            cfg.train.save_keep_checkpoint_every_n_hours = FLAGS.saver_keep_checkpoint_every_n_hours
        return cfg
Beispiel #2
0
def AddLaserAndCamera(params):
    """Adds laser and camera extractors."""
    cluster = cluster_factory.Current()
    job = cluster.job
    if job != 'decoder':
        return params

    extractor_params = list(dict(params.extractors.IterParams()).values())
    extractor_classes = [p.cls for p in extractor_params]

    # Add images if not present.
    if kitti_input_generator.KITTIImageExtractor not in extractor_classes:
        params.extractors.Define(
            'images', kitti_input_generator.KITTIImageExtractor.Params(), '')

    # Add raw lasers if not present.
    if kitti_input_generator.KITTILaserExtractor not in extractor_classes:
        labels = None
        for p in extractor_params:
            if p.cls == kitti_input_generator.KITTILabelExtractor:
                labels = p
        if labels is None:
            labels = kitti_input_generator.KITTILabelExtractor.Params()
        params.extractors.Define(
            'lasers', kitti_input_generator.KITTILaserExtractor.Params(labels),
            '')

    return params
Beispiel #3
0
    def _configure_input(self, p, split):
        p = super(StarNetPedFused, self)._configure_input(p, split)

        job_type = cluster_factory.Current().job

        if job_type.startswith('trainer') and not self.RUN_LOCALLY:
            p.file_buffer_size = 48
            p.file_parallelism = 48
            p.num_batcher_threads = 48

        # Fuses select_centers and gather_features into one sampler.
        p.preprocessors.Define(
            'sampler',
            input_preprocessors.SparseSampler.Params().Set(
                center_selector='farthest',
                neighbor_sampler='uniform',
                num_centers=p.preprocessors.select_centers.num_cell_centers,
                keep_z_range=(0.09522381, 1.720825),
                num_neighbors=p.preprocessors.gather_features.
                num_points_per_cell,
                max_distance=2.75), '')
        p.preprocessors.Delete('select_centers')
        p.preprocessors.Delete('gather_features')
        p.preprocessors_order.remove('select_centers')
        p.preprocessors_order[p.preprocessors_order.index(
            'gather_features')] = 'sampler'
        return p
Beispiel #4
0
 def _DecoderDevice(self):
     """Returns the device to run the decoder computation."""
     if py_utils.use_tpu():
         cluster = cluster_factory.Current()
         return tf.device(cluster.WorkerDeviceInModelSplit(1))
     else:
         return tf.device('')
Beispiel #5
0
def _FromGlobal(field_name, allow_override_from_cluster=False):
  """Get 'field_name' from a global configuration object.

  Currently the global configuration object used is FLAGS, but this may
  change to Cluster() or an equivalent stack-scoped config object.

  Args:
    field_name: The string field name to look up.
    allow_override_from_cluster: Allow the Cluster() to override FLAGS.

  Returns:
    The value associated with the global configuration string 'field_name'.
  """

  if allow_override_from_cluster:
    cluster = cluster_factory.Current()
    if field_name in cluster.params:
      params_value = cluster.params.Get(field_name)
      # Return the value in the cluster params if it is not None
      if params_value is not None:
        return params_value

  # Now check the FLAGS object for backwards compatibility.
  #
  # If not explicitly set, get the field from the FLAGS object.  If FLAGS
  # have not been parsed yet, the default value of the flag will be used.
  return FLAGS[field_name].value
Beispiel #6
0
    def testTextPackedInputBatchSize(self, use_per_host_infeed,
                                     packing_factor):
        p = cluster_factory.Current().params.Copy()
        p.job = 'trainer'
        p.worker.tpus_per_replica = 8
        p.worker.num_tpu_hosts = 16
        p.worker.devices_per_split = 2
        cluster = p.Instantiate()

        with cluster, mock.patch('lingvo.core.py_utils.use_tpu',
                                 return_value=True):
            p = input_generator.TextPackedInput.Params()
            p.use_per_host_infeed = use_per_host_infeed
            p.file_random_seed = 0
            p.file_pattern = 'tfrecord:' + test_helper.test_src_dir_path(
                'tasks/mt/testdata/en_fr.tfrecord')
            p.pad_to_max_seq_length = True
            p.tokenizer = tokenizers.AsciiTokenizer.Params()
            p.input_file_type = 'sentence_proto'
            p.source_max_length = 32
            p.target_max_length = 32
            p.bucket_batch_limit = [128]
            p.packing_factor = packing_factor
            with self.session() as sess:
                inp = p.Instantiate()
                # GlobalBatchSize is batch_size (128) * num_splits_per_client (4).
                # num_splits_per_client is 4, because num_splits_per_replica is 4.
                # num_splits_per_replica is 4 because that's tpus_per_replica
                # divided by devices_per_split.
                expected_global_batch_size = (
                    p.bucket_batch_limit[0] //
                    cluster.params.worker.devices_per_split *
                    cluster.params.worker.tpus_per_replica)
                if p.packing_factor is not None:
                    expected_global_batch_size = np.math.floor(
                        expected_global_batch_size * p.packing_factor)

                expected_infeed_batch_size = expected_global_batch_size
                if use_per_host_infeed:
                    expected_infeed_batch_size = (
                        expected_global_batch_size //
                        cluster.params.worker.num_tpu_hosts)

                expected_packed_infeed_batch_size = expected_infeed_batch_size
                if p.packing_factor is not None:
                    expected_packed_infeed_batch_size = np.math.floor(
                        expected_infeed_batch_size / p.packing_factor)

                self.assertEqual(expected_global_batch_size,
                                 inp.GlobalBatchSize())
                self.assertEqual(expected_infeed_batch_size,
                                 inp.InfeedBatchSize())

                batch_tensor = inp.GetPreprocessedInputBatch()
                for k, x in batch_tensor.FlattenItems():
                    self.assertTrue(x.shape.is_fully_defined(), k)
                batch = sess.run(batch_tensor)
                self.assertEqual(batch.src.ids.shape,
                                 (expected_packed_infeed_batch_size, 32))
Beispiel #7
0
def _GetTrainingStatistics(train_input_p):
    """Get training statistics, including total batch size and steps per epoch."""
    cluster = cluster_factory.Current()
    # E.g., this is 1 for a single GPU, 8 for a 2x2 TPU, 32 for a 4x4 TPU,
    # or 0 if no training job is launched.
    total_num_cores = cluster.total_worker_devices
    total_batch_size = max(train_input_p.batch_size * total_num_cores, 1)
    steps_per_epoch = float(train_input_p.num_samples) / total_batch_size
    return py_utils.NestedMap(total_num_cores=total_num_cores,
                              total_batch_size=total_batch_size,
                              steps_per_epoch=steps_per_epoch)
Beispiel #8
0
 def scaled_bucket_batch_limit(self):
   p = self.params
   if not hasattr(self, '_scaled_bucket_batch_limit'):
     cluster = cluster_factory.Current()
     self._scaled_bucket_batch_limit = [
         b * cluster.num_splits_per_client for b in p.bucket_batch_limit
     ]
     if p.use_per_host_infeed and cluster.num_tpu_hosts > 0:
       self._scaled_bucket_batch_limit = [
           x // cluster.num_tpu_hosts for x in self._scaled_bucket_batch_limit
       ]
   return self._scaled_bucket_batch_limit
Beispiel #9
0
  def InputBatchSize(self):
    """Returns the batch size for the current step."""
    p = self.params
    cluster = cluster_factory.Current()

    # If use_per_host_infeed, each input op is only responsible
    # for generating a subset of the whole batch.
    batch_per_input = p.batch_size * cluster.num_splits_per_client
    if p.use_per_host_infeed and cluster.num_tpu_hosts > 0:
      tf.logging.info('batch_size %d cluster.num_tpu_hosts %d', batch_per_input,
                      cluster.num_tpu_hosts)
      batch_per_input //= cluster.num_tpu_hosts
    tf.logging.info('batch_per_input: %d', batch_per_input)
    return batch_per_input
Beispiel #10
0
 def _GetSaver(self):
   """Returns a saver."""
   do_eval = cluster_factory.Current().do_eval
   if not self._save_only and self._model.ema and do_eval:
     tf.logging.info('Using EMA for evaluation.')
     return tf.train.Saver(
         self._model.ema.variables_to_restore(self._model.variables_for_ema))
   return tf.train.Saver(
       sharded=True,
       max_to_keep=self._train_params.save_max_to_keep,
       keep_checkpoint_every_n_hours=(
           self._train_params.save_keep_checkpoint_every_n_hours),
       pad_step_number=True,  # %08d
       write_version=tf.train.SaverDef.V2)
    def __init__(self, model_name, split, run_preprocessors):
        self._model_name = model_name
        self._split = split
        self._run_preprocessors = run_preprocessors
        self._sess = None

        # Create a cluster configuration assuming evaluation; the input pipelines
        # need to know the cluster job type to set up the outputs correctly.
        cluster = cluster_factory.Current()
        cluster.params.job = 'evaler'
        cluster.params.mode = 'sync'
        cluster.params.task = 0
        cluster.params.evaler.replicas = 1
        self._cluster = cluster_factory.Cluster(cluster.params)
Beispiel #12
0
    def GetExecutorParams(self):
        """Get the params needed to instantiate the ExecutorTpu.

    Returns:
       Tuple (dict, params):

         - ps_params_dict: high_level task_name -> ProgramScheduleParams
         - train_cfg: Either a SingleTaskModelParams or MultiTaskModelParams.
    """
        cluster = cluster_factory.Current()
        self.UpdateClusterParamsFromFlags(cluster.params, 'executor_tpu')
        ps_params_dict, train_cfg = executor.GetExecutorParams(
            self._model_name, cluster.params, self.model_registry)

        return ps_params_dict, train_cfg
Beispiel #13
0
 def _MaybeOverwriteModelVariablesWithEMA(self):
     """Overwrite model variables with EMA shadow variables in eval mode."""
     do_eval = cluster_factory.Current().do_eval
     if not self._save_only and do_eval:
         for model in self._models:
             if not model.ema:
                 continue
             tf.logging.info('Using EMA for evaluation.')
             # TODO(jiaweix): this implementation will load both the model variables
             # and EMA variables. As a result the memory usage will be higher than
             # the eval jobs in TF1 mode.
             ema = model.ema
             cur_vars = model.GetVariablesDict()
             for v in cur_vars.values():
                 shadow_v = ema.average(v)
                 if shadow_v is not None:
                     v.assign(shadow_v)
Beispiel #14
0
 def _GetSaver(self):
     """Returns a saver."""
     do_eval = cluster_factory.Current().do_eval
     variables_to_restore = {}
     if not self._save_only and do_eval:
         for model in self._models:
             if model.ema:
                 tf.logging.info('Using EMA for evaluation.')
                 variables_to_restore.update(
                     model.ema.variables_to_restore(
                         model.variables_for_ema))
     if not variables_to_restore:
         variables_to_restore = None
     return SaverWrapper(self._train_dir,
                         self._train_params,
                         variables_to_restore_dict=variables_to_restore,
                         async_save=self.async_checkpointing)
Beispiel #15
0
def ConcatenateAcrossReplicas(tensors,
                              tpu_cores=None,
                              axis=0,
                              stop_cross_gradients=False):
    """Concatenates one or more local tensors across all TPU cores.

  Input `tensors` may be in any format supported by `tf.nest.flatten` (single
  Tensor, dict, etc.), or a `NestedMap`.

  In order to avoid having to pass a TPU core ID into the infeed, this
  implementation produces a different rotation of the concatenation for each
  core.  For example, core 0 will be arranged as [0, 1, 2, ...], whereas core 3
  will be arranged as [3, 4, ..., 0, 1, 2].

  If called from a non-TPU context, this function returns the `tensors`
  unchanged.

  Args:
    tensors: The local tensor or tensors to concatenate across cores.
    tpu_cores: The total number of TPU cores. If not set, the number of cores is
      inferred from `cluster_factory.Current()`.
    axis: The axis to concatenate.
    stop_cross_gradients: If true, stop gradients on cross-replica slices.

  Returns:
    The tensor(s) concatenated across all replicas.
  """
    if not py_utils.use_tpu():
        return tensors

    if tpu_cores is None:
        cluster = cluster_factory.Current()
        tpu_cores = cluster.tpus_per_replica * cluster.num_replicas
        assert tpu_cores, 'Unable to determine number of TPU cores from cluster.'

    concat_fn = functools.partial(CrossReplicaConcat,
                                  tpu_cores=tpu_cores,
                                  axis=axis,
                                  stop_cross_gradients=stop_cross_gradients)
    concatenated_tensors = [concat_fn(t) for t in tf.nest.flatten(tensors)]

    if isinstance(tensors, py_utils.NestedMap):
        return tensors.Pack(concatenated_tensors)
    else:
        return tf.nest.pack_sequence_as(tensors, concatenated_tensors)
Beispiel #16
0
 def _configure_input(self, p):
     """Base function managing the delegation of job specific input configs."""
     self._configure_generic_input(p)
     cluster = cluster_factory.Current()
     job = cluster.job
     if job.startswith('trainer'):
         self._configure_trainer_input(p)
     elif job.startswith('decoder'):
         self._configure_decoder_input(p)
     elif job.startswith('evaler'):
         self._configure_evaler_input(p)
     else:
         tf.logging.info('There are no input configuration changes to for '
                         'job {}.'.format(job))
     if self.RUN_LOCALLY:
         p.num_batcher_threads = 1
         p.file_buffer_size = 1
         p.file_parallelism = 1
Beispiel #17
0
def scale_split_to_infeed(split_batch_size, use_per_host_infeed):
    """Obtains an infeed batch size from a split batch size and cluster configs.

  Args:
    split_batch_size: int: Per-split batch size.
    use_per_host_infeed: bool: Whether to use an individual infeed for each
      host.

  Returns:
    int: Per-infeed batch size.
  """
    cluster = cluster_factory.Current()
    global_batch_size = split_batch_size * cluster.num_splits_per_client
    # If use_per_host_infeed, each input op is only responsible
    # for generating a subset of the whole batch.
    if use_per_host_infeed and cluster.num_tpu_hosts > 0:
        return global_batch_size // cluster.num_tpu_hosts
    else:
        return global_batch_size
Beispiel #18
0
def scale_infeed_to_global(infeed_batch_size, use_per_host_infeed):
    """Obtains a global batch size from an infeed batch size and cluster configs.

  Args:
    infeed_batch_size: int: Per-infeed batch size.
    use_per_host_infeed: bool: Whether to use an individual infeed for each
      host.

  Returns:
    int: Global batch size.
  """
    cluster = cluster_factory.Current()
    if use_per_host_infeed and cluster.num_tpu_hosts > 0:
        if not py_utils.use_tpu():
            raise ValueError('Scaling to TPU hosts without TPUs. {}'.format(
                cluster.num_tpu_hosts))
        return infeed_batch_size * cluster.num_tpu_hosts
    else:
        return infeed_batch_size
Beispiel #19
0
 def __init__(self, p):
     if self._VALIDATE_BATCH_SIZE_NONE and p.batch_size is not None:
         raise ValueError(
             'LingvoInputAdaptor does not support p.batch_size. '
             'Please specify batch size on p.input, e.g. with '
             'p.input.bucket_batch_limit = [4] or '
             'p.input.args.batch=4, depeding the Lingvo input '
             f'used. Currently: p.batch_size={p.batch_size}, '
             'it must be None.')
     super().__init__(p)
     self._cluster = copy.deepcopy(cluster_factory.Current())
     # For Lingvo's Cluster context that may impact the behavior of this input
     # generator, we always set use_tpu to True, and optionally set do_eval
     # for non-training data when configured to do so. All other Cluster params
     # use the default value.
     self._cluster.params.xla_device = 'tpu'
     self._cluster.params.enable_asserts = False
     # This indirectly sets cluster.require_sequential_input_order as well.
     self._cluster.params.do_eval = (not p.is_training
                                     and p.cluster_do_eval)
     self._initialize()
Beispiel #20
0
    def __init__(self, params):
        super(LinearRampupExponentialDecayScaledByNumSplitSchedule,
              self).__init__(params)

        p = self.params

        # We always compute lr schedule from the trainer's perspective.
        # Also note that this schedule makes sense to sync training only.
        if p.num_splits:
            splits = p.num_splits
        else:
            # Infer num_splits from cluster.
            cluster_params = cluster_factory.Current().params.Copy()
            cluster_params.task = 0
            assert cluster_params.mode == 'sync'
            cluster_params.job = 'trainer_client'
            my_cluster = cluster_params.cls(cluster_params)
            splits = my_cluster.num_splits_per_client

        warmup_end = p.warmup * splits
        decay_start = max(warmup_end + 1.0, p.decay_start / splits)
        peak = 1.0 * splits
        tf.logging.info('Peak lr: %f', peak)
        decay_end = max(decay_start + 1.0, p.decay_end / splits)
        schedules = [
            LinearLearningRateSchedule.Params().Set(start=(0., p.warmup_init),
                                                    limit=(warmup_end, peak)),
            LinearLearningRateSchedule.Params().Set(start=(warmup_end, peak),
                                                    limit=(decay_start, peak)),
            ExponentialLearningRateSchedule.Params().Set(start=(decay_start,
                                                                peak),
                                                         limit=(decay_end,
                                                                p.min)),
            LinearLearningRateSchedule.Params().Set(start=(0, p.max),
                                                    limit=(decay_end, p.max)),
        ]
        self.CreateChild(
            'combine',
            CombinedMinimumLearningRateSchedule.Params().Set(
                schedules=schedules))
Beispiel #21
0
    def __init__(self, params):
        super(LinearRampupPiecewiseConstantSchedule, self).__init__(params)

        p = self.params
        assert len(p.boundaries) >= 2 and len(p.boundaries) == len(p.lrs)
        # We always compute lr schedule from the trainer's perspective.
        # Also note that this schedule makes sense to sync training only.
        if p.num_splits:
            splits = p.num_splits
        else:
            # Infer num_splits from cluster.
            cluster_params = cluster_factory.Current().params.Copy()
            cluster_params.task = 0
            assert cluster_params.mode == 'sync'
            cluster_params.job = 'trainer_client'
            my_cluster = cluster_params.cls(cluster_params)
            splits = my_cluster.num_splits_per_client

        assert splits >= 1
        splits = float(splits)
        boundaries = [step / splits for step in p.boundaries]
        lrs = [step * splits for step in p.lrs]

        tf.logging.info('splits: {}\n boundaries: {}\n lrs: {} '.format(
            splits, boundaries, lrs))

        schedules = [
            LinearLearningRateSchedule.Params().Set(start=(0., 0.),
                                                    limit=(boundaries[0],
                                                           lrs[0])),
            PiecewiseConstantLearningRateSchedule.Params().Set(
                boundaries=boundaries, values=[1e8] + lrs)
        ]
        self.CreateChild(
            'combine',
            CombinedMinimumLearningRateSchedule.Params().Set(
                schedules=schedules))
Beispiel #22
0
 def cluster(self):
     """Returns the current cluster configuration."""
     return cluster_factory.Current()
Beispiel #23
0
def _ShouldAddSummary():
    return cluster_factory.Current().add_summary
Beispiel #24
0
    def FProp(self, theta):
        """Forward propagation.

    This default `FProp` implementation here supports batch splitting in
    synchronous and asynchronous training when sub-classes implement
    `FPropTower`.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.

    Returns:
      A dict containing metrics pairs. One of the keys should be 'loss' and its
      value should be a (loss, num_predictions) pair.
    """
        p = self.params
        cluster = cluster_factory.Current()

        with tf.name_scope('fprop'), tf.name_scope(p.name):
            all_fprop_metrics = []

            if py_utils.use_tpu():
                batch = self.input_generator.CreateTpuFeeds()
                with tf.name_scope('tower_0_0'):
                    dec_metrics = self.FPropTower(theta, batch)
                all_fprop_metrics.append(dec_metrics)
            else:
                # Splits the input batch on the input device.
                num_splits = cluster.num_splits_per_client
                with tf.device(cluster.input_device):
                    batches = self.input_generator.SplitInputBatch(num_splits)
                    assert num_splits == len(batches)

                # dev_list_per_replica[i][j] is the i-th worker's j-th device.
                dev_list_per_replica = cluster.available_devices.tolist()

                # Asserts invariant of the total number of splits w.r.t.,
                # splits per worker.
                splits_per_replica = cluster.num_splits_per_replica
                assert num_splits == splits_per_replica * len(
                    dev_list_per_replica)

                for w_id, w_devs in enumerate(dev_list_per_replica):
                    # Make local copy of the vars, shard on devices for this worker.
                    theta_local = py_utils.CreateLocalTheta(theta,
                                                            w_devs,
                                                            label='worker %d' %
                                                            w_id)

                    for s_id in range(splits_per_replica):
                        # s_id-th split for the w_id-th worker.
                        split_id = splits_per_replica * w_id + s_id
                        with py_utils.ModelSplit(split_id):
                            with tf.device(
                                    cluster.WorkerDeviceInModelSplit(0)):
                                with tf.name_scope('tower_%d_%d' %
                                                   (w_id, s_id)):
                                    batch = self.input_generator.PreprocessInputBatch(
                                        batches[split_id])
                                    dec_metrics = self.FPropTower(
                                        theta_local, batch)
                        all_fprop_metrics.append(dec_metrics)

            metrics = py_utils.WeightedAvgOfMetrics(all_fprop_metrics)

        # Adds stats about the input batch.
        metrics['num_samples_in_batch'] = (tf.convert_to_tensor(
            self.input_generator.InputBatchSize()), tf.constant(1.0))
        # Generates summaries.
        for name, (value, weight) in six.iteritems(metrics):
            self.AddEvalMetric(name, value, weight)

        # Loss.
        self._loss, self._num_predicts = metrics['loss']
        self._loss = py_utils.CheckNumerics(self._loss)

        return metrics
Beispiel #25
0
    def __init__(self, params):
        assert issubclass(params.cls, BaseTask)
        super(BaseTask, self).__init__(params)

        p = self.params

        if p.input:
            # TODO(zhifengc): Consider a simpler way to ensure the input
            # generator stops after one epoch.
            if p.is_eval and p.eval:
                seq_inp = issubclass(
                    p.input.cls,
                    base_input_generator.BaseSequenceInputGenerator)
                if p.input.num_samples == 0:
                    # Dataset size is unknown. Computes eval summary based on num_samples.
                    assert p.eval.samples_per_summary > 0
                elif (p.eval.samples_per_summary == 0) or (
                        p.input.num_samples < p.eval.samples_per_summary):
                    # If we know the dataset size and we want to evaluate the full
                    # set, we need to coordinate the input generator to flush out
                    # all samples so the evaler and decoder compute metrics on the
                    # whole set for each summary step.
                    if seq_inp:
                        p.input.flush_every_n = p.input.num_samples
                    p.eval.samples_per_summary = p.input.num_samples
                if seq_inp and p.input.num_batcher_threads > 1:
                    tf.logging.warning(
                        'input.num_batcher_threads > 1 inside eval mode.  '
                        'The input generator may not iterate over exactly '
                        'one epoch per run')

            cluster = cluster_factory.Current()
            with tf.device(
                    cluster.input_device), py_utils.outside_all_rewrites():
                self.CreateChild('input', p.input)

        self._var_grads = None
        self._encoder = None
        self._online_encoder = None
        self._decoder = None

        self._total_examples = None
        self._total_nans_and_infs = None
        self._loss = None
        self._num_predictions = None
        self._train_op = None
        self._eval_metrics = {}
        self._trainer_verbose_tensors = {}

        # Create the gradient mask,
        self._per_input_gradient_mask = None
        self._shared_global_step = py_utils.GetOrCreateGlobalStep()
        tp = p.train
        if tp:
            if tp.task_global_step:
                self._task_global_step = py_utils.CreateTaskGlobalStep(
                    p, p.name)
                self._global_step = self._task_global_step
            else:
                self._task_global_step = None
                self._global_step = self._shared_global_step
            if tp.grad_norm_tracker:
                with tf.variable_scope(p.name):
                    self.CreateChild('grad_norm_tracker', tp.grad_norm_tracker)

            self.CreateChild('lr_schedule', tp.lr_schedule)
        self._UpdateVnConfig()
Beispiel #26
0
  def ComputePredictions(self, theta, source_encs, source_paddings, targets,
                         src_segment_id):
    """Decodes `targets` given encoded source.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      source_encs: source encoding, of shape [time, batch, depth].
      source_paddings: source encoding's padding, of shape [time, batch].
      targets: A dict of string to tensors representing the targets one try to
        predict. Each tensor in targets is of shape [batch, time].
      src_segment_id: source segment id, of shape [time, batch].

    Returns:
      A Tensor with shape [time, batch, params.softmax.input_dim].
    """
    p = self.params
    time, batch = py_utils.GetShape(source_paddings, 2)
    source_encs = py_utils.HasShape(source_encs, [time, batch, p.source_dim])
    with tf.name_scope(p.name):
      target_ids = tf.transpose(targets.ids)
      target_paddings = py_utils.HasRank(targets.paddings, 2)
      target_paddings = tf.expand_dims(tf.transpose(target_paddings), 2)
      if p.packed_input:
        target_segment_id = tf.expand_dims(tf.transpose(targets.segment_ids), 2)
      else:
        target_segment_id = tf.zeros_like(target_paddings)

      if py_utils.use_tpu():
        emb_device = cluster_factory.Current().WorkerDeviceInModelSplit(0)
      else:
        emb_device = ''
      with tf.device(emb_device):
        inputs = self.emb.EmbLookup(theta.emb, target_ids)
        inputs = self.ApplyClipping(theta, inputs)
        summary_utils.histogram(p, 'input_emb', inputs)
        inputs = self.ApplyDropout(inputs)
        self._emb_out = inputs

        # Layer 0 interwines with attention.
        (atten_ctxs, xs, atten_probs, _) = self.frnn_with_atten.FProp(
            theta.frnn_with_atten,
            source_encs,
            source_paddings,
            inputs,
            target_paddings,
            src_segment_id=src_segment_id,
            segment_id=target_segment_id)
        if p.add_summary:
          self._AddAttenProbsSummary(source_paddings, targets, [atten_probs])

        atten_ctxs = self.ApplyClipping(theta, atten_ctxs)
        summary_utils.histogram(p, 'atten_ctxs', atten_ctxs)

        for i, (layer, layer_theta) in enumerate(zip(self.frnn, theta.frnn)):
          # Forward through Layer-(i + 1) because Layer-0 handled before.
          ys, _ = layer.FProp(
              layer_theta,
              tf.concat([xs, atten_ctxs], 2),
              target_paddings,
              segment_id=target_segment_id)
          ys = self.ApplyDropout(ys)
          if 1 + i >= p.residual_start:
            xs += ys  # Residual skip
            xs = self.ApplyClipping(theta, xs)
          else:
            xs = ys
          summary_utils.histogram(p, 'layer_out_%s' % i, xs)

        if p.feed_attention_context_vec_to_softmax:
          xs = tf.concat([xs, atten_ctxs], 2)

        return xs
Beispiel #27
0
  def CreateTpuFeeds(self):
    """Creates the TPU infeed queue from preprocessed batch."""
    p = self.params
    cluster = cluster_factory.Current()
    num_tpu_hosts = cluster.num_tpu_hosts
    assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
    num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

    with py_utils.outside_all_rewrites():
      assert py_utils.use_tpu()
      assert not self._made_tpu_infeed

      shards = tpu_function.get_tpu_context(
      ).number_of_shards // num_infeed_hosts
      input_ops_list = []
      queues = []
      first_batch = None
      for task_id in range(num_infeed_hosts):
        host_device = '/task:{}/device:CPU:0'.format(task_id)
        with tf.device(host_device):
          batch = self.GetPreprocessedInputBatch()
          if first_batch is None:
            first_batch = batch
          flat_batch = batch.FlattenItems()

          shapes, types = [], []
          for k, x in flat_batch:
            assert x.shape.is_fully_defined(), (
                'Shape must be fully defined: %s: %s' % (k, x))
            # TODO(cwhipkey): if it's a string (or other type not supported on
            # TPU), drop it from feeding and on the other end add in an op that
            # fails if used.
            shapes.append(x.shape)
            types.append(x.dtype)
          q = tf.contrib.tpu.InfeedQueue(tuple_types=types, tuple_shapes=shapes)
          queues.append(q)
          assert shards is not None
          q.set_number_of_shards(shards)

          if p.use_per_host_infeed:

            # TODO(ylc/zhifengc): Add this to a policy module and test it.
            def _tpu_ordinal_function(shard_index_in_host):
              device_assignment = py_utils.GetTpuDeviceAssignment()
              if device_assignment:
                # We put both enqueue/dequeue ops at core 0 in each replica.
                replica = device_assignment.lookup_replicas(
                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                return device_assignment.tpu_ordinal(replica=replica)
              else:
                return shard_index_in_host

            input_ops = q.split_inputs_and_generate_enqueue_ops(
                [v for _, v in flat_batch],
                placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                tpu_ordinal_function=_tpu_ordinal_function)
          else:
            input_ops = q.split_inputs_and_generate_enqueue_ops(
                [v for _, v in flat_batch],
                device_assignment=py_utils.GetTpuDeviceAssignment())

          input_ops_list += input_ops
      tf.logging.info('input_ops_list %s', input_ops_list)
      tpu_infeed_op = tf.group(*input_ops_list)
    self._made_tpu_infeed = True
    # Let trainer.py use multiple threads to drive the infeed op.
    for _ in range(p.tpu_infeed_parallism):
      tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

    with tf.device(tf.contrib.tpu.core(0)):
      tensors = queues[0].generate_dequeue_op()
    return first_batch.Pack(tensors)
Beispiel #28
0
    def Task(self):
        metadata = waymo_metadata.WaymoMetadata()
        num_classes = len(metadata.ClassNames())
        p = starnet.ModelV2.Params(
            num_classes,
            num_anchor_bboxes_offsets=self.NUM_ANCHOR_BBOX_OFFSETS,
            num_anchor_bboxes_rotations=self.NUM_ANCHOR_BBOX_ROTATIONS,
            num_anchor_bboxes_dimensions=self.NUM_ANCHOR_BBOX_DIMENSIONS,
            num_laser_features=3)

        # Update the Point Cloud Featurizer architecture
        starnet_builder = starnet.Builder()
        starnet_builder.linear_params_init = (
            py_utils.WeightInit.KaimingUniformFanInRelu())

        gin_layers = [[
            self.GIN_HIDDEN_DIMS * 2, self.GIN_HIDDEN_DIMS * 4,
            self.GIN_HIDDEN_DIMS
        ]] * self.NUM_GIN_LAYERS  # pyformat: disable

        p.cell_featurizer = starnet_builder.GINFeaturizerV2(
            'feat',
            num_laser_features=3,
            fc_dims=self.GIN_HIDDEN_DIMS,
            mlp_dims=gin_layers,
            fc_use_bn=False)
        p.cell_feature_dims = self.GIN_HIDDEN_DIMS * (self.NUM_GIN_LAYERS + 1)

        p.output_decoder = waymo_decoder.WaymoOpenDatasetDecoder.Params()
        p.max_nms_boxes = 512
        p.use_oriented_per_class_nms = True

        # Note: Sub-classes need to set nms_iou_threshold and nms_score_threshold
        # appropriately.
        p.nms_iou_threshold = [0.0] * num_classes

        # TODO(jngiam): 1.1 for untrained classes is needed to avoid an issue
        # with boxutils error.
        p.nms_score_threshold = [1.1] * num_classes

        p.name = 'starnet'
        tp = p.train
        tp.optimizer = optimizer.Adam.Params()
        tp.clip_gradient_norm_to_value = 5

        ep = p.eval

        # Train set uses a smaller decoding set, so we can
        # safely eval over the entire input.
        ep.samples_per_summary = 0

        # To be tuned.
        p.train.l2_regularizer_weight = 1e-8

        cluster = cluster_factory.Current()
        train_cluster_p = cluster.params.Copy()
        train_cluster_p.job = 'trainer_client'
        train_cluster_p.mode = 'sync'

        # When running a decoding only job, there are no trainer workers, so we set
        # worker replicas to 1 as a dummy value.
        if train_cluster_p.worker.replicas <= 0:
            train_cluster_p.worker.replicas = 1

        # Set learning rate and schedule.
        with cluster_factory.Cluster(train_cluster_p):
            train_input_p = self.Train()

        # Adapted from V1 tuning.
        tp.ema_decay = 0.99
        # TODO(b/148537111): consider setting this to True.
        tp.ema_decay_moving_vars = False
        tp.learning_rate = 0.001
        lr_util.SetExponentialLR(train_p=tp,
                                 train_input_p=train_input_p,
                                 exp_start_epoch=5,
                                 total_epoch=75)

        p.dimension_loss_weight = .3
        p.location_loss_weight = 3.
        p.loss_weight_classification = 1.
        p.loss_weight_localization = 3.
        p.rotation_loss_weight = 0.3

        return p
Beispiel #29
0
    def _configure_input(self, p, split):
        p.file_pattern_prefix = _WAYMO_BASE

        job_type = cluster_factory.Current().job

        max_num_points = int(64 * 2650 * 1.5)
        p.preprocessors = hyperparams.Params()
        p.preprocessors.Define(
            'filter_nlz_points',
            waymo_open_input_generator.FilterNLZPoints.Params(), '')
        # TODO(bencaine): Change this to filter based on difficulty instead
        p.preprocessors.Define(
            'filter_groundtruth',
            input_preprocessors.FilterGroundTruthByNumPoints.Params(), '')
        p.preprocessors.Define('viz_copy',
                               input_preprocessors.CreateDecoderCopy.Params(),
                               '')
        p.preprocessors.Define(
            'select_centers',
            input_preprocessors.SparseCenterSelector.Params(), '')
        p.preprocessors.Define(
            'gather_features',
            input_preprocessors.SparseCellGatherFeatures.Params(), '')
        p.preprocessors.Define('tile_anchors',
                               input_preprocessors.TileAnchorBBoxes.Params(),
                               '')
        p.preprocessors.Define('assign_anchors',
                               input_preprocessors.AnchorAssignment.Params(),
                               '')
        p.preprocessors.Define(
            'pad_lasers',
            input_preprocessors.PadLaserFeatures.Params().Set(
                max_num_points=max_num_points), '')

        p.preprocessors.viz_copy.pad_lasers.max_num_points = max_num_points
        p.preprocessors.filter_groundtruth.min_num_points = self.GT_MIN_NUM_POINTS

        p.preprocessors.select_centers.num_cell_centers = 1024
        p.preprocessors.gather_features.num_points_per_cell = self.NUM_POINTS_PER_CELL
        p.preprocessors.gather_features.sample_neighbors_uniformly = True
        p.preprocessors.gather_features.max_distance = 2.75

        p.preprocessors.assign_anchors.foreground_assignment_threshold = 0.6
        p.preprocessors.assign_anchors.background_assignment_threshold = 0.45

        p.preprocessors_order = [
            'filter_nlz_points',
            'filter_groundtruth',
            'viz_copy',
            'select_centers',
            'gather_features',
            'tile_anchors',
            'assign_anchors',
            'pad_lasers',
        ]

        # Apply car anchor box settings.
        tile_anchors_p = p.preprocessors.tile_anchors
        self.AnchorBoxSettings.Update(p.preprocessors.tile_anchors)
        num_anchor_configs = self.AnchorBoxSettings.NumAnchors()

        assert len(tile_anchors_p.anchor_box_dimensions) == num_anchor_configs
        assert len(tile_anchors_p.anchor_box_rotations) == num_anchor_configs
        assert len(tile_anchors_p.anchor_box_offsets) == num_anchor_configs

        # If this is not the decoder job (e.g., this is trainer), turn off
        # image decoding, do not count points, and do not make visualization copies.
        if job_type != 'decoder':
            p.preprocessors_order.remove('viz_copy')
            # Do not need laser points during training for current V2 model. This
            # reduces amount of data sent over during training.
            p.preprocessors.pad_lasers.max_num_points = 0

        p.file_buffer_size = 32
        p.file_parallelism = 8
        p.num_batcher_threads = 8
        if self.RUN_LOCALLY:
            p.num_batcher_threads = 1
            p.file_buffer_size = 1
            p.file_parallelism = 1

        if job_type.startswith('trainer'):
            p.batch_size = 2
        else:
            p.batch_size = 4
            p.file_buffer_size = 64
            p.file_parallelism = 16
            p.num_batcher_threads = 16
        return p
 def WithInputTargets():
     ret = copy.deepcopy(cluster_factory.Current())
     ret.params.input.targets = 'a,b'
     ret.params.input.replicas = 2
     return ret