Exemplo n.º 1
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        if kargs.get('random_generator', '1') == '1':
            if mode in [
                    tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.TRAIN
            ]:
                input_ori_ids = features['input_ori_ids']

                # [output_ids,
                # sampled_binary_mask] = random_input_ids_generation(model_config,
                # 							features['input_ori_ids'],
                # 							features['input_mask'],
                # 							mask_probability=0.2,
                # 							replace_probability=0.1,
                # 							original_probability=0.1,
                # 							**kargs)

                [output_ids, sampled_binary_mask] = hmm_input_ids_generation(
                    model_config,
                    features['input_ori_ids'],
                    features['input_mask'], [
                        tf.cast(tf.constant(hmm_tran_prob), tf.float32)
                        for hmm_tran_prob in hmm_tran_prob_list
                    ],
                    mask_probability=0.2,
                    replace_probability=0.0,
                    original_probability=0.0,
                    mask_prior=tf.constant(mask_prior, tf.float32),
                    **kargs)

                features['input_ids'] = output_ids
                tf.logging.info("****** do random generator *******")
            else:
                sampled_binary_mask = None
        else:
            sampled_binary_mask = None

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        (nsp_loss, nsp_per_example_loss,
         nsp_log_prob) = pretrain.get_next_sentence_output(
             model_config,
             model.get_pooled_output(),
             features['next_sentence_labels'],
             reuse=tf.AUTO_REUSE,
             scope=generator_scope_prefix)

        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply bert masked lm==")

        if sampled_binary_mask is not None:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = seq_masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 features['input_mask'],
                 features['input_ori_ids'],
                 features['input_ids'],
                 sampled_binary_mask,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table(),
                 scope=generator_scope_prefix)
            masked_lm_ids = features['input_ori_ids']
        else:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 masked_lm_positions,
                 masked_lm_ids,
                 masked_lm_weights,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table(),
                 scope=generator_scope_prefix)
        print(model_config.lm_ratio, '==mlm lm_ratio==')
        loss = model_config.lm_ratio * masked_lm_loss + 0.0 * nsp_loss

        if kargs.get("resample_discriminator", False):
            input_ori_ids = features['input_ori_ids']

            [output_ids, sampled_binary_mask
             ] = random_input_ids_generation(model_config,
                                             features['input_ori_ids'],
                                             features['input_mask'],
                                             mask_probability=0.2,
                                             replace_probability=0.1,
                                             original_probability=0.1)

            resample_features = {}
            for key in features:
                resample_features[key] = features[key]

            resample_features['input_ids'] = tf.identity(output_ids)
            model_resample = model_api(model_config,
                                       resample_features,
                                       labels,
                                       mode,
                                       target,
                                       reuse=tf.AUTO_REUSE,
                                       **kargs)

            tf.logging.info("**** apply discriminator resample **** ")
        else:
            model_resample = model
            resample_features = features
            tf.logging.info("**** not apply discriminator resample **** ")

        sampled_ids = token_generator(model_config,
                                      model_resample.get_sequence_output(),
                                      model_resample.get_embedding_table(),
                                      resample_features['input_ids'],
                                      resample_features['input_ori_ids'],
                                      resample_features['input_mask'],
                                      embedding_projection=model_resample.
                                      get_embedding_projection_table(),
                                      scope=generator_scope_prefix,
                                      mask_method='only_mask',
                                      use_tpu=kargs.get('use_tpu', True))

        if model_config.get('gen_sample', 1) == 1:
            input_ids = features['input_ori_ids']
            input_mask = features['input_mask']
            segment_ids = features['segment_ids']
        else:
            input_ids = tf.expand_dims(features['input_ori_ids'], axis=-1)
            # batch x seq_length x 1
            input_ids = tf.einsum(
                'abc,cd->abd', input_ids,
                tf.ones((1, model_config.get('gen_sample', 1))))
            input_ids = tf.cast(input_ids, tf.int32)

            input_shape_list = bert_utils.get_shape_list(input_ids,
                                                         expected_rank=3)
            batch_size = input_shape_list[0]
            seq_length = input_shape_list[1]
            gen_sample = input_shape_list[2]

            sampled_ids = tf.reshape(sampled_ids,
                                     [batch * gen_sample, seq_length])
            input_ids = tf.reshape(input_ids, [batch * gen_sample, seq_length])

            input_mask = tf.expand_dims(features['input_mask'], axis=-1)
            input_mask = tf.einsum(
                'abc,cd->abd', input_mask,
                tf.ones((1, model_config.get('gen_sample', 1))))
            input_mask = tf.cast(input_mask, tf.int32)

            segment_ids = tf.expand_dims(features['segmnet_ids'], axis=-1)
            segment_ids = tf.einsum(
                'abc,cd->abd', segment_ids,
                tf.ones((1, model_config.get('gen_sample', 1))))
            segment_ids = tf.cast(segment_ids, tf.int32)

            segment_ids = tf.reshape(segment_ids,
                                     [batch * gen_sample, seq_length])
            input_mask = tf.reshape(input_mask,
                                    [batch * gen_sample, seq_length])

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        if generator_scope_prefix:
            """
			"generator/cls/predictions"
			"""
            lm_pretrain_tvars = model_io_fn.get_params(
                generator_scope_prefix + "/cls/predictions",
                not_storage_params=not_storage_params)

            nsp_pretrain_vars = model_io_fn.get_params(
                generator_scope_prefix + "/cls/seq_relationship",
                not_storage_params=not_storage_params)
        else:
            lm_pretrain_tvars = model_io_fn.get_params(
                "cls/predictions", not_storage_params=not_storage_params)

            nsp_pretrain_vars = model_io_fn.get_params(
                "cls/seq_relationship", not_storage_params=not_storage_params)

        if model_config.get('embedding_scope', None) is not None:
            embedding_tvars = model_io_fn.get_params(
                model_config.get('embedding_scope', 'bert') + "/embeddings",
                not_storage_params=not_storage_params)
            pretrained_tvars.extend(embedding_tvars)

        pretrained_tvars.extend(lm_pretrain_tvars)
        pretrained_tvars.extend(nsp_pretrain_vars)
        tvars = pretrained_tvars

        print('==generator parameters==', tvars)

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu,
                restore_var_name=model_config.get('restore_var_name', []))
        else:
            scaffold_fn = None
        tf.add_to_collection("generator_loss", loss)
        return_dict = {
            "loss": loss,
            "tvars": tvars,
            "model": model,
            "sampled_ids": sampled_ids,  # batch x gen_sample, seg_length
            "sampled_input_ids": input_ids,  # batch x gen_sample, seg_length,
            "sampled_input_mask": input_mask,
            "sampled_segment_ids": segment_ids,
            "masked_lm_ids": masked_lm_ids,
            "masked_lm_weights": masked_lm_mask,
            "masked_lm_log_probs": masked_lm_log_probs,
            "masked_lm_example_loss": masked_lm_example_loss,
            "next_sentence_example_loss": nsp_per_example_loss,
            "next_sentence_log_probs": nsp_log_prob,
            "next_sentence_labels": features['next_sentence_labels'],
            "sampled_binary_mask": sampled_binary_mask
        }
        return return_dict
