Exemple #1
0
    def test_masked_sum_nd(self):
        tf.reset_default_graph()

        data = tf.placeholder(tf.float32, shape=[None, None, None])
        mask = tf.placeholder(tf.float32, shape=[None, None])
        masked_sums = utils.masked_sum_nd(data, mask)

        with self.test_session() as sess:
            result = sess.run(masked_sums,
                              feed_dict={
                                  data: [[[1, 2], [3, 4], [5, 6]],
                                         [[7, 8], [9, 10], [11, 12]]],
                                  mask: [[1, 0, 1], [0, 1, 0]]
                              })
            self.assertAllClose(result, [[[6, 8]], [[9, 10]]])
Exemple #2
0
  def encode(self, feature, length, scope=None):
    """Encodes sequence features into representation.

    Args:
      feature: A [batch, max_sequence_len, dims] float tensor.
      length: A [batch] int tensor.

    Returns:
      A [batch, dims] float tensor.
    """
    options = self._model_proto
    is_training = self._is_training

    mask = tf.sequence_mask(
        length, maxlen=utils.get_tensor_shape(feature)[1], dtype=tf.float32)

    # Compute attention distribution.

    node = feature
    for i in range(options.hidden_layers):
      node = tf.contrib.layers.fully_connected(
          inputs=node,
          num_outputs=feature.get_shape()[-1].value,
          scope=scope + '/hidden_{}'.format(i))
    logits = tf.contrib.layers.fully_connected(
        inputs=node, num_outputs=1, activation_fn=None, scope=scope)

    probas = utils.masked_softmax(
        data=logits, mask=tf.expand_dims(mask, axis=-1), dim=1)
    feature = utils.masked_sum_nd(data=feature * probas, mask=mask, dim=1)

    # Summary.

    #tf.summary.histogram('attn/probas/' + scope, probas)
    #tf.summary.histogram('attn/logits/' + scope, logits)

    return tf.squeeze(feature, axis=1)
