Example #1
0
  def testStepWithInferenceHelperMultilabel(self):
    batch_size = 5
    vocabulary_size = 7
    cell_depth = vocabulary_size
    start_token = 0
    end_token = 6

    start_inputs = array_ops.one_hot(
        np.ones(batch_size) * start_token,
        vocabulary_size)

    # The sample function samples independent bernoullis from the logits.
    sample_fn = (
        lambda x: helper_py.bernoulli_sample(logits=x, dtype=dtypes.bool))
    # The next inputs are a one-hot encoding of the sampled labels.
    next_inputs_fn = math_ops.to_float
    end_fn = lambda sample_ids: sample_ids[:, end_token]

    with self.session(use_gpu=True) as sess:
      with variable_scope.variable_scope(
          "testStepWithInferenceHelper",
          initializer=init_ops.constant_initializer(0.01)):
        cell = rnn_cell.LSTMCell(vocabulary_size)
        helper = helper_py.InferenceHelper(
            sample_fn, sample_shape=[cell_depth], sample_dtype=dtypes.bool,
            start_inputs=start_inputs, end_fn=end_fn,
            next_inputs_fn=next_inputs_fn)
        my_decoder = basic_decoder.BasicDecoder(
            cell=cell,
            helper=helper,
            initial_state=cell.zero_state(
                dtype=dtypes.float32, batch_size=batch_size))
        output_size = my_decoder.output_size
        output_dtype = my_decoder.output_dtype
        self.assertEqual(
            basic_decoder.BasicDecoderOutput(cell_depth, cell_depth),
            output_size)
        self.assertEqual(
            basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.bool),
            output_dtype)

        (first_finished, first_inputs, first_state) = my_decoder.initialize()
        (step_outputs, step_state, step_next_inputs,
         step_finished) = my_decoder.step(
             constant_op.constant(0), first_inputs, first_state)
        batch_size_t = my_decoder.batch_size

        self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
        self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
        self.assertTrue(
            isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
        self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
        self.assertEqual((batch_size, cell_depth), step_outputs[1].get_shape())
        self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
        self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
        self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
        self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

        sess.run(variables.global_variables_initializer())
        sess_results = sess.run({
            "batch_size": batch_size_t,
            "first_finished": first_finished,
            "first_inputs": first_inputs,
            "first_state": first_state,
            "step_outputs": step_outputs,
            "step_state": step_state,
            "step_next_inputs": step_next_inputs,
            "step_finished": step_finished
        })

        sample_ids = sess_results["step_outputs"].sample_id
        self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
        expected_step_finished = sample_ids[:, end_token]
        expected_step_next_inputs = sample_ids.astype(np.float32)
        self.assertAllEqual(expected_step_finished,
                            sess_results["step_finished"])
        self.assertAllEqual(expected_step_next_inputs,
                            sess_results["step_next_inputs"])
Example #2
0
  def testStepWithInferenceHelperCategorical(self):
    batch_size = 5
    vocabulary_size = 7
    cell_depth = vocabulary_size
    start_token = 0
    end_token = 6

    start_inputs = array_ops.one_hot(
        np.ones(batch_size) * start_token,
        vocabulary_size)

    # The sample function samples categorically from the logits.
    sample_fn = lambda x: helper_py.categorical_sample(logits=x)
    # The next inputs are a one-hot encoding of the sampled labels.
    next_inputs_fn = (
        lambda x: array_ops.one_hot(x, vocabulary_size, dtype=dtypes.float32))
    end_fn = lambda sample_ids: math_ops.equal(sample_ids, end_token)

    with self.session(use_gpu=True) as sess:
      with variable_scope.variable_scope(
          "testStepWithInferenceHelper",
          initializer=init_ops.constant_initializer(0.01)):
        cell = rnn_cell.LSTMCell(vocabulary_size)
        helper = helper_py.InferenceHelper(
            sample_fn, sample_shape=(), sample_dtype=dtypes.int32,
            start_inputs=start_inputs, end_fn=end_fn,
            next_inputs_fn=next_inputs_fn)
        my_decoder = basic_decoder.BasicDecoder(
            cell=cell,
            helper=helper,
            initial_state=cell.zero_state(
                dtype=dtypes.float32, batch_size=batch_size))
        output_size = my_decoder.output_size
        output_dtype = my_decoder.output_dtype
        self.assertEqual(
            basic_decoder.BasicDecoderOutput(cell_depth,
                                             tensor_shape.TensorShape([])),
            output_size)
        self.assertEqual(
            basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
            output_dtype)

        (first_finished, first_inputs, first_state) = my_decoder.initialize()
        (step_outputs, step_state, step_next_inputs,
         step_finished) = my_decoder.step(
             constant_op.constant(0), first_inputs, first_state)
        batch_size_t = my_decoder.batch_size

        self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
        self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
        self.assertTrue(
            isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
        self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
        self.assertEqual((batch_size,), step_outputs[1].get_shape())
        self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
        self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
        self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
        self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

        sess.run(variables.global_variables_initializer())
        sess_results = sess.run({
            "batch_size": batch_size_t,
            "first_finished": first_finished,
            "first_inputs": first_inputs,
            "first_state": first_state,
            "step_outputs": step_outputs,
            "step_state": step_state,
            "step_next_inputs": step_next_inputs,
            "step_finished": step_finished
        })

        sample_ids = sess_results["step_outputs"].sample_id
        self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
        expected_step_finished = (sample_ids == end_token)
        expected_step_next_inputs = np.zeros((batch_size, vocabulary_size))
        expected_step_next_inputs[np.arange(batch_size), sample_ids] = 1.0
        self.assertAllEqual(expected_step_finished,
                            sess_results["step_finished"])
        self.assertAllEqual(expected_step_next_inputs,
                            sess_results["step_next_inputs"])