def testSymbolModalityTargets(self): batch_size = 10 num_datashards = 5 length = 6 height = 7 hidden_size = 9 vocab_size = 11 model_hparams = common_hparams.basic_params1() model_hparams.hidden_size = hidden_size model_hparams.mode = tf.estimator.ModeKeys.TRAIN body_output = np.random.randint(100, size=(batch_size, length, height, hidden_size)) targets = np.random.randint(vocab_size, size=(batch_size, length, height, 1)) data_parallelism = expert_utils.Parallelism(["/device:CPU:0"] * num_datashards) sharded_body_output = tf.split(tf.to_float(body_output), num_datashards) sharded_targets = tf.split(targets, num_datashards) sharded_logits = data_parallelism( modalities.get_top(modalities.ModalityType.SYMBOL), sharded_body_output, sharded_targets, model_hparams, vocab_size) sharded_loss_num, sharded_loss_den = data_parallelism( modalities.get_loss(modalities.ModalityType.SYMBOL), sharded_logits, sharded_targets, model_hparams, vocab_size, modalities.get_weights_fn(modalities.ModalityType.SYMBOL)) train_loss = (tf.add_n(sharded_loss_num) / tf.maximum(1.0, tf.add_n(sharded_loss_den))) logits = tf.concat(sharded_logits, 0) self.evaluate(tf.global_variables_initializer()) res1, res2 = self.evaluate((logits, train_loss)) self.assertEqual(res1.shape, (batch_size, length, height, 1, vocab_size)) self.assertEqual(res2.shape, ())
def aggregate_task_lm_losses(hparams, problem_hparams, logits, feature_name, feature): """LM loss for multiproblems.""" summaries = [] vocab_size = problem_hparams.vocab_size[feature_name] if vocab_size is not None and hasattr(hparams, "vocab_divisor"): vocab_size += (-vocab_size) % hparams.vocab_divisor modality = problem_hparams.modality[feature_name] loss = hparams.loss.get(feature_name, modalities.get_loss(modality)) weights_fn = hparams.weights_fn.get(feature_name, modalities.get_weights_fn(modality)) loss_num = 0. loss_den = 0. for task in hparams.problem.task_list: loss_num_, loss_den_ = loss( logits, feature, lambda x: common_layers.weights_multi_problem_all(x, task.task_id), # pylint: disable=cell-var-from-loop hparams, vocab_size, weights_fn) loss_num += loss_num_ loss_den += loss_den_ loss_val = loss_num_ / tf.maximum(1.0, loss_den_) summaries.append([task.name + "_loss", loss_val]) return loss_num, loss_den, summaries
def testGetForAllModalities(self): for modality in modalities.ModalityType.get_choices(): bottom = modalities.get_bottom(modality) loss = modalities.get_loss(modality) name = modalities.get_name(modality) targets_bottom = modalities.get_targets_bottom(modality) top = modalities.get_top(modality) weights_fn = modalities.get_weights_fn(modality) self.assertIsNotNone(bottom, msg="{} has no default bottom".format(modality)) self.assertIsNotNone(loss, msg="{} has no default loss".format(modality)) self.assertIsNotNone(name, msg="{} has no default name".format(modality)) self.assertIsNotNone( targets_bottom, msg="{} has no default targets_bottom".format(modality)) self.assertIsNotNone(top, msg="{} has no default top".format(modality)) self.assertIsNotNone(weights_fn, msg="{} has no default weights_fn".format(modality))
def aggregate_task_losses(hparams, problem_hparams, logits, feature_name, feature): """Multiproblem loss function.""" # If no reweighting, we want the default loss to mimic the LM loss. if not hparams.multiproblem_reweight_label_loss: return aggregate_task_lm_losses(hparams=hparams, problem_hparams=problem_hparams, logits=logits, feature_name=feature_name, feature=feature) summaries = [] main_task_id = hparams.problem.task_list[0].task_id vocab_size = problem_hparams.vocab_size[feature_name] if vocab_size is not None and hasattr(hparams, "vocab_divisor"): vocab_size += (-vocab_size) % hparams.vocab_divisor modality = problem_hparams.modality[feature_name] loss = hparams.loss.get(feature_name, modalities.get_loss(modality)) weights_fn = hparams.weights_fn.get(feature_name, modalities.get_weights_fn(modality)) # Primary task loss loss_num, loss_den = loss( logits, feature, lambda x: common_layers.weights_multi_problem_all(x, main_task_id), hparams, vocab_size, weights_fn) loss_val = loss_num / tf.maximum(1.0, loss_den) summaries.append([hparams.problem.task_list[0].name + "_loss", loss_val]) # Since the losses may undergo rescaling, they cannot exist as separate # numerators and denominators. Set the denominators to 1 in order to faciliate # loss averaging. loss_num = loss_val loss_den = tf.minimum(tf.convert_to_tensor(1, dtype=tf.float32), loss_den) for task in hparams.problem.task_list[1:]: # Loss only from the input sequence -- the auxiliary LM loss. seq_loss_num, seq_loss_den = loss( logits, feature, lambda x: common_layers.weights_multi_problem_input( x, task.task_id), # pylint: disable=cell-var-from-loop hparams, vocab_size) seq_loss_num *= problem_hparams.loss_multiplier # Unscaled sequence loss. seq_loss = seq_loss_num / tf.maximum(1.0, seq_loss_den) summaries.append([task.name + "_seq_loss", seq_loss]) if hasattr(task, "num_classes"): # Loss only from the classification label. label_loss_num, label_loss_den = loss( logits, feature, lambda x: common_layers.weights_multi_problem(x, task.task_id), # pylint: disable=cell-var-from-loop hparams, vocab_size) label_loss_num *= problem_hparams.loss_multiplier # Unscaled classification label loss. label_loss = label_loss_num / tf.maximum(1.0, label_loss_den) summaries.append([task.name + "_label_loss", label_loss]) # Scaling. if hparams.multiproblem_reweight_label_loss: label_loss *= hparams.multiproblem_label_weight seq_loss *= (1 - hparams.multiproblem_label_weight) # This is the training loss for the optimizer after scaling. task_loss_val = seq_loss + label_loss loss_den_ = label_loss_den else: # Loss only from the target sequence. target_loss_num, target_loss_den = loss( logits, feature, lambda x: common_layers.weights_multi_problem(x, task.task_id), # pylint: disable=cell-var-from-loop hparams, vocab_size) target_loss_num *= problem_hparams.loss_multiplier # Unscaled target sequence loss. target_loss = target_loss_num / tf.maximum(1.0, target_loss_den) summaries.append([task.name + "_target_loss", target_loss]) # Scaling. if hparams.multiproblem_reweight_label_loss: target_loss *= hparams.multiproblem_label_weight seq_loss *= (1 - hparams.multiproblem_label_weight) # This is the training loss for the optimizer after all the scaling. task_loss_val = seq_loss + target_loss loss_den_ = target_loss_den summaries.append([task.name + "_loss", task_loss_val]) # Adding 1 to the loss den for each task leads to averaging task losses. # TODO(urvashik): Fix combination with other task losses - weighted # average based on the number of examples from that task. loss_num += task_loss_val loss_den += tf.minimum(tf.convert_to_tensor(1, dtype=tf.float32), loss_den_) return loss_num, loss_den, summaries