コード例 #1
0
def joint_softmax_classifier(mode, hparams, model_config, inputs, targets,
                             num_labels, tokens_to_keep, joint_maps,
                             transition_params):

    with tf.name_scope('joint_softmax_classifier'):

        # todo pass this as initial proj dim (which is optional)
        projection_dim = model_config['predicate_pred_mlp_size']

        with tf.variable_scope('MLP'):
            mlp = nn_utils.MLP(inputs,
                               projection_dim,
                               keep_prob=hparams.mlp_dropout,
                               n_splits=1)
        with tf.variable_scope('Classifier'):
            logits = nn_utils.MLP(mlp,
                                  num_labels,
                                  keep_prob=hparams.mlp_dropout,
                                  n_splits=1)

        # todo implement this
        if transition_params is not None:
            print(
                'Transition params not yet supported in joint_softmax_classifier'
            )
            exit(1)

        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=targets)

        cross_entropy *= tokens_to_keep
        loss = tf.reduce_sum(cross_entropy) / tf.reduce_sum(tokens_to_keep)

        predictions = tf.cast(tf.argmax(logits, axis=-1), tf.int32)

        output = {
            'loss': loss,
            'predictions': predictions,
            'scores': logits,
            'probabilities': tf.nn.softmax(logits, -1)
        }

        # now get separate-task scores and predictions for each of the maps we've passed through
        separate_output = get_separate_scores_preds_from_joint(
            output, joint_maps, num_labels)
        combined_output = {**output, **separate_output}

        return combined_output
コード例 #2
0
ファイル: output_fns.py プロジェクト: schilama/Fake-News
def softmax_classifier(mode, hparams, model_config, inputs, targets,
                       num_labels, tokens_to_keep, transition_params):

    with tf.name_scope('softmax_classifier'):

        # # todo pass this as initial proj dim (which is optional)
        # projection_dim = model_config['predicate_pred_mlp_size']
        #
        # with tf.variable_scope('MLP'):
        #   mlp = nn_utils.MLP(inputs, projection_dim, keep_prob=hparams.mlp_dropout, n_splits=1)
        with tf.variable_scope('Classifier'):
            logits = nn_utils.MLP(inputs,
                                  num_labels,
                                  keep_prob=hparams.mlp_dropout,
                                  n_splits=1)

        # todo implement this
        if transition_params is not None:
            print('Transition params not yet supported in softmax_classifier')
            exit(1)

        # cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=targets)
        targets_onehot = tf.one_hot(indices=targets, depth=num_labels, axis=-1)
        loss = tf.losses.softmax_cross_entropy(
            logits=tf.reshape(logits, [-1, num_labels]),
            onehot_labels=tf.reshape(targets_onehot, [-1, num_labels]),
            weights=tf.reshape(tokens_to_keep, [-1]),
            label_smoothing=hparams.label_smoothing,
            reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)

        predictions = tf.cast(tf.argmax(logits, axis=-1), tf.int32)

        output = {'loss': loss, 'predictions': predictions, 'scores': logits}

    return output
コード例 #3
0
ファイル: output_fns.py プロジェクト: schilama/Fake-News
def fakenews_maxpool(mode, hparams, model_config, inputs, targets, num_labels,
                     tokens_to_keep, transition_params):
    # Apply masking to inputs (shape = batch_size * seq_len * hidden_dim)
    inputs *= tf.expand_dims(tokens_to_keep, axis=-1)
    # TODO: Test performance when we mask pad tokens with constants.VERY_SMALL instead of 0
    actual_targets = targets[:,
                             0]  # targets = batch_size * seq_len (each token is labeled with the same stance), so we just need the first one

    with tf.variable_scope('fakenews_maxpool'):
        out_maxpool = tf.reduce_max(inputs, axis=1)  # batch_size * hidden_dim
        # TODO: Is dropout necessary here?
        logits = nn_utils.MLP(out_maxpool,
                              num_labels,
                              keep_prob=hparams.mlp_dropout,
                              n_splits=1)  # batch_size, num_classes
        loss = tf.losses.sparse_softmax_cross_entropy(logits=logits,
                                                      labels=actual_targets)
        predictions = tf.argmax(logits, -1)

        output = {
            'loss': loss,
            'scores': logits,
            'predictions': predictions,
        }

    return output
コード例 #4
0
ファイル: output_fns.py プロジェクト: schilama/Fake-News
def parse_bilinear(mode, hparams, model_config, inputs, targets, num_labels,
                   tokens_to_keep, transition_params):
    class_mlp_size = model_config['class_mlp_size']
    attn_mlp_size = model_config['attn_mlp_size']

    if transition_params is not None:
        print('Transition params not supported in parse_bilinear')
        exit(1)

    with tf.variable_scope('parse_bilinear'):
        with tf.variable_scope('MLP'):
            dep_mlp, head_mlp = nn_utils.MLP(inputs,
                                             class_mlp_size + attn_mlp_size,
                                             n_splits=2,
                                             keep_prob=hparams.mlp_dropout)
            dep_arc_mlp, dep_rel_mlp = dep_mlp[:, :, :
                                               attn_mlp_size], dep_mlp[:, :,
                                                                       attn_mlp_size:]
            head_arc_mlp, head_rel_mlp = head_mlp[:, :, :
                                                  attn_mlp_size], head_mlp[:, :,
                                                                           attn_mlp_size:]

        with tf.variable_scope('Arcs'):
            # batch_size x batch_seq_len x batch_seq_len
            arc_logits = nn_utils.bilinear_classifier(dep_arc_mlp,
                                                      head_arc_mlp,
                                                      hparams.bilinear_dropout)

        num_tokens = tf.reduce_sum(tokens_to_keep)

        predictions = tf.argmax(arc_logits, -1)
        probabilities = tf.nn.softmax(arc_logits)

        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=arc_logits, labels=targets)
        loss = tf.reduce_sum(cross_entropy * tokens_to_keep) / num_tokens

        output = {
            'loss': loss,
            'predictions': predictions,
            'probabilities': probabilities,
            'scores': arc_logits,
            'dep_rel_mlp': dep_rel_mlp,
            'head_rel_mlp': head_rel_mlp
        }

    return output
