Exemple #1
0
    def _CreateLayerVariables(self):
        p = self.params
        w_pc = py_utils.WeightParams(
            shape=[self._ids_per_shard, p.embedding_dim],
            init=p.params_init,
            dtype=p.dtype,
            collections=[self.__class__.__name__ + '_vars'])

        embedding_table_vars = []
        for i in range(p.num_tpu_hosts):
            device_name = self.GetDeviceName(i)
            with tf.device(device_name), py_utils.outside_all_rewrites():
                var_name = self.GetVariableName(i)
                self.CreateVariable(var_name, w_pc)
                embedding_var = self.vars[var_name]
                embedding_table_vars.append(embedding_var)
                # Remove from _private_vars / _private_thetas to be added later as wm.
                del self._private_vars[var_name]
                del self._private_theta[var_name]

        if not py_utils.use_tpu():
            # We don't want to add this for TrainerTpu, otherwise the identity
            # reference leads to copying the embedding to the TPU for no reason.
            # However, this is needed for CPU (eval/decode/controller).
            self._private_vars['wm'] = embedding_table_vars
            self._private_theta['wm'] = [
                tf.identity(v) for v in embedding_table_vars
            ]

        # Only trainer and controller need slot variables and load/retrieve ops.
        if not self.do_eval:
            self._load_op_list, self._retrieve_op_list = (
                self.optimizer.CreateSlotVariablesAndOps(
                    embedding_table_vars, self))
  def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table):
    p = self.params

    load_op_list = []
    retrieve_op_list = []

    num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts
    table_name = tpu_embedding_table.table_name
    slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars']

    for host_id, table_var in zip(range(num_tpu_hosts), table_vars):
      # The slot vars should be on the same device as the table var.
      device_name = tpu_embedding_table.GetDeviceName(host_id)
      with tf.device(device_name), py_utils.outside_all_rewrites():
        w_ada = py_utils.WeightParams(
            shape=table_var.shape.as_list(),
            init=py_utils.WeightInit.Constant(p.initial_accumulator),
            dtype=p.dtype,
            collections=slot_var_collections)
        var_name = tpu_embedding_table.GetVariableName(host_id) + '/Adagrad'
        tpu_embedding_table.CreateVariable(var_name, w_ada, trainable=False)
        accumulator_var = tpu_embedding_table.vars[var_name]

        # Only the Trainer needs these ops.
        if py_utils.use_tpu():
          # Remove the slot vars from the variable list to void copying them
          # to TPU (by the tf.cast in tpu_embedding_table.theta).
          # pylint: disable=protected-access
          del tpu_embedding_table._private_vars[var_name]
          del tpu_embedding_table._private_theta[var_name]
          # pylint: enable=protected-access

          # TPU Embedding load/retrieve ops need to be in the outer graph scope.
          with tf.init_scope():
            tf.logging.info('creating load and retrieve ops.')
            load_parameters_op = (
                tpu_embedding_lib.tpu_ops.load_tpu_embedding_adagrad_parameters(
                    parameters=table_var,
                    accumulators=accumulator_var,
                    table_name=table_name,
                    num_shards=num_tpu_hosts,
                    shard_id=host_id))
            load_op_list.append(load_parameters_op)

            retrieved_table, retrieved_accumulator = (
                tpu_embedding_lib.tpu_ops
                .retrieve_tpu_embedding_adagrad_parameters(
                    table_name=table_name,
                    num_shards=num_tpu_hosts,
                    shard_id=host_id))
            retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group(
                tf.assign(table_var, retrieved_table),
                tf.assign(accumulator_var, retrieved_accumulator))
            retrieve_op_list.append(retrieve_parameters_op)

    return load_op_list, retrieve_op_list
