Beispiel #1
0
 def test_SegmentMean(self):
     t = tf.segment_mean(self.random(4, 2, 3), np.array([0, 1, 1, 2]))
     self.check(t)
Beispiel #2
0
    def train_model(self, vocabulary_size, window_size, data, instances,
                    labels, context, doc):
        batch_size = 256
        context_window = 2 * window_size
        embedding_size = 50  # Dimension of the embedding vector.
        softmax_width = embedding_size  # +embedding_size2+embedding_size3
        num_sampled = 5  # Number of negative examples to sample.
        sum_ids = np.repeat(np.arange(batch_size), context_window)
        len_docs = len(data)

        #定义训练网络结构
        graph = tf.Graph()

        with graph.as_default():
            train_word_dataset = tf.placeholder(
                tf.int32, shape=[batch_size * context_window])
            train_doc_dataset = tf.placeholder(tf.int32, shape=[batch_size])
            train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])

            segment_ids = tf.constant(sum_ids, dtype=tf.int32)

            word_embeddings = tf.Variable(
                tf.random_uniform([vocabulary_size, embedding_size], -1.0,
                                  1.0))
            word_embeddings = tf.concat(
                [word_embeddings,
                 tf.zeros((1, embedding_size))], 0)
            doc_embeddings = tf.Variable(
                tf.random_uniform([len_docs, embedding_size], -1.0, 1.0))

            softmax_weights = tf.Variable(
                tf.truncated_normal([vocabulary_size, softmax_width],
                                    stddev=1.0 / np.sqrt(embedding_size)))
            softmax_biases = tf.Variable(tf.zeros([vocabulary_size]))

            embed_words = tf.segment_mean(
                tf.nn.embedding_lookup(word_embeddings, train_word_dataset),
                segment_ids)
            embed_docs = tf.nn.embedding_lookup(doc_embeddings,
                                                train_doc_dataset)
            embed = (embed_words + embed_docs) / 2.0

            loss = tf.reduce_mean(
                tf.nn.nce_loss(softmax_weights, softmax_biases, train_labels,
                               embed, num_sampled, vocabulary_size))

            optimizer = tf.train.AdagradOptimizer(0.5).minimize(loss)

            norm = tf.sqrt(
                tf.reduce_sum(tf.square(doc_embeddings), 1, keep_dims=True))
            normalized_doc_embeddings = doc_embeddings / norm

        num_steps = 10000
        step_delta = int(num_steps / 20)

        #训练网络
        with tf.Session(graph=graph) as session:
            tf.global_variables_initializer().run()
            print('Initialized')
            average_loss = 0
            for step in range(num_steps):
                batch_labels, batch_word_data, batch_doc_data = self.generate_batch(
                    batch_size, instances, labels, context, doc)
                feed_dict = {
                    train_word_dataset: np.squeeze(batch_word_data),
                    train_doc_dataset: np.squeeze(batch_doc_data),
                    train_labels: batch_labels
                }
                _, l = session.run([optimizer, loss], feed_dict=feed_dict)
                average_loss += l
                if step % step_delta == 0:
                    if step > 0:
                        average_loss = average_loss / step_delta
                    print('Average loss at step %d: %f' % (step, average_loss))
                    average_loss = 0

            final_word_embeddings = word_embeddings.eval()
            final_word_embeddings_out = softmax_weights.eval()
            final_doc_embeddings = normalized_doc_embeddings.eval()

        return final_doc_embeddings, final_word_embeddings, final_word_embeddings_out
Beispiel #3
0
#Segmentation Examples
import tensorflow as tf
sess = tf.InteractiveSession()
seg_ids = tf.constant([0,1,1,2,2]); # Group indexes : 0|1,2|3,4

tens1 = tf.constant([[2, 5, 3, -5],  
                    [0, 3,-2,  5], 
                    [4, 3, 5,  3], 
                    [6, 1, 4,  0],
                    [6, 1, 4,  0]])  # A sample constant matrix

