def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table):
    p = self.params

    load_op_list = []
    retrieve_op_list = []

    num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts
    table_name = tpu_embedding_table.table_name
    slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars']

    for host_id, table_var in zip(range(num_tpu_hosts), table_vars):
      # The slot vars should be on the same device as the table var.
      device_name = tpu_embedding_table.GetDeviceName(host_id)
      with tf.device(device_name), py_utils.outside_all_rewrites():
        w_ada = py_utils.WeightParams(
            shape=table_var.shape.as_list(),
            init=py_utils.WeightInit.Constant(p.initial_accumulator),
            dtype=p.dtype,
            collections=slot_var_collections)
        var_name = tpu_embedding_table.GetVariableName(host_id) + '/Adagrad'
        tpu_embedding_table.CreateVariable(var_name, w_ada, trainable=False)
        accumulator_var = tpu_embedding_table.vars[var_name]

        # Only the Trainer needs these ops.
        if py_utils.use_tpu():
          # Remove the slot vars from the variable list to void copying them
          # to TPU (by the tf.cast in tpu_embedding_table.theta).
          # pylint: disable=protected-access
          del tpu_embedding_table._private_vars[var_name]
          del tpu_embedding_table._private_theta[var_name]
          # pylint: enable=protected-access

          # TPU Embedding load/retrieve ops need to be in the outer graph scope.
          with tf.init_scope():
            tf.logging.info('creating load and retrieve ops.')
            load_parameters_op = (
                tpu_embedding_lib.tpu_ops.load_tpu_embedding_adagrad_parameters(
                    parameters=table_var,
                    accumulators=accumulator_var,
                    table_name=table_name,
                    num_shards=num_tpu_hosts,
                    shard_id=host_id))
            load_op_list.append(load_parameters_op)

            retrieved_table, retrieved_accumulator = (
                tpu_embedding_lib.tpu_ops
                .retrieve_tpu_embedding_adagrad_parameters(
                    table_name=table_name,
                    num_shards=num_tpu_hosts,
                    shard_id=host_id))
            retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group(
                tf.assign(table_var, retrieved_table),
                tf.assign(accumulator_var, retrieved_accumulator))
            retrieve_op_list.append(retrieve_parameters_op)

    return load_op_list, retrieve_op_list
Exemple #2
0
 def GetNext(self):
   """Returns the next element from the dataset."""
   # Use `init_scope()` to ensure that the datasets and iterators are created
   # outside of the function-building graph. This ensures that these creation
   # operations are not repeated in subsequent `tf.function` calls.
   with tf.init_scope():
     self._InitIterator()
   if py_utils.GetUnitTestSession():
     self.Initialize(py_utils.GetUnitTestSession())
   return self._iterator[self.host_id].get_next()
Exemple #3
0
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        if self._num_micro_batches == 1:
            return self._opt.apply_gradients(grads_and_vars, global_step)
        global_step = global_step or py_utils.GetOrCreateGlobalStepVar()
        with tf.init_scope():
            self._create_slots([v for (_, v) in grads_and_vars])

        accums = []
        variables = []

        for g, v in grads_and_vars:
            accum = self.get_slot(v, 'grad_accum')
            variables.append(v)
            # pytype: disable=attribute-error
            if isinstance(g, tf.IndexedSlices):
                scaled_grad = tf.IndexedSlices(g.values /
                                               self._num_micro_batches,
                                               g.indices,
                                               dense_shape=g.dense_shape)
            else:
                scaled_grad = g / self._num_micro_batches
            accum_tensor = accum.read_value()
            accums.append(accum.assign(accum_tensor + scaled_grad))
            # pytype: enable=attribute-error

        def _ApplyAndReset():
            normalized_accums = accums
            if self._apply_crs_to_grad:
                normalized_accums = [
                    tf.tpu.cross_replica_sum(accum.read_value())
                    for accum in accums
                ]
            apply_op = self._opt.apply_gradients(
                list(zip(normalized_accums, variables)))
            with tf.control_dependencies([apply_op]):
                zero_op = [
                    tf.assign(accum, tf.zeros_like(accum)) for accum in accums
                ]
            return tf.group(zero_op, tf.assign_add(global_step, 1))

        def _Accum():
            return tf.no_op()

        accum_step = tf.cond(
            tf.equal(
                tf.math.floormod(self._counter + 1, self._num_micro_batches),
                0),
            _ApplyAndReset,  # Apply the accumulated gradients and reset.
            _Accum)  # Accumulate gradients.

        with tf.control_dependencies([tf.group(accums)]):
            return tf.group(accum_step, tf.assign_add(self._counter, 1))
