コード例 #1
0
            def enqueue_ops_fn(idx):
                """Enqueue ops function for one host.."""
                with tf.device(device):
                    sharded_inputs = []
                    start_idx = 0
                    if host_id in range(0, self.hparams.num_infeed_workers * 2,
                                        2):
                        core_id = tf.constant(host_id *
                                              self.hparams.num_shards_per_host,
                                              shape=[1],
                                              dtype=tf.int32)
                        if self.hparams.use_synthetic_data:
                            features = output
                        else:

                            def true_fn():
                                return iterator.get_next()

                            def false_fn():
                                return {
                                    k: tf.zeros_like(
                                        self.feature_structure["features"][k])
                                    for k in self.feature_structure["features"]
                                }

                            features = tf.cond(
                                tf.equal(idx % self.hparams.num_infeed_workers,
                                         host_id // 2), true_fn, false_fn)
                        sharded_inputs.append(
                            data_nest.flatten({
                                "features": features,
                                "core_id": core_id
                            }))
                        start_idx = 1
                    for i in range(start_idx,
                                   self.hparams.num_shards_per_host):
                        sharded_inputs.append(
                            data_nest.flatten({
                                "features": {
                                    k: tf.zeros_like(
                                        self.feature_structure["features"][k])
                                    for k in self.feature_structure["features"]
                                },
                                "core_id":
                                tf.constant(
                                    host_id * self.hparams.num_shards_per_host
                                    + i,
                                    shape=[1],
                                    dtype=tf.int32)
                            }))
                infeed = tpu_feed.InfeedQueue(
                    number_of_tuple_elements=len(sharded_inputs[0]))
                self.infeed_queue.append(infeed)

                def tpu_ordinal_fn(shard_index_in_host):
                    return shard_index_in_host % self.hparams.num_shards_per_host

                return infeed.generate_enqueue_ops(
                    sharded_inputs, tpu_ordinal_function=tpu_ordinal_fn)
コード例 #2
0
def lstm_cell_grad(theta, state0, inputs, extras, dstate1):
    """Gradient function for lstm_cell."""
    padding = inputs["padding"] if (inputs is not None
                                    and "padding" in inputs) else None
    state1 = nest.flatten(lstm_cell_split(extras, state0["c"], padding))
    dstate1 = nest.flatten(dstate1)
    grad = tf.gradients(state1, [extras], dstate1)[0]
    dtheta = {"bias": tf.reduce_sum(grad, 0)}
    dinputs = {"rnn": grad}
    dstate = {"c": tf.gradients(state1, state0["c"], dstate1)[0]}
    dstate["h"] = tf.matmul(grad, tf.transpose(theta["kernel"]))
    if padding is not None:
        dinputs["padding"] = padding
    return dtheta, dstate, dinputs
コード例 #3
0
def attention_cell_grad(theta, state0, unused_inputs, extras, dstate1):
  """Gradient function for attention_cell."""
  new_lstm_state = lstm_cell_split(extras, state0["c"], None)
  new_states = attention(theta, new_lstm_state)
  del new_states["alignments"]

  y = nest.flatten(new_states)
  x = [extras, state0["c"]] + nest.flatten(theta)
  dy = nest.flatten(dstate1)
  g = tf.gradients(y, x, dy)
  dtheta = nest.pack_sequence_as(theta, g[2:])
  grad, dstate_c = g[:2]

  dtheta["bias"] = tf.reduce_sum(grad, 0)

  datten = tf.matmul(grad, tf.transpose(theta["attention_kernel"]))
  dstate_h = tf.matmul(grad, tf.transpose(theta["kernel"]))

  dstate = {
      "h": dstate_h,
      "c": dstate_c,
      "attention": datten,
  }
  return dtheta, dstate, {"rnn": grad}
コード例 #4
0
                def enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(self.hparams.num_shards_per_host):
                        with tf.control_dependencies(control_deps):
                            features = iterator.get_next()
                        self.eval_feature_structure["features"] = features
                        flattened_inputs = data_nest.flatten(
                            self.eval_feature_structure)
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    infeed = tpu_feed.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    self.eval_infeed_queue.append(infeed)

                    def tpu_ordinal_fn(shard_index_in_host):
                        return shard_index_in_host % self.hparams.num_shards_per_host

                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=tpu_ordinal_fn)
コード例 #5
0
def build_rnn(orig_theta,
              state0,
              orig_inputs,
              cell_fn,
              cell_grad,
              max_length,
              reverse=False):
  """Helper function to build an RNN."""
  max_time, batch_size = orig_inputs["rnn"].shape.as_list()[:2]
  skipped_theta = ["kernel", "attention_kernel", "memory_kernel", "seq_mask"]
  skipped_state = ["alignments"]

  @tf.custom_gradient
  def _rnn(*inp):
    """Function that drives RNN with early stop."""
    inputs = nest.pack_sequence_as(orig_inputs, inp[0:len(orig_inputs)])
    theta = nest.pack_sequence_as(orig_theta, inp[len(orig_inputs):])

    def _cell_fn(theta, state0, acc_state, acc_gate, i):
      """RNN cell function."""
      input_slice = {k: tf.gather(inputs[k], i) for k in inputs}
      state1, gate = cell_fn(theta, state0, input_slice)
      for k in state0:
        if k not in skipped_state:
          acc_state[k] = tf.stop_gradient(
              inplace_ops.alias_inplace_update(acc_state[k], i, state1[k]))
      acc_gate = tf.stop_gradient(
          inplace_ops.alias_inplace_update(acc_gate, i, gate))
      return theta, state1, acc_state, acc_gate, i - 1 if reverse else i + 1

    def _should_continue(i, is_backward=False):
      if is_backward:
        return i < max_length - 1 if reverse else i > 0
      else:
        return i >= 0 if reverse else i < max_length

    acc_state = {
        k: tf.zeros([max_time, batch_size, state0["c"].shape[-1]],
                    state0["c"].dtype) for k in state0 if k not in skipped_state
    }
    acc_state, acc_gate = tf.while_loop(
        lambda theta, state0, acc_state, acc_gate, i: _should_continue(i),
        _cell_fn, [
            theta, state0, acc_state,
            tf.zeros_like(inputs["rnn"]),
            max_length - 1 if reverse else tf.zeros([], tf.int32)
        ])[2:4]
    ret = {"h": acc_state["h"]}
    if "attention" in acc_state:
      ret["attention"] = acc_state["attention"]

    def _cell_grad_fn_with_state0(state0, theta, dy, dstate1, dtheta, dinput,
                                  i):
      """Gradient cell function."""
      state0 = {
          k: tf.stop_gradient(state0[k])
          for k in state0
          if k not in skipped_state
      }
      theta = {k: tf.stop_gradient(theta[k]) for k in theta}
      if "padding" in inputs:
        inputs_slice = {"padding": tf.gather(inputs["padding"], i)}
      else:
        inputs_slice = None
      gate = tf.gather(acc_gate, i)
      for k in dy:
        dstate1[k] = dstate1[k] + tf.gather(dy[k], i)
      dt, dstate, di = cell_grad(theta, state0, inputs_slice, gate, dstate1)
      dtheta = {k: dtheta[k] + dt[k] for k in dtheta if k not in skipped_theta}
      dinput = {
          k: inplace_ops.alias_inplace_update(dinput[k], i, di[k]) for k in di
      }
      return theta, dy, dstate, dtheta, dinput, i + 1 if reverse else i - 1

    def _cell_grad_fn(theta, dy, dstate1, dtheta, dinput, i):
      """Gradient cell function wrapper."""
      return _cell_grad_fn_with_state0(
          {
              k: tf.gather(acc_state[k], i + 1 if reverse else i - 1)
              for k in acc_state
          }, theta, dy, dstate1, dtheta, dinput, i)

    def grad(*dy):
      """Gradient function for build_rnn."""
      dy = nest.pack_sequence_as(ret, dy)

      def _continue(unused_theta, unused_dy, unused_dstate1, unused_dtheta,
                    unused_dinput, i):
        return _should_continue(i, True)

      dstate1, dtheta, dinput = tf.while_loop(_continue, _cell_grad_fn, [
          theta,
          dy,
          {
              k: tf.zeros_like(state0[k])
              for k in state0
              if k not in skipped_state
          },
          {k: tf.zeros_like(theta[k]) for k in theta if k not in skipped_theta},
          {k: tf.zeros_like(inputs[k]) for k in inputs},
          tf.zeros([], tf.int32) if reverse else max_length - 1,
      ])[2:5]
      dtheta, dinput = _cell_grad_fn_with_state0(
          state0, theta, dy, dstate1, dtheta, dinput,
          max_length - 1 if reverse else tf.zeros([], dtype=tf.int32))[3:5]
      state0_h = tf.reshape(acc_state["h"], [-1, theta["kernel"].shape[0]])
      state0_atten = tf.reshape(acc_state["attention"],
                                [-1, theta["attention_kernel"].shape[0]
                                ]) if "attention_kernel" in theta else None
      grad = tf.reshape(dinput["rnn"], [-1, theta["kernel"].shape[1]])
      if reverse:
        state0_h = tf.split(state0_h, [batch_size, -1])[1]
        grad = tf.split(grad, [-1, batch_size])[0]
      else:
        if state0_atten is not None:
          state0_atten = tf.split(state0_atten, [-1, batch_size])[0]
        state0_h = tf.split(state0_h, [-1, batch_size])[0]
        grad = tf.split(grad, [batch_size, -1])[1]

      if state0_atten is not None:
        dtheta["attention_kernel"] = tf.matmul(tf.transpose(state0_atten), grad)
      dtheta["kernel"] = tf.matmul(tf.transpose(state0_h), grad)

      if "memory_kernel" in orig_theta:
        dtheta["memory_kernel"] = tf.zeros_like(orig_theta["memory_kernel"])
        dtheta["seq_mask"] = tf.zeros_like(orig_theta["seq_mask"])
      return dinput, dtheta

    return ret, grad

  return dict(
      _rnn(*(tuple(nest.flatten(orig_inputs)) +
             tuple(nest.flatten(orig_theta)))))