コード例 #1
0
  def _save_checkpoint_from_mock_model(self,
                                       checkpoint_path,
                                       use_moving_averages,
                                       quantize=False,
                                       num_channels=3):
    g = tf.Graph()
    with g.as_default():
      mock_model = FakeModel()
      inputs = tf.placeholder(tf.float32, shape=[1, 10, 10, num_channels])
      mock_model.predict(inputs, true_image_shapes=None)
      if use_moving_averages:
        tf.train.ExponentialMovingAverage(0.0).apply()
      tf.train.get_or_create_global_step()
      if quantize:
        graph_rewriter_config = graph_rewriter_pb2.GraphRewriter()
        graph_rewriter_config.quantization.delay = 500000
        graph_rewriter_fn = graph_rewriter_builder.build(
            graph_rewriter_config, is_training=False)
        graph_rewriter_fn()

      saver = tf.train.Saver()
      init = tf.global_variables_initializer()
      with self.test_session() as sess:
        sess.run(init)
        saver.save(sess, checkpoint_path)
コード例 #2
0
def main(unused_argv):
    assert FLAGS.checkpoint_dir, '`checkpoint_dir` is missing.'
    assert FLAGS.eval_dir, '`eval_dir` is missing.'
    tf.gfile.MakeDirs(FLAGS.eval_dir)
    if FLAGS.pipeline_config_path:
        configs = config_util.get_configs_from_pipeline_file(
            FLAGS.pipeline_config_path)
        tf.gfile.Copy(FLAGS.pipeline_config_path,
                      os.path.join(FLAGS.eval_dir, 'pipeline.config'),
                      overwrite=True)
    else:
        configs = config_util.get_configs_from_multiple_files(
            model_config_path=FLAGS.model_config_path,
            eval_config_path=FLAGS.eval_config_path,
            eval_input_config_path=FLAGS.input_config_path)
        for name, config in [('model.config', FLAGS.model_config_path),
                             ('eval.config', FLAGS.eval_config_path),
                             ('input.config', FLAGS.input_config_path)]:
            tf.gfile.Copy(config,
                          os.path.join(FLAGS.eval_dir, name),
                          overwrite=True)

    model_config = configs['model']
    eval_config = configs['eval_config']
    input_config = configs['eval_input_config']
    if FLAGS.eval_training_data:
        input_config = configs['train_input_config']

    model_fn = functools.partial(model_builder.build,
                                 model_config=model_config,
                                 is_training=False)

    def get_next(config):
        return dataset_builder.make_initializable_iterator(
            dataset_builder.build(config)).get_next()

    create_input_dict_fn = functools.partial(get_next, input_config)

    label_map = label_map_util.load_labelmap(input_config.label_map_path)
    max_num_classes = max([item.id for item in label_map.item])
    categories = label_map_util.convert_label_map_to_categories(
        label_map, max_num_classes)

    if FLAGS.run_once:
        eval_config.max_evals = 1

    graph_rewriter_fn = None
    if 'graph_rewriter_config' in configs:
        graph_rewriter_fn = graph_rewriter_builder.build(
            configs['graph_rewriter_config'], is_training=False)

    evaluator.evaluate(create_input_dict_fn,
                       model_fn,
                       eval_config,
                       categories,
                       FLAGS.checkpoint_dir,
                       FLAGS.eval_dir,
                       graph_hook_fn=graph_rewriter_fn)
コード例 #3
0
 def testQuantizationBuilderSetsUpCorrectEvalArguments(self):
     with mock.patch.object(tf.contrib.quantize,
                            'create_eval_graph') as mock_quant_fn:
         with mock.patch.object(
                 tf.contrib.layers,
                 'summarize_collection') as mock_summarize_col:
             graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
             graph_rewriter_proto.quantization.delay = 10
             graph_rewrite_fn = graph_rewriter_builder.build(
                 graph_rewriter_proto, is_training=False)
             graph_rewrite_fn()
             _, kwargs = mock_quant_fn.call_args
             self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
             mock_summarize_col.assert_called_with('quant_vars')
