Exemple #1
0
    def call(self, x, y, mask, training=False):
        self.step += 1
        x_ = x

        x = dropout(x, keep_prob=self.keep_prob, training=training)
        y = dropout(y, keep_prob=self.keep_prob, training=training)

        if self.step == 0:
            if not self.identity:
                self.linear = layers.Dense(melt.get_shape(x, -1),
                                           activation=tf.nn.relu)
            else:
                self.linear = None

        # NOTICE shared linear!
        if self.linear is not None:
            x = self.linear(x)
            y = self.linear(y)

        scores = tf.matmul(x, tf.transpose(y, [0, 2, 1]))

        if mask is not None:
            JX = melt.get_shape(x, 1)
            mask = tf.tile(tf.expand_dims(mask, axis=1), [1, JX, 1])
            scores = softmax_mask(scores, mask)

        alpha = tf.nn.softmax(scores)
        self.alpha = alpha

        y = tf.matmul(alpha, y)

        if self.combine is None:
            return y
        else:
            return self.combine(x_, y, training=training)
Exemple #2
0
    def call(self, x, training=False):
        x = x['comment']
        batch_size = melt.get_shape(x, 0)
        length = melt.length(x)
        #with tf.device('/cpu:0'):
        x = self.embedding(x)

        num_units = [
            melt.get_shape(x, -1) if layer == 0 else 2 * self.num_units
            for layer in range(self.num_layers)
        ]
        #print('----------------length', tf.reduce_max(length), inputs.comment.shape)
        mask_fws = [
            melt.dropout(tf.ones([batch_size, 1, num_units[layer]],
                                 dtype=tf.float32),
                         keep_prob=self.keep_prob,
                         training=training,
                         mode=None) for layer in range(self.num_layers)
        ]
        mask_bws = [
            melt.dropout(tf.ones([batch_size, 1, num_units[layer]],
                                 dtype=tf.float32),
                         keep_prob=self.keep_prob,
                         training=training,
                         mode=None) for layer in range(self.num_layers)
        ]
        #x = self.encode(x, length, mask_fws=mask_fws, mask_bws=mask_bws)
        x = self.encode(x)

        x = self.pooling(x, length)
        #x = self.pooling(x)
        x = self.logits(x)
        return x
Exemple #3
0
 def call(self, x, y, training=False):
     self.step += 1
     if melt.get_shape(x, -1) != melt.get_shape(y, -1):
         if self.step == 0:
             self.dense = layers.Dense(melt.get_shape(x, -1),
                                       activation=None,
                                       name='sfu_dense')
         y = self.dense(y)
     return self.sfu(x, [y, x * y, x - y], training=training)
    def call(self, input, training=False):
        x1 = input['query']
        x2 = input['passage']
        length1 = melt.length(x1)
        length2 = melt.length(x2)
        #with tf.device('/cpu:0'):
        x1 = self.embedding(x1)
        x2 = self.embedding(x2)

        x = x1
        batch_size = melt.get_shape(x1, 0)

        num_units = [
            melt.get_shape(x, -1) if layer == 0 else 2 * self.num_units
            for layer in range(self.num_layers)
        ]
        #print('----------------length', tf.reduce_max(length), inputs.comment.shape)
        mask_fws = [
            melt.dropout(tf.ones([batch_size, 1, num_units[layer]],
                                 dtype=tf.float32),
                         keep_prob=self.keep_prob,
                         training=training,
                         mode=None) for layer in range(self.num_layers)
        ]
        mask_bws = [
            melt.dropout(tf.ones([batch_size, 1, num_units[layer]],
                                 dtype=tf.float32),
                         keep_prob=self.keep_prob,
                         training=training,
                         mode=None) for layer in range(self.num_layers)
        ]

        x = self.encode(x1,
                        length1,
                        x2,
                        length2,
                        mask_fws=mask_fws,
                        mask_bws=mask_bws)
        x = self.pooling(x, length1, length2)
        #x = self.pooling(x)

        if FLAGS.use_type:
            x = tf.concat([x, tf.expand_dims(tf.to_float(input['type']), 1)],
                          1)

        if not FLAGS.split_type:
            x = self.logits(x)
        else:
            x1 = self.logits(x)
            x2 = self.logits2(x)
            x = tf.cond(tf.cast(input['type'] == 0, tf.bool), lambda:
                        (x1 + x2) / 2., lambda: x2)

        return x
Exemple #5
0
 def call(self, x):
     self.step += 1
     if self.step == 0:
         self.dense = layers.Dense(
             melt.get_shape(x, -1) * self.ratio, self.activation,
             self.use_bais)
     return self.dense(x)
Exemple #6
0
  def encode(self, seq, seq_len=None, output_method='all'):
    with tf.variable_scope(self.scope):
      num_filters = self.num_units
      seqs = [seq]
      batch_size = melt.get_batch_size(seq)
     
      kernel_sizes = [3, 5, 7, 9, 11, 13]
      #kernel_sizes = [3] * 7
      assert self.num_layers <= len(kernel_sizes)

      for layer in range(self.num_layers):
        input_size_ = melt.get_shape(seq, -1) if layer == 0 else num_filters
        seq = melt.dropout(seq, self.keep_prob, self.is_train)
        seq = tf.layers.conv1d(seqs[-1], num_filters, kernel_size=kernel_sizes[layer], padding='same', activation=tf.nn.relu)
        # mask = melt.dropout(tf.ones([batch_size, 1, input_size_], dtype=tf.float32),
        #                   keep_prob=self.keep_prob, is_train=self.is_train, mode=None)
        #seq = tf.layers.conv1d(seqs[-1] * mask, num_filters, kernel_size=3, padding='same', activation=tf.nn.relu)
        #seq = tf.layers.conv1d(seqs[-1] * mask, num_filters, kernel_size=kernel_sizes[layer], padding='same', activation=tf.nn.relu)
        
        # if self.is_train and self.keep_prob < 1:
        #   seq = tf.nn.dropout(seq, self.keep_prob)
        #seq = melt.layers.batch_norm(seq, self.is_train, name='layer_%d' % layer)
        seqs.append(seq)
      
      outputs = tf.concat(seqs[1:], 2)
      # not do any dropout in convet just dropout outside 
      # if self.is_train and self.keep_prob < 1:
      #   outputs = tf.nn.dropout(outputs, self.keep_prob)

      # compact for rnn with sate return
      return melt.rnn.encode_outputs(outputs, seq_len, output_method)
    def call(self, input, training=False):
        x = input['rcontent'] if FLAGS.rcontent else input['content']
        #print(x.shape)
        batch_size = melt.get_shape(x, 0)
        length = melt.length(x)
        #with tf.device('/cpu:0'):
        x = self.embedding(x)

        x = self.encode(x, length, training=training)

        # must mask pooling when eval ? but seems much worse result
        #if not FLAGS.mask_pooling and training:
        if not FLAGS.mask_pooling:
            length = None
        x = self.pooling(x, length)

        if FLAGS.use_type:
            x = tf.concat([x, tf.expand_dims(tf.to_float(input['type']), 1)],
                          1)

        if not FLAGS.split_type:
            x = self.logits(x)
        else:
            x1 = self.logits(x)
            x2 = self.logits2(x)
            x = tf.cond(tf.cast(input['type'] == 0, tf.bool), lambda:
                        (x1 + x2) / 2., lambda: x2)

        return x
 def aug(x, x_mask):
     # print('---------------do unk aug')
     # print('ori_x....', x)
     if x_mask is None:
         x_mask = x > 0
     x_mask = tf.to_int64(x_mask)
     ratio = tf.random_uniform([
         1,
     ], 0, FLAGS.unk_aug_max_ratio)
     mask = tf.random_uniform(
         [melt.get_shape(x, 0),
          melt.get_shape(x, 1)]) > ratio
     mask = tf.to_int64(mask)
     rmask = FLAGS.unk_id * (1 - mask)
     x = (x * mask + rmask) * x_mask
     #print('aug_x....', x)
     return x
