Example #1
0
    def FProp(self, theta, inputs, paddings, class_emb):
        """Apply batch normalization.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [batch, ..., dim].
      paddings: The paddings tensor.  Shaped [batch, ..., 1], with the same rank
        as the input tensor.
      class_emb: The conditioning inputs, Shaped [batch, emb_dim].

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
        if py_utils.testonly_skip_norm_layers():
            return inputs

        p = self.params
        batch = py_utils.GetShape(inputs)[0]
        class_emb = py_utils.HasShape(class_emb, [batch, p.class_emb_dim])
        if not py_utils.use_tpu():
            class_emb = py_utils.with_dependencies([
                py_utils.assert_less_equal(
                    tf.cast(class_emb, tf.int32), 1, name='one_hot_assert1'),
                py_utils.assert_greater_equal(
                    tf.cast(class_emb, tf.int32), 0, name='one_hot_assert2'),
                py_utils.assert_equal(tf.ones([batch], tf.int32),
                                      tf.cast(tf.reduce_sum(class_emb, -1),
                                              tf.int32),
                                      name='one_hot_assert3'),
            ], class_emb)

        with tf.name_scope(p.name):
            norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments(
                theta, inputs, paddings=paddings, class_emb=class_emb)
            return self._ComputeBN(inputs, paddings, gamma, beta, norm_mean,
                                   norm_variance)
Example #2
0
    def FProp(self, theta, inputs):
        """Apply projection to inputs.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor.  Shaped [..., input_dims].

    Returns:
      Projected inputs.
    """
        p = self.params
        with tf.name_scope(p.name):
            computation_cost.Add(
                self, 'flops',
                tf.reduce_prod(tf.cast(tf.shape(inputs)[:-1], tf.int64)) *
                tf.cast(
                    symbolic.EvalExpr(symbolic.TENSOR_VALUES, p.input_dims *
                                      p.output_dims), tf.int64) * 2)
            use_tpu = py_utils.use_tpu()
            shape = inputs.shape
            if use_tpu and (shape is not None and shape.rank is not None
                            and shape.rank < 26):
                # Avoids reshape if feasible and uses Einsum.
                if shape.rank == 2:
                    return tf.matmul(inputs, theta.w)
                else:
                    s = ''.join([chr(x) for x in range(97, 123)])  # abc...xyz
                    r = shape.rank
                    return tf.einsum('{0}y,yz->{0}z'.format(s[:r - 1]), inputs,
                                     theta.w)

            input_dim = py_utils.GetShape(inputs)[-1]
            act = tf.matmul(tf.reshape(inputs, [-1, input_dim]), theta.w)
            output_dim = tf.shape(theta.w)[-1]
            act = tf.reshape(
                act, tf.concat([tf.shape(inputs)[:-1], [output_dim]], axis=0))
            return act
Example #3
0
 def infeed_bucket_batch_limit(self):
     """Returns the bucket batch limit for one infeed host."""
     p = self.params
     cluster = self.cluster
     infeed_bucket_batch_limit = [
         b * cluster.num_splits_per_client for b in p.bucket_batch_limit
     ]
     if p.use_per_host_infeed and cluster.num_tpu_hosts > 0:
         if not py_utils.use_tpu():
             raise ValueError(
                 'Scaling to TPU hosts without TPUs. {}'.format(
                     cluster.num_tpu_hosts))
         tf.logging.info(
             'scaling infeed_bucket_batch_limit num_tpu_hosts={}'.format(
                 cluster.num_tpu_hosts))
         infeed_bucket_batch_limit = [
             x // cluster.num_tpu_hosts for x in infeed_bucket_batch_limit
         ]
     tf.logging.info(
         'infeed_bucket_batch_limit={} num_splits_per_client={} bucket_batch_limit={}'
         .format(infeed_bucket_batch_limit, cluster.num_splits_per_client,
                 p.bucket_batch_limit))
     return infeed_bucket_batch_limit
Example #4
0
    def _CreateLayerVariables(self):
        p = self.params
        w_pc = py_utils.WeightParams(
            shape=[self._ids_per_shard, p.embedding_dim],
            init=p.params_init,
            dtype=p.dtype,
            collections=[self.__class__.__name__ + '_vars'])

        embedding_table_vars = []
        for i in range(p.num_tpu_hosts):
            device_name = self.GetDeviceName(i)
            with tf.device(device_name), py_utils.outside_all_rewrites():
                var_name = self.GetVariableName(i)
                self.CreateVariable(var_name, w_pc)
                embedding_var = self.vars[var_name]
                embedding_table_vars.append(embedding_var)
                # Remove from _private_vars / _private_thetas to be added later as wm.
                del self._private_vars[var_name]
                del self._private_theta[var_name]

        self._tpu_embedding_collection.AddTableVariables(
            self.table_name, embedding_table_vars)

        if not py_utils.use_tpu():
            # We don't want to add this for TrainerTpu, otherwise the identity
            # reference leads to copying the embedding to the TPU for no reason.
            # However, this is needed for CPU (eval/decode/controller).
            self._private_vars['wm'] = embedding_table_vars
            self._private_theta['wm'] = [
                tf.identity(v) for v in embedding_table_vars
            ]

        # Only trainer and controller need slot variables and load/retrieve ops.
        if not self.do_eval:
            self._load_op_list, self._retrieve_op_list = (
                self.optimizer.CreateSlotVariablesAndOps(
                    embedding_table_vars, self))
Example #5
0
 def _Moments(self, inputs, group_size):
     """Computes mean and variance over N,H,W dimensions in inputs."""
     counts, mean_ss, variance_ss, _, = tf.nn.sufficient_statistics(
         inputs, axes=[0, 1, 2], keep_dims=False)
     self.accumulators.counts.Update(counts)
     self.accumulators.mean_ss.Update(mean_ss)
     self.accumulators.variance_ss.Update(variance_ss)
     if py_utils.use_tpu() and group_size > 1:
         num_shards = tpu_function.get_tpu_context().number_of_shards
         assert num_shards >= group_size
         assert num_shards % group_size == 0
         num_groups = num_shards // group_size
         group_assignment = []
         for g in range(num_groups):
             replica_ids = [g * group_size + i for i in range(group_size)]
             group_assignment.append(replica_ids)
         counts *= group_size
         mean_ss = tf.contrib.tpu.cross_replica_sum(mean_ss,
                                                    group_assignment)
         variance_ss = tf.contrib.tpu.cross_replica_sum(
             variance_ss, group_assignment)
     mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss,
                                              None)
     return mean, variance
Example #6
0
    def _NestedMapFromBatchedOutputs(self, outputs):
        """Create a NestedMap from a tuple of outputs from generic_input_op."""
        batch_size = self.InfeedBatchSize()
        shapes = self.Shape()
        shapes.VLog(0, 'input extractor shape: ')
        flatten_shapes = shapes.Flatten()
        dtypes = self.DType()
        assert dtypes.IsCompatible(shapes), '{} vs. {}'.format(
            dtypes.DebugString(), shapes.DebugString())
        flatten_dtypes = dtypes.FlattenItems()
        assert len(flatten_shapes) == len(outputs), '{} vs. {}'.format(
            len(flatten_shapes), len(outputs))
        assert len(flatten_dtypes) == len(outputs), '{} vs. {}'.format(
            len(flatten_dtypes), len(outputs))

        rets = []
        for (output, (name, dtype), shape) in zip(outputs, flatten_dtypes,
                                                  flatten_shapes):
            assert dtype == output.dtype, '{}: {} vs. {}'.format(
                name, dtype, output.dtype)
            # Pad every output to make shapes fixed according to the corresponding
            # declared shape, since the shapes of outputs are lost through
            # generic_input_op.
            try:
                shape.assert_is_fully_defined()
            except ValueError as e:
                raise ValueError('Invalid shape for %s: %s' % (name, e))
            padded = py_utils.PadOrTrimTo(output,
                                          [batch_size] + shape.as_list())
            rets += [padded]

        rets = shapes.Pack(rets)
        if py_utils.use_tpu():
            # Drops tf.string tensors, which is not supported on TPUs.
            rets = rets.Filter(lambda x: x.dtype != tf.string)
        return rets
Example #7
0
    def __init__(self, params):
        super().__init__(params)
        p = self.params
        self._before_layers = []
        self._cells = []

        num_cells = len(p.cell_tpl)
        before_tpl_device = ''
        cell_devices = [''] * num_cells
        if py_utils.use_tpu():
            cluster = self.cluster
            before_tpl_device = cluster.WorkerDeviceInModelSplit(0)
            cell_devices = [
                cluster.WorkerDeviceInModelSplit(i) for i in range(num_cells)
            ]

        for l in p.before_tpl:
            with tf.device(before_tpl_device):
                self.CreateChild(l.name, l)
            self._before_layers.append((l.name, self.children[l.name]))
        for i, l in enumerate(p.cell_tpl):
            with tf.device(cell_devices[i]):
                self.CreateChild(l.name, l)
            self._cells.append((l.name, self.children[l.name]))
Example #8
0
    def _ProcessMASSInput(self, source_id, src):
        """Perform MASS input processing."""
        if self.do_eval or self.mass_layer is None:
            # At eval time, we copy src to tgt
            return self._ProcessSingleInput(source_id, src, src)

        _, labels, paddings = self.StringsToIds(tf.reshape(src, [1]),
                                                is_source=True,
                                                key=self._src_tokenizer_key)
        weights = 1 - paddings
        actual_seq_len = tf.cast(tf.reduce_sum(weights, 1), tf.int32)
        src_lang_ids, tgt_lang_ids = self._GetTaskIds(source_id)

        mass_out = self.mass_layer.Mask(labels, weights, actual_seq_len)

        features = py_utils.NestedMap()
        features.src = py_utils.NestedMap()
        features.src.ids = mass_out.src.ids
        features.src.paddings = paddings
        features.src.weights = weights
        features.src.task_ids = tf.cast(features.src.weights,
                                        dtype=tf.int32) * src_lang_ids
        features.src.ids_indicator = weights
        features.tgt = py_utils.NestedMap()
        features.tgt.ids = mass_out.tgt.ids
        features.tgt.labels = mass_out.tgt.labels
        features.tgt.paddings = paddings
        features.tgt.weights = mass_out.tgt.weights
        features.tgt.task_ids = tf.ones_like(features.src.task_ids,
                                             dtype=tf.int32) * tgt_lang_ids
        features.tgt.ids_indicator = weights

        if not py_utils.use_tpu():
            features.src.strs = src
            features.tgt.strs = src
        return features.Transform(tf.squeeze)
    def SplitInputBatch(self, num_splits):
        """Splits the current InputBatch into num_splits ways.

    Args:
      num_splits: The number of splits.

    Returns:
      A list of `.NestedMap`. Each `.NestedMap` represents the input
      tensors in one split.
    """
        assert num_splits >= 1
        print("num_splits " + str(num_splits))

        batch = self.GetPreprocessedInputBatch()
        if num_splits == 1:
            # Special case. No split is needed.
            # this is the place the make 1 gpu different from 4 gpu
            return [batch]

        assert not py_utils.use_tpu()
        print("batch " + str(batch))
        print("batch.Flatten " + str(batch.Flatten))
        print("num_splits " + str(num_splits))
        # batch is ok without any ? its this step that get symbol ?
        field_split = ig_helper.SplitTensors(batch.Flatten(), num_splits)
        print("field_split " + str(field_split))
        num_fields = len(field_split)
        ret = []
        for j in range(num_splits):
            print("j " + str(j))
            split_flatten = [field_split[i][j] for i in range(num_fields)]
            print("split_flatten " + str(split_flatten))
            split = batch.Pack(split_flatten)
            print("split " + str(split))
            ret += [split]
        return ret
Example #10
0
    def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table):
        load_op_list = []
        retrieve_op_list = []

        num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts
        table_name = tpu_embedding_table.table_name

        for host_id, table_var in zip(range(num_tpu_hosts), table_vars):
            # The slot vars should be on the same device as the table var.
            device_name = tpu_embedding_table.GetDeviceName(host_id)
            with tf.device(device_name), py_utils.outside_all_rewrites():
                # Only the Trainer needs these ops.
                if py_utils.use_tpu():
                    # TPU Embedding load/retrieve ops need to be in the outer graph scope.
                    with tf.init_scope():
                        tf.logging.info('creating load and retrieve ops.')
                        load_parameters_op = (
                            tpu_embedding_lib.tpu_ops.
                            load_tpu_embedding_stochastic_gradient_descent_parameters(
                                parameters=table_var,
                                table_name=table_name,
                                num_shards=num_tpu_hosts,
                                shard_id=host_id))
                        load_op_list.append(load_parameters_op)

                        retrieved_table = (
                            tpu_embedding_lib.tpu_ops.
                            retrieve_tpu_embedding_stochastic_gradient_descent_parameters(
                                table_name=table_name,
                                num_shards=num_tpu_hosts,
                                shard_id=host_id))
                        retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group(
                            tf.assign(table_var, retrieved_table))
                        retrieve_op_list.append(retrieve_parameters_op)

        return load_op_list, retrieve_op_list
Example #11
0
  def CreateTpuFeeds(self):
    """Creates the TPU infeed queue from preprocessed batch."""
    p = self.params
    cluster = cluster_factory.Current()
    num_tpu_hosts = cluster.num_tpu_hosts
    assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
    num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

    with py_utils.outside_all_rewrites():
      assert py_utils.use_tpu()
      assert not self._made_tpu_infeed

      shards = tpu_function.get_tpu_context(
      ).number_of_shards // num_infeed_hosts
      input_ops_list = []
      queues = []
      first_batch = None
      for task_id in range(num_infeed_hosts):
        host_device = '/task:{}/device:CPU:0'.format(task_id)
        with tf.device(host_device):
          batch = self.GetPreprocessedInputBatch()
          if first_batch is None:
            first_batch = batch
          flat_batch = batch.FlattenItems()

          shapes, types = [], []
          for k, x in flat_batch:
            assert x.shape.is_fully_defined(), (
                'Shape must be fully defined: %s: %s' % (k, x))
            # TODO(cwhipkey): if it's a string (or other type not supported on
            # TPU), drop it from feeding and on the other end add in an op that
            # fails if used.
            shapes.append(x.shape)
            types.append(x.dtype)
          q = tf.contrib.tpu.InfeedQueue(tuple_types=types, tuple_shapes=shapes)
          queues.append(q)
          assert shards is not None
          q.set_number_of_shards(shards)

          if p.use_per_host_infeed:

            # TODO(ylc/zhifengc): Add this to a policy module and test it.
            def _tpu_ordinal_function(shard_index_in_host):
              device_assignment = py_utils.GetTpuDeviceAssignment()
              if device_assignment:
                # We put both enqueue/dequeue ops at core 0 in each replica.
                replica = device_assignment.lookup_replicas(
                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                return device_assignment.tpu_ordinal(replica=replica)
              else:
                return shard_index_in_host

            input_ops = q.split_inputs_and_generate_enqueue_ops(
                [v for _, v in flat_batch],
                placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                tpu_ordinal_function=_tpu_ordinal_function)
          else:
            input_ops = q.split_inputs_and_generate_enqueue_ops(
                [v for _, v in flat_batch],
                device_assignment=py_utils.GetTpuDeviceAssignment())

          input_ops_list += input_ops
      tf.logging.info('input_ops_list %s', input_ops_list)
      tpu_infeed_op = tf.group(*input_ops_list)
    self._made_tpu_infeed = True
    # Let trainer.py use multiple threads to drive the infeed op.
    for _ in range(p.tpu_infeed_parallism):
      tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

    with tf.device(tf.contrib.tpu.core(0)):
      tensors = queues[0].generate_dequeue_op()
    return first_batch.Pack(tensors)
Example #12
0
 def _process(source_id, record):
     del source_id
     num = tf.strings.to_number(record, tf.int32)
     if not tf_py_utils.use_tpu():
         num = num * num
     return py_utils.NestedMap(num=num), 1
Example #13
0
    def _CreateLayerVariables(self):
        super()._CreateLayerVariables()
        p = self.params

        load_op_list = []
        retrieve_op_list = []

        # At the feature level, track which are associated
        # with "sequence embeddings".
        self._sequence_features = {}

        if py_utils.use_tpu():
            num_cores = self.cluster.params.worker.tpus_per_replica
            global_batch_size = (self.params.batch_size *
                                 self.cluster.num_splits_per_client)
            table_to_config_dict = {}
            feature_to_config_dict = {}
            for table in self.tables:
                table_to_config_dict[table.table_name] = table.table_config
                load_op_list += table.load_op_list
                retrieve_op_list += table.retrieve_op_list
                for feature in table.input_keys:
                    if table.max_sequence_length > 0:
                        self._sequence_features[feature] = True
                    feature_to_config_dict[
                        feature] = tpu_embedding_lib.FeatureConfig(
                            table.table_name,
                            max_sequence_length=table.max_sequence_length)
            tf.logging.info('adding load and retrieve ops to collection.')
            tf.add_to_collection(py_utils.TPU_EMBEDDING_LOAD_OPS, load_op_list)
            tf.add_to_collection(py_utils.TPU_EMBEDDING_RETRIEVE_OPS,
                                 retrieve_op_list)

            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            assert len(tpu_embedding_collection) <= 1
            if len(tpu_embedding_collection) == 1:
                tf.logging.info(
                    'TPUEmbedding API singleton already exists, reusing')
                self._tpu_embedding = tpu_embedding_collection[0]
            else:
                mode = tpu_embedding_lib.TRAINING
                device_config = tpu_embedding_lib.DeviceConfig(
                    num_cores=num_cores,
                    num_hosts=self.params.tables[0].num_tpu_hosts,
                    job_name=self.cluster.params.worker.name)
                self._tpu_embedding = tpu_embedding_lib.TPUEmbedding(
                    table_to_config_dict,
                    feature_to_config_dict,
                    global_batch_size,
                    mode,
                    master=None,
                    pipeline_execution_with_tensor_core=(
                        self.params.pipeline_execution_with_tensor_core),
                    partition_strategy=p.partition_strategy,
                    device_config=device_config)
                tf.add_to_collection(py_utils.TPU_EMBEDDING,
                                     self._tpu_embedding)
                tf.add_to_collection(
                    py_utils.TPU_EMBEDDING_GRADIENT_MULTIPLIER_SCHEDULE,
                    self.gradient_multiplier_schedule)
Example #14
0
  def ComputePredictions(self, theta, source_encs, source_paddings, targets,
                         src_segment_id):
    """Decodes `targets` given encoded source.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      source_encs: source encoding, of shape [time, batch, depth].
      source_paddings: source encoding's padding, of shape [time, batch].
      targets: A dict of string to tensors representing the targets one try to
        predict. Each tensor in targets is of shape [batch, time].
      src_segment_id: source segment id, of shape [time, batch].

    Returns:
      A Tensor with shape [time, batch, params.softmax.input_dim].
    """
    p = self.params
    time, batch = py_utils.GetShape(source_paddings, 2)
    source_encs = py_utils.HasShape(source_encs, [time, batch, p.source_dim])
    with tf.name_scope(p.name):
      target_ids = tf.transpose(targets.ids)
      target_paddings = py_utils.HasRank(targets.paddings, 2)
      target_paddings = tf.expand_dims(tf.transpose(target_paddings), 2)
      if p.packed_input:
        target_segment_id = tf.expand_dims(tf.transpose(targets.segment_ids), 2)
      else:
        target_segment_id = tf.zeros_like(target_paddings)

      if py_utils.use_tpu():
        emb_device = self.cluster.WorkerDeviceInModelSplit(0)
      else:
        emb_device = ''
      with tf.device(emb_device):
        inputs = self.emb.EmbLookup(theta.emb, target_ids)
        inputs = self.ApplyClipping(theta, inputs)
        summary_utils.histogram('input_emb', inputs)
        inputs = self.ApplyDropout(inputs)
        self._emb_out = inputs

        # Layer 0 interwines with attention.
        (atten_ctxs, xs, atten_probs, _) = self.frnn_with_atten.FProp(
            theta.frnn_with_atten,
            source_encs,
            source_paddings,
            inputs,
            target_paddings,
            src_segment_id=src_segment_id,
            segment_id=target_segment_id)
        self._AddAttenProbsSummary(source_paddings, targets, [atten_probs])

        atten_ctxs = self.ApplyClipping(theta, atten_ctxs)
        summary_utils.histogram('atten_ctxs', atten_ctxs)

        for i, (layer, layer_theta) in enumerate(zip(self.frnn, theta.frnn)):
          # Forward through Layer-(i + 1) because Layer-0 handled before.
          ys, _ = layer.FProp(
              layer_theta,
              tf.concat([xs, atten_ctxs], 2),
              target_paddings,
              segment_id=target_segment_id)
          ys = self.ApplyDropout(ys)
          if 1 + i >= p.residual_start:
            xs += ys  # Residual skip
            xs = self.ApplyClipping(theta, xs)
          else:
            xs = ys
          summary_utils.histogram('layer_out_%s' % i, xs)

        if p.feed_attention_context_vec_to_softmax:
          xs = tf.concat([xs, atten_ctxs], 2)

        return xs
Example #15
0
    def FProp(self, theta):
        """Forward propagation.

    This default `FProp` implementation here supports batch splitting in
    synchronous and asynchronous training when sub-classes implement
    `FPropTower`.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.

    Returns:
      A dict containing metrics pairs. One of the keys should be 'loss' and its
      value should be a (loss, num_predictions) pair.
    """
        p = self.params
        cluster = cluster_factory.Current()

        with tf.name_scope('fprop'), tf.name_scope(p.name):
            all_fprop_metrics = []

            if py_utils.use_tpu():
                batch = self.input_generator.CreateTpuFeeds()
                with tf.name_scope('tower_0_0'):
                    dec_metrics = self.FPropTower(theta, batch)
                all_fprop_metrics.append(dec_metrics)
            else:
                # Splits the input batch on the input device.
                num_splits = cluster.num_splits_per_client
                with tf.device(cluster.input_device):
                    batches = self.input_generator.SplitInputBatch(num_splits)
                    assert num_splits == len(batches)

                # dev_list_per_replica[i][j] is the i-th worker's j-th device.
                dev_list_per_replica = cluster.available_devices.tolist()

                # Asserts invariant of the total number of splits w.r.t.,
                # splits per worker.
                splits_per_replica = cluster.num_splits_per_replica
                assert num_splits == splits_per_replica * len(
                    dev_list_per_replica)

                for w_id, w_devs in enumerate(dev_list_per_replica):
                    # Make local copy of the vars, shard on devices for this worker.
                    theta_local = py_utils.CreateLocalTheta(theta,
                                                            w_devs,
                                                            label='worker %d' %
                                                            w_id)

                    for s_id in range(splits_per_replica):
                        # s_id-th split for the w_id-th worker.
                        split_id = splits_per_replica * w_id + s_id
                        with py_utils.ModelSplit(split_id):
                            with tf.device(
                                    cluster.WorkerDeviceInModelSplit(0)):
                                with tf.name_scope('tower_%d_%d' %
                                                   (w_id, s_id)):
                                    batch = self.input_generator.PreprocessInputBatch(
                                        batches[split_id])
                                    dec_metrics = self.FPropTower(
                                        theta_local, batch)
                        all_fprop_metrics.append(dec_metrics)

            metrics = py_utils.WeightedAvgOfMetrics(all_fprop_metrics)

        # Adds stats about the input batch.
        metrics['num_samples_in_batch'] = (tf.convert_to_tensor(
            self.input_generator.InputBatchSize()), tf.constant(1.0))
        # Generates summaries.
        for name, (value, weight) in six.iteritems(metrics):
            self.AddEvalMetric(name, value, weight)

        # Loss.
        self._loss, self._num_predicts = metrics['loss']
        self._loss = py_utils.CheckNumerics(self._loss)

        return metrics
Example #16
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].

    Returns:
      A NestedMap containing:
        - encoded: The encoded features, either a tensor of shape [time, batch,
            depth], or a list of tensors if is_transparent is set in
            transformer_stack.
        - padding: of shape [time, batch]
        - segment_id: [time, batch] if packed inputs are supported by the model
            (and all layers), or None otherwise.
        - embedded_inputs: [time, batch, depth] embedded inputs tokens without
            positional encodings.
    """

        p = self.params
        with tf.name_scope(p.name):
            src_segment_id = None
            src_segment_pos = None
            input_ids = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings)),
                py_utils.assert_equal(tf.rank(input_batch.ids), 2)
            ], input_batch.ids)

            if (not py_utils.use_tpu()
                    and tf.flags.FLAGS.transformer_encoder_truncates_inputs):
                max_seq_length = tf.cast(
                    tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings,
                                                1)), tf.int32)
                paddings = py_utils.with_dependencies([
                    py_utils.assert_equal(
                        tf.constant(True, tf.bool),
                        tf.reduce_all(
                            input_batch.paddings[:, max_seq_length:] > 0.5))
                ], input_batch.paddings)
                input_ids = input_ids[:, :max_seq_length]
                paddings = paddings[:, :max_seq_length]
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids[:, :
                                                             max_seq_length]
                    src_segment_pos = input_batch.segment_pos[:, :
                                                              max_seq_length]
            else:
                paddings = input_batch.paddings
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids
                    src_segment_pos = input_batch.segment_pos

            max_time = tf.shape(input_ids)[1]

            # Input token embeddings + positional embeddings
            input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                                  tf.reshape(input_ids, [-1]))
            input_embs = tf.reshape(input_embs,
                                    [-1, max_time, p.token_emb.embedding_dim])
            # [time, batch, dim]
            orig_input_embs = tf.transpose(input_embs, [1, 0, 2])

            if p.packed_input:
                position_embs = self.position_emb.FPropWithPosition(
                    theta.position_emb, src_segment_pos)
            else:
                position_embs = self.position_emb.FProp(
                    theta.position_emb, max_time)
                position_embs = tf.reshape(
                    position_embs, [1, max_time, p.token_emb.embedding_dim])
            input_embs += position_embs

            if p.model_dim != p.token_emb.embedding_dim:
                input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

            paddings = tf.transpose(paddings)
            if p.packed_input:
                src_segment_id = tf.transpose(src_segment_id)
            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)

            # [time, batch, dim]
            transformer_input = tf.transpose(input_embs, [1, 0, 2])

        encoded, padding, segment_id = self.transformer_stack.FProp(
            theta.transformer_stack, transformer_input, paddings,
            src_segment_id)
        return py_utils.NestedMap(encoded=encoded,
                                  padding=padding,
                                  segment_id=segment_id,
                                  embedded_inputs=orig_input_embs)
