コード例 #1
0
 def testSymbolModalityTargets(self):
     batch_size = 10
     num_datashards = 5
     length = 6
     height = 7
     hidden_size = 9
     vocab_size = 11
     model_hparams = common_hparams.basic_params1()
     model_hparams.hidden_size = hidden_size
     model_hparams.mode = tf.estimator.ModeKeys.TRAIN
     body_output = np.random.randint(100,
                                     size=(batch_size, length, height,
                                           hidden_size))
     targets = np.random.randint(vocab_size,
                                 size=(batch_size, length, height, 1))
     data_parallelism = expert_utils.Parallelism(["/device:CPU:0"] *
                                                 num_datashards)
     sharded_body_output = tf.split(tf.to_float(body_output),
                                    num_datashards)
     sharded_targets = tf.split(targets, num_datashards)
     sharded_logits = data_parallelism(
         modalities.get_top(modalities.ModalityType.SYMBOL),
         sharded_body_output, sharded_targets, model_hparams, vocab_size)
     sharded_loss_num, sharded_loss_den = data_parallelism(
         modalities.get_loss(modalities.ModalityType.SYMBOL),
         sharded_logits, sharded_targets, model_hparams, vocab_size,
         modalities.get_weights_fn(modalities.ModalityType.SYMBOL))
     train_loss = (tf.add_n(sharded_loss_num) /
                   tf.maximum(1.0, tf.add_n(sharded_loss_den)))
     logits = tf.concat(sharded_logits, 0)
     self.evaluate(tf.global_variables_initializer())
     res1, res2 = self.evaluate((logits, train_loss))
     self.assertEqual(res1.shape,
                      (batch_size, length, height, 1, vocab_size))
     self.assertEqual(res2.shape, ())
コード例 #2
0
def aggregate_task_lm_losses(hparams, problem_hparams, logits, feature_name,
                             feature):
    """LM loss for multiproblems."""
    summaries = []
    vocab_size = problem_hparams.vocab_size[feature_name]
    if vocab_size is not None and hasattr(hparams, "vocab_divisor"):
        vocab_size += (-vocab_size) % hparams.vocab_divisor
    modality = problem_hparams.modality[feature_name]
    loss = hparams.loss.get(feature_name, modalities.get_loss(modality))
    weights_fn = hparams.weights_fn.get(feature_name,
                                        modalities.get_weights_fn(modality))
    loss_num = 0.
    loss_den = 0.
    for task in hparams.problem.task_list:
        loss_num_, loss_den_ = loss(
            logits,
            feature,
            lambda x: common_layers.weights_multi_problem_all(x, task.task_id),  # pylint: disable=cell-var-from-loop
            hparams,
            vocab_size,
            weights_fn)

        loss_num += loss_num_
        loss_den += loss_den_

        loss_val = loss_num_ / tf.maximum(1.0, loss_den_)
        summaries.append([task.name + "_loss", loss_val])

    return loss_num, loss_den, summaries
コード例 #3
0
 def testGetForAllModalities(self):
   for modality in modalities.ModalityType.get_choices():
     bottom = modalities.get_bottom(modality)
     loss = modalities.get_loss(modality)
     name = modalities.get_name(modality)
     targets_bottom = modalities.get_targets_bottom(modality)
     top = modalities.get_top(modality)
     weights_fn = modalities.get_weights_fn(modality)
     self.assertIsNotNone(bottom,
                          msg="{} has no default bottom".format(modality))
     self.assertIsNotNone(loss, msg="{} has no default loss".format(modality))
     self.assertIsNotNone(name, msg="{} has no default name".format(modality))
     self.assertIsNotNone(
         targets_bottom,
         msg="{} has no default targets_bottom".format(modality))
     self.assertIsNotNone(top, msg="{} has no default top".format(modality))
     self.assertIsNotNone(weights_fn,
                          msg="{} has no default weights_fn".format(modality))
