def GetParamsForDataset(self, job_name, dataset_name): """Returns params for job `job_name` on the dataset `dataset_name`.""" # Get the current cluster and update its params from flags. cluster = cluster_factory.Current() self.UpdateClusterParamsFromFlags(cluster.params, job_name) with cluster_factory.Cluster(cluster.params): try: cfg = self.model_registry.GetParams(self._model_name, dataset_name) except base_model_params.DatasetError as e: dataset_name_retry = dataset_name.title() tf.logging.warning( 'Exception configuring dataset %s, retrying as %s: %s', dataset_name, dataset_name_retry, e) cfg = self.model_registry.GetParams(self._model_name, dataset_name_retry) tf.logging.warning('Succeeded after retrying as %s.' % dataset_name_retry) cfg.cluster = cluster.params # Updates a few params based on flags. if FLAGS.enqueue_max_steps is not None: cfg.train.enqueue_max_steps = FLAGS.enqueue_max_steps if FLAGS.saver_max_to_keep is not None: cfg.train.save_max_to_keep = FLAGS.saver_max_to_keep if FLAGS.saver_keep_checkpoint_every_n_hours is not None: cfg.train.save_keep_checkpoint_every_n_hours = FLAGS.saver_keep_checkpoint_every_n_hours return cfg
def AddLaserAndCamera(params): """Adds laser and camera extractors.""" cluster = cluster_factory.Current() job = cluster.job if job != 'decoder': return params extractor_params = list(dict(params.extractors.IterParams()).values()) extractor_classes = [p.cls for p in extractor_params] # Add images if not present. if kitti_input_generator.KITTIImageExtractor not in extractor_classes: params.extractors.Define( 'images', kitti_input_generator.KITTIImageExtractor.Params(), '') # Add raw lasers if not present. if kitti_input_generator.KITTILaserExtractor not in extractor_classes: labels = None for p in extractor_params: if p.cls == kitti_input_generator.KITTILabelExtractor: labels = p if labels is None: labels = kitti_input_generator.KITTILabelExtractor.Params() params.extractors.Define( 'lasers', kitti_input_generator.KITTILaserExtractor.Params(labels), '') return params
def _configure_input(self, p, split): p = super(StarNetPedFused, self)._configure_input(p, split) job_type = cluster_factory.Current().job if job_type.startswith('trainer') and not self.RUN_LOCALLY: p.file_buffer_size = 48 p.file_parallelism = 48 p.num_batcher_threads = 48 # Fuses select_centers and gather_features into one sampler. p.preprocessors.Define( 'sampler', input_preprocessors.SparseSampler.Params().Set( center_selector='farthest', neighbor_sampler='uniform', num_centers=p.preprocessors.select_centers.num_cell_centers, keep_z_range=(0.09522381, 1.720825), num_neighbors=p.preprocessors.gather_features. num_points_per_cell, max_distance=2.75), '') p.preprocessors.Delete('select_centers') p.preprocessors.Delete('gather_features') p.preprocessors_order.remove('select_centers') p.preprocessors_order[p.preprocessors_order.index( 'gather_features')] = 'sampler' return p
def _DecoderDevice(self): """Returns the device to run the decoder computation.""" if py_utils.use_tpu(): cluster = cluster_factory.Current() return tf.device(cluster.WorkerDeviceInModelSplit(1)) else: return tf.device('')
def _FromGlobal(field_name, allow_override_from_cluster=False): """Get 'field_name' from a global configuration object. Currently the global configuration object used is FLAGS, but this may change to Cluster() or an equivalent stack-scoped config object. Args: field_name: The string field name to look up. allow_override_from_cluster: Allow the Cluster() to override FLAGS. Returns: The value associated with the global configuration string 'field_name'. """ if allow_override_from_cluster: cluster = cluster_factory.Current() if field_name in cluster.params: params_value = cluster.params.Get(field_name) # Return the value in the cluster params if it is not None if params_value is not None: return params_value # Now check the FLAGS object for backwards compatibility. # # If not explicitly set, get the field from the FLAGS object. If FLAGS # have not been parsed yet, the default value of the flag will be used. return FLAGS[field_name].value
def testTextPackedInputBatchSize(self, use_per_host_infeed, packing_factor): p = cluster_factory.Current().params.Copy() p.job = 'trainer' p.worker.tpus_per_replica = 8 p.worker.num_tpu_hosts = 16 p.worker.devices_per_split = 2 cluster = p.Instantiate() with cluster, mock.patch('lingvo.core.py_utils.use_tpu', return_value=True): p = input_generator.TextPackedInput.Params() p.use_per_host_infeed = use_per_host_infeed p.file_random_seed = 0 p.file_pattern = 'tfrecord:' + test_helper.test_src_dir_path( 'tasks/mt/testdata/en_fr.tfrecord') p.pad_to_max_seq_length = True p.tokenizer = tokenizers.AsciiTokenizer.Params() p.input_file_type = 'sentence_proto' p.source_max_length = 32 p.target_max_length = 32 p.bucket_batch_limit = [128] p.packing_factor = packing_factor with self.session() as sess: inp = p.Instantiate() # GlobalBatchSize is batch_size (128) * num_splits_per_client (4). # num_splits_per_client is 4, because num_splits_per_replica is 4. # num_splits_per_replica is 4 because that's tpus_per_replica # divided by devices_per_split. expected_global_batch_size = ( p.bucket_batch_limit[0] // cluster.params.worker.devices_per_split * cluster.params.worker.tpus_per_replica) if p.packing_factor is not None: expected_global_batch_size = np.math.floor( expected_global_batch_size * p.packing_factor) expected_infeed_batch_size = expected_global_batch_size if use_per_host_infeed: expected_infeed_batch_size = ( expected_global_batch_size // cluster.params.worker.num_tpu_hosts) expected_packed_infeed_batch_size = expected_infeed_batch_size if p.packing_factor is not None: expected_packed_infeed_batch_size = np.math.floor( expected_infeed_batch_size / p.packing_factor) self.assertEqual(expected_global_batch_size, inp.GlobalBatchSize()) self.assertEqual(expected_infeed_batch_size, inp.InfeedBatchSize()) batch_tensor = inp.GetPreprocessedInputBatch() for k, x in batch_tensor.FlattenItems(): self.assertTrue(x.shape.is_fully_defined(), k) batch = sess.run(batch_tensor) self.assertEqual(batch.src.ids.shape, (expected_packed_infeed_batch_size, 32))
def _GetTrainingStatistics(train_input_p): """Get training statistics, including total batch size and steps per epoch.""" cluster = cluster_factory.Current() # E.g., this is 1 for a single GPU, 8 for a 2x2 TPU, 32 for a 4x4 TPU, # or 0 if no training job is launched. total_num_cores = cluster.total_worker_devices total_batch_size = max(train_input_p.batch_size * total_num_cores, 1) steps_per_epoch = float(train_input_p.num_samples) / total_batch_size return py_utils.NestedMap(total_num_cores=total_num_cores, total_batch_size=total_batch_size, steps_per_epoch=steps_per_epoch)
def scaled_bucket_batch_limit(self): p = self.params if not hasattr(self, '_scaled_bucket_batch_limit'): cluster = cluster_factory.Current() self._scaled_bucket_batch_limit = [ b * cluster.num_splits_per_client for b in p.bucket_batch_limit ] if p.use_per_host_infeed and cluster.num_tpu_hosts > 0: self._scaled_bucket_batch_limit = [ x // cluster.num_tpu_hosts for x in self._scaled_bucket_batch_limit ] return self._scaled_bucket_batch_limit
def InputBatchSize(self): """Returns the batch size for the current step.""" p = self.params cluster = cluster_factory.Current() # If use_per_host_infeed, each input op is only responsible # for generating a subset of the whole batch. batch_per_input = p.batch_size * cluster.num_splits_per_client if p.use_per_host_infeed and cluster.num_tpu_hosts > 0: tf.logging.info('batch_size %d cluster.num_tpu_hosts %d', batch_per_input, cluster.num_tpu_hosts) batch_per_input //= cluster.num_tpu_hosts tf.logging.info('batch_per_input: %d', batch_per_input) return batch_per_input
def _GetSaver(self): """Returns a saver.""" do_eval = cluster_factory.Current().do_eval if not self._save_only and self._model.ema and do_eval: tf.logging.info('Using EMA for evaluation.') return tf.train.Saver( self._model.ema.variables_to_restore(self._model.variables_for_ema)) return tf.train.Saver( sharded=True, max_to_keep=self._train_params.save_max_to_keep, keep_checkpoint_every_n_hours=( self._train_params.save_keep_checkpoint_every_n_hours), pad_step_number=True, # %08d write_version=tf.train.SaverDef.V2)
def __init__(self, model_name, split, run_preprocessors): self._model_name = model_name self._split = split self._run_preprocessors = run_preprocessors self._sess = None # Create a cluster configuration assuming evaluation; the input pipelines # need to know the cluster job type to set up the outputs correctly. cluster = cluster_factory.Current() cluster.params.job = 'evaler' cluster.params.mode = 'sync' cluster.params.task = 0 cluster.params.evaler.replicas = 1 self._cluster = cluster_factory.Cluster(cluster.params)
def GetExecutorParams(self): """Get the params needed to instantiate the ExecutorTpu. Returns: Tuple (dict, params): - ps_params_dict: high_level task_name -> ProgramScheduleParams - train_cfg: Either a SingleTaskModelParams or MultiTaskModelParams. """ cluster = cluster_factory.Current() self.UpdateClusterParamsFromFlags(cluster.params, 'executor_tpu') ps_params_dict, train_cfg = executor.GetExecutorParams( self._model_name, cluster.params, self.model_registry) return ps_params_dict, train_cfg
def _MaybeOverwriteModelVariablesWithEMA(self): """Overwrite model variables with EMA shadow variables in eval mode.""" do_eval = cluster_factory.Current().do_eval if not self._save_only and do_eval: for model in self._models: if not model.ema: continue tf.logging.info('Using EMA for evaluation.') # TODO(jiaweix): this implementation will load both the model variables # and EMA variables. As a result the memory usage will be higher than # the eval jobs in TF1 mode. ema = model.ema cur_vars = model.GetVariablesDict() for v in cur_vars.values(): shadow_v = ema.average(v) if shadow_v is not None: v.assign(shadow_v)
def _GetSaver(self): """Returns a saver.""" do_eval = cluster_factory.Current().do_eval variables_to_restore = {} if not self._save_only and do_eval: for model in self._models: if model.ema: tf.logging.info('Using EMA for evaluation.') variables_to_restore.update( model.ema.variables_to_restore( model.variables_for_ema)) if not variables_to_restore: variables_to_restore = None return SaverWrapper(self._train_dir, self._train_params, variables_to_restore_dict=variables_to_restore, async_save=self.async_checkpointing)
def ConcatenateAcrossReplicas(tensors, tpu_cores=None, axis=0, stop_cross_gradients=False): """Concatenates one or more local tensors across all TPU cores. Input `tensors` may be in any format supported by `tf.nest.flatten` (single Tensor, dict, etc.), or a `NestedMap`. In order to avoid having to pass a TPU core ID into the infeed, this implementation produces a different rotation of the concatenation for each core. For example, core 0 will be arranged as [0, 1, 2, ...], whereas core 3 will be arranged as [3, 4, ..., 0, 1, 2]. If called from a non-TPU context, this function returns the `tensors` unchanged. Args: tensors: The local tensor or tensors to concatenate across cores. tpu_cores: The total number of TPU cores. If not set, the number of cores is inferred from `cluster_factory.Current()`. axis: The axis to concatenate. stop_cross_gradients: If true, stop gradients on cross-replica slices. Returns: The tensor(s) concatenated across all replicas. """ if not py_utils.use_tpu(): return tensors if tpu_cores is None: cluster = cluster_factory.Current() tpu_cores = cluster.tpus_per_replica * cluster.num_replicas assert tpu_cores, 'Unable to determine number of TPU cores from cluster.' concat_fn = functools.partial(CrossReplicaConcat, tpu_cores=tpu_cores, axis=axis, stop_cross_gradients=stop_cross_gradients) concatenated_tensors = [concat_fn(t) for t in tf.nest.flatten(tensors)] if isinstance(tensors, py_utils.NestedMap): return tensors.Pack(concatenated_tensors) else: return tf.nest.pack_sequence_as(tensors, concatenated_tensors)
def _configure_input(self, p): """Base function managing the delegation of job specific input configs.""" self._configure_generic_input(p) cluster = cluster_factory.Current() job = cluster.job if job.startswith('trainer'): self._configure_trainer_input(p) elif job.startswith('decoder'): self._configure_decoder_input(p) elif job.startswith('evaler'): self._configure_evaler_input(p) else: tf.logging.info('There are no input configuration changes to for ' 'job {}.'.format(job)) if self.RUN_LOCALLY: p.num_batcher_threads = 1 p.file_buffer_size = 1 p.file_parallelism = 1
def scale_split_to_infeed(split_batch_size, use_per_host_infeed): """Obtains an infeed batch size from a split batch size and cluster configs. Args: split_batch_size: int: Per-split batch size. use_per_host_infeed: bool: Whether to use an individual infeed for each host. Returns: int: Per-infeed batch size. """ cluster = cluster_factory.Current() global_batch_size = split_batch_size * cluster.num_splits_per_client # If use_per_host_infeed, each input op is only responsible # for generating a subset of the whole batch. if use_per_host_infeed and cluster.num_tpu_hosts > 0: return global_batch_size // cluster.num_tpu_hosts else: return global_batch_size
def scale_infeed_to_global(infeed_batch_size, use_per_host_infeed): """Obtains a global batch size from an infeed batch size and cluster configs. Args: infeed_batch_size: int: Per-infeed batch size. use_per_host_infeed: bool: Whether to use an individual infeed for each host. Returns: int: Global batch size. """ cluster = cluster_factory.Current() if use_per_host_infeed and cluster.num_tpu_hosts > 0: if not py_utils.use_tpu(): raise ValueError('Scaling to TPU hosts without TPUs. {}'.format( cluster.num_tpu_hosts)) return infeed_batch_size * cluster.num_tpu_hosts else: return infeed_batch_size
def __init__(self, p): if self._VALIDATE_BATCH_SIZE_NONE and p.batch_size is not None: raise ValueError( 'LingvoInputAdaptor does not support p.batch_size. ' 'Please specify batch size on p.input, e.g. with ' 'p.input.bucket_batch_limit = [4] or ' 'p.input.args.batch=4, depeding the Lingvo input ' f'used. Currently: p.batch_size={p.batch_size}, ' 'it must be None.') super().__init__(p) self._cluster = copy.deepcopy(cluster_factory.Current()) # For Lingvo's Cluster context that may impact the behavior of this input # generator, we always set use_tpu to True, and optionally set do_eval # for non-training data when configured to do so. All other Cluster params # use the default value. self._cluster.params.xla_device = 'tpu' self._cluster.params.enable_asserts = False # This indirectly sets cluster.require_sequential_input_order as well. self._cluster.params.do_eval = (not p.is_training and p.cluster_do_eval) self._initialize()
def __init__(self, params): super(LinearRampupExponentialDecayScaledByNumSplitSchedule, self).__init__(params) p = self.params # We always compute lr schedule from the trainer's perspective. # Also note that this schedule makes sense to sync training only. if p.num_splits: splits = p.num_splits else: # Infer num_splits from cluster. cluster_params = cluster_factory.Current().params.Copy() cluster_params.task = 0 assert cluster_params.mode == 'sync' cluster_params.job = 'trainer_client' my_cluster = cluster_params.cls(cluster_params) splits = my_cluster.num_splits_per_client warmup_end = p.warmup * splits decay_start = max(warmup_end + 1.0, p.decay_start / splits) peak = 1.0 * splits tf.logging.info('Peak lr: %f', peak) decay_end = max(decay_start + 1.0, p.decay_end / splits) schedules = [ LinearLearningRateSchedule.Params().Set(start=(0., p.warmup_init), limit=(warmup_end, peak)), LinearLearningRateSchedule.Params().Set(start=(warmup_end, peak), limit=(decay_start, peak)), ExponentialLearningRateSchedule.Params().Set(start=(decay_start, peak), limit=(decay_end, p.min)), LinearLearningRateSchedule.Params().Set(start=(0, p.max), limit=(decay_end, p.max)), ] self.CreateChild( 'combine', CombinedMinimumLearningRateSchedule.Params().Set( schedules=schedules))
def __init__(self, params): super(LinearRampupPiecewiseConstantSchedule, self).__init__(params) p = self.params assert len(p.boundaries) >= 2 and len(p.boundaries) == len(p.lrs) # We always compute lr schedule from the trainer's perspective. # Also note that this schedule makes sense to sync training only. if p.num_splits: splits = p.num_splits else: # Infer num_splits from cluster. cluster_params = cluster_factory.Current().params.Copy() cluster_params.task = 0 assert cluster_params.mode == 'sync' cluster_params.job = 'trainer_client' my_cluster = cluster_params.cls(cluster_params) splits = my_cluster.num_splits_per_client assert splits >= 1 splits = float(splits) boundaries = [step / splits for step in p.boundaries] lrs = [step * splits for step in p.lrs] tf.logging.info('splits: {}\n boundaries: {}\n lrs: {} '.format( splits, boundaries, lrs)) schedules = [ LinearLearningRateSchedule.Params().Set(start=(0., 0.), limit=(boundaries[0], lrs[0])), PiecewiseConstantLearningRateSchedule.Params().Set( boundaries=boundaries, values=[1e8] + lrs) ] self.CreateChild( 'combine', CombinedMinimumLearningRateSchedule.Params().Set( schedules=schedules))
def cluster(self): """Returns the current cluster configuration.""" return cluster_factory.Current()
def _ShouldAddSummary(): return cluster_factory.Current().add_summary
def FProp(self, theta): """Forward propagation. This default `FProp` implementation here supports batch splitting in synchronous and asynchronous training when sub-classes implement `FPropTower`. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. Returns: A dict containing metrics pairs. One of the keys should be 'loss' and its value should be a (loss, num_predictions) pair. """ p = self.params cluster = cluster_factory.Current() with tf.name_scope('fprop'), tf.name_scope(p.name): all_fprop_metrics = [] if py_utils.use_tpu(): batch = self.input_generator.CreateTpuFeeds() with tf.name_scope('tower_0_0'): dec_metrics = self.FPropTower(theta, batch) all_fprop_metrics.append(dec_metrics) else: # Splits the input batch on the input device. num_splits = cluster.num_splits_per_client with tf.device(cluster.input_device): batches = self.input_generator.SplitInputBatch(num_splits) assert num_splits == len(batches) # dev_list_per_replica[i][j] is the i-th worker's j-th device. dev_list_per_replica = cluster.available_devices.tolist() # Asserts invariant of the total number of splits w.r.t., # splits per worker. splits_per_replica = cluster.num_splits_per_replica assert num_splits == splits_per_replica * len( dev_list_per_replica) for w_id, w_devs in enumerate(dev_list_per_replica): # Make local copy of the vars, shard on devices for this worker. theta_local = py_utils.CreateLocalTheta(theta, w_devs, label='worker %d' % w_id) for s_id in range(splits_per_replica): # s_id-th split for the w_id-th worker. split_id = splits_per_replica * w_id + s_id with py_utils.ModelSplit(split_id): with tf.device( cluster.WorkerDeviceInModelSplit(0)): with tf.name_scope('tower_%d_%d' % (w_id, s_id)): batch = self.input_generator.PreprocessInputBatch( batches[split_id]) dec_metrics = self.FPropTower( theta_local, batch) all_fprop_metrics.append(dec_metrics) metrics = py_utils.WeightedAvgOfMetrics(all_fprop_metrics) # Adds stats about the input batch. metrics['num_samples_in_batch'] = (tf.convert_to_tensor( self.input_generator.InputBatchSize()), tf.constant(1.0)) # Generates summaries. for name, (value, weight) in six.iteritems(metrics): self.AddEvalMetric(name, value, weight) # Loss. self._loss, self._num_predicts = metrics['loss'] self._loss = py_utils.CheckNumerics(self._loss) return metrics
def __init__(self, params): assert issubclass(params.cls, BaseTask) super(BaseTask, self).__init__(params) p = self.params if p.input: # TODO(zhifengc): Consider a simpler way to ensure the input # generator stops after one epoch. if p.is_eval and p.eval: seq_inp = issubclass( p.input.cls, base_input_generator.BaseSequenceInputGenerator) if p.input.num_samples == 0: # Dataset size is unknown. Computes eval summary based on num_samples. assert p.eval.samples_per_summary > 0 elif (p.eval.samples_per_summary == 0) or ( p.input.num_samples < p.eval.samples_per_summary): # If we know the dataset size and we want to evaluate the full # set, we need to coordinate the input generator to flush out # all samples so the evaler and decoder compute metrics on the # whole set for each summary step. if seq_inp: p.input.flush_every_n = p.input.num_samples p.eval.samples_per_summary = p.input.num_samples if seq_inp and p.input.num_batcher_threads > 1: tf.logging.warning( 'input.num_batcher_threads > 1 inside eval mode. ' 'The input generator may not iterate over exactly ' 'one epoch per run') cluster = cluster_factory.Current() with tf.device( cluster.input_device), py_utils.outside_all_rewrites(): self.CreateChild('input', p.input) self._var_grads = None self._encoder = None self._online_encoder = None self._decoder = None self._total_examples = None self._total_nans_and_infs = None self._loss = None self._num_predictions = None self._train_op = None self._eval_metrics = {} self._trainer_verbose_tensors = {} # Create the gradient mask, self._per_input_gradient_mask = None self._shared_global_step = py_utils.GetOrCreateGlobalStep() tp = p.train if tp: if tp.task_global_step: self._task_global_step = py_utils.CreateTaskGlobalStep( p, p.name) self._global_step = self._task_global_step else: self._task_global_step = None self._global_step = self._shared_global_step if tp.grad_norm_tracker: with tf.variable_scope(p.name): self.CreateChild('grad_norm_tracker', tp.grad_norm_tracker) self.CreateChild('lr_schedule', tp.lr_schedule) self._UpdateVnConfig()
def ComputePredictions(self, theta, source_encs, source_paddings, targets, src_segment_id): """Decodes `targets` given encoded source. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. source_encs: source encoding, of shape [time, batch, depth]. source_paddings: source encoding's padding, of shape [time, batch]. targets: A dict of string to tensors representing the targets one try to predict. Each tensor in targets is of shape [batch, time]. src_segment_id: source segment id, of shape [time, batch]. Returns: A Tensor with shape [time, batch, params.softmax.input_dim]. """ p = self.params time, batch = py_utils.GetShape(source_paddings, 2) source_encs = py_utils.HasShape(source_encs, [time, batch, p.source_dim]) with tf.name_scope(p.name): target_ids = tf.transpose(targets.ids) target_paddings = py_utils.HasRank(targets.paddings, 2) target_paddings = tf.expand_dims(tf.transpose(target_paddings), 2) if p.packed_input: target_segment_id = tf.expand_dims(tf.transpose(targets.segment_ids), 2) else: target_segment_id = tf.zeros_like(target_paddings) if py_utils.use_tpu(): emb_device = cluster_factory.Current().WorkerDeviceInModelSplit(0) else: emb_device = '' with tf.device(emb_device): inputs = self.emb.EmbLookup(theta.emb, target_ids) inputs = self.ApplyClipping(theta, inputs) summary_utils.histogram(p, 'input_emb', inputs) inputs = self.ApplyDropout(inputs) self._emb_out = inputs # Layer 0 interwines with attention. (atten_ctxs, xs, atten_probs, _) = self.frnn_with_atten.FProp( theta.frnn_with_atten, source_encs, source_paddings, inputs, target_paddings, src_segment_id=src_segment_id, segment_id=target_segment_id) if p.add_summary: self._AddAttenProbsSummary(source_paddings, targets, [atten_probs]) atten_ctxs = self.ApplyClipping(theta, atten_ctxs) summary_utils.histogram(p, 'atten_ctxs', atten_ctxs) for i, (layer, layer_theta) in enumerate(zip(self.frnn, theta.frnn)): # Forward through Layer-(i + 1) because Layer-0 handled before. ys, _ = layer.FProp( layer_theta, tf.concat([xs, atten_ctxs], 2), target_paddings, segment_id=target_segment_id) ys = self.ApplyDropout(ys) if 1 + i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys summary_utils.histogram(p, 'layer_out_%s' % i, xs) if p.feed_attention_context_vec_to_softmax: xs = tf.concat([xs, atten_ctxs], 2) return xs
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = cluster_factory.Current() num_tpu_hosts = cluster.num_tpu_hosts assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts input_ops_list = [] queues = [] first_batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if first_batch is None: first_batch = batch flat_batch = batch.FlattenItems() shapes, types = [], [] for k, x in flat_batch: assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes.append(x.shape) types.append(x.dtype) q = tf.contrib.tpu.InfeedQueue(tuple_types=types, tuple_shapes=shapes) queues.append(q) assert shards is not None q.set_number_of_shards(shards) if p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def _tpu_ordinal_function(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment() if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal(replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( [v for _, v in flat_batch], placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=_tpu_ordinal_function) else: input_ops = q.split_inputs_and_generate_enqueue_ops( [v for _, v in flat_batch], device_assignment=py_utils.GetTpuDeviceAssignment()) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) with tf.device(tf.contrib.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return first_batch.Pack(tensors)
def Task(self): metadata = waymo_metadata.WaymoMetadata() num_classes = len(metadata.ClassNames()) p = starnet.ModelV2.Params( num_classes, num_anchor_bboxes_offsets=self.NUM_ANCHOR_BBOX_OFFSETS, num_anchor_bboxes_rotations=self.NUM_ANCHOR_BBOX_ROTATIONS, num_anchor_bboxes_dimensions=self.NUM_ANCHOR_BBOX_DIMENSIONS, num_laser_features=3) # Update the Point Cloud Featurizer architecture starnet_builder = starnet.Builder() starnet_builder.linear_params_init = ( py_utils.WeightInit.KaimingUniformFanInRelu()) gin_layers = [[ self.GIN_HIDDEN_DIMS * 2, self.GIN_HIDDEN_DIMS * 4, self.GIN_HIDDEN_DIMS ]] * self.NUM_GIN_LAYERS # pyformat: disable p.cell_featurizer = starnet_builder.GINFeaturizerV2( 'feat', num_laser_features=3, fc_dims=self.GIN_HIDDEN_DIMS, mlp_dims=gin_layers, fc_use_bn=False) p.cell_feature_dims = self.GIN_HIDDEN_DIMS * (self.NUM_GIN_LAYERS + 1) p.output_decoder = waymo_decoder.WaymoOpenDatasetDecoder.Params() p.max_nms_boxes = 512 p.use_oriented_per_class_nms = True # Note: Sub-classes need to set nms_iou_threshold and nms_score_threshold # appropriately. p.nms_iou_threshold = [0.0] * num_classes # TODO(jngiam): 1.1 for untrained classes is needed to avoid an issue # with boxutils error. p.nms_score_threshold = [1.1] * num_classes p.name = 'starnet' tp = p.train tp.optimizer = optimizer.Adam.Params() tp.clip_gradient_norm_to_value = 5 ep = p.eval # Train set uses a smaller decoding set, so we can # safely eval over the entire input. ep.samples_per_summary = 0 # To be tuned. p.train.l2_regularizer_weight = 1e-8 cluster = cluster_factory.Current() train_cluster_p = cluster.params.Copy() train_cluster_p.job = 'trainer_client' train_cluster_p.mode = 'sync' # When running a decoding only job, there are no trainer workers, so we set # worker replicas to 1 as a dummy value. if train_cluster_p.worker.replicas <= 0: train_cluster_p.worker.replicas = 1 # Set learning rate and schedule. with cluster_factory.Cluster(train_cluster_p): train_input_p = self.Train() # Adapted from V1 tuning. tp.ema_decay = 0.99 # TODO(b/148537111): consider setting this to True. tp.ema_decay_moving_vars = False tp.learning_rate = 0.001 lr_util.SetExponentialLR(train_p=tp, train_input_p=train_input_p, exp_start_epoch=5, total_epoch=75) p.dimension_loss_weight = .3 p.location_loss_weight = 3. p.loss_weight_classification = 1. p.loss_weight_localization = 3. p.rotation_loss_weight = 0.3 return p
def _configure_input(self, p, split): p.file_pattern_prefix = _WAYMO_BASE job_type = cluster_factory.Current().job max_num_points = int(64 * 2650 * 1.5) p.preprocessors = hyperparams.Params() p.preprocessors.Define( 'filter_nlz_points', waymo_open_input_generator.FilterNLZPoints.Params(), '') # TODO(bencaine): Change this to filter based on difficulty instead p.preprocessors.Define( 'filter_groundtruth', input_preprocessors.FilterGroundTruthByNumPoints.Params(), '') p.preprocessors.Define('viz_copy', input_preprocessors.CreateDecoderCopy.Params(), '') p.preprocessors.Define( 'select_centers', input_preprocessors.SparseCenterSelector.Params(), '') p.preprocessors.Define( 'gather_features', input_preprocessors.SparseCellGatherFeatures.Params(), '') p.preprocessors.Define('tile_anchors', input_preprocessors.TileAnchorBBoxes.Params(), '') p.preprocessors.Define('assign_anchors', input_preprocessors.AnchorAssignment.Params(), '') p.preprocessors.Define( 'pad_lasers', input_preprocessors.PadLaserFeatures.Params().Set( max_num_points=max_num_points), '') p.preprocessors.viz_copy.pad_lasers.max_num_points = max_num_points p.preprocessors.filter_groundtruth.min_num_points = self.GT_MIN_NUM_POINTS p.preprocessors.select_centers.num_cell_centers = 1024 p.preprocessors.gather_features.num_points_per_cell = self.NUM_POINTS_PER_CELL p.preprocessors.gather_features.sample_neighbors_uniformly = True p.preprocessors.gather_features.max_distance = 2.75 p.preprocessors.assign_anchors.foreground_assignment_threshold = 0.6 p.preprocessors.assign_anchors.background_assignment_threshold = 0.45 p.preprocessors_order = [ 'filter_nlz_points', 'filter_groundtruth', 'viz_copy', 'select_centers', 'gather_features', 'tile_anchors', 'assign_anchors', 'pad_lasers', ] # Apply car anchor box settings. tile_anchors_p = p.preprocessors.tile_anchors self.AnchorBoxSettings.Update(p.preprocessors.tile_anchors) num_anchor_configs = self.AnchorBoxSettings.NumAnchors() assert len(tile_anchors_p.anchor_box_dimensions) == num_anchor_configs assert len(tile_anchors_p.anchor_box_rotations) == num_anchor_configs assert len(tile_anchors_p.anchor_box_offsets) == num_anchor_configs # If this is not the decoder job (e.g., this is trainer), turn off # image decoding, do not count points, and do not make visualization copies. if job_type != 'decoder': p.preprocessors_order.remove('viz_copy') # Do not need laser points during training for current V2 model. This # reduces amount of data sent over during training. p.preprocessors.pad_lasers.max_num_points = 0 p.file_buffer_size = 32 p.file_parallelism = 8 p.num_batcher_threads = 8 if self.RUN_LOCALLY: p.num_batcher_threads = 1 p.file_buffer_size = 1 p.file_parallelism = 1 if job_type.startswith('trainer'): p.batch_size = 2 else: p.batch_size = 4 p.file_buffer_size = 64 p.file_parallelism = 16 p.num_batcher_threads = 16 return p
def WithInputTargets(): ret = copy.deepcopy(cluster_factory.Current()) ret.params.input.targets = 'a,b' ret.params.input.replicas = 2 return ret