Exemple #3
0
  def _CreateLayerVariables(self):
    p = self.params

    # Reuse the singleton table variables if they were created before.
    all_table_vars = self._tpu_embedding_collection.table_variables
    if self.table_name in all_table_vars:
      embedding_table_vars = all_table_vars[self.table_name]
    else:
      w_pc = py_utils.WeightParams(
          shape=[self._ids_per_shard, p.embedding_dim],
          init=p.params_init,
          dtype=p.dtype,
          collections=[self.__class__.__name__ + '_vars'])

      embedding_table_vars = []
      for i in range(p.num_tpu_hosts):
        device_name = self.GetDeviceName(i)
        with tf.device(device_name), py_utils.outside_all_rewrites():
          var_name = self.GetVariableName(i)
          self.CreateVariable(var_name, w_pc)
          embedding_var = self.vars[var_name]
          embedding_table_vars.append(embedding_var)
          # Remove from _private_vars / _private_thetas to be added later as wm.
          _RemovePrivateVar(self, var_name)

      self._tpu_embedding_collection.AddTableVariables(self.table_name,
                                                       embedding_table_vars)

    if not _ShouldUseTpu(p):
      # We don't want to add this for TrainerTpu, otherwise the identity
      # reference leads to copying the embedding to the TPU for no reason.
      # However, this is needed for CPU (eval/decode/controller).
      self._private_vars['wm'] = embedding_table_vars
      self._private_theta['wm'] = [tf.identity(v) for v in embedding_table_vars]

    # If slot variables and load/retrieve ops were created before, maybe by a
    # different program or task, don't create it again.
    # Note that there should be only one copy of slot variables and
    # load/retrieve ops in the graph and they're shared by different
    # tasks/programs.
    all_load_ops = self._tpu_embedding_collection.load_ops
    if self.table_name not in all_load_ops:
      assert self.table_name not in self._tpu_embedding_collection.retrieve_ops
      # Only trainer and controller (for checkpointing) need slot variables.
      # Only trainer needs load/retrieve ops.
      if not self.do_eval and not p.is_inference:
        load_ops, retrieve_ops = self.optimizer.CreateSlotVariablesAndOps(
            embedding_table_vars, self)
        self._tpu_embedding_collection.AddLoadRetrieveOps(
            self.table_name, load_ops, retrieve_ops)
Exemple #4
0
    def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table):
        load_op_list = []
        retrieve_op_list = []

        num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts
        table_name = tpu_embedding_table.table_name

        for host_id, table_var in zip(range(num_tpu_hosts), table_vars):
            # The slot vars should be on the same device as the table var.
            device_name = tpu_embedding_table.GetDeviceName(host_id)
            with tf.device(device_name), py_utils.outside_all_rewrites():
                # Only the Trainer needs these ops.
                if py_utils.use_tpu():
                    # TPU Embedding load/retrieve ops need to be in the outer graph scope.
                    with tf.init_scope():
                        tf.logging.info('creating load and retrieve ops.')
                        load_parameters_op = (
                            tpu_embedding_lib.tpu_ops.
                            load_tpu_embedding_stochastic_gradient_descent_parameters(
                                parameters=table_var,
                                table_name=table_name,
                                num_shards=num_tpu_hosts,
                                shard_id=host_id))
                        load_op_list.append(load_parameters_op)

                        retrieved_table = (
                            tpu_embedding_lib.tpu_ops.
                            retrieve_tpu_embedding_stochastic_gradient_descent_parameters(
                                table_name=table_name,
                                num_shards=num_tpu_hosts,
                                shard_id=host_id))
                        retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group(
                            tf.assign(table_var, retrieved_table))
                        retrieve_op_list.append(retrieve_parameters_op)

        return load_op_list, retrieve_op_list
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info(
            'CreateTPUFeeds num_splits_per_client={} '
            'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'.
            format(cluster.num_splits_per_client,
                   cluster.num_devices_per_split, num_tpu_hosts,
                   p.use_per_host_infeed))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            tf.logging.info('shards {}'.format(shards))

            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            if num_tpu_hosts > 1 and tpu_embedding is not None:
                if not p.use_per_host_infeed:
                    tf.logging.fatal(
                        'TPU Embedding must be used with per_host_infeed with multiple '
                        'TPU host topologies.')
            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if isinstance(batch, py_utils.NestedMap):
                        # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                        # Note that when MultiTaskData is used, bucket_keys will be at the
                        # second level of the dictionary.
                        batch = batch.FilterKeyVal(
                            lambda k, _: not k.endswith('bucket_keys'))
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    if p.use_partitioned_infeed_queue:
                        device_assignment = py_utils.GetTpuDeviceAssignment()

                        host_device = device_assignment.host_device(
                            replica=0, job=tf.flags.FLAGS.tf_master)
                        host_id = int(
                            host_device.split('/task:')[1].split('/device:')
                            [0])
                        tf.logging.info('host_id: {} host_device: {}'.format(
                            host_id, host_device))
                        q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                            number_of_tuple_elements=len(dtypes),
                            device_assignment=device_assignment,
                            host_id=host_id,
                            input_partition_dims=[[p.num_partitions, 1]
                                                  for _ in dtypes],
                            tuple_types=dtypes,
                            tuple_shapes=shapes)
                    else:
                        q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                                 tuple_shapes=shapes)
                        assert shards is not None
                        q.set_number_of_shards(shards)

                    queues.append(q)
                    tf.logging.info('q=%r', q)

                    if p.use_partitioned_infeed_queue:
                        input_ops = q.generate_enqueue_ops([batch.Flatten()])
                    elif p.use_per_host_infeed:
                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        self._tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
