def __init__(self, args, src_field, tgt_field):
        super(Seq2Seq, self).__init__()
        logger.info('[MODEL] Preparing the Standard Seq2Seq model')
        self.beam_length_penalize = args.beam_length_penalize

        # Build Embedding
        if args.share_embedding:
            self.src_embed = rnn_helper.build_embedding(
                args.tgt_vocab_size_with_offsets, args.embed_size)
            self.dec_embed = self.src_embed
            assert args.share_vocab
            enc_embed = self.src_embed
            dec_embed = self.src_embed
        else:
            self.src_embed = rnn_helper.build_embedding(
                args.src_vocab_size, args.embed_size)
            self.dec_embed = rnn_helper.build_embedding(
                args.tgt_vocab_size_with_offsets, args.embed_size)
            dec_embed = self.dec_embed

        # Build Encoder
        encoder = RNNEncoder(hparams=args, embed=enc_embed)
        self.encoder = encoder

        # Build Bridge
        if args.bridge == 'general':
            bridge = LinearBridge(hparams=args)
        elif args.bridge == 'none':
            bridge = None
        else:
            raise NotImplementedError()
        self.bridge = bridge

        # Build Decoder
        decoder = RNNDecoder(hparams=args, embed=dec_embed)
        self.decoder = decoder

        # Build Copy Attention
        self.copy_mode = False
        self.copy_coverage = False
        if args.copy:
            self.copy_mode = True
            # Coverage
            if args.copy_coverage > 0.0:
                self.copy_coverage = True

        # Special Tokens
        self.tgt_sos_idx = tgt_field.vocab.stoi['<sos>']
        self.tgt_eos_idx = tgt_field.vocab.stoi['<eos>']
        self.tgt_unk_idx = tgt_field.vocab.stoi['<unk>']
        self.tgt_pad_idx = tgt_field.vocab.stoi['<pad>']
 def __init__(self, args, src_embed, concat=True, dropout=0.5):
     super(LSTMFieldTableEncoder, self).__init__(args, src_embed)
     self.concat = concat
     self.hidden_size = args.hidden_size
     self.field_size = args.field_key_embed_size
     self.input_size = self.field_word_embed_size
     assert args.field_input_tags == "none"
     if concat:
         self.word_encoder = LSTM(self.input_size + self.field_size, self.hidden_size)
     else:
         self.word_encoder = LSTM(self.input_size, self.hidden_size)
     assert args.infobox_memory_bank_format == 'fwk_fwv_fk'
     self.attribute_encoder = LSTM(self.hidden_size, self.hidden_size)
     self.positional_embedding = build_embedding(500, 5)
     self.output_projection = nn.Linear(self.hidden_size * 2, self.hidden_size, bias=False)
