Esempio n. 1
0
				def metric_fn(per_example_loss, label_ids, logits, is_real_example):
					"""Compute Matthew's correlations for STS-B."""
					predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
					# https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
					tp, tp_op = tf.metrics.true_positives(
							predictions, label_ids, weights=is_real_example)
					tn, tn_op = tf.metrics.true_negatives(
							predictions, label_ids, weights=is_real_example)
					fp, fp_op = tf.metrics.false_positives(
							predictions, label_ids, weights=is_real_example)
					fn, fn_op = tf.metrics.false_negatives(
							predictions, label_ids, weights=is_real_example)

					# Compute Matthew's correlation
					mcc = tf.div_no_nan(
							tp * tn - fp * fn,
							tf.pow((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn), 0.5))

					# Compute accuracy
					accuracy = tf.metrics.accuracy(
							labels=label_ids, predictions=predictions,
							weights=is_real_example)

					loss = tf.metrics.mean(
							values=per_example_loss,
							weights=is_real_example)

					return {"matthew_corr": (mcc, tf.group(tp_op, tn_op, fp_op, fn_op)),
									"eval_accuracy": accuracy, "eval_loss": loss,}
Esempio n. 2
0
def f1_score(answer_ids, prediction_ids, vocab):
    """Compute F1 score between answer tokens and prediction tokens.

  Args:
    answer_ids: <int32> [batch_size, answer_length]
    prediction_ids: <int32> [batch_size, prediction_length]
    vocab: Instance of text_utils.Vocab.

  Returns:
    score: <float32> [batch_size] tensor of [0.0, 1.0].
  """
    # Order insensitive, so we just create a vocabulary sized bit tensor where
    # the vocabulary items that are not to be counted are masked out.
    vocab_size = len(vocab)
    remove_ids = list(_get_normalized_set(vocab))
    remove_mask = tf.expand_dims(tf.one_hot(remove_ids, vocab_size), 0)
    remove_mask = tf.reduce_sum(remove_mask, axis=1)
    remove_mask = tf.cast(tf.equal(remove_mask, 0), tf.float32)

    # [batch_size, vocab_size]
    answer_ids = tf.reduce_sum(tf.one_hot(answer_ids, vocab_size), axis=1)
    answer_ids *= remove_mask

    # [batch_size, vocab_size]
    prediction_ids = tf.reduce_sum(tf.one_hot(prediction_ids, vocab_size),
                                   axis=1)
    prediction_ids *= remove_mask

    # Compute multiset intersection, and count the size.
    intersection = tf.minimum(prediction_ids, answer_ids)
    intersection = tf.reduce_sum(intersection, axis=1)

    # Compute F1 score:
    #   Re(A, B) = |A \cap B| / |B|
    #   Pr(A, B) = |A \cap B| / |A|
    #   F1(A, B) = 2 * (Pr * Re) / (Pr + Re)
    recall = tf.div_no_nan(intersection, tf.reduce_sum(answer_ids, axis=1))
    precision = tf.div_no_nan(intersection,
                              tf.reduce_sum(prediction_ids, axis=1))
    score = 2 * tf.div_no_nan(precision * recall, precision + recall)

    return score
Esempio n. 3
0
    def _build_model(self):
        self.graph_built = True
        tf.set_random_seed(self.seed)
        self.user_indices = tf.placeholder(tf.int32, shape=[None])
        self.item_indices = tf.placeholder(tf.int32, shape=[None])
        self.user_interacted_seq = tf.placeholder(
            tf.int32, shape=[None, self.interaction_num])
        self.user_interacted_len = tf.placeholder(tf.float32, shape=[None])
        self.labels = tf.placeholder(tf.float32, shape=[None])
        self.is_training = tf.placeholder_with_default(False, shape=[])
        self.concat_embed = []

        user_features = tf.get_variable(
            name="user_features",
            shape=[self.n_users + 1, self.embed_size],
            initializer=tf_truncated_normal(0.0, 0.01),
            regularizer=self.reg)
        item_features = tf.get_variable(
            name="item_features",
            shape=[self.n_items + 1, self.embed_size],
            initializer=tf_truncated_normal(0.0, 0.01),
            regularizer=self.reg)
        user_embed = tf.nn.embedding_lookup(user_features, self.user_indices)
        item_embed = tf.nn.embedding_lookup(item_features, self.item_indices)

        # unknown items are padded to 0-vector
        zero_padding_op = tf.scatter_update(
            item_features, self.n_items,
            tf.zeros([self.embed_size], dtype=tf.float32))
        with tf.control_dependencies([zero_padding_op]):
            multi_item_embed = tf.nn.embedding_lookup(
                item_features, self.user_interacted_seq)  # B * seq * K
        pooled_embed = tf.div_no_nan(
            tf.reduce_sum(multi_item_embed, axis=1),
            tf.expand_dims(tf.sqrt(self.user_interacted_len), axis=1))
        self.concat_embed.extend([user_embed, item_embed, pooled_embed])

        if self.sparse:
            self._build_sparse()
        if self.dense:
            self._build_dense()

        concat_embed = tf.concat(self.concat_embed, axis=1)
        mlp_layer = dense_nn(concat_embed,
                             self.hidden_units,
                             use_bn=self.use_bn,
                             dropout_rate=self.dropout_rate,
                             is_training=self.is_training)
        self.output = tf.reshape(tf.layers.dense(inputs=mlp_layer, units=1),
                                 [-1])
        count_params()