Exemplo n.º 2
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        if 'input_mask' not in features:
            input_mask = tf.cast(
                tf.not_equal(features['input_ids_{}'.format(target)],
                             kargs.get('[PAD]', 0)), tf.int32)

            if target:
                features['input_mask_{}'.format(target)] = input_mask
            else:
                features['input_mask'] = input_mask
        if 'segment_ids' not in features:
            segment_ids = tf.zeros_like(input_mask)
            if target:
                features['segment_ids_{}'.format(target)] = segment_ids
            else:
                features['segment_ids'] = segment_ids

        if target:
            features['input_ori_ids'] = features['input_ids_{}'.format(target)]
            features['input_mask'] = features['input_mask_{}'.format(target)]
            features['segment_ids'] = features['segment_ids_{}'.format(target)]
            features['input_ids'] = features['input_ids_{}'.format(target)]

        input_ori_ids = features.get('input_ori_ids', None)
        if mode == tf.estimator.ModeKeys.TRAIN:
            if input_ori_ids is not None:
                # [output_ids,
                # sampled_binary_mask] = random_input_ids_generation(
                # 							model_config,
                # 							input_ori_ids,
                # 							features['input_mask'],
                # 							**kargs)

                [output_ids, sampled_binary_mask] = hmm_input_ids_generation(
                    model_config,
                    features['input_ori_ids'],
                    features['input_mask'], [
                        tf.cast(tf.constant(hmm_tran_prob), tf.float32)
                        for hmm_tran_prob in hmm_tran_prob_list
                    ],
                    mask_probability=0.2,
                    replace_probability=0.1,
                    original_probability=0.1,
                    mask_prior=tf.cast(tf.constant(mask_prior), tf.float32),
                    **kargs)

                features['input_ids'] = output_ids
                tf.logging.info(
                    "***** Running random sample input generation *****")
            else:
                sampled_binary_mask = None
        else:
            sampled_binary_mask = None

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        #(nsp_loss,
        # nsp_per_example_loss,
        # nsp_log_prob) = pretrain.get_next_sentence_output(model_config,
        #								model.get_pooled_output(),
        #								features['next_sentence_labels'],
        #								reuse=tf.AUTO_REUSE)

        # masked_lm_positions = features["masked_lm_positions"]
        # masked_lm_ids = features["masked_lm_ids"]
        # masked_lm_weights = features["masked_lm_weights"]

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply bert masked lm==")

        if sampled_binary_mask is not None:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = seq_masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 features['input_mask'],
                 features['input_ori_ids'],
                 features['input_ids'],
                 sampled_binary_mask,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table())
            masked_lm_ids = input_ori_ids
        else:

            masked_lm_positions = features["masked_lm_positions"]
            masked_lm_ids = features["masked_lm_ids"]
            masked_lm_weights = features["masked_lm_weights"]

            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 masked_lm_positions,
                 masked_lm_ids,
                 masked_lm_weights,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table())
        print(model_config.lm_ratio, '==mlm lm_ratio==')
        loss = model_config.lm_ratio * masked_lm_loss  #+ 0.0 * nsp_loss

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "cls/predictions", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)

        if load_pretrained == "yes":
            scaffold_fn = model_io_fn.load_pretrained(
                pretrained_tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=1)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            tvars = pretrained_tvars
            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss,
                    tvars,
                    opt_config.init_lr,
                    opt_config.num_train_steps,
                    use_tpu=opt_config.use_tpu)

                #	train_metric_dict = train_metric_fn(
                #			masked_lm_example_loss, masked_lm_log_probs,
                #			masked_lm_ids,
                #			masked_lm_mask,
                #			nsp_per_example_loss,
                #			nsp_log_prob,
                #			features['next_sentence_labels'],
                #			masked_lm_mask=masked_lm_mask
                #		)

                # for key in train_metric_dict:
                # 	tf.summary.scalar(key, train_metric_dict[key])
                # tf.summary.scalar('learning_rate', optimizer_fn.learning_rate)

                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn)

                return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_log_probs,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_log_probs = tf.reshape(
                    next_sentence_log_probs,
                    [-1, next_sentence_log_probs.shape[-1]])
                next_sentence_predictions = tf.argmax(next_sentence_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss
                }

            eval_metrics = (metric_fn, [
                masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                masked_lm_mask, nsp_per_example_loss, nsp_log_prob,
                features['next_sentence_labels']
            ])

            estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)

            return estimator_spec
        else:
            raise NotImplementedError()
