def build_nasnet_large(images, num_classes, is_training=True, final_endpoint=None, config=None, current_step=None): """Build NASNet Large model for the ImageNet Dataset.""" hparams = (large_imagenet_config() if config is None else copy.deepcopy(config)) _update_hparams(hparams, is_training) if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info('A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # Calculate the total number of cells in the network # Add 2 for the reduction cells total_num_cells = hparams.num_cells + 2 # If ImageNet, then add an additional two for the stem cells total_num_cells += 2 normal_cell = nasnet_utils.NasNetANormalCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps, hparams.use_bounded_activation) reduction_cell = nasnet_utils.NasNetAReductionCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps, hparams.use_bounded_activation) with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim], data_format=hparams.data_format): return _build_nasnet_base(images, normal_cell=normal_cell, reduction_cell=reduction_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, stem_type='imagenet', final_endpoint=final_endpoint, current_step=current_step)
def build_nasnet_cifar( images, num_classes, is_training=True): """Build NASNet model for the Cifar Dataset.""" hparams = _cifar_config(is_training=is_training) if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info('A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # Calculate the total number of cells in the network # Add 2 for the reduction cells total_num_cells = hparams.num_cells + 2 normal_cell = nasnet_utils.NasNetANormalCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) reduction_cell = nasnet_utils.NasNetAReductionCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim], data_format=hparams.data_format): return _build_nasnet_base(images, normal_cell=normal_cell, reduction_cell=reduction_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, stem_type='cifar')