Ejemplo n.º 1
0
def flatten(input, axis=1, end_axis=-1):
  """
  Caffe-style flatten.

  Args:
    inputs: An N-D tensor.
    axis: The first axis to flatten: all preceding axes are retained in the
      output. May be negative to index from the end (e.g., -1 for the last
      axis).
    end_axis: The last axis to flatten: all following axes are retained in the
      output. May be negative to index from the end (e.g., the default -1 for
      the last axis)
  Returns:
      A M-D tensor where M = N - (end_axis - axis)
  """
  input_shape = tf.shape(input)
  input_rank = tf.shape(input_shape)[0]
  if axis < 0:
    axis = input_rank + axis
  if end_axis < 0:
    end_axis = input_rank + end_axis
  output_shape = []
  if axis != 0:
    output_shape.append(input_shape[:axis])
  output_shape.append([tf.reduce_prod(input_shape[axis:end_axis + 1])])
  if end_axis + 1 != input_rank:
    output_shape.append(input_shape[end_axis + 1:])
  output_shape = tf.concat(output_shape, axis=0)
  output = tf.reshape(input, output_shape)
  return output
Ejemplo n.º 2
0
    def build(self, input_shape):
        intermediate_shape = self.f_phi.compute_output_shape(input_shape)[1:]
        self.f_theta_dense = tf.keras.layers.Dense(
            tf.reduce_prod(intermediate_shape), activation="relu")
        self.reshape_theta = tf.keras.layers.Reshape(intermediate_shape,
                                                     name="ReshapeTheta")

        self.reshape_output = tf.keras.layers.Reshape(input_shape[1:],
                                                      name="ReshapeOutput")

        super().build(input_shape)
Ejemplo n.º 3
0
    def __init__(self,
                 deconv_layer_params,
                 start_decoding_size=7,
                 start_decoding_filters=8,
                 padding="same",
                 preprocess_fc_layer_params=None,
                 activation_fn=tf.nn.relu,
                 output_activation_fn=tf.nn.tanh,
                 name="image_decoding_network"):
        """
        Initialize the layers for decoding a latent vector into an image.

        The user is responsible for calculating the output size given
        `start_decoding_size` and `deconv_layer_params`, and make sure that the
        size will match the expectation. How to calculate the output size:
            if padding=="same", then H = H1 * strides
            if padding=="valid", then H = (H1-1) * strides + HF
        where H = output size, H1 = input size, HF = height of kernel

        Args:
            deconv_layer_params (list[tuple]): a non-empty list of elements
                (num_filters, kernel_size, strides).
            start_decoding_size (int): the initial size we'd like to have for
                the feature map
            start_decoding_filters (int): the initial number of fitlers we'd like
                to have for the feature map. Note that given this value and
                `start_decoding_size`, we always first project an input latent
                vector into a vector of an appropriate length so that it can be
                reshaped into (`start_decoding_size`, `start_decoding_size`,
                `start_decoding_filters`).
            padding (str): "same" or "valid", see tf.keras.layers.Conv2DTranspose
            preprocess_fc_layer_params (tuple[int]): a list of fc layer units.
                These fc layers are used for preprocessing the latent vector before
                transposed convolutions.
            activation_fn (tf.nn.activation): activation for hidden layers
            output_activation_fn (tf.nn.activation): activation for the output
                layer. Usually our image inputs are normalized to [0, 1] or [-1, 1],
                so this function should be tf.nn.sigmoid or tf.nn.tanh.
            name (str): network name
        """
        super().__init__(name=name)

        assert isinstance(deconv_layer_params, list)
        assert len(deconv_layer_params) > 0

        self._preprocess_fc_layers = []
        if preprocess_fc_layer_params is not None:
            for size in preprocess_fc_layer_params:
                self._preprocess_fc_layers.append(
                    tf.keras.layers.Dense(size, activation=activation_fn))

        # We always assume "channels_last" !
        self._start_decoding_shape = [
            start_decoding_size, start_decoding_size, start_decoding_filters
        ]

        self._preprocess_fc_layers.append(
            tf.keras.layers.Dense(
                tf.reduce_prod(self._start_decoding_shape),
                activation=activation_fn))

        self._deconv_layers = []
        for i, (filters, kernel_size,
                strides) in enumerate(deconv_layer_params):
            act_fn = activation_fn
            if i == len(deconv_layer_params) - 1:
                act_fn = output_activation_fn
            self._deconv_layers.append(
                tf.keras.layers.Conv2DTranspose(
                    padding=padding,
                    filters=filters,
                    kernel_size=kernel_size,
                    strides=strides,
                    activation=act_fn))
