def get_pretrained_variables_to_restore(checkpoint_path, load_moving_average=False): """Gets veriables_to_restore mapping from pretrained checkpoint. Args: checkpoint_path: String. Path of checkpoint. load_moving_average: Boolean, whether load moving average variables to replace variables. Returns: Mapping of variables to restore. """ checkpoint_reader = tf.train.load_checkpoint(checkpoint_path) variable_shape_map = checkpoint_reader.get_variable_to_shape_map() variables_to_restore = {} ema_vars = mnas_utils.get_ema_vars() for v in tf.global_variables(): # Skip variables if they are in excluded scopes. is_excluded = False for scope in ['global_step', 'ExponentialMovingAverage']: if scope in v.op.name: is_excluded = True break if is_excluded: tf.logging.info('Exclude [%s] from loading from checkpoint.', v.op.name) continue variable_name_ckpt = v.op.name if load_moving_average and v in ema_vars: # To load moving average variables into non-moving version for # fine-tuning, maps variables here manually. variable_name_ckpt = v.op.name + '/ExponentialMovingAverage' if variable_name_ckpt not in variable_shape_map: tf.logging.info( 'Skip init [%s] from [%s] as it is not in the checkpoint', v.op.name, variable_name_ckpt) continue variables_to_restore[variable_name_ckpt] = v tf.logging.info('Init variable [%s] from [%s] in ckpt', v.op.name, variable_name_ckpt) return variables_to_restore
def build_model_fn(features, labels, mode, params): """The model_fn for MnasNet to be used with TPUEstimator. Args: features: `Tensor` of batched images. labels: `Tensor` of labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model """ is_training = (mode == tf.estimator.ModeKeys.TRAIN) # This is essential, if using a keras-derived model. tf.keras.backend.set_learning_phase(is_training) if isinstance(features, dict): features = features['feature'] if mode == tf.estimator.ModeKeys.PREDICT: # Adds an identify node to help TFLite export. features = tf.identity(features, 'float_image_input') # In most cases, the default data format NCHW instead of NHWC should be # used for a significant performance boost on GPU. NHWC should be used # only if the network needs to be run on CPU since the pooling operations # are only supported on NHWC. TPU uses XLA compiler to figure out best layout. if params['data_format'] == 'channels_first': assert not params['transpose_input'] # channels_first only for GPU features = tf.transpose(features, [0, 3, 1, 2]) stats_shape = [3, 1, 1] else: stats_shape = [1, 1, 3] if params['transpose_input'] and mode != tf.estimator.ModeKeys.PREDICT: features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC # Normalize the image to zero mean and unit variance. features -= tf.constant( imagenet_input.MEAN_RGB, shape=stats_shape, dtype=features.dtype) features /= tf.constant( imagenet_input.STDDEV_RGB, shape=stats_shape, dtype=features.dtype) has_moving_average_decay = (params['moving_average_decay'] > 0) tf.logging.info('Using open-source implementation for MnasNet definition.') override_params = {} if params['batch_norm_momentum']: override_params['batch_norm_momentum'] = params['batch_norm_momentum'] if params['batch_norm_epsilon']: override_params['batch_norm_epsilon'] = params['batch_norm_epsilon'] if params['dropout_rate']: override_params['dropout_rate'] = params['dropout_rate'] if params['data_format']: override_params['data_format'] = params['data_format'] if params['num_label_classes']: override_params['num_classes'] = params['num_label_classes'] if params['depth_multiplier']: override_params['depth_multiplier'] = params['depth_multiplier'] if params['depth_divisor']: override_params['depth_divisor'] = params['depth_divisor'] if params['min_depth']: override_params['min_depth'] = params['min_depth'] override_params['use_keras'] = params['use_keras'] def _build_model(model_name): """Build the model for a given model name.""" if model_name.startswith('mnasnet'): return mnasnet_models.build_mnasnet_model( features, model_name=model_name, training=is_training, override_params=override_params) elif model_name.startswith('mixnet'): return mixnet_builder.build_model( features, model_name=model_name, training=is_training, override_params=override_params) else: raise ValueError('Unknown model name {}'.format(model_name)) if params['precision'] == 'bfloat16': with tf.tpu.bfloat16_scope(): logits, _ = _build_model(params['model_name']) logits = tf.cast(logits, tf.float32) else: # params['precision'] == 'float32' logits, _ = _build_model(params['model_name']) if params['quantized_training']: try: from tensorflow.contrib import quantize # pylint: disable=g-import-not-at-top except ImportError as e: logging.exception('Quantized training is not supported in TensorFlow 2.x') raise e if is_training: tf.logging.info('Adding fake quantization ops for training.') quantize.create_training_graph( quant_delay=int(params['steps_per_epoch'] * FLAGS.quantization_delay_epochs)) else: tf.logging.info('Adding fake quantization ops for evaluation.') quantize.create_eval_graph() if mode == tf.estimator.ModeKeys.PREDICT: scaffold_fn = None if FLAGS.export_moving_average: # If the model is trained with moving average decay, to match evaluation # metrics, we need to export the model using moving average variables. restore_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) variables_to_restore = get_pretrained_variables_to_restore( restore_checkpoint, load_moving_average=True) tf.logging.info('Restoring from the latest checkpoint: %s', restore_checkpoint) tf.logging.info(str(variables_to_restore)) def restore_scaffold(): saver = tf.train.Saver(variables_to_restore) return tf.train.Scaffold(saver=saver) scaffold_fn = restore_scaffold predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }, scaffold_fn=scaffold_fn) # If necessary, in the model_fn, use params['batch_size'] instead the batch # size flags (--train_batch_size or --eval_batch_size). batch_size = params['batch_size'] # pylint: disable=unused-variable # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, params['num_label_classes']) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels, label_smoothing=params['label_smoothing']) # Add weight decay to the loss for non-batch-normalization variables. loss = cross_entropy + params['weight_decay'] * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) global_step = tf.train.get_global_step() if has_moving_average_decay: ema = tf.train.ExponentialMovingAverage( decay=params['moving_average_decay'], num_updates=global_step) ema_vars = mnas_utils.get_ema_vars() host_call = None if is_training: # Compute the current epoch and associated learning rate from global_step. current_epoch = ( tf.cast(global_step, tf.float32) / params['steps_per_epoch']) scaled_lr = params['base_learning_rate'] * (params['train_batch_size'] / 256.0) # pylint: disable=line-too-long learning_rate = mnas_utils.build_learning_rate(scaled_lr, global_step, params['steps_per_epoch']) optimizer = mnas_utils.build_optimizer(learning_rate) if params['use_tpu']: # When using TPU, wrap the optimizer with CrossShardOptimizer which # handles synchronization details between different TPU cores. To the # user, this should look like regular synchronous training. optimizer = tf.tpu.CrossShardOptimizer(optimizer) if params['add_summaries']: summary_writer = tf2.summary.create_file_writer( FLAGS.model_dir, max_queue=params['iterations_per_loop']) with summary_writer.as_default(): should_record = tf.equal(global_step % params['iterations_per_loop'], 0) with tf2.summary.record_if(should_record): tf2.summary.scalar('loss', loss, step=global_step) tf2.summary.scalar('learning_rate', learning_rate, step=global_step) tf2.summary.scalar('current_epoch', current_epoch, step=global_step) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops + tf.summary.all_v2_summary_ops()): train_op = optimizer.minimize(loss, global_step) if has_moving_average_decay: with tf.control_dependencies([train_op]): train_op = ema.apply(ema_vars) else: train_op = None eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(labels, logits): """Evaluation metric function. Evaluates accuracy. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `eval_metrics`. See https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `eval_metrics`. Args: labels: `Tensor` with shape `[batch]`. logits: `Tensor` with shape `[batch, num_classes]`. Returns: A dict of the metrics to return from evaluation. """ predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) top_5_accuracy = tf.metrics.mean(in_top_5) return { 'top_1_accuracy': top_1_accuracy, 'top_5_accuracy': top_5_accuracy, } eval_metrics = (metric_fn, [labels, logits]) num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info('number of trainable parameters: {}'.format(num_params)) # Prepares scaffold_fn if needed. scaffold_fn = None if is_training and FLAGS.init_checkpoint: variables_to_restore = get_pretrained_variables_to_restore( FLAGS.init_checkpoint, has_moving_average_decay) tf.logging.info('Initializing from pretrained checkpoint: %s', FLAGS.init_checkpoint) if FLAGS.use_tpu: def init_scaffold(): tf.train.init_from_checkpoint(FLAGS.init_checkpoint, variables_to_restore) return tf.train.Scaffold() scaffold_fn = init_scaffold else: tf.train.init_from_checkpoint(FLAGS.init_checkpoint, variables_to_restore) restore_vars_dict = None if not is_training and has_moving_average_decay: # Load moving average variables for eval. restore_vars_dict = ema.variables_to_restore(ema_vars) def eval_scaffold(): saver = tf.train.Saver(restore_vars_dict) return tf.train.Scaffold(saver=saver) scaffold_fn = eval_scaffold return tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)