Exemplo n.º 1
0
def pad(text, start_id=None, end_id=None):
    print('Pad with start_id', start_id, ' end_id', end_id)
    need_start_mark = start_id is not None
    need_end_mark = end_id is not None
    if not need_start_mark and not need_end_mark:
        return text, melt.length(text)

    batch_size = tf.shape(text)[0]
    zero_pad = tf.zeros([batch_size, 1], dtype=text.dtype)

    sequence_length = melt.length(text)

    if not need_start_mark:
        text = tf.concat([text, zero_pad], 1)
    else:
        if need_start_mark:
            start_pad = zero_pad + start_id
            if need_end_mark:
                text = tf.concat([start_pad, text, zero_pad], 1)
            else:
                text = tf.concat([start_pad, text], 1)
            sequence_length += 1

    if need_end_mark:
        text = melt.dynamic_append_with_length(
            text, sequence_length, tf.constant(end_id, dtype=text.dtype))
        sequence_length += 1

    return text, sequence_length
Exemplo n.º 2
0
    def call(self, input, training=False):
        x1 = input['query']
        x2 = input['passage']
        length1 = melt.length(x1)
        length2 = melt.length(x2)
        #with tf.device('/cpu:0'):
        x1 = self.embedding(x1)
        x2 = self.embedding(x2)

        x = x1
        batch_size = melt.get_shape(x1, 0)

        num_units = [
            melt.get_shape(x, -1) if layer == 0 else 2 * self.num_units
            for layer in range(self.num_layers)
        ]
        #print('----------------length', tf.reduce_max(length), inputs.comment.shape)
        mask_fws = [
            melt.dropout(tf.ones([batch_size, 1, num_units[layer]],
                                 dtype=tf.float32),
                         keep_prob=self.keep_prob,
                         training=training,
                         mode=None) for layer in range(self.num_layers)
        ]
        mask_bws = [
            melt.dropout(tf.ones([batch_size, 1, num_units[layer]],
                                 dtype=tf.float32),
                         keep_prob=self.keep_prob,
                         training=training,
                         mode=None) for layer in range(self.num_layers)
        ]

        x = self.encode(x1,
                        length1,
                        x2,
                        length2,
                        mask_fws=mask_fws,
                        mask_bws=mask_bws)
        x = self.pooling(x, length1, length2)
        #x = self.pooling(x)

        if FLAGS.use_type:
            x = tf.concat([x, tf.expand_dims(tf.to_float(input['type']), 1)],
                          1)

        if not FLAGS.split_type:
            x = self.logits(x)
        else:
            x1 = self.logits(x)
            x2 = self.logits2(x)
            x = tf.cond(tf.cast(input['type'] == 0, tf.bool), lambda:
                        (x1 + x2) / 2., lambda: x2)

        return x
Exemplo n.º 3
0
    def call(self, input, training=False):
        x = input['rcontent'] if FLAGS.rcontent else input['content']
        #print(x.shape)
        batch_size = melt.get_shape(x, 0)
        length = melt.length(x)
        #with tf.device('/cpu:0'):
        x = self.embedding(x)

        x = self.encode(x, length, training=training)

        # must mask pooling when eval ? but seems much worse result
        #if not FLAGS.mask_pooling and training:
        if not FLAGS.mask_pooling:
            length = None
        x = self.pooling(x, length)

        if FLAGS.use_type:
            x = tf.concat([x, tf.expand_dims(tf.to_float(input['type']), 1)],
                          1)

        if not FLAGS.split_type:
            x = self.logits(x)
        else:
            x1 = self.logits(x)
            x2 = self.logits2(x)
            x = tf.cond(tf.cast(input['type'] == 0, tf.bool), lambda:
                        (x1 + x2) / 2., lambda: x2)

        return x
Exemplo n.º 4
0
    def call(self, x, training=False):
        x = x['comment']
        batch_size = melt.get_shape(x, 0)
        length = melt.length(x)
        #with tf.device('/cpu:0'):
        x = self.embedding(x)

        num_units = [
            melt.get_shape(x, -1) if layer == 0 else 2 * self.num_units
            for layer in range(self.num_layers)
        ]
        #print('----------------length', tf.reduce_max(length), inputs.comment.shape)
        mask_fws = [
            melt.dropout(tf.ones([batch_size, 1, num_units[layer]],
                                 dtype=tf.float32),
                         keep_prob=self.keep_prob,
                         training=training,
                         mode=None) for layer in range(self.num_layers)
        ]
        mask_bws = [
            melt.dropout(tf.ones([batch_size, 1, num_units[layer]],
                                 dtype=tf.float32),
                         keep_prob=self.keep_prob,
                         training=training,
                         mode=None) for layer in range(self.num_layers)
        ]
        #x = self.encode(x, length, mask_fws=mask_fws, mask_bws=mask_bws)
        x = self.encode(x)

        x = self.pooling(x, length)
        #x = self.pooling(x)
        x = self.logits(x)
        return x
Exemplo n.º 5
0
    def compute_sim(self, u, v):
        u_len = melt.length(u)
        v_len = melt.length(v)
        u, v = self.embeddings_lookup(u, v)
        alpha, beta = self.attention_layer(u, v, u_len, v_len)
        u, v = self.comparison_layer(u, v, alpha, beta, u_len, v_len)
        u, v = self.aggregation_layer(u, v)

        if not FLAGS.nli_cosine:
            f = tf.concat([u, v], 1)
            #f = tf.concat((u, v, tf.abs(u-v), u*v), 1)
            score = self.fc_layer(f)
            #score = tf.sigmoid(score)
        else:
            score = melt.element_wise_cosine(u, v)

        return score
Exemplo n.º 6
0
  def call(self, x):
    xs = tf.split(x, len(self.embs), axis=-1)
    embs = []
    for x, emb in zip(xs, self.embs):
      embs += [emb(x)]

    embs = tf.add_n(embs)

    return self.pooling(embs, mt.length(x))
Exemplo n.º 7
0
    def call(self, x):
        xs = tf.split(x, len(self.embs), axis=-1)
        embs = []
        for x, emb in zip(xs, self.embs):
            embs += [emb(x)]

        embs = tf.add_n(embs)
        len_ = mt.length(x)
        seqs = self.encoder(embs, len_)
        return self.pooling(seqs, len_)
