Exemple #1
0
 def _BeamSearchDecode(self, input_batch):
     p = self.params
     with tf.name_scope('fprop'), tf.name_scope(p.name):
         encoder_outputs = self.enc.FPropDefaultTheta(input_batch.src)
         encoder_outputs = self.dec.AddExtraDecodingInfo(
             encoder_outputs, input_batch.tgt)
         decoder_outs = self.dec.BeamSearchDecode(encoder_outputs)
         return self._ProcessBeamSearchDecodeOut(input_batch,
                                                 encoder_outputs,
                                                 decoder_outs)
Exemple #2
0
 def FProp(self, theta, *args):
     r"""Applies lambda(x, \*kwargs) for every non-None arg."""
     del theta
     p = self.params
     with tf.name_scope(p.name):
         ret = [None if x is None else p.fn(x, **p.kwargs) for x in args]
         return tuple(ret) if len(ret) > 1 else ret[0]
Exemple #3
0
    def FProp(self, theta, *args):
        """FProp through multiple devices in the split.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      *args: A tuple of Tensors (one or more). Every tensor's first dimension is
        the same (the batch dimension).

    Returns:
      The sub layer's output.
    """
        p = self.params
        with tf.name_scope(p.name):
            assert all(isinstance(x, tf.Tensor) for x in args)
            cluster = self.cluster
            num = cluster.num_devices_per_split
            if num == 1:
                return self.sub.FProp(theta.sub, *args)
            inps = py_utils.SplitRecursively(list(args), num, axis=0)
            outs = []
            for i, xs in enumerate(inps):
                device = cluster.WorkerDeviceInModelSplit(i)
                tf.logging.info('%d on device %s', i, device)
                with tf.device(device):
                    ys = self.sub.FProp(theta.sub, *xs)
                    if isinstance(ys, tuple):
                        outs += [list(ys)]
                    else:
                        outs += [ys]  # ys is a single tensor
            ret = py_utils.ConcatRecursively(outs, axis=0)
            if isinstance(ret, list):
                return tuple(ret)
            else:
                return ret  # ys is a single tensor
Exemple #4
0
 def FProp(self, theta, *args):
     p = self.params
     with tf.name_scope(p.name):
         args = _ToTuple(self.body.FProp(theta.body, *args))
         for fetch in p.fetches:
             args += (self.body.GetDescendant(fetch).activation, )
         return args
 def Step(recurrent_theta, state0, inputs):
     """Computes one decoder step."""
     del inputs
     with tf.name_scope('single_sampler_step'):
         # Compute logits and states.
         bs_result, bs_state1 = pre_step_callback(
             recurrent_theta.theta,
             recurrent_theta.encoder_outputs,
             tf.expand_dims(state0.ids, 1),  # [batch, 1].
             state0.bs_state,
             num_hyps_per_beam=1)
         batch = tf.shape(bs_result.log_probs)[0]
         state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
         state1.logits = bs_result.log_probs
         # Sample ids from logits. [batch].
         state1.ids = tf.reshape(
             tf.random.stateless_categorical(
                 state1.logits / p.temperature,
                 num_samples=1,
                 seed=tf.stack(
                     [recurrent_theta.random_seed, state0.timestep]),
                 dtype=state0.ids.dtype,
                 name='sample_next_id'), [batch])
         if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
             state1.ids = tf.where(
                 tf.math.logical_and(
                     bs_result.is_last_chunk,
                     tf.equal(state1.ids, p.target_eoc_id)),
                 tf.fill(tf.shape(state1.ids), p.target_eos_id),
                 state1.ids)
         state1.bs_state = post_step_callback(
             recurrent_theta.theta, recurrent_theta.encoder_outputs,
             state1.ids, bs_state1)
     return state1, py_utils.NestedMap()
