コード例 #1
0
    def expand(fn_name):
      if fn_name == "plain_preprocess":
        yield lambda x: x
      elif fn_name == "0_to_1":
        yield get_value_range_preprocess(0, 1)
      elif fn_name == "-1_to_1":
        yield get_value_range_preprocess(-1, 1)
      elif fn_name == "resize":
        yield get_resize_preprocess(
            utils.str2intlist(FLAGS.resize_size, 2),
            is_training and FLAGS.get_flag_value("randomize_resize_method",
                                                 False))
      elif fn_name == "resize_small":
        yield get_resize_small(FLAGS.smaller_size)
      elif fn_name == "crop":
        yield get_crop(is_training,
                       utils.str2intlist(FLAGS.crop_size, 2))
      elif fn_name == "central_crop":
        yield get_crop(False, utils.str2intlist(FLAGS.crop_size, 2))
      elif fn_name == "inception_crop":
        yield get_inception_crop(is_training)
      elif fn_name == "flip_lr":
        yield get_random_flip_lr(is_training)
      elif fn_name == "crop_inception_preprocess_patches":
        yield get_inception_preprocess_patches(
            is_training, utils.str2intlist(FLAGS.resize_size, 2),
            FLAGS.num_of_inception_patches)
      elif fn_name == "to_gray":
        yield get_to_gray_preprocess(
            FLAGS.get_flag_value("grayscale_probability", 1.0))
      elif fn_name == "crop_patches":
        yield pp_lib.get_crop_patches_fn(
            is_training,
            split_per_side=FLAGS.splits_per_side,
            patch_jitter=FLAGS.get_flag_value("patch_jitter", 0))
      elif fn_name == "standardization":
        yield get_standardization_preprocess()
      elif fn_name == "rotate":
        yield get_rotate_preprocess()

      # Below this line specific combos decomposed.
      # It would be nice to move them to the configs at some point.

      elif fn_name == "inception_preprocess":
        yield get_inception_preprocess(
            is_training, utils.str2intlist(FLAGS.resize_size, 2))
      else:
        raise ValueError("Not supported preprocessing %s" % fn_name)
コード例 #2
0
    def __init__(self, update_batchnorm_params=True):
        self.update_batchnorm_params = update_batchnorm_params

        split = FLAGS.get_flag_value('train_split', 'train')
        num_samples = datasets.get_count(split)
        steps_per_epoch = num_samples // FLAGS.batch_size

        global_step = tf.train.get_or_create_global_step()
        self.global_step_inc = tf.assign_add(global_step, 1)

        # lr_scale_batch_size defines a canonical batch size that is coupled with
        # the initial learning rate. If actual batch size is not the same as
        # canonical than learning rate is linearly scaled. This is very convinient
        # as this allows to vary batch size without recomputing learning rate.
        lr_factor = 1.0
        if FLAGS.get_flag_value('lr_scale_batch_size', 0):
            lr_factor = FLAGS.batch_size / float(FLAGS.lr_scale_batch_size)

        deps = FLAGS.get_flag_value('decay_epochs', None)
        decay_epochs = utils.str2intlist(deps) if deps else [FLAGS.epochs]

        self.lr = get_lr(
            global_step,
            base_lr=FLAGS.lr * lr_factor,
            steps_per_epoch=steps_per_epoch,
            decay_epochs=decay_epochs,
            lr_decay_factor=FLAGS.get_flag_value('lr_decay_factor', 0.1),
            warmup_epochs=FLAGS.get_flag_value('warmup_epochs', 0))
コード例 #3
0
def serving_input_fn():  # pylint: disable=missing-docstring
  """A serving input fn."""
  input_shape = utils.str2intlist(FLAGS.serving_input_shape)
  image_features = {
      FLAGS.serving_input_key:
          tf.placeholder(dtype=tf.float32, shape=input_shape)}
  return tf.estimator.export.ServingInputReceiver(
      features=image_features, receiver_tensors=image_features)
コード例 #4
0
def serving_input_fn():
    """A serving input fn."""
    input_shape = utils.str2intlist(
        FLAGS.get_flag_value('serving_input_shape', 'None,None,None,3'))
    image_features = {
        FLAGS.get_flag_value('serving_input_key', 'image'):
        tf.placeholder(dtype=tf.float32, shape=input_shape)
    }
    return tf.estimator.export.ServingInputReceiver(
        features=image_features, receiver_tensors=image_features)
