Ejemplo n.º 1
0
    def _test_seq2slate_model_with_preprocessor(
            self, model: str, output_arch: Seq2SlateOutputArch):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        candidate_normalization_parameters = {
            i: _cont_norm()
            for i in range(101, 106)
        }
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        candidate_preprocessor = Preprocessor(
            candidate_normalization_parameters, False)
        candidate_size = 10
        slate_size = 4

        seq2slate = None
        if model == "transformer":
            seq2slate = Seq2SlateTransformerNet(
                state_dim=len(state_normalization_parameters),
                candidate_dim=len(candidate_normalization_parameters),
                num_stacked_layers=2,
                num_heads=2,
                dim_model=10,
                dim_feedforward=10,
                max_src_seq_len=candidate_size,
                max_tgt_seq_len=slate_size,
                output_arch=output_arch,
                temperature=0.5,
            )
        else:
            raise NotImplementedError(f"model type {model} is unknown")

        seq2slate_with_preprocessor = Seq2SlateWithPreprocessor(
            seq2slate, state_preprocessor, candidate_preprocessor, greedy=True)
        input_prototype = seq2slate_with_preprocessor.input_prototype()

        if seq2slate_with_preprocessor.can_be_traced():
            seq2slate_with_preprocessor_jit = torch.jit.trace(
                seq2slate_with_preprocessor,
                seq2slate_with_preprocessor.input_prototype(),
            )
        else:
            seq2slate_with_preprocessor_jit = torch.jit.script(
                seq2slate_with_preprocessor)
        expected_output = seq2slate_with_preprocessor(*input_prototype)
        jit_output = seq2slate_with_preprocessor_jit(*input_prototype)
        self.verify_results(expected_output, jit_output)

        # Test if scripted model can handle variable lengths of input
        input_prototype = change_cand_size_slate_ranking(input_prototype, 20)
        expected_output = seq2slate_with_preprocessor(*input_prototype)
        jit_output = seq2slate_with_preprocessor_jit(*input_prototype)
        self.verify_results(expected_output, jit_output)
Ejemplo n.º 2
0
    def _test_seq2slate_wrapper(self, model: str,
                                output_arch: Seq2SlateOutputArch):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        candidate_normalization_parameters = {
            i: _cont_norm()
            for i in range(101, 106)
        }
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        candidate_preprocessor = Preprocessor(
            candidate_normalization_parameters, False)

        seq2slate = None
        if model == "transformer":
            seq2slate = Seq2SlateTransformerNet(
                state_dim=len(state_normalization_parameters),
                candidate_dim=len(candidate_normalization_parameters),
                num_stacked_layers=2,
                num_heads=2,
                dim_model=10,
                dim_feedforward=10,
                max_src_seq_len=10,
                max_tgt_seq_len=4,
                output_arch=output_arch,
                temperature=0.5,
            )
        else:
            raise NotImplementedError(f"model type {model} is unknown")

        seq2slate_with_preprocessor = Seq2SlateWithPreprocessor(
            seq2slate, state_preprocessor, candidate_preprocessor, greedy=True)
        wrapper = Seq2SlatePredictorWrapper(seq2slate_with_preprocessor)

        (
            state_input_prototype,
            candidate_input_prototype,
        ) = seq2slate_with_preprocessor.input_prototype()
        ret_val = wrapper(state_input_prototype, candidate_input_prototype)

        preprocessed_state = state_preprocessor(state_input_prototype[0],
                                                state_input_prototype[1])
        preprocessed_candidates = candidate_preprocessor(
            candidate_input_prototype[0].view(
                1 * seq2slate.max_src_seq_len,
                len(candidate_normalization_parameters)),
            candidate_input_prototype[1].view(
                1 * seq2slate.max_src_seq_len,
                len(candidate_normalization_parameters)),
        ).view(1, seq2slate.max_src_seq_len, -1)
        src_src_mask = torch.ones(1, seq2slate.max_src_seq_len,
                                  seq2slate.max_src_seq_len)
        ranking_input = rlt.PreprocessedRankingInput.from_tensors(
            state=preprocessed_state,
            src_seq=preprocessed_candidates,
            src_src_mask=src_src_mask,
        )
        expected_output = seq2slate(
            ranking_input,
            mode=Seq2SlateMode.RANK_MODE,
            tgt_seq_len=seq2slate.max_src_seq_len,
            greedy=True,
        )
        ranked_per_seq_probs, ranked_tgt_out_idx = (
            expected_output.ranked_per_seq_probs,
            expected_output.ranked_tgt_out_idx,
        )
        # -2 to offset padding symbol and decoder start symbol
        ranked_tgt_out_idx -= 2

        self.assertTrue(ranked_per_seq_probs == ret_val[0])
        self.assertTrue(torch.all(torch.eq(ret_val[1], ranked_tgt_out_idx)))