def fast_gather(values,
                ids,
                ids_size,
                max_value=None,
                axis=0,
                batch_major_state=True):
    """Fast implementation of gather on TPUs.

  Args:
    values: Values to gather from.
    ids: ids (rows to gather)
    ids_size: id space size.
    max_value: Optional hint on maximum value for int32 that allows to speed up
      the gather operation.
    axis: axis to gather on. Defaults to 0 (rows).
    batch_major_state: Whether the values to gather from use batch major or not.
      Defaults to True.

  Returns:
    Gathered values.
  Raises:
    Value error if values is type int64.
  """
    values = tf.convert_to_tensor(values)
    ids = tf.convert_to_tensor(ids)
    with tf.name_scope("fast_gather"):
        return _Gatherer(ids, ids_size)(values,
                                        max_value=max_value,
                                        axis=axis,
                                        batch_major_state=batch_major_state)
Exemple #7
0
    def PrepareExternalInputs(self, theta, external_inputs):
        """Prepares external inputs for each sub-step.

    The external_inputs parameter of this method is processed by the
    external_inputs of each sub-step, then processed by the sub-step's
    PrepareExternalInputs method.

    Args:
      theta: variables used by sub-steps.
      external_inputs: A NestedMap of [n_batch, ...] tensors.

    Returns:
      A NestedMap of prepared inputs, where the keys are the names of
        each sub-step.
    """
        graph_tensors = builder_layers.GraphTensors()
        graph_tensors.StoreTensor('external_inputs', external_inputs)
        prepared_inputs = py_utils.NestedMap()
        with tf.name_scope(self.params.name):
            for seq in self._seq:
                if seq.external_signature:
                    template = py_utils.NestedMap(
                        inputs=seq.external_signature.inputs)
                    packed = template.Transform(graph_tensors.GetTensor)
                    seq_external_inputs = packed.inputs[0]
                    prepared_inputs[seq.name] = seq.step.PrepareExternalInputs(
                        theta[seq.name], seq_external_inputs)
                else:
                    prepared_inputs[seq.name] = py_utils.NestedMap()
        return prepared_inputs
Exemple #8
0
    def FProp(self, theta, *args):
        p = self.params
        # Collects all variable key and values into sets.
        theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars,
                                            p.repeat)

        def _ArgsToState(arg_list):
            """Returns a NestedMap from a list of FProp args."""
            state = py_utils.NestedMap()
            # Maintains a mapping from arg_idx to tensor. states cannot contains
            # None tensors.
            for idx in range(len(args)):
                if arg_list[idx] is not None:
                    state['_s{}'.format(idx)] = arg_list[idx]
            return state

        def _StateToArgs(state):
            """Returns a list of FProp args from a NestedMap."""
            arg_list = []
            for idx in range(len(args)):
                attr = '_s{}'.format(idx)
                arg_list.append(state[attr] if attr in state else None)
                if arg_list[-1] is not None:
                    arg_list[-1].set_shape(args[idx].shape)
            return arg_list

        def _CellFn(unused_theta, state0, theta_i):
            """Recurrent cell function wrapper of body.FProp."""
            # Retrieves fprop arguments from state and sets shapes.
            frop_inputs = _StateToArgs(state0)

            # Sets shapes for theta_i as well.
            for dst, src in zip(theta_i.Flatten(), theta_stack.Flatten()):
                if src is not None:
                    dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

            # Runs the actual body.FProp
            frop_outputs = self.body.FProp(theta_i, *frop_inputs)
            frop_outputs = _ToTuple(frop_outputs)
            assert len(frop_outputs) == len(frop_inputs)

            # Passes fprop outputs to the next layer through state.
            state1 = _ArgsToState(frop_outputs)
            return state1, py_utils.NestedMap()

        with tf.name_scope(p.name):
            # Add FProp arg list to state0.
            state0 = _ArgsToState(args)
            # Runs body.FProp k times using Recurrent where k = dim 0 of var_nmap.
            _, state1 = recurrent.Recurrent(
                theta=py_utils.NestedMap(),
                state0=state0,
                inputs=theta_stack,  # Pass cell_fn theta through inputs.
                cell_fn=_CellFn)

            # Retrieves fprop outputs from state1 and sets shapes.
            output_tensors = _StateToArgs(state1)
            return output_tensors[0] if len(args) == 1 else tuple(
                output_tensors)
