Ejemplo n.º 1
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        tf_logging.info("model_fn_sero_classification")
        """The `model_fn` for TPUEstimator."""
        log_features(features)
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        # Updated
        model = BertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=train_config.use_one_hot_embeddings,
        )
        pooled_output = model.get_pooled_output()
        if is_training:
            pooled_output = dropout(pooled_output, 0.1)

        logits = get_prediction_structure(modeling_opt, pooled_output)
        loss = 0

        tvars = tf.compat.v1.trainable_variables()
        assignment_fn = assignment_map.get_bert_assignment_map
        initialized_variable_names, init_fn = get_init_fn(tvars, train_config.init_checkpoint, assignment_fn)
        scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)

        predictions = None
        if modeling_opt == "multi_label_hinge":
            predictions = {
                "input_ids":input_ids,
                "logits":logits,
            }
        else:
            predictions = {
                "input_ids": input_ids,
                "logits": logits,
            }
            useful_inputs = ["data_id", "input_ids2", "data_ids"]
            for input_name in useful_inputs:
                if input_name in features:
                    predictions[input_name] = features[input_name]
        output_spec = rank_predict_estimator_spec(logits, mode, scaffold_fn, predictions)
        return output_spec
Ejemplo n.º 2
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        tf_logging.info("model_fn_sero_ranking_predict")
        """The `model_fn` for TPUEstimator."""
        log_features(features)

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]

        batch_size, _ = get_shape_list(input_mask)
        use_context = tf.ones([batch_size, 1], tf.int32)

        stacked_input_ids, stacked_input_mask, stacked_segment_ids, \
            = split_and_append_sep(input_ids, input_mask, segment_ids,
                                   config.total_sequence_length, config.window_size, CLS_ID, EOW_ID)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        # Updated

        with tf.compat.v1.variable_scope("sero"):
            model = model_class(config, is_training,
                                train_config.use_one_hot_embeddings)
            model.network_stacked(stacked_input_ids, stacked_input_mask,
                                  stacked_segment_ids, use_context)

        pooled_output = model.get_pooled_output()
        logits = get_prediction_structure(config.loss, pooled_output)

        tvars = tf.compat.v1.trainable_variables()
        assignment_fn = assignment_map.assignment_map_v2_to_v2

        initialized_variable_names, init_fn = get_init_fn(
            tvars, train_config.init_checkpoint, assignment_fn)
        scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)
        output_spec = rank_predict_estimator_spec(logits, mode, scaffold_fn)
        return output_spec
Ejemplo n.º 3
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        log_features(features)

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]

        #
        #
        # d_input_ids, d_input_mask, d_segment_ids, d_location_ids, ab_mapping, ab_mapping_mask \
        #     = get_dummy_apr_input(input_ids, input_mask,
        #                           dict_run_config.def_per_batch,
        #                           dict_run_config.inner_batch_size,
        #                           ssdr_config.max_loc_length,
        #                           dict_run_config.max_def_length)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        model = APR(input_ids, input_mask, segment_ids, is_training,
                    train_config.use_one_hot_embeddings, bert_config,
                    ssdr_config, dict_run_config.def_per_batch,
                    dict_run_config.inner_batch_size,
                    dict_run_config.max_def_length)

        #
        # model = model_class(
        #         config=bert_config,
        #         ssdr_config=ssdr_config,
        #         is_training=is_training,
        #         input_ids=input_ids,
        #         input_mask=input_mask,
        #         token_type_ids=segment_ids,
        #         d_input_ids=d_input_ids,
        #         d_input_mask=d_input_mask,
        #         d_segment_ids=d_segment_ids,
        #         d_location_ids=d_location_ids,
        #         ab_mapping=ab_mapping,
        #         ab_mapping_mask=ab_mapping_mask,
        #         use_one_hot_embeddings=train_config.use_one_hot_embeddings,
        # )
        task = Classification(3, features, model.get_pooled_output(),
                              is_training)
        loss = task.loss

        tvars = tf.compat.v1.trainable_variables()
        assignment_fn = tlm.training.assignment_map.get_assignment_map_as_is
        initialized_variable_names, init_fn = get_init_fn(
            tvars, train_config.init_checkpoint, assignment_fn)
        scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)
        output_spec = None
        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
        if mode == tf.estimator.ModeKeys.TRAIN:
            if ssdr_config.compare_attrib_value_safe("use_two_lr", True):
                tf_logging.info("Using two lr for each parts")
                train_op = create_optimizer_with_separate_lr(
                    loss, train_config)
            else:
                tf_logging.info("Using single lr ")
                train_op = optimization.create_optimizer_from_config(
                    loss, train_config)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            output_spec = TPUEstimatorSpec(mode=model,
                                           loss=loss,
                                           eval_metrics=task.eval_metrics(),
                                           scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.PREDICT:
            output_spec = TPUEstimatorSpec(mode=model,
                                           loss=loss,
                                           predictions={"loss": task.loss_arr},
                                           scaffold_fn=scaffold_fn)
        return output_spec
Ejemplo n.º 4
0
def checkpoint_init(assignment_fn, train_config):
    tvars = tf.compat.v1.trainable_variables()
    initialized_variable_names, init_fn = get_init_fn(tvars, train_config.init_checkpoint, assignment_fn)
    scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)
    log_var_assignments(tvars, initialized_variable_names)
    return scaffold_fn
