예제 #1
0
def _build_nasnet_base(images,
                       normal_cell,
                       reduction_cell,
                       num_classes,
                       hparams,
                       is_training,
                       stem_type,
                       final_endpoint=None):
    """Constructs a NASNet image model."""

    end_points = {}

    def add_and_check_endpoint(endpoint_name, net):
        end_points[endpoint_name] = net
        return final_endpoint and (endpoint_name == final_endpoint)

    # Find where to place the reduction cells or stride normal cells
    reduction_indices = nasnet_utils.calc_reduction_layers(
        hparams.num_cells, hparams.num_reduction_layers)
    stem_cell = reduction_cell

    if stem_type == 'imagenet':
        stem = lambda: _imagenet_stem(images, hparams, stem_cell)
    elif stem_type == 'cifar':
        stem = lambda: _cifar_stem(images, hparams)
    else:
        raise ValueError('Unknown stem_type: ', stem_type)
    net, cell_outputs = stem()
    if add_and_check_endpoint('Stem', net): return net, end_points

    # Setup for building in the auxiliary head.
    aux_head_cell_idxes = []
    if len(reduction_indices) >= 2:
        aux_head_cell_idxes.append(reduction_indices[1] - 1)

    # Run the cells
    filter_scaling = 1.0
    # true_cell_num accounts for the stem cells
    true_cell_num = 2 if stem_type == 'imagenet' else 0
    for cell_num in range(hparams.num_cells):
        stride = 1
        if hparams.skip_reduction_layer_input:
            prev_layer = cell_outputs[-2]
        if cell_num in reduction_indices:
            filter_scaling *= hparams.filter_scaling_rate
            net = reduction_cell(net,
                                 scope='reduction_cell_{}'.format(
                                     reduction_indices.index(cell_num)),
                                 filter_scaling=filter_scaling,
                                 stride=2,
                                 prev_layer=cell_outputs[-2],
                                 cell_num=true_cell_num)
            if add_and_check_endpoint(
                    'Reduction_Cell_{}'.format(
                        reduction_indices.index(cell_num)), net):
                return net, end_points
            true_cell_num += 1
            cell_outputs.append(net)
        if not hparams.skip_reduction_layer_input:
            prev_layer = cell_outputs[-2]
        net = normal_cell(net,
                          scope='cell_{}'.format(cell_num),
                          filter_scaling=filter_scaling,
                          stride=stride,
                          prev_layer=prev_layer,
                          cell_num=true_cell_num)

        if add_and_check_endpoint('Cell_{}'.format(cell_num), net):
            return net, end_points
        true_cell_num += 1
        if (hparams.use_aux_head and cell_num in aux_head_cell_idxes
                and num_classes and is_training):
            aux_net = tf.nn.relu(net)
            _build_aux_head(aux_net,
                            end_points,
                            num_classes,
                            hparams,
                            scope='aux_{}'.format(cell_num))
        cell_outputs.append(net)

    # Final softmax layer
    with tf.variable_scope('final_layer'):
        net = tf.nn.relu(net)
        net = nasnet_utils.global_avg_pool(net)
        if add_and_check_endpoint('global_pool', net) or num_classes is None:
            return net, end_points
        net = slim.dropout(net,
                           hparams.dense_dropout_keep_prob,
                           scope='dropout')
        logits = slim.fully_connected(net, num_classes)

        if add_and_check_endpoint('Logits', logits):
            return net, end_points

        predictions = tf.nn.softmax(logits, name='predictions')
        if add_and_check_endpoint('Predictions', predictions):
            return net, end_points
    return logits, end_points