Exemple #3
0
    def build_loss(self, predictions, examples, **kwargs):
        """Build tf graph to compute loss.

    Args:
      predictions: dict of prediction results keyed by name.
      examples: dict of inputs keyed by name.

    Returns:
      loss_dict: dict of loss tensors keyed by name.
    """
        options = self._model_proto

        loss_dict = {}

        with tf.name_scope('losses'):

            # Extract image-level labels.

            if not options.caption_as_label:
                labels = self._extract_class_label(
                    class_texts=examples[InputDataFields.object_texts],
                    vocabulary_list=self._vocabulary_list)
            else:
                labels = self._extract_class_label(
                    class_texts=slim.flatten(
                        examples[InputDataFields.caption_strings]),
                    vocabulary_list=self._vocabulary_list)

            # A prediction model from caption to class

            # Loss of the multi-instance detection network.

            midn_class_logits = predictions[NOD2Predictions.midn_class_logits]
            losses = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=labels, logits=midn_class_logits)

            # Hard-negative mining.

            if options.midn_loss_negative_mining == nod2_model_pb2.NOD2Model.NONE:
                if options.classification_loss_use_sum:
                    assert False
                    loss_dict['midn_cross_entropy_loss'] = tf.multiply(
                        tf.reduce_mean(tf.reduce_sum(losses, axis=-1)),
                        options.midn_loss_weight)
                else:
                    if options.caption_as_label:
                        loss_masks = tf.to_float(
                            tf.reduce_any(labels > 0, axis=-1))
                        loss_dict['midn_cross_entropy_loss'] = tf.multiply(
                            tf.squeeze(
                                utils.masked_avg(tf.reduce_mean(losses,
                                                                axis=-1),
                                                 mask=loss_masks,
                                                 dim=0)),
                            options.midn_loss_weight)
                    else:
                        loss_dict['midn_cross_entropy_loss'] = tf.multiply(
                            tf.reduce_mean(losses), options.midn_loss_weight)
            elif options.midn_loss_negative_mining == nod2_model_pb2.NOD2Model.HARDEST:
                assert False
                loss_masks = self._midn_loss_mine_hardest_negative(
                    labels, losses)
                loss_dict['midn_cross_entropy_loss'] = tf.reduce_mean(
                    utils.masked_avg(data=losses, mask=loss_masks, dim=1))
            else:
                raise ValueError('Invalid negative mining method.')

            # Losses of the online instance classifier refinement network.

            (num_proposals,
             proposals) = (predictions[DetectionResultFields.num_proposals],
                           predictions[DetectionResultFields.proposal_boxes])
            batch, max_num_proposals, _ = utils.get_tensor_shape(proposals)

            proposal_scores_0 = predictions[
                NOD2Predictions.oicr_proposal_scores + '_at_0']
            if options.oicr_use_proba_r_given_c:
                proposal_scores_0 = predictions[
                    NOD2Predictions.midn_proba_r_given_c]

            proposal_scores_0 = tf.concat([
                tf.fill([batch, max_num_proposals, 1], 0.0), proposal_scores_0
            ],
                                          axis=-1)

            global_step = tf.train.get_or_create_global_step()
            oicr_loss_mask = tf.cast(global_step > options.oicr_start_step,
                                     tf.float32)

            for i in range(options.oicr_iterations):
                proposal_scores_1 = predictions[
                    NOD2Predictions.oicr_proposal_scores +
                    '_at_{}'.format(i + 1)]
                oicr_cross_entropy_loss_at_i = model_utils.calc_oicr_loss(
                    labels,
                    num_proposals,
                    proposals,
                    tf.stop_gradient(proposal_scores_0),
                    proposal_scores_1,
                    scope='oicr_{}'.format(i + 1),
                    iou_threshold=options.oicr_iou_threshold)
                loss_dict['oicr_cross_entropy_loss_at_{}'.format(
                    i + 1)] = tf.multiply(
                        oicr_loss_mask * oicr_cross_entropy_loss_at_i,
                        options.oicr_loss_weight)

                proposal_scores_0 = tf.nn.softmax(proposal_scores_1, axis=-1)

            # Min-entropy loss.

            mask = tf.sequence_mask(num_proposals,
                                    maxlen=max_num_proposals,
                                    dtype=tf.float32)
            proba_r_given_c = predictions[NOD2Predictions.midn_proba_r_given_c]
            losses = tf.log(proba_r_given_c + _EPSILON)
            losses = tf.squeeze(utils.masked_sum_nd(data=losses,
                                                    mask=mask,
                                                    dim=1),
                                axis=1)
            min_entropy_loss = tf.reduce_mean(
                tf.reduce_sum(losses * labels, axis=1))
            min_entropy_loss = tf.multiply(min_entropy_loss,
                                           options.min_entropy_loss_weight)

            max_proba = tf.reduce_mean(
                utils.masked_maximum(data=proba_r_given_c,
                                     mask=tf.expand_dims(mask, -1),
                                     dim=1))
            tf.losses.add_loss(min_entropy_loss)

        tf.summary.scalar('loss/min_entropy_loss', min_entropy_loss)
        tf.summary.scalar('loss/max_proba', max_proba)

        return loss_dict
