def aggregate_task_losses(hparams, problem_hparams, logits, target_modality,
                          feature):
    """Multiproblem loss function."""
    summaries = []
    main_task_id = hparams.problem.task_list[0].task_id
    # Primary task loss
    loss_num, loss_den = target_modality.loss(
        logits,
        feature,
        weights_fn=lambda x: common_layers.weights_multi_problem_all(
            x, main_task_id))

    loss_val = loss_num / tf.maximum(1.0, loss_den)
    summaries.append([hparams.problem.task_list[0].name + "_loss", loss_val])

    # Since the losses may undergo rescaling, they cannot exist as separate
    # numerators and denominators. Set the denominators to 1 in order to faciliate
    # loss averaging.
    loss_num = loss_val
    loss_den = tf.minimum(tf.convert_to_tensor(1, dtype=tf.float32), loss_den)

    for task in hparams.problem.task_list[1:]:
        # Loss only from the input sequence -- the auxiliary LM loss.
        seq_loss_num, seq_loss_den = target_modality.loss(
            logits,
            feature,
            weights_fn=lambda x: common_layers.weights_multi_problem_input(
                x, task.task_id))  # pylint: disable=cell-var-from-loop
        seq_loss_num *= problem_hparams.loss_multiplier

        # Unscaled sequence loss.
        seq_loss = seq_loss_num / tf.maximum(1.0, seq_loss_den)
        summaries.append([task.name + "_seq_loss", seq_loss])

        if hasattr(task, "num_classes"):
            # Loss only from the classification label.
            label_loss_num, label_loss_den = target_modality.loss(
                logits,
                feature,
                weights_fn=lambda x: common_layers.weights_multi_problem(
                    x, task.task_id))  # pylint: disable=cell-var-from-loop
            label_loss_num *= problem_hparams.loss_multiplier

            # Unscaled classification label loss.
            label_loss = label_loss_num / tf.maximum(1.0, label_loss_den)
            summaries.append([task.name + "_label_loss", label_loss])

            # Scaling.
            if hparams.multiproblem_reweight_label_loss:
                label_loss *= hparams.multiproblem_label_weight
                seq_loss *= (1 - hparams.multiproblem_label_weight)

            if hparams.multiproblem_class_loss_multiplier:
                label_loss *= hparams.multiproblem_class_loss_multiplier
                summaries.append(
                    [task.name + "_scaled_label_loss", label_loss])

            # This is the training loss for the optimizer after all the scaling.
            task_loss_val = seq_loss + label_loss

            loss_den_ = label_loss_den

        else:
            # Loss only from the target sequence.
            target_loss_num, target_loss_den = target_modality.loss(
                logits,
                feature,
                weights_fn=lambda x: common_layers.weights_multi_problem(
                    x, task.task_id))  # pylint: disable=cell-var-from-loop
            target_loss_num *= problem_hparams.loss_multiplier

            # Unscaled target sequence loss.
            target_loss = target_loss_num / tf.maximum(1.0, target_loss_den)
            summaries.append([task.name + "_target_loss", target_loss])

            # Scaling.
            if hparams.multiproblem_reweight_label_loss:
                target_loss *= hparams.multiproblem_label_weight
                seq_loss *= (1 - hparams.multiproblem_label_weight)

            # This is the training loss for the optimizer after all the scaling.
            task_loss_val = seq_loss + target_loss

            loss_den_ = target_loss_den

        summaries.append([task.name + "_loss", task_loss_val])
        # Adding 1 to the loss den for each task leads to averaging task losses.
        # TODO(urvashik): Fix combination with other task losses - weighted
        # average based on the number of examples from that task.
        loss_num += task_loss_val
        loss_den += tf.minimum(tf.convert_to_tensor(1, dtype=tf.float32),
                               loss_den_)

    return loss_num, loss_den, summaries
Beispiel #2
0
 def weights_fn_for_mp(problem_task_id):
     return lambda x: common_layers.weights_multi_problem(
         x, problem_task_id)
Beispiel #3
0
 def weights_fn_for_mp(problem_task_id):
   return lambda x: common_layers.weights_multi_problem(x, problem_task_id)