Exemple #4
0
        def _BuildTpuEmbeddingApi():
            load_op_list = []
            retrieve_op_list = []

            num_cores = self.cluster.params.worker.tpus_per_replica
            global_batch_size = (self.params.batch_size *
                                 self.cluster.num_splits_per_client)
            table_to_config_dict = {}
            feature_to_config_dict = {}
            for table in self.tables:
                table_to_config_dict[table.table_name] = table.table_config
                load_op_list += table.load_op_list
                retrieve_op_list += table.retrieve_op_list
                for feature in table.input_keys:
                    feature_to_config_dict[
                        feature] = tpu_embedding_lib.FeatureConfig(
                            table.table_name,
                            max_sequence_length=table.max_sequence_length)

            mode = tpu_embedding_lib.TRAINING
            device_config = tpu_embedding_lib.DeviceConfig(
                num_cores=num_cores,
                num_hosts=self.params.tables[0].num_tpu_hosts,
                job_name=self.cluster.params.worker.name)
            tpu_embedding = tpu_embedding_lib.TPUEmbedding(
                table_to_config_dict,
                feature_to_config_dict,
                global_batch_size,
                mode,
                master=None,
                pipeline_execution_with_tensor_core=(
                    self.params.pipeline_execution_with_tensor_core),
                partition_strategy=p.partition_strategy,
                device_config=device_config)

            with tf.init_scope():
                dummy_variables, dummy_variables_init = (
                    tpu_embedding_gradient.create_dummy_table_variables(
                        tpu_embedding))
            load_op_list += [dummy_variables_init]

            tf.add_to_collection(py_utils.TPU_EMBEDDING, tpu_embedding)
            tf.add_to_collection(py_utils.TPU_EMBEDDING_DUMMY_VARS,
                                 dummy_variables)
            tf.add_to_collection(py_utils.TPU_EMBEDDING_LOAD_OPS, load_op_list)
            tf.add_to_collection(py_utils.TPU_EMBEDDING_RETRIEVE_OPS,
                                 retrieve_op_list)
Exemple #5
0
    def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table):
        load_op_list = []
        retrieve_op_list = []

        num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts
        table_name = tpu_embedding_table.table_name

        for host_id, table_var in zip(range(num_tpu_hosts), table_vars):
            # The slot vars should be on the same device as the table var.
            device_name = tpu_embedding_table.GetDeviceName(host_id)
            with tf.device(device_name), py_utils.outside_all_rewrites():
                # Only the Trainer needs these ops.
                if py_utils.use_tpu():
                    # TPU Embedding load/retrieve ops need to be in the outer graph scope.
                    with tf.init_scope():
                        tf.logging.info('creating load and retrieve ops.')
                        load_parameters_op = (
                            tpu_embedding_lib.tpu_ops.
                            load_tpu_embedding_stochastic_gradient_descent_parameters(
                                parameters=table_var,
                                table_name=table_name,
                                num_shards=num_tpu_hosts,
                                shard_id=host_id))
                        load_op_list.append(load_parameters_op)

                        retrieved_table = (
                            tpu_embedding_lib.tpu_ops.
                            retrieve_tpu_embedding_stochastic_gradient_descent_parameters(
                                table_name=table_name,
                                num_shards=num_tpu_hosts,
                                shard_id=host_id))
                        retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group(
                            tf.assign(table_var, retrieved_table))
                        retrieve_op_list.append(retrieve_parameters_op)

        return load_op_list, retrieve_op_list
Exemple #6
0
 def GetNext(self):
   """Override of the root's GetNext to support checking repeat sentinel."""
   # Use `init_scope()` to ensure that the datasets and iterators are created
   # outside of the function-building graph. This ensures that these creation
   # operations are not repeated in subsequent `tf.function` calls.
   with tf.init_scope():
     self._InitIterator()
   if py_utils.GetUnitTestSession():
     self.Initialize(py_utils.GetUnitTestSession())
   batch = self._iterator[self.host_id].get_next()
   # Sentinel check.
   if self._repeat_with_sentinel and not self._repeat_steps:
     assert_op = tf.debugging.assert_none_equal(
         batch[self.params.sentinel_key],
         tf.constant(self.params.sentinel_value),
         summarize=1,
         message='REPEAT_SENTINEL_')
     tf.logging.info('sentinel constant dtype %r',
                     tf.constant(self.params.sentinel_value))
     with tf.control_dependencies([assert_op]):
       # This identity transform will throw tf.errors.InvalidArgumentError
       # if assert_op fails (sentinel_key takes on sentinel_value).
       batch = batch.Transform(tf.identity)
   return batch