Ejemplo n.º 5
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        tf_logging.info("model_fn_nli_lm")
        """The `model_fn` for TPUEstimator."""
        log_features(features)

        input_ids = features["input_ids"]  # [batch_size, seq_length]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        batch_size, seq_max = get_shape_list2(input_ids)
        if "nli_input_ids" in features:
            nli_input_ids = features[
                "nli_input_ids"]  # [batch_size, seq_length]
            nli_input_mask = features["nli_input_mask"]
            nli_segment_ids = features["nli_segment_ids"]
        else:
            nli_input_ids = input_ids
            nli_input_mask = input_mask
            nli_segment_ids = segment_ids
            features["label_ids"] = tf.ones([batch_size], tf.int32)

        if mode == tf.estimator.ModeKeys.PREDICT:
            tf.random.set_seed(0)
            seed = 0
        else:
            seed = None

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        tf_logging.info("Doing dynamic masking (random)")

        masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights \
            = random_masking(input_ids, input_mask,
                             train_config.max_predictions_per_seq, MASK_ID, seed)

        sharing_model = sharing_model_factory(
            config, train_config.use_one_hot_embeddings, is_training,
            masked_input_ids, input_mask, segment_ids, nli_input_ids,
            nli_input_mask, nli_segment_ids)

        sequence_output_lm = sharing_model.lm_sequence_output()
        nli_feature = sharing_model.get_tt_feature()

        masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs \
            = get_masked_lm_output(config, sequence_output_lm, sharing_model.get_embedding_table(),
                                     masked_lm_positions, masked_lm_ids, masked_lm_weights)

        masked_lm_log_probs = tf.reshape(masked_lm_log_probs, [batch_size, -1])

        masked_lm_per_inst_loss = tf.reshape(masked_lm_example_loss,
                                             [batch_size, -1])

        task = Classification(3, features, nli_feature, is_training)
        nli_loss = task.loss

        task_prob = tf.nn.softmax(task.logits, axis=-1)
        arg_like = task_prob[:, 1] + task_prob[:, 2]

        vars = sharing_model.model.all_layer_outputs
        grads_1 = tf.gradients(ys=masked_lm_loss, xs=vars)  # List[ batch_szie,
        grads_2 = tf.gradients(ys=arg_like, xs=vars)
        l = []
        for g1, g2 in zip(grads_1, grads_2):
            if g1 is not None and g2 is not None:
                a = tf.reshape(g1, [batch_size * 2, seq_max, -1])[:batch_size]
                a = a / masked_lm_per_inst_loss
                b = tf.reshape(g2, [batch_size * 2, seq_max, -1])[batch_size:]
                l.append(tf.abs(a * b))
        h_overlap = tf.stack(l, axis=1)
        h_overlap = tf.reduce_sum(h_overlap, axis=2)

        loss = combine_loss_fn(masked_lm_loss, nli_loss)
        tvars = tf.compat.v1.trainable_variables()
        assignment_fn = get_bert_assignment_map
        initialized_variable_names, init_fn = get_init_fn(
            tvars, train_config.init_checkpoint, assignment_fn)
        scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)

        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer_from_config(
                loss, train_config)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (metric_fn_lm, [
                masked_lm_example_loss,
                masked_lm_log_probs,
                masked_lm_ids,
                masked_lm_weights,
            ])
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
        else:
            predictions = {
                "input_ids": input_ids,
                "masked_input_ids": masked_input_ids,
                "masked_lm_ids": masked_lm_ids,
                "masked_lm_example_loss": masked_lm_example_loss,
                "masked_lm_positions": masked_lm_positions,
                "masked_lm_log_probs": masked_lm_log_probs,
                "h_overlap": h_overlap,
            }
            output_spec = TPUEstimatorSpec(mode=mode,
                                           predictions=predictions,
                                           scaffold_fn=scaffold_fn)

        return output_spec