コード例 #5
0
ファイル: trainer.py プロジェクト: ZhangYH0502/TS-SSL
    def create_model_fn(is_training):  # pylint: disable=missing-docstring
        input_shape = utils.str2intlist(FLAGS.serving_input_shape)
        img = tf.placeholder(shape=input_shape, dtype=tf.float32)

        # This is an example of calling `apply_model_semi` with only one of the
        # inputs provided. The outputs will simply use the given names:
        end_points, predictions = model_fn(img, is_training)

        # Register both the class output and all endpoints to the hub module.
        hub.add_signature(inputs={'image': img}, outputs=predictions)
        hub.add_signature(inputs={'image': img},
                          outputs=end_points,
                          name='representation')
コード例 #6
0
def get_resize_preprocess(fn_args, is_training):
  # This checks if the string "randomize_method" is present anywhere in the
  # args. If it is, during training, enable randomization, but not during test.
  # That's so that a call can look like `resize(256, randomize_method)` or
  # `resize(randomize_method, 256, 128)` and they all work as expected.
  try:
    fn_args.remove("randomize_method")
    randomize_resize_method = is_training
  except ValueError:
    randomize_resize_method = False
  im_size = utils.str2intlist(fn_args, 2)

  def _resize(image, method, align_corners):

    def _process():
      # The resized_images are of type float32 and might fall outside of range
      # [0, 255].
      resized = tf.cast(
          tf.image.resize_images(
              image, im_size, method, align_corners=align_corners),
          dtype=tf.float32)
      return resized

    return _process

  def _resize_pp(data):
    im = data["image"]

    if randomize_resize_method:
      # pick random resizing method
      r = tf.random_uniform([], 0, 3, dtype=tf.int32)
      im = tf.case({
          tf.equal(r, tf.cast(0, r.dtype)):
              _resize(im, tf.image.ResizeMethod.BILINEAR, True),
          tf.equal(r, tf.cast(1, r.dtype)):
              _resize(im, tf.image.ResizeMethod.NEAREST_NEIGHBOR, True),
          tf.equal(r, tf.cast(2, r.dtype)):
              _resize(im, tf.image.ResizeMethod.BICUBIC, True),
          # NOTE: use align_corners=False for AREA resize, but True for the
          # others. See https://github.com/tensorflow/tensorflow/issues/6720
          tf.equal(r, tf.cast(3, r.dtype)):
              _resize(im, tf.image.ResizeMethod.AREA, False),
      })
    else:
      im = tf.image.resize_images(im, im_size)
    data["image"] = im
    return data

  return _resize_pp
コード例 #7
0
    def expand(fn_name, args):
      if fn_name == "plain_preprocess":
        yield lambda x: x
      elif fn_name == "0_to_1":
        yield get_value_range_preprocess(0, 1)
      elif fn_name == "-1_to_1":
        yield get_value_range_preprocess(-1, 1)
      elif fn_name == "value_range":
        yield get_value_range_preprocess(*map(float, args))
      elif fn_name == "resize":
        yield get_resize_preprocess(args, is_training)
      elif fn_name == "resize_small":
        yield get_resize_small(int(args[0]))
      elif fn_name == "crop":
        yield get_crop(is_training, utils.str2intlist(args, 2))
      elif fn_name == "central_crop":
        yield get_crop(False, utils.str2intlist(args, 2))
      elif fn_name == "multi_crop":
        yield get_multi_crop(utils.str2intlist(args, 2))
      elif fn_name == "inception_crop":
        yield get_inception_crop(is_training)
      elif fn_name == "flip_lr":
        yield get_random_flip_lr(is_training)
      elif fn_name == "hsvnoise":
        # TODO(lbeyer): expose the parameters? Or maybe just a scale parameter?
        yield get_hsvnoise_preprocess(*args)
      elif fn_name == "crop_inception_preprocess_patches":
        npatch = int(args[0])
        size = utils.str2intlist(args[1:], 2)
        yield get_inception_preprocess_patches(is_training, size, npatch)
      elif fn_name == "crop_inception_patches":
        npatch = int(args[0])
        size = utils.str2intlist(args[1:], 2)
        yield get_inception_crop_patches(size, npatch)
      elif fn_name == "to_gray":
        yield get_to_gray_preprocess(float(get(args, 0, 1.0)))
      elif fn_name == "standardize":
        yield get_standardize_preprocess()
      elif fn_name == "rotate":
        yield get_rotate_preprocess()
      elif fn_name == "copy_label":
        yield get_copy_label_preprocess(get(args, 0, "copy_label"))

      # Below this line specific combos decomposed.
      # It would be nice to move them to the configs at some point.

      elif fn_name == "inception_preprocess":
        yield get_inception_preprocess(is_training, utils.str2intlist(args, 2))
      else:
        raise ValueError("Not supported preprocessing %s" % fn_name)