Example #17
0
    def Transform(self, dataset):
        """Batches a dataset containing NestedMaps of tensors."""
        p = self.params

        require_sequential_order = p.require_sequential_order or self.do_eval
        seqlen_fn = getattr(self._input_generator, p.seqlen_fn)

        def SetBucketKeys(example):
            example.bucket_keys = seqlen_fn(example)
            return example

        dataset = dataset.map(SetBucketKeys,
                              num_parallel_calls=tf.data.experimental.AUTOTUNE,
                              deterministic=require_sequential_order)

        dataset = dataset.filter(
            lambda x: x.bucket_keys <= p.bucket_upper_bound[-1])

        dataset_structure = py_utils.NestedMap.FromNestedDict(
            tf.data.experimental.get_structure(dataset))

        input_shape_fn = getattr(self._input_generator, p.input_shape_fn)
        padded_shapes = dataset_structure.TransformWithKey(
            lambda k, _: tf.TensorShape(input_shape_fn(k)))
        input_padding_fn = getattr(self._input_generator, p.input_padding_fn)
        padding_values = dataset_structure.TransformWithKey(input_padding_fn)

        dataset_structure.VLog(0, 'dataset_structure:')
        padded_shapes.VLog(0, 'padded_shapes:')

        bucket_batch_limit = [
            batch_utils.scale_split_to_infeed(
                b, self._input_generator.params.use_per_host_infeed)
            for b in p.bucket_batch_limit
        ]
        dataset = dataset.apply(
            tf.data.experimental.bucket_by_sequence_length(
                lambda x: x.bucket_keys,
                # Upper-bound for bucket_by_sequence_length is exclusive, so add 1
                # TODO(jeffreyzhao): There is a off-by-one bug with the upper bound
                # boundary check, so add 2 instead. Remove when fixed.
                [x + 2 for x in p.bucket_upper_bound],
                bucket_batch_limit + [1],
                padded_shapes=padded_shapes,
                padding_values=padding_values,
                pad_to_bucket_boundary=True,
                drop_remainder=py_utils.use_tpu()))

        if py_utils.use_tpu():
            # Set static shapes for TPU.
            if min(bucket_batch_limit) != max(bucket_batch_limit):
                raise ValueError('TPU requires constant batch sizes.')
            else:
                b = bucket_batch_limit[0]

                def SetShape(element):
                    for t in element.Flatten():
                        t.set_shape((b, ) + t.shape[1:])
                    return element

                dataset = dataset.map(
                    SetShape,
                    num_parallel_calls=tf.data.experimental.AUTOTUNE,
                    deterministic=require_sequential_order)

        return dataset