Exemple #7
0
  def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table):
    p = self.params

    load_op_list = []
    retrieve_op_list = []

    num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts
    table_name = tpu_embedding_table.table_name
    slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars']

    for host_id, table_var in zip(range(num_tpu_hosts), table_vars):
      # The slot vars should be on the same device as the table var.
      device_name = tpu_embedding_table.GetDeviceName(host_id)
      with tf.device(device_name), py_utils.outside_all_rewrites():
        accumulator = py_utils.WeightParams(
            shape=table_var.shape.as_list(),
            init=py_utils.WeightInit.Constant(p.initial_accumulator_value),
            dtype=p.dtype,
            collections=slot_var_collections)
        accumulator_name = (
            tpu_embedding_table.GetVariableName(host_id) + '/Ftrl')
        tpu_embedding_table.CreateVariable(
            accumulator_name, accumulator, trainable=False)
        accumulator_var = tpu_embedding_table.vars[accumulator_name]

        linear = py_utils.WeightParams(
            shape=table_var.shape.as_list(),
            init=py_utils.WeightInit.Constant(p.initial_linear_value),
            dtype=p.dtype,
            collections=slot_var_collections)
        linear_name = tpu_embedding_table.GetVariableName(host_id) + '/Ftrl_1'
        tpu_embedding_table.CreateVariable(linear_name, linear, trainable=False)
        linear_var = tpu_embedding_table.vars[linear_name]

        # Only the Trainer needs these ops.
        if py_utils.use_tpu():
          # Remove the slot vars from the variable list to avoid them being
          # copied to TPU.
          _RemovePrivateVar(tpu_embedding_table, accumulator_name)
          _RemovePrivateVar(tpu_embedding_table, linear_name)

          # TPU Embedding load/retrieve ops need to be in the outer graph scope.
          with tf.init_scope():
            tf.logging.info('creating load and retrieve ops.')
            load_parameters_op = (
                tpu_embedding_lib.tpu_ops.load_tpu_embedding_ftrl_parameters(
                    parameters=table_var,
                    accumulators=accumulator_var,
                    linears=linear_var,
                    table_name=table_name,
                    num_shards=num_tpu_hosts,
                    shard_id=host_id))
            load_op_list.append(load_parameters_op)

            retrieved_table, retrieved_accumulator, retrieved_linear = (
                tpu_embedding_lib.tpu_ops
                .retrieve_tpu_embedding_ftrl_parameters(
                    table_name=table_name,
                    num_shards=num_tpu_hosts,
                    shard_id=host_id))
            retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group(
                tf.assign(table_var, retrieved_table),
                tf.assign(accumulator_var, retrieved_accumulator),
                tf.assign(linear_var, retrieved_linear))
            retrieve_op_list.append(retrieve_parameters_op)

    return load_op_list, retrieve_op_list