Exemplo n.º 8
0
def pad(text, start_id=None, end_id=None, weights=None, end_weight=1.0):
  logging.info('Pad with start_id', start_id, ' end_id', end_id)
  need_start_mark = start_id is not None
  need_end_mark = end_id is not None
  if not need_start_mark and not need_end_mark:
    return text, melt.length(text), weights 
  
  batch_size = tf.shape(text)[0]
  zero_pad = tf.zeros([batch_size, 1], dtype=text.dtype)

  sequence_length = melt.length(text)

  if not need_start_mark:
    text = tf.concat([text, zero_pad], 1)
    if weights is not None:
      weights = tf.concat([weights, tf.ones_like(zero_pad, dtype=tf.float32) * end_weight], 1)
  else:
    if need_start_mark:
      start_pad = zero_pad + start_id
      if need_end_mark:
        text = tf.concat([start_pad, text, zero_pad], 1)
        if weights is not None:
          weights = tf.concat([tf.zeros_like(start_pad, dtype=tf.float32), weights, tf.ones_like(zero_pad, dtype=tf.float32) * end_weight], 1)
      else:
        text = tf.concat([start_pad, text], 1)
        if weights is not None:
          weights = tf.concat([tf.zeros_like(start_pad, dtype=tf.float32), weights], 1)
      sequence_length += 1

  if need_end_mark:
    text = melt.dynamic_append_with_length(
        text, 
        sequence_length, 
        tf.constant(end_id, dtype=text.dtype)) 
    if weights is not None:
      weights = melt.dynamic_append_with_length_float32(
        weights, 
        sequence_length, 
        tf.constant(end_weight, dtype=weights.dtype)) 
    sequence_length += 1

  return text, sequence_length, weights
Exemplo n.º 9
0
    def call(self, x):
        len_ = mt.length(x)  # 计算x长度的函数
        if self.emb is not None:
            embs = self.emb(x)
            embs = keras.layers.Dropout(FLAGS.dropout)(embs)
        else:
            embs = x

        xs = self.encoder(embs, len_)

        xs = keras.layers.Dropout(FLAGS.dropout)(xs)
        return self.pooling(xs, len_)
Exemplo n.º 10
0
  def call(self, input, training=False):
    q = input['query']
    c = input['passage']
    q_len = melt.length(q)
    c_len = melt.length(c)
    q_mask = tf.cast(q, tf.bool)
    q_emb = self.embedding(q)
    c_emb = self.embedding(c)
    
    x = c_emb
    batch_size = melt.get_shape(x, 0)

    num_units = [melt.get_shape(x, -1) if layer == 0 else 2 * self.num_units for layer in range(self.num_layers)]
    mask_fws = [melt.dropout(tf.ones([batch_size, 1, num_units[layer]], dtype=tf.float32), keep_prob=self.keep_prob, training=training, mode=None) for layer in range(self.num_layers)]
    mask_bws = [melt.dropout(tf.ones([batch_size, 1, num_units[layer]], dtype=tf.float32), keep_prob=self.keep_prob, training=training, mode=None) for layer in range(self.num_layers)]
    
    c = self.encode(c_emb, c_len, mask_fws=mask_fws, mask_bws=mask_bws)
    q = self.encode(q_emb, q_len, mask_fws=mask_fws, mask_bws=mask_bws)

    qc_att = self.att_dot_attention(c, q, mask=q_mask, training=training)

    num_units = [melt.get_shape(qc_att, -1) if layer == 0 else 2 * self.num_units for layer in range(self.num_layers)]
    mask_fws = [melt.dropout(tf.ones([batch_size, 1, num_units[layer]], dtype=tf.float32), keep_prob=self.keep_prob, training=training, mode=None) for layer in range(1)]
    mask_bws = [melt.dropout(tf.ones([batch_size, 1, num_units[layer]], dtype=tf.float32), keep_prob=self.keep_prob, training=training, mode=None) for layer in range(1)]
    x = self.att_encode(qc_att, c_len, mask_fws=mask_fws, mask_bws=mask_bws)

    x = self.pooling(x, c_len)

    if FLAGS.use_type:
      x = tf.concat([x, tf.expand_dims(tf.to_float(input['type']), 1)], 1)

    if not FLAGS.split_type:
      x = self.logits(x)
    else:
      x1 = self.logits(x)
      x2 = self.logits2(x)
      x = tf.cond(tf.cast(input['type'] == 0, tf.bool), lambda: (x1 + x2) / 2., lambda: x2)
    
    return x
Exemplo n.º 11
0
    def gen_train_input(self, inputs, decode_fn):
        #--------------------- train
        logging.info('train_input: %s' % FLAGS.train_input)
        trainset = list_files(FLAGS.train_input)
        logging.info('trainset:{} {}'.format(len(trainset), trainset[:2]))

        assert len(trainset) >= FLAGS.min_records, '%d %d' % (
            len(trainset), FLAGS.min_records)
        if FLAGS.num_records > 0:
            assert len(trainset) == FLAGS.num_records, len(trainset)

        num_records = gezi.read_int_from(FLAGS.num_records_file)
        logging.info('num_records:{}'.format(num_records))
        logging.info('batch_size:{}'.format(FLAGS.batch_size))
        logging.info('FLAGS.num_gpus:{}'.format(FLAGS.num_gpus))
        num_gpus = max(FLAGS.num_gpus, 1)
        num_steps_per_epoch = num_records // (FLAGS.batch_size * num_gpus)
        logging.info('num_steps_per_epoch:{}'.format(num_steps_per_epoch))
        self.num_records = num_records
        self.num_steps_per_epoch = num_steps_per_epoch

        image_name, image_feature, text, text_str = inputs(
            trainset,
            decode_fn=decode_fn,
            batch_size=FLAGS.batch_size,
            num_epochs=FLAGS.num_epochs,
            #seed=seed,
            num_threads=FLAGS.num_threads,
            batch_join=FLAGS.batch_join,
            shuffle_files=FLAGS.shuffle_files,
            fix_sequence=FLAGS.fix_sequence,
            num_prefetch_batches=FLAGS.num_prefetch_batches,
            min_after_dequeue=FLAGS.min_after_dequeue,
            name=self.input_train_name)

        if FLAGS.feed_dict:
            self.text_place = text_placeholder('text_place')
            self.text_str = text_str
            text = self.text_place

            self.image_feature_place = image_feature_placeholder(
                'image_feature_place')
            self.image_feature = image_feature
            image_feature = self.image_feature_place

        if FLAGS.monitor_level > 1:
            lengths = melt.length(text)
            melt.scalar_summary("text/batch_min", tf.reduce_min(lengths))
            melt.scalar_summary("text/batch_max", tf.reduce_max(lengths))
            melt.scalar_summary("text/batch_mean", tf.reduce_mean(lengths))
        return (image_name, image_feature, text, text_str), trainset