コード例 #5
0
ファイル: model.py プロジェクト: ChristoMartin/LISA
    def model_fn(self, features, mode):

        # todo can estimators handle dropout for us or do we need to do it on our own?
        hparams = self.hparams(mode)
        tf.logging.log(tf.logging.INFO, "Running in {} mode.".format(mode))

        with tf.variable_scope("LISA", reuse=tf.AUTO_REUSE):
            # features = tf.Print(features, [features, tf.shape(features)], 'input features')
            batch_shape = tf.shape(features)
            batch_size = batch_shape[0]
            batch_seq_len = batch_shape[1]
            layer_config = self.model_config['layers']
            sa_hidden_size = layer_config['head_dim'] * layer_config[
                'num_heads']

            feats = {
                f: features[:, :, idx]
                for f, idx in self.feature_idx_map.items()
            }

            # todo this assumes that word_type is always passed in
            words = feats['word_type']
            # print("debug <input features>:", features)

            # for masking out padding tokens
            tokens_to_keep = tf.where(tf.equal(words, constants.PAD_VALUE),
                                      tf.zeros([batch_size, batch_seq_len]),
                                      tf.ones([batch_size, batch_seq_len]))

            # Extract named features from monolithic "features" input
            feats = {
                f: tf.multiply(tf.cast(tokens_to_keep, tf.int32), v)
                for f, v in feats.items()
            }
            # feats = {f: tf.Print(feats[f], [feats[f]]) for f in feats.keys()}
            # print("<debug features>: ",feats)
            # print("debug <model_config>:", self.model_config)

            # Extract named labels from monolithic "features" input, and mask them
            # todo fix masking -- is it even necessary?
            labels = {}
            for l, idx in self.label_idx_map.items():
                # print("debug <label_idx_map idx>: ", idx)
                these_labels = features[:, :, idx[0]:idx[0] + 1] if idx[
                    1] != -1 else features[:, :, idx[0]:]
                these_labels_masked = tf.multiply(
                    these_labels,
                    tf.cast(tf.expand_dims(tokens_to_keep, -1), tf.int32))
                # check if we need to mask another dimension
                # these_labels_masked_print = tf.Print(these_labels_masked, [tf.shape(these_labels_masked), these_labels_masked, idx],
                #                                'thses labels masked')
                if idx[1] == -1:
                    last_dim = tf.shape(these_labels)[2]
                    this_mask = tf.where(
                        tf.equal(these_labels_masked, constants.PAD_VALUE),
                        tf.zeros([batch_size, batch_seq_len, last_dim],
                                 dtype=tf.int32),
                        tf.ones([batch_size, batch_seq_len, last_dim],
                                dtype=tf.int32))
                    these_labels_masked = tf.multiply(these_labels_masked,
                                                      this_mask)
                else:
                    these_labels_masked = tf.squeeze(
                        these_labels_masked,
                        -1,
                        name='these_labels_masked_squeezing')
                    # these_labels_masked = tf.Print(these_labels_masked, [tf.shape(these_labels_masked), these_labels_masked], 'thses labels masked after squeezed')
                labels[l] = these_labels_masked
                # labels = [tf.Print(l, [l]) for l in labels]

            # labels = [tf.Print("label_l", l)for l in labels]
            # load transition parameters
            transition_stats = util.load_transition_params(
                self.task_config, self.vocab, hparams.train_with_crf)

            # Create embeddings tables, loading pre-trained if specified

            embeddings = {}
            for embedding_name, embedding_map in self.model_config[
                    'embeddings'].items():
                embedding_dim = embedding_map['embedding_dim']
                if 'pretrained_embeddings' in embedding_map:
                    input_pretrained_embeddings = embedding_map[
                        'pretrained_embeddings']
                    include_oov = True
                    embedding_table = self.get_embedding_table(
                        embedding_name,
                        embedding_dim,
                        include_oov,
                        pretrained_fname=input_pretrained_embeddings,
                        cwr_ood=hparams.cwr_ood)
                else:
                    num_embeddings = self.vocab.vocab_names_sizes[
                        embedding_name]
                    embedding_table = self.get_embedding_table(
                        embedding_name,
                        embedding_dim,
                        include_oov,
                        num_embeddings=num_embeddings)
                # embedding_table = tf.Print(embedding_table, [tf.shape(embedding_table), embedding_table])
                embeddings[embedding_name] = embedding_table

                tf.logging.log(tf.logging.INFO,
                               "Created embeddings for '%s'." % embedding_name)
                tf.logging.log(tf.logging.INFO, embeddings[embedding_name])

            inputs_list = []
            gp_embs = []
            with tf.device("CPU:0"):
                if hparams.cwr != "None" or hparams.glove_300d:
                    if hparams.cwr != "None":
                        cached_cwr_embeddings = tf.get_variable(
                            "cwr_embedding",
                            shape=self.cwr_embedding.shape,
                            trainable=False)

                    def init_fn(scaffold, sess):
                        if hparams.cwr != "None":
                            sess.run(
                                cached_cwr_embeddings.initializer, {
                                    cached_cwr_embeddings.initial_value:
                                    self.cwr_embedding
                                })

                    scaffold = tf.train.Scaffold(init_fn=init_fn)

                for input_name, input_transformation_name in self.model_config[
                        'inputs'].items():
                    # print("debug <actual inputs>:", input_name, input_transformation_name)
                    input_values = feats[input_name]
                    # input_values = tf.Print(input_values, ["input value under {}".format(input_name), input_values, tf.shape(input_values)])
                    if input_transformation_name == "cached_embeddings":
                        # if hparams.cwr_ood:
                        #   ROOT_emb = tf.get_variable("root_emb", shape=[1, 3072], initializer=tf.random_normal_initializer(), trainable=False)
                        #   cached_cwr_embeddings_oov = tf.concat([cached_cwr_embeddings, ROOT_emb], axis=0)
                        #   input_embedding_lookup = tf.nn.embedding_lookup(cached_cwr_embeddings_oov, input_values)
                        # else:
                        input_embedding_lookup = tf.nn.embedding_lookup(
                            cached_cwr_embeddings, input_values)
                        with tf.variable_scope("cwr_assembly"):
                            num_layers = 3  #input_embedding_lookup.get_shape()[2]
                            weight = tf.get_variable("cwr_weight",
                                                     shape=[num_layers])
                            scale = tf.get_variable("cwr_scale", shape=[])
                            input_embedding_lookup = scale * tf.math.reduce_sum(
                                tf.split(input_embedding_lookup,
                                         axis=-1,
                                         num_or_size_splits=num_layers) *
                                tf.reshape(tf.nn.softmax(weight),
                                           shape=[num_layers, 1, 1, 1]),
                                axis=0)
                    elif input_transformation_name == "bert_embeddings":
                        input_embedding_lookup = tf.nn.embedding_lookup(
                            cached_cwr_embeddings, input_values)
                    elif input_transformation_name == "embeddings":
                        print("embeddings", input_name, embeddings[input_name])
                        input_embedding_lookup = tf.nn.embedding_lookup(
                            embeddings[input_name], input_values)
                    elif input_transformation_name == "predicate":
                        print("embeddings", input_name, embeddings[input_name])
                        input_embedding_lookup = tf.nn.embedding_lookup(
                            embeddings[input_name], input_values)
                        gp_embs.append(input_embedding_lookup)
                        continue
                    else:
                        print("unknown input transformation {}".format(
                            input_transformation_name))
                        raise NotImplementedError
                    # input_embedding_lookup = tf.Print(input_embedding_lookup, ["input embedding under {}".format(input_name), input_embedding_lookup])
                    inputs_list.append(input_embedding_lookup)
                    tf.logging.log(tf.logging.INFO,
                                   "Added %s to inputs list." % input_name)
            # TODO a mere workaround with one element concat
            current_input = tf.concat(inputs_list, axis=-1)

            ## <guard: condition to enter sentence features>
            ## suppose the dim is of (B, S, H)
            if hparams.input_project_layer_norm:
                current_input = tf.contrib.layers.layer_norm(current_input)

            sentence_feature = tf.reduce_sum(
                current_input * tf.expand_dims(tokens_to_keep, -1), axis=1)
            sentence_feature /= tf.expand_dims(
                tf.reduce_sum(tokens_to_keep, axis=1),
                -1)  #To get the mean of all embeddings
            feats['sentence_feature'] = sentence_feature
            ## <guard: condition to enter sentence features>
            current_input = tf.nn.dropout(current_input, hparams.input_dropout)
            if len(gp_embs) > 0:
                current_input = tf.concat([current_input, gp_embs[0]], axis=-1)

            with tf.variable_scope('project_input'):
                current_input = nn_utils.MLP(current_input,
                                             sa_hidden_size,
                                             n_splits=1)

            # current_input = tf.Print(current_input, [tf.shape(current_input)], "input shape")

            predictions = {}
            eval_metric_ops = {}
            export_outputs = {}
            loss = tf.constant(0.)
            items_to_log = {}

            num_layers = max(self.task_config.keys()) + 1
            tf.logging.log(
                tf.logging.INFO,
                "Creating transformer model with %d layers" % num_layers)
            with tf.variable_scope('transformer'):
                current_input = transformer.add_timing_signal_1d(current_input)
                for i in range(num_layers):
                    # print("debug: <constructing {}-th layer>".format(i))
                    with tf.variable_scope('layer%d' % i):

                        special_attn = [
                            [], []
                        ]  #first bracket is for hard injection attns, the sencond is for discounting attns
                        special_values = []
                        if i in self.attention_config:

                            this_layer_attn_config = self.attention_config[i]
                            # print("debug: <layer_{} config>: ".format(i), this_layer_attn_config)

                            print(
                                "debug <attention configuration>@{}: ".format(
                                    i), this_layer_attn_config)

                            for attn_fn_item in this_layer_attn_config.keys():
                                for attn_fn, attn_fn_map in this_layer_attn_config[
                                        attn_fn_item].items():
                                    # print("debug <attn_fn, attn_fn_map>: ", attn_fn, ' ', attn_fn_map)
                                    if 'length' in attn_fn_map.keys(
                                    ) or hparams.use_hparams_headcounts:
                                        hc = hparams.__dict__['{}_headcount'.format(
                                            attn_fn
                                        )] if hparams.use_hparams_headcounts else attn_fn_map[
                                            'length']
                                        tf.logging.log(
                                            tf.logging.INFO,
                                            "{} is using {} attention mode with {} heads"
                                            .format(
                                                attn_fn_item, hparams.__dict__[
                                                    '{}_injection'.format(
                                                        attn_fn)], hc))
                                        for idx in range(
                                                hc
                                        ):  # To make sure that the three special attentions are different
                                            with tf.variable_scope(
                                                    '{}_{}'.format(
                                                        attn_fn_item, idx)):
                                                attention_fn_params = attention_fns.get_params(
                                                    mode, attn_fn_map,
                                                    predictions, feats, labels,
                                                    hparams, self.model_config,
                                                    tokens_to_keep)
                                            this_special_attn, special_attn_weight = attention_fns.dispatch(
                                                attn_fn_map['name'])(
                                                    **attention_fn_params)
                                            # todo patches everywhere!
                                            # this_special_attn = tf.Print(this_special_attn, [this_special_attn])
                                            if special_attn_weight is not None and hparams.output_attention_weight:
                                                for i in range(
                                                        special_attn_weight.
                                                        get_shape()[0]):
                                                    items_to_log[
                                                        "{}_{}_weight_{}".
                                                        format(
                                                            attn_fn, idx, i
                                                        )] = special_attn_weight[
                                                            i]
                                            if hparams.__dict__[
                                                    '{}_injection'.format(
                                                        attn_fn
                                                    )] == 'injection':
                                                special_attn[0].append(
                                                    this_special_attn)
                                            elif hparams.__dict__[
                                                    '{}_injection'.format(
                                                        attn_fn
                                                    )] == 'discounting':
                                                special_attn[1].append(
                                                    this_special_attn)
                                            else:
                                                tf.logging.log(
                                                    tf.logging.ERROR,
                                                    "The spcified injection method {} has not been implemented"
                                                    .format(attn_fn_map[
                                                        'injection_method']))
                                                raise NotImplementedError
                                            # print(special_attn)
                                    else:
                                        with tf.variable_scope(
                                                '{}'.format(attn_fn)):
                                            attention_fn_params = attention_fns.get_params(
                                                mode, attn_fn_map, predictions,
                                                feats, labels, hparams,
                                                self.model_config,
                                                tokens_to_keep)
                                            this_special_attn, _ = attention_fns.dispatch(
                                                attn_fn_map['name'])(
                                                    **attention_fn_params)
                                        if hparams.__dict__[
                                                '{}_injection'.format(
                                                    attn_fn)] == 'injection':
                                            special_attn[0].append(
                                                this_special_attn)
                                        elif hparams.__dict__[
                                                '{}_injection'.format(
                                                    attn_fn)] == 'discounting':
                                            special_attn[1].append(
                                                this_special_attn)
                                        else:
                                            tf.logging.log(
                                                tf.logging.ERROR,
                                                "The spcified injection method {} has not been implemented"
                                                .format(attn_fn_map[
                                                    'injection_method']))
                                            raise NotImplementedError
                                # print("debug <layer_{} special attention>: ".format(i), special_attn )

                            if 'value_fns' in this_layer_attn_config:
                                tf.logging.log(
                                    tf.logging.ERROR,
                                    "special value section has been dropped temporarily"
                                )
                                raise NotImplementedError
                                for value_fn, value_fn_map in this_layer_attn_config[
                                        'value_fns'].items():
                                    value_fn_params = value_fns.get_params(
                                        mode, value_fn_map, predictions, feats,
                                        labels, embeddings)
                                    this_special_values = value_fns.dispatch(
                                        value_fn_map['name'])(
                                            **value_fn_params)
                                    special_values.append(this_special_values)
                                # print("debug <layer_{} special values>: ".format(i), special_values)
                            if hparams.attn_debug:
                                print(special_attn)
                                # special_attn[1][1] = tf.Print(special_attn[1][1], [special_attn[1][0], special_attn[1][1]], "debug_check equal attn")
                                # assert_op = tf.assert_none_equal(special_attn[1][0], special_attn[1][1])
                                # tf.logging.log(tf.logging.INFO, "attention behavior is identical")
                        current_input = transformer.transformer(
                            current_input,
                            tokens_to_keep,
                            layer_config['head_dim'],
                            layer_config['num_heads'],
                            hparams.attn_dropout,
                            hparams.ff_dropout,
                            hparams.prepost_dropout,
                            layer_config['ff_hidden_size'],
                            special_attn,
                            special_values,
                            special_attention_mode=hparams.
                            special_attention_mode)
                        # current_input = tf.Print(current_input, [tf.shape(current_input)], "LISA input after transformer")
                        if i in self.task_config:

                            # if normalization is done in layer_preprocess, then it should also be done
                            # on the output, since the output can grow very large, being the sum of
                            # a whole stack of unnormalized layer outputs.
                            current_input = nn_utils.layer_norm(current_input)

                            # todo test a list of tasks for each layer
                            for task, task_map in self.task_config[i].items():
                                # print("debug <task>: ", task)
                                # print("debug <task map>:" , task_map)
                                task_labels = labels[task]
                                # task_labels = tf.Print(task_labels, [task_labels] , 'task_label'.format(task))
                                task_vocab_size = self.vocab.vocab_names_sizes[
                                    task] if task in self.vocab.vocab_names_sizes else -1

                                # Set up CRF / Viterbi transition params if specified
                                with tf.variable_scope(
                                        "crf"
                                ):  # to share parameters, change scope here
                                    # transition_stats_file = task_map['transition_stats'] if 'transition_stats' in task_map else None
                                    task_transition_stats = transition_stats[
                                        task] if task in transition_stats else None

                                    # create transition parameters if training or decoding with crf/viterbi
                                    task_crf = 'crf' in task_map and task_map[
                                        'crf']
                                    task_viterbi_decode = task_crf or 'viterbi' in task_map and task_map[
                                        'viterbi']
                                    transition_params = None
                                    if task_viterbi_decode or task_crf:
                                        # print("loading transition params", self.not_load_transition)
                                        if hparams.train_with_crf:
                                            transition_params = tf.get_variable(
                                                "transitions", [
                                                    task_vocab_size,
                                                    task_vocab_size
                                                ])
                                        else:
                                            tf.logging.log(
                                                tf.logging.INFO,
                                                "Use default transition param")
                                            tf.logging.log(
                                                tf.logging.INFO,
                                                "transition parameters:{}".
                                                format(task_transition_stats))
                                            transition_params = tf.constant(
                                                task_transition_stats,
                                                dtype=tf.float32
                                            )  #tf.get_variable("transitions", [task_vocab_size, task_vocab_size],
                                            #initializer=tf.constant_initializer(task_transition_stats) if not self.not_load_transition else tf.constant_initializer(0),
                                            #trainable=task_crf)
                                        # if mode != ModeKeys.TRAIN:
                                        #   transition_params = tf.Print(transition_params, [tf.get_variable("transitions", [task_vocab_size, task_vocab_size])],
                                        #                                                "optimized transition?")
                                        # transition_params = tf.cond(tf.equal(mode, ModeKeys.TRAIN),
                                        #                               lambda: transition_params,
                                        #                               lambda: tf.Print(transition_params, [
                                        #                                 tf.get_variable("transitions", [task_vocab_size, task_vocab_size])],
                                        #                                                "optimized transition?"))
                                        # transition_params =
                                        train_or_decode_str = "training" if task_crf else "decoding"
                                        tf.logging.log(
                                            tf.logging.INFO,
                                            "Created transition params for %s %s"
                                            % (train_or_decode_str, task))

                                output_fn_params = output_fns.get_params(
                                    mode, self.model_config,
                                    task_map['output_fn'], predictions, feats,
                                    labels, current_input, task_labels,
                                    task_vocab_size,
                                    self.vocab.joint_label_lookup_maps,
                                    tokens_to_keep, transition_params, hparams)
                                # print("debug <dispatch into {}>".format(task_map['output_fn']['name']))
                                task_outputs = output_fns.dispatch(
                                    task_map['output_fn']['name'])(
                                        **output_fn_params)
                                # print("debug <task_outputs>: ", task_outputs)
                                # want task_outputs to have:
                                # - predictions
                                # - loss
                                # - scores
                                # - probabilities
                                predictions[task] = task_outputs

                                # do the evaluation
                                for eval_name, eval_map in task_map[
                                        'eval_fns'].items():
                                    eval_fn_params = evaluation_fns.get_params(
                                        task_outputs, eval_map, predictions,
                                        feats, labels, task_labels,
                                        self.vocab.reverse_maps,
                                        tokens_to_keep)
                                    if eval_name == 'parse_eval' and hparams.using_input_with_root:
                                        eval_fn_params['has_root_token'] = True
                                    eval_result = evaluation_fns.dispatch(
                                        eval_map['name'])(**eval_fn_params)
                                    eval_metric_ops[eval_name] = eval_result

                                # get the individual task loss and apply penalty
                                this_task_loss = task_outputs[
                                    'loss'] * task_map['penalty']

                                # log this task's loss
                                items_to_log['%s_loss' % task] = this_task_loss

                                #outputing sub loss as well
                                for key in task_outputs.keys():
                                    if key.startswith('loss'):
                                        items_to_log['{}_{}'.format(
                                            task, key)] = task_outputs[key]

                                # add this loss to the overall loss being minimized
                                # this_task_loss = tf.Print(this_task_loss, [this_task_loss], '{}_{}'.format(task, key))
                                loss += this_task_loss

                                # print("debug <accumulated loss>: ", loss)
                            # break # only take one loss

            # set up moving average variables
            assign_moving_averages_dep = tf.no_op()
            if hparams.moving_average_decay > 0.:
                moving_averager = tf.train.ExponentialMovingAverage(
                    hparams.moving_average_decay,
                    zero_debias=True,
                    num_updates=tf.train.get_global_step())
                moving_average_op = moving_averager.apply(
                    train_utils.get_vars_for_moving_average(
                        hparams.average_norms))

                tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,
                                     moving_average_op)

                # use moving averages of variables if evaluating
                assign_moving_averages_dep = tf.cond(
                    tf.equal(mode, ModeKeys.TRAIN), lambda: tf.no_op(), lambda:
                    nn_utils.set_vars_to_moving_average(moving_averager))
            # print("debug <finishing setting up moving avg variables>")

            with tf.control_dependencies([assign_moving_averages_dep]):

                items_to_log['loss'] = loss
                # print("debug <final loss>: ", loss)
                # get learning rate w/ decay
                # todo dirty workaround
                if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL:
                    this_step_lr = train_utils.learning_rate(
                        hparams, tf.train.get_global_step())
                    items_to_log['lr'] = this_step_lr
                    # print("debug <items to log>: ", items_to_log)
                    # print("debug <eval_metric_content>: ", eval_metric_ops)

                    if hparams.optimizer == "lazyadam":
                        optimizer = LazyAdamOptimizer(
                            learning_rate=this_step_lr,
                            beta1=hparams.beta1,
                            beta2=hparams.beta2,
                            epsilon=hparams.epsilon,
                            use_nesterov=hparams.use_nesterov)
                    elif hparams.optimizer == "adam":
                        optimizer = tf.train.AdamOptimizer(
                            learning_rate=this_step_lr,
                            beta1=hparams.beta1,
                            beta2=hparams.beta2,
                            epsilon=hparams.epsilon)
                    else:
                        raise NotImplementedError(
                            "The specified optimizer is not implemented")
                    # loss = tf.Print(loss, [loss], "loss")
                    # # loss_no_nan = tf.cond(tf.reduce_any(tf.is_nan(loss)), lambda: tf.zeros_like(loss), lambda: loss)
                    # # loss_no_nan = tf.where(tf.is_nan(loss), tf.zeros_like(loss), loss)
                    # loss_no_nan = tf.where(tf.math.is_nan(loss), tf.zeros_like(loss), loss)
                    # loss_no_nan_printed = tf.Print(loss_no_nan, [loss_no_nan], "no nan loss")
                    # grad_and_var = optimizer.compute_gradients(loss_no_nan_printed)

                    # loss = tf.where(tf.math.is_nan(loss), tf.zeros_like(loss), loss)
                    # loss = tf.Print(loss, [loss], "loss")
                    grad_and_var = optimizer.compute_gradients(loss)

                    gradients, variables = zip(*grad_and_var)
                    # gradients_without_nan = [tf.cond(tf.reduce_any(tf.is_nan(item)), lambda: tf.zeros_like(item), item)for item in gradients]
                    # gradients_without_nan = gradients
                    gradients, gn = tf.clip_by_global_norm(
                        gradients, hparams.gradient_clip_norm)
                    # print([g is None for g in gradients])
                    # gn = gn[0]
                    zero_clipped_gradients = [
                        tf.clip_by_value(g, 0., 0.) if g is not None else g
                        for g in gradients
                    ]
                    gradients_prev_inf_norm = [
                        tf.cond(
                            tf.logical_or(tf.math.is_inf(gn),
                                          tf.math.is_nan(gn)), lambda: g_zeros,
                            lambda: g) if g is not None else None for g_zeros,
                        g in zip(zero_clipped_gradients, gradients)
                    ]

                    # gn = tf.Print(gn, [gn], "global norm")
                    with tf.control_dependencies([gn]):
                        train_op = optimizer.apply_gradients(
                            zip(gradients_prev_inf_norm, variables),
                            global_step=tf.train.get_global_step())

                    # if hparams.debug and mode == tf.estimator.ModeKeys.TRAIN:
                    #   gradients_to_print = [gradients[variables.index(var)] for var in nn_utils.gradient_to_watch]
                    #   print(gradients_to_print)
                    #   gradients[0] = tf.Print(gradients[0], gradients_to_print, "gradient for dependency label strength")

                    # train_op = optimizer.apply_gradients(zip(gradients, variables), global_step=tf.train.get_global_step())

                    # export_outputs = {'predict_output': tf.estimator.export.PredictOutput({'scores': scores, 'preds': preds})}

                    logging_hook = tf.train.LoggingTensorHook(items_to_log,
                                                              every_n_iter=100)

                    histogram_summary = [
                        summary
                        for name, summary in nn_utils.histogram_output.items()
                    ]
                    summary_hook = tf.train.SummarySaverHook(
                        save_steps=500,
                        summary_op=[
                            tf.summary.scalar(k, items_to_log[k])
                            for k in items_to_log.keys()
                        ] + histogram_summary)

                flat_predictions = {
                    "%s_%s" % (k1, k2): v2
                    for k1, v1 in predictions.items() for k2, v2 in v1.items()
                }
                # print("debug <flat predictions>:", flat_predictions)
                export_outputs = {
                    tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                    tf.estimator.export.PredictOutput(flat_predictions)
                }

                tf.logging.log(
                    tf.logging.INFO,
                    "Created model with %d trainable parameters" %
                    tf_utils.get_num_trainable_parameters())
                # if hparams.cwr!= 'None':
                #   with tf.Session() as sess:
                #     sess.run(tf.global_variables_initializer(), feed_dict={cached_cwr_embeddings_placeholder: self.cwr_embedding})

                if mode == tf.estimator.ModeKeys.TRAIN:
                    return tf.estimator.EstimatorSpec(
                        mode,
                        flat_predictions,
                        loss,
                        train_op,
                        eval_metric_ops,
                        training_hooks=[logging_hook, summary_hook],
                        export_outputs=export_outputs,
                        scaffold=scaffold if hparams.cwr != 'None' else None)
                elif mode == tf.estimator.ModeKeys.EVAL:
                    return tf.estimator.EstimatorSpec(
                        mode,
                        flat_predictions,
                        loss,
                        train_op,
                        eval_metric_ops,
                        training_hooks=[logging_hook],
                        export_outputs=export_outputs,
                        scaffold=scaffold if hparams.cwr != 'None' else None)
                elif mode == tf.estimator.ModeKeys.PREDICT:
                    return tf.estimator.EstimatorSpec(
                        mode,
                        flat_predictions,
                        loss,
                        tf.no_op(),
                        eval_metric_ops,
                        export_outputs=export_outputs,
                        scaffold=scaffold if hparams.cwr != 'None' else None)