コード例 #4
0
ファイル: exporter.py プロジェクト: suvratjain1995/SecureIt
def export_inference_graph(input_type,
                           pipeline_config,
                           trained_checkpoint_prefix,
                           output_directory,
                           input_shape=None,
                           output_collection_name='inference_op',
                           additional_output_tensor_names=None,
                           write_inference_graph=False):
    """Exports inference graph for the model specified in the pipeline config.

  Args:
    input_type: Type of input for the graph. Can be one of ['image_tensor',
      'encoded_image_string_tensor', 'tf_example'].
    pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
    trained_checkpoint_prefix: Path to the trained checkpoint file.
    output_directory: Path to write outputs.
    input_shape: Sets a fixed shape for an `image_tensor` input. If not
      specified, will default to [None, None, None, 3].
    output_collection_name: Name of collection to add output tensors to.
      If None, does not add output tensors to a collection.
    additional_output_tensor_names: list of additional output
      tensors to include in the frozen graph.
    write_inference_graph: If true, writes inference graph to disk.
  """
    detection_model = model_builder.build(pipeline_config.model,
                                          is_training=False)
    graph_rewriter_fn = None
    if pipeline_config.HasField('graph_rewriter'):
        graph_rewriter_config = pipeline_config.graph_rewriter
        graph_rewriter_fn = graph_rewriter_builder.build(graph_rewriter_config,
                                                         is_training=False)
    _export_inference_graph(input_type,
                            detection_model,
                            pipeline_config.eval_config.use_moving_averages,
                            trained_checkpoint_prefix,
                            output_directory,
                            additional_output_tensor_names,
                            input_shape,
                            output_collection_name,
                            graph_hook_fn=graph_rewriter_fn,
                            write_inference_graph=write_inference_graph)
    pipeline_config.eval_config.use_moving_averages = False
    config_util.save_pipeline_config(pipeline_config, output_directory)