Exemplo n.º 12
0
def adjust(features, subset):
    if 'hist_len' not in features:
        try:
            features['hist_len'] = mt.length(features['history'])
        except Exception:
            features['hist_len'] = tf.ones_like(features['did'])

    if FLAGS.max_history:
        for key in features:
            if 'history' in key:
                max_history = FLAGS.max_history
                if 'enti' in key:
                    max_history *= 2
                if not FLAGS.fixed_pad:
                    features[key] = features[key][:, :max_history]
                else:
                    features[key] = mt.pad(features[key], max_history)

    # 注意按照nid去获取新闻测信息 did只是用作id特征 可能被mask
    features['ori_did'] = features['did']
    features['ori_history'] = features['history']
    if 'impressions' in features:
        features['ori_impressions'] = features['impressions']

    features['did'] = mask_dids(features['did'], features['did_in_train'],
                                subset, FLAGS.test_all_mask)

    features['uid'] = mask_uids(features['uid'], subset == 'train')

    if 'history' in features:
        features['history'] = unk_aug(features['history'], subset == 'train')

    mask_negative_weights(features, subset == 'train')

    vs = gezi.get('vocab_sizes')
    if FLAGS.min_count_unk and FLAGS.min_count:
        features['uid'] = get_id(features['uid'], vs['uid'][1])
        features['did'] = get_id(features['did'], vs['did'][1])
        if FLAGS.mask_history:
            features['history'] = get_id(features['history'], vs['did'][1])
        if 'impressions' in features:
            features['impressions'] = get_id(features['impressions'],
                                             vs['did'][1])

    if vs['uid'][1] < vs['uid'][0]:
        features['uid'] = get_id(features['uid'], vs['uid'][1])

    return features
Exemplo n.º 13
0
    def call(self, input, c_len=None, max_c_len=None, training=False):
        assert isinstance(input, dict)
        x = input['content']

        batch_size = melt.get_shape(x, 0)
        if c_len is None or max_c_len is None:
            c_len, max_c_len = melt.length2(x)

        if self.rnn_no_padding:
            logging.info('------------------no padding! train or eval')
            c_len = max_c_len

        x = self.embedding(x)

        if FLAGS.use_char:
            cx = input['char']

            cx = tf.reshape(cx, [batch_size * max_c_len, FLAGS.char_limit])
            chars_len = melt.length(cx)
            cx = self.char_embedding(cx)
            cx = self.char_encode(cx, chars_len, training=training)
            cx = self.char_pooling(cx, chars_len)
            cx = tf.reshape(cx, [batch_size, max_c_len, 2 * self.num_units])

            if self.char_combiner == 'concat':
                x = tf.concat([x, cx], axis=2)
            elif self.char_combiner == 'sfu':
                x = self.char_sfu_combine(x, cx, training=training)

        if FLAGS.use_pos:
            px = input['pos']
            px = self.pos_embedding(px)
            x = tf.concat([x, px], axis=2)

        if FLAGS.use_ner:
            nx = input['ner']
            nx = self.ner_embedding(nx)
            x = tf.concat([x, nx], axis=2)

        x = self.encode(x, c_len, training=training)

        return x
Exemplo n.º 14
0
def adjust(features, subset):
  if 'hist_len' not in features:
    try:
      features['hist_len'] = melt.length(features['history'])
    except Exception:
      features['hist_len'] = tf.ones_like(features['did'])

  if FLAGS.max_history:
    for key in features:
      if key.startswith('history'):
        max_history = FLAGS.max_history
        if 'entity' in key:
          max_history *= 2
        features[key] = features[key][:,:max_history]

  # 注意按照nid去获取新闻测信息 did只是用作id特征 可能被mask
  features['ori_did'] = features['did'] 
  features['ori_history'] = features['history']
  if 'impressions' in features:
    features['ori_impressions'] = features['impressions']

  features['did'] = mask_dids(features['did'], features['did_in_train'],
                              subset, FLAGS.test_all_mask)
  features['uid'] = mask_uids(features['uid'], subset=='train')

  try:
    features['history'] = unk_aug(features['history'], subset=='train')
  except Exception:
    pass
  mask_negative_weights(features, subset=='train')

  if FLAGS.min_count_unk and FLAGS.min_count:
    vs = gezi.get('vocab_sizes')
    features['uid'] = get_id(features['uid'], vs['uid'][1])
    features['did'] = get_id(features['did'], vs['did'][1])
    if FLAGS.mask_history:
      features['history'] = get_id(features['history'], vs['did'][1])
    if 'impressins' in features:
      features['impressions'] = get_id(features['impressions'], vs['did'][1])

  return features
Exemplo n.º 15
0
    def call(self, input, training=False):
        ids = input['index']
        values = input['value']
        fields = input['field']

        # if FLAGS.hidden_size > 50:
        #   with tf.device('/cpu:0'):
        #     x = self.emb(ids)
        # else:
        x = self.emb(ids)
        if FLAGS.field_emb:
            x = K.concatenate([x, self.field_emb(fields)], axis=-1)

        if FLAGS.deep_addval:
            values = K.expand_dims(values, -1)
            x = self.mult([x, values])

        if FLAGS.field_concat:
            num_fields = FLAGS.field_dict_size
            #x = tf.math.unsorted_segment_sum(x, fields, num_fields)
            x = melt.unsorted_segment_sum_emb(x, fields, num_fields)
            # like [512, 100 * 50]
            x = K.reshape(x, [-1, num_fields * self.emb_dim])
        else:
            if FLAGS.pooling == 'allsum':
                x = K.sum(x, 1)
            else:
                assert FLAGS.index_addone, 'can not calc length for like 0,1,2,0,0,0'
                c_len = melt.length(ids)
                x = self.pooling(x, c_len)

        if self.emb_activation:
            x = self.emb_activation(x + self.bias)

        if self.mlp:
            x = self.mlp(x, training=training)

        x = self.dense(x)
        x = K.squeeze(x, -1)
        return x
Exemplo n.º 16
0
    def call(self, input, training=False):
        self.step += 1
        x = input['content']
        x = self.unk_aug(x, training=training)
        batch_size = melt.get_shape(x, 0)
        c_mask = tf.cast(x, tf.bool)
        # TODO move to __init__
        model = modeling.BertModel(config=self.bert_config,
                                   is_training=training,
                                   input_ids=x,
                                   input_mask=c_mask,
                                   use_one_hot_embeddings=FLAGS.use_tpu)

        if self.step == 0 and self.init_checkpoint:
            self.restore()
        c_len = melt.length(x)

        if FLAGS.encoder_output_method == 'last':
            x = model.get_pooled_output()
        else:
            x = model.get_sequence_output()

        if training:
            x = tf.nn.dropout(x, keep_prob=0.9)

        logging.info('---------------bert_lr_ratio', FLAGS.bert_lr_ratio)
        x = x * FLAGS.bert_lr_ratio + tf.stop_gradient(x) * (
            1 - FLAGS.bert_lr_ratio)

        if FLAGS.transformer_add_rnn:
            assert FLAGS.encoder_output_method != 'last'
            x = self.rnn_encode(x, c_len)

        if FLAGS.encoder_output_method != 'last':
            x = self.pooling(x, c_len)
            x2 = model.get_pooled_output()
            x = tf.concat([x, x2], -1)
        x = self.logits(x)
        x = tf.reshape(x, [batch_size, NUM_ATTRIBUTES, NUM_CLASSES])
        return x
