Exemple #1
0
def model_fn(features, labels, mode, params):
    num_train_steps = 100
    lr = 0.0001
    k = 5
    weights = tf.Variable(tf.random_normal_initializer()(shape=[k, k], dtype=tf.float32))
    cvals = tf.constant(np.random.uniform(-1, 1), tf.float32)
    reader_module_path = '/data/hldai/data/realm_data/cc_news_pretrained/bert'
    # vals_ragged = tf.ragged.constant(list(features))
    # vals_tensor = vals_ragged.to_tensor()
    # vals_tensor = features + global_vals
    vals_tensor = features['vals']
    # tok_id_seq_batch = features['tok_id_seq_batch'].to_tensor()
    # input_mask = features['input_mask']
    z = tf.matmul(vals_tensor, weights) + cvals
    predictions = z
    loss = tf.reduce_mean(z)
    eval_metric_ops = None

    # rand_vals = tf.constant(np.random.randint(100, 105, (4, 5)), dtype=tf.int32)
    # tok_id_seq_batch = tf.concat((features['tok_id_seq_batch'], rand_vals), axis=1).to_tensor()

    # with tf.device("/cpu:0"):
    #     blocks_np = datautils.load_pickle_data(os.path.join(config.DATA_DIR, 'realm_data/blocks_tok_id_seqs.pkl'))
    #     blocks = tf.ragged.constant(blocks_np)
    #     # blocks = tf.ragged.constant([[3, 4], [1], [2, 3, 7], [4]], dtype=tf.int32, ragged_rank=1)
    #     retrieved_block_ids = tf.constant([0, 2])
    #     retrieved_blocks = tf.gather(blocks, retrieved_block_ids).to_tensor()
    # logging_hook = tf.estimator.LoggingTensorHook({"pred": predictions, 'feat': features}, every_n_iter=1)
    # logging_hook = tf.estimator.LoggingTensorHook(
    #     {"pred": predictions, 'labels': labels, 'feat': features['tok_id_seq_batch'],
    #      'ids': retrieved_block_ids}, every_n_iter=1)
    ls = tf.reduce_sum(labels, axis=1)
    logging_hook = tf.estimator.LoggingTensorHook({
        'z': z, 'ls': ls}, every_n_iter=1)

    train_op = optimization.create_optimizer(
        loss=loss,
        init_lr=lr,
        num_train_steps=num_train_steps,
        num_warmup_steps=min(10000, max(100, int(num_train_steps / 10))),
        use_tpu=False)

    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        predictions=predictions,
        # training_hooks=[logging_hook],
        evaluation_hooks=[logging_hook],
        eval_metric_ops=eval_metric_ops)
