def __call__(self, inputs, state, scope=None): with tf.compat.v1.variable_scope(scope or type(self).__name__): update_prob_prev, cum_update_prob_prev = state[-1].update_prob, state[-1].cum_update_prob cell_input = inputs state_candidates = [] # Compute update candidates for all layers for idx in range(self._num_layers): with tf.compat.v1.variable_scope('layer_%d' % (idx + 1)): if isinstance(state[idx], SkipGRUStateTuple): h_prev = state[idx].h else: h_prev = state[idx] # Parameters of gates are concatenated into one multiply for efficiency. with tf.compat.v1.variable_scope("gates"): concat = rnn_ops.linear([cell_input, h_prev], 2 * self._num_units[idx], bias=True, bias_start=1.0,) # r = reset_gate, u = update_gate r, u = tf.split(value=concat, num_or_size_splits=2, axis=1) if self._layer_norm: r = rnn_ops.layer_norm(r, name="r") u = rnn_ops.layer_norm(u, name="u") # Apply non-linearity after layer normalization r = tf.sigmoid(r) u = tf.sigmoid(u) with tf.compat.v1.variable_scope("candidate"): new_c_tilde = self._activation(rnn_ops.linear([inputs, r * h_prev], self._num_units[idx], True)) new_h_tilde = u * h_prev + (1 - u) * new_c_tilde state_candidates.append(new_h_tilde) cell_input = new_h_tilde # Compute value for the update prob with tf.compat.v1.variable_scope('state_update_prob'): new_update_prob_tilde = rnn_ops.linear(state_candidates[-1], 1, True, bias_start=self._update_bias) new_update_prob_tilde = tf.sigmoid(new_update_prob_tilde) # Compute value for the update gate cum_update_prob = cum_update_prob_prev + tf.minimum(update_prob_prev, 1. - cum_update_prob_prev) update_gate = _binary_round(cum_update_prob) # Apply update gate new_states = [] for idx in range(self._num_layers - 1): new_h = update_gate * state_candidates[idx] + (1. - update_gate) * state[idx] new_states.append(new_h) new_h = update_gate * state_candidates[-1] + (1. - update_gate) * state[-1].h new_update_prob = update_gate * new_update_prob_tilde + (1. - update_gate) * update_prob_prev new_cum_update_prob = update_gate * 0. + (1. - update_gate) * cum_update_prob new_states.append(SkipGRUStateTuple(new_h, new_update_prob, new_cum_update_prob)) new_output = SkipGRUOutputTuple(new_h, update_gate) return new_output, new_states
def __call__(self, inputs, state, scope=None): with tf.compat.v1.variable_scope(scope or type(self).__name__): update_prob_prev, cum_update_prob_prev = state[-1].update_prob, state[-1].cum_update_prob cell_input = inputs state_candidates = [] # Compute update candidates for all layers for idx in range(self._num_layers): with tf.compat.v1.variable_scope('layer_%d' % (idx + 1)): c_prev, h_prev = state[idx].c, state[idx].h # Parameters of gates are concatenated into one multiply for efficiency. concat = rnn_ops.linear([cell_input, h_prev], 4 * self._num_units[idx], True) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = tf.split(value=concat, num_or_size_splits=4, axis=1) if self._layer_norm: i = rnn_ops.layer_norm(i, name="i") j = rnn_ops.layer_norm(j, name="j") f = rnn_ops.layer_norm(f, name="f") o = rnn_ops.layer_norm(o, name="o") new_c_tilde = (c_prev * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * self._activation(j)) new_h_tilde = self._activation(new_c_tilde) * tf.sigmoid(o) state_candidates.append(LSTMStateTuple(new_c_tilde, new_h_tilde)) cell_input = new_h_tilde # Compute value for the update prob with tf.compat.v1.variable_scope('state_update_prob'): new_update_prob_tilde = rnn_ops.linear(state_candidates[-1].c, 1, True, bias_start=self._update_bias) new_update_prob_tilde = tf.sigmoid(new_update_prob_tilde) # Compute value for the update gate cum_update_prob = cum_update_prob_prev + tf.minimum(update_prob_prev, 1. - cum_update_prob_prev) update_gate = _binary_round(cum_update_prob) # Apply update gate new_states = [] for idx in range(self._num_layers - 1): new_c = update_gate * state_candidates[idx].c + (1. - update_gate) * state[idx].c new_h = update_gate * state_candidates[idx].h + (1. - update_gate) * state[idx].h new_states.append(LSTMStateTuple(new_c, new_h)) new_c = update_gate * state_candidates[-1].c + (1. - update_gate) * state[-1].c new_h = update_gate * state_candidates[-1].h + (1. - update_gate) * state[-1].h new_update_prob = update_gate * new_update_prob_tilde + (1. - update_gate) * update_prob_prev new_cum_update_prob = update_gate * 0. + (1. - update_gate) * cum_update_prob new_states.append(SkipLSTMStateTuple(new_c, new_h, new_update_prob, new_cum_update_prob)) new_output = SkipLSTMOutputTuple(new_h, update_gate) return new_output, new_states
def __call__(self, inputs, state, scope=None): with tf.variable_scope(scope or type(self).__name__): h_prev, update_prob_prev, cum_update_prob_prev = state # Parameters of gates are concatenated into one multiply for efficiency. with tf.variable_scope("gates"): concat = rnn_ops.linear([inputs, h_prev], 2 * self._num_units, bias=True, bias_start=1.0) # r = reset_gate, u = update_gate r, u = tf.split(value=concat, num_or_size_splits=2, axis=1) if self._layer_norm: r = rnn_ops.layer_norm(r, name="r") u = rnn_ops.layer_norm(u, name="u") # Apply non-linearity after layer normalization r = tf.sigmoid(r) u = tf.sigmoid(u) with tf.variable_scope("candidate"): new_c_tilde = self._activation( rnn_ops.linear([inputs, r * h_prev], self._num_units, True)) new_h_tilde = u * h_prev + (1 - u) * new_c_tilde # Compute value for the update prob with tf.variable_scope('state_update_prob'): new_update_prob_tilde = rnn_ops.linear( new_h_tilde, 1, True, bias_start=self._update_bias) new_update_prob_tilde = tf.sigmoid(new_update_prob_tilde) # Compute value for the update gate cum_update_prob = cum_update_prob_prev + tf.minimum( update_prob_prev, 1. - cum_update_prob_prev) update_gate = _binary_round(cum_update_prob) # Apply update gate new_h = update_gate * new_h_tilde + (1. - update_gate) * h_prev new_update_prob = update_gate * new_update_prob_tilde + ( 1. - update_gate) * update_prob_prev new_cum_update_prob = update_gate * 0. + ( 1. - update_gate) * cum_update_prob new_state = SkipGRUStateTuple(new_h, new_update_prob, new_cum_update_prob) new_output = SkipGRUOutputTuple(new_h, update_gate) return new_output, new_state
def __call__(self, inputs, state, scope=None): with tf.variable_scope(scope or type(self).__name__): c_prev, h_prev, update_prob_prev, cum_update_prob_prev = state # Parameters of gates are concatenated into one multiply for efficiency. concat = rnn_ops.linear([inputs, h_prev], 4 * self._num_units, True) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = tf.split(value=concat, num_or_size_splits=4, axis=1) if self._layer_norm: i = rnn_ops.layer_norm(i, name="i") j = rnn_ops.layer_norm(j, name="j") f = rnn_ops.layer_norm(f, name="f") o = rnn_ops.layer_norm(o, name="o") new_c_tilde = (c_prev * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * self._activation(j)) new_h_tilde = self._activation(new_c_tilde) * tf.sigmoid(o) # Compute value for the update prob with tf.variable_scope('state_update_prob'): new_update_prob_tilde = rnn_ops.linear( new_c_tilde, 1, True, bias_start=self._update_bias) new_update_prob_tilde = tf.sigmoid(new_update_prob_tilde) # Compute value for the update gate cum_update_prob = cum_update_prob_prev + tf.minimum( update_prob_prev, 1. - cum_update_prob_prev) update_gate = _binary_round(cum_update_prob) # Apply update gate new_c = update_gate * new_c_tilde + (1. - update_gate) * c_prev new_h = update_gate * new_h_tilde + (1. - update_gate) * h_prev new_update_prob = update_gate * new_update_prob_tilde + ( 1. - update_gate) * update_prob_prev new_cum_update_prob = update_gate * 0. + ( 1. - update_gate) * cum_update_prob new_state = SkipLSTMStateTuple(new_c, new_h, new_update_prob, new_cum_update_prob) new_output = SkipLSTMOutputTuple(new_h, update_gate) return new_output, new_state
def __call__(self, inputs, state, scope=None): """Gated recurrent unit (GRU) with num_units cells.""" with tf.variable_scope(scope or type(self).__name__): with tf.variable_scope("gates"): # Reset gate and update gate. # We start with bias of 1.0 to not reset and not update. concat = rnn_ops.linear([inputs, state], 2 * self._num_units, True, bias_start=1.0) r, u = tf.split(value=concat, num_or_size_splits=2, axis=1) if self._layer_norm: r = rnn_ops.layer_norm(r, name="r") u = rnn_ops.layer_norm(u, name="u") # Apply non-linearity after layer normalization r = tf.sigmoid(r) u = tf.sigmoid(u) with tf.variable_scope("candidate"): c = self._activation( rnn_ops.linear([inputs, r * state], self._num_units, True)) new_h = u * state + (1 - u) * c return new_h, new_h
def __call__(self, inputs, state, scope=None): """Long short-term memory cell (LSTM).""" with tf.variable_scope(scope or type(self).__name__): c, h = state # Parameters of gates are concatenated into one multiply for efficiency. concat = rnn_ops.linear([inputs, h], 4 * self._num_units, True) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = tf.split(value=concat, num_or_size_splits=4, axis=1) if self._layer_norm: i = rnn_ops.layer_norm(i, name="i") j = rnn_ops.layer_norm(j, name="j") f = rnn_ops.layer_norm(f, name="f") o = rnn_ops.layer_norm(o, name="o") new_c = (c * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * self._activation(j)) new_h = self._activation(new_c) * tf.sigmoid(o) new_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h) return new_h, new_state