Exemple #4
0
  def build_prediction(self, examples, **kwargs):
    """Builds tf graph for prediction.

    Args:
      examples: dict of input tensors keyed by name.
      prediction_task: the specific prediction task.

    Returns:
      predictions: dict of prediction results keyed by name.
    """
    options = self._model_proto
    is_training = self._is_training

    # Image CNN features.

    inputs = examples[InputDataFields.image]
    image_features = model_utils.calc_cnn_feature(
        inputs, options.cnn_options, is_training=is_training)

    with slim.arg_scope(
        build_hyperparams(options.image_fc_hyperparams, is_training)):
      image_features = slim.fully_connected(
          image_features,
          num_outputs=options.shared_dims,
          activation_fn=None,
          scope='image')

    # Text Global-Average-Pooling features.

    (image_id, num_captions, caption_strings,
     caption_lengths) = (examples[InputDataFields.image_id],
                         examples[InputDataFields.num_captions],
                         examples[InputDataFields.caption_strings],
                         examples[InputDataFields.caption_lengths])
    image_id = tf.string_to_number(image_id, out_type=tf.int64)

    (image_ids_gathered, caption_strings_gathered,
     caption_lengths_gathered) = model_utils.gather_in_batch_captions(
         image_id, num_captions, caption_strings, caption_lengths)

    (caption_token_ids_gathered,
     caption_features_gathered) = self._extract_text_feature(
         caption_strings_gathered,
         caption_lengths_gathered,
         vocabulary_list=self._open_vocabulary_list,
         initial_embedding=self._open_vocabulary_initial_embedding,
         embedding_dims=options.embedding_dims,
         trainable=options.train_word_embedding,
         max_norm=None)

    with slim.arg_scope(
        build_hyperparams(options.text_fc_hyperparams, is_training)):
      if visual_w2v_model_pb2.VisualW2vModel.ATT == options.text_feature_extractor:
        attn = slim.fully_connected(
            caption_features_gathered,
            num_outputs=1,
            activation_fn=None,
            scope='caption_attn')
        attn = tf.squeeze(attn, axis=-1)
      caption_features_gathered = slim.fully_connected(
          caption_features_gathered,
          num_outputs=options.shared_dims,
          activation_fn=None,
          scope='caption')

    oov = len(self._open_vocabulary_list)
    caption_masks_gathered = tf.logical_not(
        tf.equal(caption_token_ids_gathered, oov))
    caption_masks_gathered = tf.to_float(caption_masks_gathered)

    if visual_w2v_model_pb2.VisualW2vModel.GAP == options.text_feature_extractor:
      caption_features_gathered = utils.masked_avg_nd(
          data=caption_features_gathered, mask=caption_masks_gathered, dim=1)
      caption_features_gathered = tf.squeeze(caption_features_gathered, axis=1)
    elif visual_w2v_model_pb2.VisualW2vModel.ATT == options.text_feature_extractor:
      attn = utils.masked_softmax(attn, mask=caption_masks_gathered, dim=-1)
      caption_features_gathered = tf.multiply(
          tf.expand_dims(attn, axis=-1), caption_features_gathered)
      caption_features_gathered = utils.masked_sum_nd(
          caption_features_gathered, mask=caption_masks_gathered, dim=1)
      caption_features_gathered = tf.squeeze(caption_features_gathered, axis=1)
    else:
      raise ValueError('Invalid text feature extractor.')

    # Export token embeddings.

    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
      _, token_embeddings = self._encode_tokens(
          tokens=tf.constant(self._open_vocabulary_list),
          embedding_dims=options.embedding_dims,
          vocabulary_list=self._open_vocabulary_list,
          initial_embedding=self._open_vocabulary_initial_embedding,
          trainable=options.train_word_embedding)
      with slim.arg_scope(
          build_hyperparams(options.text_fc_hyperparams, is_training)):
        token_embeddings = slim.fully_connected(
            token_embeddings,
            num_outputs=options.shared_dims,
            activation_fn=None,
            scope='caption')
    var_to_assign = tf.get_variable(
        name='weights_proj',
        shape=[len(self._open_vocabulary_list), options.shared_dims])
    var_to_assign = tf.assign(var_to_assign, token_embeddings)
    tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, var_to_assign)

    tf.summary.histogram('token_embedding_proj', token_embeddings)

    # Compute similarity.

    similarity = model_utils.calc_pairwise_similarity(
        feature_a=image_features,
        feature_b=caption_features_gathered,
        l2_normalize=True,
        dropout_keep_prob=options.cross_modal_dropout_keep_prob,
        is_training=is_training)

    predictions = {
        VisualW2vPredictions.image_id: image_id,
        VisualW2vPredictions.image_ids_gathered: image_ids_gathered,
        VisualW2vPredictions.similarity: similarity,
        VisualW2vPredictions.word2vec: var_to_assign,
    }
    return predictions
