def testSymbolModalityTargets(self):
     batch_size = 10
     num_datashards = 5
     length = 6
     height = 7
     hidden_size = 9
     vocab_size = 11
     model_hparams = common_hparams.basic_params1()
     model_hparams.hidden_size = hidden_size
     model_hparams.mode = tf.estimator.ModeKeys.TRAIN
     body_output = np.random.randint(100,
                                     size=(batch_size, length, height,
                                           hidden_size))
     targets = np.random.randint(vocab_size,
                                 size=(batch_size, length, height, 1))
     data_parallelism = expert_utils.Parallelism(["/device:CPU:0"] *
                                                 num_datashards)
     sharded_body_output = tf.split(tf.to_float(body_output),
                                    num_datashards)
     sharded_targets = tf.split(targets, num_datashards)
     sharded_logits = data_parallelism(
         modalities.get_top(modalities.ModalityType.SYMBOL),
         sharded_body_output, sharded_targets, model_hparams, vocab_size)
     sharded_loss_num, sharded_loss_den = data_parallelism(
         modalities.get_loss(modalities.ModalityType.SYMBOL),
         sharded_logits, sharded_targets, model_hparams, vocab_size,
         modalities.get_weights_fn(modalities.ModalityType.SYMBOL))
     train_loss = (tf.add_n(sharded_loss_num) /
                   tf.maximum(1.0, tf.add_n(sharded_loss_den)))
     logits = tf.concat(sharded_logits, 0)
     self.evaluate(tf.global_variables_initializer())
     res1, res2 = self.evaluate((logits, train_loss))
     self.assertEqual(res1.shape,
                      (batch_size, length, height, 1, vocab_size))
     self.assertEqual(res2.shape, ())
示例#2
0
def aggregate_task_lm_losses(hparams, problem_hparams, logits, feature_name,
                             feature):
    """LM loss for multiproblems."""
    summaries = []
    vocab_size = problem_hparams.vocab_size[feature_name]
    if vocab_size is not None and hasattr(hparams, "vocab_divisor"):
        vocab_size += (-vocab_size) % hparams.vocab_divisor
    modality = problem_hparams.modality[feature_name]
    loss = hparams.loss.get(feature_name, modalities.get_loss(modality))
    weights_fn = hparams.weights_fn.get(feature_name,
                                        modalities.get_weights_fn(modality))
    loss_num = 0.
    loss_den = 0.
    for task in hparams.problem.task_list:
        loss_num_, loss_den_ = loss(
            logits,
            feature,
            lambda x: common_layers.weights_multi_problem_all(x, task.task_id),  # pylint: disable=cell-var-from-loop
            hparams,
            vocab_size,
            weights_fn)

        loss_num += loss_num_
        loss_den += loss_den_

        loss_val = loss_num_ / tf.maximum(1.0, loss_den_)
        summaries.append([task.name + "_loss", loss_val])

    return loss_num, loss_den, summaries
示例#3
0
def create_eager_metrics_for_problem(problem, model_hparams):
    """See create_eager_metrics."""
    metric_fns = problem.eval_metric_fns(model_hparams)
    problem_hparams = problem.get_hparams(model_hparams)
    target_modality = problem_hparams.modality["targets"]
    weights_fn = model_hparams.weights_fn.get(
        "targets", modalities.get_weights_fn(target_modality))
    return create_eager_metrics_internal(metric_fns, weights_fn=weights_fn)
