예제 #1
0
def _build_aux_head(net, end_points, num_classes, hparams, scope):
    """Auxiliary head used for all models across all datasets."""
    aux_scaling = 1.0
    # TODO(huangyp): double check aux_scaling with vrv@.
    if hasattr(hparams, 'aux_scaling'):
        aux_scaling = hparams.aux_scaling
    tf.logging.info('aux scaling: {}'.format(aux_scaling))
    with tf.variable_scope(scope, custom_getter=network_utils.bp16_getter):
        aux_logits = tf.identity(net)
        with tf.variable_scope('aux_logits'):
            aux_logits = slim.avg_pool2d(aux_logits, [5, 5],
                                         stride=3,
                                         padding='VALID')
            aux_logits = slim.conv2d(aux_logits,
                                     int(128 * aux_scaling), [1, 1],
                                     scope='proj')
            aux_logits = network_utils.batch_norm(aux_logits, scope='aux_bn0')
            aux_logits = tf.nn.relu(aux_logits)
            # Shape of feature map before the final layer.
            shape = aux_logits.shape
            if hparams.data_format == 'NHWC':
                shape = shape[1:3]
            else:
                shape = shape[2:4]
            aux_logits = slim.conv2d(aux_logits,
                                     int(768 * aux_scaling),
                                     shape,
                                     padding='VALID')
            aux_logits = network_utils.batch_norm(aux_logits, scope='aux_bn1')
            aux_logits = tf.nn.relu(aux_logits)
            aux_logits = tf.contrib.layers.flatten(aux_logits)
            aux_logits = slim.fully_connected(aux_logits, num_classes)
            end_point_name = ('aux_logits' if 'aux_logits' not in end_points
                              else 'aux_logits_2')
            end_points[end_point_name] = tf.cast(aux_logits, tf.float32)
예제 #2
0
def _imagenet_stem(inputs, hparams, stem_cell, filter_scaling_rate):
    """Stem used for models trained on ImageNet."""
    num_stem_cells = 2

    # 149 x 149 x 32
    num_stem_filters = hparams.stem_reduction_size
    with tf.variable_scope('stem', custom_getter=network_utils.bp16_getter):
        net = slim.conv2d(inputs,
                          num_stem_filters, [3, 3],
                          stride=2,
                          scope='conv0',
                          padding='VALID')
        net = network_utils.batch_norm(net, scope='conv0_bn')
        tf.logging.info('imagenet_stem shape: {}'.format(net.shape))
    # Run the reduction cells
    cell_outputs = [None, net]
    filter_scaling = 1.0 / (filter_scaling_rate**num_stem_cells)
    for cell_num in range(num_stem_cells):
        net = stem_cell(net,
                        scope='cell_stem_{}'.format(cell_num),
                        filter_scaling=filter_scaling,
                        stride=2,
                        prev_layer=cell_outputs[-2],
                        cell_num=cell_num)
        cell_outputs.append(net)
        filter_scaling *= filter_scaling_rate
        tf.logging.info(
            'imagenet_stem net shape at reduction layer {}: {}'.format(
                cell_num, net.shape))
    return net, cell_outputs
예제 #3
0
def _basic_stem(inputs, hparams):
  num_stem_filters = hparams.stem_reduction_size
  with tf.variable_scope('stem', custom_getter=network_utils.bp16_getter):
    net = slim.conv2d(
        inputs, num_stem_filters, [3, 3], stride=1, scope='conv0',
        padding='VALID')
    net = network_utils.batch_norm(net, scope='conv0_bn')
    tf.logging.info('basic_stem shape: {}'.format(net.shape))
  return net, [None, net]