def test_train_model_with_serialization(self): # Serializer handles discretization of the data. number_of_time_series = 4 srl = space_serializer.BoxSpaceSerializer( space=gym.spaces.Box(shape=(number_of_time_series, ), low=0.0, high=16.0), vocab_size=16, precision=2, ) def model(mode): return serialization_utils.SerializedHalfModel( trax_models.TransformerLM(mode=mode, vocab_size=16, d_model=16, d_ff=8, n_layers=1, n_heads=1), observation_serializer=srl, action_serializer=srl, significance_decay=0.9, ) output_dir = self.create_tempdir().full_path state = trainer_lib.train(output_dir=output_dir, model=model, inputs=functools.partial( signal_inputs, seq_len=5, batch_size=64, depth=number_of_time_series), steps=2) self.assertEqual(2, state.step)
def test_serialized_model_evaluation(self, mock_stdout): precision = 1 vocab_size = 2 srl = space_serializer.BoxSpaceSerializer( space=gym.spaces.Box(shape=(), low=0.0, high=1.0), vocab_size=vocab_size, precision=precision, ) def inner_model(mode): return models.TransformerLM( mode=mode, vocab_size=vocab_size, d_model=2, d_ff=4, n_layers=1, n_heads=1, ) def model(mode): return serialization_utils.SerializedModel( inner_model(mode), observation_serializer=srl, action_serializer=srl, significance_decay=0.7, ) eval_callback = functools.partial( callbacks.SerializedModelEvaluation, model=inner_model('predict'), observation_serializer=srl, action_serializer=srl, eval_at=5, ) output_dir = self.create_tempdir().full_path trainer_lib.train( output_dir=output_dir, model=model, inputs=functools.partial(dummy_inputs, seq_len=4, batch_size=64), lr_schedule_fn=functools.partial(lr_schedules.constant, 0.01), callbacks=[eval_callback], steps=10, ) self.assertTrue(_has_metric('pred_error', mock_stdout))
def TimeSeriesModel( seq_model, low=0.0, high=1.0, precision=2, vocab_size=64, significance_decay=0.7, mode='train', ): """Simplified constructor for SerializedModel, for time series prediction.""" # Model scalar time series. obs_srl = space_serializer.BoxSpaceSerializer( space=gym.spaces.Box(shape=(), low=low, high=high), vocab_size=vocab_size, precision=precision, ) # Artifact of the fact that we must provide some actions. # TODO(pkozakowski): Remove this requirement. act_srl = space_serializer.DiscreteSpaceSerializer( space=gym.spaces.Discrete(n=1), vocab_size=1) seq_model = functools.partial(seq_model, vocab_size=vocab_size) return SerializedModel(seq_model, obs_srl, act_srl, significance_decay, mode)