Exemple #2
0
def model_fn(features, labels, mode, params):
  """Model function."""
  del labels

  # ==============================
  # Input features
  # ==============================
  # [batch_size, query_seq_len]
  query_inputs = features["query_inputs"]

  # [batch_size, num_candidates, candidate_seq_len]
  candidate_inputs = features["candidate_inputs"]

  # [batch_size, num_candidates, query_seq_len + candidate_seq_len]
  joint_inputs = features["joint_inputs"]

  # [batch_size, num_masks]
  mlm_targets = features["mlm_targets"]
  mlm_positions = features["mlm_positions"]
  mlm_mask = features["mlm_mask"]

  # ==============================
  # Create modules.
  # ==============================
  bert_module = hub.Module(
      spec=params["bert_hub_module_handle"],
      name="locbert",
      tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
      trainable=True)
  hub.register_module_for_export(bert_module, "locbert")

  embedder_module = hub.Module(
      spec=params["embedder_hub_module_handle"],
      name="embedder",
      tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
      trainable=True)
  hub.register_module_for_export(embedder_module, "embedder")

  if params["share_embedders"]:
    query_embedder_module = embedder_module
  else:
    query_embedder_module = hub.Module(
        spec=params["embedder_hub_module_handle"],
        name="embedder",
        tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
        trainable=True)
    hub.register_module_for_export(embedder_module, "query_embedder")

  # ==============================
  # Retrieve.
  # ==============================
  # [batch_size, projected_size]
  query_emb = query_embedder_module(
      inputs=dict(
          input_ids=query_inputs.token_ids,
          input_mask=query_inputs.mask,
          segment_ids=query_inputs.segment_ids),
      signature="projected")

  # [batch_size * num_candidates, candidate_seq_len]
  flat_candidate_inputs, unflatten = flatten_bert_inputs(
      candidate_inputs)

  # [batch_size * num_candidates, projected_size]
  flat_candidate_emb = embedder_module(
      inputs=dict(
          input_ids=flat_candidate_inputs.token_ids,
          input_mask=flat_candidate_inputs.mask,
          segment_ids=flat_candidate_inputs.segment_ids),
      signature="projected")

  # [batch_size, num_candidates, projected_size]
  unflattened_candidate_emb = unflatten(flat_candidate_emb)

  # [batch_size, num_candidates]
  retrieval_score = tf.einsum("BD,BND->BN", query_emb,
                              unflattened_candidate_emb)

  # ==============================
  # Read.
  # ==============================
  # [batch_size * num_candidates, query_seq_len + candidate_seq_len]
  flat_joint_inputs, unflatten = flatten_bert_inputs(joint_inputs)

  # [batch_size * num_candidates, num_masks]
  flat_mlm_positions, _ = tensor_utils.flatten(
      tf.tile(
          tf.expand_dims(mlm_positions, 1), [1, params["num_candidates"], 1]))

  batch_size, num_masks = tensor_utils.shape(mlm_targets)

  # [batch_size * num_candidates, query_seq_len + candidates_seq_len]
  flat_joint_bert_outputs = bert_module(
      inputs=dict(
          input_ids=flat_joint_inputs.token_ids,
          input_mask=flat_joint_inputs.mask,
          segment_ids=flat_joint_inputs.segment_ids,
          mlm_positions=flat_mlm_positions),
      signature="mlm",
      as_dict=True)

  # [batch_size, num_candidates]
  candidate_score = retrieval_score

  # [batch_size, num_candidates]
  candidate_log_probs = tf.math.log_softmax(candidate_score)

  # ==============================
  # Compute marginal log-likelihood.
  # ==============================
  # [batch_size * num_candidates, num_masks]
  flat_mlm_logits = flat_joint_bert_outputs["mlm_logits"]

  # [batch_size, num_candidates, num_masks, vocab_size]
  mlm_logits = tf.reshape(
      flat_mlm_logits, [batch_size, params["num_candidates"], num_masks, -1])
  mlm_log_probs = tf.math.log_softmax(mlm_logits)

  # [batch_size, num_candidates, num_masks]
  tiled_mlm_targets = tf.tile(
      tf.expand_dims(mlm_targets, 1), [1, params["num_candidates"], 1])

  # [batch_size, num_candidates, num_masks, 1]
  tiled_mlm_targets = tf.expand_dims(tiled_mlm_targets, -1)

  # [batch_size, num_candidates, num_masks, 1]
  gold_log_probs = tf.batch_gather(mlm_log_probs, tiled_mlm_targets)

  # [batch_size, num_candidates, num_masks]
  gold_log_probs = tf.squeeze(gold_log_probs, -1)

  # [batch_size, num_candidates, num_masks]
  joint_gold_log_probs = (
      tf.expand_dims(candidate_log_probs, -1) + gold_log_probs)

  # [batch_size, num_masks]
  marginal_gold_log_probs = tf.reduce_logsumexp(joint_gold_log_probs, 1)

  # [batch_size, num_masks]
  float_mlm_mask = tf.cast(mlm_mask, tf.float32)

  # []
  loss = -tf.div_no_nan(
      tf.reduce_sum(marginal_gold_log_probs * float_mlm_mask),
      tf.reduce_sum(float_mlm_mask))

  # ==============================
  # Optimization
  # ==============================
  num_warmup_steps = min(10000, max(100, int(params["num_train_steps"] / 10)))
  train_op = optimization.create_optimizer(
      loss=loss,
      init_lr=params["learning_rate"],
      num_train_steps=params["num_train_steps"],
      num_warmup_steps=num_warmup_steps,
      use_tpu=params["use_tpu"])

  # ==============================
  # Evaluation
  # ==============================
  eval_metric_ops = None if params["use_tpu"] else dict()
  if mode != tf.estimator.ModeKeys.PREDICT:
    # [batch_size, num_masks]
    retrieval_utility = marginal_gold_log_probs - gold_log_probs[:, 0]
    retrieval_utility *= tf.cast(features["mlm_mask"], tf.float32)

    # []
    retrieval_utility = tf.div_no_nan(
        tf.reduce_sum(retrieval_utility), tf.reduce_sum(float_mlm_mask))
    add_mean_metric("retrieval_utility", retrieval_utility, eval_metric_ops)

    has_timestamp = tf.cast(
        tf.greater(features["export_timestamp"], 0), tf.float64)
    off_policy_delay_secs = (
        tf.timestamp() - tf.cast(features["export_timestamp"], tf.float64))
    off_policy_delay_mins = off_policy_delay_secs / 60.0
    off_policy_delay_mins *= tf.cast(has_timestamp, tf.float64)

    add_mean_metric("off_policy_delay_mins", off_policy_delay_mins,
                    eval_metric_ops)

  # Create empty predictions to avoid errors when running in prediction mode.
  predictions = dict()

  if params["use_tpu"]:
    return tf.estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        predictions=predictions)
  else:
    if eval_metric_ops is not None:
      # Make sure the eval metrics are updated during training so that we get
      # quick feedback from tensorboard summaries when debugging locally.
      with tf.control_dependencies([u for _, u in eval_metric_ops.values()]):
        loss = tf.identity(loss)
    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        predictions=predictions)
