def flatten(input, axis=1, end_axis=-1): """ Caffe-style flatten. Args: inputs: An N-D tensor. axis: The first axis to flatten: all preceding axes are retained in the output. May be negative to index from the end (e.g., -1 for the last axis). end_axis: The last axis to flatten: all following axes are retained in the output. May be negative to index from the end (e.g., the default -1 for the last axis) Returns: A M-D tensor where M = N - (end_axis - axis) """ input_shape = tf.shape(input) input_rank = tf.shape(input_shape)[0] if axis < 0: axis = input_rank + axis if end_axis < 0: end_axis = input_rank + end_axis output_shape = [] if axis != 0: output_shape.append(input_shape[:axis]) output_shape.append([tf.reduce_prod(input_shape[axis:end_axis + 1])]) if end_axis + 1 != input_rank: output_shape.append(input_shape[end_axis + 1:]) output_shape = tf.concat(output_shape, axis=0) output = tf.reshape(input, output_shape) return output
def build(self, input_shape): intermediate_shape = self.f_phi.compute_output_shape(input_shape)[1:] self.f_theta_dense = tf.keras.layers.Dense( tf.reduce_prod(intermediate_shape), activation="relu") self.reshape_theta = tf.keras.layers.Reshape(intermediate_shape, name="ReshapeTheta") self.reshape_output = tf.keras.layers.Reshape(input_shape[1:], name="ReshapeOutput") super().build(input_shape)
def __init__(self, deconv_layer_params, start_decoding_size=7, start_decoding_filters=8, padding="same", preprocess_fc_layer_params=None, activation_fn=tf.nn.relu, output_activation_fn=tf.nn.tanh, name="image_decoding_network"): """ Initialize the layers for decoding a latent vector into an image. The user is responsible for calculating the output size given `start_decoding_size` and `deconv_layer_params`, and make sure that the size will match the expectation. How to calculate the output size: if padding=="same", then H = H1 * strides if padding=="valid", then H = (H1-1) * strides + HF where H = output size, H1 = input size, HF = height of kernel Args: deconv_layer_params (list[tuple]): a non-empty list of elements (num_filters, kernel_size, strides). start_decoding_size (int): the initial size we'd like to have for the feature map start_decoding_filters (int): the initial number of fitlers we'd like to have for the feature map. Note that given this value and `start_decoding_size`, we always first project an input latent vector into a vector of an appropriate length so that it can be reshaped into (`start_decoding_size`, `start_decoding_size`, `start_decoding_filters`). padding (str): "same" or "valid", see tf.keras.layers.Conv2DTranspose preprocess_fc_layer_params (tuple[int]): a list of fc layer units. These fc layers are used for preprocessing the latent vector before transposed convolutions. activation_fn (tf.nn.activation): activation for hidden layers output_activation_fn (tf.nn.activation): activation for the output layer. Usually our image inputs are normalized to [0, 1] or [-1, 1], so this function should be tf.nn.sigmoid or tf.nn.tanh. name (str): network name """ super().__init__(name=name) assert isinstance(deconv_layer_params, list) assert len(deconv_layer_params) > 0 self._preprocess_fc_layers = [] if preprocess_fc_layer_params is not None: for size in preprocess_fc_layer_params: self._preprocess_fc_layers.append( tf.keras.layers.Dense(size, activation=activation_fn)) # We always assume "channels_last" ! self._start_decoding_shape = [ start_decoding_size, start_decoding_size, start_decoding_filters ] self._preprocess_fc_layers.append( tf.keras.layers.Dense( tf.reduce_prod(self._start_decoding_shape), activation=activation_fn)) self._deconv_layers = [] for i, (filters, kernel_size, strides) in enumerate(deconv_layer_params): act_fn = activation_fn if i == len(deconv_layer_params) - 1: act_fn = output_activation_fn self._deconv_layers.append( tf.keras.layers.Conv2DTranspose( padding=padding, filters=filters, kernel_size=kernel_size, strides=strides, activation=act_fn))
def bn(x, params=None, moments=None, backprop_through_moments=True, use_ema=False, is_training=True, ema_epsilon=.9): """Batch normalization. The usage should be as follows: If x is the support images, moments should be None so that they are computed from the support set examples. On the other hand, if x is the query images, the moments argument should be used in order to pass in the mean and var that were computed from the support set. Args: x: inputs. params: None or a dict containing the values of the offset and scale params. moments: None or a dict containing the values of the mean and var to use for batch normalization. backprop_through_moments: Whether to allow gradients to flow through the given support set moments. Only applies to non-transductive batch norm. use_ema: apply moving averages of batch norm statistics, or update them, depending on whether we are training or testing. Note that passing moments will override this setting, and result in neither updating or using ema statistics. This is important to make sure that episodic learners don't update ema statistics a second time when processing queries. is_training: if use_ema=True, this determines whether to apply the moving averages, or update them. ema_epsilon: if updating moving averages, use this value for the exponential moving averages. Returns: output: The result of applying batch normalization to the input. params: The updated params. moments: The updated moments. """ params_keys, params_vars, moments_keys, moments_vars = [], [], [], [] with tf.variable_scope('batch_norm'): scope_name = tf.get_variable_scope().name if use_ema: ema_shape = [1, 1, 1, x.get_shape().as_list()[-1]] mean_ema = tf.get_variable( 'mean_ema', shape=ema_shape, initializer=tf.initializers.zeros(), trainable=False) var_ema = tf.get_variable( 'var_ema', shape=ema_shape, initializer=tf.initializers.ones(), trainable=False) if moments is not None: if backprop_through_moments: mean = moments[scope_name + '/mean'] var = moments[scope_name + '/var'] else: # This variant does not yield good resutls. mean = tf.stop_gradient(moments[scope_name + '/mean']) var = tf.stop_gradient(moments[scope_name + '/var']) elif use_ema and not is_training: mean = mean_ema var = var_ema else: # If not provided, compute the mean and var of the current batch. replica_ctx = tf.distribute.get_replica_context() if replica_ctx: # from third_party/tensorflow/python/keras/layers/normalization_v2.py axes = list(range(len(x.shape) - 1)) local_sum = tf.reduce_sum(x, axis=axes, keepdims=True) local_squared_sum = tf.reduce_sum( tf.square(x), axis=axes, keepdims=True) batch_size = tf.cast(tf.shape(x)[0], tf.float32) x_sum, x_squared_sum, global_batch_size = ( replica_ctx.all_reduce('sum', [local_sum, local_squared_sum, batch_size])) axes_vals = [(tf.shape(x))[i] for i in range(1, len(axes))] multiplier = tf.cast(tf.reduce_prod(axes_vals), tf.float32) multiplier = multiplier * global_batch_size mean = x_sum / multiplier x_squared_mean = x_squared_sum / multiplier # var = E(x^2) - E(x)^2 var = x_squared_mean - tf.square(mean) else: mean, var = tf.nn.moments( x, axes=list(range(len(x.shape) - 1)), keep_dims=True) # Only update ema's if training and we computed the moments in the current # call. Note: at test time for episodic learners, ema's may be passed # from the support set to the query set, even if it's not really needed. if use_ema and is_training and moments is None: replica_ctx = tf.distribute.get_replica_context() mean_upd = tf.assign(mean_ema, mean_ema * ema_epsilon + mean * (1.0 - ema_epsilon)) var_upd = tf.assign(var_ema, var_ema * ema_epsilon + var * (1.0 - ema_epsilon)) updates = tf.group([mean_upd, var_upd]) if replica_ctx: tf.add_to_collection( tf.GraphKeys.UPDATE_OPS, tf.cond( tf.equal(replica_ctx.replica_id_in_sync_group, 0), lambda: updates, tf.no_op)) else: tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, updates) moments_keys += [scope_name + '/mean'] moments_vars += [mean] moments_keys += [scope_name + '/var'] moments_vars += [var] if params is None: offset = tf.get_variable( 'offset', shape=mean.get_shape().as_list(), initializer=tf.initializers.zeros()) scale = tf.get_variable( 'scale', shape=var.get_shape().as_list(), initializer=tf.initializers.ones()) else: offset = params[scope_name + '/offset'] scale = params[scope_name + '/scale'] params_keys += [scope_name + '/offset'] params_vars += [offset] params_keys += [scope_name + '/scale'] params_vars += [scale] output = tf.nn.batch_normalization(x, mean, var, offset, scale, 0.00001) params = collections.OrderedDict(zip(params_keys, params_vars)) moments = collections.OrderedDict(zip(moments_keys, moments_vars)) return output, params, moments