def test_parametric_wrapper(self): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} action_normalization_parameters = { i: _cont_norm() for i in range(5, 9) } state_preprocessor = Preprocessor(state_normalization_parameters, False) action_preprocessor = Preprocessor(action_normalization_parameters, False) dqn = models.FullyConnectedCritic( state_dim=len(state_normalization_parameters), action_dim=len(action_normalization_parameters), sizes=[16], activations=["relu"], ) dqn_with_preprocessor = ParametricDqnWithPreprocessor( dqn, state_preprocessor=state_preprocessor, action_preprocessor=action_preprocessor, ) wrapper = ParametricDqnPredictorWrapper(dqn_with_preprocessor) input_prototype = dqn_with_preprocessor.input_prototype() output_action_names, q_value = wrapper(*input_prototype) self.assertEqual(output_action_names, ["Q"]) self.assertEqual(q_value.shape, (1, 1)) expected_output = dqn( rlt.FeatureData(state_preprocessor(*input_prototype[0])), rlt.FeatureData(action_preprocessor(*input_prototype[1])), ) self.assertTrue((expected_output == q_value).all())
def _test_seq2slate_model_with_preprocessor( self, model: str, output_arch: Seq2SlateOutputArch): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} candidate_normalization_parameters = { i: _cont_norm() for i in range(101, 106) } state_preprocessor = Preprocessor(state_normalization_parameters, False) candidate_preprocessor = Preprocessor( candidate_normalization_parameters, False) candidate_size = 10 slate_size = 4 seq2slate = None if model == "transformer": seq2slate = Seq2SlateTransformerNet( state_dim=len(state_normalization_parameters), candidate_dim=len(candidate_normalization_parameters), num_stacked_layers=2, num_heads=2, dim_model=10, dim_feedforward=10, max_src_seq_len=candidate_size, max_tgt_seq_len=slate_size, output_arch=output_arch, temperature=0.5, ) else: raise NotImplementedError(f"model type {model} is unknown") seq2slate_with_preprocessor = Seq2SlateWithPreprocessor( seq2slate, state_preprocessor, candidate_preprocessor, greedy=True) input_prototype = seq2slate_with_preprocessor.input_prototype() if seq2slate_with_preprocessor.can_be_traced(): seq2slate_with_preprocessor_jit = torch.jit.trace( seq2slate_with_preprocessor, seq2slate_with_preprocessor.input_prototype(), ) else: seq2slate_with_preprocessor_jit = torch.jit.script( seq2slate_with_preprocessor) expected_output = seq2slate_with_preprocessor(*input_prototype) jit_output = seq2slate_with_preprocessor_jit(*input_prototype) self.verify_results(expected_output, jit_output) # Test if scripted model can handle variable lengths of input input_prototype = change_cand_size_slate_ranking(input_prototype, 20) expected_output = seq2slate_with_preprocessor(*input_prototype) jit_output = seq2slate_with_preprocessor_jit(*input_prototype) self.verify_results(expected_output, jit_output)
def test_discrete_wrapper(self): ids = range(1, 5) state_normalization_parameters = {i: _cont_norm() for i in ids} state_preprocessor = Preprocessor(state_normalization_parameters, False) action_dim = 2 dqn = models.FullyConnectedDQN( state_dim=len(state_normalization_parameters), action_dim=action_dim, sizes=[16], activations=["relu"], ) state_feature_config = rlt.ModelFeatureConfig(float_feature_infos=[ rlt.FloatFeatureInfo(feature_id=i, name=f"feat_{i}") for i in ids ]) dqn_with_preprocessor = DiscreteDqnWithPreprocessor( dqn, state_preprocessor, state_feature_config) action_names = ["L", "R"] wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor, action_names, state_feature_config) input_prototype = dqn_with_preprocessor.input_prototype()[0] output_action_names, q_values = wrapper(input_prototype) self.assertEqual(action_names, output_action_names) self.assertEqual(q_values.shape, (1, 2)) state_with_presence = input_prototype.float_features_with_presence expected_output = dqn( rlt.FeatureData(state_preprocessor(*state_with_presence))) self.assertTrue((expected_output == q_values).all())
def test_actor_wrapper(self): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} action_normalization_parameters = { i: _cont_action_norm() for i in range(101, 105) } state_preprocessor = Preprocessor(state_normalization_parameters, False) postprocessor = Postprocessor(action_normalization_parameters, False) # Test with FullyConnectedActor to make behavior deterministic actor = models.FullyConnectedActor( state_dim=len(state_normalization_parameters), action_dim=len(action_normalization_parameters), sizes=[16], activations=["relu"], ) actor_with_preprocessor = ActorWithPreprocessor( actor, state_preprocessor, postprocessor) wrapper = ActorPredictorWrapper(actor_with_preprocessor) input_prototype = actor_with_preprocessor.input_prototype() action, _log_prob = wrapper(*input_prototype) self.assertEqual(action.shape, (1, len(action_normalization_parameters))) expected_output = postprocessor( actor(rlt.FeatureData( state_preprocessor(*input_prototype[0]))).action) self.assertTrue((expected_output == action).all())
def test_discrete_wrapper_with_id_list(self): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} state_preprocessor = Preprocessor(state_normalization_parameters, False) action_dim = 2 state_feature_config = rlt.ModelFeatureConfig( float_feature_infos=[ rlt.FloatFeatureInfo(name=str(i), feature_id=i) for i in range(1, 5) ], id_list_feature_configs=[ rlt.IdListFeatureConfig(name="A", feature_id=10, id_mapping_name="A_mapping") ], id_mapping_config={"A_mapping": rlt.IdMapping(ids=[0, 1, 2])}, ) embedding_concat = models.EmbeddingBagConcat( state_dim=len(state_normalization_parameters), model_feature_config=state_feature_config, embedding_dim=8, ) dqn = models.Sequential( embedding_concat, rlt.TensorFeatureData(), models.FullyConnectedDQN( embedding_concat.output_dim, action_dim=action_dim, sizes=[16], activations=["relu"], ), ) dqn_with_preprocessor = DiscreteDqnWithPreprocessor( dqn, state_preprocessor, state_feature_config) action_names = ["L", "R"] wrapper = DiscreteDqnPredictorWrapper(dqn_with_preprocessor, action_names, state_feature_config) input_prototype = dqn_with_preprocessor.input_prototype()[0] output_action_names, q_values = wrapper(input_prototype) self.assertEqual(action_names, output_action_names) self.assertEqual(q_values.shape, (1, 2)) feature_id_to_name = { config.feature_id: config.name for config in state_feature_config.id_list_feature_configs } state_id_list_features = { feature_id_to_name[k]: v for k, v in input_prototype.id_list_features.items() } state_with_presence = input_prototype.float_features_with_presence expected_output = dqn( rlt.FeatureData( float_features=state_preprocessor(*state_with_presence), id_list_features=state_id_list_features, )) self.assertTrue((expected_output == q_values).all())
def _test_seq2slate_wrapper(self, model: str, output_arch: Seq2SlateOutputArch): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} candidate_normalization_parameters = { i: _cont_norm() for i in range(101, 106) } state_preprocessor = Preprocessor(state_normalization_parameters, False) candidate_preprocessor = Preprocessor( candidate_normalization_parameters, False) candidate_size = 10 slate_size = 4 seq2slate = None if model == "transformer": seq2slate = Seq2SlateTransformerNet( state_dim=len(state_normalization_parameters), candidate_dim=len(candidate_normalization_parameters), num_stacked_layers=2, num_heads=2, dim_model=10, dim_feedforward=10, max_src_seq_len=candidate_size, max_tgt_seq_len=slate_size, output_arch=output_arch, temperature=0.5, ) else: raise NotImplementedError(f"model type {model} is unknown") seq2slate_with_preprocessor = Seq2SlateWithPreprocessor( seq2slate, state_preprocessor, candidate_preprocessor, greedy=True) wrapper = Seq2SlatePredictorWrapper(seq2slate_with_preprocessor) ( state_input_prototype, candidate_input_prototype, ) = seq2slate_with_preprocessor.input_prototype() wrapper_output = wrapper(state_input_prototype, candidate_input_prototype) ranking_input = seq2slate_input_prototype_to_ranking_input( state_input_prototype, candidate_input_prototype, state_preprocessor, candidate_preprocessor, ) expected_output = seq2slate( ranking_input, mode=Seq2SlateMode.RANK_MODE, tgt_seq_len=candidate_size, greedy=True, ) self.validate_seq2slate_output(expected_output, wrapper_output) # Test Seq2SlatePredictorWrapper can handle variable lengths of inputs random_length = random.randint(candidate_size + 1, candidate_size * 2) ( state_input_prototype, candidate_input_prototype, ) = change_cand_size_slate_ranking( seq2slate_with_preprocessor.input_prototype(), random_length) wrapper_output = wrapper(state_input_prototype, candidate_input_prototype) ranking_input = seq2slate_input_prototype_to_ranking_input( state_input_prototype, candidate_input_prototype, state_preprocessor, candidate_preprocessor, ) expected_output = seq2slate( ranking_input, mode=Seq2SlateMode.RANK_MODE, tgt_seq_len=random_length, greedy=True, ) self.validate_seq2slate_output(expected_output, wrapper_output)