Example #18
0
    def PostProcess(self, dec_out_dict, dec_metrics_dict):
        p = self.params
        assert 'topk_scores' in dec_out_dict, list(dec_out_dict.keys())
        topk_scores = dec_out_dict['topk_scores']
        topk_decoded = dec_out_dict['topk_decoded']
        transcripts = dec_out_dict['transcripts']
        if not py_utils.use_tpu():
            utt_id = dec_out_dict['utt_id']
            assert len(utt_id) == len(transcripts)
        norm_wer_errors = dec_out_dict['norm_wer_errors']
        norm_wer_words = dec_out_dict['norm_wer_words']
        target_labels = dec_out_dict['target_labels']
        target_paddings = dec_out_dict['target_paddings']
        topk_ids = dec_out_dict['topk_ids']
        topk_lens = dec_out_dict['topk_lens']
        assert len(transcripts) == len(target_labels)
        assert len(transcripts) == len(target_paddings)
        assert len(transcripts) == len(topk_decoded)
        assert len(norm_wer_errors) == len(transcripts)
        assert len(norm_wer_words) == len(transcripts)

        num_samples_in_batch = len(transcripts)
        dec_metrics_dict['num_samples_in_batch'].Update(num_samples_in_batch)

        def GetRefIds(ref_ids, ref_paddinds):
            assert len(ref_ids) == len(ref_paddinds)
            return_ids = []
            for i in range(len(ref_ids)):
                if ref_paddinds[i] == 0:
                    return_ids.append(ref_ids[i])
            return return_ids

        total_norm_wer_errs = norm_wer_errors[:, 0].sum()
        total_norm_wer_words = norm_wer_words[:, 0].sum()

        dec_metrics_dict['norm_wer'].Update(
            total_norm_wer_errs / total_norm_wer_words, total_norm_wer_words)

        for ref_str, hyps in zip(transcripts, topk_decoded):
            filtered_ref = decoder_utils.FilterNoise(ref_str)
            filtered_ref = decoder_utils.FilterEpsilon(filtered_ref)
            filtered_hyp = decoder_utils.FilterNoise(hyps[0])
            filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp)
            dec_metrics_dict['corpus_bleu'].Update(filtered_ref, filtered_hyp)

        total_errs = 0
        total_oracle_errs = 0
        total_ref_words = 0
        total_token_errs = 0
        total_ref_tokens = 0
        total_accurate_sentences = 0
        key_value_pairs = []

        if p.include_auxiliary_metrics:
            for i in range(len(transcripts)):
                ref_str = transcripts[i]
                if not py_utils.use_tpu():
                    tf.logging.info('utt_id: %s', utt_id[i])
                if self.cluster.add_summary:
                    tf.logging.info(
                        '  ref_str: %s',
                        ref_str.decode('utf-8') if p.log_utf8 else ref_str)
                hyps = topk_decoded[i]
                num_hyps_per_beam = len(hyps)
                ref_ids = GetRefIds(target_labels[i], target_paddings[i])
                hyp_index = i * num_hyps_per_beam
                top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]]
                if self.cluster.add_summary:
                    tf.logging.info('  ref_ids: %s', ref_ids)
                    tf.logging.info('  top_hyp_ids: %s', top_hyp_ids)
                total_ref_tokens += len(ref_ids)
                _, _, _, token_errs = decoder_utils.EditDistanceInIds(
                    ref_ids, top_hyp_ids)
                total_token_errs += token_errs

                filtered_ref = decoder_utils.FilterNoise(ref_str)
                filtered_ref = decoder_utils.FilterEpsilon(filtered_ref)
                oracle_errs = norm_wer_errors[i][0]
                for n, (score, hyp_str) in enumerate(zip(topk_scores[i],
                                                         hyps)):
                    if self.cluster.add_summary:
                        tf.logging.info(
                            '  %f: %s', score,
                            hyp_str.decode('utf-8') if p.log_utf8 else hyp_str)
                    filtered_hyp = decoder_utils.FilterNoise(hyp_str)
                    filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp)
                    ins, subs, dels, errs = decoder_utils.EditDistance(
                        filtered_ref, filtered_hyp)
                    # Note that these numbers are not consistent with what is used to
                    # compute normalized WER.  In particular, these numbers will be
                    # inflated when the transcript contains punctuation.
                    tf.logging.info('  ins: %d, subs: %d, del: %d, total: %d',
                                    ins, subs, dels, errs)
                    # Only aggregate scores of the top hypothesis.
                    if n == 0:
                        total_errs += errs
                        total_ref_words += len(
                            decoder_utils.Tokenize(filtered_ref))
                        if norm_wer_errors[i, n] == 0:
                            total_accurate_sentences += 1
                    oracle_errs = min(oracle_errs, norm_wer_errors[i, n])
                total_oracle_errs += oracle_errs

            dec_metrics_dict['wer'].Update(
                total_errs / max(1., total_ref_words), total_ref_words)
            dec_metrics_dict['oracle_norm_wer'].Update(
                total_oracle_errs / max(1., total_ref_words), total_ref_words)
            dec_metrics_dict['sacc'].Update(
                total_accurate_sentences / len(transcripts), len(transcripts))
            dec_metrics_dict['ter'].Update(
                total_token_errs / max(1., total_ref_tokens), total_ref_tokens)

        return key_value_pairs