Esempio n. 4
0
    def _attention_unit(self, queries, keys, keys_len):
        if self.use_tf_attention:
            query_masks = tf.cast(
                tf.ones_like(tf.reshape(self.user_interacted_len, [-1, 1])),
                dtype=tf.bool
            )
            key_masks = tf.sequence_mask(
                self.user_interacted_len, self.max_seq_len
            )
            queries = tf.expand_dims(queries, axis=1)
            attention = tf.keras.layers.Attention(use_scale=False)
            pooled_outputs = attention(inputs=[queries, keys],
                                       mask=[query_masks, key_masks])
            return pooled_outputs
        else:
            # queries: B * K, keys: B * seq * K
            queries = tf.expand_dims(queries, axis=1)
            # B * seq * K
            queries = tf.tile(queries, [1, self.max_seq_len, 1])
            queries_keys_cross = tf.concat(
                [queries, keys, queries - keys, queries * keys], axis=2)
            mlp_layer = dense_nn(queries_keys_cross, (16,), use_bn=False,
                                 activation=tf.nn.sigmoid, name="attention")
            # B * seq * 1
            mlp_layer = tf.layers.dense(mlp_layer, units=1, activation=None)
            # attention_weights = tf.transpose(mlp_layer, [0, 2, 1])
            attention_weights = tf.layers.flatten(mlp_layer)

            key_masks = tf.sequence_mask(keys_len, self.max_seq_len)
            paddings = tf.ones_like(attention_weights) * (-2**32 + 1)
            attention_scores = tf.where(key_masks, attention_weights, paddings)
            attention_scores = tf.div_no_nan(
                attention_scores,
                tf.sqrt(
                    tf.cast(keys.get_shape().as_list()[-1], tf.float32)
                )
            )
            # B * 1 * seq
            attention_scores = tf.expand_dims(
                tf.nn.softmax(attention_scores), 1)
            # B * 1 * K
            pooled_outputs = attention_scores @ keys
            return pooled_outputs
Esempio n. 5
0
def spherical_normalization(x, rectify=True):
    """Apply area weights and normalization to spherical distributions.

  The sum of all pixel values over the spherical input will be one.

  Args:
    x: [BATCH, HEIGHT, WIDTH, CHANNELS] spherical raw distributions.
    rectify: apply softplus to the input x if true.

  Returns:
    [BATCH, HEIGHT, WIDTH, CHANNELS] normalized distributions.
  """
    with tf.name_scope(None, 'spherical_normalization', [x]):
        # Apply softplus to make the input non-negative.
        shape = x.shape.as_list()
        height = shape[1]
        if rectify:
            x = tf.nn.softplus(x)
        weighted = x * equirectangular_area_weights(height)
        # Return shape [BATCH, HEIGHT, WIDTH, CHANNELS].
        return tf.div_no_nan(
            x, tf.reduce_sum(weighted, axis=[1, 2], keepdims=True))
