Exemplo n.º 1
0
def aggregate_task_lm_losses(hparams, problem_hparams, logits, feature_name,
                             feature):
    """LM loss for multiproblems."""
    summaries = []
    vocab_size = problem_hparams.vocab_size[feature_name]
    if vocab_size is not None and hasattr(hparams, "vocab_divisor"):
        vocab_size += (-vocab_size) % hparams.vocab_divisor
    modality = problem_hparams.modality[feature_name]
    loss = hparams.loss.get(feature_name, modalities.get_loss(modality))
    weights_fn = hparams.targets_weights_fn.get(
        feature_name, modalities.get_targets_weights_fn(modality))
    loss_num = 0.
    loss_den = 0.
    for task in hparams.problem.task_list:
        loss_num_, loss_den_ = loss(
            logits,
            feature,
            lambda x: common_layers.weights_multi_problem_all(x, task.task_id),  # pylint: disable=cell-var-from-loop
            hparams,
            vocab_size,
            weights_fn)

        loss_num += loss_num_
        loss_den += loss_den_

        loss_val = loss_num_ / tf.maximum(1.0, loss_den_)
        summaries.append([task.name + "_loss", loss_val])

    return loss_num, loss_den, summaries
Exemplo n.º 2
0
def create_eager_metrics_for_problem(problem, model_hparams):
    """See create_eager_metrics."""
    metric_fns = problem.eval_metric_fns(model_hparams)
    problem_hparams = problem.get_hparams(model_hparams)
    target_modality = problem_hparams.modality["targets"]
    weights_fn = model_hparams.targets_weights_fn.get(
        "targets", modalities.get_targets_weights_fn(target_modality))
    return create_eager_metrics_internal(metric_fns, weights_fn=weights_fn)
Exemplo n.º 3
0
    def get_extra_internal_loss(self, extra_raw_gts, extra_gts, extra_pds):
        """Hacky code the get the loss on predicted frames from input frames.

       Recurrent models consume the frames one-by-one. Therefore
       if there is more than one input frame they also get predicted.
       T2T only calculates loss on the predicted target frames which
       means the loss is not being applied on the predicted input frames.
       This code is to fix this issue. Since the model is not aware of the
       modality it has to match the pre-porocessing happening in bottom
       function and therefore this becomes a very hacky code. This code
       should match the bottom and top and loss of modalities otherwise
       it will calculate the wrong loss.

    Args:
      extra_raw_gts: extra raw ground truth frames.
      extra_gts: extra normalized ground truth frames.
      extra_pds: extra predicted frames.

    Returns:
      Additional reconstruction loss.

    Raises:
      ValueError: in case of unknown modality.
    """
        if self._target_modality == modalities.ModalityType.VIDEO_L2_RAW:
            recon_loss = tf.losses.mean_squared_error(extra_gts, extra_pds)
        elif self._target_modality == modalities.ModalityType.VIDEO:
            shape = common_layers.shape_list(extra_pds)
            updated_shape = shape[:-1] + [3, 256]
            extra_pds = tf.reshape(extra_pds, updated_shape)
            # Merge time and batch
            logits = tf.reshape(extra_pds, [-1] + updated_shape[2:])
            targets = extra_raw_gts
            targets_shape = common_layers.shape_list(targets)
            targets = tf.reshape(targets, [-1] + targets_shape[2:])
            modality = self.hparams.problem_hparams.modality["targets"]
            targets_weights_fn = self.hparams.targets_weights_fn.get(
                "targets",
                modalities.get_targets_weights_fn(modality))(self.hparams)
            numerator, denominator = common_layers.padded_cross_entropy(
                logits,
                targets,
                self.hparams.label_smoothing,
                cutoff=getattr(self.hparams, "video_modality_loss_cutoff",
                               0.01),
                weights_fn=targets_weights_fn)
            recon_loss = numerator / denominator
        else:
            raise ValueError(
                "internal loss only supports specific modalities.")
        tf.summary.scalar("recon_extra", recon_loss)
        return recon_loss