def aggregate_task_losses(hparams, problem_hparams, logits, target_modality,
                          feature):
    """Multiproblem loss function."""
    summaries = []
    main_task_id = hparams.problem.task_list[0].task_id
    # Primary task loss
    loss_num, loss_den = target_modality.loss(
        logits,
        feature,
        weights_fn=lambda x: common_layers.weights_multi_problem_all(
            x, main_task_id))

    loss_val = loss_num / tf.maximum(1.0, loss_den)
    summaries.append([hparams.problem.task_list[0].name + "_loss", loss_val])

    for task in hparams.problem.task_list[1:]:
        if hasattr(task, "num_classes"):
            task_loss_num_seq, task_loss_den_seq = target_modality.loss(
                logits,
                feature,
                weights_fn=lambda x: common_layers.weights_multi_problem_input(
                    x, task.task_id))  # pylint: disable=cell-var-from-loop
            task_loss_num_seq *= problem_hparams.loss_multiplier

            task_loss_num_label, task_loss_den_label = target_modality.loss(
                logits,
                feature,
                weights_fn=lambda x: common_layers.weights_multi_problem(
                    x, task.task_id))  # pylint: disable=cell-var-from-loop
            task_loss_num_label *= problem_hparams.loss_multiplier

            if hparams.multiproblem_reweight_label_loss:
                task_loss_num = (1 - hparams.multiproblem_label_weight) * \
                                task_loss_num_seq
                task_loss_num += hparams.multiproblem_label_weight * task_loss_num_label
            elif hparams.multiproblem_class_loss_multiplier > 0:
                task_loss_num = task_loss_num_seq
                task_loss_num += hparams.multiproblem_class_loss_multiplier * \
                                 task_loss_num_label
            else:
                task_loss_num = task_loss_num_seq + task_loss_num_label

            task_loss_den = task_loss_den_seq + task_loss_den_label

            # Log the unscaled versions of the losses to tensorboard.
            task_loss_val = (task_loss_num_seq +
                             task_loss_num_label) / tf.maximum(
                                 1.0, task_loss_den)
            summaries.append([task.name + "_loss", task_loss_val])

            task_loss_val_label = task_loss_num_label / tf.maximum(
                1.0, task_loss_den_label)
            summaries.append(
                [task.name + "_only_label_loss", task_loss_val_label])

            loss_num += task_loss_num
            loss_den += task_loss_den

        else:
            raise ValueError(
                "Non-classification secondary tasks are not supported.")

    return loss_num, loss_den, summaries
Beispiel #5
0
def aggregate_task_losses(hparams,
                          problem_hparams,
                          logits,
                          target_modality,
                          feature):
  """Multiproblem loss function."""
  summaries = []
  main_task_id = hparams.problem.task_list[0].task_id
  # Primary task loss
  loss_num, loss_den = target_modality.loss(
      logits, feature,
      weights_fn=
      lambda x: common_layers.weights_multi_problem_all(x, main_task_id))

  loss_val = loss_num / tf.maximum(1.0, loss_den)
  summaries.append([hparams.problem.task_list[0].name+"_loss", loss_val])

  # Since the losses may undergo rescaling, they cannot exist as separate
  # numerators and denominators. Set the denominators to 1 in order to faciliate
  # loss averaging.
  loss_num = loss_val
  loss_den = tf.minimum(tf.convert_to_tensor(1, dtype=tf.float32), loss_den)

  for task in hparams.problem.task_list[1:]:
    # Loss only from the input sequence -- the auxiliary LM loss.
    seq_loss_num, seq_loss_den = target_modality.loss(
        logits, feature,
        weights_fn=
        lambda x: common_layers.weights_multi_problem_input(x, task.task_id))  # pylint: disable=cell-var-from-loop
    seq_loss_num *= problem_hparams.loss_multiplier

    # Unscaled sequence loss.
    seq_loss = seq_loss_num / tf.maximum(1.0, seq_loss_den)
    summaries.append([task.name+"_seq_loss", seq_loss])

    if hasattr(task, "num_classes"):
      # Loss only from the classification label.
      label_loss_num, label_loss_den = target_modality.loss(
          logits, feature,
          weights_fn=
          lambda x: common_layers.weights_multi_problem(x, task.task_id))  # pylint: disable=cell-var-from-loop
      label_loss_num *= problem_hparams.loss_multiplier

      # Unscaled classification label loss.
      label_loss = label_loss_num / tf.maximum(1.0, label_loss_den)
      summaries.append([task.name+"_label_loss", label_loss])

      # Scaling.
      if hparams.multiproblem_reweight_label_loss:
        label_loss *= hparams.multiproblem_label_weight
        seq_loss *= (1 - hparams.multiproblem_label_weight)

      if hparams.multiproblem_class_loss_multiplier:
        label_loss *= hparams.multiproblem_class_loss_multiplier
        summaries.append([task.name+"_scaled_label_loss", label_loss])

      # This is the training loss for the optimizer after all the scaling.
      task_loss_val = seq_loss + label_loss

      loss_den_ = label_loss_den

    else:
      # Loss only from the target sequence.
      target_loss_num, target_loss_den = target_modality.loss(
          logits, feature,
          weights_fn=
          lambda x: common_layers.weights_multi_problem(x, task.task_id))  # pylint: disable=cell-var-from-loop
      target_loss_num *= problem_hparams.loss_multiplier

      # Unscaled target sequence loss.
      target_loss = target_loss_num / tf.maximum(1.0, target_loss_den)
      summaries.append([task.name+"_target_loss", target_loss])

      # Scaling.
      if hparams.multiproblem_reweight_label_loss:
        target_loss *= hparams.multiproblem_label_weight
        seq_loss *= (1 - hparams.multiproblem_label_weight)

      # This is the training loss for the optimizer after all the scaling.
      task_loss_val = seq_loss + target_loss

      loss_den_ = target_loss_den

    summaries.append([task.name+"_loss", task_loss_val])
    # Adding 1 to the loss den for each task leads to averaging task losses.
    # TODO(urvashik): Fix combination with other task losses - weighted
    # average based on the number of examples from that task.
    loss_num += task_loss_val
    loss_den += tf.minimum(tf.convert_to_tensor(1, dtype=tf.float32),
                           loss_den_)

  return loss_num, loss_den, summaries