tf.segment_sum(tens1, seg_ids).eval()   # Sum segmentation
tf.segment_prod(tens1, seg_ids).eval() # Product segmantation
tf.segment_min(tens1, seg_ids).eval() # minimun value goes to group
tf.segment_max(tens1, seg_ids).eval() # maximum value goes to group
tf.segment_mean(tens1, seg_ids).eval() # mean value goes to group
#Segmentation Examples
import tensorflow as tf
sess = tf.InteractiveSession()
seg_ids = tf.constant([0, 1, 1, 2, 2])
# Group indexes : 0|1,2|3,4

tens1 = tf.constant([[2, 5, 3, -5], [0, 3, -2, 5], [4, 3, 5, 3], [6, 1, 4, 0],
                     [6, 1, 4, 0]])  # A sample constant matrix

tf.segment_sum(tens1, seg_ids).eval()  # Sum segmentation
tf.segment_prod(tens1, seg_ids).eval()  # Product segmantation
tf.segment_min(tens1, seg_ids).eval()  # minimun value goes to group
tf.segment_max(tens1, seg_ids).eval()  # maximum value goes to group
tf.segment_mean(tens1, seg_ids).eval()  # mean value goes to group
Beispiel #5
0
    def create_model(self):
        g = tf.Graph()
        with g.as_default():
            # Define model variables
            var_linear = tf.get_variable(
                'linear', [self.feature_dim, 1],
                initializer=tf.random_uniform_initializer(
                    -self.args.init_mean, self.args.init_mean))

            var_emb_factors = tf.get_variable(
                'emb_factors', [self.feature_dim, self.args.num_dims],
                initializer=tf.random_uniform_initializer(
                    -self.args.init_mean, self.args.init_mean))

            # Sparse placeholders
            pl_user_list = tf.placeholder(tf.int64,
                                          shape=[None],
                                          name='pos_list')

            pl_pos_indices = tf.placeholder(tf.int64,
                                            shape=[None, 2],
                                            name='pos_indices')
            pl_pos_values = tf.placeholder(tf.float32,
                                           shape=[None],
                                           name='pos_values')
            pl_pos_shape = tf.placeholder(tf.int64,
                                          shape=[2],
                                          name='pos_shape')

            pl_neg_indices = tf.placeholder(tf.int64,
                                            shape=[None, 2],
                                            name='neg_indices')
            pl_neg_values = tf.placeholder(tf.float32,
                                           shape=[None],
                                           name='neg_values')
            pl_neg_shape = tf.placeholder(tf.int64,
                                          shape=[2],
                                          name='neg_shape')

            placeholders = {
                'pl_user_list': pl_user_list,
                'pl_pos_indices': pl_pos_indices,
                'pl_pos_values': pl_pos_values,
                'pl_pos_shape': pl_pos_shape,
                'pl_neg_indices': pl_neg_indices,
                'pl_neg_values': pl_neg_values,
                'pl_neg_shape': pl_neg_shape
            }

            # Input positive features, shape = (batch_size * feature_dim)
            sparse_pos_feats = tf.SparseTensor(pl_pos_indices, pl_pos_values,
                                               pl_pos_shape)

            # Input negative features, shape = (batch_size * feature_dim)
            sparse_neg_feats = tf.SparseTensor(pl_neg_indices, pl_neg_values,
                                               pl_neg_shape)

            pos_preds, neg_preds = self.get_preds(var_linear, var_emb_factors,
                                                  sparse_pos_feats,
                                                  sparse_neg_feats)

            l2_reg = tf.add_n([
                self.args.linear_reg * tf.reduce_sum(tf.square(var_linear)),
                self.args.emb_reg * tf.reduce_sum(tf.square(var_emb_factors)),
            ])

            # BPR training op (add 1e-10 to help numerical stability)
            bprloss_op = tf.reduce_sum(
                tf.log(1e-10 + tf.sigmoid(pos_preds - neg_preds))) - l2_reg
            bprloss_op = -bprloss_op

            global_step = tf.Variable(0, trainable=False)
            learning_rate = tf.train.exponential_decay(
                self.args.starting_lr,
                global_step,
                self.args.lr_decay_freq,
                self.args.lr_decay_factor,
                staircase=False)
            optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            train_op = optimizer.minimize(bprloss_op, global_step=global_step)

            # AUC
            binary_ranks = tf.to_float((pos_preds - neg_preds) > 0)
            auc_per_user = tf.segment_mean(binary_ranks, pl_user_list)
            auc_op = tf.divide(
                tf.reduce_sum(auc_per_user),
                tf.to_float(tf.size(tf.unique(pl_user_list)[0])))

        self.var_linear = var_linear
        self.var_emb_factors = var_emb_factors
        return (g, bprloss_op, optimizer, train_op, auc_op, l2_reg,
                placeholders)