示例#4
0
    def get_extra_internal_loss(self, extra_raw_gts, extra_gts, extra_pds):
        """Hacky code the get the loss on predicted frames from input frames.

       Recurrent models consume the frames one-by-one. Therefore
       if there is more than one input frame they also get predicted.
       T2T only calculates loss on the predicted target frames which
       means the loss is not being applied on the predicted input frames.
       This code is to fix this issue. Since the model is not aware of the
       modality it has to match the pre-porocessing happening in bottom
       function and therefore this becomes a very hacky code. This code
       should match the bottom and top and loss of modalities otherwise
       it will calculate the wrong loss.

    Args:
      extra_raw_gts: extra raw ground truth frames.
      extra_gts: extra normalized ground truth frames.
      extra_pds: extra predicted frames.

    Returns:
      Additional reconstruction loss.

    Raises:
      ValueError: in case of unknown loss transformation.
    """
        # TODO(trandustin): This logic should be moved elsewhere.
        if self.hparams.loss.get("targets") == modalities.video_l2_raw_loss:
            recon_loss = tf.losses.mean_squared_error(extra_gts, extra_pds)
        elif "targets" not in self.hparams.loss:
            shape = common_layers.shape_list(extra_pds)
            updated_shape = shape[:-1] + [3, 256]
            extra_pds = tf.reshape(extra_pds, updated_shape)
            # Merge time and batch
            logits = tf.reshape(extra_pds, [-1] + updated_shape[2:])
            targets = extra_raw_gts
            targets_shape = common_layers.shape_list(targets)
            targets = tf.reshape(targets, [-1] + targets_shape[2:])
            targets_weights_fn = self.hparams.weights_fn.get(
                "targets", modalities.get_weights_fn(self._target_modality))
            numerator, denominator = common_layers.padded_cross_entropy(
                logits,
                targets,
                self.hparams.label_smoothing,
                cutoff=getattr(self.hparams, "video_modality_loss_cutoff",
                               0.01),
                weights_fn=targets_weights_fn)
            recon_loss = numerator / denominator
        else:
            raise ValueError(
                "internal loss only supports specific hparams.loss.")
        tf.summary.scalar("recon_extra", recon_loss)
        return recon_loss
 def testGetForAllModalities(self):
   for modality in modalities.ModalityType.get_choices():
     bottom = modalities.get_bottom(modality)
     loss = modalities.get_loss(modality)
     name = modalities.get_name(modality)
     targets_bottom = modalities.get_targets_bottom(modality)
     top = modalities.get_top(modality)
     weights_fn = modalities.get_weights_fn(modality)
     self.assertIsNotNone(bottom,
                          msg="{} has no default bottom".format(modality))
     self.assertIsNotNone(loss, msg="{} has no default loss".format(modality))
     self.assertIsNotNone(name, msg="{} has no default name".format(modality))
     self.assertIsNotNone(
         targets_bottom,
         msg="{} has no default targets_bottom".format(modality))
     self.assertIsNotNone(top, msg="{} has no default top".format(modality))
     self.assertIsNotNone(weights_fn,
                          msg="{} has no default weights_fn".format(modality))
示例#6
0
def aggregate_task_losses(hparams, problem_hparams, logits, feature_name,
                          feature):
    """Multiproblem loss function."""

    # If no reweighting, we want the default loss to mimic the LM loss.
    if not hparams.multiproblem_reweight_label_loss:
        return aggregate_task_lm_losses(hparams=hparams,
                                        problem_hparams=problem_hparams,
                                        logits=logits,
                                        feature_name=feature_name,
                                        feature=feature)

    summaries = []
    main_task_id = hparams.problem.task_list[0].task_id
    vocab_size = problem_hparams.vocab_size[feature_name]
    if vocab_size is not None and hasattr(hparams, "vocab_divisor"):
        vocab_size += (-vocab_size) % hparams.vocab_divisor
    modality = problem_hparams.modality[feature_name]
    loss = hparams.loss.get(feature_name, modalities.get_loss(modality))
    weights_fn = hparams.weights_fn.get(feature_name,
                                        modalities.get_weights_fn(modality))
    # Primary task loss
    loss_num, loss_den = loss(
        logits, feature,
        lambda x: common_layers.weights_multi_problem_all(x, main_task_id),
        hparams, vocab_size, weights_fn)

    loss_val = loss_num / tf.maximum(1.0, loss_den)
    summaries.append([hparams.problem.task_list[0].name + "_loss", loss_val])

    # Since the losses may undergo rescaling, they cannot exist as separate
    # numerators and denominators. Set the denominators to 1 in order to faciliate
    # loss averaging.
    loss_num = loss_val
    loss_den = tf.minimum(tf.convert_to_tensor(1, dtype=tf.float32), loss_den)

    for task in hparams.problem.task_list[1:]:
        # Loss only from the input sequence -- the auxiliary LM loss.
        seq_loss_num, seq_loss_den = loss(
            logits,
            feature,
            lambda x: common_layers.weights_multi_problem_input(
                x, task.task_id),  # pylint: disable=cell-var-from-loop
            hparams,
            vocab_size)
        seq_loss_num *= problem_hparams.loss_multiplier

        # Unscaled sequence loss.
        seq_loss = seq_loss_num / tf.maximum(1.0, seq_loss_den)
        summaries.append([task.name + "_seq_loss", seq_loss])

        if hasattr(task, "num_classes"):
            # Loss only from the classification label.
            label_loss_num, label_loss_den = loss(
                logits,
                feature,
                lambda x: common_layers.weights_multi_problem(x, task.task_id),  # pylint: disable=cell-var-from-loop
                hparams,
                vocab_size)
            label_loss_num *= problem_hparams.loss_multiplier

            # Unscaled classification label loss.
            label_loss = label_loss_num / tf.maximum(1.0, label_loss_den)
            summaries.append([task.name + "_label_loss", label_loss])

            # Scaling.
            if hparams.multiproblem_reweight_label_loss:
                label_loss *= hparams.multiproblem_label_weight
                seq_loss *= (1 - hparams.multiproblem_label_weight)

            # This is the training loss for the optimizer after scaling.
            task_loss_val = seq_loss + label_loss

            loss_den_ = label_loss_den

        else:
            # Loss only from the target sequence.
            target_loss_num, target_loss_den = loss(
                logits,
                feature,
                lambda x: common_layers.weights_multi_problem(x, task.task_id),  # pylint: disable=cell-var-from-loop
                hparams,
                vocab_size)
            target_loss_num *= problem_hparams.loss_multiplier

            # Unscaled target sequence loss.
            target_loss = target_loss_num / tf.maximum(1.0, target_loss_den)
            summaries.append([task.name + "_target_loss", target_loss])

            # Scaling.
            if hparams.multiproblem_reweight_label_loss:
                target_loss *= hparams.multiproblem_label_weight
                seq_loss *= (1 - hparams.multiproblem_label_weight)

            # This is the training loss for the optimizer after all the scaling.
            task_loss_val = seq_loss + target_loss

            loss_den_ = target_loss_den

        summaries.append([task.name + "_loss", task_loss_val])
        # Adding 1 to the loss den for each task leads to averaging task losses.
        # TODO(urvashik): Fix combination with other task losses - weighted
        # average based on the number of examples from that task.
        loss_num += task_loss_val
        loss_den += tf.minimum(tf.convert_to_tensor(1, dtype=tf.float32),
                               loss_den_)

    return loss_num, loss_den, summaries
