Пример #1
0
    def __init__(self,
                 batch_size,
                 hidden_size,
                 emb_size,
                 field_size,
                 pos_size,
                 source_vocab,
                 field_vocab,
                 position_vocab,
                 target_vocab,
                 field_concat,
                 position_concat,
                 fgate_enc,
                 dual_att,
                 encoder_add_pos,
                 decoder_add_pos,
                 learning_rate,
                 scope_name,
                 name,
                 start_token=2,
                 stop_token=2,
                 max_length=150):
        '''
        batch_size, hidden_size, emb_size, field_size, pos_size: size of batch; hidden layer; word/field/position embedding
        source_vocab, target_vocab, field_vocab, position_vocab: vocabulary size of encoder words; decoder words; field types; position
        field_concat, position_concat: bool values, whether concat field/position embedding to word embedding for encoder inputs or not
        fgate_enc, dual_att: bool values, whether use field-gating / dual attention or not
        encoder_add_pos, decoder_add_pos: bool values, whether add position embedding to field-gating encoder / decoder with dual attention or not
        '''
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.emb_size = emb_size
        self.field_size = field_size
        self.pos_size = pos_size
        self.uni_size = emb_size if not field_concat else emb_size + field_size
        self.uni_size = self.uni_size if not position_concat else self.uni_size + 2 * pos_size
        self.field_encoder_size = field_size if not encoder_add_pos else field_size + 2 * pos_size
        self.field_attention_size = field_size if not decoder_add_pos else field_size + 2 * pos_size
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab
        self.field_vocab = field_vocab
        self.position_vocab = position_vocab
        self.grad_clip = 5.0
        self.start_token = start_token
        self.stop_token = stop_token
        self.max_length = max_length
        self.scope_name = scope_name
        self.name = name
        self.field_concat = field_concat
        self.position_concat = position_concat
        self.fgate_enc = fgate_enc
        self.dual_att = dual_att
        self.encoder_add_pos = encoder_add_pos
        self.decoder_add_pos = decoder_add_pos

        self.units = {}
        self.params = {}

        self.encoder_input = tf.placeholder(tf.int32, [None, None])
        self.encoder_field = tf.placeholder(tf.int32, [None, None])
        self.encoder_pos = tf.placeholder(tf.int32, [None, None])
        self.encoder_rpos = tf.placeholder(tf.int32, [None, None])
        self.decoder_input = tf.placeholder(tf.int32, [None, None])
        self.encoder_len = tf.placeholder(tf.int32, [None])
        self.decoder_len = tf.placeholder(tf.int32, [None])
        self.decoder_output = tf.placeholder(tf.int32, [None, None])
        self.enc_mask = tf.sign(tf.to_float(self.encoder_pos))
        with tf.variable_scope(scope_name):
            if self.fgate_enc:
                print 'field-gated encoder LSTM'
                self.enc_lstm = fgateLstmUnit(self.hidden_size, self.uni_size,
                                              self.field_encoder_size,
                                              'encoder_select')
            else:
                print 'normal encoder LSTM'
                self.enc_lstm = LstmUnit(self.hidden_size, self.uni_size,
                                         'encoder_lstm')
            self.dec_lstm = LstmUnit(self.hidden_size, self.emb_size,
                                     'decoder_lstm')
            self.dec_out = OutputUnit(self.hidden_size, self.target_vocab,
                                      'decoder_output')

        self.units.update({
            'encoder_lstm': self.enc_lstm,
            'decoder_lstm': self.dec_lstm,
            'decoder_output': self.dec_out
        })

        # ======================================== embeddings ======================================== #
        with tf.device('/cpu:0'):
            with tf.variable_scope(scope_name):
                self.embedding = tf.get_variable(
                    'embedding', [self.source_vocab, self.emb_size])
                self.encoder_embed = tf.nn.embedding_lookup(
                    self.embedding, self.encoder_input)
                self.decoder_embed = tf.nn.embedding_lookup(
                    self.embedding, self.decoder_input)
                if self.field_concat or self.fgate_enc or self.encoder_add_pos or self.decoder_add_pos:
                    self.fembedding = tf.get_variable(
                        'fembedding', [self.field_vocab, self.field_size])
                    self.field_embed = tf.nn.embedding_lookup(
                        self.fembedding, self.encoder_field)
                    self.field_pos_embed = self.field_embed
                    if self.field_concat:
                        self.encoder_embed = tf.concat(
                            [self.encoder_embed, self.field_embed], 2)
                if self.position_concat or self.encoder_add_pos or self.decoder_add_pos:
                    self.pembedding = tf.get_variable(
                        'pembedding', [self.position_vocab, self.pos_size])
                    self.rembedding = tf.get_variable(
                        'rembedding', [self.position_vocab, self.pos_size])
                    self.pos_embed = tf.nn.embedding_lookup(
                        self.pembedding, self.encoder_pos)
                    self.rpos_embed = tf.nn.embedding_lookup(
                        self.rembedding, self.encoder_rpos)
                    if position_concat:
                        self.encoder_embed = tf.concat([
                            self.encoder_embed, self.pos_embed, self.rpos_embed
                        ], 2)
                        self.field_pos_embed = tf.concat([
                            self.field_embed, self.pos_embed, self.rpos_embed
                        ], 2)
                    elif self.encoder_add_pos or self.decoder_add_pos:
                        self.field_pos_embed = tf.concat([
                            self.field_embed, self.pos_embed, self.rpos_embed
                        ], 2)

        if self.field_concat or self.fgate_enc:
            self.params.update({'fembedding': self.fembedding})
        if self.position_concat or self.encoder_add_pos or self.decoder_add_pos:
            self.params.update({'pembedding': self.pembedding})
            self.params.update({'rembedding': self.rembedding})
        self.params.update({'embedding': self.embedding})

        # ======================================== encoder ======================================== #
        if self.fgate_enc:
            print 'field gated encoder used'
            en_outputs, en_state = self.fgate_encoder(self.encoder_embed,
                                                      self.field_pos_embed,
                                                      self.encoder_len)
        else:
            print 'normal encoder used'
            en_outputs, en_state = self.encoder(self.encoder_embed,
                                                self.encoder_len)
        # ======================================== decoder ======================================== #

        if self.dual_att:
            print 'dual attention mechanism used'
            with tf.variable_scope(scope_name):
                self.att_layer = dualAttentionWrapper(
                    self.hidden_size, self.hidden_size,
                    self.field_attention_size, en_outputs,
                    self.field_pos_embed, "attention")
                self.units.update({'attention': self.att_layer})
        else:
            print "normal attention used"
            with tf.variable_scope(scope_name):
                self.att_layer = AttentionWrapper(self.hidden_size,
                                                  self.hidden_size, en_outputs,
                                                  "attention")
                self.units.update({'attention': self.att_layer})

        # decoder for training
        de_outputs, de_state = self.decoder_t(en_state, self.decoder_embed,
                                              self.decoder_len)
        # decoder for testing
        self.g_tokens, self.atts = self.decoder_g(en_state)
        # self.beam_seqs, self.beam_probs, self.cand_seqs, self.cand_probs = self.decoder_beam(en_state, beam_size)

        losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=de_outputs, labels=self.decoder_output)
        mask = tf.sign(tf.to_float(self.decoder_output))
        losses = mask * losses
        self.mean_loss = tf.reduce_mean(losses)

        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.mean_loss, tvars),
                                          self.grad_clip)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        self.train_op = optimizer.apply_gradients(zip(grads, tvars))
