示例#1
0
  def testNestedCellFn(self):
    """Tests when cell_fn calls another function."""

    @function.Defun(tf.float32)
    def RandWithCoeff(coeff):
      return coeff * tf.random_uniform(shape=[], dtype=coeff.dtype)

    def Rand(theta, state, inputs):
      del theta
      next_state = py_utils.NestedMap()
      next_state.value = state.value + RandWithCoeff(inputs.coeff)
      return next_state, py_utils.NestedMap()

    @function.Defun(tf.float32)
    def Coeff(coeff):
      return coeff * 2

    def Deterministic(theta, state, inputs):
      del theta
      next_state = py_utils.NestedMap()
      next_state.value = state.value + Coeff(inputs.coeff)
      return next_state, py_utils.NestedMap()

    with self.session():
      theta = py_utils.NestedMap()
      state = py_utils.NestedMap()
      state.value = tf.constant(0.0)
      inputs = py_utils.NestedMap()
      inputs.coeff = tf.constant([1., 2., 3.])

      recurrent.Recurrent(
          theta, state, inputs, Deterministic, check_stateful_ops=True)
      with self.assertRaisesRegexp(ValueError, 'stateful.*RandWithCoeff'):
        recurrent.Recurrent(theta, state, inputs, Rand, check_stateful_ops=True)
示例#2
0
    def testSymbolToTensorMap(self):
        """Tests that cell_fn can rely on the contextual symbol-to-tensor map."""

        x = symbolic.Symbol('x')
        y = symbolic.Symbol('y')

        def PlusWXT(theta, state, inputs):
            """state.value += theta.w * x * inputs.t."""
            next_state = py_utils.NestedMap()
            x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x)
            next_state.value = state.value + theta.w * x_tensor * inputs.t
            return next_state, py_utils.NestedMap()

        def PlusWXTGrad(theta, state0, inputs, extras, dstate1):
            """Gradient function for PlusWXT."""
            del state0, extras
            x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x)
            dtheta = py_utils.NestedMap(w=dstate1.value * x_tensor * inputs.t)
            dstate0 = py_utils.NestedMap(value=dstate1.value)
            dinputs = py_utils.NestedMap(t=dstate1.value * theta.w * x_tensor)
            return dtheta, dstate0, dinputs, None

        with self.session() as sess:
            theta = py_utils.NestedMap(w=tf.constant(1., name='w'))
            state0 = py_utils.NestedMap(value=tf.constant(0., name='value'))
            inputs = py_utils.NestedMap(t=tf.constant([1., 2., 3.], name='t'))

            # With automatic cell_grad.
            with symbolic.SymbolToValueMap(symbolic.TENSOR_VALUES, {
                    x: tf.constant(7., name='x7'),
                    y: 8
            }):
                x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x)
                _, state1 = recurrent.Recurrent(theta, state0, inputs, PlusWXT)
                dw = tf.gradients(ys=[state1.value], xs=[theta.w])[0]
                dx = tf.gradients(ys=[state1.value], xs=[x_tensor])[0]
                final_value, x_val, dx_val, dw_val = sess.run(
                    [state1.value, x_tensor, dx, dw])
            self.assertEqual(x_val, 7)
            self.assertEqual(final_value, x_val * (1. + 2. + 3.))
            self.assertEqual(dw_val, x_val * (1. + 2. + 3.))
            self.assertEqual(dx_val, (1. + 2. + 3.))

            # With manual cell_grad.
            with symbolic.SymbolToValueMap(symbolic.TENSOR_VALUES,
                                           {x: tf.constant(5., name='x5')}):
                x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x)
                _, state1 = recurrent.Recurrent(theta,
                                                state0,
                                                inputs,
                                                PlusWXT,
                                                cell_grad=PlusWXTGrad)
                dw = tf.gradients(ys=[state1.value], xs=[theta.w])[0]
                dx = tf.gradients(ys=[state1.value], xs=[x_tensor])[0]
                final_value, x_val, dx_val, dw_val = sess.run(
                    [state1.value, x_tensor, dx, dw])
            self.assertEqual(x_val, 5)
            self.assertEqual(final_value, x_val * (1. + 2. + 3.))
            self.assertEqual(dw_val, x_val * (1. + 2. + 3.))
            self.assertEqual(dx_val, (1. + 2. + 3.))
示例#3
0
    def testCaptureDisallowed(self):

        with self.session() as unused_sess:

            theta = py_utils.NestedMap()
            theta.x = tf.constant(2.0)
            state0 = py_utils.NestedMap()
            state0.value = tf.constant(0.0)
            inputs = py_utils.NestedMap()
            inputs.coeff = tf.constant([1., 2., 3.])
            captured = tf.constant(1.2, name='captured_const')

            def CellFn(theta, state, inputs):
                next_state = py_utils.NestedMap()
                # Captured is pulled from outside of the function and captured via the
                # internal Defun.
                next_state.value = state.value + inputs.coeff * captured * theta.x
                return next_state, py_utils.NestedMap()

            # Run real fprop implementation.
            with self.assertRaisesRegexp(AssertionError,
                                         'implicit capture is disabled'):
                unused_real_acc, unused_real_staten = recurrent.Recurrent(
                    theta,
                    state0,
                    inputs,
                    CellFn,
                    allow_implicit_capture=False)
示例#4
0
  def FProp(self, theta, *args):
    """Runs p.repeat copies of self.body.FProp independently.

    Args:
      theta: Layer model parameters. The shape of each variable in theta is
        always [p.repeat, ...]. And the i-th slice theta[i] becomes theta of the
        i-th copy of self.body.
      *args: Input arguments. The shape of each tensor in args is always
        [p.repeat, ....]. And the list [arg[i] for arg in args] becomes inputs
        to the i-th copy of self.body.FProp.

    Returns:
      The accumulated output_tensors. Each tensor t in the return has the shape
      [p.repeat, ....] and the tuple (t[i] for i in output_tensors) is the
      return tuple of the i-th self.body.FProp.
    """
    p = self.params
    for arg in args:
      if arg is not None:
        arg = py_utils.HasShape(arg, [p.repeat], ndims=1)

    theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars, p.repeat)
    inputs = py_utils.NestedMap(theta=theta_stack, args=list(args))
    # Infer out_shapes from FPropMeta.
    out_shapes = self._InferOutShapes(args)

    def _CellFn(unused_theta, unused_state0, inputs):
      """Recurrent cell function wrapper of body.FProp."""
      # Sets shapes for both theta and inputs to self.body.FProp.
      for dst, src in zip(inputs.args + inputs.theta.Flatten(),
                          list(args) + theta_stack.Flatten()):
        if src is not None:
          dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

      # Runs the actual body.FProp
      fprop_outputs = self.body.FProp(inputs.theta, *inputs.args)
      fprop_outputs = _ToTuple(fprop_outputs)
      assert len(fprop_outputs) == len(out_shapes)
      # Passes fprop outputs to the next layer through state.
      state1 = py_utils.NestedMap(outputs=list(fprop_outputs))
      return state1, py_utils.NestedMap()

    with tf.name_scope(p.name):
      # Initiate state0 with inferred output shapes.
      state0 = py_utils.NestedMap(
          outputs=[tf.zeros(shape, args[0].dtype) for shape in out_shapes])
      # Runs body.FProp p.repeat times using Recurrent.
      acc_states, _ = recurrent.Recurrent(
          theta=py_utils.NestedMap(),
          state0=state0,
          inputs=inputs,
          cell_fn=_CellFn)

      # Retrieves fprop outputs from state1 and sets shapes.
      output_tensors = tuple(acc_states.outputs)
      for out_idx in range(len(output_tensors)):
        output_tensors[out_idx].set_shape(
            tf.TensorShape([p.repeat] + out_shapes[out_idx].as_list()))

      return output_tensors[0] if len(args) == 1 else tuple(output_tensors)
    def testRecurrent(self):
        """Run on GPU locally with --test_output=streamed --config=cuda."""
        def Sum(theta, state, inputs):
            next_state = py_utils.NestedMap()
            v = tf.reduce_sum(tf.one_hot(inputs.one_hot, depth=2) * theta.x,
                              axis=0)
            next_state.sum = state.sum + v
            return next_state, py_utils.NestedMap()

        with self.session(use_gpu=True) as sess:

            theta = py_utils.NestedMap()
            theta.x = tf.constant([-1.0, 2.0])
            state = py_utils.NestedMap()
            state.sum = tf.constant(0.0)
            inputs = py_utils.NestedMap()
            inputs.one_hot = tf.constant([0, 1, 1], dtype=tf.int32)

            # sum = -1 + 2 + 2
            ret = recurrent.Recurrent(theta, state, inputs, Sum)

            acc, state = sess.run(ret)
            self.assertAllClose(acc.sum, [-1., 1., 3.])
            self.assertAllClose(state.sum, 3.)

            y = ret[1].sum
            dx, d_inputs = tf.gradients(ys=[y], xs=[theta.x, inputs.one_hot])
            tf.logging.info('d(inputs) = %s', d_inputs)
            dx_val = sess.run(dx)
            self.assertAllClose(dx_val, [1, 2])