コード例 #6
0
ファイル: output_fns.py プロジェクト: schilama/Fake-News
def srl_bilinear(mode, hparams, model_config, inputs, targets, num_labels,
                 tokens_to_keep, predicate_preds_train, predicate_preds_eval,
                 predicate_targets, transition_params):
    '''

    :param input: Tensor with dims: [batch_size, batch_seq_len, hidden_size]
    :param predicate_preds: Tensor of predictions from predicates layer with dims: [batch_size, batch_seq_len]
    :param targets: Tensor of SRL labels with dims: [batch_size, batch_seq_len, batch_num_predicates]
    :param tokens_to_keep:
    :param predictions:
    :param transition_params: [num_labels x num_labels] transition parameters, if doing Viterbi decoding
    :return:
    '''

    with tf.name_scope('srl_bilinear'):

        input_shape = tf.shape(inputs)
        batch_size = input_shape[0]
        batch_seq_len = input_shape[1]

        predicate_mlp_size = model_config['predicate_mlp_size']
        role_mlp_size = model_config['role_mlp_size']

        predicate_preds = predicate_preds_train if mode == tf.estimator.ModeKeys.TRAIN else predicate_preds_eval

        # (1) project into predicate, role representations
        with tf.variable_scope('MLP'):
            predicate_role_mlp = nn_utils.MLP(inputs,
                                              predicate_mlp_size +
                                              role_mlp_size,
                                              keep_prob=hparams.mlp_dropout)
            predicate_mlp, role_mlp = predicate_role_mlp[:, :, :predicate_mlp_size], \
                                      predicate_role_mlp[:, :, predicate_mlp_size:]

        # (2) feed through bilinear to obtain scores
        with tf.variable_scope('Bilinear'):
            # gather just the predicates
            # gathered_predicates: num_predicates_in_batch x 1 x predicate_mlp_size
            # role mlp: batch x seq_len x role_mlp_size
            # gathered roles: need a (batch_seq_len x role_mlp_size) role representation for each predicate,
            # i.e. a (num_predicates_in_batch x batch_seq_len x role_mlp_size) tensor
            predicate_gather_indices = tf.where(tf.equal(predicate_preds, 1))
            gathered_predicates = tf.expand_dims(
                tf.gather_nd(predicate_mlp, predicate_gather_indices), 1)
            tiled_roles = tf.reshape(
                tf.tile(role_mlp, [1, batch_seq_len, 1]),
                [batch_size, batch_seq_len, batch_seq_len, role_mlp_size])
            gathered_roles = tf.gather_nd(tiled_roles,
                                          predicate_gather_indices)

            # now multiply them together to get (num_predicates_in_batch x batch_seq_len x num_srl_classes) tensor of scores
            srl_logits = nn_utils.bilinear_classifier_nary(
                gathered_predicates, gathered_roles, num_labels,
                hparams.bilinear_dropout)
            srl_logits_transposed = tf.transpose(srl_logits, [0, 2, 1])

        # (3) compute loss

        # need to repeat each of these once for each target in the sentence
        mask_tiled = tf.reshape(tf.tile(tokens_to_keep, [1, batch_seq_len]),
                                [batch_size, batch_seq_len, batch_seq_len])
        mask = tf.gather_nd(mask_tiled, tf.where(tf.equal(predicate_preds, 1)))

        # now we have k sets of targets for the k frames
        # (p1) f1 f2 f3
        # (p2) f1 f2 f3

        # get all the tags for each token (which is the predicate for a frame), structuring
        # targets as follows (assuming p1 and p2 are predicates for f1 and f3, respectively):
        # (p1) f1 f1 f1
        # (p2) f3 f3 f3
        srl_targets_transposed = tf.transpose(targets, [0, 2, 1])

        gold_predicate_counts = tf.reduce_sum(predicate_targets, -1)
        srl_targets_indices = tf.where(
            tf.sequence_mask(tf.reshape(gold_predicate_counts, [-1])))

        # num_predicates_in_batch x seq_len
        srl_targets_gold_predicates = tf.gather_nd(srl_targets_transposed,
                                                   srl_targets_indices)

        predicted_predicate_counts = tf.reduce_sum(predicate_preds, -1)
        srl_targets_pred_indices = tf.where(
            tf.sequence_mask(tf.reshape(predicted_predicate_counts, [-1])))
        srl_targets_predicted_predicates = tf.gather_nd(
            srl_targets_transposed, srl_targets_pred_indices)

        # num_predicates_in_batch x seq_len
        predictions = tf.cast(tf.argmax(srl_logits_transposed, axis=-1),
                              tf.int32)

        seq_lens = tf.cast(tf.reduce_sum(mask, 1), tf.int32)

        if transition_params is not None and (mode == ModeKeys.PREDICT
                                              or mode == ModeKeys.EVAL):
            predictions, score = tf.contrib.crf.crf_decode(
                srl_logits_transposed, transition_params, seq_lens)

        if transition_params is not None and mode == ModeKeys.TRAIN and tf_utils.is_trainable(
                transition_params):
            # flat_seq_lens = tf.reshape(tf.tile(seq_lens, [1, bucket_size]), tf.stack([batch_size * bucket_size]))
            log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
                srl_logits_transposed, srl_targets_predicted_predicates,
                seq_lens, transition_params)
            loss = tf.reduce_mean(-log_likelihood)
        else:
            srl_targets_onehot = tf.one_hot(
                indices=srl_targets_predicted_predicates,
                depth=num_labels,
                axis=-1)
            loss = tf.losses.softmax_cross_entropy(
                logits=tf.reshape(srl_logits_transposed, [-1, num_labels]),
                onehot_labels=tf.reshape(srl_targets_onehot, [-1, num_labels]),
                weights=tf.reshape(mask, [-1]),
                label_smoothing=hparams.label_smoothing,
                reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)

        output = {
            'loss': loss,
            'predictions': predictions,
            'scores': srl_logits_transposed,
            'targets': srl_targets_gold_predicates,
        }

        return output