Пример #2
0
    def __init__(self,
                 batch_size,
                 hidden_size,
                 emb_size,
                 field_size,
                 pos_size,
                 source_vocab,
                 field_vocab,
                 position_vocab,
                 target_vocab,
                 field_concat,
                 position_concat,
                 fgate_enc,
                 dual_att,
                 encoder_add_pos,
                 decoder_add_pos,
                 learning_rate,
                 scope_name,
                 name,
                 use_coverage,
                 coverage_penalty,
                 fieldid2word,
                 copy_gate_penalty,
                 use_copy_gate,
                 gpt_hparams,
                 vocab_ind,
                 empty_token=28920,
                 stop_token=50256,
                 max_length=85):
        '''
        batch_size, hidden_size, emb_size, field_size, pos_size: size of batch; hidden layer; word/field/position embedding
        source_vocab, target_vocab, field_vocab, position_vocab: vocabulary size of encoder words; decoder words; field types; position
        field_concat, position_concat: bool values, whether concat field/position embedding to word embedding for encoder inputs or not
        fgate_enc, dual_att: bool values, whether use field-gating / dual attention or not
        encoder_add_pos, decoder_add_pos: bool values, whether add position embedding to field-gating encoder / decoder with dual attention or not

        ###
        original full vocab ind
        empty_token=28920, stop_token=50256
        '''

        # data options
        self.empty_token = empty_token
        self.stop_token = stop_token
        self.max_length = max_length
        self.start_token = empty_token
        self.select_ind = vocab_ind
        self.fieldid2word = fieldid2word

        # model hyperparams
        self.gpt_hparams = gpt_hparams
        self.hidden_size = self.gpt_hparams.n_embd

        # model architecture options
        self.use_coverage = use_coverage
        self.coverage_penalty = coverage_penalty
        self.use_copy_gate = use_copy_gate
        self.copy_gate_penalty = copy_gate_penalty
        self.fgate_enc = fgate_enc
        self.dual_att = dual_att
        self.scope_name = scope_name
        self.name = name

        # embedding sizes
        self.emb_size = self.gpt_hparams.n_embd  # word embedding size
        self.field_size = field_size  # field embedding size
        self.pos_size = pos_size  # position embedding size
        self.field_concat = field_concat
        self.position_concat = position_concat
        self.encoder_add_pos = encoder_add_pos
        self.decoder_add_pos = decoder_add_pos
        self.uni_size = self.emb_size if not field_concat else self.emb_size + field_size
        self.uni_size = self.uni_size if not position_concat else self.uni_size + 2 * pos_size
        self.field_encoder_size = field_size if not encoder_add_pos else field_size + 2 * pos_size
        self.field_attention_size = field_size if not decoder_add_pos else field_size + 2 * pos_size
        self.dec_input_size = self.emb_size + field_size + 2 * pos_size  # FIXME not conditioned?

        # source and target vocabulary sizes, field and position vocabulary sizes
        self.source_vocab = self.gpt_hparams.n_vocab
        self.target_vocab = self.gpt_hparams.n_vocab
        self.field_vocab = field_vocab
        self.position_vocab = position_vocab

        # training options
        self.grad_clip = 5.0

        self.units = {}
        self.params = {}

        self.define_input_placeholders()

        self.define_encoder_unit()

        context_outputs = self.define_decoder_arch()

        # get GPT embeddings
        self.lookup_all_embeddings()

        self.define_encoder_arch()

        # attention and copy layers
        if self.dual_att:
            print('dual attention mechanism used')
            with tf.variable_scope(scope_name):
                self.att_layer = dualAttentionWrapper(
                    self.dec_input_size, self.hidden_size, self.hidden_size,
                    self.field_attention_size, "attention")
                self.units.update({'attention': self.att_layer})
        else:
            print("normal attention used")
            with tf.variable_scope(scope_name):
                self.att_layer = AttentionWrapper(self.hidden_size,
                                                  self.hidden_size,
                                                  self.en_outputs, "attention")
                self.units.update({'attention': self.att_layer})

        # loss functions
        # calculate those locations where field values are present
        self.copy_gate_mask = tf.cast(
            tf.greater(self.decoder_pos_input,
                       tf.zeros_like(self.decoder_pos_input)), tf.float32)
        self.copy_gate_mask = tf.concat([
            self.copy_gate_mask,
            tf.zeros([tf.shape(self.encoder_input)[0], 1], tf.float32)
        ], 1)

        # decoder for training
        # get start values to start gpt generation
        logits0 = context_outputs['logits'][:, -1, :]
        dist0 = tf.nn.softmax(logits0)  # start token
        x0 = tf.cast(tf.argmax(dist0, 1), tf.int32)
        past0 = context_outputs['presents']
        hidden0 = context_outputs['hidden'][:, -1, :]

        de_outputs, _, self.de_conv_loss, self.copy_gate_loss = self.decoder_t(
            self.decoder_input, self.decoder_len, x0, past0, hidden0)

        # decoder for testing
        self.g_tokens, self.atts = self.decoder_g(x0, past0, hidden0)

        ### enc-dec loss
        self.decoder_output_one_hot = tf.one_hot(indices=self.decoder_output,
                                                 depth=self.target_vocab,
                                                 axis=-1)

        # mask for dec. plus eos
        dec_shape_len = tf.shape(self.decoder_output)[1]
        batch_nums = tf.range(0, dec_shape_len)
        batch_nums = tf.expand_dims(batch_nums, 0)
        batch_nums = tf.tile(batch_nums, [self.batch_size, 1])
        decoder_len_com = tf.expand_dims(self.decoder_len, 1)
        decoder_len_com = tf.tile(decoder_len_com, [1, dec_shape_len])
        mask = tf.cast(tf.less_equal(batch_nums, decoder_len_com), tf.float32)

        # total loss
        losses = -tf.reduce_sum(
            self.decoder_output_one_hot * tf.log(de_outputs + 1e-6), 2)
        losses = mask * losses

        # faster. original reduce mean
        self.mean_loss = tf.reduce_sum(losses)

        self.de_conv_loss *= self.coverage_penalty

        self.copy_gate_loss = self.copy_gate_penalty * tf.reduce_sum(
            self.copy_gate_loss)

        if self.use_copy_gate:
            self.mean_loss += self.copy_gate_loss

        if self.use_coverage:
            self.mean_loss += self.de_conv_loss

        train_params = tf.trainable_variables()

        # train enc-dec
        with tf.variable_scope(scope_name):
            self.global_step = tf.Variable(0,
                                           name='global_step',
                                           trainable=False)
            self.grads, _ = tf.clip_by_global_norm(
                tf.gradients(self.mean_loss,
                             train_params,
                             colocate_gradients_with_ops=True), self.grad_clip)

            # accumulate gradient
            self.opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
            self.acc_gradients = list(
                map(
                    lambda param: tf.get_variable(param.name.split(":")[0],
                                                  param.get_shape(),
                                                  param.dtype,
                                                  tf.constant_initializer(0.0),
                                                  trainable=False),
                    train_params))

            # initialize losses?
            self._loss = tf.get_variable("acc_loss", (),
                                         tf.float32,
                                         tf.constant_initializer(0.0),
                                         trainable=False)
            self._cov_loss = tf.get_variable("acc_cov_loss", (),
                                             tf.float32,
                                             tf.constant_initializer(0.0),
                                             trainable=False)
            self._gate_loss = tf.get_variable("acc_gate_loss", (),
                                              tf.float32,
                                              tf.constant_initializer(0.0),
                                              trainable=False)

            self.accumulate_gradients()

            # train update
            self.update = self.opt.apply_gradients(
                zip(list(map(lambda v: v.value(), self.acc_gradients)),
                    train_params),
                global_step=self.global_step)

            # collect all values to reset after updating with accumulated gradient
            self.reset = list(
                map(lambda param: param.initializer, self.acc_gradients))
            self.reset.append(self._loss.initializer)
            self.reset.append(self._cov_loss.initializer)
            self.reset.append(self._gate_loss.initializer)

            self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)