Exemplo n.º 17
0
    def gen_text_feature(self, text):
        is_training = self.is_training
        batch_size = tf.shape(text)[0]

        zero_pad = tf.zeros([batch_size, 1], dtype=text.dtype)
        text = tf.concat(1, [zero_pad, text, zero_pad])
        sequence_length = melt.length(text) + 1
        text = melt.dynamic_append_with_length(
            text, sequence_length, tf.constant(self.end_id, dtype=text.dtype))
        sequence_length += 1

        state = self.cell.zero_state(batch_size, tf.float32)

        inputs = tf.nn.embedding_lookup(self.emb, text)
        if is_training and FLAGS.keep_prob < 1:
            inputs = tf.nn.dropout(inputs, FLAGS.keep_prob)

        outputs, state = tf.nn.dynamic_rnn(self.cell,
                                           inputs,
                                           initial_state=state,
                                           sequence_length=sequence_length)

        text_feature = melt.dynamic_last_relevant(outputs, sequence_length)
        return text_feature
Exemplo n.º 18
0
    def call(self, input, training=False):
        q = input['query']
        c = input['passage']

        # reverse worse
        if FLAGS.cq_reverse:
            q, c = c, q

        #print(input['type'])
        # print('q', q)
        # print('c', c)

        q_len = melt.length(q)
        c_len = melt.length(c)
        q_mask = tf.cast(q, tf.bool)
        c_mask = tf.cast(c, tf.bool)
        q_emb = self.embedding(q)
        c_emb = self.embedding(c)

        x = c_emb
        batch_size = melt.get_shape(x, 0)

        if FLAGS.share_dropout:
            num_units = [
                melt.get_shape(x, -1) if layer == 0 else 2 * self.num_units
                for layer in range(self.num_layers)
            ]
            mask_fws = [
                melt.dropout(tf.ones([batch_size, 1, num_units[layer]],
                                     dtype=tf.float32),
                             keep_prob=self.keep_prob,
                             training=training,
                             mode=None) for layer in range(self.num_layers)
            ]
            mask_bws = [
                melt.dropout(tf.ones([batch_size, 1, num_units[layer]],
                                     dtype=tf.float32),
                             keep_prob=self.keep_prob,
                             training=training,
                             mode=None) for layer in range(self.num_layers)
            ]

            # NOTICE query and passage share same drop out, so same word still has same embedding vector after dropout in query and passage
            c = self.encode(c_emb,
                            c_len,
                            mask_fws=mask_fws,
                            mask_bws=mask_bws,
                            training=training)
            q = self.encode(q_emb,
                            q_len,
                            mask_fws=mask_fws,
                            mask_bws=mask_bws,
                            training=training)
        else:
            c = self.encode(c_emb, c_len, training=training)
            q = self.encode(q_emb, q_len, training=training)

        # helps a lot using qc att, now bidaf att worse..
        for i in range(FLAGS.hop):
            if not FLAGS.use_bidaf_att:
                x = self.att_dot_attentions[i](c,
                                               q,
                                               mask=q_mask,
                                               training=training)
            else:
                x = self.att_dot_attentions[i](c,
                                               q,
                                               c_mask,
                                               q_mask,
                                               training=training)
            if FLAGS.use_att_encode:
                x = self.att_encodes[i](x, c_len, training=training)
            x = self.match_dot_attentions[i](x,
                                             x,
                                             mask=c_mask,
                                             training=training)
            #x = self.match_dot_attentions[i](x, mask=c_mask, training=training)
            x = self.match_encodes[i](x, c_len, training=training)

        x = self.pooling(x, c_len, calc_word_scores=self.debug)

        if FLAGS.use_type:
            x = tf.concat([x, tf.expand_dims(tf.to_float(input['type']), 1)],
                          1)

        # might helps ensemble
        if FLAGS.use_answer_emb:
            x1 = x

            neg = input['candidate_neg']
            pos = input['candidate_pos']
            na = input['candidate_na']
            neg_len = melt.length(neg)
            pos_len = melt.length(pos)
            na_len = melt.length(na)
            neg_emb = self.embedding(neg)
            pos_emb = self.embedding(pos)
            na_emb = self.embedding(na)

            if FLAGS.share_dropout:
                neg = self.encode(neg_emb,
                                  neg_len,
                                  mask_fws=mask_fws,
                                  mask_bws=mask_bws,
                                  training=training)
                pos = self.encode(pos_emb,
                                  pos_len,
                                  mask_fws=mask_fws,
                                  mask_bws=mask_bws,
                                  training=training)
                na = self.encode(na_emb,
                                 na_len,
                                 mask_fws=mask_fws,
                                 mask_bws=mask_bws,
                                 training=training)
            else:
                neg = self.encode(neg_emb, neg_len, training=training)
                pos = self.encode(pos_emb, pos_len, training=training)
                na = self.encode(na_emb, na_len, training=training)

            neg = self.pooling(neg, neg_len)
            pos = self.pooling(pos, pos_len)
            na = self.pooling(na, na_len)

            answer = tf.stack([neg, pos, na], 1)

            # [batch_size, emb_dim]
            x = self.context_dense(x)
            # [batch_size, 3, emb_dim]
            answer = self.answer_dense(answer)
            x = tf.matmul(answer, tf.transpose(tf.expand_dims(x, 1),
                                               [0, 2, 1]))
            x = tf.reshape(x, [batch_size, NUM_CLASSES])

            x = tf.concat([x1, x], -1)

            #return x

        # not help
        if FLAGS.combine_query:
            q = self.pooling(q, q_len)
            x = tf.concat([x, q], -1)

        if not FLAGS.use_label_emb:
            # split logits by type is useful, especially for type1, and improve a lot with type1 only finetune
            if not FLAGS.split_type:
                x = self.logits(x)
            else:
                x1 = self.logits(x)
                x2 = self.logits2(x)
                mask = tf.expand_dims(tf.to_float(tf.equal(input['type'], 0)),
                                      1)
                x = x1 * mask + x2 * (1 - mask)
        else:
            # use label emb seems not help ?
            x = self.label_dense(x)
            # TODO..
            x = melt.dot(x, self.label_embedding(None))

        return x