Beispiel #3
0
    def __init__(self, method, dropout, vocab_size, word_embed_size, embed_size, down_scale=1.00):
        super(T2STokenEncoder, self).__init__()
        self.sub_embed = rnn_helper.build_embedding(vocab_size, word_embed_size)
        self.scaled_embed_size = int(embed_size * down_scale)
        if self.scaled_embed_size != word_embed_size:
            self.down_scale_projection = torch.nn.Linear(word_embed_size, self.scaled_embed_size)
        else:
            self.down_scale_projection = None

        if method == 'mean':
            self.encoder = TokenMeanEncoder(dropout=dropout)
            self.output_projection = torch.nn.Linear(word_embed_size + self.scaled_embed_size,
                                                     embed_size)
        elif method == 'cnn':
            self.encoder = TokenCNNEncoder(embedding_size=self.scaled_embed_size,
                                           filters=[(1, 75), (2, 100), (3, 25)], dropout=dropout)
            self.output_projection = torch.nn.Linear(word_embed_size + 200,
                                                     embed_size)
    def __init__(self, args, src_field, tgt_field):
        super(InfoSeq2Seq, self).__init__()

        logger.info('[MODEL] Preparing the InfoSeq2Seq model')
        self.drop_out = nn.Dropout(args.dropout)
        self.beam_length_penalize = args.beam_length_penalize
        self.hidden_size = args.hidden_size
        self.teach_force_rate = args.teach_force_rate
        # 如果打开 Table2Text模式,那么就只做Table2Text的生成,Text2Text 同理
        self.table2text_mode = False
        self.text2text_mode = False
        if args.task_mode == 'text2text':
            self.text2text_mode = True
        elif args.task_mode == 'table2text':
            self.table2text_mode = True

        self.enable_query_encoder = True
        self.enable_field_encoder = True

        self.enabled_char_encoders = set(args.char_encoders.split(','))

        if self.table2text_mode:
            self.enable_query_encoder = False
            assert args.copy is False
            assert args.bridge == 'none' or args.bridge == 'general'
        if self.text2text_mode:
            self.enable_field_encoder = False
            assert args.field_copy is False
            assert args.bridge == 'none' or args.bridge == 'general'

            # Word embedding for representing dialogue words
        self.embed_size = args.embed_size
        if args.share_embedding:
            self.src_embed = rnn_helper.build_embedding(
                args.tgt_vocab_size_with_offsets, args.embed_size)
            self.dec_embed = self.src_embed
            enc_embed = self.src_embed
            dec_embed = self.src_embed
        else:
            self.src_embed = rnn_helper.build_embedding(
                args.src_vocab_size, args.embed_size)
            self.dec_embed = rnn_helper.build_embedding(
                args.tgt_vocab_size_with_offsets, args.embed_size)
            dec_embed = self.dec_embed

        if 'src' in self.enabled_char_encoders:
            self.sub_src_embed = rnn_helper.build_embedding(
                args.sub_src_vocab_size, args.embed_size)

        # 是否使用Dialogue的Pos Tag
        if args.add_pos_tag_embedding and self.enable_query_encoder:
            self.add_pos_tag_embedding = True
            self.src_pos_tag_embed_size = args.field_tag_embedding_size
            self.src_pos_tag_embed = rnn_helper.build_embedding(
                args.src_tag_vocab_size, self.src_pos_tag_embed_size)
        else:
            self.src_pos_tag_embed_size = 0
            self.add_pos_tag_embedding = False
            self.src_pos_tag_embed = None

        # Build Field Encoder
        if self.enable_field_encoder:
            if args.field_encoder == 'lstm':
                field_encoder = FieldTableEncoder(args=args,
                                                  src_embed=self.src_embed)
            elif args.field_encoder == 'transformer':
                field_encoder = TransformerFieldTableEncoder(
                    args=args, src_embed=self.src_embed)
            elif args.field_encoder == 'hierarchical_lstm':
                field_encoder = LSTMFieldTableEncoder(args=args,
                                                      src_embed=self.src_embed)
            elif args.field_encoder == 'hierarchical_field':
                field_encoder = HierarchicalFieldTableEncoder(
                    args=args, src_embed=self.src_embed)
            elif args.field_encoder == 'hierarchical_intra_field':
                field_encoder = HierarchicalIntraFieldTableEncoder(
                    args=args, src_embed=self.src_embed)
            elif args.field_encoder == 'hierarchical_infobox':
                field_encoder = HierarchicalInfoboxEncoder(
                    args=args, src_embed=self.src_embed)
            self.field_equivalent_input_size = field_encoder.get_field_equivalent_input_size(
            )
            self.field_encoder = field_encoder
        else:
            self.field_equivalent_input_size = None
            self.field_encoder = None

        if self.enable_query_encoder:
            encoder = RNNEncoder(hparams=args,
                                 embed=enc_embed,
                                 tag_embed=self.src_pos_tag_embed)
            self.encoder = encoder
            if args.enc_birnn:
                self.birnn_down_scale = torch.nn.Linear(
                    args.hidden_size * 2, args.hidden_size, False)
            else:
                self.birnn_down_scale = None
                # Build Bridge
            self.bridge_mode = args.bridge
            if args.bridge == 'general' or args.bridge == 'fusion':
                bridge = LinearBridge(args.bridge, args.rnn_type,
                                      args.hidden_size, args.enc_layers,
                                      args.dropout, args.enc_birnn)
            elif args.bridge == 'attention' or args.bridge == 'field_attention' or args.bridge == 'fusion_attention':
                bridge = AttentionBridge(args.bridge, args.hidden_size,
                                         args.dropout, args.enc_birnn)
            elif args.bridge == 'clue_attention' or args.bridge == 'clue_attention2':
                bridge = AttentionBridge('attention', args.hidden_size,
                                         args.dropout, args.enc_birnn)
                self.posterior_bridge = AttentionBridge(
                    'attention', args.hidden_size, args.dropout,
                    args.enc_birnn)
                if args.bridge == 'clue_attention' or args.bridge == 'clue_attention2':
                    self.clue_query_projection = torch.nn.Linear(
                        (int(args.enc_birnn) + 1) * args.hidden_size * 2,
                        args.hidden_size * (int(args.enc_birnn) + 1))
            elif args.bridge == 'post_field_attention' or args.bridge == 'clue_field_attention' \
                    or args.bridge == 'clue_field_attention2':
                bridge = AttentionBridge('field_attention', args.hidden_size,
                                         args.dropout, args.enc_birnn)
                self.posterior_bridge = AttentionBridge(
                    'field_attention', args.hidden_size, args.dropout,
                    args.enc_birnn)
                if args.bridge == 'clue_field_attention' or args.bridge == 'clue_field_attention2':
                    self.clue_query_projection = torch.nn.Linear(
                        (int(args.enc_birnn) + 1) * args.hidden_size * 2,
                        args.hidden_size * (int(args.enc_birnn) + 1))
            elif args.bridge == 'post_attn_fusion':
                bridge = AttentionBridge('fusion_attention', args.hidden_size,
                                         args.dropout, args.enc_birnn)
                self.posterior_bridge = LinearBridge('general', args.rnn_type,
                                                     args.hidden_size,
                                                     args.enc_layers,
                                                     args.dropout)
            elif args.bridge == 'post_fusion' or args.bridge == 'post_fusion2':
                bridge = LinearBridge('fusion', args.rnn_type,
                                      args.hidden_size, args.enc_layers,
                                      args.dropout, args.enc_birnn)
                post_bridge = LinearBridge('general', args.rnn_type,
                                           args.hidden_size, args.enc_layers,
                                           args.dropout, args.enc_birnn)
                self.posterior_bridge = post_bridge
            elif args.bridge == 'none':
                bridge = None
            else:
                raise NotImplementedError()
            self.bridge = bridge
        else:
            self.encoder = None
            assert args.bridge == 'none', 'table2text mode does not support bridge'

        # Build Decoder
        decoder = TableAwareDecoder(hparams=args, embed=dec_embed)
        self.decoder = decoder

        # Build SRC-Copy Attention
        self.src_copy_mode = False
        self.copy_coverage = False
        if args.copy:
            assert self.table2text_mode is False
            self.src_copy_mode = True
            self.max_copy_token_num = args.max_copy_token_num
            # State-to-Input Projection:
            self.add_state_to_copy_token = args.add_state_to_copy_token
            copy_state_input_dim = args.embed_size
            if self.add_pos_tag_embedding:
                copy_state_input_dim += self.src_pos_tag_embed_size
            if self.add_state_to_copy_token:
                copy_state_input_dim += self.hidden_size
            if copy_state_input_dim > args.embed_size:
                self.copied_state_to_input_embed_projection = nn.Linear(
                    copy_state_input_dim, args.embed_size, bias=False)
            else:
                self.copied_state_to_input_embed_projection = None
            # Coverage
            if args.copy_coverage > 0.0:
                self.copy_coverage = True

        # Build Field-SRC Copy
        self.field_copy_mode = False
        if args.field_copy:
            self.field_copy_mode = True
            self.max_kw_pairs_num = args.max_kw_pairs_num
            # State-to-Input Projection
            self.field_state_to_input_embed_projection = nn.Linear(
                self.field_equivalent_input_size + self.hidden_size,
                args.embed_size,
                bias=False)

        # Special Tokens
        self.tgt_sos_idx = tgt_field.vocab.stoi['<sos>']
        self.tgt_eos_idx = tgt_field.vocab.stoi['<eos>']
        self.tgt_unk_idx = tgt_field.vocab.stoi['<unk>']
        self.tgt_pad_idx = tgt_field.vocab.stoi['<pad>']
    def __init__(self, args, src_embed):
        super(InfoboxTableEncoder, self).__init__()
        self.infobox_mode = 'standard'
        self.src_embed = src_embed
        self.field_key_embed_size = args.field_key_embed_size
        self.field_all_pos_embed_size = 0
        self.field_input_tags = set(args.field_input_tags.split(','))
        assert len(self.field_input_tags - VALID_FIELD_POS_INPUTS) == 0

        self.enabled_char_encoders = set(args.char_encoders.split(','))

        # Word embedding for representing infobox words
        self.field_key_embed = rnn_helper.build_embedding(
            args.field_vocab_size, args.field_key_embed_size)
        if 'field_key' in self.enabled_char_encoders:

            self.sub_field_key_char_encoder = char_encoders.T2STokenEncoder(
                args.char_encoder_type,
                dropout=args.dropout,
                vocab_size=args.sub_field_vocab_size,
                word_embed_size=args.field_key_embed_size,
                embed_size=args.field_key_embed_size,
            )
        else:
            self.sub_field_key_char_encoder = None

        if 'field_word' in self.enabled_char_encoders:
            self.sub_field_word_char_encoder = char_encoders.T2STokenEncoder(
                args.char_encoder_type,
                dropout=args.dropout,
                vocab_size=args.sub_field_word_vocab_size,
                word_embed_size=args.embed_size,
                embed_size=args.embed_size)
        else:
            self.sub_field_word_embed = None
            self.sub_field_word_char_encoder = None

        # Local Positions
        if 'local_pos_fw' in self.field_input_tags:
            self.local_pos_fw_embed = rnn_helper.build_embedding(
                args.max_field_intra_word_num,
                args.field_position_embedding_size)
            self.field_all_pos_embed_size += args.field_position_embedding_size
        else:
            self.local_pos_fw_embed = None

        if 'local_pos_bw' in self.field_input_tags:
            self.local_pos_bw_embed = rnn_helper.build_embedding(
                args.max_field_intra_word_num,
                args.field_position_embedding_size)
            self.field_all_pos_embed_size += args.field_position_embedding_size
        else:
            self.local_pos_bw_embed = None

        # Global Positions
        if 'field_kv_pos' in self.field_input_tags:
            self.field_kv_pos_embed = rnn_helper.build_embedding(
                args.max_kv_pairs_num, args.field_position_embedding_size)
            self.field_all_pos_embed_size += args.field_position_embedding_size
        else:
            self.field_kv_pos_embed = None

        if 'field_kw_pos' in self.field_input_tags:
            self.field_kw_pos_embed = rnn_helper.build_embedding(
                args.max_kw_pairs_num, args.field_position_embedding_size)
            self.field_all_pos_embed_size += args.field_position_embedding_size
        else:
            self.field_kw_pos_embed = None

        # POS Tag 的Embedding
        if args.field_tag_usage == 'general':
            self.field_tag_embed = rnn_helper.build_embedding(
                args.field_pos_tag_vocab_size, args.field_tag_embedding_size)
            self.field_tag_embed_size = args.field_tag_embedding_size
        else:
            self.field_tag_embed = None
            self.field_tag_embed_size = 0

        # 对于Filed Word是否同时使用Field Word和标准Word的Embedding
        self.dual_field_word_embedding = args.dual_field_word_embedding
        if args.field_word_vocab_path != 'none':
            logger.info('[Model] Use a separate field word embedding ')
            self.field_word_embedding = rnn_helper.build_embedding(
                args.field_word_vocab_size, args.embed_size)
            self.field_word_embed_size = args.embed_size
            if args.dual_field_word_embedding:
                self.field_word_embed_size = args.embed_size * 2
        else:
            assert args.dual_field_word_embedding is False, 'requires a separate field vocab'
            self.field_word_embedding = self.src_embed
            self.field_word_embed_size = args.embed_size

        self.dual_attn = args.dual_attn != 'none'
        if self.dual_attn:
            self.dual_attn_projection = nn.Linear(
                self.field_key_embed_size + self.field_all_pos_embed_size +
                self.field_tag_embed_size,
                args.hidden_size,
                bias=False)