Exemple #5
0
    def build_loss(self, predictions, examples, **kwargs):
        """Build tf graph to compute loss.

    Args:
      predictions: dict of prediction results keyed by name.
      examples: dict of inputs keyed by name.

    Returns:
      loss_dict: dict of loss tensors keyed by name.
    """
        options = self._model_proto

        loss_dict = {}

        with tf.name_scope('losses'):

            # Extract image-level labels.

            labels = self._extract_class_label(
                class_texts=slim.flatten(predictions[
                    NOD3Predictions.training_only_caption_strings]),
                vocabulary_list=self._vocabulary_list)

            # A prediction model from caption to class

            # Loss of the multi-instance detection network.

            midn_class_logits = predictions[NOD3Predictions.midn_class_logits]
            losses = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=labels, logits=midn_class_logits)

            # Hard-negative mining.

            if options.midn_loss_negative_mining == nod3_model_pb2.NOD3Model.NONE:
                if options.classification_loss_use_sum:
                    assert False
                    loss_dict['midn_cross_entropy_loss'] = tf.multiply(
                        tf.reduce_mean(tf.reduce_sum(losses, axis=-1)),
                        options.midn_loss_weight)
                else:
                    if options.caption_as_label:
                        loss_masks = tf.to_float(
                            tf.reduce_any(labels > 0, axis=-1))
                        loss_dict['midn_cross_entropy_loss'] = tf.multiply(
                            tf.squeeze(
                                utils.masked_avg(tf.reduce_mean(losses,
                                                                axis=-1),
                                                 mask=loss_masks,
                                                 dim=0)),
                            options.midn_loss_weight)
                    else:
                        loss_dict['midn_cross_entropy_loss'] = tf.multiply(
                            tf.reduce_mean(losses), options.midn_loss_weight)
            elif options.midn_loss_negative_mining == nod3_model_pb2.NOD3Model.HARDEST:
                assert False
                loss_masks = self._midn_loss_mine_hardest_negative(
                    labels, losses)
                loss_dict['midn_cross_entropy_loss'] = tf.reduce_mean(
                    utils.masked_avg(data=losses, mask=loss_masks, dim=1))
            else:
                raise ValueError('Invalid negative mining method.')

            # Triplet loss
            if options.triplet_loss_weight > 0:
                (image_id, image_ids_gathered,
                 similarity) = (predictions[NOD3Predictions.image_id],
                                predictions[NOD3Predictions.image_id],
                                predictions[NOD3Predictions.similarity])

                distance = 1.0 - similarity
                pos_mask = tf.cast(
                    tf.equal(tf.expand_dims(image_id, axis=1),
                             tf.expand_dims(image_ids_gathered, axis=0)),
                    tf.float32)
                neg_mask = 1.0 - pos_mask
                distance_ap = utils.masked_maximum(distance, pos_mask)

                if options.triplet_loss_use_semihard:

                    # Use the semihard.

                    # negatives_outside: smallest D_an where D_an > D_ap.

                    mask = tf.cast(tf.greater(distance, distance_ap),
                                   tf.float32)
                    mask = mask * neg_mask
                    negatives_outside = utils.masked_minimum(distance, mask)

                    # negatives_inside: largest D_an.

                    negatives_inside = utils.masked_maximum(distance, neg_mask)

                    # distance_an: the semihard negatives.

                    mask_condition = tf.greater(
                        tf.reduce_sum(mask, axis=1, keepdims=True), 0.0)

                    distance_an = tf.where(mask_condition, negatives_outside,
                                           negatives_inside)

                else:

                    # Use the hardest.

                    distance_an = utils.masked_minimum(distance, neg_mask)

                losses = tf.maximum(
                    distance_ap - distance_an + options.triplet_loss_margin, 0)

                num_loss_examples = tf.count_nonzero(losses, dtype=tf.float32)
                triplet_loss = tf.reduce_mean(losses)

                loss_dict['triplet_loss'] = tf.multiply(
                    triplet_loss, options.triplet_loss_weight)

            # Losses of the online instance classifier refinement network.

            (num_proposals,
             proposals) = (predictions[DetectionResultFields.num_proposals],
                           predictions[DetectionResultFields.proposal_boxes])
            batch, max_num_proposals, _ = utils.get_tensor_shape(proposals)

            proposal_scores_0 = predictions[
                NOD3Predictions.oicr_proposal_scores + '_at_0']
            if options.oicr_use_proba_r_given_c:
                proposal_scores_0 = predictions[
                    NOD3Predictions.midn_proba_r_given_c]

            proposal_scores_0 = tf.concat([
                tf.fill([batch, max_num_proposals, 1], 0.0), proposal_scores_0
            ],
                                          axis=-1)

            global_step = tf.train.get_or_create_global_step()
            oicr_loss_mask = tf.cast(global_step > options.oicr_start_step,
                                     tf.float32)

            for i in range(options.oicr_iterations):
                proposal_scores_1 = predictions[
                    NOD3Predictions.oicr_proposal_scores +
                    '_at_{}'.format(i + 1)]
                oicr_cross_entropy_loss_at_i = model_utils.calc_oicr_loss(
                    labels,
                    num_proposals,
                    proposals,
                    tf.stop_gradient(proposal_scores_0),
                    proposal_scores_1,
                    scope='oicr_{}'.format(i + 1),
                    iou_threshold=options.oicr_iou_threshold)
                loss_dict['oicr_cross_entropy_loss_at_{}'.format(
                    i + 1)] = tf.multiply(
                        oicr_loss_mask * oicr_cross_entropy_loss_at_i,
                        options.oicr_loss_weight)

                proposal_scores_0 = tf.nn.softmax(proposal_scores_1, axis=-1)

            # Min-entropy loss.

            mask = tf.sequence_mask(num_proposals,
                                    maxlen=max_num_proposals,
                                    dtype=tf.float32)
            proba_r_given_c = predictions[NOD3Predictions.midn_proba_r_given_c]
            losses = tf.log(proba_r_given_c + _EPSILON)
            losses = tf.squeeze(utils.masked_sum_nd(data=losses,
                                                    mask=mask,
                                                    dim=1),
                                axis=1)
            min_entropy_loss = tf.reduce_mean(
                tf.reduce_sum(losses * labels, axis=1))
            min_entropy_loss = tf.multiply(min_entropy_loss,
                                           options.min_entropy_loss_weight)

            max_proba = tf.reduce_mean(
                utils.masked_maximum(data=proba_r_given_c,
                                     mask=tf.expand_dims(mask, -1),
                                     dim=1))
            tf.losses.add_loss(min_entropy_loss)

        if options.triplet_loss_weight > 0:
            tf.summary.scalar('loss/num_loss_examples', num_loss_examples)
        tf.summary.scalar('loss/min_entropy_loss', min_entropy_loss)
        tf.summary.scalar('loss/max_proba', max_proba)

        return loss_dict