コード例 #5
0
def export_tflite_graph(pipeline_config, trained_checkpoint_prefix, output_dir,
                        add_postprocessing_op, max_detections,
                        max_classes_per_detection):
  """Exports a tflite compatible graph and anchors for ssd detection model.

  Anchors are written to a tensor and tflite compatible graph
  is written to output_dir/tflite_graph.pb.

  Args:
    pipeline_config: a pipeline.proto object containing the configuration for
      SSD model to export.
    trained_checkpoint_prefix: a file prefix for the checkpoint containing the
      trained parameters of the SSD model.
    output_dir: A directory to write the tflite graph and anchor file to.
    add_postprocessing_op: If add_postprocessing_op is true: frozen graph adds a
      TFLite_Detection_PostProcess custom op
    max_detections: Maximum number of detections (boxes) to show
    max_classes_per_detection: Number of classes to display per detection


  Raises:
    ValueError: if the pipeline config contains models other than ssd or uses an
      fixed_shape_resizer and provides a shape as well.
  """
  tf.gfile.MakeDirs(output_dir)
  if pipeline_config.model.WhichOneof('model') != 'ssd':
    raise ValueError('Only ssd models are supported in tflite. '
                     'Found {} in config'.format(
                         pipeline_config.model.WhichOneof('model')))

  num_classes = pipeline_config.model.ssd.num_classes
  nms_score_threshold = {
      pipeline_config.model.ssd.post_processing.batch_non_max_suppression.
      score_threshold
  }
  nms_iou_threshold = {
      pipeline_config.model.ssd.post_processing.batch_non_max_suppression.
      iou_threshold
  }
  scale_values = {}
  scale_values['y_scale'] = {
      pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.y_scale
  }
  scale_values['x_scale'] = {
      pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.x_scale
  }
  scale_values['h_scale'] = {
      pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.height_scale
  }
  scale_values['w_scale'] = {
      pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.width_scale
  }

  image_resizer_config = pipeline_config.model.ssd.image_resizer
  image_resizer = image_resizer_config.WhichOneof('image_resizer_oneof')
  num_channels = _DEFAULT_NUM_CHANNELS
  if image_resizer == 'fixed_shape_resizer':
    height = image_resizer_config.fixed_shape_resizer.height
    width = image_resizer_config.fixed_shape_resizer.width
    if image_resizer_config.fixed_shape_resizer.convert_to_grayscale:
      num_channels = 1
    shape = [1, height, width, num_channels]
  else:
    raise ValueError(
        'Only fixed_shape_resizer'
        'is supported with tflite. Found {}'.format(
            image_resizer_config.WhichOneof('image_resizer_oneof')))

  image = tf.placeholder(
      tf.float32, shape=shape, name='normalized_input_image_tensor')

  detection_model = model_builder.build(
      pipeline_config.model, is_training=False)
  predicted_tensors = detection_model.predict(image, true_image_shapes=None)
  # The score conversion occurs before the post-processing custom op
  _, score_conversion_fn = post_processing_builder.build(
      pipeline_config.model.ssd.post_processing)
  class_predictions = score_conversion_fn(
      predicted_tensors['class_predictions_with_background'])

  with tf.name_scope('raw_outputs'):
    # 'raw_outputs/box_encodings': a float32 tensor of shape [1, num_anchors, 4]
    #  containing the encoded box predictions. Note that these are raw
    #  predictions and no Non-Max suppression is applied on them and
    #  no decode center size boxes is applied to them.
    tf.identity(predicted_tensors['box_encodings'], name='box_encodings')
    # 'raw_outputs/class_predictions': a float32 tensor of shape
    #  [1, num_anchors, num_classes] containing the class scores for each anchor
    #  after applying score conversion.
    tf.identity(class_predictions, name='class_predictions')
  # 'anchors': a float32 tensor of shape
  #   [4, num_anchors] containing the anchors as a constant node.
  tf.identity(
      get_const_center_size_encoded_anchors(predicted_tensors['anchors']),
      name='anchors')

  # Add global step to the graph, so we know the training step number when we
  # evaluate the model.
  tf.train.get_or_create_global_step()

  # graph rewriter
  if pipeline_config.HasField('graph_rewriter'):
    graph_rewriter_config = pipeline_config.graph_rewriter
    graph_rewriter_fn = graph_rewriter_builder.build(
        graph_rewriter_config, is_training=False)
    graph_rewriter_fn()

  # freeze the graph
  saver_kwargs = {}
  if pipeline_config.eval_config.use_moving_averages:
    saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
    moving_average_checkpoint = tempfile.NamedTemporaryFile()
    exporter.replace_variable_values_with_moving_averages(
        tf.get_default_graph(), trained_checkpoint_prefix,
        moving_average_checkpoint.name)
    checkpoint_to_use = moving_average_checkpoint.name
  else:
    checkpoint_to_use = trained_checkpoint_prefix

  saver = tf.train.Saver(**saver_kwargs)
  input_saver_def = saver.as_saver_def()
  frozen_graph_def = exporter.freeze_graph_with_def_protos(
      input_graph_def=tf.get_default_graph().as_graph_def(),
      input_saver_def=input_saver_def,
      input_checkpoint=checkpoint_to_use,
      output_node_names=','.join([
          'raw_outputs/box_encodings', 'raw_outputs/class_predictions',
          'anchors'
      ]),
      restore_op_name='save/restore_all',
      filename_tensor_name='save/Const:0',
      clear_devices=True,
      output_graph='',
      initializer_nodes='')

  # Add new operation to do post processing in a custom op (TF Lite only)
  if add_postprocessing_op:
    transformed_graph_def = append_postprocessing_op(
        frozen_graph_def, max_detections, max_classes_per_detection,
        nms_score_threshold, nms_iou_threshold, num_classes, scale_values)
  else:
    # Return frozen without adding post-processing custom op
    transformed_graph_def = frozen_graph_def

  binary_graph = os.path.join(output_dir, 'tflite_graph.pb')
  with tf.gfile.GFile(binary_graph, 'wb') as f:
    f.write(transformed_graph_def.SerializeToString())
  txt_graph = os.path.join(output_dir, 'tflite_graph.pbtxt')
  with tf.gfile.GFile(txt_graph, 'w') as f:
    f.write(str(transformed_graph_def))