Example #19
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host))
        tf.logging.info('num_devices_per_split {}'.format(
            cluster.num_devices_per_split))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            input_ops_list = []
            queues = []
            first_batch = None
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            tpu_embedding_input_keys = (
                tpu_embedding.feature_to_config_dict.keys()
                if tpu_embedding is not None else [])

            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    tpu_embedding_features = []
                    for tpu_embedding_input_key in tpu_embedding_input_keys:
                        tpu_embedding_feature = batch.pop(
                            tpu_embedding_input_key)
                        tpu_embedding_features.append(
                            (tpu_embedding_input_key, tpu_embedding_feature))

                    if first_batch is None:
                        first_batch = batch
                    flat_batch = batch.FlattenItems()

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {}
                        ] * tpu_embedding.num_cores_per_host
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for tpu_embedding_input_key, tpu_embedding_feature in tpu_embedding_features:
                            tpu_embedding_feature_splitted = tf.split(
                                tpu_embedding_feature, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_embedding_feature_splitted):
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    tf.squeeze(split, axis=[1]))
                                enqueue_dict_per_core[core][
                                    tpu_embedding_input_key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    shapes, types = [], []
                    for k, x in flat_batch:
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                        shapes.append(x.shape)
                        types.append(x.dtype)
                    q = tf.contrib.tpu.InfeedQueue(tuple_types=types,
                                                   tuple_shapes=shapes)
                    queues.append(q)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                    if p.use_per_host_infeed:

                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def _tpu_ordinal_function(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            [v for _, v in flat_batch],
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=_tpu_ordinal_function)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            [v for _, v in flat_batch],
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        with tf.device(tf.compat.v1.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return first_batch.Pack(tensors)
Example #20
0
  def FProp(self, theta, x, paddings=None, update=False):
    """Computes distances of the given input 'x' to all centroids.

    This implementation applies layer normalization on 'x' internally first,
    and the returned 'dists' is computed using the normalized 'x'.

    Args:
      theta: A `.NestedMap` of weights' values of this layer.
      x: A tensor of shape [B, L, N, H].
      paddings: If not None, a tensor of shape [B, L].
      update: bool, whether to update centroids using x.

    Returns:
      dists: "distances" of the given input 'x' to all centroids.
             Shape [B, L, N, K].
      k_means_loss: the average squared Euclidean distances to the closest
                    centroid, a scalar.
    """
    p = self.params
    if paddings is None:
      paddings = tf.zeros_like(x[:, :, 0, 0])
    # Shape [B, L, 1, 1]
    paddings_4d = paddings[:, :, None, None]

    if p.apply_layer_norm:
      x = KMeansClusteringForAtten.LayerNorm(x, p.epsilon)

    # 'x' is normalized (but theta.means is not), we use negative dot product to
    # approximate the Euclidean distance here.
    dists = -tf.einsum('BLNH, NKH -> BLNK', x, theta.means)

    # For padded positions we update the distances to very large numbers.
    very_large_dists = tf.ones_like(dists) * tf.constant(
        0.1, dtype=dists.dtype) * dists.dtype.max
    paddings_tiled = tf.tile(paddings_4d, [1, 1, p.num_heads, p.num_clusters])
    dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists)

    # Shape [B, L, N, K], the same as 'dists' above.
    nearest_one_hot = tf.one_hot(
        tf.math.argmin(dists, axis=-1),
        p.num_clusters,
        dtype=py_utils.FPropDtype(p))
    # Same shape as the input 'x'.
    nearest_centroid = tf.einsum('BLNK, NKH -> BLNH', nearest_one_hot,
                                 theta.means)
    diff = tf.math.squared_difference(x, tf.stop_gradient(nearest_centroid))
    diff = py_utils.ApplyPadding(paddings_4d, diff)
    diff = tf.math.reduce_mean(diff, axis=2)

    # The commitment loss which when back proped against encourages the 'x'
    # values to commit to their chosen centroids.
    k_means_loss = tf.math.reduce_sum(diff) / tf.math.reduce_sum(1.0 - paddings)
    summary_utils.scalar('k_means/squared_distance_loss', k_means_loss)

    # TODO(zhouwk): investigate normalizing theta.means after each update.
    means_norm = tf.norm(theta.means)
    summary_utils.scalar('k_means/centroid_l2_norm/min',
                         tf.math.reduce_min(means_norm))
    summary_utils.scalar('k_means/centroid_l2_norm/mean',
                         tf.math.reduce_mean(means_norm))

    if not update:
      return dists, k_means_loss

    # To update the centroids (self.vars.means), we apply gradient descent on
    # the mini-batch of input 'x', which yields the following:
    #   new_centroid = centroid + (1 - decay) * (x_mean - centroid)
    # where x_mean is the average over all the input vectors closest to this
    # centroid.
    #
    # Note that this approach is equivalent with backprop via
    #    loss = tf.math.reduce_mean(
    #        tf.math.squared_difference(tf.stop_gradient(x), nearest_centroid)))
    # , except that here the learning rate is independently set via 'decay'.

    # Ensure that the padded positions are not used to update the centroids.
    nearest_one_hot = py_utils.ApplyPadding(paddings_4d, nearest_one_hot)

    # Sum away batch and sequence length dimensions to get per cluster count.
    # Shape: [N, K]
    per_cluster_count = tf.reduce_sum(nearest_one_hot, axis=[0, 1])
    summary_utils.histogram('k_means/per_cluster_vec_count', per_cluster_count)

    # Sum of the input 'x' per each closest centroid.
    sum_x = tf.einsum('BLNK, BLNH -> NKH', nearest_one_hot, x)

    if py_utils.use_tpu():
      per_cluster_count = tf.tpu.cross_replica_sum(per_cluster_count)
      sum_x = tf.tpu.cross_replica_sum(sum_x)

    # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that
    # cluster's position will always be 0, hence 'sum_x' in that dimension will
    # be 0.
    new_means = sum_x / tf.maximum(
        tf.constant(1.0, dtype=per_cluster_count.dtype),
        tf.expand_dims(per_cluster_count, axis=-1))

    # We use exponential moving average. TODO(zhouwk): investigate smooth this
    # over an exponentially moving averaged per cluster count.
    #
    # Note that we intentionally do not normalize the means after this update
    # as empirically this works better.
    update_means_diff = tf.cast((1.0 - p.decay) * (new_means - theta.means),
                                self.vars.means.dtype)
    return py_utils.with_dependencies(
        [tf.assign_add(self.vars.means, update_means_diff)],
        dists), k_means_loss
Example #21
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info(
            'CreateTPUFeeds num_splits_per_client={} '
            'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'.
            format(cluster.num_splits_per_client,
                   cluster.num_devices_per_split, num_tpu_hosts,
                   p.use_per_host_infeed))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            tf.logging.info('shards {}'.format(shards))

            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            if num_tpu_hosts > 1 and tpu_embedding is not None:
                if not p.use_per_host_infeed:
                    tf.logging.fatal(
                        'TPU Embedding must be used with per_host_infeed with multiple '
                        'TPU host topologies.')
            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if isinstance(batch, py_utils.NestedMap):
                        # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                        # Note that when MultiTaskData is used, bucket_keys will be at the
                        # second level of the dictionary.
                        batch = batch.FilterKeyVal(
                            lambda k, _: not k.endswith('bucket_keys'))
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    if p.use_partitioned_infeed_queue:
                        device_assignment = py_utils.GetTpuDeviceAssignment()

                        host_device = device_assignment.host_device(
                            replica=0, job=tf.flags.FLAGS.tf_master)
                        host_id = int(
                            host_device.split('/task:')[1].split('/device:')
                            [0])
                        tf.logging.info('host_id: {} host_device: {}'.format(
                            host_id, host_device))
                        q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                            number_of_tuple_elements=len(dtypes),
                            device_assignment=device_assignment,
                            host_id=host_id,
                            input_partition_dims=[[p.num_partitions, 1]
                                                  for _ in dtypes],
                            tuple_types=dtypes,
                            tuple_shapes=shapes)
                    else:
                        q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                                 tuple_shapes=shapes)
                        assert shards is not None
                        q.set_number_of_shards(shards)

                    queues.append(q)
                    tf.logging.info('q=%r', q)

                    if p.use_partitioned_infeed_queue:
                        input_ops = q.generate_enqueue_ops([batch.Flatten()])
                    elif p.use_per_host_infeed:
                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        self._tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