Exemple #9
0
def roc_auc_scores(y_pred, y_true):
    num_classes = melt.get_shape(y_pred, -1)
    y_preds = tf.split(y_pred, num_classes, axis=1)
    y_trues = tf.split(y_true, num_classes, axis=1)
    losses = []
    for y_pred, y_true in zip(y_preds, y_trues):
        losses.append(roc_auc_score(y_pred, y_true))
        #losses.append(art_loss(y_pred, y_true))
    return tf.stack(losses)
    def call(self, input, training=False):
        x = input['content']
        x = self.unk_aug(x, training=training)

        c_mask = tf.cast(x, tf.bool)
        batch_size = melt.get_shape(x, 0)
        c_len, max_c_len = melt.length2(x)

        if FLAGS.rnn_no_padding:
            logging.info('------------------no padding! train or eval')
            #c_len = tf.ones([batch_size], dtype=x.dtype) * tf.cast(melt.get_shape(x, -1), x.dtype)
            c_len = max_c_len

        x = self.encode(input, c_len, max_c_len, training=training)

        # not help
        if self.hier_encode is not None:
            x = self.hier_encode(x, c_len)

        # yes just using label emb..
        label_emb = self.label_embedding(None)
        label_seq = tf.tile(tf.expand_dims(label_emb, 0), [batch_size, 1, 1])
        label_seq = self.label_encode(label_seq,
                                      tf.ones([batch_size], dtype=tf.int64) *
                                      self.label_emb_height,
                                      training=training)

        for i in range(FLAGS.hop):
            x = self.att_dot_attentions[i](
                x,
                label_seq,
                mask=tf.ones([batch_size, self.label_emb_height], tf.bool),
                training=training)
            x = self.att_encodes[i](x, c_len, training=training)
            x = self.match_dot_attentions[i](x, mask=c_mask, training=training)
            x = self.match_encodes[i](x, c_len, training=training)

        x = self.pooling(x, c_len, calc_word_scores=self.debug)
        #x = self.pooling(x)

        # not help much
        if self.dense is not None:
            x = self.dense(x)
            x = self.dropout(x, training=training)

        if not FLAGS.use_label_emb:
            x = self.logits(x)
        else:
            x = self.label_dense(x)
            # TODO..
            x = melt.dot(x, self.label_embedding(None))

        x = tf.reshape(x, [batch_size, NUM_ATTRIBUTES, self.num_classes])

        return x
Exemple #11
0
 def call(self, x, y, training=False):
     self.step += 1
     #with tf.variable_scope(self.scope):
     res = tf.concat([x, y], axis=2)
     dim = melt.get_shape(res, -1)
     d_res = dropout(res, keep_prob=self.keep_prob, training=training)
     if self.step == 0:
         self.dense = layers.Dense(dim,
                                   use_bias=False,
                                   activation=tf.nn.sigmoid)
     gate = self.dense(d_res)
     return res * gate
Exemple #12
0
    def call(self, x, mask, training=False):
        self.step += 1
        x_ = x
        x = dropout(x, keep_prob=self.keep_prob, training=training)

        if self.step == 0:
            if not self.identity:
                self.linear = layers.Dense(melt.get_shape(x, -1),
                                           activation=tf.nn.relu)
            else:
                self.linear = None

        # NOTICE shared linear!
        if self.linear is not None:
            x = self.linear(x)

        scores = tf.matmul(x, tf.transpose(x, [0, 2, 1]))

        #  x = tf.constant([[[1,2,3], [4,5,6],[7,8,9]],[[1,2,3],[4,5,6],[7,8,9]]], dtype=tf.float32) # shape=(2, 3, 3)
        #  z = tf.matrix_set_diag(x, tf.zeros([2, 3]))
        if not self.diag:
            # TODO better dim
            dim0 = melt.get_shape(scores, 0)
            dim1 = melt.get_shape(scores, 1)
            scores = tf.matrix_set_diag(scores, tf.zeros([dim0, dim1]))

        if mask is not None:
            JX = melt.get_shape(x, 1)
            mask = tf.tile(tf.expand_dims(mask, axis=1), [1, JX, 1])
            scores = softmax_mask(scores, mask)

        alpha = tf.nn.softmax(scores)
        self.alpha = alpha

        x = tf.matmul(alpha, x)

        if self.combine is None:
            return y
        else:
            return self.combine(x_, x, training=training)
  def call(self, input, training=False):
    q = input['query']
    c = input['passage']
    q_len = melt.length(q)
    c_len = melt.length(c)
    q_mask = tf.cast(q, tf.bool)
    q_emb = self.embedding(q)
    c_emb = self.embedding(c)
    
    x = c_emb
    batch_size = melt.get_shape(x, 0)

    num_units = [melt.get_shape(x, -1) if layer == 0 else 2 * self.num_units for layer in range(self.num_layers)]
    mask_fws = [melt.dropout(tf.ones([batch_size, 1, num_units[layer]], dtype=tf.float32), keep_prob=self.keep_prob, training=training, mode=None) for layer in range(self.num_layers)]
    mask_bws = [melt.dropout(tf.ones([batch_size, 1, num_units[layer]], dtype=tf.float32), keep_prob=self.keep_prob, training=training, mode=None) for layer in range(self.num_layers)]
    
    c = self.encode(c_emb, c_len, mask_fws=mask_fws, mask_bws=mask_bws)
    q = self.encode(q_emb, q_len, mask_fws=mask_fws, mask_bws=mask_bws)

    qc_att = self.att_dot_attention(c, q, mask=q_mask, training=training)

    num_units = [melt.get_shape(qc_att, -1) if layer == 0 else 2 * self.num_units for layer in range(self.num_layers)]
    mask_fws = [melt.dropout(tf.ones([batch_size, 1, num_units[layer]], dtype=tf.float32), keep_prob=self.keep_prob, training=training, mode=None) for layer in range(1)]
    mask_bws = [melt.dropout(tf.ones([batch_size, 1, num_units[layer]], dtype=tf.float32), keep_prob=self.keep_prob, training=training, mode=None) for layer in range(1)]
    x = self.att_encode(qc_att, c_len, mask_fws=mask_fws, mask_bws=mask_bws)

    x = self.pooling(x, c_len)

    if FLAGS.use_type:
      x = tf.concat([x, tf.expand_dims(tf.to_float(input['type']), 1)], 1)

    if not FLAGS.split_type:
      x = self.logits(x)
    else:
      x1 = self.logits(x)
      x2 = self.logits2(x)
      x = tf.cond(tf.cast(input['type'] == 0, tf.bool), lambda: (x1 + x2) / 2., lambda: x2)
    
    return x