Exemple #8
0
  def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table):
    p = self.params

    load_op_list = []
    retrieve_op_list = []

    num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts
    table_name = tpu_embedding_table.table_name
    slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars']

    for host_id, table_var in zip(range(num_tpu_hosts), table_vars):
      # The slot vars should be on the same device as the table var.
      device_name = tpu_embedding_table.GetDeviceName(host_id)
      with tf.device(device_name), py_utils.outside_all_rewrites():
        m_adam = py_utils.WeightParams(
            shape=table_var.shape.as_list(),
            init=py_utils.WeightInit.Constant(0.0),
            dtype=p.dtype,
            collections=slot_var_collections)
        var_name_m = tpu_embedding_table.GetVariableName(host_id) + '/Adam/m'
        tpu_embedding_table.CreateVariable(var_name_m, m_adam, trainable=False)
        m_var = tpu_embedding_table.vars[var_name_m]

        v_adam = py_utils.WeightParams(
            shape=table_var.shape.as_list(),
            init=py_utils.WeightInit.Constant(0.0),
            dtype=p.dtype,
            collections=slot_var_collections)
        var_name_v = tpu_embedding_table.GetVariableName(host_id) + '/Adam/v'
        tpu_embedding_table.CreateVariable(var_name_v, v_adam, trainable=False)
        v_var = tpu_embedding_table.vars[var_name_v]

        # Only the Trainer needs these ops.
        if py_utils.use_tpu():
          # Remove the slot vars from the variable list to avoid them being
          # copied to TPU.
          _RemovePrivateVar(tpu_embedding_table, var_name_m)
          _RemovePrivateVar(tpu_embedding_table, var_name_v)

          # TPU Embedding load/retrieve ops need to be in the outer graph scope.
          with tf.init_scope():
            tf.logging.info('creating load and retrieve ops.')
            load_parameters_op = (
                tpu_embedding_lib.tpu_ops.load_tpu_embedding_adam_parameters(
                    parameters=table_var,
                    momenta=m_var,
                    velocities=v_var,
                    table_name=table_name,
                    num_shards=num_tpu_hosts,
                    shard_id=host_id))
            load_op_list.append(load_parameters_op)

            retrieved_table, retrieved_m, retrieved_v = (
                tpu_embedding_lib.tpu_ops
                .retrieve_tpu_embedding_adam_parameters(
                    table_name=table_name,
                    num_shards=num_tpu_hosts,
                    shard_id=host_id))
            retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group(
                tf.assign(table_var, retrieved_table),
                tf.assign(m_var, retrieved_m), tf.assign(v_var, retrieved_v))
            retrieve_op_list.append(retrieve_parameters_op)

    return load_op_list, retrieve_op_list
Exemple #9
0
def GenericInputV2Create(processor, **kwargs):
    # pyformat: disable
    """Builds a generic input pipeline with an explicit resource handle.

  The resource handle uniquely identifies each GenericInputV2 dataset. This
  handle is passsed into method GenericInputV2GetNext to get a batch.

  Example usage:

    def ParseRecord(record):
      # Given a tf.string record, return a (NestedMap, bucketing key) pair.
      feature_map = ...
      features = tf.io.parse_single_example(record, feature_map)
      # Each example is represented by a NestedMap of tensors (without a
      # batch dimension).
      example = py_utils.NestedMap(field1=..., field2=...)
      # bucketing_key is a scalar convertible to tf.int32.
      # Use 1 if all examples are of the same size.
      bucketing_key = 1
      return example, bucketing_key

    resource, out_types, output_tmpl = GenericInputV2Create(ParseRecord, ...)
    input_batch, ... = GenericInputV2GetNext(resource, out_types, output_tmpl)
    # input_batch is a NestedMap of tensors, where dim 0 of each tensor
    # represents the batch dimension.
    input_batch.field1 = ...

  ParseRecord can also take both 'source_id' and 'record' as inputs (the arg
  names must be exactly 'source_id' and 'record'):

    def ParseRecord(source_id, record):
      # Given a tf.int32 source_id and a tf.string record, return a (NestedMap,
      # bucketing key) pair.
      example = py_utils.NestedMap(source_id=source_id, ...)
      ...
      return example, bucketing_key

    resource, out_types, output_tmpl = GenericInputV2Create(
        ParseRecord, file_pattern=..., ...)
    input_batch, bucket_keys = GenericInputV2GetNext(
        resource, out_types, output_tmpl)

  Args:
    processor: a function that takes either a tf.string record or a
      (source_id: tf.int32, record: tf.string) pair as input and returns a
      tuple (output, bucketing_key). `output` must be a NestedMap or a list of
      tensors representing an example. `bucketing_key` must be a scalar
      convertible to a tf.int32 tensor that represents the bucketing key
      (e.g., sequence length for sequence inputs). If `bucketing_key` is a
      negative number, the record is dropped.
    **kwargs: additional keyword args for x_ops.generic_input_v2_create.

  Returns:
    A tuple of (resource, out_types, output_tmpl):

    - resource: a handle that uniquely identifies the created GenericInputV2
        resource.
    - out_types: a list of tensor types representing the types in each batch.
    - output_tmpl: a NestedMap that will be used to pack each batch.
  """
    # pyformat: enable
    proc_fn, out_types, output_tmpl = _ParseProcessor(processor)
    # "Lifts" the resource creation outside of tf.function Graphs (i.e.
    # FuncGraphs). This is necessary when tf.function is retraced, but the
    # same GenericInputV2 resource needs to be used.
    with tf.init_scope():
        return ops.gen_x_ops.generic_input_v2_create(
            processor=proc_fn, out_types=out_types[:-1],
            **kwargs), out_types[:-1], output_tmpl