Esempio n. 6
0
def model_fn(features, labels, mode, params):
  """Model function."""
  del labels

  # ==============================
  # Input features
  # ==============================
  # [batch_size, query_seq_len]
  query_inputs = features["query_inputs"]

  # [batch_size, num_candidates, candidate_seq_len]
  candidate_inputs = features["candidate_inputs"]

  # [batch_size, num_candidates, query_seq_len + candidate_seq_len]
  joint_inputs = features["joint_inputs"]

  # [batch_size, num_masks]
  mlm_targets = features["mlm_targets"]
  mlm_positions = features["mlm_positions"]
  mlm_mask = features["mlm_mask"]

  # ==============================
  # Create modules.
  # ==============================
  bert_module = hub.Module(
      spec=params["bert_hub_module_handle"],
      name="locbert",
      tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
      trainable=True)
  hub.register_module_for_export(bert_module, "locbert")

  embedder_module = hub.Module(
      spec=params["embedder_hub_module_handle"],
      name="embedder",
      tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
      trainable=True)
  hub.register_module_for_export(embedder_module, "embedder")

  if params["share_embedders"]:
    query_embedder_module = embedder_module
  else:
    query_embedder_module = hub.Module(
        spec=params["embedder_hub_module_handle"],
        name="embedder",
        tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
        trainable=True)
    hub.register_module_for_export(embedder_module, "query_embedder")

  # ==============================
  # Retrieve.
  # ==============================
  # [batch_size, projected_size]
  query_emb = query_embedder_module(
      inputs=dict(
          input_ids=query_inputs.token_ids,
          input_mask=query_inputs.mask,
          segment_ids=query_inputs.segment_ids),
      signature="projected")

  # [batch_size * num_candidates, candidate_seq_len]
  flat_candidate_inputs, unflatten = flatten_bert_inputs(
      candidate_inputs)

  # [batch_size * num_candidates, projected_size]
  flat_candidate_emb = embedder_module(
      inputs=dict(
          input_ids=flat_candidate_inputs.token_ids,
          input_mask=flat_candidate_inputs.mask,
          segment_ids=flat_candidate_inputs.segment_ids),
      signature="projected")

  # [batch_size, num_candidates, projected_size]
  unflattened_candidate_emb = unflatten(flat_candidate_emb)

  # [batch_size, num_candidates]
  retrieval_score = tf.einsum("BD,BND->BN", query_emb,
                              unflattened_candidate_emb)

  # ==============================
  # Read.
  # ==============================
  # [batch_size * num_candidates, query_seq_len + candidate_seq_len]
  flat_joint_inputs, unflatten = flatten_bert_inputs(joint_inputs)

  # [batch_size * num_candidates, num_masks]
  flat_mlm_positions, _ = tensor_utils.flatten(
      tf.tile(
          tf.expand_dims(mlm_positions, 1), [1, params["num_candidates"], 1]))

  batch_size, num_masks = tensor_utils.shape(mlm_targets)

  # [batch_size * num_candidates, query_seq_len + candidates_seq_len]
  flat_joint_bert_outputs = bert_module(
      inputs=dict(
          input_ids=flat_joint_inputs.token_ids,
          input_mask=flat_joint_inputs.mask,
          segment_ids=flat_joint_inputs.segment_ids,
          mlm_positions=flat_mlm_positions),
      signature="mlm",
      as_dict=True)

  # [batch_size, num_candidates]
  candidate_score = retrieval_score

  # [batch_size, num_candidates]
  candidate_log_probs = tf.math.log_softmax(candidate_score)

  # ==============================
  # Compute marginal log-likelihood.
  # ==============================
  # [batch_size * num_candidates, num_masks]
  flat_mlm_logits = flat_joint_bert_outputs["mlm_logits"]

  # [batch_size, num_candidates, num_masks, vocab_size]
  mlm_logits = tf.reshape(
      flat_mlm_logits, [batch_size, params["num_candidates"], num_masks, -1])
  mlm_log_probs = tf.math.log_softmax(mlm_logits)

  # [batch_size, num_candidates, num_masks]
  tiled_mlm_targets = tf.tile(
      tf.expand_dims(mlm_targets, 1), [1, params["num_candidates"], 1])

  # [batch_size, num_candidates, num_masks, 1]
  tiled_mlm_targets = tf.expand_dims(tiled_mlm_targets, -1)

  # [batch_size, num_candidates, num_masks, 1]
  gold_log_probs = tf.batch_gather(mlm_log_probs, tiled_mlm_targets)

  # [batch_size, num_candidates, num_masks]
  gold_log_probs = tf.squeeze(gold_log_probs, -1)

  # [batch_size, num_candidates, num_masks]
  joint_gold_log_probs = (
      tf.expand_dims(candidate_log_probs, -1) + gold_log_probs)

  # [batch_size, num_masks]
  marginal_gold_log_probs = tf.reduce_logsumexp(joint_gold_log_probs, 1)

  # [batch_size, num_masks]
  float_mlm_mask = tf.cast(mlm_mask, tf.float32)

  # []
  loss = -tf.div_no_nan(
      tf.reduce_sum(marginal_gold_log_probs * float_mlm_mask),
      tf.reduce_sum(float_mlm_mask))

  # ==============================
  # Optimization
  # ==============================
  num_warmup_steps = min(10000, max(100, int(params["num_train_steps"] / 10)))
  train_op = optimization.create_optimizer(
      loss=loss,
      init_lr=params["learning_rate"],
      num_train_steps=params["num_train_steps"],
      num_warmup_steps=num_warmup_steps,
      use_tpu=params["use_tpu"])

  # ==============================
  # Evaluation
  # ==============================
  eval_metric_ops = None if params["use_tpu"] else dict()
  if mode != tf.estimator.ModeKeys.PREDICT:
    # [batch_size, num_masks]
    retrieval_utility = marginal_gold_log_probs - gold_log_probs[:, 0]
    retrieval_utility *= tf.cast(features["mlm_mask"], tf.float32)

    # []
    retrieval_utility = tf.div_no_nan(
        tf.reduce_sum(retrieval_utility), tf.reduce_sum(float_mlm_mask))
    add_mean_metric("retrieval_utility", retrieval_utility, eval_metric_ops)

    has_timestamp = tf.cast(
        tf.greater(features["export_timestamp"], 0), tf.float64)
    off_policy_delay_secs = (
        tf.timestamp() - tf.cast(features["export_timestamp"], tf.float64))
    off_policy_delay_mins = off_policy_delay_secs / 60.0
    off_policy_delay_mins *= tf.cast(has_timestamp, tf.float64)

    add_mean_metric("off_policy_delay_mins", off_policy_delay_mins,
                    eval_metric_ops)

  # Create empty predictions to avoid errors when running in prediction mode.
  predictions = dict()

  if params["use_tpu"]:
    return tf.estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        predictions=predictions)
  else:
    if eval_metric_ops is not None:
      # Make sure the eval metrics are updated during training so that we get
      # quick feedback from tensorboard summaries when debugging locally.
      with tf.control_dependencies([u for _, u in eval_metric_ops.values()]):
        loss = tf.identity(loss)
    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        predictions=predictions)