Exemplo n.º 4
0
 def testSymbolModalityTargetsFactored(self):
   batch_size = 10
   num_datashards = 5
   length = 6
   height = 7
   hidden_size = 9
   vocab_size = 11
   model_hparams = common_hparams.basic_params1()
   model_hparams.factored_logits = True
   model_hparams.hidden_size = hidden_size
   model_hparams.mode = tf.estimator.ModeKeys.TRAIN
   body_output = np.random.randint(
       100, size=(batch_size, length, height, hidden_size))
   targets = np.random.randint(
       vocab_size, size=(batch_size, length, height, 1))
   data_parallelism = expert_utils.Parallelism(
       ["/device:CPU:0"] * num_datashards)
   with self.test_session() as session:
     sharded_body_output = tf.split(tf.to_float(body_output), num_datashards)
     sharded_targets = tf.split(targets, num_datashards)
     sharded_logits = data_parallelism(
         modalities.get_top(modalities.ModalityType.SYMBOL),
         sharded_body_output,
         sharded_targets,
         model_hparams,
         vocab_size)
     sharded_loss_num, sharded_loss_den = data_parallelism(
         modalities.SymbolModality.loss,
         sharded_logits,
         sharded_targets,
         model_hparams,
         vocab_size,
         modalities.get_targets_weights_fn(modalities.ModalityType.SYMBOL))
     train_loss = (tf.add_n(sharded_loss_num) /
                   tf.maximum(1.0, tf.add_n(sharded_loss_den)))
     logits = tf.concat(sharded_logits, 0)
     session.run(tf.global_variables_initializer())
     res1, res2 = session.run((logits, train_loss))
   self.assertEqual(res1.shape, (batch_size, length, height, 1, vocab_size))
   self.assertEqual(res2.shape, ())
Exemplo n.º 5
0
def create_evaluation_metrics(problems, model_hparams):
    """Creates the evaluation metrics for the model.

  Args:
    problems: List of Problem instances.
    model_hparams: a set of hparams.

  Returns:
    dict<metric name, metric function>. The metric functions have signature
    (Tensor predictions, features) -> (metric Tensor, update op), where features
    is a dict with keys {targets}.

  Raises:
    ValueError: if the metrics specified by a problem are not recognized (i.e.
      are not defined in the Metrics enum.
  """
    def reduce_dimensions(predictions, labels):
        """Reduce dimensions for high-dimensional predictions and labels."""
        # We will treat first dimensions as batch. One example are video frames.
        if len(predictions.get_shape()) > 5:
            predictions_shape = common_layers.shape_list(predictions)
            predictions = tf.reshape(predictions, [
                predictions_shape[0], predictions_shape[1], -1,
                predictions_shape[-1]
            ])
            labels_shape = common_layers.shape_list(labels)
            labels = tf.reshape(labels, [labels_shape[0], labels_shape[1], -1])
        return predictions, labels

    def make_problem_specific_metric_fn(metric_fn, weights_fn):
        """Create a metric fn."""
        def problem_metric_fn(predictions, features, labels):
            """Metric fn."""
            # Send along the entire features dict if the metric fn has the kwarg
            # "features".
            kwargs = {}
            args, _, keywords, _ = inspect.getargspec(metric_fn)
            if ("features" in args) or keywords:
                kwargs["features"] = features

            predictions, labels = reduce_dimensions(predictions, labels)

            scores, weights = metric_fn(predictions,
                                        labels,
                                        weights_fn=weights_fn,
                                        **kwargs)
            return tf.metrics.mean(scores, weights)

        return problem_metric_fn

    def make_image_wrapped_metric_fn(metric_fn):
        """Metric fn without tf.metrics.mean."""
        def image_wrapped_metric_fn(predictions,
                                    features,
                                    labels,
                                    weights_fn=common_layers.weights_all):
            del weights_fn
            del features
            predictions, labels = reduce_dimensions(predictions, labels)
            return metric_fn(predictions, labels, model_hparams)

        return image_wrapped_metric_fn

    def weights_fn_for_mp(problem_task_id):
        return lambda x: common_layers.weights_multi_problem(
            x, problem_task_id)

    eval_metrics = dict()
    for problem_instance in problems:
        problem_name = problem_instance.name
        if problem_instance.was_reversed:
            problem_name += "_rev"
        metrics = problem_instance.eval_metric_fns(model_hparams)
        if hasattr(model_hparams.problem, "task_list"):
            metrics = model_hparams.problem.eval_metric_fns(model_hparams)

        tm = problem_instance.get_hparams(model_hparams).modality["targets"]
        if not isinstance(tm, dict):
            tm = {"targets": tm}

        for target_name, modality in six.iteritems(tm):
            weights_fn = model_hparams.targets_weights_fn.get(
                "targets", modalities.get_targets_weights_fn(modality))
            if hasattr(model_hparams.problem, "task_list"):
                ptid = problem_instance.task_id  # pylint: disable=cell-var-from-loop
                weights_fn = weights_fn_for_mp(ptid)

            for metric, metric_fn in six.iteritems(metrics):
                overload_eval_metric_name = getattr(
                    model_hparams, "overload_eval_metric_name", None)
                if len(problems) == 1 and overload_eval_metric_name:
                    metric_name = "metrics-%s/%s/%s" % (
                        overload_eval_metric_name, target_name, metric)
                else:
                    metric_name = "metrics-%s/%s/%s" % (problem_name,
                                                        target_name, metric)
                if metric == Metrics.IMAGE_SUMMARY:
                    eval_metrics[metric_name] = make_image_wrapped_metric_fn(
                        metric_fn)
                else:
                    eval_metrics[
                        metric_name] = make_problem_specific_metric_fn(
                            metric_fn, weights_fn)

    return eval_metrics