Ejemplo n.º 6
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        tf_logging.info("model_fn_nli_lm")
        """The `model_fn` for TPUEstimator."""
        log_features(features)

        input_ids = features["input_ids"]  # [batch_size, seq_length]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        batch_size, _ = get_shape_list2(input_ids)
        if "nli_input_ids" in features:
            nli_input_ids = features[
                "nli_input_ids"]  # [batch_size, seq_length]
            nli_input_mask = features["nli_input_mask"]
            nli_segment_ids = features["nli_segment_ids"]
        else:
            nli_input_ids = input_ids
            nli_input_mask = input_mask
            nli_segment_ids = segment_ids
            features["label_ids"] = tf.ones([batch_size], tf.int32)

        if mode == tf.estimator.ModeKeys.PREDICT:
            tf.random.set_seed(0)
            seed = 0
        else:
            seed = None

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        tf_logging.info("Doing dynamic masking (random)")

        masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights \
            = random_masking(input_ids, input_mask,
                             train_config.max_predictions_per_seq, MASK_ID, seed)

        sharing_model = sharing_model_factory(
            config, train_config.use_one_hot_embeddings, is_training,
            masked_input_ids, input_mask, segment_ids, nli_input_ids,
            nli_input_mask, nli_segment_ids)

        sequence_output_lm = sharing_model.lm_sequence_output()
        nli_feature = sharing_model.get_tt_feature()

        masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs \
            = get_masked_lm_output(config, sequence_output_lm, sharing_model.get_embedding_table(),
                                     masked_lm_positions, masked_lm_ids, masked_lm_weights)

        masked_lm_log_probs = tf.reshape(masked_lm_log_probs, [batch_size, -1])

        top_guess = masked_lm_log_probs

        task = Classification(3, features, nli_feature, is_training)
        nli_loss = task.loss

        overlap_score = shared_gradient_fine_grained(
            masked_lm_example_loss, task.logits,
            train_config.max_predictions_per_seq)
        loss = combine_loss_fn(masked_lm_loss, nli_loss)
        tvars = tf.compat.v1.trainable_variables()
        assignment_fn = get_bert_assignment_map
        initialized_variable_names, init_fn = get_init_fn(
            tvars, train_config.init_checkpoint, assignment_fn)
        scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)

        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer_from_config(
                loss, train_config)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (metric_fn_lm, [
                masked_lm_example_loss,
                masked_lm_log_probs,
                masked_lm_ids,
                masked_lm_weights,
            ])
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
        else:
            predictions = {
                "input_ids": input_ids,
                "masked_input_ids": masked_input_ids,
                "masked_lm_ids": masked_lm_ids,
                "masked_lm_example_loss": masked_lm_example_loss,
                "masked_lm_positions": masked_lm_positions,
                "masked_lm_log_probs": masked_lm_log_probs,
                "overlap_score": overlap_score,
                "top_guess": top_guess,
            }
            output_spec = TPUEstimatorSpec(mode=mode,
                                           predictions=predictions,
                                           scaffold_fn=scaffold_fn)

        return output_spec