Exemple #14
0
    def call(self, input, c_len=None, max_c_len=None, training=False):
        self.step += 1
        x = input['content'] if isinstance(input, dict) else input
        batch_size = melt.get_shape(x, 0)
        model = modeling.BertModel(
            config=self.bert_config,
            is_training=training,
            input_ids=x,
            input_mask=(x > 0) if c_len is not None else None)

        if self.step == 0 and self.init_checkpoint:
            self.restore()
        x = model.get_sequence_output()
        return x
Exemple #15
0
 def call(self, x, training=False):
     self.step += 1
     if self.step == 0:
         n_state = melt.get_shape(x, -1)
         self.g = self.add_variable("g",
                                    shape=[n_state],
                                    initializer=tf.constant_initializer(1))
         self.b = self.add_variable("b",
                                    shape=[n_state],
                                    initializer=tf.constant_initializer(1))
     e, axis = self.e, self.axis
     u = tf.reduce_mean(x, axis=axis, keepdims=True)
     s = tf.reduce_mean(tf.square(x - u), axis=axis, keepdims=True)
     x = (x - u) * tf.rsqrt(s + e)
     x = x * self.g + self.b
     return x
Exemple #16
0
 def call(self, outputs, sequence_length=None, axis=1):
     self.step += 1
     if self.step == 0 and self.dense is None:
         self.dense = layers.Dense(melt.get_shape(outputs, -1),
                                   activation=self.activation)
     x = self.dense(outputs)
     logits = self.logits(x)
     alphas = tf.nn.softmax(
         logits) if sequence_length is None else melt.masked_softmax(
             logits, sequence_length)
     encoding = tf.reduce_sum(outputs * alphas, 1)
     # [batch_size, sequence_length, 1] -> [batch_size, sequence_length]
     self.alphas = tf.squeeze(alphas, -1)
     #self.alphas = alphas
     tf.add_to_collection('self_attention', self.alphas)
     return encoding
Exemple #17
0
  def encode(self, seq, seq_len=None, output_method='all'):
    with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE):
      if self.use_position_encoding:
        hidden_size = melt.get_shape(seq, -1)
        # Scale embedding by the sqrt of the hidden size
        seq *= hidden_size ** 0.5

        # Create binary array of size [batch_size, length]
        # where 1 = padding, 0 = not padding
        padding = tf.to_float(tf.sequence_mask(seq_len))

        # Set all padding embedding values to 0
        seq *= tf.expand_dims(padding, -1)

        pos_encoding = model_utils.get_position_encoding(
            tf.shape(seq)[1], tf.shape(seq)[-1])
        seq = seq + pos_encoding

      num_filters = self.num_filters
      seqs = [seq]
      #batch_size = melt.get_batch_size(seq)
     
      #kernel_sizes = [3, 5, 7, 9, 11, 13]
      kernel_sizes = [3] * 7
      assert self.num_layers <= len(kernel_sizes)

      for layer in range(self.num_layers):
        #input_size_ = melt.get_shape(seq, -1) if layer == 0 else num_filters
        seq = melt.dropout(seq, self.keep_prob, self.is_train)
        seq = tf.layers.conv1d(seqs[-1], num_filters, kernel_size=kernel_sizes[layer], padding='same', activation=tf.nn.relu)
        # mask = melt.dropout(tf.ones([batch_size, 1, input_size_], dtype=tf.float32),
        #                   keep_prob=self.keep_prob, is_train=self.is_train, mode=None)
        #seq = tf.layers.conv1d(seqs[-1] * mask, num_filters, kernel_size=3, padding='same', activation=tf.nn.relu)
        #seq = tf.layers.conv1d(seqs[-1] * mask, num_filters, kernel_size=kernel_sizes[layer], padding='same', activation=tf.nn.relu)
        
        # if self.is_train and self.keep_prob < 1:
        #   seq = tf.nn.dropout(seq, self.keep_prob)
        #seq = melt.layers.batch_norm(seq, self.is_train, name='layer_%d' % layer)
        seqs.append(seq)
      
      outputs = tf.concat(seqs[1:], 2)
      # not do any dropout in convet just dropout outside 
      # if self.is_train and self.keep_prob < 1:
      #   outputs = tf.nn.dropout(outputs, self.keep_prob)

      # compact for rnn with sate return
      return melt.rnn.encode_outputs(outputs, seq_len, output_method)