Exemple #6
0
  def __init__(self, params):
    assert issubclass(params.cls, BaseTask)
    super(BaseTask, self).__init__(params)

    p = self.params

    if p.input:
      # TODO(zhifengc): Consider a simpler way to ensure the input
      # generator stops after one epoch.
      if p.is_eval and p.eval:
        seq_inp = issubclass(p.input.cls,
                             base_input_generator.BaseInputGeneratorFromFiles)
        if p.input.num_samples == 0:
          # Dataset size is unknown. Computes eval summary based on num_samples.
          assert p.eval.samples_per_summary > 0
        elif (p.eval.samples_per_summary == 0) or (p.input.num_samples <
                                                   p.eval.samples_per_summary):
          # If we know the dataset size and we want to evaluate the full
          # set, we need to coordinate the input generator to flush out
          # all samples so the evaler and decoder compute metrics on the
          # whole set for each summary step.
          if seq_inp:
            p.input.flush_every_n = p.input.num_samples
          p.eval.samples_per_summary = p.input.num_samples
        if seq_inp and p.input.num_batcher_threads > 1:
          tf.logging.warning('input.num_batcher_threads > 1 inside eval mode.  '
                             'The input generator may not iterate over exactly '
                             'one epoch per run')

      with tf.device(
          self.cluster.input_device), py_utils.outside_all_rewrites():
        self.CreateChild('input', p.input)

    self._var_grads = None
    self._encoder = None
    self._online_encoder = None
    self._decoder = None

    self._total_examples = None
    self._total_nans_and_infs = None
    self._loss = None
    self._num_predictions = None
    self._train_op = None
    self._eval_metrics = {}
    self._trainer_verbose_tensors = {}

    # Create the gradient mask,
    self._per_input_gradient_mask = None
    self._shared_global_step = py_utils.GetOrCreateGlobalStep()
    tp = p.train
    if tp:
      if tp.task_global_step:
        self._task_global_step = CreateTaskGlobalStep(p, p.name)
        self._global_step = self._task_global_step
      else:
        self._task_global_step = None
        self._global_step = self._shared_global_step
      if tp.grad_norm_tracker:
        with tf.variable_scope(p.name):
          self.CreateChild('grad_norm_tracker', tp.grad_norm_tracker)

      self.CreateChild('lr_schedule', tp.lr_schedule)
      self.CreateChild('optimizer', tp.optimizer)
    self._UpdateVnConfig()