Exemplo n.º 19
0
    def compute_seq_loss(self, image_emb, text):
        """
    same ass 7
    but use dynamic rnn
    """
        #notice here must use tf.shape not text.get_shape()[0], because it is dynamic shape, known at runtime
        is_training = self.is_training

        batch_size = tf.shape(text)[0]

        zero_pad = tf.zeros([batch_size, 1], dtype=text.dtype)

        #add zero before sentence to avoid always generate A...
        #add zero after sentence to make sure end mark will not exceed boundary incase your input sentence is long with out 0 padding at last
        text = tf.concat(1, [zero_pad, text, zero_pad])
        #+1 for the first zero
        sequence_length = melt.length(text) + 1
        text = melt.dynamic_append_with_length(
            text, sequence_length, tf.constant(self.end_id, dtype=text.dtype))
        sequence_length += 1

        #@TODO different init state as show in ptb_word_lm
        state = self.cell.zero_state(batch_size, tf.float32)

        self.initial_state = state

        #print('melt.last_dimension(text)', melt.last_dimension(text))

        #[batch_size, num_steps - 1, emb_dim], remove last col
        #notice tf 10.0 now do not support text[:,:-1] @TODO may change to that if tf support in future
        #now the hack is to use last_dimension wich will inside use static shape notice dynamic shape like tf.shape not work!
        #using last_dimension is static type! Konwn on graph construction not dynamic runtime
        #inputs = tf.nn.embedding_lookup(self.emb, text[:,:melt.last_dimension(text) - 1]) + self.bemb
        # TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
        #inputs = tf.nn.embedding_lookup(self.emb, text[:,:tf.shape(text)[1] - 1]) + self.bemb
        #can see ipynotebook/dynamic_length.npy
        #well this work..
        #num_steps = tf.shape(text)[1]
        #inputs = tf.nn.embedding_lookup(self.emb, melt.exclude_last_col(text)) + self.bemb
        inputs = tf.nn.embedding_lookup(
            self.emb, melt.dynamic_exclude_last_col(text)) + self.bemb

        if is_training and FLAGS.keep_prob < 1:
            inputs = tf.nn.dropout(inputs, FLAGS.keep_prob)

        #[batch_size, num_steps, emb_dim] image_emp( [batch_size, emb_dim] ->
        #[batch_size, 1, emb_dim]) before concat
        inputs = tf.concat(1, [tf.expand_dims(image_emb, 1), inputs])

        outputs, state = tf.nn.dynamic_rnn(self.cell,
                                           inputs,
                                           initial_state=state,
                                           sequence_length=sequence_length)
        self.final_state = state

        #@TODO now looks like this version is much faster then using like _compute_seq_loss13
        #but still there are much un necessary calculations like mat mul for all batch_size * num steps ..
        #can we speed up by not calc loss for mask[pos] == 0 ?
        output = tf.reshape(outputs, [-1, self.emb_dim])

        with tf.device('/cpu:0'):
            logits = tf.matmul(
                output, self.embed_word_W
            ) + self.embed_word_b if self.softmax_loss_function is None else output
        targets = text
        mask = tf.cast(tf.sign(text), dtype=tf.float32)

        loss = tf.nn.seq2seq.sequence_loss_by_example(
            [logits], [tf.reshape(targets, [-1])], [tf.reshape(mask, [-1])],
            softmax_loss_function=self.softmax_loss_function)

        #--------@TODO seems using below and tf.reduce_mean will generate not as good as above loss and melt.reduce_mean
        #--if no bug the diff shold be per example(per step) loss and per single step loss
        if (not is_training) or FLAGS.per_example_loss:
            loss = melt.reduce_mean_with_mask(tf.reshape(
                loss, [batch_size, -1]),
                                              mask,
                                              reduction_indices=1,
                                              keep_dims=True)
        else:
            #if use this the will be [batch_size * num_steps, 1], so for use negs, could not use dynamic length mode
            loss = tf.reshape(loss, [-1, 1])

        return loss
Exemplo n.º 20
0
 def gen_text_feature(self, text, emb):
   inputs = tf.nn.embedding_lookup(emb, text)
   text_feature = self.encoder.encode(inputs, melt.length(text)).final_state
   #print('---------------------', text_feature)
   return text_feature
Exemplo n.º 21
0
    def call(self, input, training=False):
        q = input['query']
        c = input['passage']

        # reverse worse
        if FLAGS.cq_reverse:
            q, c = c, q

        #print(input['type'])
        # print('q', q)
        # print('c', c)

        q_len = melt.length(q)
        c_len = melt.length(c)
        q_mask = tf.cast(q, tf.bool)
        c_mask = tf.cast(c, tf.bool)

        q_emb = self.embedding(q)
        c_emb = self.embedding(c)

        x = c_emb
        batch_size = melt.get_shape(x, 0)

        if FLAGS.rnn_no_padding:
            logging.info('------------------no padding! train or eval')
            q_len = tf.ones([batch_size], dtype=q.dtype) * tf.cast(
                melt.get_shape(q, -1), q.dtype)
            c_len = tf.ones([batch_size], dtype=c.dtype) * tf.cast(
                melt.get_shape(c, -1), c.dtype)
            q_mask = tf.ones_like(q)
            c_mask = tf.ones_like(c)

        c = self.encode(c_emb, c_len, training=training)
        q = self.encode(q_emb, q_len, training=training)

        # helps a lot using qc att, now bidaf att worse..
        # TODO... FIXME WRONG!  must use sfu as to iterative align gate will increase dim while sfu not
        x = c
        for i in range(FLAGS.hop):
            if not FLAGS.use_bidaf_att:
                x = self.att_dot_attentions[i](x,
                                               q,
                                               mask=q_mask,
                                               training=training)
            else:
                x = self.att_dot_attentions[i](x,
                                               q,
                                               c_mask,
                                               q_mask,
                                               training=training)
            if FLAGS.use_att_encode:
                x = self.att_encodes[i](x, c_len, training=training)
            #x = self.match_dot_attentions[i](x, x, mask=c_mask, training=training)
            x = self.match_dot_attentions[i](x, mask=c_mask, training=training)
            x = self.match_encodes[i](x, c_len, training=training)

        if FLAGS.mask_pooling:
            x = self.pooling(x, c_len, calc_word_scores=self.debug)
        else:
            x = self.pooling(x, None, calc_word_scores=self.debug)

        if FLAGS.use_type:
            x = tf.concat([x, tf.expand_dims(tf.to_float(input['type']), 1)],
                          1)

        if FLAGS.use_type_emb:
            x = tf.concat([x, self.type_embedding(input['type'])], 1)

        # might helps ensemble
        if FLAGS.use_answer_emb:
            x1 = x

            neg = input['candidate_neg']
            pos = input['candidate_pos']
            na = input['candidate_na']
            neg_len = melt.length(neg)
            pos_len = melt.length(pos)
            na_len = melt.length(na)
            neg_emb = self.embedding(neg)
            pos_emb = self.embedding(pos)
            na_emb = self.embedding(na)

            if FLAGS.share_dropout:
                neg = self.encode(neg_emb,
                                  neg_len,
                                  mask_fws=mask_fws,
                                  mask_bws=mask_bws,
                                  training=training)
                pos = self.encode(pos_emb,
                                  pos_len,
                                  mask_fws=mask_fws,
                                  mask_bws=mask_bws,
                                  training=training)
                na = self.encode(na_emb,
                                 na_len,
                                 mask_fws=mask_fws,
                                 mask_bws=mask_bws,
                                 training=training)
            else:
                neg = self.encode(neg_emb, neg_len, training=training)
                pos = self.encode(pos_emb, pos_len, training=training)
                na = self.encode(na_emb, na_len, training=training)

            neg = self.pooling(neg, neg_len)
            pos = self.pooling(pos, pos_len)
            na = self.pooling(na, na_len)

            answer = tf.stack([neg, pos, na], 1)

            # [batch_size, emb_dim]
            x = self.context_dense(x)
            # [batch_size, 3, emb_dim]
            answer = self.answer_dense(answer)
            x = tf.matmul(answer, tf.transpose(tf.expand_dims(x, 1),
                                               [0, 2, 1]))
            x = tf.reshape(x, [batch_size, NUM_CLASSES])

            x = tf.concat([x1, x], -1)

            #return x

        # not help
        if FLAGS.combine_query:
            q = self.pooling(q, q_len)
            x = tf.concat([x, q], -1)

        if not FLAGS.use_label_emb:
            # split logits by type is useful, especially for type1, and improve a lot with type1 only finetune
            if not FLAGS.split_type:
                x = self.logits(x)
            else:
                x1 = self.logits(x)
                x2 = self.logits2(x)
                mask = tf.expand_dims(tf.to_float(tf.equal(input['type'], 0)),
                                      1)
                x = x1 * mask + x2 * (1 - mask)
        else:
            # use label emb seems not help ?
            x = self.label_dense(x)
            # TODO..
            x = melt.dot(x, self.label_embedding(None))

        return x
