Exemplo n.º 1
0
    def test_mock_transformer_lm_rewinds(self):
        model = test_utils.MockTransformerLM(
            sequence_fn=arithmetic_sequence,
            vocab_size=10,
            mode='predict',
        )
        sample_3 = functools.partial(
            decoding.autoregressive_sample,
            max_length=3,
            eos_id=-1,
            accelerate=False,
        )

        # Generate the 3 initial symbols.
        init_output = sample_3(model, start_id=0)
        np.testing.assert_array_equal(init_output, [[1, 2, 3]])
        state = model.state

        # Generate the next 3 symbols.
        next_output = sample_3(model, start_id=init_output[0, -1])
        np.testing.assert_array_equal(next_output, [[4, 5, 6]])

        # Rewind and generate the last 3 symbols again.
        model.state = state
        next_output = sample_3(model, start_id=init_output[0, -1])
        np.testing.assert_array_equal(next_output, [[4, 5, 6]])

        # Check the buffers.
        model.assert_prediction_buffers_equal([[0, 1, 2, 3, 4, 5]])
Exemplo n.º 2
0
  def test_srl_eval_reports_zero_error_for_perfect_model(self, precision):
    vocab_size = 100
    n_steps = 5

    multibonacci_modulo = make_multibonacci_modulo(2 * precision, vocab_size)
    space = gym.spaces.MultiDiscrete(nvec=([vocab_size] * precision))
    (obs, act) = generate_trajectory(multibonacci_modulo, space, n_steps)
    eval_task = make_singleton_eval_task(obs, act)
    model = tl_test_utils.MockTransformerLM(
        sequence_fn=multibonacci_modulo, vocab_size=vocab_size, mode='predict'
    )
    srl = space_serializer.MultiDiscreteSpaceSerializer(space, vocab_size)
    callback = callbacks.SerializedModelEvaluation(
        loop=None,
        eval_task=eval_task,
        model=model,
        observation_serializer=srl,
        action_serializer=srl,
        context_lengths=(1,),
        horizon_lengths=(4,),
        accelerate_model=False,
    )
    metrics = callback.evaluate(weights=None)
    error = next(
        value for (name, value) in metrics.items() if 'pred_error' in name
    )
    assert error == 0
Exemplo n.º 3
0
  def test_srl_eval_feeds_correct_sequence(
      self, context_lengths, horizon_lengths
  ):
    vocab_size = 10
    n_steps = 5

    multibonacci_modulo = make_multibonacci_modulo(2, vocab_size)
    space = gym.spaces.Discrete(n=vocab_size)
    (obs, act) = generate_trajectory(multibonacci_modulo, space, n_steps)
    eval_task = make_singleton_eval_task(obs, act)
    model = tl_test_utils.MockTransformerLM(
        sequence_fn=multibonacci_modulo, vocab_size=vocab_size, mode='predict'
    )
    srl = space_serializer.DiscreteSpaceSerializer(space, vocab_size)
    callback = callbacks.SerializedModelEvaluation(
        loop=None,
        eval_task=eval_task,
        model=model,
        observation_serializer=srl,
        action_serializer=srl,
        context_lengths=context_lengths,
        horizon_lengths=horizon_lengths,
        accelerate_model=False,
    )
    callback.evaluate(weights=None)

    expected_seq = np.zeros(2 * n_steps + 1)
    expected_seq[1::2] = obs
    expected_seq[2::2] = act
    seen_len = (context_lengths[-1] + horizon_lengths[-1]) * 2
    model.assert_prediction_buffers_equal([expected_seq[:seen_len]])
Exemplo n.º 4
0
    def test_mock_transformer_lm_decodes_arithmetic_sequence(self):
        model = test_utils.MockTransformerLM(
            sequence_fn=arithmetic_sequence,
            vocab_size=10,
            mode='predict',
        )
        output = decoding.autoregressive_sample(model,
                                                max_length=5,
                                                start_id=0,
                                                eos_id=-1,
                                                accelerate=False)

        # Sequence including the leading 0 and the last predicted symbol.
        full_seq = list(range(6))
        # decoding.autoregressive_sample doesn't return the leading 0.
        np.testing.assert_array_equal(output, [full_seq[1:]])
        # The prediction buffers don't include the last predicted symbol.
        model.assert_prediction_buffers_equal([full_seq[:-1]])