示例#6
0
    def testCapture(self):

        with self.session():

            theta = py_utils.NestedMap()
            theta.x = tf.constant(2.0)
            state0 = py_utils.NestedMap()
            state0.value = tf.constant(0.0)
            inputs = py_utils.NestedMap()
            inputs.coeff = tf.constant([1., 2., 3.])
            captured = tf.constant(1.2, name='captured_const')

            def CellFn(theta, state, inputs):
                next_state = py_utils.NestedMap()
                # Captured is pulled from outside of the function and captured via the
                # internal Defun.
                next_state.value = state.value + inputs.coeff * captured * theta.x
                return next_state, py_utils.NestedMap()

            # Run static fprop reference implementation.
            ref_acc, ref_staten = _ReferenceStaticUnroll(
                theta, state0, inputs, CellFn)
            ref_acc_values, ref_staten_values = self.evaluate(
                (ref_acc, ref_staten))
            print('Ref fprop: acc =', ref_acc, ', stateN =', ref_staten)
            print('Ref fprop values: acc =', ref_acc_values, ', stateN =',
                  ref_staten_values)
            self.assertAllClose(ref_acc_values.value, [2.4, 7.2, 14.4])
            self.assertAllClose(ref_staten_values.value, 14.4)

            # Run real fprop implementation.
            real_acc, real_staten = recurrent.Recurrent(
                theta, state0, inputs, CellFn, allow_implicit_capture=True)
            real_acc_values, real_staten_values = self.evaluate(
                (real_acc, real_staten))
            print('Real fprop: acc =', real_acc, ', stateN =', real_staten)
            print('Real fprop values: acc =', real_acc_values, ', stateN =',
                  real_staten_values)
            self.assertAllClose(ref_acc_values.value, real_acc_values.value)
            self.assertAllClose(ref_staten_values.value,
                                real_staten_values.value)

            # BProp real vs ref of stateN.
            ref_dx, ref_dcaptured = tf.gradients(ys=[ref_staten.value],
                                                 xs=[theta.x, captured])
            ref_dx_values, ref_dcaptured_values = self.evaluate(
                [ref_dx, ref_dcaptured])
            real_dx, real_dcaptured = tf.gradients(ys=[real_staten.value],
                                                   xs=[theta.x, captured])
            real_dx_values, real_dcaptured_values = self.evaluate(
                [real_dx, real_dcaptured])
            print('Ref Dstate/[dx,dcaptured] =', ref_dx_values, ', ',
                  ref_dcaptured_values)
            print('Real Dstate/[dx,dcaptured] =', real_dx_values, ', ',
                  real_dcaptured_values)
            self.assertAllClose(ref_dx_values, 7.2)
            self.assertAllClose(ref_dcaptured_values, 12.0)
            self.assertAllClose(ref_dx_values, real_dx_values)
            self.assertAllClose(ref_dcaptured_values, real_dcaptured_values)
示例#7
0
    def FProp(self, theta, *args):
        p = self.params
        # Collects all variable key and values into sets.
        theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars,
                                            p.repeat)

        def _ArgsToState(arg_list):
            """Returns a NestedMap from a list of FProp args."""
            state = py_utils.NestedMap()
            # Maintains a mapping from arg_idx to tensor. states cannot contains
            # None tensors.
            for idx in range(len(args)):
                if arg_list[idx] is not None:
                    state['_s{}'.format(idx)] = arg_list[idx]
            return state

        def _StateToArgs(state):
            """Returns a list of FProp args from a NestedMap."""
            arg_list = []
            for idx in range(len(args)):
                attr = '_s{}'.format(idx)
                arg_list.append(state[attr] if attr in state else None)
                if arg_list[-1] is not None:
                    arg_list[-1].set_shape(args[idx].shape)
            return arg_list

        def _CellFn(unused_theta, state0, theta_i):
            """Recurrent cell function wrapper of body.FProp."""
            # Retrieves fprop arguments from state and sets shapes.
            frop_inputs = _StateToArgs(state0)

            # Sets shapes for theta_i as well.
            for dst, src in zip(theta_i.Flatten(), theta_stack.Flatten()):
                if src is not None:
                    dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

            # Runs the actual body.FProp
            frop_outputs = self.body.FProp(theta_i, *frop_inputs)
            frop_outputs = _ToTuple(frop_outputs)
            assert len(frop_outputs) == len(frop_inputs)

            # Passes fprop outputs to the next layer through state.
            state1 = _ArgsToState(frop_outputs)
            return state1, py_utils.NestedMap()

        with tf.name_scope(p.name):
            # Add FProp arg list to state0.
            state0 = _ArgsToState(args)
            # Runs body.FProp k times using Recurrent where k = dim 0 of var_nmap.
            _, state1 = recurrent.Recurrent(
                theta=py_utils.NestedMap(),
                state0=state0,
                inputs=theta_stack,  # Pass cell_fn theta through inputs.
                cell_fn=_CellFn)

            # Retrieves fprop outputs from state1 and sets shapes.
            output_tensors = _StateToArgs(state1)
            return output_tensors[0] if len(args) == 1 else tuple(
                output_tensors)
示例#8
0
    def _testElmanHelper(self, seqlen, use_grad, stop_fn=None):
        with self.session() as sess:
            tf.set_random_seed(342462)

            batch = 3
            dims = 4
            theta = py_utils.NestedMap()
            theta.w = self.Rand([2 * dims, dims])
            theta.b = self.Rand([dims])
            state0 = py_utils.NestedMap()
            state0.h = self.Rand([batch, dims])
            inputs = py_utils.NestedMap()
            inputs.x = self.Rand([seqlen, batch, dims])

            # Static unrolled.
            s = state0
            out = []
            for i in range(seqlen):
                inp = py_utils.NestedMap()
                inp.x = inputs.x[i, :]
                s, _ = self.Elman(theta, s, inp)
                out += [s.h]
                if stop_fn and stop_fn(i + 1, theta, s):
                    out += [
                        tf.zeros_like(out[-1]) for _ in range(seqlen - i - 1)
                    ]
                    break
            acc0, final0 = tf.stack(out), s.h
            loss0 = tf.reduce_sum(acc0) + tf.reduce_sum(final0)
            (dw0, db0, dh0,
             di0) = tf.gradients(loss0, [theta.w, theta.b, state0.h, inputs.x])

            # Uses the Recurrent() library.
            acc1, final1 = recurrent.Recurrent(
                theta=theta,
                state0=state0,
                inputs=inputs,
                cell_fn=self.Elman,
                cell_grad=self.ElmanGrad if use_grad else None,
                stop_fn=stop_fn)
            acc1, final1 = acc1.h, final1.h
            loss1 = tf.reduce_sum(acc1) + tf.reduce_sum(final1)
            (dw1, db1, dh1,
             di1) = tf.gradients(loss1, [theta.w, theta.b, state0.h, inputs.x])

            # Fetches a bunch of values and compare them.
            (acc0, acc1, final0, final1, dw0, dw1, db0, db1, dh0, dh1, di0,
             di1) = sess.run([
                 acc0, acc1, final0, final1, dw0, dw1, db0, db1, dh0, dh1, di0,
                 di1
             ])
            self.assertAllClose(acc0, acc1)
            self.assertAllClose(final0, final1)
            self.assertAllClose(dw0, dw1)
            self.assertAllClose(db0, db1)
            self.assertAllClose(dh0, dh1)
            self.assertAllClose(di0, di1)
