def testBuildEmptyOptimizer(self): optimizer_text_proto = """ """ optimizer_proto = optimizer_pb2.Optimizer() text_format.Merge(optimizer_text_proto, optimizer_proto) with self.assertRaises(ValueError): optimizer_builder.build(optimizer_proto)
def testBuildAdamOptimizer(self): optimizer_text_proto = """ adam_optimizer: { learning_rate: { constant_learning_rate { learning_rate: 0.002 } } } use_moving_average: false """ optimizer_proto = optimizer_pb2.Optimizer() text_format.Merge(optimizer_text_proto, optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto) self.assertTrue(isinstance(optimizer, tf.train.AdamOptimizer))
def testBuildMomentumOptimizer(self): optimizer_text_proto = """ momentum_optimizer: { learning_rate: { constant_learning_rate { learning_rate: 0.001 } } momentum_optimizer_value: 0.99 } use_moving_average: false """ optimizer_proto = optimizer_pb2.Optimizer() text_format.Merge(optimizer_text_proto, optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto) self.assertTrue(isinstance(optimizer, tf.train.MomentumOptimizer))
def testBuildMovingAverageOptimizerWithNonDefaultDecay(self): optimizer_text_proto = """ adam_optimizer: { learning_rate: { constant_learning_rate { learning_rate: 0.002 } } } use_moving_average: True moving_average_decay: 0.2 """ optimizer_proto = optimizer_pb2.Optimizer() text_format.Merge(optimizer_text_proto, optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto) self.assertTrue( isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer)) # TODO(rathodv): Find a way to not depend on the private members. self.assertAlmostEqual(optimizer._ema._decay, 0.2)
def testBuildRMSPropOptimizer(self): optimizer_text_proto = """ rms_prop_optimizer: { learning_rate: { exponential_decay_learning_rate { initial_learning_rate: 0.004 decay_steps: 800720 decay_factor: 0.95 } } momentum_optimizer_value: 0.9 decay: 0.9 epsilon: 1.0 } use_moving_average: false """ optimizer_proto = optimizer_pb2.Optimizer() text_format.Merge(optimizer_text_proto, optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto) self.assertTrue(isinstance(optimizer, tf.train.RMSPropOptimizer))
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name, is_chief, train_dir, graph_hook_fn=None): """Training function for detection models. Args: create_tensor_dict_fn: a function to create a tensor input dictionary. create_model_fn: a function that creates a DetectionModel and generates losses. train_config: a train_pb2.TrainConfig protobuf. master: BNS name of the TensorFlow master to use. task: The task id of this training instance. num_clones: The number of clones to run per machine. worker_replicas: The number of work replicas to train with. clone_on_cpu: True if clones should be forced to run on CPU. ps_tasks: Number of parameter server tasks. worker_job_name: Name of the worker job. is_chief: Whether this replica is the chief replica. train_dir: Directory to write checkpoints and training summaries to. graph_hook_fn: Optional function that is called after the inference graph is built (before optimization). This is helpful to perform additional changes to the training graph such as adding FakeQuant ops. The function should modify the default graph. Raises: ValueError: If both num_clones > 1 and train_config.sync_replicas is true. """ detection_model = create_model_fn() data_augmentation_options = [ preprocessor_builder.build(step) for step in train_config.data_augmentation_options ] with tf.Graph().as_default(): # Build a configuration specifying multi-GPU and multi-replicas. deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, replica_id=task, num_replicas=worker_replicas, num_ps_tasks=ps_tasks, worker_job_name=worker_job_name) # Place the global step on the device storing the variables. with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() if num_clones != 1 and train_config.sync_replicas: raise ValueError('In Synchronous SGD mode num_clones must ', 'be 1. Found num_clones: {}'.format(num_clones)) batch_size = train_config.batch_size // num_clones if train_config.sync_replicas: batch_size //= train_config.replicas_to_aggregate with tf.device(deploy_config.inputs_device()): input_queue = create_input_queue( batch_size, create_tensor_dict_fn, train_config.batch_queue_capacity, train_config.num_batch_queue_threads, train_config.prefetch_queue_capacity, data_augmentation_options) # Gather initial summaries. # TODO(rathodv): See if summaries can be added/extracted from global tf # collections so that they don't have to be passed around. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) global_summaries = set([]) model_fn = functools.partial(_create_losses, create_model_fn=create_model_fn, train_config=train_config) clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue]) first_clone_scope = clones[0].scope if graph_hook_fn: with tf.device(deploy_config.variables_device()): graph_hook_fn() # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) with tf.device(deploy_config.optimizer_device()): training_optimizer, optimizer_summary_vars = optimizer_builder.build( train_config.optimizer) for var in optimizer_summary_vars: tf.summary.scalar(var.op.name, var, family='LearningRate') sync_optimizer = None if train_config.sync_replicas: training_optimizer = tf.train.SyncReplicasOptimizer( training_optimizer, replicas_to_aggregate=train_config.replicas_to_aggregate, total_num_replicas=worker_replicas) sync_optimizer = training_optimizer with tf.device(deploy_config.optimizer_device()): regularization_losses = ( None if train_config.add_regularization_loss else []) total_loss, grads_and_vars = model_deploy.optimize_clones( clones, training_optimizer, regularization_losses=regularization_losses) total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') # Optionally multiply bias gradients by train_config.bias_grad_multiplier. if train_config.bias_grad_multiplier: biases_regex_list = ['.*/biases'] grads_and_vars = variables_helper.multiply_gradients_matching_regex( grads_and_vars, biases_regex_list, multiplier=train_config.bias_grad_multiplier) # Optionally freeze some layers by setting their gradients to be zero. if train_config.freeze_variables: grads_and_vars = variables_helper.freeze_gradients_matching_regex( grads_and_vars, train_config.freeze_variables) # Optionally clip gradients if train_config.gradient_clipping_by_norm > 0: with tf.name_scope('clip_grads'): grads_and_vars = slim.learning.clip_gradient_norms( grads_and_vars, train_config.gradient_clipping_by_norm) # Create gradient updates. grad_updates = training_optimizer.apply_gradients( grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops, name='update_barrier') with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add summaries. for model_var in slim.get_model_variables(): global_summaries.add( tf.summary.histogram('ModelVars/' + model_var.op.name, model_var)) for loss_tensor in tf.losses.get_losses(): global_summaries.add( tf.summary.scalar('Losses/' + loss_tensor.op.name, loss_tensor)) global_summaries.add( tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss())) # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) summaries |= global_summaries # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) # Save checkpoints regularly. keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours saver = tf.train.Saver( keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) # Create ops required to initialize the model from a given checkpoint. init_fn = None if train_config.fine_tune_checkpoint: if not train_config.fine_tune_checkpoint_type: # train_config.from_detection_checkpoint field is deprecated. For # backward compatibility, fine_tune_checkpoint_type is set based on # from_detection_checkpoint. if train_config.from_detection_checkpoint: train_config.fine_tune_checkpoint_type = 'detection' else: train_config.fine_tune_checkpoint_type = 'classification' var_map = detection_model.restore_map( fine_tune_checkpoint_type=train_config. fine_tune_checkpoint_type, load_all_detection_checkpoint_vars=( train_config.load_all_detection_checkpoint_vars)) available_var_map = ( variables_helper.get_variables_available_in_checkpoint( var_map, train_config.fine_tune_checkpoint, include_global_step=False)) init_saver = tf.train.Saver(available_var_map) def initializer_fn(sess): init_saver.restore(sess, train_config.fine_tune_checkpoint) init_fn = initializer_fn slim.learning.train( train_tensor, logdir=train_dir, master=master, is_chief=is_chief, session_config=session_config, startup_delay_steps=train_config.startup_delay_steps, init_fn=init_fn, summary_op=summary_op, number_of_steps=(train_config.num_steps if train_config.num_steps else None), save_summaries_secs=120, sync_optimizer=sync_optimizer, saver=saver)
def train_loop( hparams, pipeline_config_path, model_dir, config_override=None, train_steps=None, use_tpu=False, save_final_config=False, export_to_tpu=None, checkpoint_every_n=1000, **kwargs): """Trains a model using eager + functions. This method: 1. Processes the pipeline configs 2. (Optionally) saves the as-run config 3. Builds the model & optimizer 4. Gets the training input data 5. Loads a fine-tuning detection or classification checkpoint if requested 6. Loops over the train data, executing distributed training steps inside tf.functions. 7. Checkpoints the model every `checkpoint_every_n` training steps. 8. Logs the training metrics as TensorBoard summaries. Args: hparams: A `HParams`. pipeline_config_path: A path to a pipeline config file. model_dir: The directory to save checkpoints and summaries to. config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to override the config from `pipeline_config_path`. train_steps: Number of training steps. If None, the number of training steps is set from the `TrainConfig` proto. use_tpu: Boolean, whether training and evaluation should run on TPU. save_final_config: Whether to save final config (obtained after applying overrides) to `model_dir`. export_to_tpu: When use_tpu and export_to_tpu are true, `export_savedmodel()` exports a metagraph for serving on TPU besides the one on CPU. If export_to_tpu is not provided, we will look for it in hparams too. checkpoint_every_n: Checkpoint every n training steps. **kwargs: Additional keyword arguments for configuration override. """ ## Parse the configs get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[ 'get_configs_from_pipeline_file'] merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[ 'merge_external_params_with_configs'] create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[ 'create_pipeline_proto_from_configs'] configs = get_configs_from_pipeline_file( pipeline_config_path, config_override=config_override) kwargs.update({ 'train_steps': train_steps, 'use_bfloat16': configs['train_config'].use_bfloat16 and use_tpu }) configs = merge_external_params_with_configs( configs, hparams, kwargs_dict=kwargs) model_config = configs['model'] train_config = configs['train_config'] train_input_config = configs['train_input_config'] unpad_groundtruth_tensors = train_config.unpad_groundtruth_tensors use_bfloat16 = train_config.use_bfloat16 add_regularization_loss = train_config.add_regularization_loss clip_gradients_value = None if train_config.gradient_clipping_by_norm > 0: clip_gradients_value = train_config.gradient_clipping_by_norm # update train_steps from config but only when non-zero value is provided if train_steps is None and train_config.num_steps != 0: train_steps = train_config.num_steps # Read export_to_tpu from hparams if not passed. if export_to_tpu is None: export_to_tpu = hparams.get('export_to_tpu', False) tf.logging.info( 'train_loop: use_tpu %s, export_to_tpu %s', use_tpu, export_to_tpu) # Parse the checkpoint fine tuning configs if hparams.load_pretrained: fine_tune_checkpoint_path = train_config.fine_tune_checkpoint else: fine_tune_checkpoint_path = None load_all_detection_checkpoint_vars = ( train_config.load_all_detection_checkpoint_vars) # TODO(kaftan) (or anyone else): move this piece of config munging to ## utils/config_util.py if not train_config.fine_tune_checkpoint_type: # train_config.from_detection_checkpoint field is deprecated. For # backward compatibility, set train_config.fine_tune_checkpoint_type # based on train_config.from_detection_checkpoint. if train_config.from_detection_checkpoint: train_config.fine_tune_checkpoint_type = 'detection' else: train_config.fine_tune_checkpoint_type = 'classification' fine_tune_checkpoint_type = train_config.fine_tune_checkpoint_type # Write the as-run pipeline config to disk. if save_final_config: pipeline_config_final = create_pipeline_proto_from_configs(configs) config_util.save_pipeline_config(pipeline_config_final, model_dir) # TODO(kaftan): Either make strategy a parameter of this method, or ## grab it w/ Distribution strategy's get_scope # Build the model, optimizer, and training input strategy = tf.compat.v2.distribute.MirroredStrategy() with strategy.scope(): detection_model = model_builder.build( model_config=model_config, is_training=True) # Create the inputs. train_input = inputs.train_input( train_config=train_config, train_input_config=train_input_config, model_config=model_config, model=detection_model) train_input = strategy.experimental_distribute_dataset( train_input.repeat()) global_step = tf.compat.v2.Variable( 0, trainable=False, dtype=tf.compat.v2.dtypes.int64) optimizer, (learning_rate,) = optimizer_builder.build( train_config.optimizer, global_step=global_step) if callable(learning_rate): learning_rate_fn = learning_rate else: learning_rate_fn = lambda: learning_rate ## Train the model summary_writer = tf.compat.v2.summary.create_file_writer(model_dir + '/train') with summary_writer.as_default(): with strategy.scope(): # Load a fine-tuning checkpoint. if fine_tune_checkpoint_path: load_fine_tune_checkpoint(detection_model, fine_tune_checkpoint_path, fine_tune_checkpoint_type, load_all_detection_checkpoint_vars, train_input, unpad_groundtruth_tensors, use_tpu, use_bfloat16) ckpt = tf.compat.v2.train.Checkpoint( step=global_step, model=detection_model) manager = tf.compat.v2.train.CheckpointManager( ckpt, model_dir, max_to_keep=7) ## Maybe re-enable checkpoint restoration depending on how it works: # ckpt.restore(manager.latest_checkpoint) def train_step_fn(features, labels): return eager_train_step( detection_model, features, labels, unpad_groundtruth_tensors, optimizer, learning_rate=learning_rate_fn(), use_bfloat16=use_bfloat16, add_regularization_loss=add_regularization_loss, clip_gradients_value=clip_gradients_value, use_tpu=use_tpu, global_step=global_step, num_replicas=strategy.num_replicas_in_sync) @tf.function def _dist_train_step(data_iterator): """A distributed train step.""" features, labels = data_iterator.next() per_replica_losses = strategy.experimental_run_v2( train_step_fn, args=( features, labels, )) # TODO(anjalisridhar): explore if it is safe to remove the ## num_replicas scaling of the loss and switch this to a ReduceOp.Mean mean_loss = strategy.reduce( tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) return mean_loss train_input_iter = iter(train_input) for _ in range(train_steps): start_time = time.time() loss = _dist_train_step(train_input_iter) global_step.assign_add(1) end_time = time.time() tf.compat.v2.summary.scalar( 'steps_per_sec', 1.0 / (end_time - start_time), step=global_step) # TODO(kaftan): Remove this print after it is no longer helpful for ## debugging. tf.print('Finished step', global_step, end_time, loss) if int(global_step.value().numpy()) % checkpoint_every_n == 0: manager.save()
def model_fn(features, labels, mode, params=None): """Constructs the object detection model. Args: features: Dictionary of feature tensors, returned from `input_fn`. labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL, otherwise None. mode: Mode key from tf.estimator.ModeKeys. params: Parameter dictionary passed from the estimator. Returns: An `EstimatorSpec` that encapsulates the model and its serving configurations. """ params = params or {} total_loss, train_op, detections, export_outputs = None, None, None, None is_training = mode == tf.estimator.ModeKeys.TRAIN # Make sure to set the Keras learning phase. True during training, # False for inference. tf.keras.backend.set_learning_phase(is_training) detection_model = detection_model_fn(is_training=is_training, add_summaries=(not use_tpu)) scaffold_fn = None if mode == tf.estimator.ModeKeys.TRAIN: labels = unstack_batch(labels, unpad_groundtruth_tensors=train_config. unpad_groundtruth_tensors) elif mode == tf.estimator.ModeKeys.EVAL: # For evaling on train data, it is necessary to check whether groundtruth # must be unpadded. boxes_shape = (labels[fields.InputDataFields.groundtruth_boxes]. get_shape().as_list()) unpad_groundtruth_tensors = boxes_shape[ 1] is not None and not use_tpu labels = unstack_batch( labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors) if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): provide_groundtruth(detection_model, labels) preprocessed_images = features[fields.InputDataFields.image] if use_tpu and train_config.use_bfloat16: with tf.contrib.tpu.bfloat16_scope(): prediction_dict = detection_model.predict( preprocessed_images, features[fields.InputDataFields.true_image_shape]) prediction_dict = ops.bfloat16_to_float32_nested( prediction_dict) else: prediction_dict = detection_model.predict( preprocessed_images, features[fields.InputDataFields.true_image_shape]) def postprocess_wrapper(args): return detection_model.postprocess(args[0], args[1]) if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT): if use_tpu and postprocess_on_cpu: detections = tf.contrib.tpu.outside_compilation( postprocess_wrapper, (prediction_dict, features[fields.InputDataFields.true_image_shape])) else: detections = postprocess_wrapper( (prediction_dict, features[fields.InputDataFields.true_image_shape])) if mode == tf.estimator.ModeKeys.TRAIN: if train_config.fine_tune_checkpoint and hparams.load_pretrained: if not train_config.fine_tune_checkpoint_type: # train_config.from_detection_checkpoint field is deprecated. For # backward compatibility, set train_config.fine_tune_checkpoint_type # based on train_config.from_detection_checkpoint. if train_config.from_detection_checkpoint: train_config.fine_tune_checkpoint_type = 'detection' else: train_config.fine_tune_checkpoint_type = 'classification' asg_map = detection_model.restore_map( fine_tune_checkpoint_type=train_config. fine_tune_checkpoint_type, load_all_detection_checkpoint_vars=( train_config.load_all_detection_checkpoint_vars)) available_var_map = ( variables_helper.get_variables_available_in_checkpoint( asg_map, train_config.fine_tune_checkpoint, include_global_step=False)) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint( train_config.fine_tune_checkpoint, available_var_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint( train_config.fine_tune_checkpoint, available_var_map) if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): losses_dict = detection_model.loss( prediction_dict, features[fields.InputDataFields.true_image_shape]) losses = [loss_tensor for loss_tensor in losses_dict.values()] if train_config.add_regularization_loss: regularization_losses = detection_model.regularization_losses() if use_tpu and train_config.use_bfloat16: regularization_losses = ops.bfloat16_to_float32_nested( regularization_losses) if regularization_losses: regularization_loss = tf.add_n(regularization_losses, name='regularization_loss') losses.append(regularization_loss) losses_dict[ 'Loss/regularization_loss'] = regularization_loss total_loss = tf.add_n(losses, name='total_loss') losses_dict['Loss/total_loss'] = total_loss if 'graph_rewriter_config' in configs: graph_rewriter_fn = graph_rewriter_builder.build( configs['graph_rewriter_config'], is_training=is_training) graph_rewriter_fn() # TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we # can write learning rate summaries on TPU without host calls. global_step = tf.train.get_or_create_global_step() training_optimizer, optimizer_summary_vars = optimizer_builder.build( train_config.optimizer) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: training_optimizer = tf.contrib.tpu.CrossShardOptimizer( training_optimizer) # Optionally freeze some layers by setting their gradients to be zero. trainable_variables = None include_variables = (train_config.update_trainable_variables if train_config.update_trainable_variables else None) exclude_variables = (train_config.freeze_variables if train_config.freeze_variables else None) trainable_variables = tf.contrib.framework.filter_variables( tf.trainable_variables(), include_patterns=include_variables, exclude_patterns=exclude_variables) clip_gradients_value = None if train_config.gradient_clipping_by_norm > 0: clip_gradients_value = train_config.gradient_clipping_by_norm if not use_tpu: for var in optimizer_summary_vars: tf.summary.scalar(var.op.name, var) summaries = [] if use_tpu else None if train_config.summarize_gradients: summaries = [ 'gradients', 'gradient_norm', 'global_gradient_norm' ] train_op = tf.contrib.layers.optimize_loss( loss=total_loss, global_step=global_step, learning_rate=None, clip_gradients=clip_gradients_value, optimizer=training_optimizer, update_ops=detection_model.updates(), variables=trainable_variables, summaries=summaries, name='') # Preventing scope prefix on all variables. if mode == tf.estimator.ModeKeys.PREDICT: exported_output = exporter_lib.add_output_tensor_nodes(detections) export_outputs = { tf.saved_model.signature_constants.PREDICT_METHOD_NAME: tf.estimator.export.PredictOutput(exported_output) } eval_metric_ops = None scaffold = None if mode == tf.estimator.ModeKeys.EVAL: class_agnostic = (fields.DetectionResultFields.detection_classes not in detections) groundtruth = _prepare_groundtruth_for_eval( detection_model, class_agnostic, eval_input_config.max_number_of_boxes) use_original_images = fields.InputDataFields.original_image in features if use_original_images: eval_images = features[fields.InputDataFields.original_image] true_image_shapes = tf.slice( features[fields.InputDataFields.true_image_shape], [0, 0], [-1, 3]) original_image_spatial_shapes = features[ fields.InputDataFields.original_image_spatial_shape] else: eval_images = features[fields.InputDataFields.image] true_image_shapes = None original_image_spatial_shapes = None eval_dict = eval_util.result_dict_for_batched_example( eval_images, features[inputs.HASH_KEY], detections, groundtruth, class_agnostic=class_agnostic, scale_to_absolute=True, original_image_spatial_shapes=original_image_spatial_shapes, true_image_shapes=true_image_shapes) if class_agnostic: category_index = label_map_util.create_class_agnostic_category_index( ) else: category_index = label_map_util.create_category_index_from_labelmap( eval_input_config.label_map_path) vis_metric_ops = None if not use_tpu and use_original_images: eval_metric_op_vis = vis_utils.VisualizeSingleFrameDetections( category_index, max_examples_to_draw=eval_config.num_visualizations, max_boxes_to_draw=eval_config.max_num_boxes_to_visualize, min_score_thresh=eval_config.min_score_threshold, use_normalized_coordinates=False) vis_metric_ops = eval_metric_op_vis.get_estimator_eval_metric_ops( eval_dict) # Eval metrics on a single example. eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators( eval_config, list(category_index.values()), eval_dict) for loss_key, loss_tensor in iter(losses_dict.items()): eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor) for var in optimizer_summary_vars: eval_metric_ops[var.op.name] = (var, tf.no_op()) if vis_metric_ops is not None: eval_metric_ops.update(vis_metric_ops) eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()} if eval_config.use_moving_averages: variable_averages = tf.train.ExponentialMovingAverage(0.0) variables_to_restore = variable_averages.variables_to_restore() keep_checkpoint_every_n_hours = ( train_config.keep_checkpoint_every_n_hours) saver = tf.train.Saver( variables_to_restore, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours ) scaffold = tf.train.Scaffold(saver=saver) # EVAL executes on CPU, so use regular non-TPU EstimatorSpec. if use_tpu and mode != tf.estimator.ModeKeys.EVAL: return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, scaffold_fn=scaffold_fn, predictions=detections, loss=total_loss, train_op=train_op, eval_metrics=eval_metric_ops, export_outputs=export_outputs) else: if scaffold is None: keep_checkpoint_every_n_hours = ( train_config.keep_checkpoint_every_n_hours) saver = tf.train.Saver( sharded=True, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) scaffold = tf.train.Scaffold(saver=saver) return tf.estimator.EstimatorSpec(mode=mode, predictions=detections, loss=total_loss, train_op=train_op, eval_metric_ops=eval_metric_ops, export_outputs=export_outputs, scaffold=scaffold)