コード例 #8
0
def train_and_eval():
  """Trains a network on (self) supervised data."""
  checkpoint_dir = FLAGS.get_flag_value("checkpoint", FLAGS.workdir)
  tf.gfile.MakeDirs(checkpoint_dir)

  if FLAGS.tpu_name:
    cluster = TPUClusterResolver(tpu=[FLAGS.tpu_name])
  else:
    cluster = None

  # tf.logging.info("master: %s", master)
  config = RunConfig(
      model_dir=checkpoint_dir,
      tf_random_seed=FLAGS.random_seed,
      cluster=cluster,
      keep_checkpoint_max=None,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      tpu_config=TPUConfig(iterations_per_loop=TPU_ITERATIONS_PER_LOOP))

  # Optionally resume from a stored checkpoint.
  if FLAGS.path_to_initial_ckpt:
    warm_start_from = tf.estimator.WarmStartSettings(
        ckpt_to_initialize_from=FLAGS.path_to_initial_ckpt,
        # The square bracket is important for loading all the
        # variables from GLOBAL_VARIABLES collection.
        # See https://www.tensorflow.org/api_docs/python/tf/estimator/WarmStartSettings  # pylint: disable=line-too-long
        # section vars_to_warm_start for more details.
        vars_to_warm_start=[FLAGS.vars_to_restore]
    )
  else:
    warm_start_from = None

  # The global batch-sizes are passed to the TPU estimator, and it will pass
  # along the local batch size in the model_fn's `params` argument dict.
  estimator = TPUEstimator(
      model_fn=semi_supervised.get_model(FLAGS.task),
      model_dir=checkpoint_dir,
      config=config,
      use_tpu=FLAGS.tpu_name is not None,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.get_flag_value("eval_batch_size", FLAGS.batch_size),
      warm_start_from=warm_start_from
  )

  if FLAGS.run_eval:
    data_fn = functools.partial(
        datasets.get_data,
        split_name=FLAGS.val_split,
        preprocessing=FLAGS.get_flag_value("preprocessing_eval",
                                           FLAGS.preprocessing),
        is_training=False,
        shuffle=False,
        num_epochs=1,
        drop_remainder=True)

    # Contrary to what the documentation claims, the `train` and the
    # `evaluate` functions NEED to have `max_steps` and/or `steps` set and
    # cannot make use of the iterator's end-of-input exception, so we need
    # to do some math for that here.
    num_samples = datasets.get_count(FLAGS.val_split)
    num_steps = num_samples // FLAGS.get_flag_value("eval_batch_size",
                                                    FLAGS.batch_size)
    tf.logging.info("val_steps: %d", num_steps)

    for checkpoint in checkpoints_iterator(
        estimator.model_dir, timeout=FLAGS.eval_timeout_mins * 60):

      result_dict_val = estimator.evaluate(
          checkpoint_path=checkpoint, input_fn=data_fn, steps=num_steps)

      hub_exporter = hub.LatestModuleExporter("hub", serving_input_fn)
      hub_exporter.export(
          estimator,
          os.path.join(checkpoint_dir, "export/hub"),
          checkpoint)
      # This is here instead of using the above `checkpoints_iterator`'s
      # `timeout_fn` param, because that would wait forever on failed
      # trainers which will never create this file.
      if tf.gfile.Exists(os.path.join(FLAGS.workdir, "TRAINING_IS_DONE")):
        break

    # Evaluates the latest checkpoint on validation set.
    result_dict_val = estimator.evaluate(input_fn=data_fn, steps=num_steps)
    tf.logging.info(result_dict_val)

    # Optionally evaluates the latest checkpoint on test set.
    if FLAGS.test_split:
      data_fn = functools.partial(
          datasets.get_data,
          split_name=FLAGS.test_split,
          preprocessing=FLAGS.get_flag_value("preprocessing_eval",
                                             FLAGS.preprocessing),
          is_training=False,
          shuffle=False,
          num_epochs=1,
          drop_remainder=True)
      num_samples = datasets.get_count(FLAGS.test_split)
      num_steps = num_samples // FLAGS.get_flag_value("eval_batch_size",
                                                      FLAGS.batch_size)
      result_dict_test = estimator.evaluate(input_fn=data_fn, steps=num_steps)
      tf.logging.info(result_dict_test)
    return result_dict_val

  else:
    train_data_fn = functools.partial(
        datasets.get_data,
        split_name=FLAGS.train_split,
        preprocessing=FLAGS.preprocessing,
        is_training=True,
        num_epochs=None,  # read data indefenitely for training
        drop_remainder=True)

    # We compute the number of steps and make use of Estimator's max_steps
    # arguments instead of relying on the Dataset's iterator to run out after
    # a number of epochs so that we can use "fractional" epochs, which are
    # used by regression tests. (And because TPUEstimator needs it anyways.)
    num_samples = datasets.get_count(FLAGS.train_split)
    if FLAGS.num_supervised_examples:
      num_samples = FLAGS.num_supervised_examples
    # Depending on whether we drop the last batch each epoch or only at the
    # ver end, this should be ordered differently for rounding.
    updates_per_epoch = num_samples // FLAGS.batch_size
    epochs = utils.str2intlist(FLAGS.schedule, strict_int=False)[-1]
    num_steps = int(math.ceil(epochs * updates_per_epoch))
    tf.logging.info("train_steps: %d", num_steps)

    return estimator.train(
        train_data_fn,
        max_steps=num_steps)