示例#9
0
    def testRnnStepRecurrent(self):
        with self.session(use_gpu=False) as sess:
            p = rnn_steps.RnnStep.Params()
            p.name = 'rnn_step'
            p.cell.name = 'rnn_step_cell'
            p.cell.output_nonlinearity = True
            p.cell.params_init = py_utils.WeightInit.Uniform(1.24, 429891685)
            p.cell.bias_init = py_utils.WeightInit.Uniform(1.24, 429891685)
            p.cell.vn.global_vn = False
            p.cell.vn.per_step_vn = False
            p.cell.num_input_nodes = 2
            p.cell.num_output_nodes = 2
            p.cell.inputs_arity = 2

            rnn_step = p.Instantiate()
            external = tf.constant([[3]], tf.float32)
            packed = rnn_step.PrepareExternalInputs(rnn_step.theta, external)
            padding = tf.constant([0.0], dtype=tf.float32)

            def cell_fn(theta, state0, inputs):
                output, state1 = rnn_step.FProp(theta, state0.packed, inputs,
                                                state0.padding, state0.state)
                return py_utils.NestedMap(
                    output=output,
                    state=state1,
                    packed=state0.packed,
                    padding=state0.padding), py_utils.NestedMap()

            _, state2 = recurrent.Recurrent(
                theta=rnn_step.theta,
                state0=py_utils.NestedMap(
                    state=rnn_step.ZeroState(rnn_step.theta, packed, 1),
                    output=py_utils.NestedMap(output=tf.zeros([1, 2],
                                                              tf.float32),
                                              extra=py_utils.NestedMap(),
                                              padding=padding),
                    padding=padding,
                    packed=packed),
                cell_fn=cell_fn,
                inputs=py_utils.NestedMap(
                    inputs=[tf.constant([[[4]], [[4]]], tf.float32)]))

            tf.global_variables_initializer().run()
            state2 = sess.run(state2)
            self.assertAllClose(state2.state.m, [[-0.32659757, 0.87739915]])
            self.assertAllClose(state2.state.c, [[-1.9628618, 1.4194499]])
            self.assertAllClose(state2.output.output,
                                [[-0.32659757, 0.87739915]])
示例#10
0
    def testStopFnNotTriggeredBeforeEOS(self):

        with self.session() as sess:

            def StopFn(t, unused_theta, unused_state):
                # The input sequence is only length 4, so this is never true.
                # However, the Recurrent call should still terminate after iteration 4.
                return t >= 5

            theta = py_utils.NestedMap()
            theta.x = tf.constant(2.0)
            state = py_utils.NestedMap()
            state.value = tf.constant(0.0)
            state.x_power = tf.constant(1.0)
            inputs = py_utils.NestedMap()
            inputs.coeff = tf.constant([1., 2., 3., 4.])

            # x = 2
            # 1 + 2*x + 3*x^2 + 4*x^3
            ret = recurrent.Recurrent(theta,
                                      state,
                                      inputs,
                                      _Poly,
                                      stop_fn=StopFn)

            acc, state = sess.run(ret)
            self.assertAllClose([1., 5., 17., 49.], acc.value)
            self.assertAllClose([2., 4., 8., 16.], acc.x_power)
            self.assertAllClose(49., state.value)
            self.assertAllClose(16., state.x_power)

            y = ret[1].value
            dx, d_coeff = tf.gradients(ys=[y], xs=[theta.x, inputs.coeff])
            dx_val, d_coeff_val = sess.run([dx, d_coeff])

            # 2 + 6*x + 12*x^2
            self.assertAllClose(62., dx_val)
            self.assertAllClose([1., 2., 4., 8.], d_coeff_val)

            # acc = [1, 1+2x, 1+2x+3x^2, 1+2x+3x^2+4x^3]
            # sum(acc) = 4 + 6x + 6x^2 + 4x^3
            acc = ret[0].value
            dx, d_coeff = tf.gradients(ys=[tf.reduce_sum(acc)],
                                       xs=[theta.x, inputs.coeff])
            dx_val, d_coeff_val = sess.run([dx, d_coeff])
            # 6 + 12*x + 12*x^2
            self.assertAllClose(78., dx_val)
            self.assertAllClose([4., 6., 8., 8.], d_coeff_val)
示例#11
0
  def FProp(self, theta, prepared_inputs, inputs, padding, state0, **kwargs):
    """Runs a Step layer over multiple timesteps using Recurrent.

    Args:
      theta: A NestedMap containing weights' values of this layer and its
        children layers.
      prepared_inputs: External inputs returned by Step.PrepareExternalInputs().
      inputs: A NestedMap of inputs of shape [time, batch_size, dim].
      padding: A 0/1 float tensor of shape [time, batch_size]; 1.0 means that
        this batch element is empty in this step.
      state0: A NestedMap containing the initial recurrent state.
      **kwargs: Additional kwargs to pass to Recurrent.

    Returns:
      A tuple (outputs, state1).

      - outputs: A NestedMap containing the accumulated outputs of all steps,
        containing Tensors shaped [time, batch_size, dim].
      - state1: A NestedMap containing the accumulated recurrent states,
        containing Tensors shaped [time, batch_size, dim].
    """

    def RnnStep(recurrent_theta, recurrent_state0, recurrent_inputs):
      """Compute a single timestep."""
      output, state1 = self.step.FProp(
          theta=recurrent_theta.theta,
          prepared_inputs=recurrent_theta.prepared_inputs,
          step_inputs=recurrent_inputs.inputs,
          padding=recurrent_inputs.padding,
          state0=recurrent_state0.state)
      recurrent_state1 = py_utils.NestedMap(output=output, state=state1)
      return recurrent_state1, py_utils.NestedMap()

    # In order to pass Step outputs through Recurrent, they need to be
    # included as part of state.
    output0, _ = self.step.FProp(theta.step, prepared_inputs,
                                 inputs.Transform(lambda x: x[0]), padding[0],
                                 state0)

    accumulated_states, _ = recurrent.Recurrent(
        theta=py_utils.NestedMap(
            theta=theta.step, prepared_inputs=prepared_inputs),
        state0=py_utils.NestedMap(output=output0, state=state0),
        inputs=py_utils.NestedMap(inputs=inputs, padding=padding),
        cell_fn=RnnStep,
        **kwargs)

    return accumulated_states.output, accumulated_states.state
