def __init__( self, vocab: Vocabulary, embedding_size: int, encoder_hidden_size: int, encoder_num_layers: int, decoder: DecoderNet, decoder_type: str = "lstm", decoder_num_layers: int = 1, share_decoder_params: bool = True, # only valid when decoder_type == `transformer` text_field_embedder: TextFieldEmbedder = None, start_token: str = "[CLS]", end_token: str = "[SEP]", index_name: str = "tokens", beam_size: int = 4, max_turn_len: int = 3, min_dec_len: int = 4, max_dec_len: int = 30, coverage_factor: float = 0.0, device: Union[int, str, List[int]] = -1, metrics: Optional[List[Metric]] = None, valid_metric_keys: List[str] = None, dropout_rate: float = 0.1, seed: int = 42, initializer: InitializerApplicator = InitializerApplicator(), regularizer: RegularizerApplicator = None): # 初始化vocab和regularizer Model.__init__(self, vocab, regularizer) # ----------- 定义embedding和编码器 --------------- # 获取单词序列的embedding self._text_field_embedder = text_field_embedder # 定义编码器 self.encoder = torch.nn.LSTM(input_size=embedding_size, hidden_size=encoder_hidden_size, num_layers=encoder_num_layers, batch_first=True, dropout=dropout_rate, bidirectional=True) self.encoder_num_layers = encoder_num_layers # 由于编码器是双向的,而解码器是单向的 # 所有将编码器的输出转换成单向的维度 self.bi2uni_dec_init_state = torch.nn.Linear(2 * encoder_hidden_size, encoder_hidden_size) self.encoder_output_dim = encoder_hidden_size # ------------- 通用初始化过程 --------------------- self.common_init(self.encoder_output_dim, decoder, decoder_type, decoder_num_layers, share_decoder_params, start_token, end_token, index_name, beam_size, min_dec_len, max_dec_len, coverage_factor, device, metrics, valid_metric_keys, seed, initializer) # -------------- 不同编码器不同的初始化过程 --------------- # 获取embedding的维度 embedding_size = self._text_field_embedder.get_output_dim() self.turn_embedding = torch.nn.Embedding(max_turn_len, embedding_size)
def __init__(self, vocab: Vocabulary, encoder: Seq2SeqEncoder, decoder: DecoderNet, decoder_type: str = "lstm", encoder_num_layers: int = 1, decoder_num_layers: int = 1, share_encoder_params: bool = True, share_decoder_params: bool = True, text_field_embedder: TextFieldEmbedder = None, start_token: str = "[CLS]", end_token: str = "[SEP]", index_name: str = "tokens", beam_size: int = 4, max_turn_len: int = 3, min_dec_len: int = 4, max_dec_len: int = 30, coverage_factor: float = 0.0, device: Union[int, str, List[int]] = -1, metrics: Optional[List[Metric]] = None, valid_metric_keys: List[str] = None, seed: int = 42, initializer: InitializerApplicator = InitializerApplicator(), regularizer: RegularizerApplicator = None): # 初始化vocab和regularizer Model.__init__(self, vocab, regularizer) # ---------- 定义embedding和编码器 ----------------- # 获取单词序列的embedding # 通常是Embedding这个类 self._text_field_embedder = text_field_embedder # 定义编码器 self.encoder = encoder # 获取编码器的输出维度 self.encoder_output_dim = self.encoder.get_output_dim() # ---------- 通用初始化过程 ------------- self.common_init(self.encoder_output_dim, decoder, decoder_type, decoder_num_layers, share_decoder_params, start_token, end_token, index_name, beam_size, min_dec_len, max_dec_len, coverage_factor, device, metrics, valid_metric_keys, seed, initializer) # --------- 不同编码器不同的初始化过程 ------- # 获取embedding的维度 embedding_size = self._text_field_embedder.get_output_dim() self.turn_embedding = torch.nn.Embedding(max_turn_len, embedding_size) self.encoder_num_layers = encoder_num_layers self._share_encoder_params = share_encoder_params # 如果解码器是LSTM,则需要使用attention初始化LSTM的初始状态 # 如果编码器也是LSTM,则不需要 if self.params["decoder_type"] == "lstm": self.h_query = torch.nn.Parameter(torch.randn( [self.encoder_output_dim]), requires_grad=True) self.c_query = torch.nn.Parameter(torch.randn( [self.encoder_output_dim]), requires_grad=True) self.init_attention = DotProductAttention()