예제 #1
0
 def get_logits():
     """Return the logits."""
     end_points, aux_logits = None, None
     if FLAGS.model_type == 'resnet':
         avg_pool = model.resnet_v1_model(feature, labels, mode, params)
     else:
         assert False
     name = 'final_dense_dst'
     with tf.variable_scope('target_CLS'):
         logits = tf.layers.dense(
             inputs=avg_pool,
             units=bird_num_classes,
             kernel_initializer=tf.random_normal_initializer(
                 stddev=.01),
             name=name)
         if end_points is not None:
             aux_pool = end_points['AuxLogits_Pool']
             aux_logits = tf.layers.dense(
                 inputs=aux_pool,
                 units=bird_num_classes,
                 kernel_initializer=tf.random_normal_initializer(
                     stddev=.001),
                 name='Aux{}'.format(name))
     return logits, aux_logits, end_points
예제 #2
0
def get_logits(feature, params, mode):
    """Returns the network logits."""
    end_points = None
    avg_pool = model.resnet_v1_model(feature, mode, params)
    return avg_pool, end_points