示例#12
0
    def testStateBasedStopFn(self):

        with self.session() as sess:

            def StopFn(unused_t, unused_theta, state):
                # This stops after 3 iterations.
                return state.value >= 15.

            theta = py_utils.NestedMap()
            theta.x = tf.constant(2.0)
            state = py_utils.NestedMap()
            state.value = tf.constant(0.0)
            state.x_power = tf.constant(1.0)
            inputs = py_utils.NestedMap()
            inputs.coeff = tf.constant([1., 2., 3., 4.])

            # x = 2
            # 1 + 2*x + 3*x^2
            ret = recurrent.Recurrent(theta,
                                      state,
                                      inputs,
                                      _Poly,
                                      stop_fn=StopFn)

            acc, state = sess.run(ret)
            self.assertAllClose([1., 5., 17., 0.], acc.value)
            self.assertAllClose([2., 4., 8., 0.], acc.x_power)
            self.assertAllClose(17., state.value)
            self.assertAllClose(8., state.x_power)

            y = ret[1].value
            dx, d_coeff = tf.gradients(ys=[y], xs=[theta.x, inputs.coeff])
            dx_val, d_coeff_val = sess.run([dx, d_coeff])

            # 2 + 6*x
            self.assertAllClose(14., dx_val)
            self.assertAllClose([1., 2., 4., 0.], d_coeff_val)

            # acc = [1, 1+2x, 1+2x+3x^2]
            # sum(acc) = 3 + 4x + 3x^2
            acc = ret[0].value
            dx, d_coeff = tf.gradients(ys=[tf.reduce_sum(acc)],
                                       xs=[theta.x, inputs.coeff])
            dx_val, d_coeff_val = sess.run([dx, d_coeff])
            # 4 + 6*x
            self.assertAllClose(16., dx_val)
            self.assertAllClose([3., 4., 4., 0.], d_coeff_val)
示例#13
0
  def testBasic(self):

    with self.session() as sess:

      theta = py_utils.NestedMap()
      theta.x = tf.constant(2.0)
      state = py_utils.NestedMap()
      state.value = tf.constant(0.0)
      state.x_power = tf.constant(1.0)
      inputs = py_utils.NestedMap()
      inputs.coeff = tf.constant([1., 2., 3.])

      def Poly(theta, state, inputs):
        next_state = py_utils.NestedMap()
        next_state.value = state.value + inputs.coeff * state.x_power
        next_state.x_power = state.x_power * theta.x
        return next_state, py_utils.NestedMap()

      # x = 2
      # 1 + 2*x + 3*x^2
      ret = recurrent.Recurrent(theta, state, inputs, Poly)

      acc, state = sess.run(ret)
      self.assertAllClose(acc.value, [1., 5., 17.])
      self.assertAllClose(acc.x_power, [2., 4., 8.])
      self.assertAllClose(state.value, 17.)
      self.assertAllClose(state.x_power, 8.)

      y = ret[1].value
      dx, d_coeff = tf.gradients(ys=[y], xs=[theta.x, inputs.coeff])
      dx_val, d_coeff_val = sess.run([dx, d_coeff])

      # 2 + 6*x
      self.assertAllClose(dx_val, 14.)
      self.assertAllClose(d_coeff_val, [1., 2., 4.])

      # acc = [1, 1+2x, 1+2x+3x^2]
      # sum(acc) = 3 + 4x + 3x^2
      acc = ret[0].value
      dx, d_coeff = tf.gradients(
          ys=[tf.reduce_sum(acc)], xs=[theta.x, inputs.coeff])
      dx_val, d_coeff_val = sess.run([dx, d_coeff])
      # 4 + 6*x
      self.assertAllClose(dx_val, 16.)
      self.assertAllClose(d_coeff_val, [3., 4., 4.])
示例#14
0
  def FProp(self, theta, in_nmap):
    p = self.params
    if not p.remat:
      return self._FProp(theta, in_nmap)

    def CellFn(theta, state0, unused_inputs):
      out_nmap = self._FProp(theta, state0)
      return out_nmap, py_utils.NestedMap()

    _, state1 = recurrent.Recurrent(
        theta=theta,
        state0=in_nmap,
        inputs=py_utils.NestedMap(
            inputs=tf.zeros([1, 0])),  # A dummy input of shape [T, ?].
        cell_fn=CellFn,
        allow_implicit_capture=p.allow_implicit_capture)

    return state1
示例#15
0
  def testStatefulCellFn(self):

    def Rand(theta, state, inputs):
      del theta
      next_state = py_utils.NestedMap()
      next_state.value = (
          state.value +
          inputs.coeff * tf.random_uniform(shape=[], dtype=state.value.dtype))
      return next_state, py_utils.NestedMap()

    with self.session():
      theta = py_utils.NestedMap()
      state = py_utils.NestedMap()
      state.value = tf.constant(0.0)
      inputs = py_utils.NestedMap()
      inputs.coeff = tf.constant([1., 2., 3.])

      with self.assertRaisesRegexp(ValueError, 'stateful.*random_uniform'):
        recurrent.Recurrent(theta, state, inputs, Rand, check_stateful_ops=True)
示例#16
0
  def FProp(self, theta, inputs, paddings):
    p = self.params
    if not p.remat:
      return self._FProp(theta, inputs, paddings)

    def CellFn(theta, state0, unused_inputs):
      outs, out_paddings = self._FProp(theta, state0.inputs, state0.paddings)
      return py_utils.NestedMap(
          inputs=outs, paddings=out_paddings), py_utils.NestedMap()

    state0 = py_utils.NestedMap(inputs=inputs, paddings=paddings)
    _, state1 = recurrent.Recurrent(
        theta=theta,
        state0=state0,
        inputs=py_utils.NestedMap(
            inputs=tf.zeros([1, 0])),  # A dummy input of shape [T, ?].
        cell_fn=CellFn,
        allow_implicit_capture=p.allow_implicit_capture)

    return state1.inputs, state1.paddings
示例#17
0
  def EmbLookup(self, theta, ids, state0):
    """Use this layer like a regular EmbeddingLayer (ids -> embeddings).

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      ids: A tensor of token ids with dimensions [t, b].
      state0: A NestedMap containing the state of previous tokens.
      - prev_ids: A Tensor containing the n previous token ids. [batch,
        num_prev_tokens]. Each row is the token ids at t-1, ..., t-n.
      - embeddings: The output embeddings of the previous step.

    Returns:
      Embeddings time major [t, b, d] and next state

    """

    def _FPropWrapper(theta, state0, inputs):
      embedding, state1 = self.FProp(
          theta,
          None,
          inputs,
          None,
          state0,
      )
      state1.embedding = embedding.output
      return state1, py_utils.NestedMap()

    steps_input = py_utils.NestedMap(inputs=[tf.cast(ids, tf.float32)])
    state1, _ = recurrent.Recurrent(
        theta=theta,
        state0=state0,
        inputs=steps_input,
        cell_fn=_FPropWrapper,
    )
    # for training, we want the output across all timesteps
    sequence_of_embeddings = state1.embedding
    # for inference, we want just the last timestep's state, not all states
    state1.embedding = state1.embedding[-1]
    state1.prev_ids = state1.prev_ids[-1]
    return sequence_of_embeddings, state1
示例#18
0
    def testDropoutInRecurrent(self, graph_seed):
        with self.session() as sess:
            if graph_seed:
                tf.random.set_seed(12345)
            l = lingvo_layers.DeterministicDropoutLayer.Params().Set(
                name='dropout', keep_prob=0.7).Instantiate()
            # Input variable.
            w = tf.get_variable('w',
                                shape=[9, 20],
                                initializer=tf.ones_initializer())
            sess.run(tf.global_variables_initializer())
            prev_sum = np.sum(np.isclose(sess.run(w), 0.0))

            def Step(theta, state0, unused_inputs):
                w = l.FProp(theta.l, state0.w)
                state1 = py_utils.NestedMap(w=w)
                return state1, py_utils.NestedMap()

            acc, final = recurrent.Recurrent(
                theta=py_utils.NestedMap(l=l.theta),
                state0=py_utils.NestedMap(w=w),
                inputs=py_utils.NestedMap(x=tf.zeros([4])),
                cell_fn=Step)

            acc_w = sess.run(acc.w)
            self.assertLen(acc_w, 4)
            for acc_w_i in acc_w:
                next_sum = np.sum(np.isclose(acc_w_i, 0.0))
                self.assertGreater(next_sum, prev_sum)
                prev_sum = next_sum

            # Construct loss function such that gradients = final activation.
            loss = tf.reduce_sum(final.w)
            grads = py_utils.ComputeGradients(loss, py_utils.NestedMap(w=w))
            w_val, grads_val = sess.run([final.w, grads.w.grad])
            self.assertAllClose(w_val, grads_val)