Exemple #18
0
    def call(self, input, c_len=None, max_c_len=None, training=False):
        assert isinstance(input, dict)
        x = input['content']

        batch_size = melt.get_shape(x, 0)
        if c_len is None or max_c_len is None:
            c_len, max_c_len = melt.length2(x)

        if self.rnn_no_padding:
            logging.info('------------------no padding! train or eval')
            c_len = max_c_len

        x = self.embedding(x)

        if FLAGS.use_char:
            cx = input['char']

            cx = tf.reshape(cx, [batch_size * max_c_len, FLAGS.char_limit])
            chars_len = melt.length(cx)
            cx = self.char_embedding(cx)
            cx = self.char_encode(cx, chars_len, training=training)
            cx = self.char_pooling(cx, chars_len)
            cx = tf.reshape(cx, [batch_size, max_c_len, 2 * self.num_units])

            if self.char_combiner == 'concat':
                x = tf.concat([x, cx], axis=2)
            elif self.char_combiner == 'sfu':
                x = self.char_sfu_combine(x, cx, training=training)

        if FLAGS.use_pos:
            px = input['pos']
            px = self.pos_embedding(px)
            x = tf.concat([x, px], axis=2)

        if FLAGS.use_ner:
            nx = input['ner']
            nx = self.ner_embedding(nx)
            x = tf.concat([x, nx], axis=2)

        x = self.encode(x, c_len, training=training)

        return x
Exemple #19
0
 def call(self, x, fusions, training=False):
     self.step += 1
     assert len(fusions) > 0
     vectors = tf.concat(
         [x] + fusions, axis=-1
     )  # size = [batch_size, ..., input_dim * (len(fusion_vectors) + 1)]
     dim = melt.get_shape(x, -1)
     dv = dropout(vectors, keep_prob=self.keep_prob, training=training)
     if self.step == 0:
         self.composition_dense = layers.Dense(dim,
                                               use_bias=True,
                                               activation=tf.nn.tanh,
                                               name='compostion_dense')
         self.gate_dense = layers.Dense(dim,
                                        use_bias=True,
                                        activation=tf.nn.sigmoid,
                                        name='gate_dense')
     r = self.composition_dense(dv)
     g = self.gate_dense(dv)
     return g * r + (1 - g) * x
Exemple #20
0
    def call(self, input, training=False):
        self.step += 1
        x = input['content']
        x = self.unk_aug(x, training=training)
        batch_size = melt.get_shape(x, 0)
        c_mask = tf.cast(x, tf.bool)
        # TODO move to __init__
        model = modeling.BertModel(config=self.bert_config,
                                   is_training=training,
                                   input_ids=x,
                                   input_mask=c_mask,
                                   use_one_hot_embeddings=FLAGS.use_tpu)

        if self.step == 0 and self.init_checkpoint:
            self.restore()
        c_len = melt.length(x)

        if FLAGS.encoder_output_method == 'last':
            x = model.get_pooled_output()
        else:
            x = model.get_sequence_output()

        if training:
            x = tf.nn.dropout(x, keep_prob=0.9)

        logging.info('---------------bert_lr_ratio', FLAGS.bert_lr_ratio)
        x = x * FLAGS.bert_lr_ratio + tf.stop_gradient(x) * (
            1 - FLAGS.bert_lr_ratio)

        if FLAGS.transformer_add_rnn:
            assert FLAGS.encoder_output_method != 'last'
            x = self.rnn_encode(x, c_len)

        if FLAGS.encoder_output_method != 'last':
            x = self.pooling(x, c_len)
            x2 = model.get_pooled_output()
            x = tf.concat([x, x2], -1)
        x = self.logits(x)
        x = tf.reshape(x, [batch_size, NUM_ATTRIBUTES, NUM_CLASSES])
        return x
Exemple #21
0
def focal_loss(target_tensor,
               prediction_tensor,
               weights=None,
               alpha=0.25,
               gamma=2):
    r"""Compute focal loss for predictions.
        Multi-labels Focal loss formula:
            FL = -alpha * (z-p)^gamma * log(p) -(1-alpha) * p^gamma * log(1-p)
                 ,which alpha = 0.25, gamma = 2, p = sigmoid(x), z = target_tensor.
    Args:
     prediction_tensor: A float tensor of shape [batch_size, num_anchors,
        num_classes] representing the predicted logits for each class
     target_tensor: A float tensor of shape [batch_size, num_anchors,
        num_classes] representing one-hot encoded classification targets
     weights: A float tensor of shape [batch_size, num_anchors]
     alpha: A scalar tensor for focal loss alpha hyper-parameter
     gamma: A scalar tensor for focal loss gamma hyper-parameter
    Returns:
        loss: A (scalar) tensor representing the value of the loss function
    """
    sigmoid_p = tf.nn.sigmoid(prediction_tensor)
    zeros = array_ops.zeros_like(sigmoid_p, dtype=sigmoid_p.dtype)

    #target_tensor = tf.to_float(target_tensor)
    target_tensor = tf.one_hot(target_tensor,
                               melt.get_shape(prediction_tensor, -1))

    # For poitive prediction, only need consider front part loss, back part is 0;
    # target_tensor > zeros <=> z=1, so poitive coefficient = z - p.
    pos_p_sub = array_ops.where(target_tensor > zeros,
                                target_tensor - sigmoid_p, zeros)

    # For negative prediction, only need consider back part loss, front part is 0;
    # target_tensor > zeros <=> z=1, so negative coefficient = 0.
    neg_p_sub = array_ops.where(target_tensor > zeros, zeros, sigmoid_p)
    per_entry_cross_ent = - alpha * (pos_p_sub ** gamma) * tf.log(tf.clip_by_value(sigmoid_p, 1e-8, 1.0)) \
                          - (1 - alpha) * (neg_p_sub ** gamma) * tf.log(tf.clip_by_value(1.0 - sigmoid_p, 1e-8, 1.0))
    #return tf.reduce_sum(per_entry_cross_ent)
    return tf.reduce_mean(per_entry_cross_ent)
Exemple #22
0
  def call(self, seq, seq_len=None, masks=None, 
           output_method=OutputMethod.all, 
           training=False):
    if self.use_position_encoding:
      hidden_size = melt.get_shape(seq, -1)
      # Scale embedding by the sqrt of the hidden size
      seq *= hidden_size ** 0.5

      # Create binary array of size [batch_size, length]
      # where 1 = padding, 0 = not padding
      padding = tf.to_float(tf.sequence_mask(seq_len))

      # Set all padding embedding values to 0
      seq *= tf.expand_dims(padding, -1)

      pos_encoding = model_utils.get_position_encoding(
          tf.shape(seq)[1], tf.shape(seq)[-1])
      seq = seq + pos_encoding

    num_filters = self.num_filters
    seqs = [seq]
    #batch_size = melt.get_batch_size(seq)

    for layer in range(self.num_layers):
      if masks is None:
        seq_ = melt.dropout(seq, self.keep_prob, training)
      else:
        seq_ = seq * masks[layer]
      seq = self.conv1ds[layer](seq_)
      seqs.append(seq)
    
    outputs = tf.concat(seqs[1:], 2)
    # not do any dropout in convet just dropout outside 
    # if self.is_train and self.keep_prob < 1:
    #   outputs = tf.nn.dropout(outputs, self.keep_prob)

    # compact for rnn with sate return
    return melt.rnn.encode_outputs(outputs, seq_len, output_method)