コード例 #6
0
  def model_fn(features, labels, mode, params=None):
    """Constructs the object detection model.

    Args:
      features: Dictionary of feature tensors, returned from `input_fn`.
      labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
        otherwise None.
      mode: Mode key from tf.estimator.ModeKeys.
      params: Parameter dictionary passed from the estimator.

    Returns:
      An `EstimatorSpec` that encapsulates the model and its serving
        configurations.
    """
    params = params or {}
    total_loss, train_op, detections, export_outputs = None, None, None, None
    is_training = mode == tf.estimator.ModeKeys.TRAIN

    # Make sure to set the Keras learning phase. True during training,
    # False for inference.
    tf.keras.backend.set_learning_phase(is_training)
    detection_model = detection_model_fn(is_training=is_training,
                                         add_summaries=(not use_tpu))
    scaffold_fn = None

    if mode == tf.estimator.ModeKeys.TRAIN:
      labels = unstack_batch(
          labels,
          unpad_groundtruth_tensors=train_config.unpad_groundtruth_tensors)
    elif mode == tf.estimator.ModeKeys.EVAL:
      # For evaling on train data, it is necessary to check whether groundtruth
      # must be unpadded.
      boxes_shape = (
          labels[fields.InputDataFields.groundtruth_boxes].get_shape()
          .as_list())
      unpad_groundtruth_tensors = True if boxes_shape[1] is not None else False
      labels = unstack_batch(
          labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)

    if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
      gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes]
      gt_classes_list = labels[fields.InputDataFields.groundtruth_classes]
      gt_masks_list = None
      if fields.InputDataFields.groundtruth_instance_masks in labels:
        gt_masks_list = labels[
            fields.InputDataFields.groundtruth_instance_masks]
      gt_keypoints_list = None
      if fields.InputDataFields.groundtruth_keypoints in labels:
        gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints]
      if fields.InputDataFields.groundtruth_is_crowd in labels:
        gt_is_crowd_list = labels[fields.InputDataFields.groundtruth_is_crowd]
      detection_model.provide_groundtruth(
          groundtruth_boxes_list=gt_boxes_list,
          groundtruth_classes_list=gt_classes_list,
          groundtruth_masks_list=gt_masks_list,
          groundtruth_keypoints_list=gt_keypoints_list,
          groundtruth_weights_list=labels[
              fields.InputDataFields.groundtruth_weights],
          groundtruth_is_crowd_list=gt_is_crowd_list)

    preprocessed_images = features[fields.InputDataFields.image]
    prediction_dict = detection_model.observe(
        preprocessed_images, features[fields.InputDataFields.true_image_shape])
    if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
      detections = detection_model.postprocess(
          prediction_dict, features[fields.InputDataFields.true_image_shape])

    if mode == tf.estimator.ModeKeys.TRAIN:
      if train_config.fine_tune_checkpoint and hparams.load_pretrained:
        if not train_config.fine_tune_checkpoint_type:
          # train_config.from_detection_checkpoint field is deprecated. For
          # backward compatibility, set train_config.fine_tune_checkpoint_type
          # based on train_config.from_detection_checkpoint.
          if train_config.from_detection_checkpoint:
            train_config.fine_tune_checkpoint_type = 'detection'
          else:
            train_config.fine_tune_checkpoint_type = 'classification'
        asg_map = detection_model.restore_map(
            fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
            load_all_detection_checkpoint_vars=(
                train_config.load_all_detection_checkpoint_vars))
        available_var_map = (
            variables_helper.get_variables_available_in_checkpoint(
                asg_map, train_config.fine_tune_checkpoint,
                include_global_step=False))
        if use_tpu:
          def tpu_scaffold():
            tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
                                          available_var_map)
            return tf.train.Scaffold()
          scaffold_fn = tpu_scaffold
        else:
          tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
                                        available_var_map)

    if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
      losses_dict = detection_model.loss(
          prediction_dict, features[fields.InputDataFields.true_image_shape])
      losses = [loss_tensor for loss_tensor in losses_dict.values()]
      if train_config.add_regularization_loss:
        regularization_losses = tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES)
        if regularization_losses:
          regularization_loss = tf.add_n(regularization_losses,
                                         name='regularization_loss')
          losses.append(regularization_loss)
          losses_dict['Loss/regularization_loss'] = regularization_loss
      total_loss = tf.add_n(losses, name='total_loss')
      losses_dict['Loss/total_loss'] = total_loss

      if 'graph_rewriter_config' in configs:
        graph_rewriter_fn = graph_rewriter_builder.build(
            configs['graph_rewriter_config'], is_training=is_training)
        graph_rewriter_fn()

      # TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we
      # can write learning rate summaries on TPU without host calls.
      global_step = tf.train.get_or_create_global_step()
      training_optimizer, optimizer_summary_vars = optimizer_builder.build(
          train_config.optimizer)

    if mode == tf.estimator.ModeKeys.TRAIN:
      if use_tpu:
        training_optimizer = tf.contrib.tpu.CrossShardOptimizer(
            training_optimizer)

      # Optionally freeze some layers by setting their gradients to be zero.
      trainable_variables = None
      if train_config.freeze_variables:
        trainable_variables = tf.contrib.framework.filter_variables(
            tf.trainable_variables(),
            exclude_patterns=train_config.freeze_variables)

      clip_gradients_value = None
      if train_config.gradient_clipping_by_norm > 0:
        clip_gradients_value = train_config.gradient_clipping_by_norm

      if not use_tpu:
        for var in optimizer_summary_vars:
          tf.summary.scalar(var.op.name, var)
      summaries = [] if use_tpu else None
      train_op = tf.contrib.layers.optimize_loss(
          loss=total_loss,
          global_step=global_step,
          learning_rate=None,
          clip_gradients=clip_gradients_value,
          optimizer=training_optimizer,
          variables=trainable_variables,
          summaries=summaries,
          name='')  # Preventing scope prefix on all variables.

    if mode == tf.estimator.ModeKeys.PREDICT:
      export_outputs = {
          tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
              tf.estimator.export.PredictOutput(detections)
      }

    eval_metric_ops = None
    scaffold = None
    if mode == tf.estimator.ModeKeys.EVAL:
      class_agnostic = (fields.DetectionResultFields.detection_classes
                        not in detections)
      groundtruth = _prepare_groundtruth_for_eval(
          detection_model, class_agnostic)
      use_original_images = fields.InputDataFields.original_image in features
      eval_images = (
          features[fields.InputDataFields.original_image] if use_original_images
          else features[fields.InputDataFields.image])
      eval_dict = eval_util.result_dict_for_single_example(
          eval_images[0:1],
          features[inputs.HASH_KEY][0],
          detections,
          groundtruth,
          class_agnostic=class_agnostic,
          scale_to_absolute=True)

      if class_agnostic:
        category_index = label_map_util.create_class_agnostic_category_index()
      else:
        category_index = label_map_util.create_category_index_from_labelmap(
            eval_input_config.label_map_path)
      img_summary = None
      if not use_tpu and use_original_images:
        detection_and_groundtruth = (
            vis_utils.draw_side_by_side_evaluation_image(
                eval_dict, category_index, max_boxes_to_draw=20,
                min_score_thresh=0.2,
                use_normalized_coordinates=False))
        img_summary = tf.summary.image('Detections_Left_Groundtruth_Right',
                                       detection_and_groundtruth)

      # Eval metrics on a single example.
      eval_metrics = eval_config.metrics_set
      if not eval_metrics:
        eval_metrics = ['coco_detection_metrics']
      eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
          eval_metrics,
          category_index.values(),
          eval_dict,
          include_metrics_per_category=eval_config.include_metrics_per_category)
      for loss_key, loss_tensor in iter(losses_dict.items()):
        eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
      for var in optimizer_summary_vars:
        eval_metric_ops[var.op.name] = (var, tf.no_op())
      if img_summary is not None:
        eval_metric_ops['Detections_Left_Groundtruth_Right'] = (
            img_summary, tf.no_op())
      eval_metric_ops = {str(k): v for k, v in eval_metric_ops.iteritems()}

      if eval_config.use_moving_averages:
        variable_averages = tf.train.ExponentialMovingAverage(0.0)
        variables_to_restore = variable_averages.variables_to_restore()
        keep_checkpoint_every_n_hours = (
            train_config.keep_checkpoint_every_n_hours)
        saver = tf.train.Saver(
            variables_to_restore,
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
        scaffold = tf.train.Scaffold(saver=saver)

    # EVAL executes on CPU, so use regular non-TPU EstimatorSpec.
    if use_tpu and mode != tf.estimator.ModeKeys.EVAL:
      return tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          scaffold_fn=scaffold_fn,
          predictions=detections,
          loss=total_loss,
          train_op=train_op,
          eval_metrics=eval_metric_ops,
          export_outputs=export_outputs)
    else:
      return tf.estimator.EstimatorSpec(
          mode=mode,
          predictions=detections,
          loss=total_loss,
          train_op=train_op,
          eval_metric_ops=eval_metric_ops,
          export_outputs=export_outputs,
          scaffold=scaffold)