Exemple #9
0
 def FProp(self, theta, x):
     tf.logging.vlog(1, 'layer %s', self.params.name)
     with tf.name_scope(self.params.name):
         for (name, ch) in self._seq:
             th = theta[name]
             tf.logging.vlog(1, '  call %s %s %s', ch.params.name, ch, x)
             x = ch.FProp(th, x)
         return x
    def FProp(self, theta, input_batch):
        """Encodes source as represented by `inputs` and `paddings`.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      input_batch: A `.NestedMap` with fields:
        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].

    Returns:
      A NestedMap containing:

      - encoded: The encoded features, a tensor of shape [time, batch, depth]
      - padding: of shape [time, batch]
      - segment_id: [time, batch] if packed inputs are supported by the model
        (and all layers), or None otherwise.
    """
        p = self.params
        src_segment_id = None
        with tf.name_scope(p.name):
            # Now the rnn layers.
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            self._emb_out = xs
            ps = paddings
            # When cc_schedule is specified, make sure lstm_tpl is QuantizedLSTMCell
            # with the same cc_schedule so that the RNN layer output is within
            # clipping range.
            xs = self.rnn[0].FProp(theta.rnn[0], xs, ps)
            xs = self.dropout.FProp(theta.dropout, xs)
            for i in range(1, p.num_lstm_layers):
                layer = self.rnn[i]
                ys, _ = layer.FProp(theta.rnn[i], xs, ps)
                ys = self.dropout.FProp(theta.dropout, ys)
                if hasattr(layer.params, 'cell'):
                    layer_params = layer.params.cell
                else:
                    layer_params = layer.params
                if layer_params.num_input_nodes == layer_params.num_output_nodes:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    # When cc_schedule is specified, make sure lstm_tpl is
                    # QuantizedLSTMCell with the same cc_schedule so that the RNN layer
                    # output is within clipping range.
                    xs = ys
            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id)
Exemple #11
0
 def FProp(self, theta, current_step):
     p = self.params
     assert p.total_steps > 0
     assert p.initial_value > p.final_value
     with tf.name_scope(p.name):
         decay_gap = p.initial_value - p.final_value
         return p.final_value + 0.5 * decay_gap * (1 + tf.cos(
             math.pi *
             tf.minimum(1.0,
                        tf.cast(current_step, tf.float32) / p.total_steps)))
Exemple #12
0
    def FProp(self, theta, inputs, paddings):
        """Apply convolution to inputs.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor. It is expected to be of shape [batch, time,
        frequency, channel]. The time dimension corresponds to the height
        dimension as in images and the frequency dimension corresponds to the
        width dimension as in images.
      paddings: The paddings tensor, expected to be of shape [batch, time].

    Returns:
      outputs, out_paddings pair.
    """
        p = self.params
        with tf.name_scope(p.name):
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(paddings), [-1, -1]),
                py_utils.assert_shape_match(
                    tf.shape(inputs),
                    tf.concat([
                        tf.shape(paddings),
                        [-1, symbolic.ToStatic(self.input_channels)]
                    ], 0))
            ], inputs)

            def _ApplyPadding(tensor_in, padding_in):
                padding_expanded = tf.expand_dims(
                    tf.expand_dims(padding_in, -1), -1)
                return tensor_in * (1.0 - padding_expanded)

            # Zeroing out padded inputs.
            inputs = _ApplyPadding(inputs, paddings)

            # Apply conv on 'inputs'.
            out = self._ApplyConv(theta, inputs)

            if p.partial_conv:
                out = self._RescaleBoundary(out, paddings)
            # NOTE: this may be slightly inaccurate when p.dilation_rate[0] > 1.
            # But there's likely no real problems. Trying to set it gives an error:
            # pooling with SAME padding is not implemented for dilation_rate > 1.
            # NOTE: we use window=p.filter_stride[0] to be compatible with legacy
            # implementation.  Consider updating it to be the actual shape.
            conv_padding = ComputeConvOutputPadding(paddings,
                                                    window=p.filter_stride[0],
                                                    stride=p.filter_stride[0])
            # Assuming padded nodes will be properly zero-ed out if necessary by
            # sub-sequent layers.
            # out = _ApplyPadding(out, conv_padding)
            out = py_utils.HasShape(
                out, symbolic.ToStatic(self.OutShape(tf.shape(inputs))))
            return out, conv_padding
