Ejemplo n.º 1
0
    def test_train_model_with_serialization(self):
        # Serializer handles discretization of the data.
        number_of_time_series = 4
        srl = space_serializer.BoxSpaceSerializer(
            space=gym.spaces.Box(shape=(number_of_time_series, ),
                                 low=0.0,
                                 high=16.0),
            vocab_size=16,
            precision=2,
        )

        def model(mode):
            return serialization_utils.SerializedHalfModel(
                trax_models.TransformerLM(mode=mode,
                                          vocab_size=16,
                                          d_model=16,
                                          d_ff=8,
                                          n_layers=1,
                                          n_heads=1),
                observation_serializer=srl,
                action_serializer=srl,
                significance_decay=0.9,
            )

        output_dir = self.create_tempdir().full_path
        state = trainer_lib.train(output_dir=output_dir,
                                  model=model,
                                  inputs=functools.partial(
                                      signal_inputs,
                                      seq_len=5,
                                      batch_size=64,
                                      depth=number_of_time_series),
                                  steps=2)
        self.assertEqual(2, state.step)
Ejemplo n.º 2
0
    def test_serialized_model_evaluation(self, mock_stdout):
        precision = 1
        vocab_size = 2
        srl = space_serializer.BoxSpaceSerializer(
            space=gym.spaces.Box(shape=(), low=0.0, high=1.0),
            vocab_size=vocab_size,
            precision=precision,
        )

        def inner_model(mode):
            return models.TransformerLM(
                mode=mode,
                vocab_size=vocab_size,
                d_model=2,
                d_ff=4,
                n_layers=1,
                n_heads=1,
            )

        def model(mode):
            return serialization_utils.SerializedModel(
                inner_model(mode),
                observation_serializer=srl,
                action_serializer=srl,
                significance_decay=0.7,
            )

        eval_callback = functools.partial(
            callbacks.SerializedModelEvaluation,
            model=inner_model('predict'),
            observation_serializer=srl,
            action_serializer=srl,
            eval_at=5,
        )

        output_dir = self.create_tempdir().full_path
        trainer_lib.train(
            output_dir=output_dir,
            model=model,
            inputs=functools.partial(dummy_inputs, seq_len=4, batch_size=64),
            lr_schedule_fn=functools.partial(lr_schedules.constant, 0.01),
            callbacks=[eval_callback],
            steps=10,
        )
        self.assertTrue(_has_metric('pred_error', mock_stdout))
Ejemplo n.º 3
0
def TimeSeriesModel(
    seq_model,
    low=0.0,
    high=1.0,
    precision=2,
    vocab_size=64,
    significance_decay=0.7,
    mode='train',
):
    """Simplified constructor for SerializedModel, for time series prediction."""
    # Model scalar time series.
    obs_srl = space_serializer.BoxSpaceSerializer(
        space=gym.spaces.Box(shape=(), low=low, high=high),
        vocab_size=vocab_size,
        precision=precision,
    )
    # Artifact of the fact that we must provide some actions.
    # TODO(pkozakowski): Remove this requirement.
    act_srl = space_serializer.DiscreteSpaceSerializer(
        space=gym.spaces.Discrete(n=1), vocab_size=1)
    seq_model = functools.partial(seq_model, vocab_size=vocab_size)
    return SerializedModel(seq_model, obs_srl, act_srl, significance_decay,
                           mode)