Esempio n. 1
0
  def testLossSingleWeights(self):
    """Ensure _loss_single() respects optional 'weights' argument."""
    with tf.Graph().as_default():
      with self.test_session() as sess:
        batch_size = 2
        sequence_size = 16
        vocab_size = 3

        model_hparams = tf.contrib.training.HParams(
            label_smoothing=0.0,
            shared_embedding_and_softmax_weights=False)

        problem_hparams = tf.contrib.training.HParams(loss_multiplier=1.0)
        problem_hparams.modality = {}

        model = t2t_model.T2TModel(
            model_hparams, problem_hparams=problem_hparams)
        logits = tf.zeros((batch_size, sequence_size, 1, 1, vocab_size))
        target_modality = modality.Modality(model_hparams)
        feature = tf.ones((batch_size, sequence_size, 1, 1))

        # all-zero weights == zero loss.
        weights = tf.zeros((batch_size, sequence_size))
        loss_num, loss_denom = model._loss_single(
            logits, target_modality, feature, weights=weights)
        self.assertAllClose(tf.zeros_like(loss_num), sess.run(loss_num))
        self.assertAllClose(tf.zeros_like(loss_denom), sess.run(loss_denom))

        # non-zero weights > zero loss.
        weights = tf.ones((batch_size, sequence_size))
        loss_num, loss_denom = model._loss_single(
            logits, target_modality, feature, weights=weights)
        self.assertAllLess(0.0, sess.run(loss_num))
        self.assertAllClose(batch_size * sequence_size, sess.run(loss_denom))
  def testLossSingleWeights(self):
    """Ensure _loss_single() respects optional 'weights' argument."""
    with tf.Graph().as_default():
      with self.test_session() as sess:
        batch_size = 2
        sequence_size = 16
        vocab_size = 3

        model_hparams = hparam.HParams(
            prepend_mode="none",
            loss={},
            weights_fn={},
            label_smoothing=0.0,
            shared_embedding_and_softmax_weights=False)

        ph = problem_hparams.TestProblem(
            vocab_size, vocab_size).get_hparams(model_hparams)

        model = t2t_model.T2TModel(model_hparams, problem_hparams=ph)
        logits = tf.zeros((batch_size, sequence_size, 1, 1, vocab_size))
        feature = tf.ones((batch_size, sequence_size, 1, 1))

        # all-zero weights == zero loss.
        weights = tf.zeros((batch_size, sequence_size))
        loss_num, loss_denom = model._loss_single(
            logits, "targets", feature, weights=weights)
        self.assertAllClose(tf.zeros_like(loss_num), sess.run(loss_num))
        self.assertAllClose(tf.zeros_like(loss_denom), sess.run(loss_denom))

        # non-zero weights > zero loss.
        weights = tf.ones((batch_size, sequence_size))
        loss_num, loss_denom = model._loss_single(
            logits, "targets", feature, weights=weights)
        self.assertAllLess(0.0, sess.run(loss_num))
        self.assertAllClose(batch_size * sequence_size, sess.run(loss_denom))
 def testSummarizeLosses(self):
   with tf.Graph().as_default():
     model = t2t_model.T2TModel(tf.contrib.training.HParams())
     losses = {"training": tf.random_normal([]),
               "extra": tf.random_normal([])}
     outputs = model._summarize_losses(losses)
     self.assertIsNone(outputs, None)
     self.assertLen(tf.get_collection(tf.GraphKeys.SUMMARIES, scope="losses"),
                    len(losses))
Esempio n. 4
0
    def model_fn(
        features, labels, mode, params
    ):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info('*** Features ***')
        for name in sorted(features.keys()):
            tf.logging.info(
                '  name = %s, shape = %s' % (name, features[name].shape)
            )

        vocab_size = bert_config.vocab_size
        ph = problem_hparams.TestProblem(vocab_size, vocab_size).get_hparams(
            model_hparams
        )

        model_t2t = t2t_model.T2TModel(model_hparams, problem_hparams = ph)

        input_ids = features['input_ids']
        input_mask = features['input_mask']
        segment_ids = features['segment_ids']
        y = features['y']

        is_training = mode == tf.estimator.ModeKeys.TRAIN

        model = bert_to_bert.Model(
            bert_config = bert_config,
            training = is_training,
            input_ids = input_ids,
            input_mask = input_mask,
            token_type_ids = segment_ids,
            Y = y,
        )
        o = model.training_logits
        Y_seq_len = tf.count_nonzero(y, 1, dtype = tf.int32)
        masks = tf.sequence_mask(Y_seq_len, tf.shape(y)[1], dtype = tf.float32)
        logits = tf.expand_dims(tf.expand_dims(o, axis = 2), axis = 2)
        feature = tf.expand_dims(tf.expand_dims(y, axis = 2), axis = 2)
        loss_num, loss_denom = model_t2t._loss_single(
            logits, 'targets', feature, weights = masks
        )
        total_loss = loss_num / loss_denom
        y_t = tf.argmax(o, axis = 2)
        y_t = tf.cast(y_t, tf.int32)
        prediction = tf.boolean_mask(y_t, masks)
        mask_label = tf.boolean_mask(y, masks)
        correct_pred = tf.equal(prediction, mask_label)
        correct_index = tf.cast(correct_pred, tf.float32)
        total_accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (
                assignment_map,
                initialized_variable_names,
            ) = modeling.get_assignment_map_from_checkpoint(
                tvars, init_checkpoint
            )
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(
                        init_checkpoint, assignment_map
                    )
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info('**** Trainable Variables ****')
        for var in tvars:
            init_string = ''
            if var.name in initialized_variable_names:
                init_string = ', *INIT_FROM_CKPT*'
            tf.logging.info(
                '  name = %s, shape = %s%s', var.name, var.shape, init_string
            )

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(
                total_loss,
                learning_rate,
                num_train_steps,
                num_warmup_steps,
                use_tpu,
            )

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode = mode,
                loss = total_loss,
                train_op = train_op,
                scaffold_fn = scaffold_fn,
            )
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(loss, accuracy):
                return {'total_loss': loss, 'total_accuracy': accuracy}

            eval_metrics = (metric_fn, [total_loss, total_accuracy])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode = mode,
                loss = total_loss,
                eval_metrics = eval_metrics,
                scaffold_fn = scaffold_fn,
            )
        else:
            raise ValueError(
                'Only TRAIN and EVAL modes are supported: %s' % (mode)
            )

        return output_spec
Esempio n. 5
0
from tensor2tensor.utils import test_utils

model_hparams = hparam.HParams(
    prepend_mode = 'none',
    loss = {},
    weights_fn = {},
    label_smoothing = 0.0,
    shared_embedding_and_softmax_weights = False,
)

vocab_size = 119547
ph = problem_hparams.TestProblem(vocab_size, vocab_size).get_hparams(
    model_hparams
)

model_t2t = t2t_model.T2TModel(model_hparams, problem_hparams = ph)

flags = tf.flags

FLAGS = flags.FLAGS

flags.DEFINE_string(
    'input_file',
    'multilanguagebert-train-*.tfrecord',
    'Input TF example files (can be a glob or comma separated).',
)

flags.DEFINE_string(
    'test_file',
    'multilanguagebert-test-*.tfrecord',
    'Input TF example files (can be a glob or comma separated).',