Exemple #13
0
    def FProp(self, theta, *args):
        r"""Applies a function (p.fn) on args.

    Args:
      theta: Unused.
      *args: A tuple of Tensors (one or more).

    Returns:
      fn(\*args).
    """
        with tf.name_scope(self.params.name):
            return self.params.fn(*args)
    def FProp(self, theta, input_batch):
        p = self.params
        with tf.name_scope(p.name):
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
            if p.packed_input:
                src_segment_id = tf.expand_dims(
                    tf.transpose(input_batch.segment_ids), 2)
            else:
                src_segment_id = None
            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            summary_utils.histogram('input_emb', xs)
            xs = self.dropout.FProp(theta.dropout, xs)
            ps = paddings
            # Now the rnn layers.
            outputs_list = []
            for i in range(0, p.num_lstm_layers):
                layer = self.rnn[i]
                ys = layer.FProp(theta.rnn[i],
                                 xs,
                                 ps,
                                 segment_id=src_segment_id)
                ys = self.dropout.FProp(theta.dropout, ys)
                if i >= p.residual_start:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    xs = ys
                outputs_list.append(xs)
                summary_utils.histogram('layer_out_%s' % i, xs)

            if p.is_transparent:
                xs = self.transparent_merger.FProp(theta.transparent_merger,
                                                   outputs_list)

            if p.lstm_cell_size * 2 != p.encoder_out_dim:
                # Project to the right depth.
                xs = self.final_proj.FProp(theta.final_proj, xs, ps)
                summary_utils.histogram('final_proj_out', xs)

            if src_segment_id is not None:
                src_segment_id = tf.squeeze(src_segment_id, [2])

            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id)
Exemple #15
0
    def FProp(self, theta, inputs):
        """Adds bias to inputs.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor.  Shaped [..., dims].

    Returns:
      Inputs plus bias.
    """
        with tf.name_scope(self.params.name):
            return inputs + theta.b
Exemple #16
0
 def _Traverse(layer):
     """Adds accumulators to layer and its descendant layers."""
     if isinstance(layer, (list, tuple)):
         for layer_i in layer:
             _Traverse(layer_i)
         return
     with tf.name_scope(layer.params.name):
         for cost_metric_name in COST_METRICS:
             dtype = COST_METRICS[cost_metric_name]
             layer.RegisterAccumulator(
                 cost_metric_name,
                 bn_layers.AddingAccumulator(shape=[], dtype=dtype))
         for _, child in sorted(layer.children.items()):
             _Traverse(child)
Exemple #17
0
def CollectVarHistogram(vs_gs):
  """Adds histogram summaries for variables and gradients."""

  for name, (var, grad) in vs_gs.FlattenItems():
    name = py_utils.SanitizeScopeKey(name)
    with tf.device(var.device), tf.name_scope(name + '/summary'):
      if isinstance(grad, tf.IndexedSlices):
        var = tf.gather(var, grad.indices)
        grad = grad.values
      if var.dtype.is_complex:
        var = tf.abs(var)
        grad = tf.abs(grad)

    histogram('var_hist/' + name, var)
    histogram('grad_hist/' + name, grad)
