def _test_seq2slate_wrapper(self, model: str, output_arch: Seq2SlateOutputArch): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} candidate_normalization_parameters = { i: _cont_norm() for i in range(101, 106) } state_preprocessor = Preprocessor(state_normalization_parameters, False) candidate_preprocessor = Preprocessor( candidate_normalization_parameters, False) candidate_size = 10 slate_size = 4 seq2slate = None if model == "transformer": seq2slate = Seq2SlateTransformerNet( state_dim=len(state_normalization_parameters), candidate_dim=len(candidate_normalization_parameters), num_stacked_layers=2, num_heads=2, dim_model=10, dim_feedforward=10, max_src_seq_len=candidate_size, max_tgt_seq_len=slate_size, output_arch=output_arch, temperature=0.5, ) else: raise NotImplementedError(f"model type {model} is unknown") seq2slate_with_preprocessor = Seq2SlateWithPreprocessor( seq2slate, state_preprocessor, candidate_preprocessor, greedy=True) wrapper = Seq2SlatePredictorWrapper(seq2slate_with_preprocessor) ( state_input_prototype, candidate_input_prototype, ) = seq2slate_with_preprocessor.input_prototype() wrapper_output = wrapper(state_input_prototype, candidate_input_prototype) ranking_input = seq2slate_input_prototype_to_ranking_input( state_input_prototype, candidate_input_prototype, state_preprocessor, candidate_preprocessor, ) expected_output = seq2slate( ranking_input, mode=Seq2SlateMode.RANK_MODE, tgt_seq_len=candidate_size, greedy=True, ) self.validate_seq2slate_output(expected_output, wrapper_output) # Test Seq2SlatePredictorWrapper can handle variable lengths of inputs random_length = random.randint(candidate_size + 1, candidate_size * 2) ( state_input_prototype, candidate_input_prototype, ) = change_cand_size_slate_ranking( seq2slate_with_preprocessor.input_prototype(), random_length) wrapper_output = wrapper(state_input_prototype, candidate_input_prototype) ranking_input = seq2slate_input_prototype_to_ranking_input( state_input_prototype, candidate_input_prototype, state_preprocessor, candidate_preprocessor, ) expected_output = seq2slate( ranking_input, mode=Seq2SlateMode.RANK_MODE, tgt_seq_len=random_length, greedy=True, ) self.validate_seq2slate_output(expected_output, wrapper_output)
def _test_seq2slate_wrapper(self, model: str, output_arch: Seq2SlateOutputArch): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} candidate_normalization_parameters = { i: _cont_norm() for i in range(101, 106) } state_preprocessor = Preprocessor(state_normalization_parameters, False) candidate_preprocessor = Preprocessor( candidate_normalization_parameters, False) seq2slate = None if model == "transformer": seq2slate = Seq2SlateTransformerNet( state_dim=len(state_normalization_parameters), candidate_dim=len(candidate_normalization_parameters), num_stacked_layers=2, num_heads=2, dim_model=10, dim_feedforward=10, max_src_seq_len=10, max_tgt_seq_len=4, output_arch=output_arch, temperature=0.5, ) else: raise NotImplementedError(f"model type {model} is unknown") seq2slate_with_preprocessor = Seq2SlateWithPreprocessor( seq2slate, state_preprocessor, candidate_preprocessor, greedy=True) wrapper = Seq2SlatePredictorWrapper(seq2slate_with_preprocessor) ( state_input_prototype, candidate_input_prototype, ) = seq2slate_with_preprocessor.input_prototype() ret_val = wrapper(state_input_prototype, candidate_input_prototype) preprocessed_state = state_preprocessor(state_input_prototype[0], state_input_prototype[1]) preprocessed_candidates = candidate_preprocessor( candidate_input_prototype[0].view( 1 * seq2slate.max_src_seq_len, len(candidate_normalization_parameters)), candidate_input_prototype[1].view( 1 * seq2slate.max_src_seq_len, len(candidate_normalization_parameters)), ).view(1, seq2slate.max_src_seq_len, -1) src_src_mask = torch.ones(1, seq2slate.max_src_seq_len, seq2slate.max_src_seq_len) ranking_input = rlt.PreprocessedRankingInput.from_tensors( state=preprocessed_state, src_seq=preprocessed_candidates, src_src_mask=src_src_mask, ) expected_output = seq2slate( ranking_input, mode=Seq2SlateMode.RANK_MODE, tgt_seq_len=seq2slate.max_src_seq_len, greedy=True, ) ranked_per_seq_probs, ranked_tgt_out_idx = ( expected_output.ranked_per_seq_probs, expected_output.ranked_tgt_out_idx, ) # -2 to offset padding symbol and decoder start symbol ranked_tgt_out_idx -= 2 self.assertTrue(ranked_per_seq_probs == ret_val[0]) self.assertTrue(torch.all(torch.eq(ret_val[1], ranked_tgt_out_idx)))
def test_seq2slate_wrapper(self): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} candidate_normalization_parameters = { i: _cont_norm() for i in range(101, 106) } state_preprocessor = Preprocessor(state_normalization_parameters, False) candidate_preprocessor = Preprocessor( candidate_normalization_parameters, False) seq2slate = Seq2SlateTransformerNet( state_dim=len(state_normalization_parameters), candidate_dim=len(candidate_normalization_parameters), num_stacked_layers=2, num_heads=2, dim_model=10, dim_feedforward=10, max_src_seq_len=10, max_tgt_seq_len=4, encoder_only=False, ) seq2slate_with_preprocessor = Seq2SlateWithPreprocessor( seq2slate, state_preprocessor, candidate_preprocessor, greedy=True) wrapper = Seq2SlatePredictorWrapper(seq2slate_with_preprocessor) state_input_prototype, candidate_input_prototype = ( seq2slate_with_preprocessor.input_prototype()) ret_val = wrapper(state_input_prototype, candidate_input_prototype) preprocessed_state = state_preprocessor(state_input_prototype[0], state_input_prototype[1]) preprocessed_candidates = candidate_preprocessor( candidate_input_prototype[0].view( 1 * seq2slate.max_src_seq_len, len(candidate_normalization_parameters)), candidate_input_prototype[1].view( 1 * seq2slate.max_src_seq_len, len(candidate_normalization_parameters)), ).view(1, seq2slate.max_src_seq_len, -1) src_src_mask = torch.ones(1, seq2slate.max_src_seq_len, seq2slate.max_src_seq_len) ranking_input = rlt.PreprocessedRankingInput.from_tensors( state=preprocessed_state, src_seq=preprocessed_candidates, src_src_mask=src_src_mask, ) expected_output = seq2slate( ranking_input, mode=Seq2SlateMode.RANK_MODE, tgt_seq_len=seq2slate.max_tgt_seq_len, greedy=True, ) ranked_tgt_out_probs, ranked_tgt_out_idx = ( expected_output.ranked_tgt_out_probs, expected_output.ranked_tgt_out_idx, ) ranked_tgt_out_probs = torch.prod( torch.gather(ranked_tgt_out_probs, 2, ranked_tgt_out_idx.unsqueeze(-1)).squeeze(), -1, ) # -2 to offset padding symbol and decoder start symbol ranked_tgt_out_idx -= 2 self.assertTrue(ranked_tgt_out_probs == ret_val[0]) self.assertTrue(torch.all(torch.eq(ret_val[1], ranked_tgt_out_idx)))