Exemple #3
0
def model_fn(features, labels, mode, params):
    # print('MMMMMMMMMMMMMMMMMModel_fn', mode)
    embedder_module_path = params['embedder_module_path']
    reader_module_path = params['reader_module_path']
    # embedder_module_path = os.path.join(config.DATA_DIR, 'realm_data/cc_news_pretrained/embedder')
    # reader_module_path = os.path.join(config.DATA_DIR, 'realm_data/cc_news_pretrained/bert')
    lr = params['lr']
    num_train_steps = params['num_train_steps']
    max_seq_len = params['max_seq_len']
    bert_dim = params['bert_dim']
    n_types = params['n_types']
    sep_tok_id = params['sep_tok_id']
    retriever_beam_size = params['retriever_beam_size']
    block_records_path = params['block_records_path']
    num_block_records = params['num_block_records']
    train_log_steps = params['train_log_steps']
    eval_log_steps = params['eval_log_steps']

    # token_ids = tf.constant([[101, 2002, 2003, 1037, 3836, 1012, 102]], dtype=tf.int32)
    tok_id_seq_batch_tensor = features['tok_id_seq_batch'].to_tensor()
    input_mask = 1 - tf.cast(tf.equal(tok_id_seq_batch_tensor, tf.constant(0)),
                             tf.int32)

    block_labels_np = pre_load_data.get('labels', None)
    block_labels = tf.constant(
        block_labels_np, tf.int32) if block_labels_np is not None else None

    # input_mask = features['input_mask']
    with tf.device("/cpu:0"):
        # retriever_outputs = retrieve(tok_id_seq_batch, input_mask, embedder_module_path, mode, retriever_beam_size)
        scaffold, retrieved_block_ids, retrieved_blocks, zx_logits = retrieve(
            tok_id_seq_batch_tensor, input_mask, embedder_module_path, mode,
            block_records_path, retriever_beam_size, num_block_records)
        # scaffold, question_emb, retrieved_block_ids = retrieve(
        #     tok_id_seq_batch, input_mask, embedder_module_path, mode, retriever_beam_size)

    retrieved_labels = tf.gather(
        block_labels,
        retrieved_block_ids[0]) if block_labels is not None else None
    retrieved_label_vecs = get_one_hot_label_vecs(retrieved_labels, n_types)

    tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer(
        reader_module_path)
    block_tok_id_seqs = tokenizer.tokenize(retrieved_blocks)
    block_tok_id_seqs = tf.cast(
        block_tok_id_seqs.merge_dims(2, 3).to_tensor(), tf.int32)
    # batch_size = tf.shape(tok_id_seq_batch_tensor)[0]
    blocks_max_seq_len = tf.shape(block_tok_id_seqs)[-1]
    block_tok_id_seqs_flat = tf.reshape(block_tok_id_seqs,
                                        (-1, blocks_max_seq_len))

    tok_id_seqs_repeat = features['tok_id_seqs_repeat']
    # tok_id_seqs_repeat = features['tok_id_seqs_repeat'].to_tensor()
    q_doc_tok_id_seqs = tf.concat((tok_id_seqs_repeat, block_tok_id_seqs_flat),
                                  axis=1).to_tensor()
    q_doc_tok_id_seqs = q_doc_tok_id_seqs[:, :max_seq_len - 1]
    q_doc_tok_id_seqs = pad_sep_to_tensor(q_doc_tok_id_seqs, sep_tok_id)

    q_doc_input_mask = 1 - tf.cast(
        tf.equal(q_doc_tok_id_seqs, tf.constant(0, dtype=tf.int32)), tf.int32)

    reader_module = hub.Module(
        reader_module_path,
        tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
        trainable=True)

    # input_mask = tf.sequence_mask(lengths, seqs_shape[1])
    # input_mask = tf.cast(input_mask, tf.int32)
    concat_outputs = reader_module(
        dict(
            # input_ids=tok_id_seq_batch,
            # input_mask=tf.ones_like(tok_id_seq_batch),
            # segment_ids=tf.zeros_like(tok_id_seq_batch)
            # segment_ids=concat_inputs.segment_ids
            input_ids=q_doc_tok_id_seqs,
            # input_mask=tf.ones_like(r_tok_id_seqs),
            input_mask=q_doc_input_mask,
            segment_ids=tf.zeros_like(q_doc_tok_id_seqs)),
        signature="tokens",
        as_dict=True)
    # # predictions = retriever_outputs.logits
    #
    # concat_token_emb = concat_outputs["sequence_output"]
    qd_reps = concat_outputs['sequence_output'][:, 0, :]

    # dense_layer = tf.keras.layers.Dense(n_types)
    # yzx_logits = dense_layer(qd_reps)
    dense_weights = tf.Variable(initial_value=np.random.uniform(
        -0.1, 0.1, (bert_dim, n_types)),
                                trainable=True,
                                dtype=tf.float32)

    zx_logits = tf.reshape(zx_logits, (-1, retriever_beam_size))
    log_softmax_zx_logits = tf.nn.log_softmax(zx_logits, axis=1)

    yzx_logits = tf.matmul(qd_reps, dense_weights)
    # weight_sum = tf.reduce_sum(dense_layer.weights)
    yzx_logits = tf.reshape(yzx_logits, (-1, retriever_beam_size, n_types))
    log_sig_yzx_logits = tf.math.log_sigmoid(yzx_logits)
    z_log_probs = log_sig_yzx_logits + tf.expand_dims(log_softmax_zx_logits, 2)
    log_probs = tf.reduce_logsumexp(z_log_probs, axis=1)
    log_neg_probs = tfp.math.log1mexp(log_probs)
    # prob_sum = tf.math.exp(log_probs) + tf.math.exp(log_neg_probs)

    # kernel_initializer = tf.truncated_normal_initializer(stddev=0.02)
    # projection = tf.layers.dense(
    #     qd_reps,
    #     bert_dim,
    #     kernel_initializer=kernel_initializer)
    # yzx_logits =

    probs = tf.exp(log_probs)
    # loss = tf.reduce_mean(predictions)

    loss = None
    eval_metric_ops = None
    train_op = None
    if mode != tf.estimator.ModeKeys.PREDICT:
        loss_samples = -tf.reduce_sum(
            labels * log_probs + (1 - labels) * log_neg_probs, axis=1)
        loss = tf.reduce_mean(loss_samples)
        # loss = tf.reduce_mean(loss_samples) + 0.00001 * tf.reduce_mean(question_emb)

        train_op = optimization.create_optimizer(
            loss=loss,
            init_lr=lr,
            num_train_steps=num_train_steps,
            num_warmup_steps=min(10000, max(100, int(num_train_steps / 10))),
            use_tpu=False)

        small_constant = tf.constant(0.00001)
        pos_preds = tf.cast(tf.less(tf.constant(0.5), probs), tf.float32)
        n_pred_pos = tf.reduce_sum(pos_preds, axis=1) + small_constant
        n_true_pos = tf.reduce_sum(labels, axis=1) + small_constant
        n_corrects = tf.reduce_sum(pos_preds * labels, axis=1)
        precision = tf.reduce_mean(n_corrects / n_pred_pos)
        recall = tf.reduce_mean(n_corrects / n_true_pos)

        p_mean, p_op = tf.compat.v1.metrics.mean(precision)
        r_mean, r_op = tf.compat.v1.metrics.mean(recall)
        f1 = 2 * p_mean * r_mean / (p_mean + r_mean + small_constant)

        eval_metric_ops = {
            # 'precision': tf.compat.v1.metrics.mean(precision),
            # 'recall': tf.compat.v1.metrics.mean(recall)
            'precision': (p_mean, p_op),
            'recall': (r_mean, r_op),
            'f1': (f1, tf.group(p_op, r_op))
        }

    # tmp_blocks = tf.constant(['i you date', 'sh ij ko', 'day in day'])
    # blocks_has_answer = orqa_ops.has_answer(blocks=tmp_blocks, answers=features['labels'][0])
    # tmp_blocks = tf.constant([[1, 2], [3, 4]], tf.int32)
    # blocks_has_answer = orqa_ops.zero_out(features['tmp'])

    train_logging_hook = tf.estimator.LoggingTensorHook(
        {
            'batch_id': features['batch_id'],
            'loss': loss,
            'ws': tf.reduce_sum(dense_weights),
            'yzx_logits': tf.reduce_sum(yzx_logits),
            'tmp': retrieved_label_vecs,
            'tmp1': retrieved_labels,
        },
        every_n_iter=train_log_steps)
    logging_hook = tf.estimator.LoggingTensorHook(
        {
            'batch_id': features['batch_id'],
            'loss': loss,
            # 'pred': tf.reduce_mean(predictions),
            # 'pred': log_probs,
            'tmp': tf.reduce_sum(dense_weights),
        },
        every_n_iter=eval_log_steps)
    pred_logging_hook = tf.estimator.LoggingTensorHook(
        {
            'labels': features['labels'][0],
            # 'ha': blocks_has_answer,
            # 'hatmp': features['tmp'],
            # 'bl': retrieved_blocks
        },
        every_n_iter=eval_log_steps)

    predictions = {
        'probs': probs,
        'text_ids': features['text_ids'],
        'block_ids': retrieved_block_ids
    }

    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        predictions=predictions,
        training_hooks=[train_logging_hook],
        evaluation_hooks=[logging_hook],
        # prediction_hooks=[pred_logging_hook],
        eval_metric_ops=eval_metric_ops,
        scaffold=scaffold)