Exemple #18
0
    def FProp(self, theta, *args):
        p = self.params

        with tf.name_scope(p.name):
            # Computes sub layers in parallel.
            outputs = []
            for (name, ch) in self._seq:
                th = theta[name]
                out = ch.FProp(th, *args)
                if isinstance(out, (list, tuple)):
                    outputs.append(tuple(out))
                else:
                    outputs.append((out, ))
            rets = p.merge(outputs)
            return rets if len(rets) > 1 else rets[0]
    def FProp(self, theta, input_batch, state0=None):
        p = self.params
        src_segment_id = None
        with tf.name_scope(p.name):
            # Reshape to [t, b]
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)

            # Setup streaming states.
            if not state0:
                state0 = self.zero_state(theta, tf.shape(inputs)[1])
            state1 = py_utils.NestedMap(rnn=[None] * p.num_lstm_layers)

            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            summary_utils.histogram('input_emb', xs)
            xs = self.dropout.FProp(theta.dropout, xs)
            ps = paddings
            # Now the rnn layers.
            outputs_list = []
            for i in range(0, p.num_lstm_layers):
                layer = self.rnn[i]
                ys, state1.rnn[i] = layer.FProp(theta.rnn[i],
                                                xs,
                                                ps,
                                                state0=state0.rnn[i])
                ys = self.dropout.FProp(theta.dropout, ys)
                if i >= p.residual_start:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    xs = ys
                outputs_list.append(xs)
                summary_utils.histogram('layer_out_%s' % i, xs)

            if p.is_transparent:
                xs = self.transparent_merger.FProp(theta.transparent_merger,
                                                   outputs_list)

            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id,
                                      state=state1)