Exemple #23
0
    def encode(self,
               inputs,
               seq_len,
               emb=None,
               concat_layers=True,
               output_method=OutputMethod.all):
        if emb is not None:
            inputs = tf.nn.embedding_lookup(emb, inputs)

        outputs = [inputs]
        keep_prob = self.keep_prob
        num_units = self.num_units
        is_train = self.is_train

        with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE):
            for layer in range(self.num_layers):
                input_size_ = melt.get_shape(
                    inputs, -1) if layer == 0 else 2 * num_units
                batch_size = melt.get_batch_size(inputs)
                with tf.variable_scope("fw_{}".format(layer)):
                    gru_fw = tf.contrib.rnn.GRUCell(num_units)
                    if not self.share_dropout:
                        mask_fw = dropout(tf.ones([batch_size, 1, input_size_],
                                                  dtype=tf.float32),
                                          keep_prob=keep_prob,
                                          is_train=is_train,
                                          mode=self.dropout_mode)
                    else:
                        if self.dropout_mask_fw[layer] is None:
                            mask_fw = dropout(
                                tf.ones([batch_size, 1, input_size_],
                                        dtype=tf.float32),
                                keep_prob=keep_prob,
                                is_train=is_train,
                                mode=self.dropout_mode)
                            self.dropout_mask_fw[layer] = mask_fw
                        else:
                            mask_fw = self.dropout_mask_fw[layer]
                    if self.train_init_state:
                        if self.init_fw[layer] is None:
                            self.init_fw[layer] = tf.tile(
                                tf.get_variable("init_state", [1, num_units],
                                                tf.float32,
                                                tf.zeros_initializer()),
                                [batch_size, 1])
                    out_fw, state = tf.nn.dynamic_rnn(
                        gru_fw,
                        outputs[-1] * mask_fw,
                        seq_len,
                        initial_state=self.init_fw[layer],
                        dtype=tf.float32)
                with tf.variable_scope("bw_{}".format(layer)):
                    gru_bw = tf.contrib.rnn.GRUCell(num_units)
                    if not self.share_dropout:
                        mask_bw = dropout(tf.ones([batch_size, 1, input_size_],
                                                  dtype=tf.float32),
                                          keep_prob=keep_prob,
                                          is_train=is_train,
                                          mode=self.dropout_mode)
                    else:
                        if self.dropout_mask_bw[layer] is None:
                            mask_bw = dropout(
                                tf.ones([batch_size, 1, input_size_],
                                        dtype=tf.float32),
                                keep_prob=keep_prob,
                                is_train=is_train,
                                mode=self.dropout_mode)
                            self.dropout_mask_bw[layer] = mask_bw
                        else:
                            mask_bw = self.dropout_mask_bw[layer]
                    if self.train_init_state:
                        if self.init_bw[layer] is None:
                            self.init_bw[layer] = tf.tile(
                                tf.get_variable("init_state", [1, num_units],
                                                tf.float32,
                                                tf.zeros_initializer()),
                                [batch_size, 1])
                    inputs_bw = tf.reverse_sequence(outputs[-1] * mask_bw,
                                                    seq_lengths=seq_len,
                                                    seq_dim=1,
                                                    batch_dim=0)
                    out_bw, _ = tf.nn.dynamic_rnn(
                        gru_bw,
                        inputs_bw,
                        seq_len,
                        initial_state=self.init_bw[layer],
                        dtype=tf.float32)
                    out_bw = tf.reverse_sequence(out_bw,
                                                 seq_lengths=seq_len,
                                                 seq_dim=1,
                                                 batch_dim=0)
                outputs.append(tf.concat([out_fw, out_bw], axis=2))

        if concat_layers:
            res = tf.concat(outputs[1:], axis=2)
        else:
            res = outputs[-1]
        res = encode_outputs(res, seq_len, output_method=output_method)
        return res
Exemple #24
0
    def call(self,
             inputs,
             sequence_length,
             inputs2,
             sequence_length2,
             mask_fws,
             mask_bws,
             concat_layers=True,
             output_method=OutputMethod.all,
             training=False):

        outputs = [inputs]
        outputs2 = [inputs2]

        keep_prob = self.keep_prob
        num_units = self.num_units
        batch_size = melt.get_batch_size(inputs)

        for layer in range(self.num_layers):
            input_size_ = melt.get_shape(inputs,
                                         -1) if layer == 0 else 2 * num_units

            gru_fw, gru_bw = self.gru_fws[layer], self.gru_bws[layer]

            if self.train_init_state:
                init_fw = self.init_fw_layer(layer, batch_size)
            else:
                init_fw = None

            mask_fw = mask_fws[layer]
            out_fw, state_fw = gru_fw(outputs[-1] * mask_fw, init_fw)
            out_fw2, state_fw2 = gru_fw(outputs2[-1] * mask_fw, state_fw)

            mask_bw = mask_bws[layer]
            inputs_bw = tf.reverse_sequence(
                outputs[-1] * mask_bw,
                sequence_lengthgths=sequence_length,
                seq_axis=1,
                batch_axis=0)
            inputs_bw2 = tf.reverse_sequence(
                outputs2[-1] * mask_bw,
                sequence_lengthgths=sequence_length2,
                seq_axis=1,
                batch_axis=0)

            if self.train_init_state:
                init_bw = self.init_bw_layer(layer, batch_size)
            else:
                init_bw = None

            out_bw, state_bw = gru_bw(inputs_bw, init_bw)
            out_bw2, state_bw2 = gru_bw(inputs_bw2, state_bw)

            outputs.append(tf.concat([out_fw, out_bw], axis=2))
            outputs2.append(tf.concat([out_fw2, out_bw2], axis=2))

        if concat_layers:
            res = tf.concat(outputs[1:], axis=2)
            res2 = tf.concat(outputs2[1:], axis=2)
        else:
            res = outputs[-1]
            res2 = outpus2[-1]

        res = tf.concat([res, res2], axis=1)

        res = encode_outputs(res,
                             output_method=output_method,
                             sequence_length=sequence_length)

        self.state = (state_fw2, state_bw2)
        return res