コード例 #7
0
def main(_):
    assert FLAGS.train_dir, '`train_dir` is missing.'
    if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir)
    if FLAGS.pipeline_config_path:
        configs = config_util.get_configs_from_pipeline_file(
            FLAGS.pipeline_config_path)
        if FLAGS.task == 0:
            tf.gfile.Copy(FLAGS.pipeline_config_path,
                          os.path.join(FLAGS.train_dir, 'pipeline.config'),
                          overwrite=True)
    else:
        configs = config_util.get_configs_from_multiple_files(
            model_config_path=FLAGS.model_config_path,
            train_config_path=FLAGS.train_config_path,
            train_input_config_path=FLAGS.input_config_path)
        if FLAGS.task == 0:
            for name, config in [('model.config', FLAGS.model_config_path),
                                 ('train.config', FLAGS.train_config_path),
                                 ('input.config', FLAGS.input_config_path)]:
                tf.gfile.Copy(config,
                              os.path.join(FLAGS.train_dir, name),
                              overwrite=True)

    model_config = configs['model']
    train_config = configs['train_config']
    input_config = configs['train_input_config']

    model_fn = functools.partial(model_builder.build,
                                 model_config=model_config,
                                 is_training=True)

    def get_next(config):
        return dataset_builder.make_initializable_iterator(
            dataset_builder.build(config)).get_next()

    create_input_dict_fn = functools.partial(get_next, input_config)

    env = json.loads(os.environ.get('TF_CONFIG', '{}'))
    cluster_data = env.get('cluster', None)
    cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
    task_data = env.get('task', None) or {'type': 'master', 'index': 0}
    task_info = type('TaskSpec', (object, ), task_data)

    # Parameters for a single worker.
    ps_tasks = 0
    worker_replicas = 1
    worker_job_name = 'lonely_worker'
    task = 0
    is_chief = True
    master = ''

    if cluster_data and 'worker' in cluster_data:
        # Number of total worker replicas include "worker"s and the "master".
        worker_replicas = len(cluster_data['worker']) + 1
    if cluster_data and 'ps' in cluster_data:
        ps_tasks = len(cluster_data['ps'])

    if worker_replicas > 1 and ps_tasks < 1:
        raise ValueError(
            'At least 1 ps task is needed for distributed training.')

    if worker_replicas >= 1 and ps_tasks > 0:
        # Set up distributed training.
        server = tf.train.Server(tf.train.ClusterSpec(cluster),
                                 protocol='grpc',
                                 job_name=task_info.type,
                                 task_index=task_info.index)
        if task_info.type == 'ps':
            server.join()
            return

        worker_job_name = '%s/task:%d' % (task_info.type, task_info.index)
        task = task_info.index
        is_chief = (task_info.type == 'master')
        master = server.target

    graph_rewriter_fn = None
    if 'graph_rewriter_config' in configs:
        graph_rewriter_fn = graph_rewriter_builder.build(
            configs['graph_rewriter_config'], is_training=True)

    trainer.train(create_input_dict_fn,
                  model_fn,
                  train_config,
                  master,
                  task,
                  FLAGS.num_clones,
                  worker_replicas,
                  FLAGS.clone_on_cpu,
                  ps_tasks,
                  worker_job_name,
                  is_chief,
                  FLAGS.train_dir,
                  graph_hook_fn=graph_rewriter_fn)