コード例 #9
0
def creates_estimator_model(images, labels, perms, num_classes, mode):
    """Creates EstimatorSpec for the patch based self supervised models.

  Args:
    images: images
    labels: self supervised labels (class indices)
    perms: patch permutations
    num_classes: number of different permutations
    mode: model's mode: training, eval or prediction

  Returns:
    EstimatorSpec
  """
    print('   +++ Mode: %s, images: %s, labels: %s' % (mode, images, labels))

    images = tf.reshape(images, shape=[-1] + images.get_shape().as_list()[-3:])
    if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]:
        with tf.variable_scope('module'):
            image_fn = lambda: images
            logits = apply_model(
                image_fn=image_fn,
                is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                num_outputs=num_classes,
                perms=perms,
                make_signature=False)
    else:
        input_shape = utils.str2intlist(
            FLAGS.get_flag_value('serving_input_shape', 'None,None,None,3'))
        image_fn = lambda: tf.placeholder(  # pylint: disable=g-long-lambda
            shape=input_shape,
            dtype=tf.float32)

        apply_model_function = functools.partial(apply_model,
                                                 image_fn=image_fn,
                                                 num_outputs=num_classes,
                                                 perms=perms,
                                                 make_signature=True)

        tf_hub_module_spec = hub.create_module_spec(
            apply_model_function, [(utils.TAGS_IS_TRAINING, {
                'is_training': True
            }), (set(), {
                'is_training': False
            })],
            drop_collections=['summaries'])
        tf_hub_module = hub.Module(tf_hub_module_spec,
                                   trainable=False,
                                   tags=set())
        hub.register_module_for_export(tf_hub_module, export_name='module')
        logits = tf_hub_module(images)
        return make_estimator(mode, predictions=logits)

    # build loss and accuracy
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
                                                          logits=logits)
    loss = tf.reduce_mean(loss)

    eval_metrics = (
        lambda labels, logits: {  # pylint: disable=g-long-lambda
            'accuracy':
            tf.metrics.accuracy(labels=labels,
                                predictions=tf.argmax(logits, axis=-1))
        },
        [labels, logits])
    return make_estimator(mode, loss, eval_metrics, logits)
コード例 #10
0
def model_fn(data, mode):
    """Produces a loss for the rotation task.

  Args:
    data: Dict of inputs containing, among others, "image" and "label."
    mode: model's mode: training, eval or prediction

  Returns:
    EstimatorSpec
  """
    num_angles = 4
    images = data['image']

    if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]:
        images = tf.reshape(images, [-1] + images.get_shape().as_list()[-3:])
        with tf.variable_scope('module'):
            image_fn = lambda: images
            logits = apply_model(
                image_fn=image_fn,
                is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                num_outputs=num_angles,
                make_signature=False)
    else:
        input_shape = utils.str2intlist(
            FLAGS.get_flag_value('serving_input_shape', 'None,None,None,3'))
        image_fn = lambda: tf.placeholder(
            shape=input_shape,  # pylint: disable=g-long-lambda
            dtype=tf.float32)
        apply_model_function = functools.partial(apply_model,
                                                 image_fn=image_fn,
                                                 num_outputs=num_angles,
                                                 make_signature=True)
        tf_hub_module_spec = hub.create_module_spec(apply_model_function,
                                                    [(utils.TAGS_IS_TRAINING, {
                                                        'is_training': True
                                                    }),
                                                     (set(), {
                                                         'is_training': False
                                                     })])
        tf_hub_module = hub.Module(tf_hub_module_spec,
                                   trainable=False,
                                   tags=set())
        hub.register_module_for_export(tf_hub_module, export_name='module')
        logits = tf_hub_module(images)

        return trainer.make_estimator(mode, predictions=logits)

    labels = tf.reshape(data['label'], [-1])

    # build loss and accuracy
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
                                                          logits=logits)
    loss = tf.reduce_mean(loss)

    eval_metrics = (
        lambda labels, logits: {  # pylint: disable=g-long-lambda
            'accuracy':
            tf.metrics.accuracy(labels=labels,
                                predictions=tf.argmax(logits, axis=-1))
        },
        [labels, logits])
    return trainer.make_estimator(mode, loss, eval_metrics, logits)