Esempio n. 7
0
                def metric_fn(per_example_loss, label_ids, logits,
                              is_real_example):
                    """Compute Matthew's correlations for COLA."""
                    predictions = tf.argmax(logits,
                                            axis=-1,
                                            output_type=tf.int32)
                    # https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
                    tp, tp_op = tf.metrics.true_positives(
                        labels=label_ids,
                        predictions=predictions,
                        weights=is_real_example)
                    tn, tn_op = tf.metrics.true_negatives(
                        labels=label_ids,
                        predictions=predictions,
                        weights=is_real_example)
                    fp, fp_op = tf.metrics.false_positives(
                        labels=label_ids,
                        predictions=predictions,
                        weights=is_real_example)
                    fn, fn_op = tf.metrics.false_negatives(
                        labels=label_ids,
                        predictions=predictions,
                        weights=is_real_example)

                    # computing precision, recall and f1 score
                    # Added for BioAlbert
                    precision = tf_metrics.precision(label_ids,
                                                     predictions,
                                                     num_labels, [1, 2],
                                                     average="micro")
                    recall = tf_metrics.recall(label_ids,
                                               predictions,
                                               num_labels, [1, 2],
                                               average="micro")
                    f1 = tf_metrics.f1(label_ids,
                                       predictions,
                                       num_labels, [1, 2],
                                       average="micro")

                    # Compute Matthew's correlation
                    mcc = tf.div_no_nan(
                        tp * tn - fp * fn,
                        tf.pow((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn),
                               0.5))

                    # Compute accuracy
                    accuracy = tf.metrics.accuracy(labels=label_ids,
                                                   predictions=predictions,
                                                   weights=is_real_example)

                    loss = tf.metrics.mean(values=per_example_loss,
                                           weights=is_real_example)

                    return {
                        "matthew_corr":
                        (mcc, tf.group(tp_op, tn_op, fp_op, fn_op)),
                        "accuracy": accuracy,
                        "eval_loss": loss,
                        "precision": precision,
                        "recall": recall,
                        "f1_score": f1,
                    }
