Ejemplo n.º 1
0
    def __call__(self, inputs, state, scope=None):
        with tf.compat.v1.variable_scope(scope or type(self).__name__):
            update_prob_prev, cum_update_prob_prev = state[-1].update_prob, state[-1].cum_update_prob
            cell_input = inputs
            state_candidates = []

            # Compute update candidates for all layers
            for idx in range(self._num_layers):
                with tf.compat.v1.variable_scope('layer_%d' % (idx + 1)):
                    if isinstance(state[idx], SkipGRUStateTuple):
                        h_prev = state[idx].h
                    else:
                        h_prev = state[idx]

                    # Parameters of gates are concatenated into one multiply for efficiency.
                    with tf.compat.v1.variable_scope("gates"):
                        concat = rnn_ops.linear([cell_input, h_prev], 2 * self._num_units[idx], bias=True, bias_start=1.0,)

                    # r = reset_gate, u = update_gate
                    r, u = tf.split(value=concat, num_or_size_splits=2, axis=1)

                    if self._layer_norm:
                        r = rnn_ops.layer_norm(r, name="r")
                        u = rnn_ops.layer_norm(u, name="u")

                    # Apply non-linearity after layer normalization
                    r = tf.sigmoid(r)
                    u = tf.sigmoid(u)

                    with tf.compat.v1.variable_scope("candidate"):
                        new_c_tilde = self._activation(rnn_ops.linear([inputs, r * h_prev], self._num_units[idx], True))
                    new_h_tilde = u * h_prev + (1 - u) * new_c_tilde

                    state_candidates.append(new_h_tilde)
                    cell_input = new_h_tilde

            # Compute value for the update prob
            with tf.compat.v1.variable_scope('state_update_prob'):
                new_update_prob_tilde = rnn_ops.linear(state_candidates[-1], 1, True, bias_start=self._update_bias)
                new_update_prob_tilde = tf.sigmoid(new_update_prob_tilde)

            # Compute value for the update gate
            cum_update_prob = cum_update_prob_prev + tf.minimum(update_prob_prev, 1. - cum_update_prob_prev)
            update_gate = _binary_round(cum_update_prob)

            # Apply update gate
            new_states = []
            for idx in range(self._num_layers - 1):
                new_h = update_gate * state_candidates[idx] + (1. - update_gate) * state[idx]
                new_states.append(new_h)
            new_h = update_gate * state_candidates[-1] + (1. - update_gate) * state[-1].h
            new_update_prob = update_gate * new_update_prob_tilde + (1. - update_gate) * update_prob_prev
            new_cum_update_prob = update_gate * 0. + (1. - update_gate) * cum_update_prob

            new_states.append(SkipGRUStateTuple(new_h, new_update_prob, new_cum_update_prob))
            new_output = SkipGRUOutputTuple(new_h, update_gate)

            return new_output, new_states