Ejemplo n.º 3
0
    def _test_seq2slate_wrapper(self, model: str,
                                output_arch: Seq2SlateOutputArch):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        candidate_normalization_parameters = {
            i: _cont_norm()
            for i in range(101, 106)
        }
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        candidate_preprocessor = Preprocessor(
            candidate_normalization_parameters, False)
        candidate_size = 10
        slate_size = 4

        seq2slate = None
        if model == "transformer":
            seq2slate = Seq2SlateTransformerNet(
                state_dim=len(state_normalization_parameters),
                candidate_dim=len(candidate_normalization_parameters),
                num_stacked_layers=2,
                num_heads=2,
                dim_model=10,
                dim_feedforward=10,
                max_src_seq_len=candidate_size,
                max_tgt_seq_len=slate_size,
                output_arch=output_arch,
                temperature=0.5,
            )
        else:
            raise NotImplementedError(f"model type {model} is unknown")

        seq2slate_with_preprocessor = Seq2SlateWithPreprocessor(
            seq2slate, state_preprocessor, candidate_preprocessor, greedy=True)
        wrapper = Seq2SlatePredictorWrapper(seq2slate_with_preprocessor)

        (
            state_input_prototype,
            candidate_input_prototype,
        ) = seq2slate_with_preprocessor.input_prototype()
        wrapper_output = wrapper(state_input_prototype,
                                 candidate_input_prototype)

        ranking_input = seq2slate_input_prototype_to_ranking_input(
            state_input_prototype,
            candidate_input_prototype,
            state_preprocessor,
            candidate_preprocessor,
        )
        expected_output = seq2slate(
            ranking_input,
            mode=Seq2SlateMode.RANK_MODE,
            tgt_seq_len=candidate_size,
            greedy=True,
        )
        self.validate_seq2slate_output(expected_output, wrapper_output)

        # Test Seq2SlatePredictorWrapper can handle variable lengths of inputs
        random_length = random.randint(candidate_size + 1, candidate_size * 2)
        (
            state_input_prototype,
            candidate_input_prototype,
        ) = change_cand_size_slate_ranking(
            seq2slate_with_preprocessor.input_prototype(), random_length)
        wrapper_output = wrapper(state_input_prototype,
                                 candidate_input_prototype)

        ranking_input = seq2slate_input_prototype_to_ranking_input(
            state_input_prototype,
            candidate_input_prototype,
            state_preprocessor,
            candidate_preprocessor,
        )
        expected_output = seq2slate(
            ranking_input,
            mode=Seq2SlateMode.RANK_MODE,
            tgt_seq_len=random_length,
            greedy=True,
        )
        self.validate_seq2slate_output(expected_output, wrapper_output)