def model_fn(features, labels, mode, params, vocab):
    """Model function that satisfies the Estimator API.

  Args:
    features: Dictionary of model input tensors.
    labels: Ununsed.
    mode: A tf.estimator.ModeKeys value.
    params: Dictionary of model parameters.
    vocab: A utils.text_utils.Vocab instance.

  Returns:
    spec: A tf.estimator.TPUEstimatorSpec.
  """
    del labels

    # ----------------------------------------------------------------------------
    # INITIALIZATION.
    # ----------------------------------------------------------------------------

    # Update model config from the pre-trained checkpoint.
    model = transformer_utils.TransformerModel(
        config=transformer_utils.TransformerConfig.from_dict(params),
        is_training=(mode == tf_estimator.ModeKeys.TRAIN))

    # Initialize QA model.
    rc_model = hub.Module(params["rc_model"])

    # image_features: [batch_size, num_regions, feature_size]
    # image_positions: [batch_size, num_regions]
    # image_mask: [batch_size, num_regions]
    image_features = features["object_features"].features
    image_positions = features["object_features"].positions
    image_mask = features["object_features"].mask

    # Expand mask by 1 to account for the leading [IMG] token.
    # [batch_size, num_regions + 1]
    batch_size = tensor_utils.shape(image_mask, 0)
    input_mask = tf.pad(image_mask, [[0, 0], [1, 0]], constant_values=1)

    # Encode the image and store the cached transformer values.
    # [batch_size, num_regions + 1, num_layers, num_heads, head_size]
    _, input_cache = model.compute_image_transformer(
        input_ids=tf.fill([batch_size, 1], vocab.t2i(vocab.IMG)),
        input_image=image_features,
        input_image_mask=input_mask,
        input_positions=image_positions)

    # ----------------------------------------------------------------------------
    # TRAINING
    # ----------------------------------------------------------------------------

    if mode == tf_estimator.ModeKeys.TRAIN:
        # MIXER-style training objective consists of two parts:
        #   1) Policy gradient on rewarded rollouts.
        #   2) MLE regularization on references.
        # The full loss is L_total = L_pg + L_mle.

        # Step 1: Policy gradient.
        # Compute and score policy rollouts (multiple per image).
        rollouts = reward_utils.compute_rollouts(model=model,
                                                 rc_model=rc_model,
                                                 features=features,
                                                 encoder_cache=input_cache,
                                                 encoder_cache_mask=input_mask,
                                                 vocab=vocab,
                                                 params=params)

        # Using a self-critical baseline, R'(y) = R(y) - b where b = argmax p(y|x),
        # sample a single rollout with non-zero reward.
        rollout, reward = reward_utils.sample_from_rollouts(
            rollouts=rollouts,
            baseline=rollouts.rewards[params["reward"]][:, 0],
            reward_type=params["reward"])

        # Compute the probablity of the rollout (back-propable).
        # [batch_size, decode_length, input_length + decode_length]
        rollout_attention_mask = transformer_utils.compute_attention_mask(
            token_mask=rollout.mask[:, :-1], input_mask=input_mask)

        # [batch_size, decode_length, vocab_size]
        rollout_emb, _ = model.compute_transformer(
            input_ids=rollout.token_ids[:, :-1],
            input_segment_id=rollout.segment_ids[:, :-1],
            input_positions=rollout.positions[:, :-1],
            attention_mask=rollout_attention_mask,
            input_cache=input_cache,
            reuse=tf.AUTO_REUSE)

        # [batch_size, decode_length, vocab_size]
        rollout_logits = model.compute_logits(rollout_emb, reuse=tf.AUTO_REUSE)

        # Compute the RL loss, -R(y) * log p(y|x)
        # Some elements in this batch are MLE only, mask those out from the loss.
        rollout_mask = tf.cast(rollout.mask[:, 1:], tf.float32)
        pg_mask = tf.equal(features["input_type"], datasets.DatasetTypes.VQA)
        rollout_mask *= tf.expand_dims(tf.cast(pg_mask, tf.float32), 1)
        rl_loss = tf.losses.sparse_softmax_cross_entropy(
            labels=rollout.token_ids[:, 1:],
            logits=rollout_logits,
            weights=tf.expand_dims(reward, 1) * rollout_mask,
            reduction=tf.losses.Reduction.SUM)
        rl_loss = tf.math.divide_no_nan(rl_loss, tf.reduce_sum(rollout_mask))

        # Step 2: MLE on references.
        # [batch_size, decode_length, input_length + decode_length]
        reference_attention_mask = transformer_utils.compute_attention_mask(
            token_mask=features["token_inputs"].mask, input_mask=input_mask)

        # [batch_size, decode_length, hidden_size]
        target_emb, _ = model.compute_transformer(
            input_ids=features["token_inputs"].token_ids,
            input_segment_id=features["token_inputs"].segment_ids,
            input_positions=features["token_inputs"].positions,
            attention_mask=reference_attention_mask,
            input_cache=input_cache,
            reuse=tf.AUTO_REUSE)

        # [batch_size, decode_length, vocab_size]
        target_logits = model.compute_logits(target_emb, reuse=tf.AUTO_REUSE)

        # Compute the MLE objective (cross-entropy loss).
        weights = features["token_outputs"].mask
        ref_mask = tf.equal(features["input_type"],
                            datasets.DatasetTypes.REFERENCE)
        weights *= tf.expand_dims(tf.cast(ref_mask, tf.int32), 1)
        reference_loss = tf.losses.sparse_softmax_cross_entropy(
            labels=features["token_outputs"].token_ids,
            logits=target_logits,
            weights=weights)

        # Add both losses together.
        loss = rl_loss + reference_loss

        # BERT-style optimization with linear warmp.
        train_op = optimization.create_optimizer(
            loss=loss,
            init_lr=params["learning_rate"],
            num_train_steps=params["num_train_steps"],
            num_warmup_steps=params["num_warmup_steps"],
            use_tpu=params.get("use_tpu"))

        # Book-keeping.
        summaries = tpu_summaries.TpuSummaries(params["model_dir"])
        summaries.scalar("loss", loss)

        # Check what percentage of examples have non-zero reward.
        total_vqa = tf.reduce_sum(tf.cast(pg_mask, tf.float32))
        nonzero = tf.cast(tf.not_equal(reward, 0), tf.float32)
        nonzero *= tf.cast(pg_mask, tf.float32)
        total_nonzero = tf.reduce_sum(nonzero)
        summaries.scalar("density", tf.div_no_nan(total_nonzero, total_vqa))

        # Total (non-normalized) reward.
        reward = rollouts.rewards[params["reward"]][:, 0]
        reward *= tf.cast(pg_mask, tf.float32)
        total_reward = tf.reduce_sum(reward)
        summaries.scalar("reward", tf.div_no_nan(total_reward, total_vqa))
        host_call = summaries.get_host_call()
    else:
        loss = None
        train_op = None
        host_call = None

    # ----------------------------------------------------------------------------
    # TESTING.
    # ----------------------------------------------------------------------------

    if mode == tf_estimator.ModeKeys.PREDICT:
        decode_output = transformer_utils.beam_search_decode(
            model=model,
            encoder_cache=input_cache,
            encoder_cache_mask=input_mask,
            start_id=vocab.t2i(vocab.CLS),
            stop_id=vocab.t2i(vocab.SEP),
            segment_id=0,
            num_steps=params["decode_length"],
            beam_size=params["beam_size"],
            alpha=params["beam_length_penalty"],
            reuse=tf.AUTO_REUSE)
        predictions = dict(image_id=features.get("image_id", -1),
                           question_id=features.get("question_id", -1),
                           token_ids=decode_output.token_ids[:, :, 1:])
    else:
        predictions = None

    # ----------------------------------------------------------------------------
    # WARM-START.
    # ----------------------------------------------------------------------------

    # Initialize from pretrained model.
    def scaffold_fn():
        """Init op run on host."""
        checkpoint = params["base_model"]
        if params["warm_start_path"]:
            checkpoint = params["warm_start_path"]
        if checkpoint:
            checkpoint_utils.init_from_checkpoint(checkpoint)
        return tf.train.Scaffold()

    return tf_estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        predictions=predictions,
        scaffold_fn=scaffold_fn,
        host_call=host_call,
    )
Esempio n. 9
0
def _stochastic_prob(prob_mean, prob_std):
    prob_mean_mod = tf.math.log(tf.div_no_nan(prob_mean, 1 - prob_mean))
    #sample_random_normal = tf.random.normal([], prob_mean_mod, prob_std, seed=1)
    sample_random_normal = tf.random.normal([], prob_mean_mod, prob_std)
    prob = tf.divide(1, 1 + tf.math.exp(-sample_random_normal))
    return prob