Ejemplo n.º 2
0
    def __call__(self, inputs, state, scope=None):
        with tf.compat.v1.variable_scope(scope or type(self).__name__):
            update_prob_prev, cum_update_prob_prev = state[-1].update_prob, state[-1].cum_update_prob
            cell_input = inputs
            state_candidates = []

            # Compute update candidates for all layers
            for idx in range(self._num_layers):
                with tf.compat.v1.variable_scope('layer_%d' % (idx + 1)):
                    c_prev, h_prev = state[idx].c, state[idx].h

                    # Parameters of gates are concatenated into one multiply for efficiency.
                    concat = rnn_ops.linear([cell_input, h_prev], 4 * self._num_units[idx], True)

                    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
                    i, j, f, o = tf.split(value=concat, num_or_size_splits=4, axis=1)

                    if self._layer_norm:
                        i = rnn_ops.layer_norm(i, name="i")
                        j = rnn_ops.layer_norm(j, name="j")
                        f = rnn_ops.layer_norm(f, name="f")
                        o = rnn_ops.layer_norm(o, name="o")

                    new_c_tilde = (c_prev * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * self._activation(j))
                    new_h_tilde = self._activation(new_c_tilde) * tf.sigmoid(o)

                    state_candidates.append(LSTMStateTuple(new_c_tilde, new_h_tilde))
                    cell_input = new_h_tilde

            # Compute value for the update prob
            with tf.compat.v1.variable_scope('state_update_prob'):
                new_update_prob_tilde = rnn_ops.linear(state_candidates[-1].c, 1, True, bias_start=self._update_bias)
                new_update_prob_tilde = tf.sigmoid(new_update_prob_tilde)

            # Compute value for the update gate
            cum_update_prob = cum_update_prob_prev + tf.minimum(update_prob_prev, 1. - cum_update_prob_prev)
            update_gate = _binary_round(cum_update_prob)

            # Apply update gate
            new_states = []
            for idx in range(self._num_layers - 1):
                new_c = update_gate * state_candidates[idx].c + (1. - update_gate) * state[idx].c
                new_h = update_gate * state_candidates[idx].h + (1. - update_gate) * state[idx].h
                new_states.append(LSTMStateTuple(new_c, new_h))
            new_c = update_gate * state_candidates[-1].c + (1. - update_gate) * state[-1].c
            new_h = update_gate * state_candidates[-1].h + (1. - update_gate) * state[-1].h
            new_update_prob = update_gate * new_update_prob_tilde + (1. - update_gate) * update_prob_prev
            new_cum_update_prob = update_gate * 0. + (1. - update_gate) * cum_update_prob

            new_states.append(SkipLSTMStateTuple(new_c, new_h, new_update_prob, new_cum_update_prob))
            new_output = SkipLSTMOutputTuple(new_h, update_gate)

            return new_output, new_states
    def __call__(self, inputs, state, scope=None):
        with tf.variable_scope(scope or type(self).__name__):
            h_prev, update_prob_prev, cum_update_prob_prev = state

            # Parameters of gates are concatenated into one multiply for efficiency.
            with tf.variable_scope("gates"):
                concat = rnn_ops.linear([inputs, h_prev],
                                        2 * self._num_units,
                                        bias=True,
                                        bias_start=1.0)

            # r = reset_gate, u = update_gate
            r, u = tf.split(value=concat, num_or_size_splits=2, axis=1)

            if self._layer_norm:
                r = rnn_ops.layer_norm(r, name="r")
                u = rnn_ops.layer_norm(u, name="u")

            # Apply non-linearity after layer normalization
            r = tf.sigmoid(r)
            u = tf.sigmoid(u)

            with tf.variable_scope("candidate"):
                new_c_tilde = self._activation(
                    rnn_ops.linear([inputs, r * h_prev], self._num_units,
                                   True))
            new_h_tilde = u * h_prev + (1 - u) * new_c_tilde

            # Compute value for the update prob
            with tf.variable_scope('state_update_prob'):
                new_update_prob_tilde = rnn_ops.linear(
                    new_h_tilde, 1, True, bias_start=self._update_bias)
                new_update_prob_tilde = tf.sigmoid(new_update_prob_tilde)

            # Compute value for the update gate
            cum_update_prob = cum_update_prob_prev + tf.minimum(
                update_prob_prev, 1. - cum_update_prob_prev)
            update_gate = _binary_round(cum_update_prob)

            # Apply update gate
            new_h = update_gate * new_h_tilde + (1. - update_gate) * h_prev
            new_update_prob = update_gate * new_update_prob_tilde + (
                1. - update_gate) * update_prob_prev
            new_cum_update_prob = update_gate * 0. + (
                1. - update_gate) * cum_update_prob

            new_state = SkipGRUStateTuple(new_h, new_update_prob,
                                          new_cum_update_prob)
            new_output = SkipGRUOutputTuple(new_h, update_gate)

            return new_output, new_state
    def __call__(self, inputs, state, scope=None):
        with tf.variable_scope(scope or type(self).__name__):
            c_prev, h_prev, update_prob_prev, cum_update_prob_prev = state

            # Parameters of gates are concatenated into one multiply for efficiency.
            concat = rnn_ops.linear([inputs, h_prev], 4 * self._num_units,
                                    True)

            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
            i, j, f, o = tf.split(value=concat, num_or_size_splits=4, axis=1)

            if self._layer_norm:
                i = rnn_ops.layer_norm(i, name="i")
                j = rnn_ops.layer_norm(j, name="j")
                f = rnn_ops.layer_norm(f, name="f")
                o = rnn_ops.layer_norm(o, name="o")

            new_c_tilde = (c_prev * tf.sigmoid(f + self._forget_bias) +
                           tf.sigmoid(i) * self._activation(j))
            new_h_tilde = self._activation(new_c_tilde) * tf.sigmoid(o)

            # Compute value for the update prob
            with tf.variable_scope('state_update_prob'):
                new_update_prob_tilde = rnn_ops.linear(
                    new_c_tilde, 1, True, bias_start=self._update_bias)
                new_update_prob_tilde = tf.sigmoid(new_update_prob_tilde)

            # Compute value for the update gate
            cum_update_prob = cum_update_prob_prev + tf.minimum(
                update_prob_prev, 1. - cum_update_prob_prev)
            update_gate = _binary_round(cum_update_prob)

            # Apply update gate
            new_c = update_gate * new_c_tilde + (1. - update_gate) * c_prev
            new_h = update_gate * new_h_tilde + (1. - update_gate) * h_prev
            new_update_prob = update_gate * new_update_prob_tilde + (
                1. - update_gate) * update_prob_prev
            new_cum_update_prob = update_gate * 0. + (
                1. - update_gate) * cum_update_prob

            new_state = SkipLSTMStateTuple(new_c, new_h, new_update_prob,
                                           new_cum_update_prob)
            new_output = SkipLSTMOutputTuple(new_h, update_gate)

            return new_output, new_state
    def __call__(self, inputs, state, scope=None):
        """Gated recurrent unit (GRU) with num_units cells."""
        with tf.variable_scope(scope or type(self).__name__):
            with tf.variable_scope("gates"):  # Reset gate and update gate.
                # We start with bias of 1.0 to not reset and not update.
                concat = rnn_ops.linear([inputs, state],
                                        2 * self._num_units,
                                        True,
                                        bias_start=1.0)
                r, u = tf.split(value=concat, num_or_size_splits=2, axis=1)

                if self._layer_norm:
                    r = rnn_ops.layer_norm(r, name="r")
                    u = rnn_ops.layer_norm(u, name="u")

                # Apply non-linearity after layer normalization
                r = tf.sigmoid(r)
                u = tf.sigmoid(u)

            with tf.variable_scope("candidate"):
                c = self._activation(
                    rnn_ops.linear([inputs, r * state], self._num_units, True))
            new_h = u * state + (1 - u) * c
        return new_h, new_h
    def __call__(self, inputs, state, scope=None):
        """Long short-term memory cell (LSTM)."""
        with tf.variable_scope(scope or type(self).__name__):
            c, h = state

            # Parameters of gates are concatenated into one multiply for efficiency.
            concat = rnn_ops.linear([inputs, h], 4 * self._num_units, True)

            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
            i, j, f, o = tf.split(value=concat, num_or_size_splits=4, axis=1)

            if self._layer_norm:
                i = rnn_ops.layer_norm(i, name="i")
                j = rnn_ops.layer_norm(j, name="j")
                f = rnn_ops.layer_norm(f, name="f")
                o = rnn_ops.layer_norm(o, name="o")

            new_c = (c * tf.sigmoid(f + self._forget_bias) +
                     tf.sigmoid(i) * self._activation(j))
            new_h = self._activation(new_c) * tf.sigmoid(o)

            new_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)
            return new_h, new_state