Example #1
0
 def testGlobalAvgPool(self):
   data_formats = ['NHWC', 'NCHW']
   inputs = tf.placeholder(tf.float32, (5, 10, 20, 10))
   for data_format in data_formats:
     output = network_utils.global_avg_pool(
         inputs, data_format)
     self.assertEqual(output.shape, [5, 10])
Example #2
0
def _build_network_base(images, normal_cell, reduction_cell, num_classes,
                        hparams, is_training):
    """Constructs a AmoebaNet image model."""
    if hparams.get('use_bp16', False) and hparams.get('use_tpu', False):
        images = tf.cast(images, dtype=tf.bfloat16)
    end_points = {}
    filter_scaling_rate = 2
    # Find where to place the reduction cells or stride normal cells
    reduction_indices = network_utils.calc_reduction_layers(
        hparams.num_cells, hparams.num_reduction_layers)
    stem_cell = reduction_cell

    net, cell_outputs = _imagenet_stem(images, hparams, stem_cell,
                                       filter_scaling_rate)

    # Setup for building in the auxiliary head.
    aux_head_cell_idxes = []
    if len(reduction_indices) >= 2:
        aux_head_cell_idxes.append(reduction_indices[1] - 1)

    # Run the cells
    filter_scaling = 1.0
    # true_cell_num accounts for the stem cells
    true_cell_num = 2
    for cell_num in range(hparams.num_cells):
        tf.logging.info('Current cell num: {}'.format(true_cell_num))
        stride = 1

        prev_layer = cell_outputs[-2]
        if cell_num in reduction_indices:
            filter_scaling *= filter_scaling_rate
            net = reduction_cell(net,
                                 scope='reduction_cell_{}'.format(
                                     reduction_indices.index(cell_num)),
                                 filter_scaling=filter_scaling,
                                 stride=2,
                                 prev_layer=cell_outputs[-2],
                                 cell_num=true_cell_num)
            true_cell_num += 1
            cell_outputs.append(net)

        prev_layer = cell_outputs[-2]
        net = normal_cell(net,
                          scope='cell_{}'.format(cell_num),
                          filter_scaling=filter_scaling,
                          stride=stride,
                          prev_layer=prev_layer,
                          cell_num=true_cell_num)
        true_cell_num += 1
        if (hparams.use_aux_head and cell_num in aux_head_cell_idxes
                and num_classes and is_training):
            aux_net = tf.nn.relu(net)
            _build_aux_head(aux_net,
                            end_points,
                            num_classes,
                            hparams,
                            scope='aux_{}'.format(cell_num))
        cell_outputs.append(net)
        tf.logging.info('net shape at layer {}: {}'.format(
            cell_num, net.shape))

    # Final softmax layer
    with tf.variable_scope('final_layer',
                           custom_getter=network_utils.bp16_getter):
        net = tf.nn.relu(net)
        net = network_utils.global_avg_pool(net)
        net = slim.dropout(net,
                           hparams.dense_dropout_keep_prob,
                           scope='dropout')
        logits = slim.fully_connected(net, num_classes)
    logits = tf.cast(logits, tf.float32)
    predictions = tf.nn.softmax(logits, name='predictions')
    end_points['logits'] = logits
    end_points['predictions'] = predictions
    return logits, end_points