def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table): p = self.params load_op_list = [] retrieve_op_list = [] num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts table_name = tpu_embedding_table.table_name slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars'] for host_id, table_var in zip(range(num_tpu_hosts), table_vars): # The slot vars should be on the same device as the table var. device_name = tpu_embedding_table.GetDeviceName(host_id) with tf.device(device_name), py_utils.outside_all_rewrites(): w_ada = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(p.initial_accumulator), dtype=p.dtype, collections=slot_var_collections) var_name = tpu_embedding_table.GetVariableName(host_id) + '/Adagrad' tpu_embedding_table.CreateVariable(var_name, w_ada, trainable=False) accumulator_var = tpu_embedding_table.vars[var_name] # Only the Trainer needs these ops. if py_utils.use_tpu(): # Remove the slot vars from the variable list to void copying them # to TPU (by the tf.cast in tpu_embedding_table.theta). # pylint: disable=protected-access del tpu_embedding_table._private_vars[var_name] del tpu_embedding_table._private_theta[var_name] # pylint: enable=protected-access # TPU Embedding load/retrieve ops need to be in the outer graph scope. with tf.init_scope(): tf.logging.info('creating load and retrieve ops.') load_parameters_op = ( tpu_embedding_lib.tpu_ops.load_tpu_embedding_adagrad_parameters( parameters=table_var, accumulators=accumulator_var, table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) load_op_list.append(load_parameters_op) retrieved_table, retrieved_accumulator = ( tpu_embedding_lib.tpu_ops .retrieve_tpu_embedding_adagrad_parameters( table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group( tf.assign(table_var, retrieved_table), tf.assign(accumulator_var, retrieved_accumulator)) retrieve_op_list.append(retrieve_parameters_op) return load_op_list, retrieve_op_list
def GetNext(self): """Returns the next element from the dataset.""" # Use `init_scope()` to ensure that the datasets and iterators are created # outside of the function-building graph. This ensures that these creation # operations are not repeated in subsequent `tf.function` calls. with tf.init_scope(): self._InitIterator() if py_utils.GetUnitTestSession(): self.Initialize(py_utils.GetUnitTestSession()) return self._iterator[self.host_id].get_next()
def apply_gradients(self, grads_and_vars, global_step=None, name=None): if self._num_micro_batches == 1: return self._opt.apply_gradients(grads_and_vars, global_step) global_step = global_step or py_utils.GetOrCreateGlobalStepVar() with tf.init_scope(): self._create_slots([v for (_, v) in grads_and_vars]) accums = [] variables = [] for g, v in grads_and_vars: accum = self.get_slot(v, 'grad_accum') variables.append(v) # pytype: disable=attribute-error if isinstance(g, tf.IndexedSlices): scaled_grad = tf.IndexedSlices(g.values / self._num_micro_batches, g.indices, dense_shape=g.dense_shape) else: scaled_grad = g / self._num_micro_batches accum_tensor = accum.read_value() accums.append(accum.assign(accum_tensor + scaled_grad)) # pytype: enable=attribute-error def _ApplyAndReset(): normalized_accums = accums if self._apply_crs_to_grad: normalized_accums = [ tf.tpu.cross_replica_sum(accum.read_value()) for accum in accums ] apply_op = self._opt.apply_gradients( list(zip(normalized_accums, variables))) with tf.control_dependencies([apply_op]): zero_op = [ tf.assign(accum, tf.zeros_like(accum)) for accum in accums ] return tf.group(zero_op, tf.assign_add(global_step, 1)) def _Accum(): return tf.no_op() accum_step = tf.cond( tf.equal( tf.math.floormod(self._counter + 1, self._num_micro_batches), 0), _ApplyAndReset, # Apply the accumulated gradients and reset. _Accum) # Accumulate gradients. with tf.control_dependencies([tf.group(accums)]): return tf.group(accum_step, tf.assign_add(self._counter, 1))
def _BuildTpuEmbeddingApi(): load_op_list = [] retrieve_op_list = [] num_cores = self.cluster.params.worker.tpus_per_replica global_batch_size = (self.params.batch_size * self.cluster.num_splits_per_client) table_to_config_dict = {} feature_to_config_dict = {} for table in self.tables: table_to_config_dict[table.table_name] = table.table_config load_op_list += table.load_op_list retrieve_op_list += table.retrieve_op_list for feature in table.input_keys: feature_to_config_dict[ feature] = tpu_embedding_lib.FeatureConfig( table.table_name, max_sequence_length=table.max_sequence_length) mode = tpu_embedding_lib.TRAINING device_config = tpu_embedding_lib.DeviceConfig( num_cores=num_cores, num_hosts=self.params.tables[0].num_tpu_hosts, job_name=self.cluster.params.worker.name) tpu_embedding = tpu_embedding_lib.TPUEmbedding( table_to_config_dict, feature_to_config_dict, global_batch_size, mode, master=None, pipeline_execution_with_tensor_core=( self.params.pipeline_execution_with_tensor_core), partition_strategy=p.partition_strategy, device_config=device_config) with tf.init_scope(): dummy_variables, dummy_variables_init = ( tpu_embedding_gradient.create_dummy_table_variables( tpu_embedding)) load_op_list += [dummy_variables_init] tf.add_to_collection(py_utils.TPU_EMBEDDING, tpu_embedding) tf.add_to_collection(py_utils.TPU_EMBEDDING_DUMMY_VARS, dummy_variables) tf.add_to_collection(py_utils.TPU_EMBEDDING_LOAD_OPS, load_op_list) tf.add_to_collection(py_utils.TPU_EMBEDDING_RETRIEVE_OPS, retrieve_op_list)
def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table): load_op_list = [] retrieve_op_list = [] num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts table_name = tpu_embedding_table.table_name for host_id, table_var in zip(range(num_tpu_hosts), table_vars): # The slot vars should be on the same device as the table var. device_name = tpu_embedding_table.GetDeviceName(host_id) with tf.device(device_name), py_utils.outside_all_rewrites(): # Only the Trainer needs these ops. if py_utils.use_tpu(): # TPU Embedding load/retrieve ops need to be in the outer graph scope. with tf.init_scope(): tf.logging.info('creating load and retrieve ops.') load_parameters_op = ( tpu_embedding_lib.tpu_ops. load_tpu_embedding_stochastic_gradient_descent_parameters( parameters=table_var, table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) load_op_list.append(load_parameters_op) retrieved_table = ( tpu_embedding_lib.tpu_ops. retrieve_tpu_embedding_stochastic_gradient_descent_parameters( table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group( tf.assign(table_var, retrieved_table)) retrieve_op_list.append(retrieve_parameters_op) return load_op_list, retrieve_op_list
def GetNext(self): """Override of the root's GetNext to support checking repeat sentinel.""" # Use `init_scope()` to ensure that the datasets and iterators are created # outside of the function-building graph. This ensures that these creation # operations are not repeated in subsequent `tf.function` calls. with tf.init_scope(): self._InitIterator() if py_utils.GetUnitTestSession(): self.Initialize(py_utils.GetUnitTestSession()) batch = self._iterator[self.host_id].get_next() # Sentinel check. if self._repeat_with_sentinel and not self._repeat_steps: assert_op = tf.debugging.assert_none_equal( batch[self.params.sentinel_key], tf.constant(self.params.sentinel_value), summarize=1, message='REPEAT_SENTINEL_') tf.logging.info('sentinel constant dtype %r', tf.constant(self.params.sentinel_value)) with tf.control_dependencies([assert_op]): # This identity transform will throw tf.errors.InvalidArgumentError # if assert_op fails (sentinel_key takes on sentinel_value). batch = batch.Transform(tf.identity) return batch
def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table): p = self.params load_op_list = [] retrieve_op_list = [] num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts table_name = tpu_embedding_table.table_name slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars'] for host_id, table_var in zip(range(num_tpu_hosts), table_vars): # The slot vars should be on the same device as the table var. device_name = tpu_embedding_table.GetDeviceName(host_id) with tf.device(device_name), py_utils.outside_all_rewrites(): accumulator = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(p.initial_accumulator_value), dtype=p.dtype, collections=slot_var_collections) accumulator_name = ( tpu_embedding_table.GetVariableName(host_id) + '/Ftrl') tpu_embedding_table.CreateVariable( accumulator_name, accumulator, trainable=False) accumulator_var = tpu_embedding_table.vars[accumulator_name] linear = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(p.initial_linear_value), dtype=p.dtype, collections=slot_var_collections) linear_name = tpu_embedding_table.GetVariableName(host_id) + '/Ftrl_1' tpu_embedding_table.CreateVariable(linear_name, linear, trainable=False) linear_var = tpu_embedding_table.vars[linear_name] # Only the Trainer needs these ops. if py_utils.use_tpu(): # Remove the slot vars from the variable list to avoid them being # copied to TPU. _RemovePrivateVar(tpu_embedding_table, accumulator_name) _RemovePrivateVar(tpu_embedding_table, linear_name) # TPU Embedding load/retrieve ops need to be in the outer graph scope. with tf.init_scope(): tf.logging.info('creating load and retrieve ops.') load_parameters_op = ( tpu_embedding_lib.tpu_ops.load_tpu_embedding_ftrl_parameters( parameters=table_var, accumulators=accumulator_var, linears=linear_var, table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) load_op_list.append(load_parameters_op) retrieved_table, retrieved_accumulator, retrieved_linear = ( tpu_embedding_lib.tpu_ops .retrieve_tpu_embedding_ftrl_parameters( table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group( tf.assign(table_var, retrieved_table), tf.assign(accumulator_var, retrieved_accumulator), tf.assign(linear_var, retrieved_linear)) retrieve_op_list.append(retrieve_parameters_op) return load_op_list, retrieve_op_list
def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table): p = self.params load_op_list = [] retrieve_op_list = [] num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts table_name = tpu_embedding_table.table_name slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars'] for host_id, table_var in zip(range(num_tpu_hosts), table_vars): # The slot vars should be on the same device as the table var. device_name = tpu_embedding_table.GetDeviceName(host_id) with tf.device(device_name), py_utils.outside_all_rewrites(): m_adam = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(0.0), dtype=p.dtype, collections=slot_var_collections) var_name_m = tpu_embedding_table.GetVariableName(host_id) + '/Adam/m' tpu_embedding_table.CreateVariable(var_name_m, m_adam, trainable=False) m_var = tpu_embedding_table.vars[var_name_m] v_adam = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(0.0), dtype=p.dtype, collections=slot_var_collections) var_name_v = tpu_embedding_table.GetVariableName(host_id) + '/Adam/v' tpu_embedding_table.CreateVariable(var_name_v, v_adam, trainable=False) v_var = tpu_embedding_table.vars[var_name_v] # Only the Trainer needs these ops. if py_utils.use_tpu(): # Remove the slot vars from the variable list to avoid them being # copied to TPU. _RemovePrivateVar(tpu_embedding_table, var_name_m) _RemovePrivateVar(tpu_embedding_table, var_name_v) # TPU Embedding load/retrieve ops need to be in the outer graph scope. with tf.init_scope(): tf.logging.info('creating load and retrieve ops.') load_parameters_op = ( tpu_embedding_lib.tpu_ops.load_tpu_embedding_adam_parameters( parameters=table_var, momenta=m_var, velocities=v_var, table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) load_op_list.append(load_parameters_op) retrieved_table, retrieved_m, retrieved_v = ( tpu_embedding_lib.tpu_ops .retrieve_tpu_embedding_adam_parameters( table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group( tf.assign(table_var, retrieved_table), tf.assign(m_var, retrieved_m), tf.assign(v_var, retrieved_v)) retrieve_op_list.append(retrieve_parameters_op) return load_op_list, retrieve_op_list
def GenericInputV2Create(processor, **kwargs): # pyformat: disable """Builds a generic input pipeline with an explicit resource handle. The resource handle uniquely identifies each GenericInputV2 dataset. This handle is passsed into method GenericInputV2GetNext to get a batch. Example usage: def ParseRecord(record): # Given a tf.string record, return a (NestedMap, bucketing key) pair. feature_map = ... features = tf.io.parse_single_example(record, feature_map) # Each example is represented by a NestedMap of tensors (without a # batch dimension). example = py_utils.NestedMap(field1=..., field2=...) # bucketing_key is a scalar convertible to tf.int32. # Use 1 if all examples are of the same size. bucketing_key = 1 return example, bucketing_key resource, out_types, output_tmpl = GenericInputV2Create(ParseRecord, ...) input_batch, ... = GenericInputV2GetNext(resource, out_types, output_tmpl) # input_batch is a NestedMap of tensors, where dim 0 of each tensor # represents the batch dimension. input_batch.field1 = ... ParseRecord can also take both 'source_id' and 'record' as inputs (the arg names must be exactly 'source_id' and 'record'): def ParseRecord(source_id, record): # Given a tf.int32 source_id and a tf.string record, return a (NestedMap, # bucketing key) pair. example = py_utils.NestedMap(source_id=source_id, ...) ... return example, bucketing_key resource, out_types, output_tmpl = GenericInputV2Create( ParseRecord, file_pattern=..., ...) input_batch, bucket_keys = GenericInputV2GetNext( resource, out_types, output_tmpl) Args: processor: a function that takes either a tf.string record or a (source_id: tf.int32, record: tf.string) pair as input and returns a tuple (output, bucketing_key). `output` must be a NestedMap or a list of tensors representing an example. `bucketing_key` must be a scalar convertible to a tf.int32 tensor that represents the bucketing key (e.g., sequence length for sequence inputs). If `bucketing_key` is a negative number, the record is dropped. **kwargs: additional keyword args for x_ops.generic_input_v2_create. Returns: A tuple of (resource, out_types, output_tmpl): - resource: a handle that uniquely identifies the created GenericInputV2 resource. - out_types: a list of tensor types representing the types in each batch. - output_tmpl: a NestedMap that will be used to pack each batch. """ # pyformat: enable proc_fn, out_types, output_tmpl = _ParseProcessor(processor) # "Lifts" the resource creation outside of tf.function Graphs (i.e. # FuncGraphs). This is necessary when tf.function is retraced, but the # same GenericInputV2 resource needs to be used. with tf.init_scope(): return ops.gen_x_ops.generic_input_v2_create( processor=proc_fn, out_types=out_types[:-1], **kwargs), out_types[:-1], output_tmpl
def GenericInput(processor, **kwargs): """Builds a generic input pipeline. Example usage: def ParseRecord(record): # Given a tf.string record, return a (NestedMap, bucketing key) pair. feature_map = ... features = tf.io.parse_single_example(record, feature_map) # Each example is represented by a NestedMap of tensors (without a # batch dimension). example = py_utils.NestedMap(field1=..., field2=...) # bucketing_key is a scalar convertible to tf.int32. # Use 1 if all examples are of the same size. bucketing_key = 1 return example, bucketing_key input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...) # input_batch is a NestedMap of tensors, where dim 0 of each tensor # represents the batch dimension. input_batch.field1 = ... ParseRecord can also take both 'source_id' and 'record' as inputs (the arg names must be exactly 'source_id' and 'record'): def ParseRecord(source_id, record): # Given a tf.int32 source_id and a tf.string record, return a (NestedMap, # bucketing key) pair. example = py_utils.NestedMap(source_id=source_id, ...) ... return example, bucketing_key input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...) Args: processor: a function that takes either a tf.string record or a (source_id: tf.int32, record: tf.string) pair as input and returns a tuple (output, bucketing_key). `output` must be a NestedMap or a list of tensors representing an example. `bucketing_key` must be a scalar convertible to a tf.int32 tensor that represents the bucketing key (e.g., sequence length for sequence inputs). If `bucketing_key` is a negative number, the record is dropped. **kwargs: additional keyword args for x_ops.generic_input. Returns: A tuple of (outputs, bucket_keys): - outputs: a NestedMap or a tuple of tensors, similar to `processor`'s return, except every tensor will have an additional dimension 0 that represents the batch dimension. - bucket_keys: a tf.int32 vector. Raises: RuntimeError: If called in pure Eager/tf.function mode without `generic_input_v2_key` defined. """ # In TF2 mode, call `GenericInputV2Create` and `GenericInputV2GetNext` for # the purpose of migration. if py_utils.IsEagerMode(): if not IsGenericInputV2AllwedInEager() and ( 'allow_eager' not in kwargs or not kwargs['allow_eager']): raise RuntimeError( 'GenericInput is called in tf.function or pure Eager mode. This means' ' you might be in the process of migrating your code from TF1 to TF2.' ' GenericInput is generally not safe for pure Eager mode and a newer ' 'version is introduced (GenericInputV2). To enable that, please ' 'add keyword arg `allow_eager=True` when calling GenericInput. ' 'Also, we recommend that you add extra tests for your own data ' 'pipeline in TF2 mode. Refer to b/223271939 for concrete examples.' ) kwargs.pop('allow_eager', None) generic_input_v2_key = kwargs.pop('generic_input_v2_key', None) if generic_input_v2_key is None: raise RuntimeError(_MISSING_KEY_ERR) if generic_input_v2_key in _GENERIC_CACHE_V2: resource, out_types, output_tmpl = _GENERIC_CACHE_V2[ generic_input_v2_key] else: with tf.init_scope(): resource, out_types, output_tmpl = GenericInputV2Create( processor, **kwargs) _GENERIC_CACHE_V2[generic_input_v2_key] = (resource, out_types, output_tmpl) return GenericInputV2GetNext(resource, out_types, output_tmpl) proc_fn, out_types, output_tmpl = _ParseProcessor(processor) flat_outputs, bucket_keys = ops.gen_x_ops.generic_input( processor=proc_fn, out_types=out_types[:-1], **kwargs) tf.logging.debug('x_ops.generic_input flat_outputs=%s', flat_outputs) # Pack flat_outputs to outputs. outputs = output_tmpl.Pack(flat_outputs).out_values if isinstance(outputs, list): outputs = tuple(outputs) # b/124336469 tf.logging.debug('x_ops.generic_input outputs=%s', outputs) return outputs, bucket_keys