Beispiel #6
0
    def call(self, inputs):
        # Note that I is useless, because thee layer cannot be used in graph
        # batch mode.
        if len(inputs) == 3:
            X, A, I = inputs
        else:
            X, A = inputs
            I = None

        # Check if the layer is operating in batch mode (X and A have rank 3)
        batch_mode = K.ndim(A) == 3

        # Optionally compute hidden layer
        if self.h is None:
            Hid = X
        else:
            Hid = K.dot(X, self.kernel_in)
            if self.use_bias:
                Hid = K.bias_add(Hid, self.bias_in)
            if self.activation is not None:
                Hid = self.activation(Hid)

        # Compute cluster assignment matrix
        S = K.dot(Hid, self.kernel_out)
        if self.use_bias:
            S = K.bias_add(S, self.bias_out)
        S = activations.softmax(
            S, axis=-1)  # Apply softmax to get cluster assignments

        # MinCut regularization
        A_pooled = ops.matmul_AT_B_A(S, A)
        num = tf.trace(A_pooled)

        D = ops.degree_matrix(A)
        den = tf.trace(ops.matmul_AT_B_A(S, D))
        cut_loss = -(num / den)
        if batch_mode:
            cut_loss = K.mean(cut_loss)
        self.add_loss(cut_loss)

        # Orthogonality regularization
        SS = ops.matmul_AT_B(S, S)
        I_S = tf.eye(self.k)
        ortho_loss = tf.norm(SS / tf.norm(SS, axis=(-1, -2)) -
                             I_S / tf.norm(I_S),
                             axis=(-1, -2))
        if batch_mode:
            ortho_loss = K.mean(cut_loss)
        self.add_loss(ortho_loss)

        # Pooling
        X_pooled = ops.matmul_AT_B(S, X)
        A_pooled = tf.linalg.set_diag(A_pooled, tf.zeros(
            K.shape(A_pooled)[:-1]))  # Remove diagonal
        A_pooled = ops.normalize_A(A_pooled)

        output = [X_pooled, A_pooled]

        if I is not None:
            I_mean = tf.segment_mean(I, I)
            I_pooled = ops.tf_repeat_1d(I_mean, tf.ones_like(I_mean) * self.k)
            output.append(I_pooled)

        if self.return_mask:
            output.append(S)

        return output
Beispiel #7
0
    def call(self, inputs):
        # Note that I is useless, because thee layer cannot be used in graph
        # batch mode.
        if len(inputs) == 3:
            X, A, I = inputs
        else:
            X, A = inputs
            I = None

        N = K.shape(A)[-1]
        # Check if the layer is operating in batch mode (X and A have rank 3)
        batch_mode = K.ndim(A) == 3

        # Get normalized adjacency
        if K.is_sparse(A):
            I_ = tf.sparse.eye(N, dtype=A.dtype)
            A_ = tf.sparse.add(A, I_)
        else:
            I_ = tf.eye(N, dtype=A.dtype)
            A_ = A + I_
        fltr = ops.normalize_A(A_)

        # Node embeddings
        Z = K.dot(X, self.kernel_emb)
        Z = ops.filter_dot(fltr, Z)
        if self.activation is not None:
            Z = self.activation(Z)

        # Compute cluster assignment matrix
        S = K.dot(X, self.kernel_pool)
        S = ops.filter_dot(fltr, S)
        S = activations.softmax(S, axis=-1)  # softmax applied row-wise

        # Link prediction loss
        S_gram = ops.matmul_A_BT(S, S)
        if K.is_sparse(A):
            LP_loss = tf.sparse.add(
                A, -S_gram)  # A/tf.norm(A) - S_gram/tf.norm(S_gram)
        else:
            LP_loss = A - S_gram
        LP_loss = tf.norm(LP_loss, axis=(-1, -2))
        if batch_mode:
            LP_loss = K.mean(LP_loss)
        self.add_loss(LP_loss)

        # Entropy loss
        entr = tf.negative(
            tf.reduce_sum(tf.multiply(S, K.log(S + K.epsilon())), axis=-1))
        entr_loss = K.mean(entr, axis=-1)
        if batch_mode:
            entr_loss = K.mean(entr_loss)
        self.add_loss(entr_loss)

        # Pooling
        X_pooled = ops.matmul_AT_B(S, Z)
        A_pooled = ops.matmul_AT_B_A(S, A)

        output = [X_pooled, A_pooled]

        if I is not None:
            I_mean = tf.segment_mean(I, I)
            I_pooled = ops.tf_repeat_1d(I_mean, tf.ones_like(I_mean) * self.k)
            output.append(I_pooled)

        if self.return_mask:
            output.append(S)

        return output
