def model_fn(features, mode, params): '''The model_fn to be used with TPUEstimator. Args: features: `Tensor` of batched images. labels: `Tensor` of labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model ''' def preprocess_image(image): # In most cases, the default data format NCHW instead of NHWC should be # used for a significant performance boost on GPU. NHWC should be used # only if the network needs to be run on CPU since the pooling operations # are only supported on NHWC. TPU uses XLA compiler to figure out best layout. if FLAGS.data_format == 'channels_first': assert not FLAGS.transpose_input # channels_first only for GPU image = tf.transpose(image, [0, 3, 1, 2]) if FLAGS.transpose_input and mode == tf.estimator.ModeKeys.TRAIN: image = tf.transpose(image, [3, 0, 1, 2]) # HWCN to NHWC return image def normalize_image(image): # Normalize the image to zero mean and unit variance. if FLAGS.data_format == 'channels_first': stats_shape = [3, 1, 1] else: stats_shape = [1, 1, 3] mean, std = task_info.get_mean_std(FLAGS.task_name) image -= tf.constant(mean, shape=stats_shape, dtype=image.dtype) image /= tf.constant(std, shape=stats_shape, dtype=image.dtype) return image image = features['image'] image = preprocess_image(image) image_shape = image.get_shape().as_list() tf.logging.info('image shape: {}'.format(image_shape)) is_training = (mode == tf.estimator.ModeKeys.TRAIN) if mode != tf.estimator.ModeKeys.PREDICT: labels = features['label'] else: labels = None # If necessary, in the model_fn, use params['batch_size'] instead the batch # size flags (--train_batch_size or --eval_batch_size). batch_size = params['batch_size'] # pylint: disable=unused-variable if FLAGS.unlabel_ratio and is_training: unl_bsz = features['unl_probs'].shape[0] else: unl_bsz = 0 lab_bsz = image.shape[0] - unl_bsz assert lab_bsz == batch_size metric_dict = {} global_step = tf.train.get_global_step() has_moving_average_decay = (FLAGS.moving_average_decay > 0) # This is essential, if using a keras-derived model. tf.keras.backend.set_learning_phase(is_training) tf.logging.info('Using open-source implementation.') override_params = {} if FLAGS.dropout_rate is not None: override_params['dropout_rate'] = FLAGS.dropout_rate if FLAGS.stochastic_depth_rate is not None: override_params['stochastic_depth_rate'] = FLAGS.stochastic_depth_rate if FLAGS.data_format: override_params['data_format'] = FLAGS.data_format if FLAGS.num_label_classes: override_params['num_classes'] = FLAGS.num_label_classes if FLAGS.depth_coefficient: override_params['depth_coefficient'] = FLAGS.depth_coefficient if FLAGS.width_coefficient: override_params['width_coefficient'] = FLAGS.width_coefficient def build_model(scope=None, reuse=tf.AUTO_REUSE, model_name=None, model_is_training=None, input_image=None, use_adv_bn=False, is_teacher=False): model_name = model_name or FLAGS.model_name if model_is_training is None: model_is_training = is_training if input_image is None: input_image = image input_image = normalize_image(input_image) scope_model_name = model_name if scope: scope = scope + '/' else: scope = '' with tf.variable_scope(scope + scope_model_name, reuse=reuse): if model_name.startswith('efficientnet'): logits, _ = efficientnet_builder.build_model( input_image, model_name=model_name, training=model_is_training, override_params=override_params, model_dir=FLAGS.model_dir, use_adv_bn=use_adv_bn, is_teacher=is_teacher) else: assert False, 'model {} not implemented'.format(model_name) return logits if params['use_bfloat16']: with tf.tpu.bfloat16_scope(): logits = tf.cast(build_model(), tf.float32) else: logits = build_model() if FLAGS.teacher_model_name: teacher_image = preprocess_image(features['teacher_image']) if params['use_bfloat16']: with tf.tpu.bfloat16_scope(): teacher_logits = tf.cast( build_model(scope='teacher_model', model_name=FLAGS.teacher_model_name, model_is_training=False, input_image=teacher_image, is_teacher=True), tf.float32) else: teacher_logits = build_model(scope='teacher_model', model_name=FLAGS.teacher_model_name, model_is_training=False, input_image=teacher_image, is_teacher=True) teacher_logits = tf.stop_gradient(teacher_logits) if FLAGS.teacher_softmax_temp != -1: teacher_prob = tf.nn.softmax(teacher_logits / FLAGS.teacher_softmax_temp) else: teacher_prob = None teacher_one_hot_pred = tf.argmax(teacher_logits, axis=1, output_type=labels.dtype) if mode == tf.estimator.ModeKeys.PREDICT: if has_moving_average_decay: ema = tf.train.ExponentialMovingAverage( decay=FLAGS.moving_average_decay) ema_vars = utils.get_all_variable() restore_vars_dict = ema.variables_to_restore(ema_vars) tf.logging.info( 'restored variables:\n%s', json.dumps(sorted(restore_vars_dict.keys()), indent=4)) predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=functools.partial(_scaffold_fn, restore_vars_dict=restore_vars_dict) if has_moving_average_decay else None) if has_moving_average_decay: ema_step = global_step ema = tf.train.ExponentialMovingAverage( decay=FLAGS.moving_average_decay, num_updates=ema_step) ema_vars = utils.get_all_variable() lab_labels = labels[:lab_bsz] lab_logits = logits[:lab_bsz] lab_pred = tf.argmax(lab_logits, axis=-1, output_type=labels.dtype) lab_prob = tf.nn.softmax(lab_logits) lab_acc = tf.to_float(tf.equal(lab_pred, lab_labels)) metric_dict['lab/acc'] = tf.reduce_mean(lab_acc) metric_dict['lab/pred_prob'] = tf.reduce_mean( tf.reduce_max(lab_prob, axis=-1)) one_hot_labels = tf.one_hot(lab_labels, FLAGS.num_label_classes) if FLAGS.unlabel_ratio: unl_labels = labels[lab_bsz:] unl_logits = logits[lab_bsz:] unl_pred = tf.argmax(unl_logits, axis=-1, output_type=labels.dtype) unl_prob = tf.nn.softmax(unl_logits) unl_acc = tf.to_float(tf.equal(unl_pred, unl_labels)) metric_dict['unl/acc_to_dump'] = tf.reduce_mean(unl_acc) metric_dict['unl/pred_prob'] = tf.reduce_mean( tf.reduce_max(unl_prob, axis=-1)) # compute lab_loss one_hot_labels = tf.one_hot(lab_labels, FLAGS.num_label_classes) lab_loss = tf.losses.softmax_cross_entropy( logits=lab_logits, onehot_labels=one_hot_labels, label_smoothing=FLAGS.label_smoothing, reduction=tf.losses.Reduction.NONE) if FLAGS.label_data_sample_prob != 1: # mask out part of the labeled data random_mask = tf.floor( FLAGS.label_data_sample_prob + tf.random_uniform(tf.shape(lab_loss), dtype=lab_loss.dtype)) lab_loss = tf.reduce_mean(lab_loss * random_mask) else: lab_loss = tf.reduce_mean(lab_loss) metric_dict['lab/loss'] = lab_loss if FLAGS.unlabel_ratio: if FLAGS.teacher_softmax_temp == -1: # Hard labels # Get one-hot labels if FLAGS.teacher_model_name: ext_teacher_pred = teacher_one_hot_pred[lab_bsz:] one_hot_labels = tf.one_hot(ext_teacher_pred, FLAGS.num_label_classes) else: one_hot_labels = tf.one_hot(unl_labels, FLAGS.num_label_classes) # Compute cross entropy unl_loss = tf.losses.softmax_cross_entropy( logits=unl_logits, onehot_labels=one_hot_labels, label_smoothing=FLAGS.label_smoothing) else: # Soft labels # Get teacher prob if FLAGS.teacher_model_name: unl_teacher_prob = teacher_prob[lab_bsz:] else: scaled_prob = tf.pow(features['unl_probs'], 1 / FLAGS.teacher_softmax_temp) unl_teacher_prob = scaled_prob / tf.reduce_sum( scaled_prob, axis=-1, keepdims=True) metric_dict['unl/target_prob'] = tf.reduce_mean( tf.reduce_max(unl_teacher_prob, axis=-1)) unl_loss = cross_entropy(unl_teacher_prob, unl_logits, return_mean=True) metric_dict['ext/loss'] = unl_loss else: unl_loss = 0 real_lab_bsz = tf.to_float(lab_bsz) * FLAGS.label_data_sample_prob real_unl_bsz = batch_size * FLAGS.label_data_sample_prob * FLAGS.unlabel_ratio data_loss = lab_loss * real_lab_bsz + unl_loss * real_unl_bsz data_loss = data_loss / real_lab_bsz # Add weight decay to the loss for non-batch-normalization variables. loss = data_loss + FLAGS.weight_decay * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) metric_dict['train/data_loss'] = data_loss metric_dict['train/loss'] = loss host_call = None restore_vars_dict = None if is_training: # Compute the current epoch and associated learning rate from global_step. current_epoch = (tf.cast(global_step, tf.float32) / params['steps_per_epoch']) real_train_batch_size = FLAGS.train_batch_size real_train_batch_size *= FLAGS.label_data_sample_prob scaled_lr = FLAGS.base_learning_rate * (real_train_batch_size / 256.0) if FLAGS.final_base_lr: # total number of training epochs total_epochs = FLAGS.train_steps * FLAGS.train_batch_size * 1. / FLAGS.num_train_images - 5 decay_times = math.log(FLAGS.final_base_lr / FLAGS.base_learning_rate) / math.log(0.97) decay_epochs = total_epochs / decay_times tf.logging.info( 'setting decay_epochs to {:.2f}'.format(decay_epochs) + '\n' * 3) else: decay_epochs = 2.4 * FLAGS.train_ratio learning_rate = utils.build_learning_rate( scaled_lr, global_step, params['steps_per_epoch'], decay_epochs=decay_epochs, start_from_step=FLAGS.train_steps - FLAGS.train_last_step_num, warmup_epochs=5, ) metric_dict['train/lr'] = learning_rate metric_dict['train/epoch'] = current_epoch optimizer = utils.build_optimizer(learning_rate) if FLAGS.use_tpu: # When using TPU, wrap the optimizer with CrossShardOptimizer which # handles synchronization details between different TPU cores. To the # user, this should look like regular synchronous training. optimizer = tf.tpu.CrossShardOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) tvars = tf.trainable_variables() g_vars = [] tvars = sorted(tvars, key=lambda var: var.name) for var in tvars: if 'teacher_model' not in var.name: g_vars += [var] with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step, var_list=g_vars) if has_moving_average_decay: with tf.control_dependencies([train_op]): train_op = ema.apply(ema_vars) if not FLAGS.skip_host_call: host_call = utils.construct_scalar_host_call(metric_dict) scaffold_fn = None if FLAGS.teacher_model_name or FLAGS.init_model: scaffold_fn = utils.init_from_ckpt(scaffold_fn) else: train_op = None if has_moving_average_decay: # Load moving average variables for eval. restore_vars_dict = ema.variables_to_restore(ema_vars) eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: scaffold_fn = functools.partial(_scaffold_fn, restore_vars_dict=restore_vars_dict ) if has_moving_average_decay else None def metric_fn(labels, logits): '''Evaluation metric function. Evaluates accuracy. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `eval_metrics`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `eval_metrics`. Args: labels: `Tensor` with shape `[batch]`. logits: `Tensor` with shape `[batch, num_classes]`. Returns: A dict of the metrics to return from evaluation. ''' predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) top_5_accuracy = tf.metrics.mean(in_top_5) result_dict = { 'top_1_accuracy': top_1_accuracy, 'top_5_accuracy': top_5_accuracy, } return result_dict eval_metrics = (metric_fn, [labels, logits]) num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info('number of trainable parameters: {}'.format(num_params)) return tf.estimator.tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)
def model_fn(features, labels, mode, params): """The model_fn to be used with TPUEstimator. Args: features: `Tensor` of batched images. labels: `Tensor` of one hot labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model """ if isinstance(features, dict): features = features['feature'] # In most cases, the default data format NCHW instead of NHWC should be # used for a significant performance boost on GPU. NHWC should be used # only if the network needs to be run on CPU since the pooling operations # are only supported on NHWC. TPU uses XLA compiler to figure out best layout. if FLAGS.data_format == 'channels_first': assert not FLAGS.transpose_input # channels_first only for GPU features = tf.transpose(features, [0, 3, 1, 2]) stats_shape = [3, 1, 1] else: stats_shape = [1, 1, 3] if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT: features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC is_training = (mode == tf.estimator.ModeKeys.TRAIN) has_moving_average_decay = (FLAGS.moving_average_decay > 0) # This is essential, if using a keras-derived model. tf.keras.backend.set_learning_phase(is_training) logging.info('Using open-source implementation.') override_params = {} if FLAGS.batch_norm_momentum is not None: override_params['batch_norm_momentum'] = FLAGS.batch_norm_momentum if FLAGS.batch_norm_epsilon is not None: override_params['batch_norm_epsilon'] = FLAGS.batch_norm_epsilon if FLAGS.dropout_rate is not None: override_params['dropout_rate'] = FLAGS.dropout_rate if FLAGS.survival_prob is not None: override_params['survival_prob'] = FLAGS.survival_prob if FLAGS.data_format: override_params['data_format'] = FLAGS.data_format if FLAGS.num_label_classes: override_params['num_classes'] = FLAGS.num_label_classes if FLAGS.depth_coefficient: override_params['depth_coefficient'] = FLAGS.depth_coefficient if FLAGS.width_coefficient: override_params['width_coefficient'] = FLAGS.width_coefficient def normalize_features(features, mean_rgb, stddev_rgb): """Normalize the image given the means and stddevs.""" features -= tf.constant(mean_rgb, shape=stats_shape, dtype=features.dtype) features /= tf.constant(stddev_rgb, shape=stats_shape, dtype=features.dtype) return features def build_model(): """Build model using the model_name given through the command line.""" model_builder = model_builder_factory.get_model_builder( FLAGS.model_name) normalized_features = normalize_features(features, model_builder.MEAN_RGB, model_builder.STDDEV_RGB) logits, _ = model_builder.build_model(normalized_features, model_name=FLAGS.model_name, training=is_training, override_params=override_params, model_dir=FLAGS.model_dir) return logits if params['use_bfloat16']: with tf.tpu.bfloat16_scope(): logits = tf.cast(build_model(), tf.float32) else: logits = build_model() if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) # If necessary, in the model_fn, use params['batch_size'] instead the batch # size flags (--train_batch_size or --eval_batch_size). batch_size = params['batch_size'] # pylint: disable=unused-variable # Calculate loss, which includes softmax cross entropy and L2 regularization. cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels, label_smoothing=FLAGS.label_smoothing) # Add weight decay to the loss for non-batch-normalization variables. loss = cross_entropy + FLAGS.weight_decay * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) global_step = tf.train.get_global_step() if has_moving_average_decay: ema = tf.train.ExponentialMovingAverage( decay=FLAGS.moving_average_decay, num_updates=global_step) ema_vars = utils.get_ema_vars() host_call = None restore_vars_dict = None if is_training: # Compute the current epoch and associated learning rate from global_step. current_epoch = (tf.cast(global_step, tf.float32) / params['steps_per_epoch']) scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0) logging.info('base_learning_rate = %f', FLAGS.base_learning_rate) learning_rate = utils.build_learning_rate( scaled_lr, global_step, params['steps_per_epoch'], decay_epochs=FLAGS.lr_decay_epoch) optimizer = utils.build_optimizer(learning_rate) if FLAGS.use_tpu: # When using TPU, wrap the optimizer with CrossShardOptimizer which # handles synchronization details between different TPU cores. To the # user, this should look like regular synchronous training. optimizer = tf.tpu.CrossShardOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) if has_moving_average_decay: with tf.control_dependencies([train_op]): train_op = ema.apply(ema_vars) if not FLAGS.skip_host_call: def host_call_fn(gs, lr, ce): """Training host call. Creates scalar summaries for training metrics. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `host_call`. See https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `host_call`. Args: gs: `Tensor with shape `[batch]` for the global_step lr: `Tensor` with shape `[batch]` for the learning_rate. ce: `Tensor` with shape `[batch]` for the current_epoch. Returns: List of summary ops to run on the CPU host. """ gs = gs[0] # Host call fns are executed FLAGS.iterations_per_loop times after one # TPU loop is finished, setting max_queue value to the same as number of # iterations will make the summary writer only flush the data to storage # once per loop. with tf2.summary.create_file_writer( FLAGS.model_dir, max_queue=FLAGS.iterations_per_loop).as_default(): with tf2.summary.record_if(True): tf2.summary.scalar('learning_rate', lr[0], step=gs) tf2.summary.scalar('current_epoch', ce[0], step=gs) return tf.summary.all_v2_summary_ops() # To log the loss, current learning rate, and epoch for Tensorboard, the # summary op needs to be run on the host CPU via host_call. host_call # expects [batch_size, ...] Tensors, thus reshape to introduce a batch # dimension. These Tensors are implicitly concatenated to # [params['batch_size']]. gs_t = tf.reshape(global_step, [1]) lr_t = tf.reshape(learning_rate, [1]) ce_t = tf.reshape(current_epoch, [1]) host_call = (host_call_fn, [gs_t, lr_t, ce_t]) else: train_op = None if has_moving_average_decay: # Load moving average variables for eval. restore_vars_dict = ema.variables_to_restore(ema_vars) eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(labels, logits): """Evaluation metric function. Evaluates accuracy. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `eval_metrics`. See https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `eval_metrics`. Args: labels: `Tensor` with shape `[batch, num_classes]`. logits: `Tensor` with shape `[batch, num_classes]`. Returns: A dict of the metrics to return from evaluation. """ labels = tf.argmax(labels, axis=1) predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) top_5_accuracy = tf.metrics.mean(in_top_5) return { 'top_1_accuracy': top_1_accuracy, 'top_5_accuracy': top_5_accuracy, } eval_metrics = (metric_fn, [labels, logits]) num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) logging.info('number of trainable parameters: %d', num_params) def _scaffold_fn(): saver = tf.train.Saver(restore_vars_dict) return tf.train.Scaffold(saver=saver) if has_moving_average_decay and not is_training: # Only apply scaffold for eval jobs. scaffold_fn = _scaffold_fn else: scaffold_fn = None return tf.estimator.tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)
def mnasnet_model_fn(features, labels, mode, params): """The model_fn for MnasNet to be used with TPUEstimator. Args: features: `Tensor` of batched images. labels: `Tensor` of labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model """ is_training = (mode == tf.estimator.ModeKeys.TRAIN) # This is essential, if using a keras-derived model. K.set_learning_phase(is_training) if isinstance(features, dict): features = features['feature'] if mode == tf.estimator.ModeKeys.PREDICT: # Adds an identify node to help TFLite export. features = tf.identity(features, 'float_image_input') # In most cases, the default data format NCHW instead of NHWC should be # used for a significant performance boost on GPU. NHWC should be used # only if the network needs to be run on CPU since the pooling operations # are only supported on NHWC. TPU uses XLA compiler to figure out best layout. if params['data_format'] == 'channels_first': assert not params['transpose_input'] # channels_first only for GPU features = tf.transpose(features, [0, 3, 1, 2]) stats_shape = [3, 1, 1] else: stats_shape = [1, 1, 3] if params['transpose_input'] and mode != tf.estimator.ModeKeys.PREDICT: features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC # Normalize the image to zero mean and unit variance. features -= tf.constant(imagenet_input.MEAN_RGB, shape=stats_shape, dtype=features.dtype) features /= tf.constant(imagenet_input.STDDEV_RGB, shape=stats_shape, dtype=features.dtype) has_moving_average_decay = (params['moving_average_decay'] > 0) tf.logging.info('Using open-source implementation for MnasNet definition.') override_params = {} if params['batch_norm_momentum']: override_params['batch_norm_momentum'] = params['batch_norm_momentum'] if params['batch_norm_epsilon']: override_params['batch_norm_epsilon'] = params['batch_norm_epsilon'] if params['dropout_rate']: override_params['dropout_rate'] = params['dropout_rate'] if params['data_format']: override_params['data_format'] = params['data_format'] if params['num_label_classes']: override_params['num_classes'] = params['num_label_classes'] if params['depth_multiplier']: override_params['depth_multiplier'] = params['depth_multiplier'] if params['depth_divisor']: override_params['depth_divisor'] = params['depth_divisor'] if params['min_depth']: override_params['min_depth'] = params['min_depth'] override_params['use_keras'] = params['use_keras'] if params['precision'] == 'bfloat16': with tf.contrib.tpu.bfloat16_scope(): logits, _ = mnasnet_models.build_mnasnet_model( features, model_name=params['model_name'], training=is_training, override_params=override_params) logits = tf.cast(logits, tf.float32) else: # params['precision'] == 'float32' logits, _ = mnasnet_models.build_mnasnet_model( features, model_name=params['model_name'], training=is_training, override_params=override_params) if params['quantized_training']: if is_training: tf.logging.info('Adding fake quantization ops for training.') tf.contrib.quantize.create_training_graph( quant_delay=int(params['steps_per_epoch'] * FLAGS.quantization_delay_epochs)) else: tf.logging.info('Adding fake quantization ops for evaluation.') tf.contrib.quantize.create_eval_graph() if mode == tf.estimator.ModeKeys.PREDICT: scaffold_fn = None if FLAGS.export_moving_average: # If the model is trained with moving average decay, to match evaluation # metrics, we need to export the model using moving average variables. restore_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) variables_to_restore = get_pretrained_variables_to_restore( restore_checkpoint, load_moving_average=True) tf.logging.info('Restoring from the latest checkpoint: %s', restore_checkpoint) tf.logging.info(str(variables_to_restore)) def restore_scaffold(): saver = tf.train.Saver(variables_to_restore) return tf.train.Scaffold(saver=saver) scaffold_fn = restore_scaffold predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }, scaffold_fn=scaffold_fn) # If necessary, in the model_fn, use params['batch_size'] instead the batch # size flags (--train_batch_size or --eval_batch_size). batch_size = params['batch_size'] # pylint: disable=unused-variable # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, params['num_label_classes']) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels, label_smoothing=params['label_smoothing']) # Add weight decay to the loss for non-batch-normalization variables. loss = cross_entropy + params['weight_decay'] * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) global_step = tf.train.get_global_step() if has_moving_average_decay: ema = tf.train.ExponentialMovingAverage( decay=params['moving_average_decay'], num_updates=global_step) ema_vars = utils.get_ema_vars() host_call = None if is_training: # Compute the current epoch and associated learning rate from global_step. current_epoch = (tf.cast(global_step, tf.float32) / params['steps_per_epoch']) scaled_lr = params['base_learning_rate'] * (params['train_batch_size'] / 256.0) # pylint: disable=line-too-long learning_rate = utils.build_learning_rate(scaled_lr, global_step, params['steps_per_epoch']) optimizer = utils.build_optimizer(learning_rate) if params['use_tpu']: # When using TPU, wrap the optimizer with CrossShardOptimizer which # handles synchronization details between different TPU cores. To the # user, this should look like regular synchronous training. optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) if has_moving_average_decay: with tf.control_dependencies([train_op]): train_op = ema.apply(ema_vars) if not params['skip_host_call']: def host_call_fn(gs, loss, lr, ce): """Training host call. Creates scalar summaries for training metrics. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `host_call`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `host_call`. Args: gs: `Tensor with shape `[batch]` for the global_step loss: `Tensor` with shape `[batch]` for the training loss. lr: `Tensor` with shape `[batch]` for the learning_rate. ce: `Tensor` with shape `[batch]` for the current_epoch. Returns: List of summary ops to run on the CPU host. """ gs = gs[0] # Host call fns are executed params['iterations_per_loop'] times after # one TPU loop is finished, setting max_queue value to the same as # number of iterations will make the summary writer only flush the # data to storage once per loop. with tf.contrib.summary.create_file_writer( FLAGS.model_dir, max_queue=params['iterations_per_loop']).as_default(): with tf.contrib.summary.always_record_summaries(): tf.contrib.summary.scalar('loss', loss[0], step=gs) tf.contrib.summary.scalar('learning_rate', lr[0], step=gs) tf.contrib.summary.scalar('current_epoch', ce[0], step=gs) return tf.contrib.summary.all_summary_ops() # To log the loss, current learning rate, and epoch for Tensorboard, the # summary op needs to be run on the host CPU via host_call. host_call # expects [batch_size, ...] Tensors, thus reshape to introduce a batch # dimension. These Tensors are implicitly concatenated to # [params['batch_size']]. gs_t = tf.reshape(global_step, [1]) loss_t = tf.reshape(loss, [1]) lr_t = tf.reshape(learning_rate, [1]) ce_t = tf.reshape(current_epoch, [1]) host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t]) else: train_op = None eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(labels, logits): """Evaluation metric function. Evaluates accuracy. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `eval_metrics`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `eval_metrics`. Args: labels: `Tensor` with shape `[batch]`. logits: `Tensor` with shape `[batch, num_classes]`. Returns: A dict of the metrics to return from evaluation. """ predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) top_5_accuracy = tf.metrics.mean(in_top_5) return { 'top_1_accuracy': top_1_accuracy, 'top_5_accuracy': top_5_accuracy, } eval_metrics = (metric_fn, [labels, logits]) num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info('number of trainable parameters: {}'.format(num_params)) # Prepares scaffold_fn if needed. scaffold_fn = None if is_training and FLAGS.init_checkpoint: variables_to_restore = get_pretrained_variables_to_restore( FLAGS.init_checkpoint, has_moving_average_decay) tf.logging.info('Initializing from pretrained checkpoint: %s', FLAGS.init_checkpoint) if FLAGS.use_tpu: def init_scaffold(): tf.train.init_from_checkpoint(FLAGS.init_checkpoint, variables_to_restore) return tf.train.Scaffold() scaffold_fn = init_scaffold else: tf.train.init_from_checkpoint(FLAGS.init_checkpoint, variables_to_restore) restore_vars_dict = None if not is_training and has_moving_average_decay: # Load moving average variables for eval. restore_vars_dict = ema.variables_to_restore(ema_vars) def eval_scaffold(): saver = tf.train.Saver(restore_vars_dict) return tf.train.Scaffold(saver=saver) scaffold_fn = eval_scaffold return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)
def final_model_fn(features, labels, mode, params): """The model_fn for ConvNet to be used with TPUEstimator. Args: features: `Tensor` of batched images. labels: `Tensor` of labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model """ if isinstance(features, dict): features = features['feature'] # In most cases, the default data format NCHW instead of NHWC should be # used for a significant performance boost on GPU/TPU. NHWC should be used # only if the network needs to be run on CPU since the pooling operations # are only supported on NHWC. if FLAGS.data_format == 'channels_first': if not FLAGS.transpose_input: # channels_first only for GPU raise ValueError('The option transpose_input is set to False') features = tf.transpose(features, [0, 3, 1, 2]) if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT: features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC # Normalize the image to zero mean and unit variance. features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype) features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype) is_training = (mode == tf.estimator.ModeKeys.TRAIN) has_moving_average_decay = (FLAGS.moving_average_decay > 0) # This is essential, if using a keras-derived model. K.set_learning_phase(is_training) tf.logging.info('Using open-source implementation for MnasNet definition.') # Override params when necessary override_params = utils.get_override_params_dict(FLAGS) logits, _ = models.build_model(features, model_name=FLAGS.model_name, training=is_training, override_params=override_params) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) # If necessary, in the model_fn, use params['batch_size'] instead the batch # size flags (--train_batch_size or --eval_batch_size). batch_size = params['batch_size'] # pylint: disable=unused-variable # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels, label_smoothing=FLAGS.label_smoothing) # Add weight decay to the loss for non-batch-normalization variables. loss = cross_entropy + FLAGS.weight_decay * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) global_step = tf.train.get_global_step() if has_moving_average_decay: ema = tf.train.ExponentialMovingAverage( decay=FLAGS.moving_average_decay, num_updates=global_step) ema_vars = tf.trainable_variables() + tf.get_collection('moving_vars') for v in tf.global_variables(): # We maintain mva for batch norm moving mean and variance as well. if 'moving_mean' in v.name or 'moving_variance' in v.name: ema_vars.append(v) ema_vars = list(set(ema_vars)) host_call = None restore_vars_dict = None if is_training: # Compute the current epoch and associated learning rate from global_step. current_epoch = (tf.cast(global_step, tf.float32) / params['steps_per_epoch']) scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0) learning_rate = utils.build_learning_rate(scaled_lr, global_step, params['steps_per_epoch']) optimizer = utils.build_optimizer(learning_rate) if FLAGS.use_tpu: # When using TPU, wrap the optimizer with CrossShardOptimizer which # handles synchronization details between different TPU cores. To the # user, this should look like regular synchronous training. optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) if has_moving_average_decay: with tf.control_dependencies([train_op]): train_op = ema.apply(ema_vars) if not FLAGS.skip_host_call: # To log the loss, current learning rate, and epoch for Tensorboard, the # summary op needs to be run on the host CPU via host_call. host_call # expects [batch_size, ...] Tensors, thus reshape to introduce a batch # dimension. These Tensors are implicitly concatenated to # [params['batch_size']]. gs_t = tf.reshape(global_step, [1]) loss_t = tf.reshape(loss, [1]) lr_t = tf.reshape(learning_rate, [1]) ce_t = tf.reshape(current_epoch, [1]) host_call = (train_host_call_fn, [gs_t, loss_t, lr_t, ce_t]) else: train_op = None if has_moving_average_decay: # Load moving average variables for eval. restore_vars_dict = ema.variables_to_restore(ema_vars) eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (eval_metric_fn, [labels, logits]) num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info('number of trainable parameters: {}'.format(num_params)) def _scaffold_fn(): saver = tf.train.Saver(restore_vars_dict) return tf.train.Scaffold(saver=saver) return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics, scaffold_fn=_scaffold_fn if has_moving_average_decay else None)
def model_fn(features, labels, mode, params): """The model_fn to be used with TPUEstimator. Args: features: `Tensor` of batched images. labels: `Tensor` of labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model """ if isinstance(features, dict): features = features['feature'] stats_shape = [1, 1, 3] is_training = (mode == tf.estimator.ModeKeys.TRAIN) has_moving_average_decay = (FLAGS.moving_average_decay > 0) # This is essential, if using a keras-derived model. tf.keras.backend.set_learning_phase(is_training) tf.logging.info('Using open-source implementation.') override_params = {} if FLAGS.batch_norm_momentum is not None: override_params['batch_norm_momentum'] = FLAGS.batch_norm_momentum if FLAGS.batch_norm_epsilon is not None: override_params['batch_norm_epsilon'] = FLAGS.batch_norm_epsilon if FLAGS.dropout_rate is not None: override_params['dropout_rate'] = FLAGS.dropout_rate if FLAGS.drop_connect_rate is not None: override_params['drop_connect_rate'] = FLAGS.drop_connect_rate if FLAGS.num_label_classes: override_params['num_classes'] = FLAGS.num_label_classes if FLAGS.depth_coefficient: override_params['depth_coefficient'] = FLAGS.depth_coefficient if FLAGS.width_coefficient: override_params['width_coefficient'] = FLAGS.width_coefficient def normalize_features(features, mean_rgb, stddev_rgb): """Normalize the image given the means and stddevs.""" features -= tf.constant(mean_rgb, shape=stats_shape, dtype=features.dtype) features /= tf.constant(stddev_rgb, shape=stats_shape, dtype=features.dtype) return features def build_model(): """Build model using the model_name given through the command line.""" model_builder = None if FLAGS.model_name.startswith('efficientnet'): model_builder = efficientnet_builder else: raise ValueError('Model must be either efficientnet-b*') normalized_features = normalize_features(features, model_builder.MEAN_RGB, model_builder.STDDEV_RGB) logits, _ = model_builder.build_model(normalized_features, model_name=FLAGS.model_name, training=is_training, override_params=override_params, model_dir=FLAGS.model_dir) return logits logits = build_model() if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels, label_smoothing=FLAGS.label_smoothing) # Add weight decay to the loss for non-batch-normalization variables. loss = cross_entropy + FLAGS.weight_decay * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) global_step = tf.train.get_global_step() if has_moving_average_decay: ema = tf.train.ExponentialMovingAverage( decay=FLAGS.moving_average_decay, num_updates=global_step) ema_vars = utils.get_ema_vars() train_op = None restore_vars_dict = None training_hooks = [] if is_training: # Compute the current epoch and associated learning rate from global_step. current_epoch = (tf.cast(global_step, tf.float32) / params['steps_per_epoch']) scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0) learning_rate = utils.build_learning_rate(scaled_lr, global_step, params['steps_per_epoch']) optimizer = utils.build_optimizer(learning_rate, optimizer_name='adam') # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) if has_moving_average_decay: with tf.control_dependencies([train_op]): train_op = ema.apply(ema_vars) predictions = tf.argmax(logits, axis=1) top1_accuray = tf.metrics.accuracy(labels, predictions) logging_hook = tf.train.LoggingTensorHook( { "loss": loss, "accuracy": top1_accuray[1], "step": global_step }, every_n_iter=1) training_hooks.append(logging_hook) eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: predictions = tf.argmax(logits, axis=1) top1_accuray = tf.metrics.accuracy(labels, predictions) eval_metrics = {'val_accuracy': top1_accuray} num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info('number of trainable parameters: {}'.format(num_params)) scaffold = None if has_moving_average_decay and not is_training: # Only apply scaffold for eval jobs. saver = tf.train.Saver(restore_vars_dict) scaffold = tf.train.Scaffold(saver=saver) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks, eval_metric_ops=eval_metrics, scaffold=scaffold)
def model_fn(features, labels, mode, params=None): """The model_fn to be used with TPUEstimator. Args: features: `Tensor` of batched images. labels: `Tensor` of one hot labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}` Returns: A `TPUEstimatorSpec` for the model """ if isinstance(features, dict): features = features["feature"] # In most cases, the default data format NCHW instead of NHWC should be # used for a significant performance boost on GPU. NHWC should be used # only if the network needs to be run on CPU since the pooling operations # are only supported on NHWC. TPU uses XLA compiler to figure out best layout. if context.get_hparam("data_format") == "channels_first": assert not context.get_hparam("transpose_input") # channels_first only for GPU features = tf.transpose(features, [0, 3, 1, 2]) stats_shape = [3, 1, 1] else: stats_shape = [1, 1, 3] #if context.get_hparam("transpose_input") and mode != tf.estimator.ModeKeys.PREDICT: # features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC is_training = mode == tf.estimator.ModeKeys.TRAIN has_moving_average_decay = context.get_hparam("moving_average_decay") > 0 # This is essential, if using a keras-derived model. tf.keras.backend.set_learning_phase(is_training) logging.info("Using open-source implementation.") override_params = {} #if context.get_hparam("batch_norm_momentum") is not None: # override_params["batch_norm_momentum"] = context.get_hparam("batch_norm_momentum") #if context.get_hparam("batch_norm_epsilon") is not None: # override_params["batch_norm_epsilon"] = context.get_hparam("batch_norm_epsilon") # if context.get_hparam("dropout_rate") is not None: # override_params["dropout_rate"] = context.get_hparam("dropout_rate") # if context.get_hparam("survival_prob") is not None: # override_params["survival_prob"] = context.get_hparam("survival_prob") # if context.get_hparam("data_format"): # override_params["data_format"] = context.get_hparam("data_format") # if context.get_hparam("num_label_classes"): # override_params["num_classes"] = context.get_hparam("num_label_classes") # if context.get_hparam("depth_coefficient"): # override_params["depth_coefficient"] = context.get_hparam("depth_coefficient") # if context.get_hparam("width_coefficient"): # override_params["width_coefficient"] = context.get_hparam("width_coefficient") def normalize_features(features, mean_rgb, stddev_rgb): """Normalize the image given the means and stddevs.""" features -= tf.constant(mean_rgb, shape=stats_shape, dtype=features.dtype) features /= tf.constant(stddev_rgb, shape=stats_shape, dtype=features.dtype) return features def build_model(): """Build model using the model_name given through the command line.""" model_builder = model_builder_factory.get_model_builder( context.get_hparam("model_name"), ) normalized_features = normalize_features( features, model_builder.MEAN_RGB, model_builder.STDDEV_RGB ) logits, _ = model_builder.build_model( normalized_features, model_name=context.get_hparam("model_name"), training=is_training, override_params=override_params, #model_dir=context.get_hparam("model_dir"), ) return logits logits = build_model() # Calculate loss, which includes softmax cross entropy and L2 regularization. cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels, label_smoothing=context.get_hparam("label_smoothing") ) # Add weight decay to the loss for non-batch-normalization variables. loss = cross_entropy + context.get_hparam("weight_decay") * tf.add_n( [ tf.nn.l2_loss(v) for v in tf.trainable_variables() if "batch_normalization" not in v.name ] ) global_step = tf.train.get_global_step() if has_moving_average_decay: ema = tf.train.ExponentialMovingAverage( decay=context.get_hparam("moving_average_decay"), num_updates=global_step ) ema_vars = utils.get_ema_vars() restore_vars_dict = None train_op = None if is_training: # Compute the current epoch and associated learning rate from global_step. current_epoch = tf.cast(global_step, tf.float32) / context.get_hparam("steps_per_epoch") scaled_lr = context.get_hparam("base_learning_rate") * (context.get_hparam("train_batch_size") / 256.0) logging.info("base_learning_rate = %f", context.get_hparam("base_learning_rate")) learning_rate = utils.build_learning_rate( scaled_lr, global_step, context.get_hparam("steps_per_epoch"), ) optimizer = utils.build_optimizer(context, learning_rate) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) if has_moving_average_decay: with tf.control_dependencies([train_op]): train_op = ema.apply(ema_vars) if has_moving_average_decay: # Load moving average variables for eval. restore_vars_dict = ema.variables_to_restore(ema_vars) eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(labels, logits): """Evaluation metric function. Evaluates accuracy. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `eval_metrics`. See https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `eval_metrics`. Args: labels: `Tensor` with shape `[batch, num_classes]`. logits: `Tensor` with shape `[batch, num_classes]`. Returns: A dict of the metrics to return from evaluation. """ labels = tf.argmax(labels, axis=1) predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) top_5_accuracy = tf.metrics.mean(in_top_5) return { "top_1_accuracy": top_1_accuracy, "top_5_accuracy": top_5_accuracy, } eval_metrics = metric_fn(labels, logits) num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) logging.info("number of trainable parameters: %d", num_params) return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metrics, )
def build_model(): logits, _ = efficientnet_builder.build_model( features, model_name=FLAGS.model_name, training=is_training, override_params=override_params, model_dir=FLAGS.model_dir) return logits if params['use_bfloat16']: with tf.contrib.tpu.bfloat16_scope(): logits = tf.cast(build_model(), tf.float32) else: logits = build_model() if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) # If necessary, in the model_fn, use params['batch_size'] instead the batch # size flags (--train_batch_size or --eval_batch_size). batch_size = params['batch_size'] # pylint: disable=unused-variable # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels, label_smoothing=FLAGS.label_smoothing) # Add weight decay to the loss for non-batch-normalization variables. loss = cross_entropy + FLAGS.weight_decay * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) global_step = tf.train.get_global_step() if has_moving_average_decay: ema = tf.train.ExponentialMovingAverage( decay=FLAGS.moving_average_decay, num_updates=global_step) ema_vars = tf.trainable_variables() + tf.get_collection('moving_vars') for v in tf.global_variables(): # We maintain mva for batch norm moving mean and variance as well. if 'moving_mean' in v.name or 'moving_variance' in v.name: ema_vars.append(v) ema_vars = list(set(ema_vars)) host_call = None restore_vars_dict = None if is_training: # Compute the current epoch and associated learning rate from global_step. current_epoch = (tf.cast(global_step, tf.float32) / params['steps_per_epoch']) scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0) learning_rate = utils.build_learning_rate(scaled_lr, global_step, params['steps_per_epoch']) optimizer = utils.build_optimizer(learning_rate) if FLAGS.use_tpu: # When using TPU, wrap the optimizer with CrossShardOptimizer which # handles synchronization details between different TPU cores. To the # user, this should look like regular synchronous training. optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) if has_moving_average_decay: with tf.control_dependencies([train_op]): train_op = ema.apply(ema_vars) if not FLAGS.skip_host_call: def host_call_fn(gs, lr, ce): """Training host call. Creates scalar summaries for training metrics. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `host_call`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `host_call`. Args: gs: `Tensor with shape `[batch]` for the global_step lr: `Tensor` with shape `[batch]` for the learning_rate. ce: `Tensor` with shape `[batch]` for the current_epoch. Returns: List of summary ops to run on the CPU host. """ gs = gs[0] # Host call fns are executed FLAGS.iterations_per_loop times after one # TPU loop is finished, setting max_queue value to the same as number of # iterations will make the summary writer only flush the data to storage # once per loop. with tf.contrib.summary.create_file_writer( FLAGS.model_dir, max_queue=FLAGS.iterations_per_loop).as_default(): with tf.contrib.summary.always_record_summaries(): tf.contrib.summary.scalar('learning_rate', lr[0], step=gs) tf.contrib.summary.scalar('current_epoch', ce[0], step=gs) return tf.contrib.summary.all_summary_ops() # To log the loss, current learning rate, and epoch for Tensorboard, the # summary op needs to be run on the host CPU via host_call. host_call # expects [batch_size, ...] Tensors, thus reshape to introduce a batch # dimension. These Tensors are implicitly concatenated to # [params['batch_size']]. gs_t = tf.reshape(global_step, [1]) lr_t = tf.reshape(learning_rate, [1]) ce_t = tf.reshape(current_epoch, [1]) host_call = (host_call_fn, [gs_t, lr_t, ce_t]) else: train_op = None if has_moving_average_decay: # Load moving average variables for eval. restore_vars_dict = ema.variables_to_restore(ema_vars) eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(labels, logits): """Evaluation metric function. Evaluates accuracy. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `eval_metrics`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `eval_metrics`. Args: labels: `Tensor` with shape `[batch]`. logits: `Tensor` with shape `[batch, num_classes]`. Returns: A dict of the metrics to return from evaluation. """ predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) top_5_accuracy = tf.metrics.mean(in_top_5) return { 'top_1_accuracy': top_1_accuracy, 'top_5_accuracy': top_5_accuracy, } eval_metrics = (metric_fn, [labels, logits]) num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info('number of trainable parameters: {}'.format(num_params)) def _scaffold_fn(): saver = tf.train.Saver(restore_vars_dict) return tf.train.Scaffold(saver=saver) return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics, scaffold_fn=_scaffold_fn if has_moving_average_decay else None)