def model_fn(features, labels, mode, config, params): """Estimator model function.""" del labels del config del params tf.get_variable_scope().set_initializer( tf.variance_scaling_initializer(1.0, mode="fan_avg", distribution="uniform")) if mode == tf.estimator.ModeKeys.PREDICT: predictions = model_params.estimator_prediction_fn(features) if include_features_in_predictions: predictions.update(features) if decode_keys: # Decode the raw ids into strings in prediction. def decode_host_call(tensor_dict): for key in decode_keys: predictions[key] = public_parsing_ops.decode( tensor_dict[key], model_params.vocab_filename, model_params.encoder_type) return tensor_dict contrib_tpu.outside_compilation(decode_host_call, predictions) return tpu_estimator.TPUEstimatorSpec(mode=mode, predictions=predictions) training = mode == tf.estimator.ModeKeys.TRAIN if use_tpu and model_params.use_bfloat16: with contrib_tpu.bfloat16_scope(): loss, outputs = model_params.model()(features, training) else: loss, outputs = model_params.model()(features, training) # TPU requires ouputs all have batch dimension and doesn't handle scalar. # Tile all scalars to 1 dimension vector. outputs = _tile_scalar_to_batch_size(outputs, model_params.batch_size) if mode == tf.estimator.ModeKeys.TRAIN: init_lr = model_params.learning_rate global_step = tf.train.get_global_step() lr = init_lr / 0.01 * tf.rsqrt( tf.maximum(tf.to_float(global_step), 10000)) if train_init_checkpoint: lr = tf.minimum( tf.to_float(global_step + 1) / train_warmup_steps * init_lr, lr) optimizer = adafactor.AdafactorOptimizer( learning_rate=lr, decay_rate=adafactor.adafactor_decay_rate_pow(0.8), beta1=0.0) if use_tpu: optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) train_op = optimizer.minimize(loss, global_step=global_step) return tpu_estimator.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=_load_vars_from_checkpoint(use_tpu, train_init_checkpoint), host_call=add_scalars_to_summary(model_dir, {"learning_rate": lr})) if mode == tf.estimator.ModeKeys.EVAL: eval_metrics = model_params.estimator_eval_metrics_fn( features, outputs) return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics)
def resnet_model_fn_w_pruning(features, labels, mode, params): """The model_fn for ResNet-50 with pruning. Args: features: A float32 batch of images. labels: A int32 batch of labels. mode: Specifies whether training or evaluation. params: Dictionary of parameters passed to the model. Returns: A TPUEstimatorSpec for the model """ width = 1. if FLAGS.width <= 0 else FLAGS.width if isinstance(features, dict): features = features['feature'] if FLAGS.data_format == 'channels_first': assert not FLAGS.transpose_input # channels_first only for GPU features = tf.transpose(features, [0, 3, 1, 2]) if FLAGS.transpose_input and mode != tf_estimator.ModeKeys.PREDICT: features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC # Normalize the image to zero mean and unit variance. features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype) features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype) pruning_method = params['pruning_method'] use_tpu = params['use_tpu'] log_alpha_threshold = params['log_alpha_threshold'] def build_network(): """Construct the network in the graph.""" model_pruning_method = pruning_method if pruning_method == 'scratch': model_pruning_method = 'threshold' network = resnet_model.resnet_v1_( resnet_depth=FLAGS.resnet_depth, num_classes=FLAGS.num_label_classes, # we need to construct the model with the pruning masks, but they won't # be updated if we're doing scratch training pruning_method=model_pruning_method, init_method=FLAGS.init_method, width=width, prune_first_layer=FLAGS.prune_first_layer, prune_last_layer=FLAGS.prune_last_layer, data_format=FLAGS.data_format, end_sparsity=FLAGS.end_sparsity, clip_log_alpha=FLAGS.clip_log_alpha, log_alpha_threshold=log_alpha_threshold, weight_decay=FLAGS.weight_decay) return network(inputs=features, is_training=(mode == tf_estimator.ModeKeys.TRAIN)) if FLAGS.precision == 'bfloat16': with contrib_tpu.bfloat16_scope(): logits = build_network() logits = tf.cast(logits, tf.float32) elif FLAGS.precision == 'float32': logits = build_network() if mode == tf_estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf_estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf_estimator.export.PredictOutput(predictions) }) output_dir = params['output_dir'] # pylint: disable=unused-variable # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes) # make sure we reuse the same label smoothing parameter is we're doing # scratch / lottery ticket experiments. label_smoothing = FLAGS.label_smoothing if FLAGS.pruning_method == 'scratch': label_smoothing = float(FLAGS.load_mask_dir.split('/')[15]) loss = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=one_hot_labels, label_smoothing=label_smoothing) # Add regularization loss term loss += tf.losses.get_regularization_loss() if pruning_method == 'variational_dropout': reg_loss = utils.variational_dropout_dkl_loss( reg_scalar=FLAGS.reg_scalar, start_reg_ramp_up=FLAGS.sparsity_begin_step, end_reg_ramp_up=FLAGS.sparsity_end_step, warm_up=FLAGS.is_warm_up, use_tpu=use_tpu) loss += reg_loss tf.losses.add_loss(reg_loss, loss_collection=tf.GraphKeys.LOSSES) elif pruning_method == 'l0_regularization': reg_loss = utils.l0_regularization_loss( reg_scalar=FLAGS.reg_scalar, start_reg_ramp_up=FLAGS.sparsity_begin_step, end_reg_ramp_up=FLAGS.sparsity_end_step, warm_up=FLAGS.is_warm_up, use_tpu=use_tpu) loss += reg_loss tf.losses.add_loss(reg_loss, loss_collection=tf.GraphKeys.LOSSES) host_call = None if mode == tf_estimator.ModeKeys.TRAIN: host_call, train_op = train_function(pruning_method, loss, output_dir, use_tpu) else: train_op = None eval_metrics = None if mode == tf_estimator.ModeKeys.EVAL: def metric_fn(labels, logits): """Calculate eval metrics.""" logging.info('In metric function') eval_metrics = {} predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) eval_metrics['top_5_eval_accuracy'] = tf.metrics.mean(in_top_5) eval_metrics['eval_accuracy'] = tf.metrics.accuracy( labels=labels, predictions=predictions) return eval_metrics def vd_metric_fn(labels, logits, global_sparsity): eval_metrics = metric_fn(labels, logits) eval_metrics['global_sparsity'] = tf.metrics.mean(global_sparsity) return eval_metrics tensors = [labels, logits] metric_function = metric_fn if FLAGS.pruning_method == 'variational_dropout': batch_size = labels.shape[0] ones = tf.ones([batch_size, 1]) mask_metrics = utils.add_vd_pruning_summaries( threshold=FLAGS.log_alpha_threshold) tensors.append(mask_metrics['global_sparsity'] * ones) metric_function = vd_metric_fn eval_metrics = (metric_function, tensors) # define a custom scaffold function to enable initializing the mask from an # already trained checkpoint. def initialize_mask_from_ckpt(ckpt_path): """Load mask from an existing checkpoint.""" model_dir = FLAGS.output_dir already_has_ckpt = model_dir and tf.train.latest_checkpoint( model_dir) is not None if already_has_ckpt: tf.logging.info( 'Training already started on this model, not loading masks from' 'previously trained model') return reader = tf.train.NewCheckpointReader(ckpt_path) mask_names = reader.get_variable_to_shape_map().keys() mask_names = [x for x in mask_names if x.endswith('mask')] variable_map = {} for var in tf.global_variables(): var_name = var.name.split(':')[0] if var_name in mask_names: tf.logging.info('Loading mask variable from checkpoint: %s', var_name) variable_map[var_name] = var elif 'mask' in var_name: tf.logging.info( 'Cannot find mask variable in checkpoint, skipping: %s', var_name) tf.train.init_from_checkpoint(ckpt_path, variable_map) def initialize_parameters_from_ckpt(ckpt_path): """Load parameters from an existing checkpoint.""" model_dir = FLAGS.output_dir already_has_ckpt = model_dir and tf.train.latest_checkpoint( model_dir) is not None if already_has_ckpt: tf.logging.info( 'Training already started on this model, not loading masks from' 'previously trained model') return reader = tf.train.NewCheckpointReader(ckpt_path) param_names = reader.get_variable_to_shape_map().keys() param_names = [x for x in param_names if not x.endswith('mask')] variable_map = {} for var in tf.global_variables(): var_name = var.name.split(':')[0] if var_name in param_names: tf.logging.info( 'Loading parameter variable from checkpoint: %s', var_name) variable_map[var_name] = var elif 'mask' not in var_name: tf.logging.info( 'Cannot find parameter variable in checkpoint, skipping: %s', var_name) tf.train.init_from_checkpoint(ckpt_path, variable_map) if FLAGS.pruning_method == 'scratch': if FLAGS.load_mask_dir: def scaffold_fn(): initialize_mask_from_ckpt(FLAGS.load_mask_dir) if FLAGS.initial_value_checkpoint: initialize_parameters_from_ckpt( FLAGS.initial_value_checkpoint) return tf.train.Scaffold() else: raise ValueError( 'Must supply a mask directory to use scratch method') else: scaffold_fn = None return contrib_tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)
def resnet_model_fn_w_pruning(features, labels, mode, params): """The model_fn for ResNet-50 with pruning. Args: features: A float32 batch of images. labels: A int32 batch of labels. mode: Specifies whether training or evaluation. params: Dictionary of parameters passed to the model. Returns: A TPUEstimatorSpec for the model """ width = 1. if FLAGS.width <= 0 else FLAGS.width if isinstance(features, dict): features = features['feature'] if FLAGS.data_format == 'channels_first': assert not FLAGS.transpose_input # channels_first only for GPU features = tf.transpose(features, [0, 3, 1, 2]) if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT: features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC # Normalize the image to zero mean and unit variance. features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype) features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype) training_method = params['training_method'] use_tpu = params['use_tpu'] def build_network(): """Construct the network in the graph.""" if FLAGS.model_architecture == 'mobilenet_v2': network_func = functools.partial( mobilenetv2_model.mobilenet_v2, expansion_factor=FLAGS.expansion_factor) elif FLAGS.model_architecture == 'mobilenet_v1': network_func = functools.partial(mobilenetv1_model.mobilenet_v1) elif FLAGS.model_architecture == 'resnet': prune_first_layer = FLAGS.first_layer_sparsity != 0. network_func = functools.partial( resnet_model.resnet_v1_, resnet_depth=FLAGS.resnet_depth, init_method=FLAGS.init_method, end_sparsity=FLAGS.end_sparsity, prune_first_layer=prune_first_layer) elif FLAGS.model_architecture.startswith('vgg'): network_func = functools.partial(vgg.vgg, vgg_type=FLAGS.model_architecture, init_method=FLAGS.init_method, end_sparsity=FLAGS.end_sparsity) else: raise ValueError('Unknown archiecture ' + FLAGS.archiecture) prune_last_layer = FLAGS.last_layer_sparsity != 0. network = network_func( num_classes=FLAGS.num_label_classes, # TODO remove the pruning_method option. pruning_method='threshold', width=width, prune_last_layer=prune_last_layer, data_format=FLAGS.data_format, weight_decay=FLAGS.weight_decay) is_training = (mode == tf.estimator.ModeKeys.TRAIN) if FLAGS.use_batch_statistics: is_training = True return network(inputs=features, is_training=is_training) if FLAGS.precision == 'bfloat16': with contrib_tpu.bfloat16_scope(): logits = build_network() logits = tf.cast(logits, tf.float32) elif FLAGS.precision == 'float32': logits = build_network() if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) output_dir = params['output_dir'] # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes) # make sure we reuse the same label smoothing parameter is we're doing # scratch / lottery ticket experiments. label_smoothing = FLAGS.label_smoothing if FLAGS.training_method == 'scratch' and FLAGS.load_mask_dir: scratch_stripped = FLAGS.load_mask_dir.replace('/scratch', '') label_smoothing = float(scratch_stripped.split('/')[15]) tf.logging.info('LABEL SMOOTHING USED: %.2f' % label_smoothing) cross_loss = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels, label_smoothing=label_smoothing) # Add regularization loss term reg_loss = tf.losses.get_regularization_loss() loss = cross_loss + reg_loss host_call = None if mode == tf.estimator.ModeKeys.TRAIN: host_call, train_op = train_function(training_method, loss, cross_loss, reg_loss, output_dir, use_tpu) else: train_op = None eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(labels, logits, cross_loss, reg_loss): """Calculate eval metrics.""" logging.info('In metric function') eval_metrics = {} predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) eval_metrics['top_5_eval_accuracy'] = tf.metrics.mean(in_top_5) eval_metrics['cross_loss'] = tf.metrics.mean(cross_loss) eval_metrics['reg_loss'] = tf.metrics.mean(reg_loss) eval_metrics['eval_accuracy'] = tf.metrics.accuracy( labels=labels, predictions=predictions) # If evaluating once lets also calculate sparsities. if FLAGS.mode == 'eval_once': sparsity_summaries = utils.mask_summaries(pruning.get_masks()) # We call mean on a scalar to create tensor, update_op pairs. sparsity_summaries = { k: tf.metrics.mean(v) for k, v in sparsity_summaries.items() } eval_metrics.update(sparsity_summaries) return eval_metrics tensors = [ labels, logits, tf.broadcast_to(cross_loss, tf.shape(labels)), tf.broadcast_to(reg_loss, tf.shape(labels)) ] eval_metrics = (metric_fn, tensors) if (FLAGS.load_mask_dir and FLAGS.training_method not in NO_MASK_INIT_METHODS): def scaffold_fn(): """For initialization, passed to the estimator.""" utils.initialize_parameters_from_ckpt(FLAGS.load_mask_dir, FLAGS.output_dir, MASK_SUFFIX) if FLAGS.initial_value_checkpoint: utils.initialize_parameters_from_ckpt( FLAGS.initial_value_checkpoint, FLAGS.output_dir, PARAM_SUFFIXES) return tf.train.Scaffold() elif (FLAGS.mask_init_method and FLAGS.training_method not in NO_MASK_INIT_METHODS): def scaffold_fn(): """For initialization, passed to the estimator.""" if FLAGS.initial_value_checkpoint: utils.initialize_parameters_from_ckpt( FLAGS.initial_value_checkpoint, FLAGS.output_dir, PARAM_SUFFIXES) all_masks = pruning.get_masks() assigner = sparse_utils.get_mask_init_fn( all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity, CUSTOM_SPARSITY_MAP, erk_power_scale=FLAGS.erk_power_scale) def init_fn(scaffold, session): """A callable for restoring variable from a checkpoint.""" del scaffold # Unused. session.run(assigner) return tf.train.Scaffold(init_fn=init_fn) else: assert FLAGS.training_method in NO_MASK_INIT_METHODS scaffold_fn = None tf.logging.info('No mask is set, starting dense.') return contrib_tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)
def model_fn(features, labels, mode, params=None): """Constructs the object detection model. Args: features: Dictionary of feature tensors, returned from `input_fn`. labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL, otherwise None. mode: Mode key from tf.estimator.ModeKeys. params: Parameter dictionary passed from the estimator. Returns: An `EstimatorSpec` that encapsulates the model and its serving configurations. """ params = params or {} total_loss, train_op, detections, export_outputs = None, None, None, None is_training = mode == tf.estimator.ModeKeys.TRAIN # Make sure to set the Keras learning phase. True during training, # False for inference. tf.keras.backend.set_learning_phase(is_training) # Set policy for mixed-precision training with Keras-based models. if use_tpu and train_config.use_bfloat16: from tensorflow.python.keras.engine import base_layer_utils # pylint: disable=g-import-not-at-top # Enable v2 behavior, as `mixed_bfloat16` is only supported in TF 2.0. base_layer_utils.enable_v2_dtype_behavior() tf.compat.v2.keras.mixed_precision.experimental.set_policy( 'mixed_bfloat16') detection_model = detection_model_fn(is_training=is_training, add_summaries=(not use_tpu)) scaffold_fn = None if mode == tf.estimator.ModeKeys.TRAIN: labels = unstack_batch(labels, unpad_groundtruth_tensors=train_config. unpad_groundtruth_tensors) elif mode == tf.estimator.ModeKeys.EVAL: # For evaling on train data, it is necessary to check whether groundtruth # must be unpadded. boxes_shape = (labels[fields.InputDataFields.groundtruth_boxes]. get_shape().as_list()) unpad_groundtruth_tensors = boxes_shape[ 1] is not None and not use_tpu labels = unstack_batch( labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors) if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): provide_groundtruth(detection_model, labels) preprocessed_images = features[fields.InputDataFields.image] side_inputs = detection_model.get_side_inputs(features) if use_tpu and train_config.use_bfloat16: with contrib_tpu.bfloat16_scope(): prediction_dict = detection_model.predict( preprocessed_images, features[fields.InputDataFields.true_image_shape], **side_inputs) prediction_dict = ops.bfloat16_to_float32_nested( prediction_dict) else: prediction_dict = detection_model.predict( preprocessed_images, features[fields.InputDataFields.true_image_shape], **side_inputs) def postprocess_wrapper(args): return detection_model.postprocess(args[0], args[1]) if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT): if use_tpu and postprocess_on_cpu: detections = contrib_tpu.outside_compilation( postprocess_wrapper, (prediction_dict, features[fields.InputDataFields.true_image_shape])) else: detections = postprocess_wrapper( (prediction_dict, features[fields.InputDataFields.true_image_shape])) if mode == tf.estimator.ModeKeys.TRAIN: load_pretrained = hparams.load_pretrained if hparams else False if train_config.fine_tune_checkpoint and load_pretrained: if not train_config.fine_tune_checkpoint_type: # train_config.from_detection_checkpoint field is deprecated. For # backward compatibility, set train_config.fine_tune_checkpoint_type # based on train_config.from_detection_checkpoint. if train_config.from_detection_checkpoint: train_config.fine_tune_checkpoint_type = 'detection' else: train_config.fine_tune_checkpoint_type = 'classification' asg_map = detection_model.restore_map( fine_tune_checkpoint_type=train_config. fine_tune_checkpoint_type, load_all_detection_checkpoint_vars=( train_config.load_all_detection_checkpoint_vars)) available_var_map = ( variables_helper.get_variables_available_in_checkpoint( asg_map, train_config.fine_tune_checkpoint, include_global_step=False)) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint( train_config.fine_tune_checkpoint, available_var_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint( train_config.fine_tune_checkpoint, available_var_map) if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): if (mode == tf.estimator.ModeKeys.EVAL and eval_config.use_dummy_loss_in_eval): total_loss = tf.constant(1.0) losses_dict = {'Loss/total_loss': total_loss} else: losses_dict = detection_model.loss( prediction_dict, features[fields.InputDataFields.true_image_shape]) losses = [loss_tensor for loss_tensor in losses_dict.values()] if train_config.add_regularization_loss: regularization_losses = detection_model.regularization_losses( ) if use_tpu and train_config.use_bfloat16: regularization_losses = ops.bfloat16_to_float32_nested( regularization_losses) if regularization_losses: regularization_loss = tf.add_n( regularization_losses, name='regularization_loss') losses.append(regularization_loss) losses_dict[ 'Loss/regularization_loss'] = regularization_loss total_loss = tf.add_n(losses, name='total_loss') losses_dict['Loss/total_loss'] = total_loss if 'graph_rewriter_config' in configs: graph_rewriter_fn = graph_rewriter_builder.build( configs['graph_rewriter_config'], is_training=is_training) graph_rewriter_fn() # TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we # can write learning rate summaries on TPU without host calls. global_step = tf.train.get_or_create_global_step() training_optimizer, optimizer_summary_vars = optimizer_builder.build( train_config.optimizer) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: training_optimizer = contrib_tpu.CrossShardOptimizer( training_optimizer) # Optionally freeze some layers by setting their gradients to be zero. trainable_variables = None include_variables = (train_config.update_trainable_variables if train_config.update_trainable_variables else None) exclude_variables = (train_config.freeze_variables if train_config.freeze_variables else None) trainable_variables = contrib_framework.filter_variables( tf.trainable_variables(), include_patterns=include_variables, exclude_patterns=exclude_variables) clip_gradients_value = None if train_config.gradient_clipping_by_norm > 0: clip_gradients_value = train_config.gradient_clipping_by_norm if not use_tpu: for var in optimizer_summary_vars: tf.summary.scalar(var.op.name, var) summaries = [] if use_tpu else None if train_config.summarize_gradients: summaries = [ 'gradients', 'gradient_norm', 'global_gradient_norm' ] train_op = contrib_layers.optimize_loss( loss=total_loss, global_step=global_step, learning_rate=None, clip_gradients=clip_gradients_value, optimizer=training_optimizer, update_ops=detection_model.updates(), variables=trainable_variables, summaries=summaries, name='') # Preventing scope prefix on all variables. if mode == tf.estimator.ModeKeys.PREDICT: exported_output = exporter_lib.add_output_tensor_nodes(detections) export_outputs = { tf.saved_model.signature_constants.PREDICT_METHOD_NAME: tf.estimator.export.PredictOutput(exported_output) } eval_metric_ops = None scaffold = None if mode == tf.estimator.ModeKeys.EVAL: class_agnostic = (fields.DetectionResultFields.detection_classes not in detections) groundtruth = _prepare_groundtruth_for_eval( detection_model, class_agnostic, eval_input_config.max_number_of_boxes) use_original_images = fields.InputDataFields.original_image in features if use_original_images: eval_images = features[fields.InputDataFields.original_image] true_image_shapes = tf.slice( features[fields.InputDataFields.true_image_shape], [0, 0], [-1, 3]) original_image_spatial_shapes = features[ fields.InputDataFields.original_image_spatial_shape] else: eval_images = features[fields.InputDataFields.image] true_image_shapes = None original_image_spatial_shapes = None eval_dict = eval_util.result_dict_for_batched_example( eval_images, features[inputs.HASH_KEY], detections, groundtruth, class_agnostic=class_agnostic, scale_to_absolute=True, original_image_spatial_shapes=original_image_spatial_shapes, true_image_shapes=true_image_shapes) if fields.InputDataFields.image_additional_channels in features: eval_dict[fields.InputDataFields. image_additional_channels] = features[ fields.InputDataFields.image_additional_channels] if class_agnostic: category_index = label_map_util.create_class_agnostic_category_index( ) else: category_index = label_map_util.create_category_index_from_labelmap( eval_input_config.label_map_path) vis_metric_ops = None if not use_tpu and use_original_images: eval_metric_op_vis = vis_utils.VisualizeSingleFrameDetections( category_index, max_examples_to_draw=eval_config.num_visualizations, max_boxes_to_draw=eval_config.max_num_boxes_to_visualize, min_score_thresh=eval_config.min_score_threshold, use_normalized_coordinates=False) vis_metric_ops = eval_metric_op_vis.get_estimator_eval_metric_ops( eval_dict) # Eval metrics on a single example. eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators( eval_config, list(category_index.values()), eval_dict) for loss_key, loss_tensor in iter(losses_dict.items()): eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor) for var in optimizer_summary_vars: eval_metric_ops[var.op.name] = (var, tf.no_op()) if vis_metric_ops is not None: eval_metric_ops.update(vis_metric_ops) eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()} if eval_config.use_moving_averages: variable_averages = tf.train.ExponentialMovingAverage(0.0) variables_to_restore = variable_averages.variables_to_restore() keep_checkpoint_every_n_hours = ( train_config.keep_checkpoint_every_n_hours) saver = tf.train.Saver( variables_to_restore, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours ) scaffold = tf.train.Scaffold(saver=saver) # EVAL executes on CPU, so use regular non-TPU EstimatorSpec. if use_tpu and mode != tf.estimator.ModeKeys.EVAL: return contrib_tpu.TPUEstimatorSpec(mode=mode, scaffold_fn=scaffold_fn, predictions=detections, loss=total_loss, train_op=train_op, eval_metrics=eval_metric_ops, export_outputs=export_outputs) else: if scaffold is None: keep_checkpoint_every_n_hours = ( train_config.keep_checkpoint_every_n_hours) saver = tf.train.Saver( sharded=True, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) scaffold = tf.train.Scaffold(saver=saver) return tf.estimator.EstimatorSpec(mode=mode, predictions=detections, loss=total_loss, train_op=train_op, eval_metric_ops=eval_metric_ops, export_outputs=export_outputs, scaffold=scaffold)
def model_fn(self, features, labels, mode, config=None, params=None): """Estimator model_fn. Note, this function overwrites the model_fn of the wrapped t2r_model since is replaces specifications with their TPU corresponding calls and introduces additional casting conversion after the specification has been verified. Args: features: This is the first item returned from the input_fn and parsed by tensorspec_utils.validate_and_pack. A spec_structure which fulfills the requirements of the self.get_feature_specification. labels: This is the second item returned from the input_fn and parsed by tensorspec_utils.validate_and_pack. A spec_structure which fulfills the requirements of the self.get_feature_specification. mode: (ModeKeys) Specifies if this is training, evaluation or prediction. config: (Optional tf.estimator.RunConfig or contrib_tpu.RunConfig) Will receive what is passed to Estimator in config parameter, or the default config (tf.estimator.RunConfig). Allows updating things in your model_fn based on configuration such as num_ps_replicas, or model_dir. params: An optional dict of hyper parameters that will be passed into input_fn and model_fn. Keys are names of parameters, values are basic python types. There are reserved keys for TPUEstimator, including 'batch_size'. Raises: ValueError: If the mode key is not supported, not in [PREDICT, TRAIN, EVAL]. Returns: A TPUEstimatorSpec. """ features = tensorspec_utils.validate_and_pack( expected_spec=self.get_feature_specification(mode), actual_tensors_or_spec=features, ignore_batch=True) if labels: labels = tensorspec_utils.validate_and_pack( expected_spec=self.get_label_specification(mode), actual_tensors_or_spec=labels, ignore_batch=True) # In order to support both TPU and CPU for inference, tensors # with dtype=bfloat16 will be casted to float32. # Note, despite casting the benefit of bfloat16 are still maintained # for TPUs since this operation is a noop on this platform. # See http://shortn/_TTg3ZyATRo for rationale. if not self._train_in_bfloat16 or ( mode == tf.estimator.ModeKeys.PREDICT or mode == tf.estimator.ModeKeys.EVAL): features = tensorspec_utils.cast_bfloat16_to_float32(features) if labels is not None: labels = tensorspec_utils.cast_bfloat16_to_float32(labels) if self._train_in_bfloat16 and mode == tf.estimator.ModeKeys.TRAIN: with contrib_tpu.bfloat16_scope(): inference_outputs = self._t2r_model.inference_network_fn( features, labels, mode, config, params) else: inference_outputs = self._t2r_model.inference_network_fn( features, labels, mode, config, params) update_ops = None if isinstance(inference_outputs, tuple): update_ops = inference_outputs[1] inference_outputs = inference_outputs[0] if mode == tf.estimator.ModeKeys.PREDICT: model_fn_results = self._t2r_model.create_export_outputs_fn( features, inference_outputs, mode, config, params) export_outputs = None if isinstance(model_fn_results, tuple): predictions = model_fn_results[0] export_outputs = model_fn_results[1] elif isinstance(model_fn_results, dict): export_outputs = {} if len(model_fn_results) == 1: name, output = list(model_fn_results.items())[0] export_outputs[ name] = tf.estimator.export.RegressionOutput(output) export_outputs[tf.saved_model.signature_constants. DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( tf.estimator.export.PredictOutput( model_fn_results)) predictions = model_fn_results else: raise ValueError( 'The create_export_outputs_fn should return a ' 'tuple(predictions, export_outputs) or predictions.') return contrib_tpu.TPUEstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs) train_fn_result = self._t2r_model.model_train_fn( features, labels, inference_outputs, mode, config, params) if isinstance(train_fn_result, tf.Tensor): train_loss = train_fn_result train_outputs = {} elif isinstance(train_fn_result, tuple): train_loss = train_fn_result[0] train_outputs = train_fn_result[1] else: raise ValueError('The model_train_fn should return a ' 'tuple(loss, train_outputs) or loss.') if mode == tf.estimator.ModeKeys.TRAIN: # Create the tf.train.Optimizer. optimizer = get_cross_shard_optimizer( self._t2r_model.create_optimizer(params)) train_op = self._t2r_model.create_train_op(train_loss, optimizer, update_ops, train_outputs) self._t2r_model.add_summaries(features, labels, inference_outputs, train_loss, train_outputs, mode, config, params) # For TPUs the init has to happen in a scaffold function. Since the model # already contains one implementation which is internal to the model # this call is simply wrapped. # No new variables are allowed to be added, otherwise # we would not initialize these variables. # Note, this feature is only available for train to bootstrap a model # (partially) from a different model. As soon as this checkpoint is # written all other modes will use the local checkpoint within # model_dir. def create_scaffold_fn(): """Creates a scaffold instance.""" self._t2r_model.maybe_init_from_checkpoint() # Return the value of the property first since it might be changed. scaffold_fn = self._t2r_model.scaffold_fn scaffold = scaffold_fn() # In order to export asynchronously the saver has to be registered # in the graph collection. The scaffold function might register a # saver already which is why it is checked here and a saver only # added it has none has been added. if not tf.get_collection(tf.GraphKeys.SAVERS): # TODO(T2R_CONTRIBUTORS): Switch to using gin config for all saver params. keep_checkpoint_every_n_hours = None max_to_keep = None if config is not None: keep_checkpoint_every_n_hours = config.keep_checkpoint_every_n_hours max_to_keep = config.keep_checkpoint_max saver = abstract_model.gin_configurable_saver( keep_checkpoint_every_n_hours= keep_checkpoint_every_n_hours, max_to_keep=max_to_keep, ) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) return scaffold training_hooks = [] # EstimatorSpec has training_chief_hooks, but TPUEstimatorSpec does not, # so we have to use training_hooks here and check is_chief. if config and config.is_chief: # pytype: disable=attribute-error training_hooks.append( gin_utils.GinConfigSaverHook(config.model_dir, summarize_config=True)) return contrib_tpu.TPUEstimatorSpec(mode=mode, loss=train_loss, train_op=train_op, training_hooks=training_hooks, scaffold_fn=create_scaffold_fn) if mode == tf.estimator.ModeKeys.EVAL: self._t2r_model.add_summaries(features, labels, inference_outputs, train_loss, train_outputs, mode, config, params) eval_metrics = self._t2r_model.model_eval_fn( features, labels, inference_outputs, train_loss, train_outputs, mode, config, params) evaluation_hooks = self._t2r_model.get_eval_hooks(config, params) if config and config.is_chief: # pytype: disable=attribute-error eval_name = params.get('eval_name', 'eval') # pytype: disable=attribute-error evaluation_hooks.append( gin_utils.GinConfigSaverHook(os.path.join( config.model_dir, eval_name), summarize_config=True)) return contrib_tpu.TPUEstimatorSpec( mode=mode, loss=train_loss, eval_metrics=eval_metrics, evaluation_hooks=evaluation_hooks) raise ValueError('The mode {} is not supported yet.'.format(mode))
def _model_fn(features, labels, mode, params, model, use_tpu_estimator_spec, variable_filter_fn=None): """Model defination for the RetinaNet model based on ResNet. Args: features: the input image tensor with shape [batch_size, height, width, 3]. The height and width are fixed and equal. labels: the input labels in a dictionary. The labels include class targets and box targets which are dense label maps. The labels are generated from get_input_fn function in dataloader.py mode: the mode of TPUEstimator/Estimator including TRAIN, EVAL, and PREDICT. params: the dictionary defines hyperparameters of model. The default settings are in default_hparams function in this file. model: the RetinaNet model outputs class logits and box regression outputs. use_tpu_estimator_spec: Whether to use TPUEstimatorSpec or EstimatorSpec. variable_filter_fn: the filter function that takes trainable_variables and returns the variable list after applying the filter rule. Returns: tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction. """ # In predict mode features is a dict with input as value of the 'inputs'. image_info = None if (mode == tf.estimator.ModeKeys.PREDICT and isinstance(features, dict) and 'inputs' in features): image_info = features['image_info'] labels = None if 'labels' in features: labels = features['labels'] features = features['inputs'] def _model_outputs(): return model(features, min_level=params['min_level'], max_level=params['max_level'], num_classes=params['num_classes'], num_anchors=len(params['aspect_ratios'] * params['num_scales']), resnet_depth=params['resnet_depth'], is_training_bn=params['is_training_bn']) if params['use_bfloat16']: with contrib_tpu.bfloat16_scope(): cls_outputs, box_outputs = _model_outputs() levels = cls_outputs.keys() for level in levels: cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32) box_outputs[level] = tf.cast(box_outputs[level], tf.float32) else: cls_outputs, box_outputs = _model_outputs() levels = cls_outputs.keys() # First check if it is in PREDICT mode. if mode == tf.estimator.ModeKeys.PREDICT: # Postprocess on host; memory layout for NMS on TPU is very inefficient. def _predict_postprocess_wrapper(args): return _predict_postprocess(*args) predictions = contrib_tpu.outside_compilation( _predict_postprocess_wrapper, (cls_outputs, box_outputs, labels, params)) # Include resizing information on prediction output to help bbox drawing. if image_info is not None: predictions.update({ 'image_info': tf.identity(image_info, 'ImageInfo'), }) return contrib_tpu.TPUEstimatorSpec(mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions) # Load pretrained model from checkpoint. if params['resnet_checkpoint'] and mode == tf.estimator.ModeKeys.TRAIN: def scaffold_fn(): """Loads pretrained model through scaffold function.""" tf.train.init_from_checkpoint( params['resnet_checkpoint'], { '/': 'resnet%s/' % params['resnet_depth'], }) return tf.train.Scaffold() else: scaffold_fn = None # Set up training loss and learning rate. update_learning_rate_schedule_parameters(params) global_step = tf.train.get_global_step() learning_rate = learning_rate_schedule(params['adjusted_learning_rate'], params['lr_warmup_init'], params['lr_warmup_step'], params['first_lr_drop_step'], params['second_lr_drop_step'], global_step) # cls_loss and box_loss are for logging. only total_loss is optimized. total_loss, cls_loss, box_loss = detection_loss(cls_outputs, box_outputs, labels, params) total_loss += _WEIGHT_DECAY * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) if mode == tf.estimator.ModeKeys.TRAIN: optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=params['momentum']) if params['use_tpu']: optimizer = contrib_tpu.CrossShardOptimizer(optimizer) else: if params['auto_mixed_precision']: optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer) # Batch norm requires `update_ops` to be executed alongside `train_op`. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) var_list = variable_filter_fn( tf.trainable_variables(), params['resnet_depth']) if variable_filter_fn else None minimize_op = optimizer.minimize(total_loss, global_step, var_list=var_list) train_op = tf.group(minimize_op, update_ops) else: train_op = None eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(**kwargs): """Returns a dictionary that has the evaluation metrics.""" batch_size = params['batch_size'] eval_anchors = anchors.Anchors(params['min_level'], params['max_level'], params['num_scales'], params['aspect_ratios'], params['anchor_scale'], params['image_size']) anchor_labeler = anchors.AnchorLabeler(eval_anchors, params['num_classes']) cls_loss = tf.metrics.mean(kwargs['cls_loss_repeat']) box_loss = tf.metrics.mean(kwargs['box_loss_repeat']) coco_metrics = coco_metric_fn(batch_size, anchor_labeler, params['val_json_file'], **kwargs) # Add metrics to output. output_metrics = { 'cls_loss': cls_loss, 'box_loss': box_loss, } output_metrics.update(coco_metrics) return output_metrics cls_loss_repeat = tf.reshape( tf.tile(tf.expand_dims(cls_loss, 0), [ params['batch_size'], ]), [params['batch_size'], 1]) box_loss_repeat = tf.reshape( tf.tile(tf.expand_dims(box_loss, 0), [ params['batch_size'], ]), [params['batch_size'], 1]) metric_fn_inputs = { 'cls_loss_repeat': cls_loss_repeat, 'box_loss_repeat': box_loss_repeat, 'source_ids': labels['source_ids'], 'groundtruth_data': labels['groundtruth_data'], 'image_scales': labels['image_scales'], } add_metric_fn_inputs(params, cls_outputs, box_outputs, metric_fn_inputs) eval_metrics = (metric_fn, metric_fn_inputs) if use_tpu_estimator_spec: return contrib_tpu.TPUEstimatorSpec(mode=mode, loss=total_loss, train_op=train_op, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: return tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, # TODO(rostam): Fix bug to get scaffold working. # scaffold=scaffold_fn(), train_op=train_op)
def model_fn(features, labels, mode): """Definition for ResNet model.""" is_training = mode == tf.estimator.ModeKeys.TRAIN if FLAGS.transpose_input: features = tf.transpose(features, [3, 0, 1, 2]) # Double-transpose trick # Normalize the image to zero mean and unit variance. features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype) features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype) def create_model(): """Create the model and compute the logits.""" if FLAGS.use_keras_model: model = tf.keras.applications.resnet50.ResNet50( include_top=True, weights=None, input_tensor=None, input_shape=None, pooling=None, classes=_NUM_CLASSES) return model(features, training=is_training) else: model = resnet_model.resnet_v1(resnet_depth=_RESNET_DEPTH, num_classes=_NUM_CLASSES, data_format='channels_last') return model(inputs=features, is_training=is_training) if FLAGS.precision == 'bfloat16': with contrib_tpu.bfloat16_scope(): logits = create_model() else: logits = create_model() logits = tf.cast(logits, tf.float32) if mode == tf.estimator.ModeKeys.PREDICT: assert False, 'Not implemented correctly right now!' predictions = {'logits': logits} return tf.estimator.EstimatorSpec(mode, predictions=predictions) cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) loss = cross_entropy + _WEIGHT_DECAY * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) if mode == tf.estimator.ModeKeys.EVAL: predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) # TODO(priyag): Add this back when in_top_k is supported on TPU. # in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) # top_5_accuracy = tf.metrics.mean(in_top_5) eval_metric_ops = { 'top_1_accuracy': top_1_accuracy, # 'top_5_accuracy': top_5_accuracy, } return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=eval_metric_ops) assert mode == tf.estimator.ModeKeys.TRAIN global_step = tf.train.get_or_create_global_step() batches_per_epoch = (_NUM_TRAIN_IMAGES / (FLAGS.train_batch_size * FLAGS.num_cores)) current_epoch = (tf.cast(global_step, tf.float32) / batches_per_epoch) learning_rate = learning_rate_schedule(current_epoch) if FLAGS.optimizer == 'sgd': optimizer = tf.train.GradientDescentOptimizer( learning_rate=learning_rate) else: optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=_MOMENTUM, use_nesterov=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step=global_step) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
def _model_fn(features, labels, mode, params, model, variable_filter_fn=None): """Model defination for the RetinaNet model based on ResNet-50. Args: features: The input images tensor with shape [batch_size, height, width, 3]. The height and width are fixed and equal. labels: The input labels in a tensor with the same shape as input images. mode: The mode of TPUEstimator including TRAIN, EVAL, and PREDICT. params: The dictionary defines hyperparameters of model. The default settings are in default_hparams function in this file. model: The FPN segmentation model outputs class logits. variable_filter_fn: the filter function that takes trainable_variables and returns the variable list after applying the filter rule. Returns: tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction. """ def _model_outputs(): return model(features, min_level=params['min_level'], max_level=params['max_level'], num_classes=params['num_classes'], resnet_depth=params['resnet_depth'], is_training_bn=params['is_training_bn']) if params['use_bfloat16']: with contrib_tpu.bfloat16_scope(): cls_outputs = _model_outputs() cls_outputs = tf.cast(cls_outputs, tf.float32) else: cls_outputs = _model_outputs() # First check if it is in PREDICT mode. if mode == tf.estimator.ModeKeys.PREDICT: predictions = {'image': features, 'cls_outputs': cls_outputs} return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) # Load pretrained model from checkpoint. if params['resnet_checkpoint'] and mode == tf.estimator.ModeKeys.TRAIN: def scaffold_fn(): """Loads pretrained model through scaffold function.""" tf.train.init_from_checkpoint( params['resnet_checkpoint'], { '/': 'resnet%s/' % params['resnet_depth'], }) return tf.train.Scaffold() else: scaffold_fn = None # Set up training loss and learning rate. retinanet_model.update_learning_rate_schedule_parameters(params) global_step = tf.train.get_global_step() learning_rate = retinanet_model.learning_rate_schedule( params['adjusted_learning_rate'], params['lr_warmup_init'], params['lr_warmup_step'], params['first_lr_drop_step'], params['second_lr_drop_step'], global_step) cls_loss = _segmentation_loss(cls_outputs, labels, params) weight_decay_loss = params['weight_decay'] * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) # Add L2 regularization loss total_loss = cls_loss + weight_decay_loss if mode == tf.estimator.ModeKeys.TRAIN: optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=params['momentum']) if params['use_tpu']: optimizer = contrib_tpu.CrossShardOptimizer(optimizer) # Batch norm requires update_ops to be added as a train_op dependency. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) var_list = variable_filter_fn( tf.trainable_variables(), params['resnet_depth']) if variable_filter_fn else None with tf.control_dependencies(update_ops): train_op = optimizer.minimize(total_loss, global_step, var_list=var_list) else: train_op = None # Evaluation only works on GPU/CPU host and batch_size=1 eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: batch_size = params['batch_size'] def metric_fn(**kwargs): """Creates metric_fn for TPUEstimatorSpec.""" cls_loss = tf.metrics.mean(kwargs['cls_loss_repeat']) total_loss = tf.metrics.mean(kwargs['total_loss_repeat']) logits = tf.image.resize_bilinear(kwargs['prediction'], tf.shape(kwargs['labels'])[1:3], align_corners=True) predictions_with_shape = tf.argmax(logits, 3, output_type=tf.int32) predictions = tf.reshape(predictions_with_shape, shape=[-1]) labels = tf.reshape(kwargs['labels'], shape=[-1]) # Background class is considered as a class. Not ignored. weights = tf.to_float(tf.not_equal(labels, params['ignore_label'])) # Set ignore_label regions to label 0, because metrics.mean_iou requires # range of labels = [0, dataset.num_classes). # Note the ignore_lable regions are not evaluated since the corresponding # regions contain weights = 0. labels = tf.where(tf.equal(labels, params['ignore_label']), tf.zeros_like(labels), labels) return { 'total_loss': total_loss, 'cls_loss': cls_loss, 'miou': tf.metrics.mean_iou(predictions, labels, params['num_classes'], weights=weights), } cls_loss_repeat = tf.reshape( tf.tile(tf.expand_dims(cls_loss, 0), [ batch_size, ]), [batch_size, 1]) total_loss_repeat = tf.reshape( tf.tile(tf.expand_dims(total_loss, 0), [ batch_size, ]), [batch_size, 1]) metric_fn_inputs = { 'cls_loss_repeat': cls_loss_repeat, 'total_loss_repeat': total_loss_repeat, 'prediction': cls_outputs, 'labels': labels, } eval_metrics = (metric_fn, metric_fn_inputs) return contrib_tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn, )
def model_fn(features, labels, mode, params): """Defines how to train, evaluate and predict from the transformer model.""" with tf.variable_scope("model"): inputs, targets = features, labels # Create model and get output logits. with bfloat16_scope(): model = transformer.Transformer( params, mode == tf.estimator.ModeKeys.TRAIN) logits = model(inputs, targets) # When in prediction mode, the labels/targets is None. The model output # is the prediction if mode == tf.estimator.ModeKeys.PREDICT: if params["use_tpu"]: raise NotImplementedError( "Prediction is not yet supported on TPUs.") return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.PREDICT, predictions=logits, export_outputs={ "translate": tf.estimator.export.PredictOutput(logits) }) # Explicitly set the shape of the logits for XLA (TPU). This is needed # because the logits are passed back to the host VM CPU for metric # evaluation, and the shape of [?, ?, vocab_size] is too vague. However # it is known from Transformer that the first two dimensions of logits # are the dimensions of targets. Note that the ambiguous shape of logits is # not a problem when computing xentropy, because padded_cross_entropy_loss # resolves the shape on the TPU. logits.set_shape(targets.shape.as_list() + logits.shape.as_list()[2:]) # Calculate model loss. # xentropy contains the cross entropy loss of every nonpadding token in the # targets. xentropy, weights = metrics.padded_cross_entropy_loss( logits, targets, params["label_smoothing"], params["vocab_size"]) loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights) # Save loss as named tensor that will be logged with the logging hook. tf.identity(loss, "cross_entropy") if mode == tf.estimator.ModeKeys.EVAL: if params["use_tpu"]: # host call functions should only have tensors as arguments. # This lambda pre-populates params so that metric_fn is # TPUEstimator compliant. metric_fn = lambda logits, labels: (metrics.get_eval_metrics( logits, labels, params=params)) eval_metrics = (metric_fn, [logits, labels]) return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, predictions={"predictions": logits}, eval_metrics=eval_metrics) return tf.estimator.EstimatorSpec( mode=mode, loss=loss, predictions={"predictions": logits}, eval_metric_ops=metrics.get_eval_metrics( logits, labels, params)) else: train_op, metric_dict = get_train_op_and_metrics(loss, params) # Epochs can be quite long. This gives some intermediate information # in TensorBoard. metric_dict["minibatch_loss"] = loss if params["use_tpu"]: return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, host_call=tpu_util.construct_scalar_host_call( metric_dict=metric_dict, model_dir=params["model_dir"], prefix="training/")) record_scalars(metric_dict) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)