Beispiel #8
0
    word_embeddings = tf.Variable(
        tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
    word_embeddings = tf.concat(
        [word_embeddings, tf.zeros((1, embedding_size))], 0)
    doc_embeddings = tf.Variable(
        tf.random_uniform([len_docs, embedding_size], -1.0, 1.0))

    softmax_weights = tf.Variable(
        tf.truncated_normal([vocabulary_size, softmax_width],
                            stddev=1.0 / np.sqrt(embedding_size)))
    softmax_biases = tf.Variable(tf.zeros([vocabulary_size]))

    # Model.
    # Look up embeddings for inputs.
    embed_words = tf.segment_mean(
        tf.nn.embedding_lookup(word_embeddings, train_word_dataset),
        segment_ids)
    embed_docs = tf.nn.embedding_lookup(doc_embeddings, train_doc_dataset)
    embed = (embed_words + embed_docs) / 2.0  #+embed_hash+embed_users

    # Compute the softmax loss, using a sample of the negative labels each time.
    loss = tf.reduce_mean(
        tf.nn.nce_loss(softmax_weights, softmax_biases, train_labels, embed,
                       num_sampled, vocabulary_size))

    # Optimizer.
    optimizer = tf.train.AdagradOptimizer(0.5).minimize(loss)

    norm = tf.sqrt(tf.reduce_sum(tf.square(doc_embeddings), 1, keep_dims=True))
    normalized_doc_embeddings = doc_embeddings / norm
Beispiel #9
0
    def __init__(self):
        with tf.variable_scope('knet'):
            with tf.variable_scope('preprocessing'):
                # generate useful box transformations (once)
                self.dets_boxdata = self._xyxy_to_boxdata(self.dets)
                self.gt_boxdata = self._xyxy_to_boxdata(self.gt_boxes)

                # overlaps
                self.det_anno_iou = self._iou(self.dets_boxdata,
                                              self.gt_boxdata, self.gt_crowd)
                self.det_det_iou = self._iou(self.dets_boxdata,
                                             self.dets_boxdata)
                if self.multiclass:
                    # set overlaps of detection and annotations to 0 if they
                    # have different classes, so they don't get matched in the
                    # loss
                    print('doing multiclass NMS')
                    same_class = tf.equal(
                        tf.reshape(self.det_classes, [-1, 1]),
                        tf.reshape(self.gt_classes, [1, -1]))
                    zeros = tf.zeros_like(self.det_anno_iou)
                    self.det_anno_iou = tf.select(same_class,
                                                  self.det_anno_iou, zeros)
                else:
                    print('doing single class NMS')

                # find neighbors
                self.neighbor_pair_idxs = tf.where(
                    tf.greater_equal(self.det_det_iou,
                                     cfg.gnet.neighbor_thresh))
                pair_c_idxs = self.neighbor_pair_idxs[:, 0]
                pair_n_idxs = self.neighbor_pair_idxs[:, 1]

                # generate handcrafted pairwise features
                self.num_dets = tf.shape(self.dets)[0]
                pw_feats = self._geometry_feats(pair_c_idxs, pair_n_idxs)

            # pw_feats -> K
            with tf.variable_scope('K'):
                num_fc = 3
                feats = pw_feats
                dim = cfg.knet.pairfeat_dim
                for i in range(1, num_fc + 1):
                    feats = tf.contrib.layers.fully_connected(
                        inputs=feats,
                        num_outputs=dim,
                        activation_fn=tf.nn.relu,
                        weights_initializer=weights_init,
                        weights_regularizer=weight_reg,
                        biases_initializer=biases_init,
                        scope='fc{}'.format(i))
                dim = cfg.knet.feat_dim * cfg.knet.feat_dim
                feats = tf.contrib.layers.fully_connected(
                    inputs=feats,
                    num_outputs=dim,
                    activation_fn=None,
                    weights_initializer=weights_init,
                    weights_regularizer=weight_reg,
                    biases_initializer=biases_init,
                    scope='fc'.format(num_fc + 1))
                K = tf.segment_mean(feats, pair_c_idxs, name='mean')

            # imfeats -> f(x)
            with tf.variable_scope('imfeats'):
                self.imfeats, stride, self._ignore_prefixes = get_resnet(
                    self.image)
                self.det_imfeats = crop_windows(self.imfeats,
                                                self.dets_boxdata, stride)
                self.det_imfeats = tf.contrib.layers.flatten(self.det_imfeats)

                feats = tf.contrib.layers.fully_connected(
                    inputs=self.det_imfeats,
                    num_outputs=cfg.knet.feat_dim,
                    activation_fn=tf.nn.relu,
                    weights_initializer=weights_init,
                    weights_regularizer=weight_reg,
                    biases_initializer=biases_init,
                    scope='reduce')

            f_new = tf.mat_mul(K, feats)
            with tf.variable_scope('predict'):
                self.prediction = tf.contrib.layers.fully_connected(
                    inputs=f_new,
                    num_outputs=1,
                    activation_fn=None,
                    weights_initializer=weights_init,
                    weights_regularizer=weight_reg,
                    biases_initializer=biases_init,
                    scope='reduce')

        with tf.variable_scope('loss'):
            self.loss()

        # collect trainable variables
        tvars = tf.trainable_variables()
        self.trainable_variables = [
            var for var in tvars
            if (var.name.startswith('gnet') or var.name.startswith('resnet'))
            and not any(
                var.name.startswith(pref) for pref in self._ignore_prefixes)
        ]
Beispiel #10
0
 def test_SegmentMean(self):
     t = tf.segment_mean(self.random(4, 2, 3), np.array([0, 1, 1, 2]))
     self.check(t)
Beispiel #11
0
def segment_mean(value):
    shape = image_util.get_shape(value)
    test_batch_size_per_gpu = (FLAGS.batch_size / FLAGS.num_test_crops) / FLAGS.num_gpus
    value = tf.segment_mean(value, np.repeat(np.arange(test_batch_size_per_gpu), FLAGS.num_test_crops))
    value.set_shape((None,) + shape[1:])
    return value
Beispiel #12
0
def model_fn(features, labels, mode, params):
    is_training = mode == ModeKeys.TRAIN
    model_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    num_labels, learning_rate, use_tpu, use_one_hot_embeddings = (
        model_config.num_labels, model_config.learning_rate,
        model_config.use_tpu, model_config.use_one_hot_embeddings)
    init_checkpoint = (model_config.init_checkpoint
                       and model_config.init_checkpoint.split(':')[-1])

    steps_per_epoch = model_config.num_training_examples / model_config.batch_size
    num_train_steps = int(steps_per_epoch * model_config.epochs)
    num_warmup_steps = int(steps_per_epoch * model_config.warmup_proportion)

    assert model_config.num_queries == 2
    text_a = features['sentence']
    text_b = features['entity']
    label = features.get('label')

    with tf.device('/device:CPU:0'):
        to_dense = functools.partial(custom_bert.bert_sparse_to_dense,
                                     text_a,
                                     seq_length=model_config.seq_length,
                                     vocab_file_path=FLAGS.vocab_file,
                                     discard_text_b=True)
        input_ids, input_mask, segment_ids = map(tf.to_int32, to_dense(text_a))
        concat_int = tf.concat([input_ids, input_mask, segment_ids], axis=1)
        # unique along the batch dimension.
        unique_int, segment_idx = unique_2d(concat_int)
        input_ids, input_mask, segment_ids = tf.split(unique_int,
                                                      axis=1,
                                                      num_or_size_splits=3)
    model = modeling.BertModel(config=model_config,
                               is_training=is_training,
                               input_ids=input_ids,
                               input_mask=input_mask,
                               token_type_ids=segment_ids,
                               use_one_hot_embeddings=use_one_hot_embeddings)

    sequence_output = tf.gather(model.get_sequence_output(), segment_idx)
    orig_width = sequence_output.shape[-1].value
    hidden_size = orig_width
    expt_flags = dict(
        t.split(':')
        for t in filter(None, model_config.experimental_flags.split(';')))
    # only use the first column. Later pword columns are treated as features.
    # label > -1 as context feature mask.
    pword_context_aggregator = expt_flags.get('pword_context_aggregator')
    needle_embedding_aggregator = expt_flags.get('needle_embedding_aggregator')
    output_dim = int(expt_flags.get('output_dim', 1))  # 2 means using softmax
    if pword_context_aggregator:
        hidden_size += orig_width
    if eval(expt_flags.get('concat_first_embedding', 'False')):
        hidden_size += orig_width
    output_weights, output_biases = [], []
    with tf.variable_scope("loss"):
        output_layers = list(
            map(int,
                filter(None,
                       expt_flags.get('output_layers', '').split(','))))
        output_layers = [hidden_size] + output_layers + [output_dim]
        for i, (a, b) in enumerate(zip(output_layers[:-1], output_layers[1:])):
            suffix = '' if i == 0 else '_%d' % i
            output_weights.append(
                tf.get_variable(
                    'output_weights' + suffix, [b, a],
                    initializer=tf.truncated_normal_initializer(stddev=0.02)))
            output_biases.append(
                tf.get_variable('output_bias' + suffix, [b],
                                initializer=tf.zeros_initializer()))

    def mlp(net, weights, biases):
        for i, (w, b) in enumerate(zip(weights, biases)):
            dropout_rate = float(expt_flags.get('mlp_dropout_rate', 0.0))
            if dropout_rate > 0.0 and is_training:
                net = modeling.dropout(net, dropout_rate)
            if eval(expt_flags.get('mlp_layer_norm', 'False')):
                net = modeling.layer_norm(net)
            net = tf.nn.bias_add(tf.matmul(net, w, transpose_b=True), b)
            if i < len(weights) - 1:
                net = modeling.gelu(net)
        return net

    output_layers = []

    # batch x num needles
    batch_idx, needle_idx, start_pos, needle_widths = map(
        tf.to_int32, (sparse_sequence_match(text_a, text_b)))
    # even with preprocessing some sentences may not have any matched entity.
    if eval(expt_flags.get('fill_missing', 'False')):
        batch_idx, needle_idx, start_pos = fill_missing(
            batch_idx, needle_idx, start_pos, text_a, num_labels)

    # batch x num needles x needle width
    batch_idx2, needle_idx2, needle_pos = flatten_needle(
        batch_idx, needle_idx, start_pos, needle_widths)
    # batch x needle idx x haystack pos
    # sequence output: batch x sequence length x embedding dim.
    indices = tf.stack([batch_idx2, needle_pos], axis=1)
    # (batch x num needles x needle width) x embedding_dim
    embeddings = tf.gather_nd(sequence_output, indices)
    needle_idx3 = bash_uniq_with_counts([batch_idx2, needle_idx2])[1]

    # (batch x num_needles) x embedding_dim
    output_aggregator = expt_flags.get('output_aggregator', 'segment_mean')
    output_layer = aggregate_embedding(embeddings,
                                       needle_idx3,
                                       output_aggregator,
                                       name='aggregated_entity_embedding',
                                       aux={
                                           'needle_pos': needle_pos,
                                           'sequence_output': sequence_output,
                                           'batch_idx2': batch_idx2,
                                           'is_training': is_training
                                       },
                                       config=model_config)
    if eval(expt_flags.get('concat_first_embedding', 'False')):
        pooled_layer = tf.gather(model.get_pooled_output(), segment_idx)
        if not output_aggregator.startswith('transformer'):
            pooled_layer = tf.gather(pooled_layer, to_vec(batch_idx))
        output_layer = tf.concat([output_layer, pooled_layer], axis=1)
    if needle_embedding_aggregator:
        assert not output_aggregator.startswith('transformer'), (
            'transformer aggregator already aggregates needles in the same row!'
        )
        output_layer = aggregate_embedding(output_layer, to_vec(batch_idx),
                                           needle_embedding_aggregator)
    if is_training and float(expt_flags.get('mlp_dropout_rate', 0.0)) <= 0.0:
        # I.e., 0.1 dropout
        output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)

    logits2 = mlp(output_layer, output_weights, output_biases)
    if needle_embedding_aggregator or output_aggregator.startswith(
            'transformer'):
        logits = tf.identity(logits2, name='mean_logits')
    else:
        logits = tf.segment_mean(tf.reshape(logits2, [-1, output_dim]),
                                 to_vec(batch_idx),
                                 name='mean_logits')
    loss, per_example_loss = None, None
    if mode != ModeKeys.PREDICT:
        if output_dim == 1:
            per_example_loss = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=to_vec(labels[:, 0]), logits=to_vec(logits))
        else:  # pairwise softmax
            assert output_dim == 2
            dense_log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(tf.to_int32(labels[:, 0]),
                                        depth=output_dim,
                                        dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(one_hot_labels * dense_log_probs,
                                              axis=-1)
            dense_probs = tf.maximum(
                1e-8, tf.minimum(1.0 - 1e-8, tf.exp(dense_log_probs[:, 1])))
            logits = -tf.log(1. / dense_probs - 1.)
        per_query_loss = tf.reduce_sum(per_example_loss, axis=-1)
        loss = tf.reduce_mean(per_query_loss)

    tvars = tf.trainable_variables()
    initialized_variable_names = {}
    if init_checkpoint:
        (assignment_map, initialized_variable_names
         ) = modeling.get_assignment_map_from_checkpoint(
             tvars, init_checkpoint)
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    tf.logging.info("**** Trainable Variables ****")
    for var in tvars:
        init_string = ""
        if var.name in initialized_variable_names:
            init_string = ", *INIT_FROM_CKPT*"
        tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                        init_string)
    output_spec = None
    if mode == ModeKeys.TRAIN:
        train_op = optimization.create_optimizer(loss, learning_rate,
                                                 num_train_steps,
                                                 num_warmup_steps, use_tpu)
        output_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                 loss=loss,
                                                 train_op=train_op)
    elif mode == ModeKeys.EVAL:

        def metric_fn(per_example_loss, labels, logits):
            metrics = {}

            def update_metrics(labels, logits, metrics, scope=None):
                probabilities = to_vec(tf.sigmoid(logits))
                predicted_classes = tf.to_int32(probabilities > 0.5)
                tmp = binary_classification_metrics(
                    to_vec(labels), {
                        'predicted_classes': predicted_classes,
                        'probabilities': probabilities,
                        'logits': to_vec(logits)
                    })
                for k, v in tmp.items():
                    metrics['%s/%s' % (scope, k) if scope else k] = v

            update_metrics(labels, logits, metrics)
            for i in range(num_labels):
                update_metrics(labels[:, i], logits[:, i], metrics,
                               'column_%d' % i)
            return metrics

        eval_metric_ops = metric_fn(per_example_loss, labels, logits)
        output_spec = tf.estimator.EstimatorSpec(
            mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
    else:  # PREDICT
        predictions = {
            "probabilities": tf.sigmoid(logits),
            "logits": logits,
        }
        output_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                 predictions=predictions)
    return output_spec