Exemple #6
0
    def build_loss(self, predictions, examples, **kwargs):
        """Build tf graph to compute loss.

    Args:
      predictions: dict of prediction results keyed by name.
      examples: dict of inputs keyed by name.

    Returns:
      loss_dict: dict of loss tensors keyed by name.
    """
        options = self._model_proto

        loss_dict = {}

        with tf.name_scope('losses'):

            # Extract image-level labels.

            assert options.caption_as_label

            vocabulary_list = self._vocabulary_list
            mapping = {
                'traffic light': 'stoplight',
                'fire hydrant': 'hydrant',
                'stop sign': 'sign',
                'parking meter': 'meter',
                'sports ball': 'ball',
                'baseball bat': 'bat',
                'baseball glove': 'glove',
                'tennis racket': 'racket',
                'wine glass': 'wineglass',
                'hot dog': 'hotdog',
                'potted plant': 'plant',
                'dining table': 'table',
                'cell phone': 'cellphone',
                'teddy bear': 'teddy',
                'hair drier': 'hairdryer',
            }
            vocabulary_list = [
                mapping.get(cls, cls) for cls in vocabulary_list
            ]

            labels_gt = self._extract_class_label(
                class_texts=slim.flatten(
                    examples[InputDataFields.caption_strings]),
                vocabulary_list=vocabulary_list)

            examples[NOD4Predictions.debug_groundtruth_labels] = labels_gt
            if options.label_strategem == nod4_model_pb2.NOD4Model.EXACTLY_MATCH:
                labels = labels_gt
            elif options.label_strategem == nod4_model_pb2.NOD4Model.W2V_SYNONYM_MATCH:
                labels_ps = self._extract_pseudo_label(
                    texts=slim.flatten(
                        examples[InputDataFields.caption_strings]),
                    vocabulary_list=vocabulary_list,
                    open_vocabulary_list=self._open_vocabulary_list,
                    embedding_dims=options.embedding_dims)
                select_op = tf.reduce_any(labels_gt > 0, axis=-1)
                labels = tf.where(select_op, labels_gt, labels_ps)
                labels_ps = tf.where(select_op, tf.zeros_like(labels_ps),
                                     labels_ps)
                examples[NOD4Predictions.debug_pseudo_labels] = labels_ps
            else:
                raise ValueError('Invalid label strategy')

            # Loss of the multi-instance detection network.

            midn_class_logits = predictions[NOD4Predictions.midn_class_logits]
            losses = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=labels, logits=midn_class_logits)

            # Hard-negative mining.

            if options.midn_loss_negative_mining == nod4_model_pb2.NOD4Model.NONE:
                if options.classification_loss_use_sum:
                    assert False
                    loss_dict['midn_cross_entropy_loss'] = tf.multiply(
                        tf.reduce_mean(tf.reduce_sum(losses, axis=-1)),
                        options.midn_loss_weight)
                else:
                    if options.caption_as_label:
                        loss_masks = tf.to_float(
                            tf.reduce_any(labels > 0, axis=-1))
                        loss_dict['midn_cross_entropy_loss'] = tf.multiply(
                            tf.squeeze(
                                utils.masked_avg(tf.reduce_mean(losses,
                                                                axis=-1),
                                                 mask=loss_masks,
                                                 dim=0)),
                            options.midn_loss_weight)
                    else:
                        loss_dict['midn_cross_entropy_loss'] = tf.multiply(
                            tf.reduce_mean(losses), options.midn_loss_weight)
            elif options.midn_loss_negative_mining == nod4_model_pb2.NOD4Model.HARDEST:
                assert False
                loss_masks = self._midn_loss_mine_hardest_negative(
                    labels, losses)
                loss_dict['midn_cross_entropy_loss'] = tf.reduce_mean(
                    utils.masked_avg(data=losses, mask=loss_masks, dim=1))
            else:
                raise ValueError('Invalid negative mining method.')

            # Losses of the online instance classifier refinement network.

            (num_proposals,
             proposals) = (predictions[DetectionResultFields.num_proposals],
                           predictions[DetectionResultFields.proposal_boxes])
            batch, max_num_proposals, _ = utils.get_tensor_shape(proposals)

            proposal_scores_0 = predictions[
                NOD4Predictions.oicr_proposal_scores + '_at_0']
            if options.oicr_use_proba_r_given_c:
                proposal_scores_0 = predictions[
                    NOD4Predictions.midn_proba_r_given_c]

            proposal_scores_0 = tf.concat([
                tf.fill([batch, max_num_proposals, 1], 0.0), proposal_scores_0
            ],
                                          axis=-1)

            global_step = tf.train.get_or_create_global_step()
            oicr_loss_mask = tf.cast(global_step > options.oicr_start_step,
                                     tf.float32)

            for i in range(options.oicr_iterations):
                proposal_scores_1 = predictions[
                    NOD4Predictions.oicr_proposal_scores +
                    '_at_{}'.format(i + 1)]
                oicr_cross_entropy_loss_at_i = model_utils.calc_oicr_loss(
                    labels,
                    num_proposals,
                    proposals,
                    tf.stop_gradient(proposal_scores_0),
                    proposal_scores_1,
                    scope='oicr_{}'.format(i + 1),
                    iou_threshold=options.oicr_iou_threshold)
                loss_dict['oicr_cross_entropy_loss_at_{}'.format(
                    i + 1)] = tf.multiply(
                        oicr_loss_mask * oicr_cross_entropy_loss_at_i,
                        options.oicr_loss_weight)

                proposal_scores_0 = tf.nn.softmax(proposal_scores_1, axis=-1)

            # Min-entropy loss.

            mask = tf.sequence_mask(num_proposals,
                                    maxlen=max_num_proposals,
                                    dtype=tf.float32)
            proba_r_given_c = predictions[NOD4Predictions.midn_proba_r_given_c]
            losses = tf.log(proba_r_given_c + _EPSILON)
            losses = tf.squeeze(utils.masked_sum_nd(data=losses,
                                                    mask=mask,
                                                    dim=1),
                                axis=1)
            min_entropy_loss = tf.reduce_mean(
                tf.reduce_sum(losses * labels, axis=1))
            min_entropy_loss = tf.multiply(min_entropy_loss,
                                           options.min_entropy_loss_weight)

            max_proba = tf.reduce_mean(
                utils.masked_maximum(data=proba_r_given_c,
                                     mask=tf.expand_dims(mask, -1),
                                     dim=1))
            tf.losses.add_loss(min_entropy_loss)

        tf.summary.scalar('loss/min_entropy_loss', min_entropy_loss)
        tf.summary.scalar('loss/max_proba', max_proba)

        return loss_dict