Exemple #7
0
  def CreateTpuFeeds(self):
    """Creates the TPU infeed queue from preprocessed batch."""
    p = self.params
    cluster = cluster_factory.Current()
    num_tpu_hosts = cluster.num_tpu_hosts
    assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
    num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

    with py_utils.outside_all_rewrites():
      assert py_utils.use_tpu()
      assert not self._made_tpu_infeed

      shards = tpu_function.get_tpu_context(
      ).number_of_shards // num_infeed_hosts
      input_ops_list = []
      queues = []
      first_batch = None
      for task_id in range(num_infeed_hosts):
        host_device = '/task:{}/device:CPU:0'.format(task_id)
        with tf.device(host_device):
          batch = self.GetPreprocessedInputBatch()
          if first_batch is None:
            first_batch = batch
          flat_batch = batch.FlattenItems()

          shapes, types = [], []
          for k, x in flat_batch:
            assert x.shape.is_fully_defined(), (
                'Shape must be fully defined: %s: %s' % (k, x))
            # TODO(cwhipkey): if it's a string (or other type not supported on
            # TPU), drop it from feeding and on the other end add in an op that
            # fails if used.
            shapes.append(x.shape)
            types.append(x.dtype)
          q = tf.contrib.tpu.InfeedQueue(tuple_types=types, tuple_shapes=shapes)
          queues.append(q)
          assert shards is not None
          q.set_number_of_shards(shards)

          if p.use_per_host_infeed:

            # TODO(ylc/zhifengc): Add this to a policy module and test it.
            def _tpu_ordinal_function(shard_index_in_host):
              device_assignment = py_utils.GetTpuDeviceAssignment()
              if device_assignment:
                # We put both enqueue/dequeue ops at core 0 in each replica.
                replica = device_assignment.lookup_replicas(
                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                return device_assignment.tpu_ordinal(replica=replica)
              else:
                return shard_index_in_host

            input_ops = q.split_inputs_and_generate_enqueue_ops(
                [v for _, v in flat_batch],
                placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                tpu_ordinal_function=_tpu_ordinal_function)
          else:
            input_ops = q.split_inputs_and_generate_enqueue_ops(
                [v for _, v in flat_batch],
                device_assignment=py_utils.GetTpuDeviceAssignment())

          input_ops_list += input_ops
      tf.logging.info('input_ops_list %s', input_ops_list)
      tpu_infeed_op = tf.group(*input_ops_list)
    self._made_tpu_infeed = True
    # Let trainer.py use multiple threads to drive the infeed op.
    for _ in range(p.tpu_infeed_parallism):
      tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

    with tf.device(tf.contrib.tpu.core(0)):
      tensors = queues[0].generate_dequeue_op()
    return first_batch.Pack(tensors)
Exemple #8
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host))
        tf.logging.info('num_devices_per_split {}'.format(
            cluster.num_devices_per_split))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            input_ops_list = []
            queues = []
            first_batch = None
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            tpu_embedding_input_keys = (
                tpu_embedding.feature_to_config_dict.keys()
                if tpu_embedding is not None else [])

            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    tpu_embedding_features = []
                    for tpu_embedding_input_key in tpu_embedding_input_keys:
                        tpu_embedding_feature = batch.pop(
                            tpu_embedding_input_key)
                        tpu_embedding_features.append(
                            (tpu_embedding_input_key, tpu_embedding_feature))

                    if first_batch is None:
                        first_batch = batch
                    flat_batch = batch.FlattenItems()

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {}
                        ] * tpu_embedding.num_cores_per_host
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for tpu_embedding_input_key, tpu_embedding_feature in tpu_embedding_features:
                            tpu_embedding_feature_splitted = tf.split(
                                tpu_embedding_feature, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_embedding_feature_splitted):
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    tf.squeeze(split, axis=[1]))
                                enqueue_dict_per_core[core][
                                    tpu_embedding_input_key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    shapes, types = [], []
                    for k, x in flat_batch:
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                        shapes.append(x.shape)
                        types.append(x.dtype)
                    q = tf.contrib.tpu.InfeedQueue(tuple_types=types,
                                                   tuple_shapes=shapes)
                    queues.append(q)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                    if p.use_per_host_infeed:

                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def _tpu_ordinal_function(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            [v for _, v in flat_batch],
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=_tpu_ordinal_function)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            [v for _, v in flat_batch],
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        with tf.device(tf.compat.v1.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return first_batch.Pack(tensors)
Exemple #9
0
  def __init__(self, params):
    assert issubclass(params.cls, BaseTask)
    # Ensure global_step exists before calling super.
    py_utils.GetOrCreateGlobalStepVar()
    super(BaseTask, self).__init__(params)

    p = self.params

    if p.input:
      # TODO(zhifengc): Consider a simpler way to ensure the input
      # generator stops after one epoch.
      if p.is_eval and p.eval:
        seq_inp = issubclass(p.input.cls,
                             base_input_generator.BaseInputGeneratorFromFiles)
        if p.input.num_samples == 0:
          # Dataset size is unknown. Computes eval summary based on num_samples.
          assert p.eval.samples_per_summary > 0
        elif (p.eval.samples_per_summary == 0) or (p.input.num_samples <
                                                   p.eval.samples_per_summary):
          # If we know the dataset size and we want to evaluate the full
          # set, we need to coordinate the input generator to flush out
          # all samples so the evaler and decoder compute metrics on the
          # whole set for each summary step.
          if seq_inp:
            p.input.flush_every_n = p.input.num_samples
          p.eval.samples_per_summary = p.input.num_samples
        if seq_inp and p.input.num_batcher_threads > 1:
          tf.logging.warning('input.num_batcher_threads > 1 inside eval mode.  '
                             'The input generator may not iterate over exactly '
                             'one epoch per run')
      tf.logging.info('input_params: %s', p.input)
      input_params = self.cluster.PlaceInput(p.input)
      with py_utils.outside_all_rewrites():
        self.CreateChild('input', input_params)

    self._encoder = None
    self._online_encoder = None
    self._decoder = None

    self._loss = None
    self._num_predictions = None
    self._train_op = None
    self._eval_metrics = {}
    self._per_example = {}
    self._trainer_verbose_tensors = {}

    # Create the gradient mask,
    self._per_input_gradient_mask = None
    task_global_step_list = tf.get_collection('TASK_GLOBAL_STEP',
                                              '^%s_global_step' % p.name)
    if len(task_global_step_list) > 1:
      raise ValueError('Found multiple task_global_step for task %s' % p.name)
    self._global_step_var = (
        task_global_step_list[0] if len(task_global_step_list) == 1 else
        py_utils.GetOrCreateGlobalStepVar())
    self._global_step = tf.identity(
        self._global_step_var, name='global_step_tensor')
    tp = p.train
    # p.train can be None if this task is the teacher/student task in a
    # DistillationTask.
    if tp and self.cluster.job in ('worker', 'trainer', 'trainer_client',
                                   'controller', 'executor_tpu'):
      self._SetLearnerFromLegacyParams(tp)
      if tp.learner is not None:
        if isinstance(tp.learner, (list, tuple)):
          self.CreateChildren('learners', tp.learner)
        else:
          self.CreateChildren('learners', [tp.learner])
    self._UpdateVnConfig()
Exemple #10
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host))
        tf.logging.info('num_devices_per_split {}'.format(
            cluster.num_devices_per_split))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if 'bucket_keys' in batch:
                        # Hack: bucket_keys are not needed on TPU.
                        del batch['bucket_keys']
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                             tuple_shapes=shapes)
                    queues.append(q)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                    if p.use_per_host_infeed:

                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        # For executor-driven multiple programs, we need more fine-grained
        # access rather than using a single global graph collection.
        self.tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