Ejemplo n.º 4
0
def bn(x,
       params=None,
       moments=None,
       backprop_through_moments=True,
       use_ema=False,
       is_training=True,
       ema_epsilon=.9):
  """Batch normalization.

  The usage should be as follows: If x is the support images, moments should be
  None so that they are computed from the support set examples. On the other
  hand, if x is the query images, the moments argument should be used in order
  to pass in the mean and var that were computed from the support set.

  Args:
    x: inputs.
    params: None or a dict containing the values of the offset and scale params.
    moments: None or a dict containing the values of the mean and var to use for
      batch normalization.
    backprop_through_moments: Whether to allow gradients to flow through the
      given support set moments. Only applies to non-transductive batch norm.
    use_ema: apply moving averages of batch norm statistics, or update them,
      depending on whether we are training or testing.  Note that passing
      moments will override this setting, and result in neither updating or
      using ema statistics.  This is important to make sure that episodic
      learners don't update ema statistics a second time when processing
      queries.
    is_training: if use_ema=True, this determines whether to apply the moving
      averages, or update them.
    ema_epsilon: if updating moving averages, use this value for the
      exponential moving averages.

  Returns:
    output: The result of applying batch normalization to the input.
    params: The updated params.
    moments: The updated moments.
  """
  params_keys, params_vars, moments_keys, moments_vars = [], [], [], []

  with tf.variable_scope('batch_norm'):
    scope_name = tf.get_variable_scope().name

    if use_ema:
      ema_shape = [1, 1, 1, x.get_shape().as_list()[-1]]
      mean_ema = tf.get_variable(
          'mean_ema',
          shape=ema_shape,
          initializer=tf.initializers.zeros(),
          trainable=False)
      var_ema = tf.get_variable(
          'var_ema',
          shape=ema_shape,
          initializer=tf.initializers.ones(),
          trainable=False)

    if moments is not None:
      if backprop_through_moments:
        mean = moments[scope_name + '/mean']
        var = moments[scope_name + '/var']
      else:
        # This variant does not yield good resutls.
        mean = tf.stop_gradient(moments[scope_name + '/mean'])
        var = tf.stop_gradient(moments[scope_name + '/var'])
    elif use_ema and not is_training:
      mean = mean_ema
      var = var_ema
    else:
      # If not provided, compute the mean and var of the current batch.

      replica_ctx = tf.distribute.get_replica_context()
      if replica_ctx:
        # from third_party/tensorflow/python/keras/layers/normalization_v2.py
        axes = list(range(len(x.shape) - 1))
        local_sum = tf.reduce_sum(x, axis=axes, keepdims=True)
        local_squared_sum = tf.reduce_sum(
            tf.square(x), axis=axes, keepdims=True)
        batch_size = tf.cast(tf.shape(x)[0], tf.float32)
        x_sum, x_squared_sum, global_batch_size = (
            replica_ctx.all_reduce('sum',
                                   [local_sum, local_squared_sum, batch_size]))

        axes_vals = [(tf.shape(x))[i] for i in range(1, len(axes))]
        multiplier = tf.cast(tf.reduce_prod(axes_vals), tf.float32)
        multiplier = multiplier * global_batch_size

        mean = x_sum / multiplier
        x_squared_mean = x_squared_sum / multiplier
        # var = E(x^2) - E(x)^2
        var = x_squared_mean - tf.square(mean)
      else:
        mean, var = tf.nn.moments(
            x, axes=list(range(len(x.shape) - 1)), keep_dims=True)

    # Only update ema's if training and we computed the moments in the current
    # call.  Note: at test time for episodic learners, ema's may be passed
    # from the support set to the query set, even if it's not really needed.
    if use_ema and is_training and moments is None:
      replica_ctx = tf.distribute.get_replica_context()
      mean_upd = tf.assign(mean_ema,
                           mean_ema * ema_epsilon + mean * (1.0 - ema_epsilon))
      var_upd = tf.assign(var_ema,
                          var_ema * ema_epsilon + var * (1.0 - ema_epsilon))
      updates = tf.group([mean_upd, var_upd])
      if replica_ctx:
        tf.add_to_collection(
            tf.GraphKeys.UPDATE_OPS,
            tf.cond(
                tf.equal(replica_ctx.replica_id_in_sync_group, 0),
                lambda: updates, tf.no_op))
      else:
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, updates)

    moments_keys += [scope_name + '/mean']
    moments_vars += [mean]
    moments_keys += [scope_name + '/var']
    moments_vars += [var]

    if params is None:
      offset = tf.get_variable(
          'offset',
          shape=mean.get_shape().as_list(),
          initializer=tf.initializers.zeros())
      scale = tf.get_variable(
          'scale',
          shape=var.get_shape().as_list(),
          initializer=tf.initializers.ones())
    else:
      offset = params[scope_name + '/offset']
      scale = params[scope_name + '/scale']

    params_keys += [scope_name + '/offset']
    params_vars += [offset]
    params_keys += [scope_name + '/scale']
    params_vars += [scale]

    output = tf.nn.batch_normalization(x, mean, var, offset, scale, 0.00001)

    params = collections.OrderedDict(zip(params_keys, params_vars))
    moments = collections.OrderedDict(zip(moments_keys, moments_vars))

    return output, params, moments