Example #22
0
  def __init__(self,
               cell_fn,
               cell_grad,
               theta,
               state0,
               inputs,
               extras,
               implicit_captures=None,
               unused_acc_state=None):
    """RNN helper class.

    Args:
      cell_fn: A python function, which computes:
         state1, extras = cell_fn(theta, state0, inputs[t, :])
      cell_grad: A python function which computes:
         dtheta, dstate0, dinputs[t, :] = cell_grad(
           theta, state0, inputs[t, :], extras, dstate1)
      theta: weights. A `.NestedMap`.
      state0: initial state. A `.NestedMap`.
      inputs: inputs. A `.NestedMap`.
      extras: A `.NestedMap` of Tensors. The 2nd return value of every
        invocation of cell_fn is a `.NestedMap` with matching keys and shapes
        of this 'extras'.
      implicit_captures: A `.NestedMap` corresponding to implicit captures of
        the cell_fn. If empty/None, implicit captures are either not present
        or disallowed.
      unused_acc_state: If None, we assume every field of acc_state is consumed
        in the following timestamps. If True, None of the acc_state is consumed.
        And we reduce_sum each timestep's new state into a scalar.
        Note, this feature should be used with StackedRecurrent where we send
        out the new state to the other devices.
    """
    self._theta = theta
    self._state = state0
    self._inputs = inputs
    self._cell_fn = cell_fn
    self._cell_grad = cell_grad
    self._extras = extras
    self._implicit_captures = implicit_captures
    self._unused_acc_state = unused_acc_state

    if self._implicit_captures is None:
      self._implicit_captures = _EmptyCaptures()

    # pylint: disable=unbalanced-tuple-unpacking

    # NOTE: TF Function (Fwd, Bak, ForwardLoopBody, BackwardLoopBody,
    # Forward and Backward defined below) simply takes a list of
    # Tensors and returns a list of Tensors. When we pass in a
    # structure (a list of NestedMap of Tensors), we use _Flatten to
    # convert the structure into a list of tensor. Conversely, the
    # following code often uses _Pack to formulate a structure from a
    # list of tensors based on a "template".

    # Wraps cell_fn in a TF Function:
    #    state1 = cell_fn(theta, state0, inputs)
    fwd_sig = [self._theta, self._state, self._inputs]

    compiled = py_utils.use_tpu()
    noinline = not compiled
    dev_t_type = tf.int32 if py_utils.use_tpu() else tf.int64

    @function.Defun(*_Dtypes(fwd_sig))
    def Fwd(*args):
      (theta, state0, inputs) = _Pack(args, fwd_sig)
      state1, extras = self._cell_fn(theta, state0, inputs)
      _AssertIsCompatible(state1, self._state)
      _AssertIsCompatible(extras, self._extras)
      return _Flatten([state1, extras])

    # Wraps cell_fn in a TF Function as a for-loop's body.
    #
    # The loop state is composed of:
    #  t: The loop variable. Timestep id.
    #  dev_t: The loop variable mirrored on the device.
    #  theta: the recurrent net's weights.
    #  state0: the previous recurrent state.
    #  inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t.
    #  acc_state: Each timestep's computed new state is also stashed into
    #    acc_state.
    #  acc_extras: Each timestep's computed extras is stashed into acc_extras
    fwdloop_sig = [
        self._theta, self._state, self._inputs, self._state, self._extras
    ]

    @function.Defun(tf.int32, dev_t_type, *_Dtypes(fwdloop_sig))
    def ForwardLoopBody(*args):
      """The body of forward loop."""
      t, dev_t = args[0], args[1]
      (theta, state0, inputs, acc_state, acc_extras) = _Pack(
          args[2:], fwdloop_sig)
      inputs_t = _Index(inputs, t)  # external input at time step t.
      state1, extras = _Pack(
          Fwd(*_Flatten([theta, state0, inputs_t])),
          [self._state, self._extras])
      # Saves state1 and extras in their accumulators.
      if not self._unused_acc_state:
        acc_state = _Update(acc_state, state1, dev_t)
      acc_extras = _Update(acc_extras, extras, dev_t)

      return [tf.add(dev_t, 1)] + _Flatten(
          [theta, state1, inputs, acc_state, acc_extras])

    def Grad(op, *args):
      """The python grad function for the Forward function.

      Flowchart:
      +------------------------------------------------------------+
      |  Backward() DEFUN -> [d_fwd..., acc_extras, dcaptured]     |
      |                          |                                 |
      |                          v                                 |
      |                For(BackwardLoopBody())                     |
      |                          |                                 |
      |                          v                                 |
      |                BackwardLoopBody() DEFUN ->                 |
      |             ..., d_theta, d_state0, d_inputs,              |
      |                 d_acc_state, d_captured                    |
      |                          |                                 |
      |                          v                                 |
      |          Bak(..., inputs[t], extras[t]) DEFUN ->           |
      |       d_theta_t, d_state0, d_inputs_t, d_captured_t        |
      |                          |                                 |
      |                          v                                 |
      |      CellGrad(theta, state0, inputs, extras, d_state1) ->  |
      |               dtheta, dstate0, dinputs, dcaptured          |
      |                                                            |
      +------------------------------------------------------------+

      The key thing is that this function must return a dx value for each of
      the inputs to the Fwd function (theta, state0, inputs, captured...).
      The tricky part is that implicitly captured inputs are carried through
      function boundaries implicitly by the function call as the last
      arguments. When assembling gradients, we must account for these implicit
      captures even though they are not passed explicitly from function to
      function.

      Args:
        op: The forward operation.
        *args: Args to the forward operation (includes implicit captures).
      Returns:
        Tuple of derivitives.
      Raises:
        ValueError: on argument mismatch issues.
      """
      expected_num_inputs = 0
      for nmap in [
          self._theta,
          self._state,
          self._inputs,
          self._extras,
          # Implicit captured tensors always come last
          self._implicit_captures
      ]:
        expected_num_inputs += len(nmap.Flatten())
      if len(op.inputs) != expected_num_inputs:
        if len(op.inputs) > expected_num_inputs:
          raise ValueError(
              ('Too many inputs. The most likely cause is that cell_fn '
               'captures additional tensors: extra inputs %r vs captures %r') %
              (list(op.inputs), list(self._implicit_captures.Flatten())))
        raise ValueError(
            ('Mismatched inputs to cell fn: Found %d vs expected %d: %r'
             '. Implicit captures(%d) = %r') %
            (len(op.inputs), expected_num_inputs, list(op.inputs),
             len(self._implicit_captures.Flatten()), self._implicit_captures))

      # NOTE: tf.gradient backprops None for int32/int64 while zeros
      # for float32/float64. For consistency, we always backprop
      # zeros.
      args = list(args)
      for i, dy in enumerate(args):
        if dy is None:
          args[i] = tf.zeros_like(op.outputs[i])
      (theta, state0, inputs, _, unused_captured) = _Pack(
          [x for x in op.inputs],
          [
              self._theta,
              self._state,
              self._inputs,
              self._extras,
              # Implicit captured tensors always come last
              self._implicit_captures,
          ])
      # acc_state and acc_extras are computed by the Forward pass and
      # needed by the Backward pass.
      acc_state, _, acc_extras = _Pack([x for x in op.outputs],
                                       [self._state, self._state, self._extras])

      # Forward computes acc_state, the final state and
      # acc_extras. tf.gradients gives us their gradients w.r.t. the
      # final loss. Because acc_extras are not exposed by Compute(),
      # it has no gradients w.r.t. the final loss (i.e., by
      # construction, it must be zeros).
      d_acc_state, d_state1, _ = _Pack(args,
                                       [self._state, self._state, self._extras])

      if self._unused_acc_state:
        # XLA While op requires the same shape for the init and carry on values.
        state0 = state0.Transform(tf.reduce_sum)
        d_state1 = d_state1.Transform(tf.reduce_sum)

      return Backward(*_Flatten([
          theta,
          state0,
          inputs,
          acc_state,
          acc_extras,
          d_acc_state,
          d_state1,
      ]))

    # Forward calls ForwardLoopBody n times. Each time computes one
    # time step of the recurrent net.
    forward_sig = [self._theta, self._state, self._inputs, self._extras]

    @function.Defun(
        *_Dtypes(forward_sig), python_grad_func=Grad, noinline=noinline)
    def Forward(*args):
      """Forward pass of the recurrent net."""
      theta, state0, inputs, extras = _Pack(args, forward_sig)

      # The sequence length.
      pad_begin, pad_end = _SeqPaddingLength(inputs)
      slen_dim = _SeqLenDim(inputs)

      # Creates accumulators for state0 and extras.
      if self._unused_acc_state:
        acc_state = _EmptyWithFixShape([slen_dim], state0)
      else:
        acc_state = _EmptyAcc(slen_dim, state0)
      acc_extras = _EmptyAcc(slen_dim, extras)

      if py_utils.use_tpu():
        dev_t = tf.to_int32(pad_begin)
      else:
        dev_t = tf.to_int64(pad_begin)
      run = functional_ops.For(
          start=pad_begin,
          limit=slen_dim - pad_end,
          delta=1,
          inputs=[dev_t] + _Flatten(
              [theta, state0, inputs, acc_state, acc_extras]),
          body=ForwardLoopBody,
          rewrite_with_while=compiled)
      _, state1, _, acc_state, acc_extras = _Pack(
          run[1:],
          [self._theta, self._state, self._inputs, self._state, self._extras])

      return _Flatten([acc_state, state1, acc_extras])

    # The per-step backward computes:
    #    d_theta, d_state0, d_inputs = cell_grad(
    #        theta, state0, inputs, extras, d_state1)
    # where d_state1 is the backprop-ed gradient for state1, and
    # extras is the computed by the forward step to facilitate the
    # backward step.
    bak_sig = [
        self._theta,
        self._state,
        self._inputs,
        self._extras,
        self._state,
    ]

    @function.Defun(*_Dtypes(bak_sig))
    def Bak(*args):
      """Backward step."""
      (theta, state0, inputs, extras, d_state1) = _Pack(args, bak_sig)
      (dtheta, dstate0, dinputs, dcaptures) = self._cell_grad(
          theta, state0, inputs, extras, d_state1)
      _AssertIsCompatible(dtheta, self._theta)
      _AssertIsCompatible(dstate0, self._state)
      _AssertIsCompatible(dinputs, self._inputs)
      if dcaptures is None:
        # NOTE: Custom gradient fns can return None if they do not support
        # captured tensors. The return value is reserved for the future when
        # that may be supported.
        dcaptures = _EmptyLike(self._implicit_captures)
      _AssertIsCompatible(dcaptures, self._implicit_captures)

      # Make sure this function didn't capture anything different than the
      # cell_fn when reflected on at the beginning. Must come after the call
      # to cell_grad() which adds to the captured list.
      _AssertSameTensors(function.get_extra_inputs(),
                         self._implicit_captures.Flatten())

      (captured,) = _Pack(function.get_extra_args(), [self._implicit_captures])
      return _Flatten(
          _ConvertNoneGradientToZeros([theta, state0, inputs, captured],
                                      [dtheta, dstate0, dinputs, dcaptures]))

    # Define defuns used by a functional.if in BackwardLoopBody.
    state_if_sig = [self._state, self._state]

    @function.Defun(*_Dtypes(state_if_sig))
    def ReturnOrigState0(*args):
      """Returns original state0 from inputs."""
      (_, orig_state0) = _Pack(args, state_if_sig)
      return orig_state0.Flatten()

    @function.Defun(*_Dtypes(state_if_sig))
    def ReturnAccState(*args):
      """Returns acc_state[t-1] from inputs."""
      (acc_state, _) = _Pack(args, state_if_sig)
      return acc_state.Flatten()

    # Wraps cell_grad gradient function in a TF Function as a
    # for-loop's body for the Backward pass.
    #
    # The loop state is composed of:
    #  t: The loop variable. Timestep id.
    #  state0: the initial state for the entire backward loop.
    #  dev_t: The loop variable mirrored on the device.
    #  theta: the recurrent net's weights.
    #  inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t.
    #  acc_state: Each timestep's computed new state was stashed into
    #    acc_state by the Forward pass.
    #  acc_extras: Each timestep's computed extras was stashed into
    #    acc_extras by the Forward pass.
    #  d_theta: All timestep's gradient for theta is accumulated (added) into
    #      d_theta.
    #  d_state1: The backprop-ed gradient for the new stated computed by
    #      timestep t.
    #  d_inputs: d_inputs[t, :] is populated by the backward time step t.
    #  d_acc_state: The backprop-ed gradient for acc_state.
    #  d_captured: All timestep's gradient for theta is accumulated (added)
    #      into d_captured.
    bakloop_sig = [
        self._theta,
        self._state,
        self._inputs,
        self._state,
        self._extras,
        # End of forward params
        self._theta,
        self._state,
        self._inputs,
        self._state,
        self._implicit_captures,
    ]

    @function.Defun(tf.int32, dev_t_type, *_Dtypes(bakloop_sig))
    def BackwardLoopBody(*args):
      """Backward loop body function."""
      t, dev_t = args[0], args[1]
      (
          theta,
          orig_state0,
          inputs,
          acc_state,
          acc_extras,
          # End of forward params
          d_theta,
          d_state1,
          d_inputs,
          d_acc_state,
          d_captured) = (
              _Pack(args[2:], bakloop_sig))

      # The input recurrent state for time step t is previous time step's
      # output, or the original state0 when on time step 0.
      state_from_acc = _Index(acc_state, tf.maximum(0, t - 1))
      state0 = functional_ops.If(
          tf.equal(t, tf.constant(0, tf.int32)),
          _Flatten([state_from_acc, orig_state0]), ReturnOrigState0,
          ReturnAccState)
      state0 = orig_state0.Pack(state0)

      # The external inputs for time step t.
      inputs_t = _Index(inputs, t)
      # The extras for time step t.
      extras_t = _Index(acc_extras, t)

      d_state1 = _Add(_Index(d_acc_state, t), d_state1)
      (d_theta_t, d_state0, d_inputs_t, d_captured_t) = _Pack(
          Bak(*_Flatten([theta, state0, inputs_t, extras_t, d_state1])),
          [self._theta, self._state, self._inputs, self._implicit_captures])

      if self._unused_acc_state:
        # XLA IF op requires the same shape for if and else branches.
        d_state0 = d_state0.Transform(tf.reduce_sum)
      d_theta = _Add(d_theta, d_theta_t)
      d_inputs = _Update(d_inputs, d_inputs_t, dev_t)
      d_captured = _Add(d_captured, d_captured_t)

      # Make sure this function didn't capture anything different than the
      # cell_fn when reflected on at the beginning. Must come after the call
      # to Bak() which adds to the captured list.
      _AssertSameTensors(function.get_extra_inputs(),
                         self._implicit_captures.Flatten())

      return [tf.subtract(dev_t, 1)] + _Flatten([
          theta,
          orig_state0,
          inputs,
          acc_state,
          acc_extras,
          # End of forward params
          d_theta,
          d_state0,
          d_inputs,
          d_acc_state,
          d_captured,
      ])

    # Backward calls BackwardLoopBody n times.  Each time computes the backprop
    # for one time step of the recurrent net.
    backward_sig = [
        self._theta,
        self._state,
        self._inputs,
        self._state,
        self._extras,
        # End of forward params.
        self._state,
        self._state,
    ]

    @function.Defun(*_Dtypes(backward_sig), noinline=noinline)
    def Backward(*args):
      """Backward pass for the recurrent net."""
      # theta, state0, inputs are Forward's inputs.
      # acc_state is the accumulated 1st output of Forward.
      # acc_extras is the accumulated 2nd output of Forward.
      # d_acc_state is the gradient for acc_state.
      # d_state1 is the gradient for the final state computed by Forward.
      (theta, state0, inputs, acc_state, acc_extras, d_acc_state,
       d_state1) = _Pack(args, backward_sig)

      # Accumulators for gradients.
      d_theta = _EmptyLike(theta)
      d_inputs = _EmptyLike(inputs)
      d_captured = _EmptyLike(self._implicit_captures)

      # The sequence length.
      pad_begin, pad_end = _SeqPaddingLength(inputs)
      start = _SeqLenDim(inputs) - pad_end - 1

      if py_utils.use_tpu():
        dev_t = tf.to_int32(start)
      else:
        dev_t = tf.to_int64(start)
      run = functional_ops.For(
          start=start,
          limit=pad_begin - 1,
          delta=-1,
          inputs=[dev_t] + _Flatten([
              theta,
              state0,
              inputs,
              acc_state,
              acc_extras,
              d_theta,
              d_state1,
              d_inputs,
              d_acc_state,
              d_captured,
          ]),
          body=BackwardLoopBody,
          rewrite_with_while=compiled)

      (theta, state0, inputs, acc_state, acc_extras, d_theta, d_state0,
       d_inputs, d_acc_state, d_captured) = _Pack(run[1:], bakloop_sig)

      # Make sure this function didn't capture anything different than the
      # cell_fn when reflected on at the beginning. Must come after the
      # call to BackwardLoopBody, which adds to the captured list.
      _AssertSameTensors(function.get_extra_inputs(),
                         self._implicit_captures.Flatten())

      if self._unused_acc_state:
        # Match the shape of gradient of the init_state.
        d_state0 = self._state.Transform(tf.zeros_like)
      return _Flatten([d_theta, d_state0, d_inputs, acc_extras, d_captured])

    self._forward = Forward