Exemple #25
0
    def call(self, input):
        # TODO tf2 keras seem to auto append last dim so need this
        melt.try_squeeze_dim(input)

        if not FLAGS.batch_parse:
            util.adjust(input, self.mode)

        # print(input)

        embs = []

        if 'history' in input:
            hlen = melt.length(input['history'])
            hlen = tf.math.maximum(hlen, 1)

        bs = melt.get_shape(input['did'], 0)

        # user
        if FLAGS.use_uid:
            uemb = self.uemb(input['uid'])
            embs += [uemb]

        if FLAGS.use_did:
            demb = self.demb(input['did'])
            embs += [demb]

        if FLAGS.use_time_emb:
            embs += [
                self.hour_emb(input['hour']),
                self.weekday_emb(input['weekday']),
            ]

        if FLAGS.use_fresh_emb:
            fresh = input['fresh']
            fresh_day = tf.cast(fresh / (3600 * 12), fresh.dtype)
            fresh_hour = tf.cast(fresh / 3600, fresh.dtype)
            embs += [
                self.fresh_day_emb(fresh_day),
                self.fresh_hour_emb(fresh_hour)
            ]

        if FLAGS.use_position_emb:
            embs += [self.position_emb(input['position'])]

        if FLAGS.use_news_info and 'cat' in input:
            # print('------entity_emb', self.entity_emb.emb.weights) # check if trainable is fixed in eager mode
            embs += [
                self.cat_emb(input['cat']),
                self.scat_emb(input['sub_cat']),
                self.pooling(self.entity_type_emb(input['title_entity_types']),
                             melt.length(input['title_entity_types'])),
                self.pooling(
                    self.entity_type_emb(input['abstract_entity_types']),
                    melt.length(input['abstract_entity_types'])),
            ]
            if FLAGS.use_entities and 'title_entities' in input:
                embs += [
                    self.pooling(self.entity_emb(input['title_entities']),
                                 melt.length(input['title_entities'])),
                    self.pooling(self.entity_emb(input['abstract_entities']),
                                 melt.length(input['abstract_entities'])),
                ]

        if FLAGS.use_history_info and 'history_cats' in input:
            embs += [
                self.his_simple_pooling(self.cat_emb(input['history_cats']),
                                        melt.length(input['history_cats'])),
                self.his_simple_pooling(
                    self.scat_emb(input['history_sub_cats']),
                    melt.length(input['history_sub_cats'])),
            ]
            if FLAGS.use_history_entities:
                try:
                    embs += [
                        self.his_simple_pooling(
                            self.entity_type_emb(
                                input['history_title_entity_types']),
                            melt.length(input['history_title_entity_types'])),
                        self.his_simple_pooling(
                            self.entity_type_emb(
                                input['history_abstract_entity_types']),
                            melt.length(
                                input['history_abstract_entity_types'])),
                    ]
                    if FLAGS.use_entities and 'title_entities' in inpout:
                        embs += [
                            self.his_simple_pooling(
                                self.entity_emb(
                                    input['history_title_entities']),
                                melt.length(input['history_title_entities'])),
                            self.his_simple_pooling(
                                self.entity_emb(
                                    input['history_abstract_entities']),
                                melt.length(
                                    input['history_abstract_entities'])),
                        ]
                except Exception:
                    pass

        if FLAGS.use_history and FLAGS.use_did:
            dids = input['history']

            if FLAGS.his_strategy == 'bst' or FLAGS.his_pooling == 'mhead':
                mask = tf.cast(tf.equal(dids, 0), dids.dtype)
                dids += mask
                hlen = tf.ones_like(hlen) * 50
            hembs = self.demb(dids)

            his_embs = hembs
            his_embs = self.his_encoder(his_embs, hlen)
            self.his_embs = his_embs

            his_emb = self.his_pooling(demb, his_embs, hlen)

            embs += [his_emb]

        if FLAGS.use_title:
            cur_title = self.title_encoder(self.title_lookup(input['ori_did']))
            dids = input['ori_history']
            if FLAGS.max_titles:
                dids = dids[:, :FLAGS.max_titles]
            his_title = self.titles_encoder(self.title_lookup(dids), hlen,
                                            cur_title)
            embs += [cur_title, his_title]

        # 用impression id 会dev test不一致 不直接用id
        if FLAGS.use_impressions:
            embs += [self.mean_pooling(self.demb(input['impressions']))]

        if FLAGS.use_dense:
            dense_emb = self.deal_dense(input)
            embs += [dense_emb]

        # logging.debug('-----------embs:', len(embs))
        embs = tf.stack(embs, axis=1)

        if FLAGS.batch_norm:
            embs = self.batch_norm(embs)

        if FLAGS.l2_normalize_before_pooling:
            x = tf.math.l2_normalize(embs)

        x = self.feat_pooling(embs)

        if FLAGS.dropout:
            x = self.dropout(x)

        if FLAGS.use_dense:
            x = tf.concat([x, dense_emb], axis=1)

        if FLAGS.use_his_concat:
            x = tf.concat([x, his_concat], axis=1)

        x = self.mlp(x)

        self.logit = self.dense(x)
        self.prob = tf.math.sigmoid(self.logit)
        self.impression_id = input['impression_id']
        self.position = input['position']
        self.history_len = input['hist_len']
        self.impression_len = input['impression_len']
        self.input_ = input
        return self.logit
