Exemple #1
0
def _sequence_correct(labels, predictions):
    """Computes a per-example sequence accuracy."""
    target_decode_steps = decode_utils.decode_steps_from_labels(
        labels, trim_start_symbol=True)
    predicted_decode_steps = decode_utils.decode_steps_from_predictions(
        predictions)

    decode_utils.assert_shapes_match(target_decode_steps,
                                     predicted_decode_steps)

    equal_tokens = decode_utils.compare_decode_steps(target_decode_steps,
                                                     predicted_decode_steps)
    target_len = labels["target_len"] - 1
    loss_mask = tf.sequence_mask(lengths=tf.to_int32(target_len),
                                 maxlen=tf.to_int32(tf.shape(equal_tokens)[1]))
    equal_tokens = tf.logical_or(equal_tokens, tf.logical_not(loss_mask))
    all_equal = tf.cast(tf.reduce_all(equal_tokens, 1), tf.float32)
    return all_equal
    def model_fn(features, labels, mode, params=None):
        """Model function for use with tf.learn.Estimator."""
        del params  # unused. model_fn is batch-size agnostic.

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        pretrained_variable_names = None
        scaffold_fn = None

        bert_config = bert_utils.get_bert_config(
            model_config.model_parameters.pretrained_bert_dir,
            reinitialize_type_embeddings=model_config.model_parameters.
            use_segment_ids)

        input_embeddings = embeddings.get_input_embeddings(
            model_config,
            bert_config,
            features,
            is_training,
            use_one_hot_embeddings=use_tpu)

        source_len = tf.to_int32(features[constants.SOURCE_LEN_KEY])

        output_embeddings_table = embeddings.get_output_vocab_embeddings_table(
            model_config, output_vocab_filepath)

        output_vocab_size = embeddings.get_output_vocab_size(
            output_vocab_filepath)

        clean_output_mask = None
        if clean_output_vocab_path:
            clean_output_mask_list = embeddings.get_clean_output_mask(
                output_vocab_filepath, clean_output_vocab_path)
            clean_output_mask = tf.convert_to_tensor(clean_output_mask_list)

        # For inference, just compute the inference predictions and return.
        if mode == tf.estimator.ModeKeys.PREDICT:
            predictions = transformer.infer(
                model_config,
                input_embeddings,
                source_len,
                output_vocab_size,
                output_embeddings_table,
                mode,
                input_copy_mask=features[constants.COPIABLE_INPUT_KEY],
                clean_output_mask=clean_output_mask,
                beam_size=beam_size)

            if use_tpu:
                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode=mode,
                    predictions=predictions,
                    scaffold_fn=scaffold_fn)
            else:
                return tf.estimator.EstimatorSpec(mode=mode,
                                                  predictions=predictions)

        with tpu_utils.rewire_summary_calls(use_tpu):
            # Get training predictions.
            train_decode_steps = decode_utils.decode_steps_from_labels(
                labels, trim_end_symbol=True)
            logits, predictions = transformer.train(
                model_config,
                input_embeddings,
                source_len,
                output_vocab_size,
                output_embeddings_table,
                train_decode_steps,
                mode,
                input_copy_mask=features[constants.COPIABLE_INPUT_KEY])

            # Calculate loss.
            weights = labels[constants.WEIGHT_KEY]
            loss_decode_steps = decode_utils.decode_steps_from_labels(
                labels, trim_start_symbol=True)

            # Account for removed start symbol.
            target_len = tf.to_int32(labels[constants.TARGET_LEN_KEY])
            target_len -= 1

            batch_loss = _compute_loss(logits, loss_decode_steps, target_len,
                                       weights, output_vocab_size,
                                       model_config)

            if mode == tf.estimator.ModeKeys.TRAIN:
                pretrained_variable_names, scaffold_fn = load_from_checkpoint.init_model_from_checkpoint(
                    model_config.model_parameters.pretrained_bert_dir,
                    use_tpu=use_tpu,
                    checkpoint_file="bert_model.ckpt",
                    reinitialize_type_embeddings=model_config.model_parameters.
                    use_segment_ids)
                train_op = adam_weight_decay.build_train_op_with_pretraining(
                    batch_loss, model_config, pretrained_variable_names,
                    use_tpu)

                if use_tpu:
                    return tf.estimator.tpu.TPUEstimatorSpec(
                        mode=mode,
                        loss=batch_loss,
                        train_op=train_op,
                        scaffold_fn=scaffold_fn)
                else:
                    return tf.estimator.EstimatorSpec(mode=mode,
                                                      loss=batch_loss,
                                                      train_op=train_op)

        eval_metrics = metrics.create_metrics_ops(labels=labels,
                                                  predictions=predictions)

        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=batch_loss,
                                          eval_metric_ops=eval_metrics)