Example #23
0
    def ComputeMetrics(self, decoder_outs, input_batch, ids_to_strings_fn):
        """Computes metrics on output from decoder.

    Args:
      decoder_outs: A `BeamSearchDecodeOutput`, a namedtuple containing the
        decode results.
      input_batch:  A `NestedMap` of tensors representing the source, target,
        and other components of the input batch.
      ids_to_strings_fn: a function of (ids, lens) -> strings, where ids has
        shape [batch, length], lens has shape [batch], and strings has shape
        [batch].

    Returns:
      A dict of Tensors containing decoder output and metrics.
    """
        topk = self.GetTopK(decoder_outs, ids_to_strings_fn=ids_to_strings_fn)
        tgt_batch = tf.shape(topk.scores)[0]
        num_hyps_per_beam = tf.shape(topk.scores)[1]
        tgt = input_batch.tgt
        tgt_lens = tf.cast(tf.round(tf.reduce_sum(1.0 - tgt.paddings, 1)),
                           tf.int32)
        tgt_lens = py_utils.HasShape(tgt_lens, [tgt_batch])
        transcripts = ids_to_strings_fn(tgt.labels, tgt_lens - 1)

        # Filter out all isolated '<noise>' tokens.
        noise_pattern = ' <noise> |^<noise> | <noise>$|^<noise>$'
        filtered_refs = tf.strings.regex_replace(transcripts, noise_pattern,
                                                 ' ')
        filtered_hyps = tf.strings.regex_replace(topk.decoded, noise_pattern,
                                                 ' ')
        # Compute translation quality scores for all hyps.
        filtered_refs = tf.tile(tf.reshape(filtered_refs, [-1, 1]),
                                [1, num_hyps_per_beam])
        filtered_hyps = tf.reshape(filtered_hyps, [-1])
        filtered_refs = tf.reshape(filtered_refs, [-1])
        tf.logging.info('filtered_refs=%s', filtered_refs)
        norm_wer_errors, norm_wer_words = self.ComputeNormalizedWER(
            filtered_hyps, filtered_refs, num_hyps_per_beam)

        ret_dict = {
            'target_ids': tgt.ids,
            'target_labels': tgt.labels,
            'target_weights': tgt.weights,
            'target_paddings': tgt.paddings,
            'transcripts': transcripts,
            'topk_decoded': topk.decoded,
            'topk_ids': topk.ids,
            'topk_lens': topk.lens,
            'topk_scores': topk.scores,
            'norm_wer_errors': norm_wer_errors,
            'norm_wer_words': norm_wer_words,
        }

        if not py_utils.use_tpu() and 'sample_ids' in input_batch:
            ret_dict['utt_id'] = input_batch.sample_ids

        ret_dict.update(
            self.AddAdditionalDecoderMetricsToGraph(topk, filtered_hyps,
                                                    filtered_refs, input_batch,
                                                    decoder_outs))
        return ret_dict
Example #24
0
    def PostProcessDecodeOut(self, dec_out_dict, dec_metrics_dict):
        p = self.params
        assert 'topk_scores' in dec_out_dict, dec_out_dict.keys()
        topk_scores = dec_out_dict['topk_scores']
        topk_decoded = dec_out_dict['topk_decoded']
        transcripts = dec_out_dict['transcripts']
        if not py_utils.use_tpu():
            utt_id = dec_out_dict['utt_id']
            assert len(utt_id) == len(transcripts)
        norm_wer_errors = dec_out_dict['norm_wer_errors']
        norm_wer_words = dec_out_dict['norm_wer_words']
        target_labels = dec_out_dict['target_labels']
        target_paddings = dec_out_dict['target_paddings']
        topk_ids = dec_out_dict['topk_ids']
        topk_lens = dec_out_dict['topk_lens']
        assert len(transcripts) == len(target_labels)
        assert len(transcripts) == len(target_paddings)
        assert len(transcripts) == len(topk_decoded)
        assert (len(topk_ids) == p.decoder.beam_search.num_hyps_per_beam *
                len(transcripts))
        assert len(norm_wer_errors) == len(transcripts)
        assert len(norm_wer_words) == len(transcripts)

        dec_metrics_dict['num_samples_in_batch'].Update(len(transcripts))

        def GetRefIds(ref_ids, ref_paddinds):
            assert len(ref_ids) == len(ref_paddinds)
            return_ids = []
            for i in range(len(ref_ids)):
                if ref_paddinds[i] == 0:
                    return_ids.append(ref_ids[i])
            return return_ids

        total_errs = 0
        total_oracle_errs = 0
        total_ref_words = 0
        total_token_errs = 0
        total_ref_tokens = 0
        total_norm_wer_errs = 0
        total_norm_wer_words = 0
        total_accurate_sentences = 0
        key_value_pairs = []
        for i in range(len(transcripts)):
            ref_str = transcripts[i]
            if not py_utils.use_tpu():
                tf.logging.info('utt_id: %s', utt_id[i])
            tf.logging.info('  ref_str: %s', ref_str)
            hyps = topk_decoded[i]
            ref_ids = GetRefIds(target_labels[i], target_paddings[i])
            hyp_index = i * p.decoder.beam_search.num_hyps_per_beam
            top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]]
            total_ref_tokens += len(ref_ids)
            _, _, _, token_errs = decoder_utils.EditDistanceInIds(
                ref_ids, top_hyp_ids)
            total_token_errs += token_errs

            assert p.decoder.beam_search.num_hyps_per_beam == len(hyps)
            filtered_ref = decoder_utils.FilterNoise(ref_str)
            filtered_ref = decoder_utils.FilterEpsilon(filtered_ref)
            oracle_errs = norm_wer_errors[i][0]
            for n, (score, hyp_str) in enumerate(zip(topk_scores[i], hyps)):
                tf.logging.info('  %f: %s', score, hyp_str)
                filtered_hyp = decoder_utils.FilterNoise(hyp_str)
                filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp)
                ins, subs, dels, errs = decoder_utils.EditDistance(
                    filtered_ref, filtered_hyp)
                # Note that these numbers are not consistent with what is used to
                # compute normalized WER.  In particular, these numbers will be inflated
                # when the transcript contains punctuation.
                tf.logging.info('  ins: %d, subs: %d, del: %d, total: %d', ins,
                                subs, dels, errs)
                hyp_norm_wer_errors = norm_wer_errors[i][n]
                hyp_norm_wer_words = norm_wer_words[i][n]
                # Only aggregate scores of the top hypothesis.
                if n == 0:
                    total_errs += errs
                    total_ref_words += len(
                        decoder_utils.Tokenize(filtered_ref))
                    total_norm_wer_errs += hyp_norm_wer_errors
                    if hyp_norm_wer_errors == 0:
                        total_accurate_sentences += 1
                    total_norm_wer_words += hyp_norm_wer_words
                    dec_metrics_dict['corpus_bleu'].Update(
                        filtered_ref, filtered_hyp)
                if hyp_norm_wer_errors < oracle_errs:
                    oracle_errs = hyp_norm_wer_errors
            total_oracle_errs += oracle_errs

        dec_metrics_dict['wer'].Update(total_errs / total_ref_words,
                                       total_ref_words)
        dec_metrics_dict['oracle_norm_wer'].Update(
            total_oracle_errs / total_ref_words, total_ref_words)
        dec_metrics_dict['sacc'].Update(
            total_accurate_sentences / len(transcripts), len(transcripts))
        dec_metrics_dict['norm_wer'].Update(
            total_norm_wer_errs / total_norm_wer_words, total_norm_wer_words)
        dec_metrics_dict['ter'].Update(total_token_errs / total_ref_tokens,
                                       total_ref_tokens)

        # Update any additional metrics.
        dec_metrics_dict = self.UpdateAdditionalMetrics(
            dec_out_dict, dec_metrics_dict)
        return key_value_pairs