예제 #2
0
def build_nasnet_base_my(images,
                         normal_cell,
                         reduction_cell,
                         hparams,
                         final_endpoint=None):
    """Constructs a NASNet image model."""

    end_points = {}

    def add_and_check_endpoint(endpoint_name, net):
        end_points[endpoint_name] = net
        return final_endpoint and (endpoint_name == final_endpoint)

    # Find where to place the reduction cells or stride normal cells
    reduction_indices = nasnet_utils.calc_reduction_layers(
        hparams.num_cells, hparams.num_reduction_layers)
    stem_cell = reduction_cell

    stem = lambda: _imagenet_stem(images, hparams, stem_cell)

    net, cell_outputs = stem()
    # print('net:', net)
    # print('cell_outputs:', cell_outputs)
    stem4 = cell_outputs[2]
    stem8 = cell_outputs[3]
    end_points['Stem_4'] = stem4
    end_points['Stem_8'] = stem8

    #print("stem4:",stem4)
    #print("stem8:",stem8)

    if add_and_check_endpoint('Stem', net): return net, end_points

    # Run the cells
    filter_scaling = 1.0
    # true_cell_num accounts for the stem cells
    true_cell_num = 2
    for cell_num in range(hparams.num_cells):
        stride = 1
        if hparams.skip_reduction_layer_input:
            prev_layer = cell_outputs[-2]
        if cell_num in reduction_indices:
            filter_scaling *= hparams.filter_scaling_rate
            #print('before:', net)
            net = reduction_cell(net,
                                 scope='reduction_cell_{}'.format(
                                     reduction_indices.index(cell_num)),
                                 filter_scaling=filter_scaling,
                                 stride=2,
                                 prev_layer=cell_outputs[-2],
                                 cell_num=true_cell_num)
            if add_and_check_endpoint(
                    'Reduction_Cell_{}'.format(
                        reduction_indices.index(cell_num)), net):
                return net, end_points
            true_cell_num += 1
            cell_outputs.append(net)
        if not hparams.skip_reduction_layer_input:
            prev_layer = cell_outputs[-2]
        net = normal_cell(net,
                          scope='cell_{}'.format(cell_num),
                          filter_scaling=filter_scaling,
                          stride=stride,
                          prev_layer=prev_layer,
                          cell_num=true_cell_num)

        if add_and_check_endpoint('Cell_{}'.format(cell_num), net):
            return net, end_points
        true_cell_num += 1
        cell_outputs.append(net)

    return net, end_points
예제 #3
0
def _build_pnasnet_base(images,
                        normal_cell,
                        num_classes,
                        hparams,
                        is_training,
                        final_endpoint=None):
    """Constructs a PNASNet image model."""

    end_points = {}

    def add_and_check_endpoint(endpoint_name, net):
        end_points[endpoint_name] = net
        return final_endpoint and (endpoint_name == final_endpoint)

    # Find where to place the reduction cells or stride normal cells
    reduction_indices = nasnet_utils.calc_reduction_layers(
        hparams.num_cells, hparams.num_reduction_layers)

    # pylint: disable=protected-access
    stem = lambda: nasnet._imagenet_stem(images, hparams, normal_cell)
    # pylint: enable=protected-access
    net, cell_outputs = stem()
    if add_and_check_endpoint('Stem', net):
        return net, end_points

    # Setup for building in the auxiliary head.
    aux_head_cell_idxes = []
    if len(reduction_indices) >= 2:
        aux_head_cell_idxes.append(reduction_indices[1] - 1)

    # Run the cells
    filter_scaling = 1.0
    # true_cell_num accounts for the stem cells
    true_cell_num = 2
    for cell_num in range(hparams.num_cells):
        is_reduction = cell_num in reduction_indices
        stride = 2 if is_reduction else 1
        if is_reduction: filter_scaling *= hparams.filter_scaling_rate
        prev_layer = cell_outputs[-2]
        net = normal_cell(net,
                          scope='cell_{}'.format(cell_num),
                          filter_scaling=filter_scaling,
                          stride=stride,
                          prev_layer=prev_layer,
                          cell_num=true_cell_num)
        if add_and_check_endpoint('Cell_{}'.format(cell_num), net):
            return net, end_points
        true_cell_num += 1
        cell_outputs.append(net)

        if (hparams.use_aux_head and cell_num in aux_head_cell_idxes
                and num_classes and is_training):
            aux_net = tf.nn.relu(net)
            # pylint: disable=protected-access
            nasnet._build_aux_head(aux_net,
                                   end_points,
                                   num_classes,
                                   hparams,
                                   scope='aux_{}'.format(cell_num))
            # pylint: enable=protected-access

    # Final softmax layer
    with tf.variable_scope('final_layer'):
        net = tf.nn.relu(net)
        net = nasnet_utils.global_avg_pool(net)
        if add_and_check_endpoint('global_pool', net) or not num_classes:
            return net, end_points
        net = slim.dropout(net,
                           hparams.dense_dropout_keep_prob,
                           scope='dropout')
        logits = slim.fully_connected(net, num_classes)

        if add_and_check_endpoint('Logits', logits):
            return net, end_points

        predictions = tf.nn.softmax(logits, name='predictions')
        if add_and_check_endpoint('Predictions', predictions):
            return net, end_points
    return logits, end_points