示例#7
0
def create_evaluation_metrics(problems, model_hparams):
    """Creates the evaluation metrics for the model.

  Args:
    problems: List of Problem instances.
    model_hparams: a set of hparams.

  Returns:
    dict<metric name, metric function>. The metric functions have signature
    (Tensor predictions, features) -> (metric Tensor, update op), where features
    is a dict with keys {targets}.

  Raises:
    ValueError: if the metrics specified by a problem are not recognized (i.e.
      are not defined in the Metrics enum.
  """
    def reduce_dimensions(predictions, labels):
        """Reduce dimensions for high-dimensional predictions and labels."""
        # We will treat first dimensions as batch. One example are video frames.
        if len(predictions.get_shape()) > 5:
            predictions_shape = common_layers.shape_list(predictions)
            predictions = tf.reshape(predictions, [
                predictions_shape[0], predictions_shape[1], -1,
                predictions_shape[-1]
            ])
            labels_shape = common_layers.shape_list(labels)
            labels = tf.reshape(labels, [labels_shape[0], labels_shape[1], -1])
        return predictions, labels

    def make_problem_specific_metric_fn(metric_fn, weights_fn):
        """Create a metric fn."""
        def problem_metric_fn(predictions, features, labels):
            """Metric fn."""
            # Send along the entire features dict if the metric fn has the kwarg
            # "features".
            kwargs = {}
            args, _, keywords, _ = inspect.getargspec(metric_fn)
            if ("features" in args) or keywords:
                kwargs["features"] = features

            predictions, labels = reduce_dimensions(predictions, labels)

            scores, weights = metric_fn(predictions,
                                        labels,
                                        weights_fn=weights_fn,
                                        **kwargs)
            return tf.metrics.mean(scores, weights)

        return problem_metric_fn

    def make_image_wrapped_metric_fn(metric_fn):
        """Metric fn without tf.metrics.mean."""
        def image_wrapped_metric_fn(predictions,
                                    features,
                                    labels,
                                    weights_fn=common_layers.weights_all):
            del weights_fn
            del features
            predictions, labels = reduce_dimensions(predictions, labels)
            return metric_fn(predictions, labels, model_hparams)

        return image_wrapped_metric_fn

    def weights_fn_for_mp(problem_task_id):
        return lambda x: common_layers.weights_multi_problem(
            x, problem_task_id)

    eval_metrics = {}
    for problem_instance in problems:
        problem_name = problem_instance.name
        if problem_instance.was_reversed:
            problem_name += "_rev"
        metrics = problem_instance.eval_metric_fns(model_hparams)
        if hasattr(model_hparams.problem, "task_list"):
            metrics = model_hparams.problem.eval_metric_fns(model_hparams)

        tm = problem_instance.get_hparams(model_hparams).modality["targets"]
        if not isinstance(tm, dict):
            tm = {"targets": tm}

        for target_name, modality in six.iteritems(tm):
            weights_fn = model_hparams.weights_fn.get(
                "targets", modalities.get_weights_fn(modality))
            if hasattr(model_hparams.problem, "task_list"):
                ptid = problem_instance.task_id  # pylint: disable=cell-var-from-loop
                weights_fn = weights_fn_for_mp(ptid)

            for metric, metric_fn in six.iteritems(metrics):
                overload_eval_metric_name = getattr(
                    model_hparams, "overload_eval_metric_name", None)
                if len(problems) == 1 and overload_eval_metric_name:
                    metric_name = "metrics-%s/%s/%s" % (
                        overload_eval_metric_name, target_name, metric)
                else:
                    metric_name = "metrics-%s/%s/%s" % (problem_name,
                                                        target_name, metric)
                if metric == Metrics.IMAGE_SUMMARY:
                    eval_metrics[metric_name] = make_image_wrapped_metric_fn(
                        metric_fn)
                else:
                    eval_metrics[
                        metric_name] = make_problem_specific_metric_fn(
                            metric_fn, weights_fn)

    return eval_metrics
    def _loss_single_iw(self, logits, feature_name, feature, weights=None):
        # The current bfloat16 version still uses float32 for most parts of backward
        # propagation to keep model quality, so cast back before computing the loss
        # value.
        no_problem_err_str = (
            "The default implementation of %s requires that the "
            "model be used with a Problem. If using a Problem, augment the "
            "hparams object with trainer_lib.add_problem_hparams. If not, "
            "override %s.")
        no_problem_err = (lambda method_name: no_problem_err_str %
                          (method_name, method_name))
        if not self._problem_hparams:
            t2t_model.log_warn(no_problem_err("loss"))
            return (tf.constant(0., dtype=tf.float32),
                    tf.constant(1., dtype=tf.float32))

        # Calculate loss contribution.
        modality = self._problem_hparams.modality[feature_name]
        vocab_size = self._problem_hparams.vocab_size[feature_name]
        if vocab_size is not None and hasattr(self._hparams, "vocab_divisor"):
            vocab_size += (-vocab_size) % self._hparams.vocab_divisor
        # loss = self._hparams.loss.get(feature_name, modalities.get_loss(modality))
        loss = ops.generic_loss
        targets_weights_fn = self._hparams.weights_fn.get(
            "targets", modalities.get_weights_fn(modality))
        if weights is None:
            loss_num, loss_den = loss(logits,
                                      feature,
                                      self._hparams,
                                      vocab_size,
                                      weights_fn=targets_weights_fn)
        else:

            def weights_fn(labels):
                """Per-token weights for loss."""
                # Use target_weights_fn() given by modality as well as explicitly given
                # weights.
                modality_weights = targets_weights_fn(labels)

                # Broadcast 'weights' along minor dimensions (TF's default is major).
                explicit_weights = weights
                if len(explicit_weights.shape) < len(modality_weights.shape):
                    explicit_weights = common_layers.expand_squeeze_to_nd(
                        weights, modality_weights.shape.ndims)

                return explicit_weights * modality_weights

            # Ensure that target.modality_loss() supports "weights_fn" keyword
            # argument. If it doesn't and "weights" is specified, raise an exception.
            argument_names = inspect.getargspec(loss).args
            if "weights_fn" not in argument_names:
                raise ValueError(
                    "Explicit 'weights' given but default loss for modality doesn't "
                    "support 'weights_fn' keyword argument: %s.loss(%s)." %
                    (modality, ", ".join(argument_names)))

            loss_num, loss_den = loss(logits,
                                      feature,
                                      self._hparams,
                                      vocab_size,
                                      weights_fn=weights_fn)

        loss_num *= self._problem_hparams.loss_multiplier

        if hasattr(self.hparams, "problem") and hasattr(
                self.hparams.problem, "task_list"):
            if weights is not None:
                raise NotImplementedError("weights not yet implemented in "
                                          "multitask setting.")
            loss_num, loss_den, summaries = multi_problem.aggregate_task_losses(
                self.hparams, self._problem_hparams, logits, feature_name,
                feature)

            for key, val in summaries:
                tf.summary.scalar(key, val)

        return loss_num, loss_den