Example #25
0
    def FProp(self, theta, batch, state0=None):
        """Encodes source as represented by 'inputs' and 'paddings'.

    Args:
      theta: A NestedMap object containing weights' values of this
        layer and its children layers.
      batch: A NestedMap with fields:

        - src_inputs - The inputs tensor. It is expected to be of shape [batch,
          time, feature_dim, channels].
        - paddings - The paddings tensor. It is expected to be of shape [batch,
          time].
      state0: Recurrent input state. Not supported/ignored by this encoder.

    Returns:
      A NestedMap containing

      - 'encoded': a feature tensor of shape [time, batch, depth]
      - 'padding': a 0/1 tensor of shape [time, batch]
      - 'state': the updated recurrent state
      - '${layer_type}_${layer_index}': The per-layer encoder output. Each one
        is a NestedMap containing 'encoded' and 'padding' similar to regular
        final outputs, except that 'encoded' from conv or conv_lstm layers are
        of shape [time, batch, depth, channels].
    """
        p = self.params
        inputs, paddings = batch.src_inputs, batch.paddings
        outputs = py_utils.NestedMap()
        with tf.name_scope(p.name):
            # Adding specAugmentation.
            if p.use_specaugment and not self.do_eval:
                inputs, paddings = self.specaugment.FProp(
                    theta.specaugment, inputs, paddings)
            # Add a few extra padded timesteps at the end. This is for ensuring the
            # correctness of the conv-layers at the edges.
            if p.pad_steps > 0:
                # inplace_update() is not supported by TPU for now. Since we have done
                # padding on the input_generator, we may avoid this additional padding.
                assert not py_utils.use_tpu()
                inputs_pad = tf.zeros(
                    inplace_ops.inplace_update(tf.shape(inputs), 1,
                                               p.pad_steps), inputs.dtype)
                paddings_pad = tf.ones(
                    inplace_ops.inplace_update(tf.shape(paddings), 1,
                                               p.pad_steps), paddings.dtype)
                inputs = tf.concat([inputs, inputs_pad], 1, name='inputs')
                paddings = tf.concat([paddings, paddings_pad], 1)

            plots = [
                summary_utils.PrepareSequenceForPlot(
                    tf.transpose(inputs, [0, 1, 3, 2]), paddings, 'inputs')
            ]

            conv_out = inputs
            out_padding = paddings
            for i, conv_layer in enumerate(self.conv):
                conv_out, out_padding = conv_layer.FProp(
                    theta.conv[i], conv_out, out_padding)
                if p.extra_per_layer_outputs:
                    conv_out *= (1.0 -
                                 out_padding[:, :, tf.newaxis, tf.newaxis])
                    outputs['conv_%d' % i] = py_utils.NestedMap(
                        encoded=tf.transpose(conv_out,
                                             [1, 0, 2, 3]),  # to [t, b, d, c]
                        padding=tf.transpose(out_padding))
                plots.append(
                    summary_utils.PrepareSequenceForPlot(
                        tf.transpose(conv_out, [0, 1, 3, 2]), out_padding,
                        'conv_%d_out' % i))

            def TransposeFirstTwoDims(t):
                first_dim = tf.shape(t)[0]
                second_dim = tf.shape(t)[1]
                t_new = tf.transpose(
                    tf.reshape(t, [first_dim, second_dim, -1]), [1, 0, 2])
                t_shape_new = tf.concat([[second_dim], [first_dim],
                                         tf.shape(t)[2:]], 0)
                return tf.reshape(t_new, t_shape_new)

            # Now the conv-lstm part.
            conv_lstm_out = conv_out
            conv_lstm_out_padding = out_padding
            for i, (rnn, cnn) in enumerate(
                    zip(self.conv_lstm_rnn, self.conv_lstm_cnn)):
                conv_lstm_in = conv_lstm_out
                # Move time dimension to be the first.
                conv_lstm_in = TransposeFirstTwoDims(conv_lstm_in)
                conv_lstm_in = tf.expand_dims(conv_lstm_in, 2)
                conv_lstm_in_padding = tf.expand_dims(
                    tf.transpose(conv_lstm_out_padding), 2)
                lstm_out = rnn.FProp(theta.conv_lstm_rnn[i], conv_lstm_in,
                                     conv_lstm_in_padding)
                # Move time dimension to be the second.
                cnn_in = TransposeFirstTwoDims(lstm_out)
                cnn_in = tf.squeeze(cnn_in, 2)
                cnn_in_padding = conv_lstm_out_padding
                cnn_out, cnn_out_padding = cnn.FProp(theta.conv_lstm_cnn[i],
                                                     cnn_in, cnn_in_padding)
                conv_lstm_out, conv_lstm_out_padding = cnn_out, cnn_out_padding
                if p.extra_per_layer_outputs:
                    conv_lstm_out *= (
                        1.0 -
                        conv_lstm_out_padding[:, :, tf.newaxis, tf.newaxis])
                    outputs['conv_lstm_%d' % i] = py_utils.NestedMap(
                        encoded=tf.transpose(conv_lstm_out,
                                             [1, 0, 2, 3]),  # to [t, b, d, c]
                        padding=tf.transpose(conv_lstm_out_padding))
                plots.append(
                    summary_utils.PrepareSequenceForPlot(
                        conv_lstm_out, conv_lstm_out_padding,
                        'conv_lstm_%d_out' % i))

            # Need to do a reshape before starting the rnn layers.
            conv_lstm_out = py_utils.HasRank(conv_lstm_out, 4)
            conv_lstm_out_shape = tf.shape(conv_lstm_out)
            new_shape = tf.concat([conv_lstm_out_shape[:2], [-1]], 0)
            conv_lstm_out = tf.reshape(conv_lstm_out, new_shape)
            if self._first_lstm_input_dim_pad:
                conv_lstm_out = tf.pad(
                    conv_lstm_out,
                    [[0, 0], [0, 0], [0, self._first_lstm_input_dim_pad]])

            conv_lstm_out = py_utils.HasShape(
                conv_lstm_out, [-1, -1, self._first_lstm_input_dim])

            # Transpose to move the time dimension to be the first.
            rnn_in = tf.transpose(conv_lstm_out, [1, 0, 2])
            rnn_padding = tf.expand_dims(tf.transpose(conv_lstm_out_padding),
                                         2)
            # rnn_in is of shape [time, batch, depth]
            # rnn_padding is of shape [time, batch, 1]

            # Now the rnn layers.
            num_skips = 0
            for i in range(p.num_lstm_layers):
                rnn_out = self.rnn[i].FProp(theta.rnn[i], rnn_in, rnn_padding)
                residual_index = i - p.residual_start + 1
                if p.residual_start > 0 and residual_index >= 0:
                    if residual_index % p.residual_stride == 0:
                        residual_in = rnn_in
                    if residual_index % p.residual_stride == p.residual_stride - 1:
                        # Highway skip connection.
                        if p.highway_skip:
                            rnn_out = self.highway_skip[num_skips].FProp(
                                theta.highway_skip[num_skips], residual_in,
                                rnn_out)
                            num_skips += 1
                        else:
                            # Residual skip connection.
                            rnn_out += py_utils.HasShape(
                                residual_in, tf.shape(rnn_out))
                if p.project_lstm_output and (i < p.num_lstm_layers - 1):
                    # Projection layers.
                    rnn_out = self.proj[i].FProp(theta.proj[i], rnn_out,
                                                 rnn_padding)
                if i == p.num_lstm_layers - 1:
                    rnn_out *= (1.0 - rnn_padding)
                if p.extra_per_layer_outputs:
                    rnn_out *= (1.0 - rnn_padding)
                    outputs['rnn_%d' % i] = py_utils.NestedMap(
                        encoded=rnn_out, padding=tf.squeeze(rnn_padding, [2]))
                # Stacking layer connection.
                if p.layer_index_before_stacking == i:
                    # Stacking layer expects input tensor shape as [batch, time, feature].
                    # So transpose the tensors before and after the layer.
                    rnn_out, rnn_padding = self.stacking.FProp(
                        tf.transpose(rnn_out, [1, 0, 2]),
                        tf.transpose(rnn_padding, [1, 0, 2]))
                    rnn_out = tf.transpose(rnn_out, [1, 0, 2])
                    rnn_padding = tf.transpose(rnn_padding, [1, 0, 2])

                plots.append(
                    summary_utils.PrepareSequenceForPlot(
                        tf.transpose(rnn_out, [1, 0, 2]),
                        tf.transpose(rnn_padding, [1, 0, 2]),
                        'rnn_%d_out' % i))
                rnn_in = rnn_out
            final_out = rnn_in

            summary_utils.PlotSequenceFeatures(list(reversed(plots)),
                                               'encoder_example',
                                               xlabel='Time')

            outputs['encoded'] = final_out
            outputs['padding'] = tf.squeeze(rnn_padding, [2])
            outputs['state'] = py_utils.NestedMap()
            return outputs
Example #26
0
    def PostProcess(self, dec_out_dict, dec_metrics_dict):
        p = self.params
        assert 'topk_scores' in dec_out_dict, list(dec_out_dict.keys())
        topk_scores = dec_out_dict['topk_scores']
        topk_decoded = dec_out_dict['topk_decoded']
        transcripts = dec_out_dict['transcripts']
        if not py_utils.use_tpu():
            utt_id = dec_out_dict['utt_id']
            assert len(utt_id) == len(transcripts)
        norm_wer_errors = dec_out_dict['norm_wer_errors']
        norm_wer_words = dec_out_dict['norm_wer_words']
        target_labels = dec_out_dict['target_labels']
        target_paddings = dec_out_dict['target_paddings']
        topk_ids = dec_out_dict['topk_ids']
        topk_lens = dec_out_dict['topk_lens']
        if 'example_weights' in dec_out_dict:
            example_weights = dec_out_dict['example_weights']
        else:
            example_weights = np.ones([len(transcripts)], np.float32)
        assert len(transcripts) == len(target_labels)
        assert len(transcripts) == len(target_paddings)
        assert len(transcripts) == len(topk_decoded)
        assert len(norm_wer_errors) == len(transcripts)
        assert len(norm_wer_words) == len(transcripts)

        num_samples_in_batch = example_weights.sum()
        dec_metrics_dict['num_samples_in_batch'].Update(num_samples_in_batch)

        def GetRefIds(ref_ids, ref_paddinds):
            assert len(ref_ids) == len(ref_paddinds)
            return_ids = []
            for i in range(len(ref_ids)):
                if ref_paddinds[i] == 0:
                    return_ids.append(ref_ids[i])
            return return_ids

        total_norm_wer_errs = (norm_wer_errors[:, 0] * example_weights).sum()
        total_norm_wer_words = (norm_wer_words[:, 0] * example_weights).sum()

        dec_metrics_dict['norm_wer'].Update(
            total_norm_wer_errs / total_norm_wer_words, total_norm_wer_words)

        filtered_transcripts = []
        filtered_top_hyps = []
        for ref_str, hyps in zip(transcripts, topk_decoded):
            filtered_ref = decoder_utils.FilterNoise(ref_str)
            filtered_ref = decoder_utils.FilterEpsilon(filtered_ref)
            filtered_transcripts.append(filtered_ref)
            filtered_hyp = decoder_utils.FilterNoise(hyps[0])
            filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp)
            filtered_top_hyps.append(filtered_hyp)
            dec_metrics_dict['corpus_bleu'].Update(filtered_ref, filtered_hyp)

        total_errs = 0
        total_oracle_errs = 0
        total_ref_words = 0
        total_token_errs = 0
        total_ref_tokens = 0
        total_accurate_sentences = 0
        key_value_pairs = []

        if p.include_auxiliary_metrics:
            for i in range(len(transcripts)):
                ref_str = transcripts[i]
                if not py_utils.use_tpu():
                    tf.logging.info('utt_id: %s', utt_id[i])
                if self.cluster.add_summary:
                    tf.logging.info(
                        '  ref_str: %s',
                        ref_str.decode('utf-8') if p.log_utf8 else ref_str)
                hyps = topk_decoded[i]
                num_hyps_per_beam = len(hyps)
                ref_ids = GetRefIds(target_labels[i], target_paddings[i])
                hyp_index = i * num_hyps_per_beam
                top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]]
                if self.cluster.add_summary:
                    tf.logging.info('  ref_ids: %s', ref_ids)
                    tf.logging.info('  top_hyp_ids: %s', top_hyp_ids)
                total_ref_tokens += len(ref_ids)
                _, _, _, token_errs = decoder_utils.EditDistanceInIds(
                    ref_ids, top_hyp_ids)
                total_token_errs += token_errs

                filtered_ref = filtered_transcripts[i]
                oracle_errs = norm_wer_errors[i][0]
                for n, (score, hyp_str) in enumerate(zip(topk_scores[i],
                                                         hyps)):
                    oracle_errs = min(oracle_errs, norm_wer_errors[i, n])
                    if self.cluster.add_summary:
                        tf.logging.info(
                            '  %f: %s', score,
                            hyp_str.decode('utf-8') if p.log_utf8 else hyp_str)
                    # Only aggregate scores of the top hypothesis.
                    if n != 0:
                        continue
                    filtered_hyp = filtered_top_hyps[i]
                    _, _, _, errs = decoder_utils.EditDistance(
                        filtered_ref, filtered_hyp)
                    total_errs += errs
                    total_ref_words += len(
                        decoder_utils.Tokenize(filtered_ref))
                    if norm_wer_errors[i, n] == 0:
                        total_accurate_sentences += 1

                total_oracle_errs += oracle_errs

            dec_metrics_dict['wer'].Update(
                total_errs / max(1., total_ref_words), total_ref_words)
            dec_metrics_dict['oracle_norm_wer'].Update(
                total_oracle_errs / max(1., total_ref_words), total_ref_words)
            dec_metrics_dict['sacc'].Update(
                total_accurate_sentences / len(transcripts), len(transcripts))
            dec_metrics_dict['ter'].Update(
                total_token_errs / max(1., total_ref_tokens), total_ref_tokens)

        return key_value_pairs
