Ejemplo n.º 1
0
        def grad(*dy):
            """Gradient function for build_rnn."""
            dy = nest.pack_sequence_as(ret, dy)

            def _continue(unused_theta, unused_dy, unused_dstate1,
                          unused_dtheta, unused_dinput, i):
                return _should_continue(i, True)

            dstate1, dtheta, dinput = tf.while_loop(_continue, _cell_grad_fn, [
                theta,
                dy,
                {
                    k: tf.zeros_like(state0[k])
                    for k in state0 if k not in skipped_state
                },
                {
                    k: tf.zeros_like(theta[k])
                    for k in theta if k not in skipped_theta
                },
                {k: tf.zeros_like(inputs[k])
                 for k in inputs},
                tf.zeros([], tf.int32) if reverse else max_length - 1,
            ])[2:5]
            dtheta, dinput = _cell_grad_fn_with_state0(
                state0, theta, dy, dstate1, dtheta, dinput, max_length -
                1 if reverse else tf.zeros([], dtype=tf.int32))[3:5]
            state0_h = tf.reshape(acc_state["h"],
                                  [-1, theta["kernel"].shape[0]])
            state0_atten = tf.reshape(acc_state["attention"], [
                -1, theta["attention_kernel"].shape[0]
            ]) if "attention_kernel" in theta else None
            grad = tf.reshape(dinput["rnn"], [-1, theta["kernel"].shape[1]])
            if reverse:
                state0_h = tf.split(state0_h, [batch_size, -1])[1]
                grad = tf.split(grad, [-1, batch_size])[0]
            else:
                if state0_atten is not None:
                    state0_atten = tf.split(state0_atten, [-1, batch_size])[0]
                state0_h = tf.split(state0_h, [-1, batch_size])[0]
                grad = tf.split(grad, [batch_size, -1])[1]

            if state0_atten is not None:
                dtheta["attention_kernel"] = tf.matmul(
                    tf.transpose(state0_atten), grad)
            dtheta["kernel"] = tf.matmul(tf.transpose(state0_h), grad)

            if "memory_kernel" in orig_theta:
                dtheta["memory_kernel"] = tf.zeros_like(
                    orig_theta["memory_kernel"])
                dtheta["seq_mask"] = tf.zeros_like(orig_theta["seq_mask"])
            return dinput, dtheta
        def tpu_eval_step():
            """Generate the TPU graph."""
            values = self.eval_infeed_queue[0].generate_dequeue_op(
                tpu_device=0)
            unflattened_inputs = data_nest.pack_sequence_as(
                self.eval_feature_structure, values)
            features = unflattened_inputs["features"]
            estimator_spec = model_fn(features, None,
                                      tf.estimator.ModeKeys.PREDICT, params)
            for k, v in six.iteritems(estimator_spec.predictions):
                self.outfeed_names.append(k)
                self.outfeed_tensors.append(v)

            with tf.device(
                    device_for_tpu_core(get_host(self.resolver,
                                                 self.hparams))):
                outfeed_enqueue_ops = tpu_ops.outfeed_enqueue_tuple(
                    self.outfeed_tensors)
            with tf.control_dependencies([outfeed_enqueue_ops]):
                return tf.no_op()
        def tpu_train_step(loss):
            """Generate the TPU graph."""
            del loss
            values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
            unflattened_inputs = data_nest.pack_sequence_as(
                self.feature_structure, values)
            features = unflattened_inputs["features"]
            core_id = unflattened_inputs["core_id"]
            new_features = {}
            for k in features:
                s = features[k].shape.as_list()
                s = [self.hparams.num_shards, s[0] // self.hparams.num_shards
                     ] + s[1:]
                new_features[k] = tf.squeeze(
                    tf.gather(
                        tf.reshape(tpu_ops.cross_replica_sum(features[k]), s),
                        core_id), [0])

            estimator_spec = model_fn(new_features, None,
                                      tf.estimator.ModeKeys.TRAIN, params)
            loss, train_op = estimator_spec.loss, estimator_spec.train_op
            with tf.control_dependencies([train_op]):
                return tf.identity(loss)
Ejemplo n.º 4
0
def attention_cell_grad(theta, state0, unused_inputs, extras, dstate1):
  """Gradient function for attention_cell."""
  new_lstm_state = lstm_cell_split(extras, state0["c"], None)
  new_states = attention(theta, new_lstm_state)
  del new_states["alignments"]

  y = nest.flatten(new_states)
  x = [extras, state0["c"]] + nest.flatten(theta)
  dy = nest.flatten(dstate1)
  g = tf.gradients(y, x, dy)
  dtheta = nest.pack_sequence_as(theta, g[2:])
  grad, dstate_c = g[:2]

  dtheta["bias"] = tf.reduce_sum(grad, 0)

  datten = tf.matmul(grad, tf.transpose(theta["attention_kernel"]))
  dstate_h = tf.matmul(grad, tf.transpose(theta["kernel"]))

  dstate = {
      "h": dstate_h,
      "c": dstate_c,
      "attention": datten,
  }
  return dtheta, dstate, {"rnn": grad}