Exemplo n.º 22
0
def monitor_text_length(text):
  lengths = melt.length(text)
  melt.scalar_summary("text/batch_min", tf.reduce_min(lengths))
  melt.scalar_summary("text/batch_max", tf.reduce_max(lengths))
  melt.scalar_summary("text/batch_mean", tf.reduce_mean(lengths))
Exemplo n.º 23
0
    def call(self, input):
        # TODO tf2 keras seem to auto append last dim so need this
        mt.try_squeeze_dim(input)

        if not FLAGS.batch_parse:
            util.adjust(input, self.mode)

        self.embs = []
        self.feats = {}

        bs = mt.get_shape(input['did'], 0)

        def _add(feat, name):
            if _is_ok(name):
                self.feats[name] = feat
                self.embs += [feat]

        def _adds(feats, names):
            for feat, name in zip(feats, names):
                _add(feat, name)

        # --------------------------  user
        if FLAGS.use_uid:
            uemb = self.uemb(input['uid'])
            _add(uemb, 'uid')
        # --------------------------  doc
        if FLAGS.use_did:
            demb = self.demb(input['did'])
            _add(demb, 'did')

        # ---------------------------  context
        if 'history' in input:
            hlen = mt.length(input['history'])
            hlen = tf.math.maximum(hlen, 1)

        if FLAGS.use_time_emb:
            _add(self.hour_emb(input['hour']), 'hour')
            _add(self.weekday_emb(input['weekday']), 'weekday')

        if FLAGS.use_fresh_emb:
            fresh = input['fresh']
            fresh_day = tf.cast(fresh / (3600 * 12), fresh.dtype)
            fresh_hour = tf.cast(fresh / 3600, fresh.dtype)

            _add(self.fresh_day_emb(fresh_day), 'fresh_day')
            _add(self.fresh_hour_emb(fresh_hour), 'fresh_hour')

        if FLAGS.use_position_emb:
            _add(self.position_emb(input['position']), 'position')

        if FLAGS.use_history:
            dids = input['history']
            if FLAGS.his_strategy == 'bst' or FLAGS.his_pooling == 'mhead':
                mask = tf.cast(tf.equal(dids, 0), dids.dtype)
                dids += mask
                hlen = tf.ones_like(hlen) * 50
            hembs = self.demb(dids)

            his_embs = hembs
            his_embs = self.his_encoder(his_embs, hlen)
            self.his_embs = his_embs

            his_emb = self.his_pooling(demb, his_embs, hlen)

            _add(his_emb, 'his_id')

        # --------------- doc info
        doc_feats = gezi.get('doc_feats')
        doc_feat_lens = gezi.get('doc_feat_lens')
        doc = mt.lookup_feats(input['ori_did'], self.doc_lookup, doc_feats,
                              doc_feat_lens)

        cat = tf.squeeze(doc['cat'], -1)
        sub_cat = tf.squeeze(doc['sub_cat'], -1)

        # title_entities = doc['title_entities']
        # title_entity_types = doc['title_entity_types']
        # abstract_entities = doc['abstract_entities']
        # abstract_entity_types = doc['abstract_entity_types']

        title_entities = input['title_entities']
        title_entity_types = input['title_entity_types']
        abstract_entities = input['abstract_entities']
        abstract_entity_types = input['abstract_entity_types']

        # mt.length 不用速度会慢
        # prev_cat_emb = self.cat_emb(cat)
        # prev_scat_emb = self.scat_emb(cat)
        if _is_ok('cat'):
            cat_emb = self.cat_emb(cat)
            scat_emb = self.scat_emb(sub_cat)
            _adds(
                [
                    # prev_cat_emb,
                    # prev_scat_emb,
                    cat_emb,
                    scat_emb,
                ],
                # ['cat', 'sub_cat', 'title_entity_types', 'abstract_entity_types', 'title_entities', 'abstract_entities']
                [
                    # 'prev_cat', 'prev_scat',
                    'cat',
                    'sub_cat'
                ])

        if _is_ok('enti'):
            title_entities = self.entities_encoder(
                tf.concat([title_entities, title_entity_types], -1))
            abstract_entities = self.entities_encoder(
                tf.concat([abstract_entities, abstract_entity_types], -1))

            _adds(
                [
                    # self.pooling(self.entity_emb(title_entities), mt.length(doc['title_entities'])),
                    # self.pooling(self.entity_type_emb(title_entity_types), mt.length(doc['title_entity_types'])),
                    # self.pooling(self.entity_emb(abstract_entities), mt.length(doc['abstract_entities'])),
                    # self.pooling(self.entity_type_emb(abstract_entity_types), mt.length(doc['abstract_entity_types'])),
                    title_entities,
                    abstract_entities
                ],
                ['title_entities', 'abstract_entities'])

            # _adds(
            #     [
            #       self.his_simple_pooling(self.entity_type_emb(input['history_title_entity_types']), mt.length(input['history_title_entity_types'])),
            #       self.his_simple_pooling(self.entity_type_emb(input['history_abstract_entity_types']), mt.length(input['history_abstract_entity_types']))
            #     ],
            #     ['history_title_entity_merge_types', 'history_abstract_entity_merge_types']
            # )
            input['history_title_entities'] = input[
                'history_title_entities'][:, :FLAGS.max_his_title_entities *
                                          FLAGS.max_lookup_history]
            input['history_title_entity_types'] = input[
                'history_title_entity_types'][:, :FLAGS.
                                              max_his_title_entities *
                                              FLAGS.max_lookup_history]
            input['history_abstract_entities'] = input[
                'history_abstract_entities'][:, :FLAGS.max_his_title_entities *
                                             FLAGS.max_lookup_history]
            input['history_abstract_entity_types'] = input[
                'history_abstract_entity_types'][:, :FLAGS.
                                                 max_his_title_entities *
                                                 FLAGS.max_lookup_history]
            _adds([
                self.his_entity_pooling(
                    title_entities,
                    (self.entity_emb(input['history_title_entities']) +
                     self.entity_type_emb(input['history_title_entity_types'])
                     ), mt.length(input['history_title_entities'])),
                self.his_entity_pooling(
                    abstract_entities,
                    (self.entity_emb(input['history_abstract_entities']) +
                     self.entity_type_emb(
                         input['history_abstract_entity_types'])),
                    mt.length(input['history_abstract_entities']))
            ], ['his_title_merge_entities', 'his_abstract_merge_entities'])

            # --------------- history info
        dids = input['ori_history']
        dids = dids[:, :FLAGS.max_lookup_history]
        hlen = mt.length(input['history'])
        hlen = tf.math.maximum(hlen, 1)

        his = mt.lookup_feats(dids, self.doc_lookup, doc_feats, doc_feat_lens)

        his_cats = his['cat']
        his_cats = tf.squeeze(his_cats, -1)
        his_sub_cats = his['sub_cat']
        his_sub_cats = tf.squeeze(his_sub_cats, -1)

        # his_title_entities = his['title_entities']
        # his_title_entity_types = his['title_entity_types']
        # his_abstract_entities = his['abstract_entities']
        # his_abstract_entity_types = his['abstract_entity_types']

        # his_title_entities = self.his_entities_encoder(tf.concat([his_title_entities, his_title_entity_types], -1),
        #                                                tf.math.minimum(hlen, FLAGS.max_titles), title_entities)
        # his_abstract_entities = self.his_entities_encoder(tf.concat([his_abstract_entities, his_abstract_entity_types], -1),
        #                                                   tf.math.minimum(hlen, FLAGS.max_titles), abstract_entities)

        if _is_ok('cat'):
            # FIXME 当前如果直接展平 mt.length有问题 因为都是内壁 0 pad,  类似  2,3,0,0 1,0,0,0  会丢掉很多信息 填1 是一种方式 (v1就是这种 最多 1,1)
            # 另外也可以用encoder
            _adds(
                [
                    self.his_cat_pooling(self.cat_emb(his_cats),
                                         mt.length(his_cats)),
                    self.his_cat_pooling(self.scat_emb(his_sub_cats),
                                         mt.length(his_sub_cats)),
                    ## 对应cat din效果不如att(增加也没有收益) 对应title din效果比att好, entity也是din比较好
                    # self.his_scat_din_pooling(scat_emb, self.scat_emb(his_sub_cats), mt.length(his_sub_cats)),
                    # his_title_entities,
                    # his_abstract_entities,
                ],
                [
                    'his_cats',
                    'his_sub_cats',
                    #  'history_title_entities', 'history_abstract_entities'
                ])

        if not FLAGS.bert_dir or not FLAGS.bert_only:
            if _is_ok('^cur_title&'):
                cur_title = self.title_encoder(doc['title'])
                his_titles = his['title']
                if FLAGS.max_titles:
                    his_titles = his_titles[:, :FLAGS.max_titles]
                his_title = self.titles_encoder(
                    his_titles, tf.math.minimum(hlen, FLAGS.max_titles),
                    cur_title)
                _adds([cur_title, his_title], ['cur_title', 'his_title'])

            if _is_ok('^abstract&'):
                cur_abstract = self.abstract_encoder(doc['abstract'])
                his_abstracts = his['abstract']
                if FLAGS.max_abstracts:
                    his_abstracts = his_abstracts[:, :FLAGS.max_abstracts]
                his_abstract = self.abstracts_encoder(
                    his_abstracts, tf.math.minimum(hlen, FLAGS.max_abstracts),
                    cur_abstract)
                _adds([cur_abstract, his_abstract],
                      ['cur_abstract', 'his_abstract'])

            if FLAGS.use_body:
                if _is_ok('^body&'):
                    cur_body = self.body_encoder(doc['body'])
                    his_bodies = his['body']
                    if FLAGS.max_bodies:
                        his_bodies = his_bodies[:, :FLAGS.max_bodies]
                    his_body = self.bodies_encoder(
                        his_bodies, tf.math.minimum(hlen, FLAGS.max_bodies),
                        cur_body)
                    _adds([
                        cur_body,
                        his_body,
                    ], ['cur_body', 'his_body'])

        if FLAGS.bert_dir:
            if _is_ok('bert_title'):
                bert_title = self.bert_title_encoder(doc['title_uncased'])
                max_titles = FLAGS.max_bert_titles
                his_bert_title = self.bert_titles_encoder(
                    his['title_uncased'][:, :max_titles],
                    tf.math.minimum(hlen, max_titles), bert_title)
                _adds([
                    bert_title,
                    his_bert_title,
                ], ['bert_title', 'his_bert_title'])
            if _is_ok('bert_abstract') and FLAGS.bert_abstract:
                bert_abstract = self.bert_abstract_encoder(
                    doc['abstract_uncased'])
                max_abstracts = FLAGS.max_bert_abstracts
                his_bert_abstract = self.bert_abstracts_encoder(
                    his['abstract_uncased'][:, :max_abstracts],
                    tf.math.minimum(hlen, max_abstracts), bert_abstract)
                _adds([
                    bert_abstract,
                    his_bert_abstract,
                ], ['bert_abstract', 'his_bert_abstract'])
            if _is_ok('bert_body') and FLAGS.bert_body:
                bert_body = self.bert_body_encoder(doc['body_uncased'])
                max_bodies = FLAGS.max_bert_bodies
                his_bert_body = self.bert_bodies_encoder(
                    his['body_uncased'][:, :max_bodies],
                    tf.math.minimum(hlen, max_bodies), bert_body)
                _adds([
                    bert_body,
                    his_bert_body,
                ], ['bert_body', 'his_bert_body'])

        if FLAGS.use_impression_titles:  # dev +0.4% test下降
            his_impression = mt.lookup_feats(input['impressions'],
                                             self.doc_lookup, doc_feats,
                                             doc_feat_lens)
            his_impression_titles = his_impression['title']
            his_impression_title = self.titles_encoder2(
                his_impression_titles, mt.length(input['impressions']),
                cur_title)
            _adds([
                his_impression_title,
            ], ['impression_title'])

        # 用impression id 会dev test不一致 不直接用id
        if FLAGS.use_impressions:
            _add(self.mean_pooling(self.demb(input['impressions'])),
                 'impressions')

        if FLAGS.use_dense:
            dense_emb = self.deal_dense(input)
            _add(dense_emb, 'dense')

        embs = self.embs
        # logging.info('-----------embs:', len(embs))
        logging.info(self.feats.keys())
        # logging.debug(self.feats)
        embs = [
            x if len(mt.get_shape(x)) == 2 else tf.squeeze(x, 1) for x in embs
        ]
        embs = tf.stack(embs, axis=1)

        if FLAGS.batch_norm:
            embs = self.batch_norm(embs)

        if FLAGS.l2_normalize_before_pooling:
            x = tf.math.l2_normalize(embs, axis=FLAGS.l2_norm_axis)

        x = self.feat_pooling(embs)

        # if FLAGS.dropout:
        #   x = self.dropout(x)

        if FLAGS.use_dense:
            x = tf.concat([x, dense_emb], axis=1)

        # if FLAGS.use_his_concat:
        #   x = tf.concat([x, his_concat], axis=1)

        x = self.mlp(x)
        self.logit = self.dense(x)

        self.prob = tf.math.sigmoid(self.logit)
        self.impression_id = input['impression_id']
        self.position = input['position']
        self.history_len = input['hist_len']
        self.impression_len = input['impression_len']
        self.input_ = input
        return self.logit