示例#19
0
    def _Repeat(self,
                theta,
                fn,
                *,
                common_input=None,
                layerwise_inputs=None,
                iterative_input_0=None,
                layerwise_output_0=None):
        """Invokes 'fn' for each sub-layer.

    Args:
      theta: layer parameters.
      fn: A function with args (theta, common_input, layerwise_input, iterative)
        returning a NestedMap(layerwise_output=..., iterative=...).
      common_input: a NestedMap with common inputs for every sub-layer.
      layerwise_inputs: a nested tensor with separate inputs for each sub-layer,
        where each leaf value T is a tensor of shape [p.repeat, ...] and T[i,
        ...] represents layerwise inputs to the i'th sub-layer.
      iterative_input_0: a nested tensor for the iterative input of the 0'th
        sub-layer.
      layerwise_output_0: a nested tensor representing the layerwise output of
        the first sub layer. The actual tensor values are not used. Only the
        structure and tensor shapes are. In particular, if layerwise_output_0 is
        None, calls fn(...) to compute the output.

    Returns:
      A NestedMap with:
      - iterative: a nested tensor with the same structure as iterative_input_0
          representing the iterative output of the last sub-layer.
      - layerwise: a nested tensor where each leaf value T is a tensor of shape
          [p.repeat, ...] and T[i, ...] represents layerwise output from the
          i'th sub-layer.
    """
        p = self.params

        if common_input is None:
            common_input = py_utils.NestedMap()
        if layerwise_inputs is None:
            layerwise_inputs = py_utils.NestedMap()
        if iterative_input_0 is None:
            iterative_input_0 = py_utils.NestedMap()

        if p.per_layer_vars:
            all_iters = [theta['body_iter_%05d' % i] for i in range(p.repeat)]
            theta_stack = py_utils.NestedMap(body=tf.nest.map_structure(
                lambda *t: tf.stack(list(t)), *all_iters))
        else:
            theta_stack = self._MaybeStackExtraTheta(theta)
        theta_0 = tf.nest.map_structure(lambda t: t[0, ...], theta_stack)
        layerwise_input_0 = tf.nest.map_structure(lambda t: t[0, ...],
                                                  layerwise_inputs)

        if layerwise_output_0 is None:
            # Run 'fn' for sublayer 0 to get a template for layerwise output.
            layerwise_output_0 = fn(
                theta_0,
                common_input=common_input,
                layerwise_input=layerwise_input_0,
                iterative=iterative_input_0).layerwise_output

        # Split 'common_input' to static vs. tensor values.
        static_common_input = tf.nest.map_structure(
            lambda x: None if isinstance(x, tf.Tensor) else x, common_input)
        dummy_tensor = tf.zeros([], dtype=p.dtype)
        tensor_common_input = tf.nest.map_structure(
            lambda x: x
            if isinstance(x, tf.Tensor) else dummy_tensor, common_input)

        def _ToRecurState(*, iterative, layerwise_output):
            """Returns a recurrent state."""
            # Make sure each value in the NestedMap is a tensor.
            recur_state = py_utils.NestedMap(iterative=iterative,
                                             layerwise_output=layerwise_output)
            for key, value in recur_state.FlattenItems():
                if not isinstance(value,
                                  tf.Tensor) or value.shape.rank is None:
                    raise ValueError(
                        'Each value in the recurrent state must be a tensor '
                        f'with a known rank: {key}={value}')
            return recur_state

        def _CellFn(recur_theta, recur_state0, recur_input_i):
            """Recurrent cell function wrapper of body.FProp."""
            # Retrieves fprop arguments from state and sets shapes.
            theta_i = _CopyShapes(theta_0, recur_input_i.theta)
            # Use non-tensor values from 'static_common_input' if available.
            merged_common_input = tf.nest.map_structure(
                (lambda x, s, t: t if isinstance(x, tf.Tensor) else s),
                common_input, static_common_input, recur_theta)

            # Call 'fn'.
            fn_results = fn(theta_i,
                            common_input=merged_common_input,
                            layerwise_input=_CopyShapes(
                                layerwise_input_0, recur_input_i.layerwise),
                            iterative=_CopyShapes(iterative_input_0,
                                                  recur_state0.iterative))

            # Pack results.
            return _ToRecurState(**fn_results), py_utils.NestedMap()

        with tf.name_scope('repeat'):
            recur_state0 = _ToRecurState(iterative=iterative_input_0,
                                         layerwise_output=layerwise_output_0)
            # Runs body.FProp k times using Recurrent where k = dim 0 of var_nmap.
            acc_states, final_recur_state = recurrent.Recurrent(
                theta=tensor_common_input,
                state0=recur_state0,
                inputs=py_utils.NestedMap(theta=theta_stack,
                                          layerwise=layerwise_inputs),
                cell_fn=_CellFn,
                allow_implicit_capture=p.allow_implicit_capture)
            # Retrieves outputs from recur_state1 and sets shapes.
            iterative_output = _CopyShapes(iterative_input_0,
                                           final_recur_state.iterative)
            # Set shapes of layer-wise outputs according to layerwise_output_0.
            layerwise_outputs = acc_states.layerwise_output
            tf.nest.map_structure(
                lambda t_stack, t_0: t_stack.set_shape(
                    [p.repeat] + t_0.shape.as_list()), layerwise_outputs,
                layerwise_output_0)
            return py_utils.NestedMap(iterative=iterative_output,
                                      layerwise=layerwise_outputs)