Exemple #4
0
def model_fn(features, labels, mode, params):
    """Model function."""
    if labels is None:
        labels = tf.constant([""])

    reader_beam_size = params["reader_beam_size"]
    if mode == tf.estimator.ModeKeys.PREDICT:
        retriever_beam_size = reader_beam_size
    else:
        retriever_beam_size = params["retriever_beam_size"]
    assert reader_beam_size <= retriever_beam_size

    with tf.device("/cpu:0"):
        retriever_outputs = retrieve(features=features,
                                     retriever_beam_size=retriever_beam_size,
                                     mode=mode,
                                     params=params)

    with tf.variable_scope("reader"):
        reader_outputs = read(
            features=features,
            retriever_logits=retriever_outputs.logits[:reader_beam_size],
            blocks=retriever_outputs.blocks[:reader_beam_size],
            mode=mode,
            params=params,
            labels=labels)

    predictions = get_predictions(reader_outputs, params)

    if mode == tf.estimator.ModeKeys.PREDICT:
        loss = None
        train_op = None
        eval_metric_ops = None
    else:
        # [retriever_beam_size]
        retriever_correct = orqa_ops.has_answer(
            blocks=retriever_outputs.blocks, answers=labels)

        # [reader_beam_size, num_candidates]
        reader_correct = compute_correct_candidates(
            candidate_starts=reader_outputs.candidate_starts,
            candidate_ends=reader_outputs.candidate_ends,
            gold_starts=reader_outputs.gold_starts,
            gold_ends=reader_outputs.gold_ends)

        eval_metric_ops = compute_eval_metrics(
            labels=labels,
            predictions=predictions,
            retriever_correct=retriever_correct,
            reader_correct=reader_correct)

        # []
        loss = compute_loss(retriever_logits=retriever_outputs.logits,
                            retriever_correct=retriever_correct,
                            reader_logits=reader_outputs.logits,
                            reader_correct=reader_correct)

        train_op = optimization.create_optimizer(
            loss=loss,
            init_lr=params["learning_rate"],
            num_train_steps=params["num_train_steps"],
            num_warmup_steps=min(10000,
                                 max(100,
                                     int(params["num_train_steps"] / 10))),
            use_tpu=False)

    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op,
                                      predictions=predictions,
                                      eval_metric_ops=eval_metric_ops)