Пример #1
0
def finalize_serving(model_output, export_config):
  """Adds extra layers based on the provided configuration."""

  finalize_method = export_config.finalize_method
  output_layer = model_output
  if not finalize_method or finalize_method[0] == 'none':
    return output_layer
  discrete = False
  for i in range(len(finalize_method)):
    if finalize_method[i] == 'argmax':
      discrete = True
      is_argmax_last = (i + 1) == len(finalize_method)
      if is_argmax_last:
        output_layer = tf.argmax(
            output_layer, axis=3, output_type=tf.dtypes.int32)
      else:
        # TODO(tohaspiridonov): add first_match=False when cl/383951533 submited
        output_layer = custom_layers.argmax(
            output_layer, keepdims=True, epsilon=1e-3)
    elif finalize_method[i] == 'squeeze':
      output_layer = tf.squeeze(output_layer, axis=3)
    else:
      resize_params = finalize_method[i].split('resize')
      if len(resize_params) != 2 or resize_params[0]:
        raise ValueError('Cannot finalize with ' + finalize_method[i] + '.')
      resize_to_size = int(resize_params[1])
      if discrete:
        output_layer = tf.image.resize(
            output_layer, [resize_to_size, resize_to_size],
            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
      else:
        output_layer = tf.image.resize(
            output_layer, [resize_to_size, resize_to_size],
            method=tf.image.ResizeMethod.BILINEAR)
  return output_layer
Пример #2
0
 def test_reference_match(self, shape, input_type, output_type):
     random_inputs = tf.random.uniform(shape=shape,
                                       maxval=10,
                                       dtype=input_type)
     for axis in range(-len(shape) + 1, len(shape)):
         control_output = tf.math.argmax(random_inputs,
                                         axis=axis,
                                         output_type=output_type)
         test_output = custom_layers.argmax(random_inputs,
                                            axis=axis,
                                            output_type=output_type)
         self.assertAllEqual(control_output, test_output)