示例#20
0
  def Sample(self, decoder_theta, encoder_outputs, random_seed,
             init_state_callback, pre_step_callback, post_step_callback):
    """Samples target sequences, one target sequence per source sequence.

    (Please see beam_search_helper.py for description of decoder callbacks.)

    Args:
      decoder_theta: A NestedMap object containing weights' values of the
        decoder layer and its children layers, to be passed to decoder
        callbacks.
      encoder_outputs: the outputs of the encoder, to be passed to callbacks.
      random_seed: a scalar int32 tensor representing the random seed.
      init_state_callback: decoder._InitBeamSearchStateCallback.
      pre_step_callback: decoder._PreBeamSearchStepCallback.
      post_step_callback: decoder._PostBeamSearchStepCallback.

    Returns:
      A NestedMap containing the following tensors:
      - 'logits': [batch, max_target_length, vocab_size], representing the
        distribution from which target sequences are sampled.
      - 'ids': [batch, max_target_length] of int32, representing the target
        sequence ids, not including target_sos_id, but maybe ending with
        target_eos_id if end-of-sequence is reached before target_seq_len.
      - 'paddings': [batch, max_target_length] of 0/1, where 1 represents
        a padded timestep.
    """
    p = self.params
    assert p.temperature > 0
    # 'recurrent_theta' represents all cross-timestep information used by the
    # recurrent loop below, including layer theta and encoder outputs.
    recurrent_theta = py_utils.NestedMap(
        theta=decoder_theta,
        random_seed=random_seed,
        encoder_outputs=encoder_outputs)
    bs_result, bs_state = init_state_callback(
        recurrent_theta.theta, encoder_outputs, num_hyps_per_beam=1)
    batch = tf.shape(bs_result.log_probs)[0]
    recurrent_state0 = py_utils.NestedMap(
        timestep=tf.zeros(shape=[], dtype=tf.int32),
        logits=bs_result.log_probs,
        # Start with target_sos_id.
        ids=tf.fill([batch], tf.to_int32(p.target_sos_id)),
        bs_state=bs_state)
    inputs = py_utils.NestedMap(dummy=tf.zeros([p.target_seq_len, batch]))

    def Step(recurrent_theta, state0, inputs):
      """Computes one decoder step."""
      del inputs
      with tf.name_scope('single_sampler_step'):
        # Compute logits and states.
        bs_result, bs_state1 = pre_step_callback(
            recurrent_theta.theta,
            recurrent_theta.encoder_outputs,
            tf.expand_dims(state0.ids, 1),  # [batch, 1].
            state0.bs_state,
            num_hyps_per_beam=1)
        batch = tf.shape(bs_result.log_probs)[0]
        state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
        state1.logits = bs_result.log_probs
        # Sample ids from logits. [batch].
        state1.ids = tf.reshape(
            tf.random.stateless_multinomial(
                state1.logits / p.temperature,
                num_samples=1,
                seed=tf.stack([recurrent_theta.random_seed, state0.timestep]),
                output_dtype=state0.ids.dtype,
                name='sample_next_id'), [batch])
        if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
          state1.ids = tf.where(
              tf.logical_and(bs_result.is_last_chunk,
                             tf.equal(state1.ids, p.target_eoc_id)),
              tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids)
        state1.bs_state = post_step_callback(recurrent_theta.theta,
                                             recurrent_theta.encoder_outputs,
                                             state1.ids, bs_state1)
      return state1, py_utils.NestedMap()

    accumulated_states, _ = recurrent.Recurrent(recurrent_theta,
                                                recurrent_state0, inputs, Step)
    result = py_utils.NestedMap(
        logits=tf.transpose(accumulated_states.logits, [1, 0, 2]),
        ids=tf.transpose(accumulated_states.ids))
    result.paddings = tf.cast(
        _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype)
    # Force ids to be eos_id if the timestep is padded.
    result.ids = tf.where(
        tf.equal(result.paddings, 0), result.ids,
        tf.fill(tf.shape(result.ids), p.target_eos_id))
    static_batch_size = bs_result.log_probs.shape[0]
    result.ids.set_shape([static_batch_size, p.target_seq_len])
    result.paddings.set_shape([static_batch_size, p.target_seq_len])
    return result
示例#21
0
  def FProp(self, theta, *args):
    p = self.params

    if self.do_eval and p.unrolled_in_eval:
      return self._unrolled_fprop(theta, *args)

    if p.per_layer_vars:
      all_iters = [theta['body_iter_%05d' % i] for i in range(p.repeat)]
      theta_stack = tf.nest.map_structure(lambda *t: tf.stack(list(t)),
                                          *all_iters)
    else:
      theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars, p.repeat)

    def _ArgsToState(arg_list):
      """Returns a NestedMap from a list of FProp args."""
      state = py_utils.NestedMap()
      # Maintains a mapping from arg_idx to tensor. states cannot contains
      # None tensors.
      for idx in range(len(args)):
        if isinstance(arg_list[idx], py_utils.NestedMap):
          # Make sure each value in the NestedMap is a tensor.
          if not all(isinstance(t, tf.Tensor) for t in arg_list[idx].Flatten()):
            raise ValueError(
                'Each value in the input NestedMap must be a tensor.')
        if arg_list[idx] is not None:
          state['_s{}'.format(idx)] = arg_list[idx]
      return state

    def _StateToArgs(state):
      """Returns a list of FProp args from a NestedMap."""
      arg_list = []
      for idx in range(len(args)):
        attr = '_s{}'.format(idx)
        arg_list.append(state[attr] if attr in state else None)
        if isinstance(arg_list[-1], py_utils.NestedMap):
          assert isinstance(args[idx], py_utils.NestedMap)
          py_utils.SetShapes(arg_list[-1], args[idx])
        elif isinstance(arg_list[-1], tf.Tensor):
          if arg_list[-1] is not None:
            arg_list[-1].set_shape(args[idx].shape)
      return arg_list

    def _CellFn(unused_theta, state0, theta_i):
      """Recurrent cell function wrapper of body.FProp."""
      # Retrieves fprop arguments from state and sets shapes.
      fprop_inputs = _StateToArgs(state0)

      # Sets shapes for theta_i as well.
      for dst, src in zip(theta_i.Flatten(), theta_stack.Flatten()):
        if src is not None:
          dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

      # Runs the actual body.FProp
      fprop_outputs = self._body.FProp(theta_i, *fprop_inputs)
      fprop_outputs = _ToTuple(fprop_outputs)
      assert len(fprop_outputs) == len(fprop_inputs)

      # Passes fprop outputs to the next layer through state.
      state1 = _ArgsToState(fprop_outputs)
      return state1, py_utils.NestedMap()

    with tf.name_scope(p.name):
      # Add FProp arg list to state0.
      state0 = _ArgsToState(args)
      # Runs body.FProp k times using Recurrent where k = dim 0 of var_nmap.
      _, state1 = recurrent.Recurrent(
          theta=py_utils.NestedMap(),
          state0=state0,
          inputs=theta_stack,  # Pass cell_fn theta through inputs.
          cell_fn=_CellFn,
          allow_implicit_capture=p.allow_implicit_capture)

      # Retrieves fprop outputs from state1 and sets shapes.
      output_tensors = _StateToArgs(state1)
      return output_tensors[0] if len(args) == 1 else tuple(output_tensors)
示例#22
0
    def FProp(self, theta, inputs, paddings, state0=None, segment_id=None):
        """Computes LSTM forward pass.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: A single tensor or a tuple of tensors with cardinality equal to
        rnn_cell.inputs_arity. For every input tensor, the first dimension is
        assumed to be time, second dimension batch, and third dimension depth.
      paddings: A tensor. First dim is time, second dim is batch, and third dim
        is expected to be 1.
      state0: If not None, the initial rnn state in a `.NestedMap`. Defaults to
        the cell's zero-state.
      segment_id: A tensor to support packed inputs. First dim is time, second
        dim is batch, and third dim is expected to be 1.

    Returns:
      A tensor of [time, batch, dims].
      The final recurrent state.
    """
        p = self.params
        assert isinstance(self.cell, rnn_cell.RNNCell)

        if not isinstance(inputs, (list, tuple)):
            inputs = [inputs]

        # Slicing wm to wm_{i,h} outside the loop to get 20% speedup over regular
        # LSTM baseline.
        # Keeping slicing within the loop gives only < 3% speedup.
        cell_theta = theta.cell.copy()
        num_input_nodes = p.cell.num_input_nodes
        cell_theta['wm_i'] = cell_theta.wm[:num_input_nodes, :]
        cell_theta['wm_h'] = cell_theta.wm[num_input_nodes:, :]
        tf.logging.vlog(1, 'cell_theta: %r', cell_theta)
        if p.packed_input:
            assert segment_id is not None
            reset_mask = rnn_layers.GeneratePackedInputResetMask(
                segment_id, is_reverse=False)
            reset_mask = py_utils.HasShape(reset_mask, tf.shape(paddings))
        else:
            reset_mask = tf.zeros_like(paddings)

        if p.reverse:
            inputs = [tf.reverse(x, [0]) for x in inputs]
            paddings = tf.reverse(paddings, [0])
            reset_mask = tf.reverse(reset_mask, [0])

        if not state0:
            batch_size = py_utils.GetShape(paddings)[1]
            state0 = self.cell.zero_state(cell_theta, batch_size)

        # [T, B, H]
        proj_inputs = self.cell.ProjectInputSequence(
            cell_theta, py_utils.NestedMap(act=inputs))
        proj_inputs = py_utils.NestedMap(proj_inputs=proj_inputs,
                                         padding=paddings,
                                         reset_mask=reset_mask)

        acc_state, final_state = recurrent.Recurrent(
            theta=cell_theta,
            state0=state0,
            inputs=proj_inputs,
            cell_fn=self.cell.FPropWithProjectedInput,
            cell_type=self.cell.layer_type,
            accumulator_layer=self,
            allow_implicit_capture=p.allow_implicit_capture)

        act = self.cell.GetOutput(acc_state)
        if p.reverse:
            act = tf.reverse(act, [0])
        return act, final_state