Exemple #11
0
  def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table):
    p = self.params

    load_op_list = []
    retrieve_op_list = []

    num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts
    table_name = tpu_embedding_table.table_name
    slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars']

    for host_id, table_var in zip(range(num_tpu_hosts), table_vars):
      # The slot vars should be on the same device as the table var.
      device_name = tpu_embedding_table.GetDeviceName(host_id)
      with tf.device(device_name), py_utils.outside_all_rewrites():
        accumulator = py_utils.WeightParams(
            shape=table_var.shape.as_list(),
            init=py_utils.WeightInit.Constant(p.initial_accumulator_value),
            dtype=p.dtype,
            collections=slot_var_collections)
        accumulator_name = (
            tpu_embedding_table.GetVariableName(host_id) + '/Ftrl')
        tpu_embedding_table.CreateVariable(
            accumulator_name, accumulator, trainable=False)
        accumulator_var = tpu_embedding_table.vars[accumulator_name]

        linear = py_utils.WeightParams(
            shape=table_var.shape.as_list(),
            init=py_utils.WeightInit.Constant(p.initial_linear_value),
            dtype=p.dtype,
            collections=slot_var_collections)
        linear_name = tpu_embedding_table.GetVariableName(host_id) + '/Ftrl_1'
        tpu_embedding_table.CreateVariable(linear_name, linear, trainable=False)
        linear_var = tpu_embedding_table.vars[linear_name]

        # Only the Trainer needs these ops.
        if py_utils.use_tpu():
          # Remove the slot vars from the variable list to avoid them being
          # copied to TPU.
          _RemovePrivateVar(tpu_embedding_table, accumulator_name)
          _RemovePrivateVar(tpu_embedding_table, linear_name)

          # TPU Embedding load/retrieve ops need to be in the outer graph scope.
          with tf.init_scope():
            tf.logging.info('creating load and retrieve ops.')
            load_parameters_op = (
                tpu_embedding_lib.tpu_ops.load_tpu_embedding_ftrl_parameters(
                    parameters=table_var,
                    accumulators=accumulator_var,
                    linears=linear_var,
                    table_name=table_name,
                    num_shards=num_tpu_hosts,
                    shard_id=host_id))
            load_op_list.append(load_parameters_op)

            retrieved_table, retrieved_accumulator, retrieved_linear = (
                tpu_embedding_lib.tpu_ops
                .retrieve_tpu_embedding_ftrl_parameters(
                    table_name=table_name,
                    num_shards=num_tpu_hosts,
                    shard_id=host_id))
            retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group(
                tf.assign(table_var, retrieved_table),
                tf.assign(accumulator_var, retrieved_accumulator),
                tf.assign(linear_var, retrieved_linear))
            retrieve_op_list.append(retrieve_parameters_op)

    return load_op_list, retrieve_op_list