Exemplo n.º 3
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        if target:
            features['input_ori_ids'] = features['input_ids_{}'.format(target)]
            features['input_ids'] = features['input_ids_{}'.format(target)]
        sequence_mask = tf.cast(
            tf.not_equal(features['input_ori_ids'], kargs.get('[PAD]', 0)),
            tf.int32)
        features['input_mask'] = sequence_mask

        seq_features = {}
        for key in features:
            seq_features[key] = features[key]
        if 'input_ori_ids' in features:
            seq_features['input_ids'] = features["input_ori_ids"]
        else:
            features['input_ori_ids'] = seq_features['input_ids']

        not_equal = tf.cast(
            tf.not_equal(features["input_ori_ids"],
                         tf.zeros_like(features["input_ori_ids"])), tf.int32)
        not_equal = tf.reduce_sum(not_equal, axis=-1)
        loss_mask = tf.cast(tf.not_equal(not_equal, tf.zeros_like(not_equal)),
                            tf.float32)

        if not kargs.get('use_tpu', False):
            tf.summary.scalar('loss_mask', tf.reduce_sum(loss_mask))

        casual_flag = model_config.get('is_casual', True)
        tf.logging.info("***** is casual flag *****", str(casual_flag))

        if not casual_flag:
            [output_ids, sampled_binary_mask] = hmm_input_ids_generation(
                model_config,
                features['input_ori_ids'],
                features['input_mask'], [
                    tf.cast(tf.constant(hmm_tran_prob), tf.float32)
                    for hmm_tran_prob in hmm_tran_prob_list
                ],
                mask_probability=0.02,
                replace_probability=0.01,
                original_probability=0.01,
                mask_prior=tf.cast(tf.constant(mask_prior), tf.float32),
                **kargs)
            tf.logging.info("***** apply random sampling *****")
            seq_features['input_ids'] = output_ids

        model = model_api(model_config,
                          seq_features,
                          labels,
                          mode,
                          "",
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        # if mode == tf.estimator.ModeKeys.TRAIN:
        if kargs.get('mask_type', 'left2right') == 'left2right':
            tf.logging.info("***** using left2right mask and loss *****")
            sequence_mask = tf.to_float(
                tf.not_equal(features['input_ori_ids'][:, 1:],
                             kargs.get('[PAD]', 0)))
        elif kargs.get('mask_type', 'left2right') == 'seq2seq':
            tf.logging.info("***** using seq2seq mask and loss *****")
            sequence_mask = tf.to_float(features['segment_ids'][:, 1:])
            if not kargs.get('use_tpu', False):
                tf.summary.scalar("loss mask", tf.reduce_mean(sequence_mask))

        # batch x seq_length
        if casual_flag:
            print(model.get_sequence_output_logits().get_shape(),
                  "===logits shape===")
            seq_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=features['input_ori_ids'][:, 1:],
                logits=model.get_sequence_output_logits()[:, :-1])

            per_example_loss = tf.reduce_sum(
                seq_loss * sequence_mask,
                axis=-1) / (tf.reduce_sum(sequence_mask, axis=-1) + 1e-10)
            loss = tf.reduce_mean(per_example_loss)

            if model_config.get("cnn_type",
                                "dgcnn") in ['bi_dgcnn', 'bi_light_dgcnn']:
                seq_backward_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=features['input_ori_ids'][:, :-1],
                    logits=model.get_sequence_backward_output_logits()[:, 1:])

                per_backward_example_loss = tf.reduce_sum(
                    seq_backward_loss * sequence_mask,
                    axis=-1) / (tf.reduce_sum(sequence_mask, axis=-1) + 1e-10)
                backward_loss = tf.reduce_mean(per_backward_example_loss)
                loss += backward_loss
                tf.logging.info("***** using backward loss *****")
        else:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = pretrain.seq_mask_masked_lm_output(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 seq_features['input_mask'],
                 seq_features['input_ori_ids'],
                 seq_features['input_ids'],
                 sampled_binary_mask,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table())
            loss = masked_lm_loss
            tf.logging.info("***** using masked lm loss *****")
        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "cls/predictions", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                pretrained_tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
                tf.logging.info(
                    "***** using tpu with tpu-captiable optimizer *****")
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0
                tf.logging.info(
                    "***** using gpu with gpu-captiable optimizer *****")

            tvars = pretrained_tvars
            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss,
                    tvars,
                    opt_config.init_lr,
                    opt_config.num_train_steps,
                    use_tpu=use_tpu)

                # train_metric_dict = train_metric(features['input_ori_ids'],
                # 								model.get_sequence_output_logits(),
                # 								seq_features,
                # 								**kargs)

                # if not kargs.get('use_tpu', False):
                # 	for key in train_metric_dict:
                # 		tf.summary.scalar(key, train_metric_dict[key])
                # 	tf.summary.scalar('learning_rate', optimizer_fn.learning_rate)
                # 	tf.logging.info("***** logging metric *****")
                # 	tf.summary.scalar("causal_attenion_mask_length", tf.reduce_sum(sequence_mask))
                # tf.summary.scalar("bi_attenion_mask_length", tf.reduce_sum(model.bi_attention_mask))

                if kargs.get('use_tpu', False):
                    estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                        mode=mode,
                        loss=loss,
                        train_op=train_op,
                        scaffold_fn=scaffold_fn)
                else:
                    estimator_spec = tf.estimator.EstimatorSpec(
                        mode=mode, loss=loss, train_op=train_op)

                return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            gpu_eval_metrics = eval_metric(features['input_ori_ids'],
                                           model.get_sequence_output_logits(),
                                           seq_features, **kargs)
            tpu_eval_metrics = (eval_metric, [
                features['input_ori_ids'],
                model.get_sequence_output_logits(), seq_features,
                kargs.get('mask_type', 'left2right')
            ])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            if kargs.get('predict_type',
                         'sample_sequence') == 'sample_sequence':
                results = bert_seq_sample_utils.sample_sequence(
                    model_api,
                    model_config,
                    mode,
                    features,
                    target="",
                    start_token=kargs.get("start_token_id", 101),
                    batch_size=None,
                    context=features.get("context", None),
                    temperature=kargs.get("sample_temp", 1.0),
                    n_samples=kargs.get("n_samples", 1),
                    top_k=0,
                    end_token=kargs.get("end_token_id", 102),
                    greedy_or_sample="greedy",
                    gumbel_temp=0.01,
                    estimator="stop_gradient",
                    back_prop=True,
                    swap_memory=True,
                    seq_type=kargs.get("seq_type", "seq2seq"),
                    mask_type=kargs.get("mask_type", "seq2seq"),
                    attention_type=kargs.get('attention_type',
                                             'normal_attention'))
                # stop_gradient output:
                # samples, mask_sequence, presents, logits, final

                sampled_token = results['samples']
                sampled_token_logits = results['logits']
                mask_sequence = results['mask_sequence']

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions={
                        'token': sampled_token,
                        "logits": sampled_token_logits,
                        "mask_sequence": mask_sequence
                    },
                    export_outputs={
                        "output":
                        tf.estimator.export.PredictOutput({
                            'token':
                            sampled_token,
                            "logits":
                            sampled_token_logits,
                            "mask_sequence":
                            mask_sequence
                        })
                    })

                return estimator_spec

            elif kargs.get('predict_type',
                           'sample_sequence') == 'infer_inputs':

                sequence_mask = tf.to_float(
                    tf.not_equal(features['input_ids'][:, 1:],
                                 kargs.get('[PAD]', 0)))

                if kargs.get('mask_type', 'left2right') == 'left2right':
                    tf.logging.info(
                        "***** using left2right mask and loss *****")
                    sequence_mask = tf.to_float(
                        tf.not_equal(features['input_ori_ids'][:, 1:],
                                     kargs.get('[PAD]', 0)))
                elif kargs.get('mask_type', 'left2right') == 'seq2seq':
                    tf.logging.info("***** using seq2seq mask and loss *****")
                    sequence_mask = tf.to_float(features['segment_ids'][:, 1:])
                    if not kargs.get('use_tpu', False):
                        tf.summary.scalar("loss mask",
                                          tf.reduce_mean(sequence_mask))

                output_logits = model.get_sequence_output_logits()[:, :-1]
                # output_logits = tf.nn.log_softmax(output_logits, axis=-1)

                output_id_logits = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=features['input_ids'][:, 1:], logits=output_logits)

                per_example_perplexity = tf.reduce_sum(output_id_logits *
                                                       sequence_mask,
                                                       axis=-1)  # batch
                per_example_perplexity /= tf.reduce_sum(sequence_mask,
                                                        axis=-1)  # batch

                perplexity = tf.exp(per_example_perplexity)

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions={
                        'token': features['input_ids'][:, 1:],
                        "logits": output_id_logits,
                        'perplexity': perplexity,
                        # "all_logits":output_logits
                    },
                    export_outputs={
                        "output":
                        tf.estimator.export.PredictOutput({
                            'token':
                            features['input_ids'][:, 1:],
                            "logits":
                            output_id_logits,
                            'perplexity':
                            perplexity,
                            # "all_logits":output_logits
                        })
                    })

                return estimator_spec
        else:
            raise NotImplementedError()