示例#23
0
    def _BuildStackedRecurrentElman(self, seqlen, trailing_pad_len, batch,
                                    dims, layers):
        tf.set_random_seed(342462)
        np.random.seed(32540)

        seqlen += trailing_pad_len
        dtype = tf.float64

        def CreateTheta():
            return py_utils.NestedMap(
                w=tf.constant(np.random.uniform(0, 0.2, (2 * dims, dims)),
                              dtype=dtype),
                b=tf.constant(np.random.uniform(0, 0.2, (dims, )),
                              dtype=dtype))

        def CreateState0():
            return py_utils.NestedMap(h=tf.constant(np.random.uniform(
                0, 0.2, (batch, dims)),
                                                    dtype=dtype),
                                      padding=tf.constant([[0]] * batch,
                                                          dtype=dtype))

        devices = ['/cpu:0'] * layers
        cell_fns = [self.Elman] * layers
        cell_grads = [self.ElmanGrad] * layers
        cell_outs = [self.ElmanOut] * layers
        cell_out_grads = [self.ElmanOutGrad] * layers
        thetas = [CreateTheta() for _ in range(layers)]
        init_states = [CreateState0() for _ in range(layers)]
        padding = np.zeros((seqlen, batch, 1))
        padding[-trailing_pad_len:, :, :] = 1.
        padding[-trailing_pad_len - 3:-trailing_pad_len - 1, :, :] = 1.
        inputs = py_utils.NestedMap(x=tf.constant(np.random.uniform(
            0, 0.2, (seqlen, batch, dims)),
                                                  dtype=dtype),
                                    padding=tf.constant(padding, dtype=dtype))
        output, _ = recurrent.StackedRecurrent(devices=devices,
                                               cell_fns=cell_fns,
                                               cell_grads=cell_grads,
                                               cell_outs=cell_outs,
                                               cell_out_grads=cell_out_grads,
                                               thetas=thetas,
                                               init_states=init_states,
                                               inputs=inputs)
        o = output.x
        if 'padding' in inputs:
            o *= (1 - inputs.padding)
        loss = tf.reduce_sum(tf.square(o))

        xs = recurrent.Flatten(thetas + [py_utils.NestedMap(x=inputs.x)])
        dxs = tf.gradients(ys=loss, xs=xs)

        # Reference implementation using Recurrent().
        ref = inputs
        for i in range(layers):
            ref = self.ElmanOut(
                recurrent.Recurrent(cell_fn=cell_fns[i],
                                    cell_grad=cell_grads[i],
                                    theta=thetas[i],
                                    state0=init_states[i],
                                    inputs=ref)[0])
        return ref.x, output.x, loss, xs, dxs
示例#24
0
    def Sample(self,
               decoder_theta,
               encoder_outputs,
               random_seed,
               init_state_callback,
               pre_step_callback,
               post_step_callback,
               init_step_ids=None):
        """Samples target sequences, one target sequence per source sequence.

    (Please see beam_search_helper.py for description of decoder callbacks.)

    Args:
      decoder_theta: A NestedMap object containing weights' values of the
        decoder layer and its children layers, to be passed to decoder
        callbacks.
      encoder_outputs: the outputs of the encoder, to be passed to callbacks.
      random_seed: a scalar int32 tensor representing the random seed.
      init_state_callback: decoder._InitBeamSearchStateCallback.
      pre_step_callback: decoder._PreBeamSearchStepCallback.
      post_step_callback: decoder._PostBeamSearchStepCallback.
      init_step_ids: [batch], optional init step ids, default to SOS.

    Returns:
      A NestedMap containing the following tensors

      - 'logits': [batch, max_target_length, vocab_size], representing the
        distribution from which target sequences are sampled.
      - 'ids': [batch, max_target_length] of int32, representing the target
        sequence ids, not including target_sos_id, but maybe ending with
        target_eos_id if end-of-sequence is reached before target_seq_len.
      - 'paddings': [batch, max_target_length] of 0/1, where 1 represents
        a padded timestep.
    """
        p = self.params
        assert p.temperature > 0
        assert p.top_k >= 0
        assert p.num_hyps_per_beam >= 1
        if getattr(encoder_outputs, 'segment_id', 1) is None:
            # Remove None values, which are not supported by recurrent.
            del encoder_outputs['segment_id']
        # init_state_callback may modify 'encoder_outputs', e.g., by inserting
        # 'packed_src'.
        bs_result, bs_state = init_state_callback(decoder_theta,
                                                  encoder_outputs,
                                                  p.num_hyps_per_beam)
        # 'recurrent_theta' represents all cross-timestep information used by the
        # recurrent loop below, including layer theta and encoder outputs.
        recurrent_theta = py_utils.NestedMap(random_seed=random_seed,
                                             encoder_outputs=encoder_outputs)
        batch = tf.shape(bs_result.log_probs)[0]
        recurrent_state0 = py_utils.NestedMap(
            timestep=tf.zeros(shape=[], dtype=tf.int32),
            logits=bs_result.log_probs,
            # Start with target_sos_id.
            ids=init_step_ids if init_step_ids is not None else tf.fill(
                [batch], tf.cast(p.target_sos_id, tf.int32)),
            bs_state=bs_state)

        if p.use_recurrent:
            inputs = py_utils.NestedMap(
                dummy=tf.zeros([p.target_seq_len, batch]))
        else:
            inputs = py_utils.NestedMap(
                ids=tf.TensorArray(dtype=tf.int32, size=p.target_seq_len),
                logits=tf.TensorArray(dtype=bs_result.log_probs.dtype,
                                      size=p.target_seq_len),
            )

        def Step(recurrent_theta, state0, inputs):
            """Computes one decoder step."""
            if p.use_recurrent:
                del inputs
            with tf.name_scope('single_sampler_step'):
                # Compute logits and states.
                bs_result, bs_state1 = pre_step_callback(
                    decoder_theta,
                    recurrent_theta.encoder_outputs,
                    tf.expand_dims(state0.ids, 1),  # [batch, 1].
                    state0.bs_state,
                    num_hyps_per_beam=p.num_hyps_per_beam)
                batch = tf.shape(bs_result.log_probs)[0]
                state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
                state1.logits = bs_result.log_probs

                if p.top_k > 0:
                    topk_logits, topk_ids = tf.math.top_k(state1.logits,
                                                          k=p.top_k)
                    sample_logits = tf.nn.log_softmax(
                        topk_logits) if p.top_k_renormalize else topk_logits
                else:
                    sample_logits = state1.logits

                # Sample ids from logits. [batch].
                ids = tf.reshape(
                    tf.random.stateless_categorical(
                        sample_logits / p.temperature,
                        num_samples=1,
                        seed=tf.stack(
                            [recurrent_theta.random_seed, state0.timestep]),
                        dtype=state0.ids.dtype,
                        name='sample_next_id'), [batch])
                state1.ids = tf.gather(topk_ids, ids, axis=1,
                                       batch_dims=1) if p.top_k > 0 else ids

                if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
                    state1.ids = tf.where(
                        tf.math.logical_and(
                            bs_result.is_last_chunk,
                            tf.equal(state1.ids, p.target_eoc_id)),
                        tf.fill(tf.shape(state1.ids), p.target_eos_id),
                        state1.ids)
                state1.bs_state = post_step_callback(
                    decoder_theta, recurrent_theta.encoder_outputs, state1.ids,
                    bs_state1)
            if p.use_recurrent:
                return state1, py_utils.NestedMap()
            else:
                inputs.ids = inputs.ids.write(state0.timestep, state1.ids)
                inputs.logits = inputs.logits.write(state0.timestep,
                                                    state1.logits)
                return (recurrent_theta, state1, inputs)

        if p.use_recurrent:

            def StopFn(t, theta, state):
                del t, theta  # Unused: this stop function only uses the state ids.
                return tf.equal(state.ids, p.target_eos_id)
        else:

            def StopFn(recurrent_theta, state, inputs):
                del recurrent_theta, inputs
                return tf.logical_not(
                    tf.reduce_all(tf.equal(state.ids, p.target_eos_id)))

        if p.use_stop_fn:
            stop_fn = StopFn
        else:
            stop_fn = None

        if p.use_recurrent:
            accumulated_states, _ = recurrent.Recurrent(
                recurrent_theta,
                recurrent_state0,
                inputs,
                Step,
                stop_fn=stop_fn,
                allow_implicit_capture=True)
        else:
            loop_vars = (recurrent_theta, recurrent_state0, inputs)
            (_, _, accumulated_states) = tf.while_loop(
                StopFn,
                Step,
                loop_vars=loop_vars,
                shape_invariants=_GetShapes(loop_vars, none_shapes=True),
                back_prop=False,
                maximum_iterations=p.target_seq_len)
            accumulated_states.ids = accumulated_states.ids.stack()
            accumulated_states.logits = accumulated_states.logits.stack()

        result = py_utils.NestedMap(logits=tf.transpose(
            accumulated_states.logits, [1, 0, 2]),
                                    ids=tf.transpose(accumulated_states.ids))
        result.paddings = tf.cast(
            _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype)
        # Force ids to be eos_id if the timestep is padded.
        result.ids = tf.where(tf.equal(result.paddings, 0), result.ids,
                              tf.fill(tf.shape(result.ids), p.target_eos_id))
        static_batch_size = bs_result.log_probs.shape[0]
        result.ids.set_shape([static_batch_size, p.target_seq_len])
        result.paddings.set_shape([static_batch_size, p.target_seq_len])
        return result