Ejemplo n.º 4
0
    def test_seq2slate_wrapper(self):
        state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)}
        candidate_normalization_parameters = {
            i: _cont_norm()
            for i in range(101, 106)
        }
        state_preprocessor = Preprocessor(state_normalization_parameters,
                                          False)
        candidate_preprocessor = Preprocessor(
            candidate_normalization_parameters, False)
        seq2slate = Seq2SlateTransformerNet(
            state_dim=len(state_normalization_parameters),
            candidate_dim=len(candidate_normalization_parameters),
            num_stacked_layers=2,
            num_heads=2,
            dim_model=10,
            dim_feedforward=10,
            max_src_seq_len=10,
            max_tgt_seq_len=4,
            encoder_only=False,
        )
        seq2slate_with_preprocessor = Seq2SlateWithPreprocessor(
            seq2slate, state_preprocessor, candidate_preprocessor, greedy=True)
        wrapper = Seq2SlatePredictorWrapper(seq2slate_with_preprocessor)

        state_input_prototype, candidate_input_prototype = (
            seq2slate_with_preprocessor.input_prototype())
        ret_val = wrapper(state_input_prototype, candidate_input_prototype)

        preprocessed_state = state_preprocessor(state_input_prototype[0],
                                                state_input_prototype[1])
        preprocessed_candidates = candidate_preprocessor(
            candidate_input_prototype[0].view(
                1 * seq2slate.max_src_seq_len,
                len(candidate_normalization_parameters)),
            candidate_input_prototype[1].view(
                1 * seq2slate.max_src_seq_len,
                len(candidate_normalization_parameters)),
        ).view(1, seq2slate.max_src_seq_len, -1)
        src_src_mask = torch.ones(1, seq2slate.max_src_seq_len,
                                  seq2slate.max_src_seq_len)
        ranking_input = rlt.PreprocessedRankingInput.from_tensors(
            state=preprocessed_state,
            src_seq=preprocessed_candidates,
            src_src_mask=src_src_mask,
        )
        expected_output = seq2slate(
            ranking_input,
            mode=Seq2SlateMode.RANK_MODE,
            tgt_seq_len=seq2slate.max_tgt_seq_len,
            greedy=True,
        )
        ranked_tgt_out_probs, ranked_tgt_out_idx = (
            expected_output.ranked_tgt_out_probs,
            expected_output.ranked_tgt_out_idx,
        )
        ranked_tgt_out_probs = torch.prod(
            torch.gather(ranked_tgt_out_probs, 2,
                         ranked_tgt_out_idx.unsqueeze(-1)).squeeze(),
            -1,
        )
        # -2 to offset padding symbol and decoder start symbol
        ranked_tgt_out_idx -= 2

        self.assertTrue(ranked_tgt_out_probs == ret_val[0])
        self.assertTrue(torch.all(torch.eq(ret_val[1], ranked_tgt_out_idx)))
Ejemplo n.º 5
0
    def test_seq2slate_scriptable(self):
        state_dim = 2
        candidate_dim = 3
        num_stacked_layers = 2
        num_heads = 2
        dim_model = 128
        dim_feedforward = 128
        candidate_size = 8
        slate_size = 8
        output_arch = Seq2SlateOutputArch.AUTOREGRESSIVE
        temperature = 1.0
        greedy_serving = True

        # test the raw Seq2Slate model is script-able
        seq2slate = Seq2SlateTransformerModel(
            state_dim=state_dim,
            candidate_dim=candidate_dim,
            num_stacked_layers=num_stacked_layers,
            num_heads=num_heads,
            dim_model=dim_model,
            dim_feedforward=dim_feedforward,
            max_src_seq_len=candidate_size,
            max_tgt_seq_len=slate_size,
            output_arch=output_arch,
            temperature=temperature,
        )
        seq2slate_scripted = torch.jit.script(seq2slate)

        seq2slate_net = Seq2SlateTransformerNet(
            state_dim=state_dim,
            candidate_dim=candidate_dim,
            num_stacked_layers=num_stacked_layers,
            num_heads=num_heads,
            dim_model=dim_model,
            dim_feedforward=dim_feedforward,
            max_src_seq_len=candidate_size,
            max_tgt_seq_len=slate_size,
            output_arch=output_arch,
            temperature=temperature,
        )

        state_normalization_data = NormalizationData(
            dense_normalization_parameters={
                0: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
                1: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
            })

        candidate_normalization_data = NormalizationData(
            dense_normalization_parameters={
                5: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
                6: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
                7: NormalizationParameters(feature_type=DO_NOT_PREPROCESS),
            })
        state_preprocessor = Preprocessor(
            state_normalization_data.dense_normalization_parameters, False)
        candidate_preprocessor = Preprocessor(
            candidate_normalization_data.dense_normalization_parameters, False)

        # test trace
        seq2slate_net.seq2slate = seq2slate
        seq2slate_with_preprocessor = Seq2SlateWithPreprocessor(
            seq2slate_net.eval(),
            state_preprocessor,
            candidate_preprocessor,
            greedy_serving,
        )
        seq2slate_with_preprocessor(
            *seq2slate_with_preprocessor.input_prototype())
        torch.jit.trace(seq2slate_with_preprocessor,
                        seq2slate_with_preprocessor.input_prototype())

        # test mix of script + trace
        seq2slate_net.seq2slate = seq2slate_scripted
        seq2slate_with_preprocessor = Seq2SlateWithPreprocessor(
            seq2slate_net.eval(),
            state_preprocessor,
            candidate_preprocessor,
            greedy_serving,
        )
        seq2slate_with_preprocessor(
            *seq2slate_with_preprocessor.input_prototype())
        torch.jit.trace(seq2slate_with_preprocessor,
                        seq2slate_with_preprocessor.input_prototype())