Ejemplo n.º 7
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        tf_logging.info("model_fn_apr_lm")
        """The `model_fn` for TPUEstimator."""
        log_features(features)

        raw_input_ids = features["input_ids"]  # [batch_size, seq_length]
        raw_input_mask = features["input_mask"]
        raw_segment_ids = features["segment_ids"]

        word_tokens = features["word"]
        word_input_mask = tf.cast(tf.not_equal(word_tokens, 0), tf.int32)
        word_segment_ids = tf.ones_like(word_tokens, tf.int32)

        if mode == tf.estimator.ModeKeys.PREDICT:
            tf.random.set_seed(0)
            seed = 0
        else:
            seed = None

        input_ids = tf.concat([word_tokens, raw_input_ids], axis=1)
        input_mask = tf.concat([word_input_mask, raw_input_mask], axis=1)
        segment_ids = tf.concat([word_segment_ids, raw_segment_ids], axis=1)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        tf_logging.info("Using masked_input_ids")
        masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights \
            = random_masking(input_ids, input_mask, train_config.max_predictions_per_seq, MASK_ID, seed)

        model = BertModel(
            config=config,
            is_training=is_training,
            input_ids=masked_input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=train_config.use_one_hot_embeddings,
        )

        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_log_probs) = get_masked_lm_output(
             config, model.get_sequence_output(), model.get_embedding_table(),
             masked_lm_positions, masked_lm_ids, masked_lm_weights)

        loss = masked_lm_loss
        tvars = tf.compat.v1.trainable_variables()
        assignment_fn = tlm.training.assignment_map.get_bert_assignment_map
        initialized_variable_names, init_fn = get_init_fn(
            tvars, train_config.init_checkpoint, assignment_fn)
        scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)

        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_logging.info("Using single lr ")
            train_op = optimization.create_optimizer_from_config(
                loss, train_config)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (metric_fn_lm, [
                masked_lm_example_loss,
                masked_lm_log_probs,
                masked_lm_ids,
                masked_lm_weights,
            ])
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
        else:
            predictions = {
                "input_ids": input_ids,
                "masked_input_ids": masked_input_ids,
                "masked_lm_ids": masked_lm_ids,
                "masked_lm_example_loss": masked_lm_example_loss,
                "masked_lm_positions": masked_lm_positions
            }
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           predictions=predictions,
                                           scaffold_fn=scaffold_fn)

        return output_spec
Ejemplo n.º 8
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        tf_logging.info("model_fn_sero_lm")
        """The `model_fn` for TPUEstimator."""
        log_features(features)

        input_ids = features["input_ids"]  # [batch_size, seq_length]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]

        is_sero_modeling = "sero" in modeling
        if is_sero_modeling:
            use_context = features["use_context"]
        elif modeling == "bert":
            batch_size, _ = get_shape_list(input_mask)
            use_context = tf.ones([batch_size, 1], tf.int32)
        else:
            assert False

        if mode == tf.estimator.ModeKeys.PREDICT:
            tf.random.set_seed(0)
            seed = 0
        else:
            seed = None

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        tf_logging.info("Using masked_input_ids")
        if is_sero_modeling:
            stacked_input_ids, stacked_input_mask, stacked_segment_ids, \
                = split_and_append_sep(input_ids, input_mask, segment_ids,
                                       config.total_sequence_length, config.window_size, CLS_ID, EOW_ID)
            input_ids_2d = r3to2(stacked_input_ids)
            input_mask_2d = r3to2(stacked_input_mask)

        elif modeling == "bert":
            stacked_input_ids, stacked_input_mask, stacked_segment_ids = input_ids, input_mask, segment_ids
            input_ids_2d = stacked_input_ids
            input_mask_2d = stacked_input_mask
        else:
            assert False

        tf_logging.info("Doing dynamic masking (random)")

        # TODO make stacked_input_ids 2D and recover
        masked_input_ids_2d, masked_lm_positions_2d, masked_lm_ids_2d, masked_lm_weights_2d \
            = random_masking(input_ids_2d, input_mask_2d,
                             train_config.max_predictions_per_seq, MASK_ID, seed, [EOW_ID])

        if is_sero_modeling:
            masked_input_ids = tf.reshape(masked_input_ids_2d,
                                          stacked_input_ids.shape)
        elif modeling == "bert":
            masked_input_ids = tf.expand_dims(masked_input_ids_2d, 1)
            stacked_input_mask = tf.expand_dims(stacked_input_mask, 1)
            stacked_segment_ids = tf.expand_dims(stacked_segment_ids, 1)
        else:
            assert False

        if modeling == "sero":
            model_class = SeroDelta
        elif modeling == "sero_epsilon":
            model_class = SeroEpsilon

        with tf.compat.v1.variable_scope("sero"):
            model = model_class(config, is_training,
                                train_config.use_one_hot_embeddings)
            sequence_output_3d = model.network_stacked(masked_input_ids,
                                                       stacked_input_mask,
                                                       stacked_segment_ids,
                                                       use_context)
        masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs \
            = get_masked_lm_output(config, sequence_output_3d, model.get_embedding_table(),
                                     masked_lm_positions_2d, masked_lm_ids_2d, masked_lm_weights_2d)

        predictions = None
        if prediction_op == "gradient_to_long_context":
            predictions = {}
            for idx, input_tensor in enumerate(model.upper_module_inputs):
                g = tf.abs(tf.gradients(ys=masked_lm_loss, xs=input_tensor)[0])
                main_g = g[:, :config.window_size, :]
                context_g = g[:, config.window_size:, :]
                main_g = tf.reduce_mean(tf.reduce_mean(main_g, axis=2), axis=1)
                context_g = tf.reduce_mean(tf.reduce_mean(context_g, axis=2),
                                           axis=1)
                predictions['main_g_{}'.format(idx)] = main_g
                predictions['context_g_{}'.format(idx)] = context_g

        loss = masked_lm_loss  #+ bert_task.masked_lm_loss
        tvars = tf.compat.v1.trainable_variables()
        if train_config.init_checkpoint:
            assignment_fn = get_assignment_map_from_checkpoint_type(
                train_config.checkpoint_type, config.lower_layers)
        else:
            assignment_fn = None
        initialized_variable_names, init_fn = get_init_fn(
            tvars, train_config.init_checkpoint, assignment_fn)
        log_var_assignments(tvars, initialized_variable_names)
        scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)

        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer_from_config(
                loss, train_config)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           training_hooks=[OomReportingHook()],
                                           scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            output_spec = TPUEstimatorSpec(mode=model,
                                           loss=loss,
                                           eval_metrics=None,
                                           scaffold_fn=scaffold_fn)
        else:
            if predictions is None:
                predictions = {
                    "input_ids": input_ids,
                    "masked_input_ids": masked_input_ids,
                    "masked_lm_ids": masked_lm_ids_2d,
                    "masked_lm_example_loss": masked_lm_example_loss,
                    "masked_lm_positions": masked_lm_positions_2d,
                }
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           predictions=predictions,
                                           scaffold_fn=scaffold_fn)

        return output_spec