示例#25
0
    def testStatefulEmbeddingStep(self):
        with self.session(use_gpu=False):
            tf.random.set_seed(398847392)
            p = embedding_steps.StatefulEmbeddingStep.Params().Set(
                name='emb_step',
                num_prev_tokens=1,
                include_current_token=False,
                target_sos_id=1,
                embedding_dim=3,
            )
            p.emb.Set(
                vocab_size=10,
                embedding_dim=3,
                max_num_shards=1,
                params_init=py_utils.WeightInit.Gaussian(0.01),
            )
            p.emb.vn.global_vn = False
            p.emb.vn.per_step_vn = False
            emb = p.Instantiate()

            # Verify that nothing bad happens when these methods are called.
            packed = emb.PrepareExternalInputs(None, None)
            state0 = emb.ZeroState(emb.theta, packed, 2)

            # Test FProp of the unit
            out1, state1 = emb.FProp(
                emb.theta, packed,
                py_utils.NestedMap(inputs=[tf.constant([4, 3], tf.int32)]),
                tf.constant([0.0], dtype=tf.float32), state0)
            self.evaluate(tf.global_variables_initializer())
            out1, state1 = self.evaluate([out1, state1])

            self.assertAllEqual(state1.prev_ids, np.array([[4], [3]]))
            self.assertAllClose(
                out1.output,
                np.array([[-0.00740041, -0.00746862, 0.00093992],
                          [-0.00740041, -0.00746862, 0.00093992]]))

            # Test FProp and BProp when integrated with Recurrent()
            def _FProp(theta, state0, inputs):
                embedding, state1 = emb.FProp(
                    theta,
                    None,
                    inputs,
                    None,
                    state0,
                )
                state1.embedding = embedding.output
                return state1, py_utils.NestedMap()

            inputs = py_utils.NestedMap(inputs=[
                tf.constant([[1., 2.], [3., 2.], [0., 1.], [2., 3.], [3., 0.]])
            ])
            acc, _ = recurrent.Recurrent(
                emb.theta,
                state0,
                inputs,
                _FProp,
            )
            loss = tf.math.l2_normalize(acc.embedding)
            grad = tf.gradients(loss, emb.emb.theta.wm[0])
            self.evaluate(tf.global_variables_initializer())
            acc_, _, grad_ = self.evaluate([acc, emb.emb.theta.wm[0], grad])
            prev_ids_expected = np.array([
                [[1.], [2.]],
                [[3.], [2.]],
                [[0.], [1.]],
                [[2.], [3.]],
                [[3.], [0.]],
            ])
            grad_expected = np.array([[21.952698, 20.50312, 19.037958],
                                      [79.622116, 72.15271, 106.34329],
                                      [41.631985, 70.19292, 75.52608],
                                      [53.644493, 36.28507, 36.64856],
                                      [0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
                                      [0., 0., 0.], [0., 0., 0.], [0., 0.,
                                                                   0.]])
            self.assertAllClose(acc_.prev_ids, prev_ids_expected)
            self.assertAllClose(grad_[0], grad_expected)
示例#26
0
    def testBasicWithAccumulator(self):

        with self.session() as sess:

            p = _SampleAccumulatorLayer.Params()
            p.name = 'sample'
            accum_layer = _SampleAccumulatorLayer(p)
            accum_obj = accum_layer.accumulators[accum_layer.accumulator_name]

            theta = py_utils.NestedMap()
            theta.x = tf.constant(2.0)
            state = py_utils.NestedMap()
            state.value = tf.constant(0.0)
            state.x_power = tf.constant(1.0)
            inputs = py_utils.NestedMap()
            inputs.coeff = tf.constant([1., 2., 3.])

            def _CellFn(theta, state, inputs):
                print('TEST ACCUM WITHIN CellFn = ', accum_obj.GetValue())
                accum_obj.Update(inputs.coeff)
                return _Poly(theta, state, inputs)

            # By doing one accumulate prior to recurrent, we ensure that incoming
            # recurrent state is preserved.
            accum_obj.Update(10.)

            # x = 2
            # 1 + 2*x + 3*x^2
            ret = recurrent.Recurrent(theta,
                                      state,
                                      inputs,
                                      _CellFn,
                                      accumulator_layer=accum_layer)

            # Verify bprop.
            y = ret[1].value
            dx, d_coeff = tf.gradients(ys=[y], xs=[theta.x, inputs.coeff])
            dx_val, d_coeff_val = sess.run([dx, d_coeff])

            # 2 + 6*x
            self.assertAllClose(dx_val, 14.)
            self.assertAllClose(d_coeff_val, [1., 2., 4.])

            # acc = [1, 1+2x, 1+2x+3x^2]
            # sum(acc) = 3 + 4x + 3x^2
            acc = ret[0].value
            dx, d_coeff = tf.gradients(ys=[tf.reduce_sum(acc)],
                                       xs=[theta.x, inputs.coeff])
            dx_val, d_coeff_val = sess.run([dx, d_coeff])
            # 4 + 6*x
            self.assertAllClose(dx_val, 16.)
            self.assertAllClose(d_coeff_val, [3., 4., 4.])

            # Verify fprop.
            (acc, state), accum_obj_value = sess.run(
                (ret, accum_obj.GetValue()))

            # Verify that accumulators don't change fprop results.
            self.assertAllClose(acc.value, [1., 5., 17.])
            self.assertAllClose(acc.x_power, [2., 4., 8.])
            self.assertAllClose(state.value, 17.)
            self.assertAllClose(state.x_power, 8.)

            # Verify accumulator (should be 10 (initial increment) + 1 + 2 + 3).
            self.assertEqual(0, accum_obj._disable_count)
            self.assertAllClose([accum_obj_value], [16.0])