Example #27
0
  def Task(self):
    p = feature_neighborhood_model_trans.FeatureNeighborhoodModelTrans.Params()
    if self._share_embeddings:
      output_symbol_path = FLAGS.input_symbols
    else:
      output_symbol_path = FLAGS.output_symbols
    _, p.input_symbols, p.output_symbols = (
        fn.FeatureNeighborhoodInput.ParameterizedConfigs(
            input_symbol_path=FLAGS.input_symbols,
            output_symbol_path=output_symbol_path,
            append_eos=FLAGS.append_eos,
            max_spelling_len=FLAGS.max_spelling_len,
            max_pronunciation_len=FLAGS.max_pronunciation_len,
            max_neighbors=FLAGS.max_neighbors))
    p.input_vocab_size = p.input_symbols.num_symbols()
    p.output_vocab_size = p.output_symbols.num_symbols()
    p.max_neighbors = FLAGS.max_neighbors
    p.max_pronunciation_len = FLAGS.max_pronunciation_len
    p.max_spelling_len = FLAGS.max_spelling_len
    p.start = p.output_symbols.find("<s>")
    p.share_embeddings = self._share_embeddings

    if self._share_embeddings:
      vocab_size = p.input_vocab_size
    else:
      vocab_size = p.output_vocab_size

    p = base_config.SetupTransformerParams(
        p,
        name="feature_neighborhood_with_neighbors",
        vocab_size=vocab_size,
        model_dim=p.embedding_dim,
        hidden_dim=p.enc_units,
        num_heads=self._num_heads,
        num_layers=self._num_layers,
        learning_rate=3.0,
        warmup_steps=40000,
        residual_dropout_prob=self._residual_dropout_prob,
        relu_dropout_prob=self._relu_dropout_prob,
        input_dropout_prob=self._input_dropout_prob,
        atten_dropout_prob=self._atten_dropout_prob,
        label_smoothing_uncertainty=self._label_smoothing_uncertainty)
    if not self._share_embeddings:
      p.encoder.token_emb.vocab_size = p.input_vocab_size
    p.eval.samples_per_summary = 20000
    # TODO(llion): Might need to change the output vocab size to one that can
    # be sharded to run efficiently on TPUs.
    p.decoder.softmax.num_shards = 1
    p.decoder.target_seq_len = p.max_pronunciation_len

    if py_utils.use_tpu():
      p.decoder.beam_search = model_helper.ChangeToBeamSearchTpuHelper(
          p.decoder.beam_search)

    if FLAGS.neigh_use_tpu:
      for pp in [p.encoder, p.decoder]:
        pp.token_emb = model_helper.ChangeToSimpleEmbedding(pp.token_emb)
      p.decoder.softmax = model_helper.ChangeToSimpleSoftmax(p.decoder.softmax)

    p.use_neighbors = self._use_neighbors
    if self._use_neighbors:
      p.spell_encoder = base_config.SetupTransformerEncoder(
          vocab_size=p.input_vocab_size,
          model_dim=p.embedding_dim,
          hidden_dim=p.enc_units,
          num_heads=self._num_heads,
          num_layers=self._num_layers,
          residual_dropout_prob=self._residual_dropout_prob,
          relu_dropout_prob=self._relu_dropout_prob,
          input_dropout_prob=self._input_dropout_prob,
          atten_dropout_prob=self._atten_dropout_prob)
      if self._attention_type != "CONCATAVE":
        p.pron_encoder = base_config.SetupTransformerEncoder(
            vocab_size=p.output_vocab_size,
            model_dim=p.embedding_dim,
            hidden_dim=p.enc_units,
            num_heads=self._num_heads,
            num_layers=self._num_layers,
            residual_dropout_prob=self._residual_dropout_prob,
            relu_dropout_prob=self._relu_dropout_prob,
            input_dropout_prob=self._input_dropout_prob,
            atten_dropout_prob=self._atten_dropout_prob)
      else:
        if not self._share_embeddings:
          raise ValueError("Must share embeddings to concat spelling and pron.")
      if FLAGS.neigh_use_tpu:
        for pp in [p.spell_encoder, p.pron_encoder]:
          if pp:
            pp.token_emb = model_helper.ChangeToSimpleEmbedding(pp.token_emb)

    p.also_shuffle_neighbors = self._also_shuffle_neighbors
    if self._use_neigh_id_emb:
      assert self._use_neighbors
      p.use_neigh_id_emb = True
      if self._attention_type == "CONCAT":
        neigh_id_emb = layers.EmbeddingLayer.Params().Set(
            vocab_size=FLAGS.max_neighbors + 1,  # +1 to include the main input
            embedding_dim=p.embedding_dim,
            max_num_shards=1,
            params_init=py_utils.WeightInit.Gaussian(
                1.0 / maths.sqrt(p.embedding_dim)),
            scale_sqrt_depth=True)
        p.encoder.task_emb = neigh_id_emb
      elif self._attention_type == "AVERAGE":
        neigh_id_emb = layers.EmbeddingLayer.Params().Set(
            vocab_size=FLAGS.max_neighbors,
            embedding_dim=p.embedding_dim,
            max_num_shards=1,
            params_init=py_utils.WeightInit.Gaussian(
                1.0 / maths.sqrt(p.embedding_dim)),
            scale_sqrt_depth=True)
        p.spell_encoder.task_emb = neigh_id_emb
        p.pron_encoder.task_emb = neigh_id_emb

    p.neigh_att_type = self._attention_type
    p.aux_dropout_prob = self._aux_dropout_prob

    return p
    def GetProjectLastDim(cls, inputs, weight, input_dim, output_dim,
                          proj_obj):
        """Linear projection on the last dim of the input tensor along with pruning.

    This is a TPU efficient implementation to avoid reshaping inputs to Rank-2
    tensor by using Einsum for the compute.

    Args:
      inputs: An input Tensor, the last dimension of which is input_dim.
      weight: A weight matrix with shape [input_dim, output_dim].
      input_dim: An integer or a symbolic dim, the last dimension of the inputs.
      output_dim: An integer or a symbolic dim, the last dimension of the
                  outputs.
      proj_obj: a ProjectionLayer object.

    Returns:
      An output Tensor of the same rank as inputs, the last dimension is
      output_dim.
    """
        theta = proj_obj.theta
        p = proj_obj.params
        input_dim = int(
            symbolic.ToStatic(input_dim) if symbolic.IsExpr(input_dim
                                                            ) else input_dim)
        output_dim = int(
            symbolic.ToStatic(output_dim) if symbolic.IsExpr(output_dim
                                                             ) else output_dim)
        if (py_utils.use_tpu() and inputs.shape is not None
                and inputs.shape.rank is not None and inputs.shape.rank < 26):
            # Avoids reshape if feasible and uses Einsum.
            if inputs.shape.rank == 2:
                outputs = tf.matmul(inputs, weight)
            else:
                outputs = cls.GetEinSumResult(inputs, proj_obj)
        else:
            if p.pruning_hparams_dict[
                    'compression_option'] == 9 and p.pruning_hparams_dict[
                        'compress_input']:
                blocked_inputs = tf.reshape(
                    inputs,
                    py_utils.ToStaticShape(
                        [-1, p.pruning_hparams_dict['input_block_size']]))
                compressed_inputs = tf.reshape(
                    py_utils.Matmul(blocked_inputs, theta.b_matrix_tfvar),
                    py_utils.ToStaticShape([
                        -1, input_dim //
                        p.pruning_hparams_dict['input_compression_factor']
                    ]))
            else:
                compressed_inputs = tf.reshape(
                    inputs, py_utils.ToStaticShape([-1, input_dim]))

            if p.pruning_hparams_dict['compression_option'] == 10:
                if p.pruning_hparams_dict['block_method'] == 'mask':
                    intermediate_result = py_utils.Matmul(
                        compressed_inputs,
                        tf.multiply(theta.c_matrix_tfvar, theta.c_mask_tfvar))
                elif p.pruning_hparams_dict['block_method'] == 'loop':
                    num_blocks = p.pruning_hparams_dict[
                        'block_compression_factor']
                    input_splitted = tf.split(compressed_inputs,
                                              num_blocks,
                                              axis=-1)
                    output_splitted = []
                    for i, input_i in enumerate(input_splitted):
                        output_splitted.append(
                            py_utils.Matmul(input_i,
                                            theta.c_matrix_tfvar[i, :, :]))
                    intermediate_result = tf.concat(output_splitted, axis=-1)
            else:
                intermediate_result = py_utils.Matmul(compressed_inputs,
                                                      theta.c_matrix_tfvar)

            if p.pruning_hparams_dict[
                    'compression_option'] == 9 and p.pruning_hparams_dict[
                        'compress_output']:
                blocked_intermediate_result = tf.reshape(
                    intermediate_result,
                    py_utils.ToStaticShape([
                        -1, p.pruning_hparams_dict['output_block_size'] //
                        p.pruning_hparams_dict['output_compression_factor']
                    ]))
                outputs = py_utils.Matmul(blocked_intermediate_result,
                                          theta.d_matrix_tfvar)
            else:
                outputs = intermediate_result

            outputs = tf.reshape(
                outputs,
                tf.concat([
                    tf.cast(py_utils.GetShape(inputs)[:-1], tf.int32),
                    py_utils.ToStaticShape([output_dim])
                ],
                          axis=0))

        return outputs
Example #29
0
 def _DecoderDevice(self):
   """Returns the device to run the decoder computation."""
   if py_utils.use_tpu():
     return tf.device(self.cluster.WorkerDeviceInModelSplit(1))
   else:
     return tf.device('')
Example #30
0
 def __init__(self, params):
     params.pad_to_max_seq_length = True
     params.fixed_input_shape = params.fixed_input_shape or py_utils.use_tpu(
     )
     super().__init__(params)