コード例 #7
0
ファイル: model.py プロジェクト: yiyouls/LISA
    def model_fn(self, features, mode):

        # todo can estimators handle dropout for us or do we need to do it on our own?
        hparams = self.hparams(mode)

        with tf.variable_scope("LISA", reuse=tf.AUTO_REUSE):

            batch_shape = tf.shape(features)
            batch_size = batch_shape[0]
            batch_seq_len = batch_shape[1]
            layer_config = self.model_config['layers']
            sa_hidden_size = layer_config['head_dim'] * layer_config[
                'num_heads']

            feats = {
                f: features[:, :, idx]
                for f, idx in self.feature_idx_map.items()
            }

            # todo this assumes that word_type is always passed in
            words = feats['word_type']

            # for masking out padding tokens
            tokens_to_keep = tf.where(tf.equal(words, constants.PAD_VALUE),
                                      tf.zeros([batch_size, batch_seq_len]),
                                      tf.ones([batch_size, batch_seq_len]))

            # Extract named features from monolithic "features" input
            feats = {
                f: tf.multiply(tf.cast(tokens_to_keep, tf.int32), v)
                for f, v in feats.items()
            }

            # Extract named labels from monolithic "features" input, and mask them
            # todo fix masking -- is it even necessary?
            labels = {}
            for l, idx in self.label_idx_map.items():
                these_labels = features[:, :, idx[0]:idx[1]] if idx[
                    1] != -1 else features[:, :, idx[0]:]
                these_labels_masked = tf.multiply(
                    these_labels,
                    tf.cast(tf.expand_dims(tokens_to_keep, -1), tf.int32))
                # check if we need to mask another dimension
                if idx[1] == -1:
                    last_dim = tf.shape(these_labels)[2]
                    this_mask = tf.where(
                        tf.equal(these_labels_masked, constants.PAD_VALUE),
                        tf.zeros([batch_size, batch_seq_len, last_dim],
                                 dtype=tf.int32),
                        tf.ones([batch_size, batch_seq_len, last_dim],
                                dtype=tf.int32))
                    these_labels_masked = tf.multiply(these_labels_masked,
                                                      this_mask)
                else:
                    these_labels_masked = tf.squeeze(these_labels_masked, -1)
                labels[l] = these_labels_masked

            # load transition parameters
            transition_stats = util.load_transition_params(
                self.task_config, self.vocab)

            # Create embeddings tables, loading pre-trained if specified
            embeddings = {}
            for embedding_name, embedding_map in self.model_config[
                    'embeddings'].items():
                embedding_dim = embedding_map['embedding_dim']
                if 'pretrained_embeddings' in embedding_map:
                    input_pretrained_embeddings = embedding_map[
                        'pretrained_embeddings']
                    include_oov = True
                    embedding_table = self.get_embedding_table(
                        embedding_name,
                        embedding_dim,
                        include_oov,
                        pretrained_fname=input_pretrained_embeddings)
                else:
                    num_embeddings = self.vocab.vocab_names_sizes[
                        embedding_name]
                    include_oov = self.vocab.oovs[embedding_name]
                    embedding_table = self.get_embedding_table(
                        embedding_name,
                        embedding_dim,
                        include_oov,
                        num_embeddings=num_embeddings)
                embeddings[embedding_name] = embedding_table
                tf.logging.log(tf.logging.INFO,
                               "Created embeddings for '%s'." % embedding_name)

            # Set up model inputs
            inputs_list = []
            for input_name in self.model_config['inputs']:
                input_values = feats[input_name]
                input_embedding_lookup = tf.nn.embedding_lookup(
                    embeddings[input_name], input_values)
                inputs_list.append(input_embedding_lookup)
                tf.logging.log(tf.logging.INFO,
                               "Added %s to inputs list." % input_name)
            current_input = tf.concat(inputs_list, axis=2)
            current_input = tf.nn.dropout(current_input, hparams.input_dropout)

            with tf.variable_scope('project_input'):
                current_input = nn_utils.MLP(current_input,
                                             sa_hidden_size,
                                             n_splits=1)

            predictions = {}
            eval_metric_ops = {}
            export_outputs = {}
            loss = tf.constant(0.)
            items_to_log = {}

            num_layers = max(self.task_config.keys()) + 1
            tf.logging.log(
                tf.logging.INFO,
                "Creating transformer model with %d layers" % num_layers)
            with tf.variable_scope('transformer'):
                current_input = transformer.add_timing_signal_1d(current_input)
                for i in range(num_layers):
                    with tf.variable_scope('layer%d' % i):

                        special_attn = []
                        special_values = []
                        if i in self.attention_config:

                            this_layer_attn_config = self.attention_config[i]

                            if 'attention_fns' in this_layer_attn_config:
                                for attn_fn, attn_fn_map in this_layer_attn_config[
                                        'attention_fns'].items():
                                    attention_fn_params = attention_fns.get_params(
                                        mode, attn_fn_map, predictions, feats,
                                        labels)
                                    this_special_attn = attention_fns.dispatch(
                                        attn_fn_map['name'])(
                                            **attention_fn_params)
                                    special_attn.append(this_special_attn)

                            if 'value_fns' in this_layer_attn_config:
                                for value_fn, value_fn_map in this_layer_attn_config[
                                        'value_fns'].items():
                                    value_fn_params = value_fns.get_params(
                                        mode, value_fn_map, predictions, feats,
                                        labels, embeddings)
                                    this_special_values = value_fns.dispatch(
                                        value_fn_map['name'])(
                                            **value_fn_params)
                                    special_values.append(this_special_values)

                        current_input = transformer.transformer(
                            current_input, tokens_to_keep,
                            layer_config['head_dim'],
                            layer_config['num_heads'], hparams.attn_dropout,
                            hparams.ff_dropout, hparams.prepost_dropout,
                            layer_config['ff_hidden_size'], special_attn,
                            special_values)
                        if i in self.task_config:

                            # if normalization is done in layer_preprocess, then it should also be done
                            # on the output, since the output can grow very large, being the sum of
                            # a whole stack of unnormalized layer outputs.
                            current_input = nn_utils.layer_norm(current_input)

                            # todo test a list of tasks for each layer
                            for task, task_map in self.task_config[i].items():
                                task_labels = labels[task]
                                task_vocab_size = self.vocab.vocab_names_sizes[
                                    task] if task in self.vocab.vocab_names_sizes else -1

                                # Set up CRF / Viterbi transition params if specified
                                with tf.variable_scope(
                                        "crf"
                                ):  # to share parameters, change scope here
                                    # transition_stats_file = task_map['transition_stats'] if 'transition_stats' in task_map else None
                                    task_transition_stats = transition_stats[
                                        task] if task in transition_stats else None

                                    # create transition parameters if training or decoding with crf/viterbi
                                    task_crf = 'crf' in task_map and task_map[
                                        'crf']
                                    task_viterbi_decode = task_crf or 'viterbi' in task_map and task_map[
                                        'viterbi']
                                    transition_params = None
                                    if task_viterbi_decode or task_crf:
                                        transition_params = tf.get_variable(
                                            "transitions",
                                            [task_vocab_size, task_vocab_size],
                                            initializer=tf.
                                            constant_initializer(
                                                task_transition_stats),
                                            trainable=task_crf)
                                        train_or_decode_str = "training" if task_crf else "decoding"
                                        tf.logging.log(
                                            tf.logging.INFO,
                                            "Created transition params for %s %s"
                                            % (train_or_decode_str, task))

                                output_fn_params = output_fns.get_params(
                                    mode, self.model_config,
                                    task_map['output_fn'], predictions, feats,
                                    labels, current_input, task_labels,
                                    task_vocab_size,
                                    self.vocab.joint_label_lookup_maps,
                                    tokens_to_keep, transition_params, hparams)
                                task_outputs = output_fns.dispatch(
                                    task_map['output_fn']['name'])(
                                        **output_fn_params)

                                # want task_outputs to have:
                                # - predictions
                                # - loss
                                # - scores
                                # - probabilities
                                predictions[task] = task_outputs

                                # do the evaluation
                                for eval_name, eval_map in task_map[
                                        'eval_fns'].items():
                                    eval_fn_params = evaluation_fns.get_params(
                                        task_outputs, eval_map, predictions,
                                        feats, labels, task_labels,
                                        self.vocab.reverse_maps,
                                        tokens_to_keep)
                                    eval_result = evaluation_fns.dispatch(
                                        eval_map['name'])(**eval_fn_params)
                                    eval_metric_ops[eval_name] = eval_result

                                # get the individual task loss and apply penalty
                                this_task_loss = task_outputs[
                                    'loss'] * task_map['penalty']

                                # log this task's loss
                                items_to_log['%s_loss' % task] = this_task_loss

                                # add this loss to the overall loss being minimized
                                loss += this_task_loss

            # set up moving average variables
            assign_moving_averages_dep = tf.no_op()
            if hparams.moving_average_decay > 0.:
                moving_averager = tf.train.ExponentialMovingAverage(
                    hparams.moving_average_decay,
                    zero_debias=True,
                    num_updates=tf.train.get_global_step())
                moving_average_op = moving_averager.apply(
                    train_utils.get_vars_for_moving_average(
                        hparams.average_norms))
                # tf.logging.log(tf.logging.INFO,
                #                "Using moving average for variables: %s" % str([v.name for v in tf.trainable_variables()])

                tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,
                                     moving_average_op)

                # use moving averages of variables if evaluating
                assign_moving_averages_dep = tf.cond(
                    tf.equal(mode, ModeKeys.TRAIN), lambda: tf.no_op(), lambda:
                    nn_utils.set_vars_to_moving_average(moving_averager))

            with tf.control_dependencies([assign_moving_averages_dep]):

                items_to_log['loss'] = loss

                # get learning rate w/ decay
                this_step_lr = train_utils.learning_rate(
                    hparams, tf.train.get_global_step())
                items_to_log['lr'] = this_step_lr

                # optimizer = tf.contrib.opt.NadamOptimizer(learning_rate=this_step_lr, beta1=hparams.beta1,
                #                                              beta2=hparams.beta2, epsilon=hparams.epsilon)
                optimizer = LazyAdamOptimizer(
                    learning_rate=this_step_lr,
                    beta1=hparams.beta1,
                    beta2=hparams.beta2,
                    epsilon=hparams.epsilon,
                    use_nesterov=hparams.use_nesterov)
                gradients, variables = zip(*optimizer.compute_gradients(loss))
                gradients, _ = tf.clip_by_global_norm(
                    gradients, hparams.gradient_clip_norm)
                train_op = optimizer.apply_gradients(
                    zip(gradients, variables),
                    global_step=tf.train.get_global_step())

                # export_outputs = {'predict_output': tf.estimator.export.PredictOutput({'scores': scores, 'preds': preds})}

                logging_hook = tf.train.LoggingTensorHook(items_to_log,
                                                          every_n_iter=20)

                # need to flatten the dict of predictions to make Estimator happy
                flat_predictions = {
                    "%s_%s" % (k1, k2): v2
                    for k1, v1 in predictions.items() for k2, v2 in v1.items()
                }

                export_outputs = {
                    tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                    tf.estimator.export.PredictOutput(flat_predictions)
                }

                tf.logging.log(
                    tf.logging.INFO,
                    "Created model with %d trainable parameters" %
                    tf_utils.get_num_trainable_parameters())

                return tf.estimator.EstimatorSpec(
                    mode,
                    flat_predictions,
                    loss,
                    train_op,
                    eval_metric_ops,
                    training_hooks=[logging_hook],
                    export_outputs=export_outputs)