Exemple #12
0
  def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table):
    p = self.params

    load_op_list = []
    retrieve_op_list = []

    num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts
    table_name = tpu_embedding_table.table_name
    slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars']

    for host_id, table_var in zip(range(num_tpu_hosts), table_vars):
      # The slot vars should be on the same device as the table var.
      device_name = tpu_embedding_table.GetDeviceName(host_id)
      with tf.device(device_name), py_utils.outside_all_rewrites():
        m_adam = py_utils.WeightParams(
            shape=table_var.shape.as_list(),
            init=py_utils.WeightInit.Constant(0.0),
            dtype=p.dtype,
            collections=slot_var_collections)
        var_name_m = tpu_embedding_table.GetVariableName(host_id) + '/Adam/m'
        tpu_embedding_table.CreateVariable(var_name_m, m_adam, trainable=False)
        m_var = tpu_embedding_table.vars[var_name_m]

        v_adam = py_utils.WeightParams(
            shape=table_var.shape.as_list(),
            init=py_utils.WeightInit.Constant(0.0),
            dtype=p.dtype,
            collections=slot_var_collections)
        var_name_v = tpu_embedding_table.GetVariableName(host_id) + '/Adam/v'
        tpu_embedding_table.CreateVariable(var_name_v, v_adam, trainable=False)
        v_var = tpu_embedding_table.vars[var_name_v]

        # Only the Trainer needs these ops.
        if py_utils.use_tpu():
          # Remove the slot vars from the variable list to avoid them being
          # copied to TPU.
          _RemovePrivateVar(tpu_embedding_table, var_name_m)
          _RemovePrivateVar(tpu_embedding_table, var_name_v)

          # TPU Embedding load/retrieve ops need to be in the outer graph scope.
          with tf.init_scope():
            tf.logging.info('creating load and retrieve ops.')
            load_parameters_op = (
                tpu_embedding_lib.tpu_ops.load_tpu_embedding_adam_parameters(
                    parameters=table_var,
                    momenta=m_var,
                    velocities=v_var,
                    table_name=table_name,
                    num_shards=num_tpu_hosts,
                    shard_id=host_id))
            load_op_list.append(load_parameters_op)

            retrieved_table, retrieved_m, retrieved_v = (
                tpu_embedding_lib.tpu_ops
                .retrieve_tpu_embedding_adam_parameters(
                    table_name=table_name,
                    num_shards=num_tpu_hosts,
                    shard_id=host_id))
            retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group(
                tf.assign(table_var, retrieved_table),
                tf.assign(m_var, retrieved_m), tf.assign(v_var, retrieved_v))
            retrieve_op_list.append(retrieve_parameters_op)

    return load_op_list, retrieve_op_list