Exemplo n.º 6
0
def aggregate_task_losses(hparams, problem_hparams, logits, feature_name,
                          feature):
    """Multiproblem loss function."""

    # If no reweighting, we want the default loss to mimic the LM loss.
    if not hparams.multiproblem_reweight_label_loss:
        return aggregate_task_lm_losses(hparams=hparams,
                                        problem_hparams=problem_hparams,
                                        logits=logits,
                                        feature_name=feature_name,
                                        feature=feature)

    summaries = []
    main_task_id = hparams.problem.task_list[0].task_id
    vocab_size = problem_hparams.vocab_size[feature_name]
    if vocab_size is not None and hasattr(hparams, "vocab_divisor"):
        vocab_size += (-vocab_size) % hparams.vocab_divisor
    modality = problem_hparams.modality[feature_name]
    loss = hparams.loss.get(feature_name, modalities.get_loss(modality))
    weights_fn = hparams.targets_weights_fn.get(
        feature_name, modalities.get_targets_weights_fn(modality))
    # Primary task loss
    loss_num, loss_den = loss(
        logits, feature,
        lambda x: common_layers.weights_multi_problem_all(x, main_task_id),
        hparams, vocab_size, weights_fn)

    loss_val = loss_num / tf.maximum(1.0, loss_den)
    summaries.append([hparams.problem.task_list[0].name + "_loss", loss_val])

    # Since the losses may undergo rescaling, they cannot exist as separate
    # numerators and denominators. Set the denominators to 1 in order to faciliate
    # loss averaging.
    loss_num = loss_val
    loss_den = tf.minimum(tf.convert_to_tensor(1, dtype=tf.float32), loss_den)

    for task in hparams.problem.task_list[1:]:
        # Loss only from the input sequence -- the auxiliary LM loss.
        seq_loss_num, seq_loss_den = loss(
            logits,
            feature,
            lambda x: common_layers.weights_multi_problem_input(
                x, task.task_id),  # pylint: disable=cell-var-from-loop
            hparams,
            vocab_size)
        seq_loss_num *= problem_hparams.loss_multiplier

        # Unscaled sequence loss.
        seq_loss = seq_loss_num / tf.maximum(1.0, seq_loss_den)
        summaries.append([task.name + "_seq_loss", seq_loss])

        if hasattr(task, "num_classes"):
            # Loss only from the classification label.
            label_loss_num, label_loss_den = loss(
                logits,
                feature,
                lambda x: common_layers.weights_multi_problem(x, task.task_id),  # pylint: disable=cell-var-from-loop
                hparams,
                vocab_size)
            label_loss_num *= problem_hparams.loss_multiplier

            # Unscaled classification label loss.
            label_loss = label_loss_num / tf.maximum(1.0, label_loss_den)
            summaries.append([task.name + "_label_loss", label_loss])

            # Scaling.
            if hparams.multiproblem_reweight_label_loss:
                label_loss *= hparams.multiproblem_label_weight
                seq_loss *= (1 - hparams.multiproblem_label_weight)

            # This is the training loss for the optimizer after scaling.
            task_loss_val = seq_loss + label_loss

            loss_den_ = label_loss_den

        else:
            # Loss only from the target sequence.
            target_loss_num, target_loss_den = loss(
                logits,
                feature,
                lambda x: common_layers.weights_multi_problem(x, task.task_id),  # pylint: disable=cell-var-from-loop
                hparams,
                vocab_size)
            target_loss_num *= problem_hparams.loss_multiplier

            # Unscaled target sequence loss.
            target_loss = target_loss_num / tf.maximum(1.0, target_loss_den)
            summaries.append([task.name + "_target_loss", target_loss])

            # Scaling.
            if hparams.multiproblem_reweight_label_loss:
                target_loss *= hparams.multiproblem_label_weight
                seq_loss *= (1 - hparams.multiproblem_label_weight)

            # This is the training loss for the optimizer after all the scaling.
            task_loss_val = seq_loss + target_loss

            loss_den_ = target_loss_den

        summaries.append([task.name + "_loss", task_loss_val])
        # Adding 1 to the loss den for each task leads to averaging task losses.
        # TODO(urvashik): Fix combination with other task losses - weighted
        # average based on the number of examples from that task.
        loss_num += task_loss_val
        loss_den += tf.minimum(tf.convert_to_tensor(1, dtype=tf.float32),
                               loss_den_)

    return loss_num, loss_den, summaries