Exemple #10
0
def GenericInput(processor, **kwargs):
    """Builds a generic input pipeline.

  Example usage:

    def ParseRecord(record):
      # Given a tf.string record, return a (NestedMap, bucketing key) pair.
      feature_map = ...
      features = tf.io.parse_single_example(record, feature_map)
      # Each example is represented by a NestedMap of tensors (without a
      # batch dimension).
      example = py_utils.NestedMap(field1=..., field2=...)
      # bucketing_key is a scalar convertible to tf.int32.
      # Use 1 if all examples are of the same size.
      bucketing_key = 1
      return example, bucketing_key

    input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...)
    # input_batch is a NestedMap of tensors, where dim 0 of each tensor
    # represents the batch dimension.
    input_batch.field1 = ...

  ParseRecord can also take both 'source_id' and 'record' as inputs (the arg
  names must be exactly 'source_id' and 'record'):

    def ParseRecord(source_id, record):
      # Given a tf.int32 source_id and a tf.string record, return a (NestedMap,
      # bucketing key) pair.
      example = py_utils.NestedMap(source_id=source_id, ...)
      ...
      return example, bucketing_key

    input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...)

  Args:
    processor: a function that takes either a tf.string record or a
      (source_id: tf.int32, record: tf.string) pair as input and returns a tuple
      (output, bucketing_key). `output` must be a NestedMap or a list of tensors
      representing an example. `bucketing_key` must be a scalar convertible to
      a tf.int32 tensor that represents the bucketing key (e.g., sequence
      length for sequence inputs). If `bucketing_key` is a negative number,
      the record is dropped.
    **kwargs: additional keyword args for x_ops.generic_input.

  Returns:
    A tuple of (outputs, bucket_keys):

    - outputs: a NestedMap or a tuple of tensors, similar to `processor`'s
      return,  except every tensor will have an additional dimension 0 that
      represents the batch dimension.
    - bucket_keys: a tf.int32 vector.

  Raises:
    RuntimeError: If called in pure Eager/tf.function mode without
      `generic_input_v2_key` defined.
  """
    # In TF2 mode, call `GenericInputV2Create` and `GenericInputV2GetNext` for
    # the purpose of migration.
    if py_utils.IsEagerMode():
        if not IsGenericInputV2AllwedInEager() and (
                'allow_eager' not in kwargs or not kwargs['allow_eager']):
            raise RuntimeError(
                'GenericInput is called in tf.function or pure Eager mode. This means'
                ' you might be in the process of migrating your code from TF1 to TF2.'
                ' GenericInput is generally not safe for pure Eager mode and a newer '
                'version is introduced (GenericInputV2). To enable that, please '
                'add keyword arg `allow_eager=True` when calling GenericInput. '
                'Also, we recommend that you add extra tests for your own data '
                'pipeline in TF2 mode. Refer to b/223271939 for concrete examples.'
            )

        kwargs.pop('allow_eager', None)
        generic_input_v2_key = kwargs.pop('generic_input_v2_key', None)
        if generic_input_v2_key is None:
            raise RuntimeError(_MISSING_KEY_ERR)

        if generic_input_v2_key in _GENERIC_CACHE_V2:
            resource, out_types, output_tmpl = _GENERIC_CACHE_V2[
                generic_input_v2_key]
        else:
            with tf.init_scope():
                resource, out_types, output_tmpl = GenericInputV2Create(
                    processor, **kwargs)

        _GENERIC_CACHE_V2[generic_input_v2_key] = (resource, out_types,
                                                   output_tmpl)
        return GenericInputV2GetNext(resource, out_types, output_tmpl)

    proc_fn, out_types, output_tmpl = _ParseProcessor(processor)
    flat_outputs, bucket_keys = ops.gen_x_ops.generic_input(
        processor=proc_fn, out_types=out_types[:-1], **kwargs)
    tf.logging.debug('x_ops.generic_input flat_outputs=%s', flat_outputs)
    # Pack flat_outputs to outputs.
    outputs = output_tmpl.Pack(flat_outputs).out_values
    if isinstance(outputs, list):
        outputs = tuple(outputs)  # b/124336469
    tf.logging.debug('x_ops.generic_input outputs=%s', outputs)
    return outputs, bucket_keys