コード例 #4
0
def aggregate_task_losses(hparams, problem_hparams, logits, feature_name,
                          feature):
    """Multiproblem loss function."""

    # If no reweighting, we want the default loss to mimic the LM loss.
    if not hparams.multiproblem_reweight_label_loss:
        return aggregate_task_lm_losses(hparams=hparams,
                                        problem_hparams=problem_hparams,
                                        logits=logits,
                                        feature_name=feature_name,
                                        feature=feature)

    summaries = []
    main_task_id = hparams.problem.task_list[0].task_id
    vocab_size = problem_hparams.vocab_size[feature_name]
    if vocab_size is not None and hasattr(hparams, "vocab_divisor"):
        vocab_size += (-vocab_size) % hparams.vocab_divisor
    modality = problem_hparams.modality[feature_name]
    loss = hparams.loss.get(feature_name, modalities.get_loss(modality))
    weights_fn = hparams.weights_fn.get(feature_name,
                                        modalities.get_weights_fn(modality))
    # Primary task loss
    loss_num, loss_den = loss(
        logits, feature,
        lambda x: common_layers.weights_multi_problem_all(x, main_task_id),
        hparams, vocab_size, weights_fn)

    loss_val = loss_num / tf.maximum(1.0, loss_den)
    summaries.append([hparams.problem.task_list[0].name + "_loss", loss_val])

    # Since the losses may undergo rescaling, they cannot exist as separate
    # numerators and denominators. Set the denominators to 1 in order to faciliate
    # loss averaging.
    loss_num = loss_val
    loss_den = tf.minimum(tf.convert_to_tensor(1, dtype=tf.float32), loss_den)

    for task in hparams.problem.task_list[1:]:
        # Loss only from the input sequence -- the auxiliary LM loss.
        seq_loss_num, seq_loss_den = loss(
            logits,
            feature,
            lambda x: common_layers.weights_multi_problem_input(
                x, task.task_id),  # pylint: disable=cell-var-from-loop
            hparams,
            vocab_size)
        seq_loss_num *= problem_hparams.loss_multiplier

        # Unscaled sequence loss.
        seq_loss = seq_loss_num / tf.maximum(1.0, seq_loss_den)
        summaries.append([task.name + "_seq_loss", seq_loss])

        if hasattr(task, "num_classes"):
            # Loss only from the classification label.
            label_loss_num, label_loss_den = loss(
                logits,
                feature,
                lambda x: common_layers.weights_multi_problem(x, task.task_id),  # pylint: disable=cell-var-from-loop
                hparams,
                vocab_size)
            label_loss_num *= problem_hparams.loss_multiplier

            # Unscaled classification label loss.
            label_loss = label_loss_num / tf.maximum(1.0, label_loss_den)
            summaries.append([task.name + "_label_loss", label_loss])

            # Scaling.
            if hparams.multiproblem_reweight_label_loss:
                label_loss *= hparams.multiproblem_label_weight
                seq_loss *= (1 - hparams.multiproblem_label_weight)

            # This is the training loss for the optimizer after scaling.
            task_loss_val = seq_loss + label_loss

            loss_den_ = label_loss_den

        else:
            # Loss only from the target sequence.
            target_loss_num, target_loss_den = loss(
                logits,
                feature,
                lambda x: common_layers.weights_multi_problem(x, task.task_id),  # pylint: disable=cell-var-from-loop
                hparams,
                vocab_size)
            target_loss_num *= problem_hparams.loss_multiplier

            # Unscaled target sequence loss.
            target_loss = target_loss_num / tf.maximum(1.0, target_loss_den)
            summaries.append([task.name + "_target_loss", target_loss])

            # Scaling.
            if hparams.multiproblem_reweight_label_loss:
                target_loss *= hparams.multiproblem_label_weight
                seq_loss *= (1 - hparams.multiproblem_label_weight)

            # This is the training loss for the optimizer after all the scaling.
            task_loss_val = seq_loss + target_loss

            loss_den_ = target_loss_den

        summaries.append([task.name + "_loss", task_loss_val])
        # Adding 1 to the loss den for each task leads to averaging task losses.
        # TODO(urvashik): Fix combination with other task losses - weighted
        # average based on the number of examples from that task.
        loss_num += task_loss_val
        loss_den += tf.minimum(tf.convert_to_tensor(1, dtype=tf.float32),
                               loss_den_)

    return loss_num, loss_den, summaries