Beispiel #1
0
def functional_rnn(cell,
                   inputs,
                   sequence_length=None,
                   initial_state=None,
                   dtype=None,
                   time_major=False,
                   scope=None,
                   use_tpu=False):
    """Same interface as `tf.nn.dynamic_rnn`."""
    with variable_scope.variable_scope(scope or 'rnn'):
        if not time_major:
            inputs = nest.map_structure(
                lambda t: array_ops.transpose(t, [1, 0, 2]), inputs)
        inputs_flat = nest.flatten(inputs)
        batch_size = array_ops.shape(inputs_flat[0])[1]
        if initial_state is None:
            initial_state = cell.zero_state(batch_size, dtype)
        func_cell = _FunctionalRnnCell(cell, inputs, initial_state)
    extended_acc_state, extended_final_state = recurrent.Recurrent(
        theta=func_cell.theta,
        state0=func_cell.extended_initial_state,
        inputs=inputs,
        cell_fn=func_cell.cell_step,
        use_tpu=use_tpu)
    tf_output, tf_state = _PostProcessOutput(extended_acc_state,
                                             extended_final_state, func_cell,
                                             inputs_flat[0].shape[0],
                                             sequence_length)

    if time_major:
        tf_output = array_ops.transpose(tf_output, [1, 0, 2])
    return tf_output, tf_state
Beispiel #2
0
    def _ParameterizedTestElman(self, seqlen, use_grad):

        with self.test_session() as sess:
            random_seed.set_random_seed(342462)

            batch = 3
            dims = 4
            theta = _ElmanTheta(w=RecurrentTest.Rand([2 * dims, dims]),
                                b=RecurrentTest.Rand([dims]))
            state0 = _ElmanState(h=RecurrentTest.Rand([batch, dims]))
            inputs = _ElmanInputs(x=RecurrentTest.Rand([seqlen, batch, dims]))

            # Statically unrolled.
            s = state0
            out = []
            for i in xrange(seqlen):
                inp = _ElmanInputs(x=inputs.x[i, :])
                s, _ = RecurrentTest.Elman(theta, s, inp)
                out += [s.h]
            acc0, final0 = array_ops.stack(out), s.h
            loss0 = math_ops.reduce_sum(acc0) + math_ops.reduce_sum(final0)
            (dw0, db0, dh0, di0) = gradients_impl.gradients(
                loss0, [theta.w, theta.b, state0.h, inputs.x])

            acc1, final1 = recurrent.Recurrent(
                theta=theta,
                state0=state0,
                inputs=inputs,
                cell_fn=RecurrentTest.Elman,
                cell_grad=RecurrentTest.ElmanGrad if use_grad else None)
            assert isinstance(acc1, _ElmanState)
            assert isinstance(final1, _ElmanState)
            acc1, final1 = acc1.h, final1.h
            loss1 = math_ops.reduce_sum(acc1) + math_ops.reduce_sum(final1)
            (dw1, db1, dh1, di1) = gradients_impl.gradients(
                loss1, [theta.w, theta.b, state0.h, inputs.x])

            # Fetches a few values and compare them.
            (acc0, acc1, final0, final1, dw0, dw1, db0, db1, dh0, dh1, di0,
             di1) = sess.run([
                 acc0, acc1, final0, final1, dw0, dw1, db0, db1, dh0, dh1, di0,
                 di1
             ])
            self.assertAllClose(acc0, acc1)
            self.assertAllClose(final0, final1)
            self.assertAllClose(dw0, dw1)
            self.assertAllClose(db0, db1)
            self.assertAllClose(dh0, dh1)
            self.assertAllClose(di0, di1)
Beispiel #3
0
    def testBasic(self):
        # pylint:disable=invalid-name
        _PolyState = collections.namedtuple('PolyState', ('value', 'x_power'))
        _PolyTheta = collections.namedtuple('PolyTheta', ('x'))
        _PolyInputs = collections.namedtuple('PolyInputs', ('coeff'))

        # pylint:enable=invalid-name

        def Poly(theta, state, inputs):
            next_state = _PolyState(value=state.value +
                                    inputs.coeff * state.x_power,
                                    x_power=state.x_power * theta.x)
            return next_state, []

        with self.test_session() as sess:
            theta = _PolyTheta(x=array_ops.constant(2.0))
            state = _PolyState(value=array_ops.constant(0.0),
                               x_power=array_ops.constant(1.0))
            inputs = _PolyInputs(coeff=array_ops.constant([1., 2., 3.]))

            # x = 2
            # 1 + 2*x + 3*x^2
            ret = recurrent.Recurrent(theta, state, inputs, Poly)

            acc, state = sess.run(ret)
            self.assertAllClose(acc.value, [1., 5., 17.])
            self.assertAllClose(acc.x_power, [2., 4., 8.])
            self.assertAllClose(state.value, 17.)
            self.assertAllClose(state.x_power, 8.)

            y = ret[1].value
            dx, d_coeff = gradients_impl.gradients(ys=[y],
                                                   xs=[theta.x, inputs.coeff])
            dx_val, d_coeff_val = sess.run([dx, d_coeff])

            # 2 + 6*x
            self.assertAllClose(dx_val, 14.)
            self.assertAllClose(d_coeff_val, [1., 2., 4.])

            # acc = [1, 1+2x, 1+2x+3x^2]
            # sum(acc) = 3 + 4x + 3x^2
            acc = ret[0].value
            dx, d_coeff = gradients_impl.gradients(
                ys=[math_ops.reduce_sum(acc)], xs=[theta.x, inputs.coeff])
            dx_val, d_coeff_val = sess.run([dx, d_coeff])
            # 4 + 6*x
            self.assertAllClose(dx_val, 16.)
            self.assertAllClose(d_coeff_val, [3., 4., 4.])