Ejemplo n.º 9
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        tf_logging.info("model_fn_sero_classification")
        """The `model_fn` for TPUEstimator."""
        log_features(features)

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]

        batch_size, _ = get_shape_list(input_mask)
        use_context = tf.ones([batch_size, 1], tf.int32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        # Updated
        if modeling == "sero":
            model_class = SeroDelta
            print("Using SeroDelta")
        elif modeling == "sero_epsilon":
            model_class = SeroEpsilon
            print("Using SeroEpsilon")
        else:
            assert False

        with tf.compat.v1.variable_scope("sero"):
            model = model_class(config, is_training,
                                train_config.use_one_hot_embeddings)
            input_ids = tf.expand_dims(input_ids, 1)
            input_mask = tf.expand_dims(input_mask, 1)
            segment_ids = tf.expand_dims(segment_ids, 1)
            sequence_output = model.network_stacked(input_ids, input_mask,
                                                    segment_ids, use_context)

        first_token_tensor = tf.squeeze(sequence_output[:, 0:1, :], axis=1)
        pooled_output = tf.keras.layers.Dense(
            config.hidden_size,
            activation=tf.keras.activations.tanh,
            kernel_initializer=create_initializer(
                config.initializer_range))(first_token_tensor)

        if "bias_loss" in special_flags:
            loss_weighting = reweight_zero
        else:
            loss_weighting = None

        task = Classification(3, features, pooled_output, is_training,
                              loss_weighting)
        loss = task.loss

        tvars = tf.compat.v1.trainable_variables()
        assignment_fn = assignment_map.assignment_map_v2_to_v2
        initialized_variable_names, init_fn = get_init_fn(
            tvars, train_config.init_checkpoint, assignment_fn)
        scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)

        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_logging.info("Using single lr ")
            train_op = optimization.create_optimizer_from_config(
                loss, train_config)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            output_spec = TPUEstimatorSpec(mode=model,
                                           loss=loss,
                                           eval_metrics=task.eval_metrics(),
                                           scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.PREDICT:
            predictions = {"input_ids": input_ids, "logits": task.logits}
            output_spec = TPUEstimatorSpec(mode=model,
                                           loss=loss,
                                           predictions=predictions,
                                           scaffold_fn=scaffold_fn)
        return output_spec