Exemplo n.º 24
0
    def call(self, input):
        # TODO tf2 keras seem to auto append last dim so need this
        melt.try_squeeze_dim(input)

        if not FLAGS.batch_parse:
            util.adjust(input, self.mode)

        # print(input)

        embs = []

        if 'history' in input:
            hlen = melt.length(input['history'])
            hlen = tf.math.maximum(hlen, 1)

        bs = melt.get_shape(input['did'], 0)

        # user
        if FLAGS.use_uid:
            uemb = self.uemb(input['uid'])
            embs += [uemb]

        if FLAGS.use_did:
            demb = self.demb(input['did'])
            embs += [demb]

        if FLAGS.use_time_emb:
            embs += [
                self.hour_emb(input['hour']),
                self.weekday_emb(input['weekday']),
            ]

        if FLAGS.use_fresh_emb:
            fresh = input['fresh']
            fresh_day = tf.cast(fresh / (3600 * 12), fresh.dtype)
            fresh_hour = tf.cast(fresh / 3600, fresh.dtype)
            embs += [
                self.fresh_day_emb(fresh_day),
                self.fresh_hour_emb(fresh_hour)
            ]

        if FLAGS.use_position_emb:
            embs += [self.position_emb(input['position'])]

        if FLAGS.use_news_info and 'cat' in input:
            # print('------entity_emb', self.entity_emb.emb.weights) # check if trainable is fixed in eager mode
            embs += [
                self.cat_emb(input['cat']),
                self.scat_emb(input['sub_cat']),
                self.pooling(self.entity_type_emb(input['title_entity_types']),
                             melt.length(input['title_entity_types'])),
                self.pooling(
                    self.entity_type_emb(input['abstract_entity_types']),
                    melt.length(input['abstract_entity_types'])),
            ]
            if FLAGS.use_entities and 'title_entities' in input:
                embs += [
                    self.pooling(self.entity_emb(input['title_entities']),
                                 melt.length(input['title_entities'])),
                    self.pooling(self.entity_emb(input['abstract_entities']),
                                 melt.length(input['abstract_entities'])),
                ]

        if FLAGS.use_history_info and 'history_cats' in input:
            embs += [
                self.his_simple_pooling(self.cat_emb(input['history_cats']),
                                        melt.length(input['history_cats'])),
                self.his_simple_pooling(
                    self.scat_emb(input['history_sub_cats']),
                    melt.length(input['history_sub_cats'])),
            ]
            if FLAGS.use_history_entities:
                try:
                    embs += [
                        self.his_simple_pooling(
                            self.entity_type_emb(
                                input['history_title_entity_types']),
                            melt.length(input['history_title_entity_types'])),
                        self.his_simple_pooling(
                            self.entity_type_emb(
                                input['history_abstract_entity_types']),
                            melt.length(
                                input['history_abstract_entity_types'])),
                    ]
                    if FLAGS.use_entities and 'title_entities' in inpout:
                        embs += [
                            self.his_simple_pooling(
                                self.entity_emb(
                                    input['history_title_entities']),
                                melt.length(input['history_title_entities'])),
                            self.his_simple_pooling(
                                self.entity_emb(
                                    input['history_abstract_entities']),
                                melt.length(
                                    input['history_abstract_entities'])),
                        ]
                except Exception:
                    pass

        if FLAGS.use_history and FLAGS.use_did:
            dids = input['history']

            if FLAGS.his_strategy == 'bst' or FLAGS.his_pooling == 'mhead':
                mask = tf.cast(tf.equal(dids, 0), dids.dtype)
                dids += mask
                hlen = tf.ones_like(hlen) * 50
            hembs = self.demb(dids)

            his_embs = hembs
            his_embs = self.his_encoder(his_embs, hlen)
            self.his_embs = his_embs

            his_emb = self.his_pooling(demb, his_embs, hlen)

            embs += [his_emb]

        if FLAGS.use_title:
            cur_title = self.title_encoder(self.title_lookup(input['ori_did']))
            dids = input['ori_history']
            if FLAGS.max_titles:
                dids = dids[:, :FLAGS.max_titles]
            his_title = self.titles_encoder(self.title_lookup(dids), hlen,
                                            cur_title)
            embs += [cur_title, his_title]

        # 用impression id 会dev test不一致 不直接用id
        if FLAGS.use_impressions:
            embs += [self.mean_pooling(self.demb(input['impressions']))]

        if FLAGS.use_dense:
            dense_emb = self.deal_dense(input)
            embs += [dense_emb]

        # logging.debug('-----------embs:', len(embs))
        embs = tf.stack(embs, axis=1)

        if FLAGS.batch_norm:
            embs = self.batch_norm(embs)

        if FLAGS.l2_normalize_before_pooling:
            x = tf.math.l2_normalize(embs)

        x = self.feat_pooling(embs)

        if FLAGS.dropout:
            x = self.dropout(x)

        if FLAGS.use_dense:
            x = tf.concat([x, dense_emb], axis=1)

        if FLAGS.use_his_concat:
            x = tf.concat([x, his_concat], axis=1)

        x = self.mlp(x)

        self.logit = self.dense(x)
        self.prob = tf.math.sigmoid(self.logit)
        self.impression_id = input['impression_id']
        self.position = input['position']
        self.history_len = input['hist_len']
        self.impression_len = input['impression_len']
        self.input_ = input
        return self.logit
Exemplo n.º 25
0
 def call(self, x):
   return self.pooling(self.emb(x), mt.length(x))