Exemple #20
0
 def FProp(self, theta, *args):
     p = self.params
     with tf.name_scope(p.name) as name_scope:
         for i, arg in enumerate(args):
             if not isinstance(arg, tf.Tensor):
                 tf.logging.info(
                     'FProp non-Tensor input in {}: arg_{} arg = {}'.format(
                         name_scope, i, arg))
             else:
                 tf.logging.info(
                     'FProp inputs in {}: arg_{} shape = {} dtype = {}'.
                     format(name_scope, i, arg.shape, arg.dtype.name))
     if len(args) == 1:
         return args[0]
     else:
         return args
  def FProp(self, theta, inputs, paddings=None):
    """Apply batch normalization.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1], with the same rank as the
        input tensor.

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
    p = self.params
    if paddings is None:
      paddings = self._GetDefaultPaddings(inputs)
    with tf.name_scope(p.name):
      norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments(
          theta, inputs, paddings)
      with tf.control_dependencies([
          py_utils.assert_greater_equal(norm_variance,
                                        tf.zeros_like(norm_variance)),
          py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                      tf.shape(norm_mean)),
          py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                      tf.shape(norm_variance)),
      ]):
        if p.use_fused_batch_norm_for_eval and self.do_eval:
          bn_output, _, _ = nn.fused_batch_norm(
              inputs,
              gamma,
              beta,
              norm_mean,
              norm_variance,
              self._epsilon,
              is_training=False)
        else:
          bn_output = tf.nn.batch_normalization(inputs, norm_mean,
                                                norm_variance, beta, gamma,
                                                self._epsilon)

        if p.set_padded_output_to_zero:
          bn_output *= 1.0 - paddings

      return bn_output
Exemple #22
0
    def ZeroState(self, theta, prepared_inputs, batch_size):
        """Creates a zero state NestedMap for this step.

    Args:
      theta: variables used by sub-steps.
      prepared_inputs: Output from a call to PrepareExternalInputs.
      batch_size: The number of items in the batch that FProp will process.

    Returns:
      A NestedMap of ZeroState results for each sub-step.
    """
        state0 = py_utils.NestedMap()
        with tf.name_scope(self.params.name):
            for seq in self._seq:
                state0[seq.name] = seq.step.ZeroState(
                    theta[seq.name], prepared_inputs[seq.name], batch_size)
        return state0
Exemple #23
0
def TraverseLayer(layer, fn):
    """Traverses the layer tree and invokes fn(node) on each node.

  Args:
    layer: a BaseLayer.
    fn: a function of (layer, layer_theta) -> None.
  """
    if isinstance(layer, (list, tuple)):
        for layer_i in layer:
            TraverseLayer(layer_i, fn)
        return

    with tf.name_scope(layer.params.name):
        fn(layer)
        # Traverse all children in alphabetical order.
        for _, child in sorted(layer.children.items()):
            TraverseLayer(child, fn)
Exemple #24
0
    def FProp(self, theta, prepared_inputs, step_inputs, padding, state0):
        """A single inference step for this step graph.

    Args:
      theta: variables used by sub-steps.
      prepared_inputs: A NestedMap containing external_inputs that were
        pre-processed by the PrepareExternalInputs method of each sub-step. The
        keys are the names of the sub-steps.
      step_inputs: A NestedMap of [batch, ...] tensors. The structure of this
        depends on the graph implementation.
      padding: A 0/1 float tensor of shape [batch_size]; 1.0 means that this
        batch element is empty in this step.
      state0: A NestedMap of state variables produced by either ZeroState or a
        previous invocation of this FProp step. The keys are the names of the
        sub-steps.

    Returns:
      (output, state1), both of which are NestedMaps.
      output is implementation-dependent and is defined by the output_signature
      parameter.
      state1 is a NestedMap where the keys are names of sub-steps and the values
      are state outputs from their FProp methods.
    """
        p = self.params
        graph_tensors = builder_layers.GraphTensors()
        graph_tensors.StoreTensor('prepared_inputs', prepared_inputs)
        graph_tensors.StoreTensor('step_inputs', step_inputs)
        state1 = py_utils.NestedMap()
        with tf.name_scope(p.name):
            for seq in self._seq:
                tf.logging.vlog(1, 'GraphStep: call %s', seq.name)
                external = None
                if seq.external_signature:
                    external = prepared_inputs[seq.name]
                template = py_utils.NestedMap(inputs=seq.signature.inputs)
                packed = template.Transform(graph_tensors.GetTensor)
                input_args = packed.inputs[0]
                out, seq_state1 = seq.step.FProp(theta[seq.name], external,
                                                 input_args, padding,
                                                 state0[seq.name])
                graph_tensors.StoreTensor(seq.signature.outputs[0], out)
                state1[seq.name] = seq_state1
        template = py_utils.NestedMap(inputs=self.output_signature.inputs)
        output_tensors = template.Transform(graph_tensors.GetTensor).inputs[0]
        return output_tensors, state1
Exemple #25
0
    def FProp(self, theta, inputs, *args):
        p = self.params
        with tf.name_scope(p.name) as scope:
            expert_dist = self._GetExpertDist(theta, inputs, *args)
            if not self.do_eval:
                summary_utils.histogram('soft_cond_{}'.format(scope),
                                        expert_dist)

            # Excludes non-variable extra_theta like global_step.
            var_set = set([key for key, _ in self.body.vars.FlattenItems()])
            values = []
            for key, value in theta.body.FlattenItems():
                if key in var_set and value is not None:
                    # Weighted average for all variables created in the body layer.
                    value = tf.einsum('i,i...->...', expert_dist, value)
                values.append(value)
            weighted_theta = theta.body.Pack(values)
            return self.body.FProp(weighted_theta, inputs, *args)
Exemple #26
0
    def FProp(self, theta, current_step):
        p = self.params
        with tf.name_scope(p.name):

            steps = self._best_step
            best_step = steps[0]
            last_step = steps[1]

            ref_step = tf.maximum(self._ref_step, best_step)
            f = self._cur_factor

            # Decay if no improvement within window.
            new_factor = tf.where(last_step - ref_step < p.window, f,
                                  tf.maximum(p.min_factor, f * p.decay))
            # Update ref_step if we decayed.
            new_step = tf.where(tf.equal(new_factor, f), ref_step, last_step)
            update_step = tf.assign(self._ref_step, new_step)
            with tf.control_dependencies([update_step]):
                return tf.assign(self._cur_factor, new_factor)
Exemple #27
0
 def FProp(self, theta, *args):
     p = self.params
     with tf.name_scope(p.name):
         tf.logging.vlog(1, 'layer %s', self.params.name)
         if p.repeat <= 1:
             for (name, ch) in self._seq:
                 th = theta[name]
                 args = _ToTuple(args)
                 tf.logging.vlog(1, 'SequentialLayer: call %s %s %d %s',
                                 ch.params.name, ch, len(args), str(args))
                 args = ch.FProp(th, *args)
         else:
             for (ch, th) in zip(self.rep, theta.rep):
                 args = _ToTuple(args)
                 tf.logging.vlog(1, '  call %s %s %d %s', ch.params.name,
                                 ch, len(args), str(args))
                 args = ch.FProp(th, *args)
         args = _ToTuple(args)
         return args[0] if len(args) == 1 else args
Exemple #28
0
    def FProp(self, theta, inputs):
        """Apply projection to inputs.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor.  Shaped [..., input_dims].

    Returns:
      Projected inputs.
    """
        p = self.params
        with tf.name_scope(p.name):
            computation_cost.Add(
                self, 'flops',
                tf.reduce_prod(tf.cast(tf.shape(inputs)[:-1], tf.int64)) *
                tf.cast(symbolic.ToTensor(p.input_dims * p.output_dims),
                        tf.int64) * 2)
            return py_utils.ProjectLastDim(inputs, theta.w, p.input_dims,
                                           p.output_dims)
def reorder_tensor(reorder_mode,
                   values,
                   num_shards,
                   shard_size,
                   max_value=None,
                   axis=0):
    """Reorder tensor based on the mode passed in.

  This method reorders rows or cols (based on `axis`) of the tensor passed in
  from one sharding mode to another sharding mode. This method uses matmul for
  reordering to be efficient on TPUs.

  Args:
    reorder_mode: Either mod_to_div or div_to_mod
    values: Tensor to reorder
    num_shards: Number of shards.
    shard_size: Size of each shard.
    max_value: If dtype=tf.int32, and we know maximum of the values, we can
      efficiently implement it as matmuls.
    axis: axis to gather on. Defaults to 0 (rows).

  Returns:
    A tensor of same shape as values but rows (or first axis) reordered.
  """
    values = tf.convert_to_tensor(values)
    with tf.name_scope("reorder_tensor_" + reorder_mode):
        num_ids = num_shards * shard_size
        # Elements to gather.
        seq_ids = tf.range(num_ids)
        if reorder_mode == "mod_to_div":
            local_ids = seq_ids // shard_size
            shard_ids = seq_ids % shard_size
            ids = local_ids + shard_ids * num_shards
        elif reorder_mode == "div_to_mod":
            shard_ids = seq_ids % num_shards
            local_ids = seq_ids // num_shards
            ids = local_ids + shard_ids * shard_size
        else:
            raise NotImplementedError(
                "Reorder mode: {} not implemented.".format(reorder_mode))
        return fast_gather(values, ids, num_ids, max_value, axis=axis)
def GenerateStepSeedPair(p, unused_global_step=None, op_seed=None):
  """Override py_utils.GenerateStepSeedPair to use GetOverWriteGlobalStep."""
  seed_dtype = tf.int32 if py_utils.use_tpu() else tf.int64
  if p.is_inference and p.random_seed is None:
    # Unlike tf.random*, stateless random ops are completely determined by the
    # passed-in seeds. This means at inference time the same inputs will produce
    # the same outputs, even if the model is supposed to have randomness such as
    # dropout during inference. We inject additional randomness only during
    # inference if the graph is exported with random_seed=None as a workaround.
    return tf.random.uniform([2], maxval=seed_dtype.max, dtype=seed_dtype)

  with tf.name_scope('op_seed') as scope:
    global_step = tf.cast(GetOverWriteGlobalStep(), seed_dtype)
    step_seed = tf.cast(py_utils.GenerateSeedFromName(scope), seed_dtype)
    seeds = tf.stack([global_step, step_seed])

    if p.random_seed is not None:
      seeds += p.random_seed
    if op_seed is not None:
      seeds += op_seed
    return seeds