def __get_decoder(self, input_layer, src_encoder_output, mutual_tar_src_mask, fact_encoder_output, mutual_tar_fact_mask): print('This is in Decoder...') self_padding_mask = PaddingMaskLayer(name='decoder_self_padding_mask', src_len=self.args.tar_seq_length, pad_id=self.pad_id)(input_layer) seq_mask = SequenceMaskLayer()(input_layer) self_attn_mask = Add()([self_padding_mask, seq_mask]) # greater than 1, means not be padded in both self_padding_mask and seq_mask self_attn_mask = Lambda(lambda x: K.cast(K.greater((x), 1), dtype='int32'), name='add_padding_seq_mask')(self_attn_mask) next_step_input, self.decoder_embedding_matrix = self.decoder_embedding_layer(input_layer) next_step_input = self.decoder_coord_embedding_layer(next_step_input, step=0) for i in range(self.args.transformer_depth): decoder_block = TEDDecoderBlock( name='transformer_decoder' + str(i), num_heads=self.args.num_heads, fact_number=self.args.fact_number, residual_dropout=self.transformer_dropout, attention_dropout=self.transformer_dropout, activation='relu', vanilla_wiring=True) # use vanilla Transformer instead of Universal Transformer next_step_input = decoder_block([next_step_input, self_attn_mask, \ src_encoder_output, mutual_tar_src_mask, \ fact_encoder_output, mutual_tar_fact_mask ]) return next_step_input
def get_model(self, pad_id): self.pad_id = pad_id inp_src = Input(name='src_input', shape=(self.args.src_seq_length, ), dtype='int32' ) inp_tar = Input(name='answer_input', shape=(self.args.tar_seq_length, ), dtype='int32', ) inp_facts = Input(name='facts_input', shape=(self.args.fact_number, self.args.src_seq_length,), dtype='int32', ) # shape: (bs, sf_number, seq_len) inp_src_expand = Lambda(lambda x: K.expand_dims(x, axis=1))(inp_src) src_encoder_output = self.__get_encoder(inp_src_expand, 'src') fact_encoder_output = self.__get_encoder(inp_facts, 'fact') mutual_tar_src_mask = PaddingMaskLayer(name='mutual_tar_src_mask', src_len=self.args.tar_seq_length, pad_id=self.pad_id)(inp_src_expand) mutual_tar_fact_mask = PaddingMaskLayer(name='mutual_tar_fact_mask', src_len=self.args.tar_seq_length, pad_id=self.pad_id)(inp_facts) inp_tar_expand = Lambda(lambda x: K.expand_dims(x, axis=1))(inp_tar) decoder_output = self.__get_decoder(inp_tar_expand, src_encoder_output, mutual_tar_src_mask, fact_encoder_output, mutual_tar_fact_mask) decoder_output = Reshape((self.args.tar_seq_length, self.args.embedding_dim, ))(decoder_output) # build model part word_predictions = self.output_softmax_layer( self.output_layer([decoder_output, self.decoder_embedding_matrix])) print('word_predictions: ', word_predictions ) model = Model( inputs=[inp_src, inp_tar, inp_facts], outputs=[word_predictions] ) return model
def __get_query_encoder(self, input_layer, pad_id, _name): #print('This is Query Encoder...') self_attn_mask = PaddingMaskLayer(src_len=self.args.src_seq_length, pad_id=pad_id)(input_layer) next_step_input, _ = self.query_embedding_layer(input_layer) next_step_input = self.query_coord_embedding_layer(next_step_input, step=0) for i in range(self.args.transformer_depth): next_step_input = self.query_encoder_blocks[i]( [next_step_input, self_attn_mask]) return next_step_input
def __get_encoder(self, input_layer): print('This is in Encoder...') self_attn_mask = PaddingMaskLayer(name='encoder_self_padding_mask', src_len=self.args.src_seq_length, pad_id=self.pad_id)(input_layer) next_step_input, _ = self.encoder_embedding_layer(input_layer) next_step_input = self.encoder_coord_embedding_layer(next_step_input, step=0) for i in range(self.args.transformer_depth): encoder_block = TransformerEncoderBlock( name='transformer_encoder' + str(i), num_heads=self.args.num_heads, residual_dropout=self.transformer_dropout, attention_dropout=self.transformer_dropout, activation='relu', vanilla_wiring=True) # use vanilla Transformer instead of Universal Transformer next_step_input = encoder_block([next_step_input, self_attn_mask]) return next_step_input
def get_model(self, pad_id): self.pad_id = pad_id inp_src = Input(name='src_input', shape=(None, ), dtype='int32') inp_answer = Input( name='answer_input', shape=(None, ), dtype='int32', ) encoder_output = self.__get_encoder(inp_src) mutual_attn_mask = PaddingMaskLayer(name='decoder_mutual_padding_mask', src_len=self.args.tar_seq_length, pad_id=self.pad_id)(inp_src) decoder_output = self.__get_decoder(inp_answer, encoder_output, mutual_attn_mask) # build model part word_predictions = self.output_softmax_layer( self.output_layer([decoder_output, self.decoder_embedding_matrix])) model = Model(inputs=[inp_src, inp_answer], outputs=[word_predictions]) return model