def _build_aux_head(net, end_points, num_classes, hparams, scope): """Auxiliary head used for all models across all datasets.""" aux_scaling = 1.0 # TODO(huangyp): double check aux_scaling with vrv@. if hasattr(hparams, 'aux_scaling'): aux_scaling = hparams.aux_scaling tf.logging.info('aux scaling: {}'.format(aux_scaling)) with tf.variable_scope(scope, custom_getter=network_utils.bp16_getter): aux_logits = tf.identity(net) with tf.variable_scope('aux_logits'): aux_logits = slim.avg_pool2d(aux_logits, [5, 5], stride=3, padding='VALID') aux_logits = slim.conv2d(aux_logits, int(128 * aux_scaling), [1, 1], scope='proj') aux_logits = network_utils.batch_norm(aux_logits, scope='aux_bn0') aux_logits = tf.nn.relu(aux_logits) # Shape of feature map before the final layer. shape = aux_logits.shape if hparams.data_format == 'NHWC': shape = shape[1:3] else: shape = shape[2:4] aux_logits = slim.conv2d(aux_logits, int(768 * aux_scaling), shape, padding='VALID') aux_logits = network_utils.batch_norm(aux_logits, scope='aux_bn1') aux_logits = tf.nn.relu(aux_logits) aux_logits = tf.contrib.layers.flatten(aux_logits) aux_logits = slim.fully_connected(aux_logits, num_classes) end_point_name = ('aux_logits' if 'aux_logits' not in end_points else 'aux_logits_2') end_points[end_point_name] = tf.cast(aux_logits, tf.float32)
def _imagenet_stem(inputs, hparams, stem_cell, filter_scaling_rate): """Stem used for models trained on ImageNet.""" num_stem_cells = 2 # 149 x 149 x 32 num_stem_filters = hparams.stem_reduction_size with tf.variable_scope('stem', custom_getter=network_utils.bp16_getter): net = slim.conv2d(inputs, num_stem_filters, [3, 3], stride=2, scope='conv0', padding='VALID') net = network_utils.batch_norm(net, scope='conv0_bn') tf.logging.info('imagenet_stem shape: {}'.format(net.shape)) # Run the reduction cells cell_outputs = [None, net] filter_scaling = 1.0 / (filter_scaling_rate**num_stem_cells) for cell_num in range(num_stem_cells): net = stem_cell(net, scope='cell_stem_{}'.format(cell_num), filter_scaling=filter_scaling, stride=2, prev_layer=cell_outputs[-2], cell_num=cell_num) cell_outputs.append(net) filter_scaling *= filter_scaling_rate tf.logging.info( 'imagenet_stem net shape at reduction layer {}: {}'.format( cell_num, net.shape)) return net, cell_outputs
def _basic_stem(inputs, hparams): num_stem_filters = hparams.stem_reduction_size with tf.variable_scope('stem', custom_getter=network_utils.bp16_getter): net = slim.conv2d( inputs, num_stem_filters, [3, 3], stride=1, scope='conv0', padding='VALID') net = network_utils.batch_norm(net, scope='conv0_bn') tf.logging.info('basic_stem shape: {}'.format(net.shape)) return net, [None, net]