Exemple #26
0
    def call(self,
             x,
             sequence_length=None,
             mask_fws=None,
             mask_bws=None,
             concat_layers=None,
             output_method=None,
             training=False):

        concat_layers = concat_layers or self.concat_layers
        output_mehtod = output_method or self.output_method

        if self.residual_connect:
            x = self.residual_linear(x)

        outputs = [x]

        #states = []
        keep_prob = self.keep_prob
        num_units = self.num_units
        batch_size = melt.get_batch_size(x)

        if sequence_length is None:
            len_ = melt.get_shape(x, 1)
            sequence_length = tf.ones([
                batch_size,
            ], dtype=tf.int64) * len_

        for layer in range(self.num_layers):
            input_size_ = melt.get_shape(x,
                                         -1) if layer == 0 else 2 * num_units

            gru_fw, gru_bw = self.gru_fws[layer], self.gru_bws[layer]

            if self.train_init_state:
                #init_fw = tf.tile(self.init_fw[layer], [batch_size, 1])
                #init_fw = tf.tile(self.init_fw_layer(layer), [batch_size, 1])
                init_fw = self.init_fw_layer(layer, batch_size)
                if self.cell == 'lstm':
                    init_fw = (init_fw, self.init_fw2_layer(layer, batch_size))
            else:
                init_fw = None

            if self.recurrent_dropout:
                if mask_fws is not None:
                    mask_fw = mask_fws[layer]
                else:
                    if not self.share_dropout:
                        mask_fw = dropout(tf.ones([batch_size, 1, input_size_],
                                                  dtype=tf.float32),
                                          keep_prob=keep_prob,
                                          training=training,
                                          mode=None)
                    else:
                        if self.dropout_mask_fw[layer] is None or (
                                tf.executing_eagerly() and batch_size !=
                                self.dropout_mask_fw[layer].shape[0]):
                            mask_fw = dropout(
                                tf.ones([batch_size, 1, input_size_],
                                        dtype=tf.float32),
                                keep_prob=keep_prob,
                                training=training,
                                mode=None)
                            self.dropout_mask_fw[layer] = mask_fw
                        else:
                            mask_fw = self.dropout_mask_fw[layer]

                inputs_fw = outputs[-1] * mask_fw
            else:
                inputs_fw = dropout(outputs[-1],
                                    keep_prob=keep_prob,
                                    training=training,
                                    mode=None)

            # https://stackoverflow.com/questions/48233400/lstm-initial-state-from-dense-layer
            # gru and lstm different ... state lstm need tuple (,) states as input state\
            if self.cell == 'gru':
                out_fw, state_fw = gru_fw(inputs_fw, init_fw)
            else:
                out_fw, state_fw1, state_fw2 = gru_fw(inputs_fw, init_fw)
                state_fw = (state_fw1, state_fw2)

            if self.train_init_state:
                #init_bw = tf.tile(self.init_bw[layer], [batch_size, 1])
                #init_bw = tf.tile(self.init_bw_layer(layer), [batch_size, 1])
                init_bw = self.init_bw_layer(layer, batch_size)
                if self.cell == 'lstm':
                    init_bw = (init_bw, self.init_bw2_layer(layer, batch_size))
            else:
                init_bw = None

            if mask_bws is not None:
                mask_bw = mask_bws[layer]
            else:
                if not self.share_dropout:
                    mask_bw = dropout(tf.ones([batch_size, 1, input_size_],
                                              dtype=tf.float32),
                                      keep_prob=keep_prob,
                                      training=training,
                                      mode=None)
                else:
                    if self.dropout_mask_bw[layer] is None or (
                            tf.executing_eagerly() and batch_size !=
                            self.dropout_mask_bw[layer].shape[0]):
                        mask_bw = dropout(tf.ones([batch_size, 1, input_size_],
                                                  dtype=tf.float32),
                                          keep_prob=keep_prob,
                                          training=training,
                                          mode=None)
                        self.dropout_mask_bw[layer] = mask_bw
                    else:
                        mask_bw = self.dropout_mask_bw[layer]

            if self.recurrent_dropout:
                inputs_bw = outputs[-1] * mask_bw
            else:
                if self.bw_dropout:
                    inputs_bw = dropout(outputs[-1],
                                        keep_prob=keep_prob,
                                        training=training,
                                        mode=None)
                else:
                    inputs_bw = inputs_fw

            inputs_bw = tf.reverse_sequence(inputs_bw,
                                            seq_lengths=sequence_length,
                                            seq_axis=1,
                                            batch_axis=0)

            if self.cell == 'gru':
                out_bw, state_bw = gru_bw(inputs_bw, init_bw)
            else:
                out_bw, state_bw1, state_bw2 = gru_bw(inputs_bw, init_bw)
                state_bw = (state_bw1, state_bw2)

            out_bw = tf.reverse_sequence(out_bw,
                                         seq_lengths=sequence_length,
                                         seq_axis=1,
                                         batch_axis=0)

            outputs.append(tf.concat([out_fw, out_bw], axis=2))
            if self.residual_connect:
                outputs[-1] = self.batch_norm(outputs[-2] + outputs[-1])

        if concat_layers:
            res = tf.concat(outputs[1:], axis=2)
        else:
            res = outputs[-1]

        res = encode_outputs(res,
                             output_method=output_method,
                             sequence_length=sequence_length)

        self.state = (state_fw, state_bw)
        if not self.return_state:
            return res
        else:
            return res, self.state
Exemple #27
0
 def call(self, outputs, sequence_length=None, axis=1):
     x = melt.top_k_pooling(outputs, self.top_k, sequence_length,
                            axis).values
     return tf.reshape(x, [-1, melt.get_shape(outputs, -1) * self.top_k])