Exemplo n.º 4
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        input_mask = tf.cast(
            tf.not_equal(features['input_ids_{}'.format(target)],
                         kargs.get('[PAD]', 0)), tf.int32)
        segment_ids = tf.zeros_like(input_mask)

        if target:
            features['input_ori_ids'] = features['input_ids_{}'.format(target)]
            features['input_mask'] = input_mask
            features['segment_ids'] = segment_ids
            # features['input_mask'] = features['input_mask_{}'.format(target)]
            # features['segment_ids'] = features['segment_ids_{}'.format(target)]
            features['input_ids'] = features['input_ids_{}'.format(target)]

        input_ori_ids = features.get('input_ori_ids', None)
        if mode == tf.estimator.ModeKeys.TRAIN:
            if input_ori_ids is not None:

                [output_ids, sampled_binary_mask] = hmm_input_ids_generation(
                    model_config,
                    features['input_ori_ids'],
                    features['input_mask'], [
                        tf.cast(tf.constant(hmm_tran_prob), tf.float32)
                        for hmm_tran_prob in hmm_tran_prob_list
                    ],
                    mask_probability=0.1,
                    replace_probability=0.1,
                    original_probability=0.1,
                    mask_prior=tf.cast(tf.constant(mask_prior), tf.float32),
                    **kargs)

                features['input_ids'] = output_ids
                tf.logging.info(
                    "***** Running random sample input generation *****")
            else:
                sampled_binary_mask = None
        else:
            sampled_binary_mask = None

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          "",
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply bert masked lm==")

        if sampled_binary_mask is not None:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = seq_masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 features['input_mask'],
                 features['input_ori_ids'],
                 features['input_ids'],
                 sampled_binary_mask,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table())
            masked_lm_ids = input_ori_ids
        else:
            masked_lm_positions = features["masked_lm_positions"]
            masked_lm_ids = features["masked_lm_ids"]
            masked_lm_weights = features["masked_lm_weights"]
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 masked_lm_positions,
                 masked_lm_ids,
                 masked_lm_weights,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table())
        print(model_config.lm_ratio, '==mlm lm_ratio==')
        loss = model_config.lm_ratio * masked_lm_loss  #+ 0.0 * nsp_loss

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "cls/predictions", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)
        tvars = pretrained_tvars

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu,
                restore_var_name=model_config.get('restore_var_name', []))
        else:
            scaffold_fn = None

        return_dict = {
            "loss": loss,
            "logits": masked_lm_log_probs,
            "masked_lm_example_loss": masked_lm_example_loss,
            "tvars": tvars,
            "model": model,
            "masked_lm_mask": masked_lm_mask,
            "output_ids": output_ids,
            "masked_lm_ids": masked_lm_ids
        }
        return return_dict