def enqueue_ops_fn(idx): """Enqueue ops function for one host..""" with tf.device(device): sharded_inputs = [] start_idx = 0 if host_id in range(0, self.hparams.num_infeed_workers * 2, 2): core_id = tf.constant(host_id * self.hparams.num_shards_per_host, shape=[1], dtype=tf.int32) if self.hparams.use_synthetic_data: features = output else: def true_fn(): return iterator.get_next() def false_fn(): return { k: tf.zeros_like( self.feature_structure["features"][k]) for k in self.feature_structure["features"] } features = tf.cond( tf.equal(idx % self.hparams.num_infeed_workers, host_id // 2), true_fn, false_fn) sharded_inputs.append( data_nest.flatten({ "features": features, "core_id": core_id })) start_idx = 1 for i in range(start_idx, self.hparams.num_shards_per_host): sharded_inputs.append( data_nest.flatten({ "features": { k: tf.zeros_like( self.feature_structure["features"][k]) for k in self.feature_structure["features"] }, "core_id": tf.constant( host_id * self.hparams.num_shards_per_host + i, shape=[1], dtype=tf.int32) })) infeed = tpu_feed.InfeedQueue( number_of_tuple_elements=len(sharded_inputs[0])) self.infeed_queue.append(infeed) def tpu_ordinal_fn(shard_index_in_host): return shard_index_in_host % self.hparams.num_shards_per_host return infeed.generate_enqueue_ops( sharded_inputs, tpu_ordinal_function=tpu_ordinal_fn)
def lstm_cell_grad(theta, state0, inputs, extras, dstate1): """Gradient function for lstm_cell.""" padding = inputs["padding"] if (inputs is not None and "padding" in inputs) else None state1 = nest.flatten(lstm_cell_split(extras, state0["c"], padding)) dstate1 = nest.flatten(dstate1) grad = tf.gradients(state1, [extras], dstate1)[0] dtheta = {"bias": tf.reduce_sum(grad, 0)} dinputs = {"rnn": grad} dstate = {"c": tf.gradients(state1, state0["c"], dstate1)[0]} dstate["h"] = tf.matmul(grad, tf.transpose(theta["kernel"])) if padding is not None: dinputs["padding"] = padding return dtheta, dstate, dinputs
def attention_cell_grad(theta, state0, unused_inputs, extras, dstate1): """Gradient function for attention_cell.""" new_lstm_state = lstm_cell_split(extras, state0["c"], None) new_states = attention(theta, new_lstm_state) del new_states["alignments"] y = nest.flatten(new_states) x = [extras, state0["c"]] + nest.flatten(theta) dy = nest.flatten(dstate1) g = tf.gradients(y, x, dy) dtheta = nest.pack_sequence_as(theta, g[2:]) grad, dstate_c = g[:2] dtheta["bias"] = tf.reduce_sum(grad, 0) datten = tf.matmul(grad, tf.transpose(theta["attention_kernel"])) dstate_h = tf.matmul(grad, tf.transpose(theta["kernel"])) dstate = { "h": dstate_h, "c": dstate_c, "attention": datten, } return dtheta, dstate, {"rnn": grad}
def enqueue_ops_fn(): """Enqueue ops function for one host.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(self.hparams.num_shards_per_host): with tf.control_dependencies(control_deps): features = iterator.get_next() self.eval_feature_structure["features"] = features flattened_inputs = data_nest.flatten( self.eval_feature_structure) control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) infeed = tpu_feed.InfeedQueue(number_of_tuple_elements=len( per_host_sharded_inputs[0])) self.eval_infeed_queue.append(infeed) def tpu_ordinal_fn(shard_index_in_host): return shard_index_in_host % self.hparams.num_shards_per_host return infeed.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_fn)
def build_rnn(orig_theta, state0, orig_inputs, cell_fn, cell_grad, max_length, reverse=False): """Helper function to build an RNN.""" max_time, batch_size = orig_inputs["rnn"].shape.as_list()[:2] skipped_theta = ["kernel", "attention_kernel", "memory_kernel", "seq_mask"] skipped_state = ["alignments"] @tf.custom_gradient def _rnn(*inp): """Function that drives RNN with early stop.""" inputs = nest.pack_sequence_as(orig_inputs, inp[0:len(orig_inputs)]) theta = nest.pack_sequence_as(orig_theta, inp[len(orig_inputs):]) def _cell_fn(theta, state0, acc_state, acc_gate, i): """RNN cell function.""" input_slice = {k: tf.gather(inputs[k], i) for k in inputs} state1, gate = cell_fn(theta, state0, input_slice) for k in state0: if k not in skipped_state: acc_state[k] = tf.stop_gradient( inplace_ops.alias_inplace_update(acc_state[k], i, state1[k])) acc_gate = tf.stop_gradient( inplace_ops.alias_inplace_update(acc_gate, i, gate)) return theta, state1, acc_state, acc_gate, i - 1 if reverse else i + 1 def _should_continue(i, is_backward=False): if is_backward: return i < max_length - 1 if reverse else i > 0 else: return i >= 0 if reverse else i < max_length acc_state = { k: tf.zeros([max_time, batch_size, state0["c"].shape[-1]], state0["c"].dtype) for k in state0 if k not in skipped_state } acc_state, acc_gate = tf.while_loop( lambda theta, state0, acc_state, acc_gate, i: _should_continue(i), _cell_fn, [ theta, state0, acc_state, tf.zeros_like(inputs["rnn"]), max_length - 1 if reverse else tf.zeros([], tf.int32) ])[2:4] ret = {"h": acc_state["h"]} if "attention" in acc_state: ret["attention"] = acc_state["attention"] def _cell_grad_fn_with_state0(state0, theta, dy, dstate1, dtheta, dinput, i): """Gradient cell function.""" state0 = { k: tf.stop_gradient(state0[k]) for k in state0 if k not in skipped_state } theta = {k: tf.stop_gradient(theta[k]) for k in theta} if "padding" in inputs: inputs_slice = {"padding": tf.gather(inputs["padding"], i)} else: inputs_slice = None gate = tf.gather(acc_gate, i) for k in dy: dstate1[k] = dstate1[k] + tf.gather(dy[k], i) dt, dstate, di = cell_grad(theta, state0, inputs_slice, gate, dstate1) dtheta = {k: dtheta[k] + dt[k] for k in dtheta if k not in skipped_theta} dinput = { k: inplace_ops.alias_inplace_update(dinput[k], i, di[k]) for k in di } return theta, dy, dstate, dtheta, dinput, i + 1 if reverse else i - 1 def _cell_grad_fn(theta, dy, dstate1, dtheta, dinput, i): """Gradient cell function wrapper.""" return _cell_grad_fn_with_state0( { k: tf.gather(acc_state[k], i + 1 if reverse else i - 1) for k in acc_state }, theta, dy, dstate1, dtheta, dinput, i) def grad(*dy): """Gradient function for build_rnn.""" dy = nest.pack_sequence_as(ret, dy) def _continue(unused_theta, unused_dy, unused_dstate1, unused_dtheta, unused_dinput, i): return _should_continue(i, True) dstate1, dtheta, dinput = tf.while_loop(_continue, _cell_grad_fn, [ theta, dy, { k: tf.zeros_like(state0[k]) for k in state0 if k not in skipped_state }, {k: tf.zeros_like(theta[k]) for k in theta if k not in skipped_theta}, {k: tf.zeros_like(inputs[k]) for k in inputs}, tf.zeros([], tf.int32) if reverse else max_length - 1, ])[2:5] dtheta, dinput = _cell_grad_fn_with_state0( state0, theta, dy, dstate1, dtheta, dinput, max_length - 1 if reverse else tf.zeros([], dtype=tf.int32))[3:5] state0_h = tf.reshape(acc_state["h"], [-1, theta["kernel"].shape[0]]) state0_atten = tf.reshape(acc_state["attention"], [-1, theta["attention_kernel"].shape[0] ]) if "attention_kernel" in theta else None grad = tf.reshape(dinput["rnn"], [-1, theta["kernel"].shape[1]]) if reverse: state0_h = tf.split(state0_h, [batch_size, -1])[1] grad = tf.split(grad, [-1, batch_size])[0] else: if state0_atten is not None: state0_atten = tf.split(state0_atten, [-1, batch_size])[0] state0_h = tf.split(state0_h, [-1, batch_size])[0] grad = tf.split(grad, [batch_size, -1])[1] if state0_atten is not None: dtheta["attention_kernel"] = tf.matmul(tf.transpose(state0_atten), grad) dtheta["kernel"] = tf.matmul(tf.transpose(state0_h), grad) if "memory_kernel" in orig_theta: dtheta["memory_kernel"] = tf.zeros_like(orig_theta["memory_kernel"]) dtheta["seq_mask"] = tf.zeros_like(orig_theta["seq_mask"]) return dinput, dtheta return ret, grad return dict( _rnn(*(tuple(nest.flatten(orig_inputs)) + tuple(nest.flatten(orig_theta)))))