Ejemplo n.º 5
0
  def _rnn(*inp):
    """Function that drives RNN with early stop."""
    inputs = nest.pack_sequence_as(orig_inputs, inp[0:len(orig_inputs)])
    theta = nest.pack_sequence_as(orig_theta, inp[len(orig_inputs):])

    def _cell_fn(theta, state0, acc_state, acc_gate, i):
      """RNN cell function."""
      input_slice = {k: tf.gather(inputs[k], i) for k in inputs}
      state1, gate = cell_fn(theta, state0, input_slice)
      for k in state0:
        if k not in skipped_state:
          acc_state[k] = tf.stop_gradient(
              inplace_ops.alias_inplace_update(acc_state[k], i, state1[k]))
      acc_gate = tf.stop_gradient(
          inplace_ops.alias_inplace_update(acc_gate, i, gate))
      return theta, state1, acc_state, acc_gate, i - 1 if reverse else i + 1

    def _should_continue(i, is_backward=False):
      if is_backward:
        return i < max_length - 1 if reverse else i > 0
      else:
        return i >= 0 if reverse else i < max_length

    acc_state = {
        k: tf.zeros([max_time, batch_size, state0["c"].shape[-1]],
                    state0["c"].dtype) for k in state0 if k not in skipped_state
    }
    acc_state, acc_gate = tf.while_loop(
        lambda theta, state0, acc_state, acc_gate, i: _should_continue(i),
        _cell_fn, [
            theta, state0, acc_state,
            tf.zeros_like(inputs["rnn"]),
            max_length - 1 if reverse else tf.zeros([], tf.int32)
        ])[2:4]
    ret = {"h": acc_state["h"]}
    if "attention" in acc_state:
      ret["attention"] = acc_state["attention"]

    def _cell_grad_fn_with_state0(state0, theta, dy, dstate1, dtheta, dinput,
                                  i):
      """Gradient cell function."""
      state0 = {
          k: tf.stop_gradient(state0[k])
          for k in state0
          if k not in skipped_state
      }
      theta = {k: tf.stop_gradient(theta[k]) for k in theta}
      if "padding" in inputs:
        inputs_slice = {"padding": tf.gather(inputs["padding"], i)}
      else:
        inputs_slice = None
      gate = tf.gather(acc_gate, i)
      for k in dy:
        dstate1[k] = dstate1[k] + tf.gather(dy[k], i)
      dt, dstate, di = cell_grad(theta, state0, inputs_slice, gate, dstate1)
      dtheta = {k: dtheta[k] + dt[k] for k in dtheta if k not in skipped_theta}
      dinput = {
          k: inplace_ops.alias_inplace_update(dinput[k], i, di[k]) for k in di
      }
      return theta, dy, dstate, dtheta, dinput, i + 1 if reverse else i - 1

    def _cell_grad_fn(theta, dy, dstate1, dtheta, dinput, i):
      """Gradient cell function wrapper."""
      return _cell_grad_fn_with_state0(
          {
              k: tf.gather(acc_state[k], i + 1 if reverse else i - 1)
              for k in acc_state
          }, theta, dy, dstate1, dtheta, dinput, i)

    def grad(*dy):
      """Gradient function for build_rnn."""
      dy = nest.pack_sequence_as(ret, dy)

      def _continue(unused_theta, unused_dy, unused_dstate1, unused_dtheta,
                    unused_dinput, i):
        return _should_continue(i, True)

      dstate1, dtheta, dinput = tf.while_loop(_continue, _cell_grad_fn, [
          theta,
          dy,
          {
              k: tf.zeros_like(state0[k])
              for k in state0
              if k not in skipped_state
          },
          {k: tf.zeros_like(theta[k]) for k in theta if k not in skipped_theta},
          {k: tf.zeros_like(inputs[k]) for k in inputs},
          tf.zeros([], tf.int32) if reverse else max_length - 1,
      ])[2:5]
      dtheta, dinput = _cell_grad_fn_with_state0(
          state0, theta, dy, dstate1, dtheta, dinput,
          max_length - 1 if reverse else tf.zeros([], dtype=tf.int32))[3:5]
      state0_h = tf.reshape(acc_state["h"], [-1, theta["kernel"].shape[0]])
      state0_atten = tf.reshape(acc_state["attention"],
                                [-1, theta["attention_kernel"].shape[0]
                                ]) if "attention_kernel" in theta else None
      grad = tf.reshape(dinput["rnn"], [-1, theta["kernel"].shape[1]])
      if reverse:
        state0_h = tf.split(state0_h, [batch_size, -1])[1]
        grad = tf.split(grad, [-1, batch_size])[0]
      else:
        if state0_atten is not None:
          state0_atten = tf.split(state0_atten, [-1, batch_size])[0]
        state0_h = tf.split(state0_h, [-1, batch_size])[0]
        grad = tf.split(grad, [batch_size, -1])[1]

      if state0_atten is not None:
        dtheta["attention_kernel"] = tf.matmul(tf.transpose(state0_atten), grad)
      dtheta["kernel"] = tf.matmul(tf.transpose(state0_h), grad)

      if "memory_kernel" in orig_theta:
        dtheta["memory_kernel"] = tf.zeros_like(orig_theta["memory_kernel"])
        dtheta["seq_mask"] = tf.zeros_like(orig_theta["seq_mask"])
      return dinput, dtheta

    return ret, grad