Exemplo n.º 1
0
    def __init__(
        self,
        state_dim: int,
        candidate_dim: int,
        num_stacked_layers: int,
        num_heads: int,
        dim_model: int,
        dim_feedforward: int,
        max_src_seq_len: int,
        max_tgt_seq_len: int,
        output_arch: Seq2SlateOutputArch,
        temperature: float = 1.0,
        state_embed_dim: Optional[int] = None,
    ):
        """
        :param state_dim: state feature dimension
        :param candidate_dim: candidate feature dimension
        :param num_stacked_layers: number of stacked layers in Transformer
        :param num_heads: number of attention heads used in Transformer
        :param dim_model: number of attention dimensions in Transformer
        :param dim_feedforward: number of hidden units in FeedForward layers
            in Transformer
        :param max_src_seq_len: the maximum length of input sequences
        :param max_tgt_seq_len: the maximum length of output sequences
        :param output_arch: determines seq2slate output architecture
        :param temperature: temperature used in decoder sampling
        :param state_embed_dim: embedding dimension of state features.
            by default (if not specified), state_embed_dim = dim_model / 2
        """
        super().__init__()
        self.state_dim = state_dim
        self.candidate_dim = candidate_dim
        self.num_stacked_layers = num_stacked_layers
        self.num_heads = num_heads
        self.dim_model = dim_model
        self.dim_feedforward = dim_feedforward
        self.max_src_seq_len = max_src_seq_len
        self.max_tgt_seq_len = max_tgt_seq_len
        self.output_arch = output_arch
        self._DECODER_START_SYMBOL = DECODER_START_SYMBOL
        self._PADDING_SYMBOL = PADDING_SYMBOL
        self._RANK_MODE = Seq2SlateMode.RANK_MODE.value
        self._PER_SYMBOL_LOG_PROB_DIST_MODE = (
            Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE.value
        )
        self._PER_SEQ_LOG_PROB_MODE = Seq2SlateMode.PER_SEQ_LOG_PROB_MODE.value
        self._DECODE_ONE_STEP_MODE = Seq2SlateMode.DECODE_ONE_STEP_MODE.value
        self._ENCODER_SCORE_MODE = Seq2SlateMode.ENCODER_SCORE_MODE.value
        self._OUTPUT_PLACEHOLDER = torch.zeros(1)

        self.encoder = EncoderPyTorch(
            dim_model, num_heads, dim_feedforward, num_stacked_layers
        )
        # Compute score at each encoder step
        self.encoder_scorer = nn.Linear(dim_model, 1)
        self.generator = Generator()
        self.decoder = DecoderPyTorch(
            dim_model, num_heads, dim_feedforward, num_stacked_layers
        )
        self.positional_encoding_decoder = PositionalEncoding(dim_model)

        if state_embed_dim is None:
            state_embed_dim = dim_model // 2
        candidate_embed_dim = dim_model - state_embed_dim
        self.state_embedder = Embedder(state_dim, state_embed_dim)
        self.candidate_embedder = Embedder(candidate_dim, candidate_embed_dim)

        # Initialize parameters with Glorot / fan_avg.
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        print_model_info(self)
Exemplo n.º 2
0
    def __init__(
        self,
        state_dim: int,
        candidate_dim: int,
        num_stacked_layers: int,
        num_heads: int,
        dim_model: int,
        dim_feedforward: int,
        max_src_seq_len: int,
        max_tgt_seq_len: int,
        output_arch: Seq2SlateOutputArch,
        temperature: float = 1.0,
        state_embed_dim: Optional[int] = None,
    ):
        """
        :param state_dim: state feature dimension
        :param candidate_dim: candidate feature dimension
        :param num_stacked_layers: number of stacked layers in Transformer
        :param num_heads: number of attention heads used in Transformer
        :param dim_model: number of attention dimensions in Transformer
        :param dim_feedforward: number of hidden units in FeedForward layers
            in Transformer
        :param max_src_seq_len: the maximum length of input sequences
        :param max_tgt_seq_len: the maximum length of output sequences
        :param output_arch: determines seq2slate output architecture
        :param temperature: temperature used in decoder sampling
        :param state_embed_dim: embedding dimension of state features.
            by default (if not specified), state_embed_dim = dim_model / 2
        """
        super().__init__()
        self.state_dim = state_dim
        self.candidate_dim = candidate_dim
        self.num_stacked_layers = num_stacked_layers
        self.num_heads = num_heads
        self.dim_model = dim_model
        self.dim_feedforward = dim_feedforward
        self.max_src_seq_len = max_src_seq_len
        self.max_tgt_seq_len = max_tgt_seq_len
        self.output_arch = output_arch
        self._DECODER_START_SYMBOL = DECODER_START_SYMBOL
        self._PADDING_SYMBOL = PADDING_SYMBOL
        self._RANK_MODE = Seq2SlateMode.RANK_MODE
        self._PER_SYMBOL_LOG_PROB_DIST_MODE = (
            Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE)
        self._PER_SEQ_LOG_PROB_MODE = Seq2SlateMode.PER_SEQ_LOG_PROB_MODE
        self._DECODE_ONE_STEP_MODE = Seq2SlateMode.DECODE_ONE_STEP_MODE
        self._ENCODER_SCORE_MODE = Seq2SlateMode.ENCODER_SCORE_MODE

        c = copy.deepcopy
        attn = MultiHeadedAttention(num_heads, dim_model)
        ff = PositionwiseFeedForward(dim_model, dim_feedforward)
        self.encoder = Encoder(EncoderLayer(dim_model, c(attn), c(ff)),
                               num_stacked_layers)
        if self.output_arch == Seq2SlateOutputArch.FRECHET_SORT:
            # Compute score at each encoder step
            self.encoder_scorer = nn.Linear(dim_model, 1)
            # Generator needs to know the output symbol size,
            # Possible output symbols include candidate indices, decoder-start symbol
            # and padding symbol
            self.generator = Generator(dim_model, max_src_seq_len + 2,
                                       temperature)
        elif self.output_arch == Seq2SlateOutputArch.AUTOREGRESSIVE:
            self.decoder = DecoderPyTorch(dim_model, num_heads,
                                          dim_feedforward, num_stacked_layers)
            self.positional_encoding_decoder = PositionalEncoding(
                dim_model, max_len=max_tgt_seq_len)
            self.generator = Generator(dim_model, max_src_seq_len + 2,
                                       temperature)
        elif self.output_arch == Seq2SlateOutputArch.ENCODER_SCORE:
            # Compute score at each encoder step
            self.encoder_scorer = nn.Linear(dim_model, 1)

        if state_embed_dim is None:
            state_embed_dim = dim_model // 2
        candidate_embed_dim = dim_model - state_embed_dim
        self.state_embedder = Embedder(state_dim, state_embed_dim)
        self.candidate_embedder = Embedder(candidate_dim, candidate_embed_dim)

        # Initialize parameters with Glorot / fan_avg.
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        print_model_info(self)