Exemple #28
0
  def take_step(self, i, prev, state):
    print('-------------i', i)
    if self.output_fn is not None:
      #[batch_size * beam_size, num_units] -> [batch_size * beam_size, num_classes]
      try:
        output = self.output_fn(prev)
      except Exception:
        output = self.output_fn(prev, state)
    else:
      output = prev

    self.output = output

    #[batch_size * beam_size, num_classes], here use log sofmax
    if self.need_softmax:
      logprobs = tf.nn.log_softmax(output)
    else:
      logprobs = tf.log(tf.maximum(output, 1e-12))
    
    if self.num_classes is None:
      self.num_classes = tf.shape(logprobs)[1]

    #->[batch_size, beam_size, num_classes]
    logprobs_batched = tf.reshape(logprobs,
                                  [-1, self.beam_size, self.num_classes])
    logprobs_batched.set_shape((None, self.beam_size, None))
    
    # Note: masking out entries to -inf plays poorly with top_k, so just subtract out a large number.
    nondone_mask = tf.reshape(
        tf.cast(
          tf.equal(tf.range(self.num_classes), self.done_token),
          tf.float32) * -1e18,
        [1, 1, self.num_classes])

    if self.past_logprobs is None:
      #[batch_size, beam_size, num_classes] -> [batch_size, num_classes]
      #-> past_logprobs[batch_size, beam_size], indices[batch_size, beam_size]
      self.past_logprobs, indices = tf.nn.top_k(
          (logprobs_batched + nondone_mask)[:, 0, :],
          self.beam_size)
      step_logprobs = self.past_logprobs
    else:
      #logprobs_batched [batch_size, beam_size, num_classes] -> [batch_size, beam_size, num_classes]  
      #past_logprobs    [batch_size, beam_size] -> [batch_size, beam_size, 1]
      step_logprobs_batched = logprobs_batched
      logprobs_batched = logprobs_batched + tf.expand_dims(self.past_logprobs, 2)


      #get [batch_size, beam_size] each
      self.past_logprobs, indices = tf.nn.top_k(
          #[batch_size, beam_size * num_classes]
          tf.reshape(logprobs_batched + nondone_mask, 
                     [-1, self.beam_size * self.num_classes]),
          self.beam_size)  

      #get current step logprobs [batch_size, beam_size]
      step_logprobs = tf.gather_nd(tf.reshape(step_logprobs_batched, 
                                              [-1, self.beam_size * self.num_classes]), 
                                   melt.batch_values_to_indices(indices))

    # For continuing to the next symbols [batch_size, beam_size]
    symbols = indices % self.num_classes
    #from wich beam it comes  [batch_size, beam_size]
    parent_refs = indices // self.num_classes
    
    if self.past_symbols is None:
      #here when i == 1, when i==0 will not do take step it just do one rnn() get output and use it for i==1 here
      #here will not need to gather state for inital state of each beam is the same
      #[batch_size, beam_size] -> [batch_size, beam_size, 1]
      self.past_symbols = tf.expand_dims(symbols, 2)
      self.past_step_logprobs = tf.expand_dims(step_logprobs, 2)
    else:
      # NOTE: outputing a zero-length sequence is not supported for simplicity reasons
      #hasky/jupter/tensorflow/beam-search2.ipynb below for mergeing path
      #here when i >= 2
      # tf.reshape(
      #           (tf.range(3 * 5) // 5) * 5,
      #           [3, 5]
      #       ).eval()
      # array([[ 0,  0,  0,  0,  0],
      #        [ 5,  5,  5,  5,  5],
      #        [10, 10, 10, 10, 10]], dtype=int32)
      parent_refs_offsets = tf.reshape(
          (tf.range(self.batch_size * self.beam_size) 
           // self.beam_size) * self.beam_size,
          [self.batch_size, self.beam_size])
      
      #self.past_symbols [batch_size, beam_size, i - 1] -> past_symbols_batch_major [batch_size * beam_size, i - 1]
      past_symbols_batch_major = tf.reshape(self.past_symbols, [-1, i-1])

      past_step_logprobs_batch_major = tf.reshape(self.past_step_logprobs, [-1, i - 1])
     
      #[batch_size, beam_size]
      past_indices = parent_refs + parent_refs_offsets 
      #-> [batch_size, beam_size, i - 1]  
      beam_past_symbols = tf.gather(past_symbols_batch_major,            #[batch_size * beam_size, i - 1]
                                    past_indices                         #[batch_size, beam_size]
                                    )

      beam_past_step_logprobs = tf.gather(past_step_logprobs_batch_major, past_indices)

      #we must also choose corresponding past state as new start
      past_indices = tf.reshape(past_indices, [-1])

      #TODO not support tf.TensorArray right now, can not use aligment_history in attention_wrapper
      def try_gather(x, indices):
        #if isinstance(x, tf.Tensor) and x.shape.ndims >= 2:
        assert isinstance(x, tf.Tensor)
        if x.shape.ndims >= 2:
          return tf.gather(x, indices)
        else:
          return x

      state = nest.map_structure(lambda x: try_gather(x, past_indices), state)

      if hasattr(state, 'alignments'):
        attention_size = melt.get_shape(state.alignments, -1)
        alignments = tf.reshape(state.alignments, [-1, self.beam_size, attention_size])
        print('alignments', alignments)

      if not self.fast_greedy:
        #[batch_size, beam_size, max_len]
        path = tf.concat([self.past_symbols, 
                          tf.ones_like(tf.expand_dims(symbols, 2)) * self.done_token,
                          tf.tile(tf.ones_like(tf.expand_dims(symbols, 2)) * self.pad_token, 
                          [1, 1, self.max_len - i])], 2)

        step_logprobs_path = tf.concat([self.past_step_logprobs, 
                                        tf.expand_dims(step_logprobs_batched[:, :, self.done_token], 2),
                                        tf.tile(tf.ones_like(tf.expand_dims(step_logprobs, 2)) * -float('inf'), 
                                                [1, 1, self.max_len - i])], 2)

        #[batch_size, 1, beam_size, max_len]
        path = tf.expand_dims(path, 1)
        step_logprobs_path = tf.expand_dims(step_logprobs_path, 1)
        self.paths_list.append(path)
        self.step_logprobs_list.append(step_logprobs_path)

      #[batch_size * beam_size, i - 1] -> [batch_size, bam_size, i] the best beam_size paths until step i
      self.past_symbols = tf.concat([beam_past_symbols, tf.expand_dims(symbols, 2)], 2)
      self.past_step_logprobs = tf.concat([beam_past_step_logprobs, tf.expand_dims(step_logprobs, 2)], 2)

      # For finishing the beam 
      #[batch_size, beam_size]
      logprobs_done = logprobs_batched[:, :, self.done_token]
      if not self.fast_greedy:
        self.logprobs_list.append(logprobs_done / i ** self.length_normalization_factor)
      else:
        done_parent_refs = tf.cast(tf.argmax(logprobs_done, 1), tf.int32)
        done_parent_refs_offsets = tf.range(self.batch_size) * self.beam_size

        done_past_symbols = tf.gather(past_symbols_batch_major,
                                      done_parent_refs + done_parent_refs_offsets)

        #[batch_size, max_len]
        symbols_done = tf.concat([done_past_symbols,
                                     tf.ones_like(done_past_symbols[:,0:1]) * self.done_token,
                                     tf.tile(tf.zeros_like(done_past_symbols[:,0:1]),
                                             [1, self.max_len - i])
                                    ], 1)

        #[batch_size, beam_size] -> [batch_size,]
        logprobs_done_max = tf.reduce_max(logprobs_done, 1)
      
        if self.length_normalization_factor > 0:
          logprobs_done_max /= i ** self.length_normalization_factor

        #[batch_size, max_len]
        self.finished_beams = tf.where(logprobs_done_max > self.logprobs_finished_beams,
                                       symbols_done,
                                       self.finished_beams)

        self.logprobs_finished_beams = tf.maximum(logprobs_done_max, self.logprobs_finished_beams)

    #->[batch_size * beam_size,]
    symbols_flat = tf.reshape(symbols, [-1])

    self.final_state = state 
    return symbols_flat, state 
Exemple #29
0
 def _get_aligments(self, state):
   attention_size = melt.get_shape(state.alignments, -1)
   alignments = tf.reshape(state.alignments, [-1, self.beam_size, attention_size])
   return alignments
Exemple #30
0
def argtopk_pooling(outputs, top_k, sequence_length=None, axis=1):
    x = top_k_pooling(outputs, top_k, sequence_length, axis).indices
    #return tf.reshape(x, [melt.get_shape(outputs, 0), -1])
    return tf.reshape(x, [-1, melt.get_shape(outputs, -1) * top_k])