def _do_data_dependent_init(): """Returns ops for the data-dependent init of g and maybe b_fc.""" w_fc_normalized = tf.nn.l2_normalize(w_fc.read_value(), [0]) output_init = tf.matmul(embeddings, w_fc_normalized) mean_init, var_init = tf.nn.moments(output_init, [0]) # Data-dependent init values. g_init_value = 1. / tf.sqrt(var_init + 1e-10) ops = [tf.assign(g, g_init_value)] if not cosine_classifier: # Also initialize a bias in a data-dependent way. b_fc_init_value = -mean_init * g_init_value ops.append(tf.assign(b_fc, b_fc_init_value)) # Mark that the data-dependent initialization is done to prevent it from # happening again in the future. ops.append(tf.assign(data_dependent_init_done, 1)) return tf.group(*ops)
def get_fc_vars_copy_ops(fc_weights, fc_bias, make_copies): """Gets copies of the classifier layer variables or returns those variables. At meta-test time, a copy is created for the given Variables, and these copies copies will be used in place of the original ones. Args: fc_weights: A Variable for the weights of the fc layer. fc_bias: A Variable for the bias of the fc layer. make_copies: A bool. Whether to copy the given variables. If not, those variables themselves are returned. Returns: fc_weights: A Variable for the weights of the fc layer. Might be the same as the input fc_weights or a copy of it. fc_bias: Analogously, a Variable for the bias of the fc layer. fc_vars_copy_ops: A (possibly empty) list of operations for assigning the value of each of fc_weights and fc_bias to a respective copy variable. """ fc_vars_copy_ops = [] if make_copies: with tf.variable_scope('weight_copy'): # fc_weights copy fc_weights_copy = tf.Variable( tf.zeros(fc_weights.shape.as_list()), collections=[tf.GraphKeys.LOCAL_VARIABLES]) fc_weights_copy_op = tf.assign(fc_weights_copy, fc_weights) fc_vars_copy_ops.append(fc_weights_copy_op) # fc_bias copy fc_bias_copy = tf.Variable( tf.zeros(fc_bias.shape.as_list()), collections=[tf.GraphKeys.LOCAL_VARIABLES]) fc_bias_copy_op = tf.assign(fc_bias_copy, fc_bias) fc_vars_copy_ops.append(fc_bias_copy_op) fc_weights = fc_weights_copy fc_bias = fc_bias_copy return fc_weights, fc_bias, fc_vars_copy_ops
def _apply_grads(variables, grads): """Applies gradients using SGD on a list of variables.""" v_new, update_ops = [], [] for (v, dv) in zip(variables, grads): if (not allow_grads_to_batch_norm_vars and ('offset' in v.name or 'scale' in v.name)): updated_value = v # no update. else: updated_value = v - learning_rate * dv # gradient descent update. if get_update_ops: update_ops.append(tf.assign(v, updated_value)) v_new.append(updated_value) return v_new, update_ops
def get_embeddings_vars_copy_ops(embedding_vars_dict, make_copies): """Gets copies of the embedding variables or returns those variables. This is useful at meta-test time for MAML and the finetuning baseline. In particular, at meta-test time, we don't want to make permanent updates to the model's variables, but only modifications that persist in the given episode. This can be achieved by creating copies of each variable and modifying and using these copies instead of the variables themselves. Args: embedding_vars_dict: A dict mapping each variable name to the corresponding Variable. make_copies: A bool. Whether to copy the given variables. If not, those variables themselves will be returned. Typically, this is True at meta- test time and False at meta-training time. Returns: embedding_vars_keys: A list of variable names. embeddings_vars: A corresponding list of Variables. embedding_vars_copy_ops: A (possibly empty) list of operations, each of which assigns the value of one of the provided Variables to a new Variable which is its copy. """ embedding_vars_keys = [] embedding_vars = [] embedding_vars_copy_ops = [] for name, var in six.iteritems(embedding_vars_dict): embedding_vars_keys.append(name) if make_copies: with tf.variable_scope('weight_copy'): shape = var.shape.as_list() var_copy = tf.Variable( tf.zeros(shape), collections=[tf.GraphKeys.LOCAL_VARIABLES]) var_copy_op = tf.assign(var_copy, var) embedding_vars_copy_ops.append(var_copy_op) embedding_vars.append(var_copy) else: embedding_vars.append(var) return embedding_vars_keys, embedding_vars, embedding_vars_copy_ops
def forward_pass(self, data): """Computes the test logits of MAML. Args: data: A `meta_dataset.providers.Episode` containing the data for the episode. Returns: The output logits for the query data in this episode. """ # Have to use one-hot labels since sparse softmax doesn't allow # second derivatives. support_embeddings_ = self.embedding_fn( data.support_images, self.is_training, reuse=tf.AUTO_REUSE) support_embeddings = support_embeddings_['embeddings'] embedding_vars_dict = support_embeddings_['params'] # TODO(eringrant): Refactor to make use of # `functional_backbones.linear_classifier`, which allows Gin-configuration. with tf.variable_scope('linear_classifier', reuse=tf.AUTO_REUSE): embedding_depth = support_embeddings.shape.as_list()[-1] fc_weights = functional_backbones.weight_variable( [embedding_depth, self.logit_dim], weight_decay=self.classifier_weight_decay) fc_bias = functional_backbones.bias_variable([self.logit_dim]) # A list of variable names, a list of corresponding Variables, and a list # of operations (possibly empty) that creates a copy of each Variable. (embedding_vars_keys, embedding_vars, embedding_vars_copy_ops) = get_embeddings_vars_copy_ops( embedding_vars_dict, make_copies=not self.is_training) # A Variable for the weights of the fc layer, a Variable for the bias of the # fc layer, and a list of operations (possibly empty) that copies them. (fc_weights, fc_bias, fc_vars_copy_ops) = get_fc_vars_copy_ops( fc_weights, fc_bias, make_copies=not self.is_training) fc_vars = [fc_weights, fc_bias] num_embedding_vars = len(embedding_vars) num_fc_vars = len(fc_vars) def _cond(step, *args): del args num_steps = self.num_update_steps if not self.is_training: num_steps += self.additional_evaluation_update_steps return step < num_steps def _body(step, *args): """The inner update loop body.""" updated_embedding_vars = args[0:num_embedding_vars] updated_fc_vars = args[num_embedding_vars:num_embedding_vars + num_fc_vars] support_embeddings = self.embedding_fn( data.support_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, updated_embedding_vars)), reuse=True)['embeddings'] updated_fc_weights, updated_fc_bias = updated_fc_vars support_logits = tf.matmul(support_embeddings, updated_fc_weights) + updated_fc_bias support_logits = support_logits[:, 0:data.way] loss = tf.losses.softmax_cross_entropy(data.onehot_support_labels, support_logits) print_op = tf.no_op() if self.debug_log: print_op = tf.print(['step: ', step, updated_fc_bias[0], 'loss:', loss]) with tf.control_dependencies([print_op]): updated_embedding_vars = gradient_descent_step( loss, updated_embedding_vars, self.first_order, self.adapt_batch_norm, self.alpha, False)['updated_vars'] updated_fc_vars = gradient_descent_step(loss, updated_fc_vars, self.first_order, self.adapt_batch_norm, self.alpha, False)['updated_vars'] step = step + 1 return tuple([step] + list(updated_embedding_vars) + list(updated_fc_vars)) # MAML meta updates using query set examples from an episode. if self.zero_fc_layer: # To account for variable class sizes, we initialize the output # weights to zero. See if truncated normal initialization will help. zero_weights_op = tf.assign(fc_weights, tf.zeros_like(fc_weights)) zero_bias_op = tf.assign(fc_bias, tf.zeros_like(fc_bias)) fc_vars_init_ops = [zero_weights_op, zero_bias_op] else: fc_vars_init_ops = fc_vars_copy_ops if self.proto_maml_fc_layer_init: support_embeddings = self.embedding_fn( data.support_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, embedding_vars)), reuse=True)['embeddings'] prototypes = metric_learners.compute_prototypes( support_embeddings, data.onehot_support_labels) pmaml_fc_weights = self.proto_maml_fc_weights( prototypes, zero_pad_to_max_way=True) pmaml_fc_bias = self.proto_maml_fc_bias( prototypes, zero_pad_to_max_way=True) fc_vars = [pmaml_fc_weights, pmaml_fc_bias] # These control dependencies assign the value of each variable to a new copy # variable that corresponds to it. This is required at test time for # initilizing the copies as they are used in place of the original vars. with tf.control_dependencies(fc_vars_init_ops + embedding_vars_copy_ops): # Make step a local variable as we don't want to save and restore it. step = tf.Variable( 0, trainable=False, name='inner_step_counter', collections=[tf.GraphKeys.LOCAL_VARIABLES]) loop_vars = [step] + embedding_vars + fc_vars step_and_all_updated_vars = tf.while_loop( _cond, _body, loop_vars, swap_memory=True) step = step_and_all_updated_vars[0] all_updated_vars = step_and_all_updated_vars[1:] updated_embedding_vars = all_updated_vars[0:num_embedding_vars] updated_fc_weights, updated_fc_bias = all_updated_vars[ num_embedding_vars:num_embedding_vars + num_fc_vars] # Forward pass the training images with the updated weights in order to # compute the means and variances, to use for the query's batch norm. support_set_moments = None if not self.transductive_batch_norm: support_set_moments = self.embedding_fn( data.support_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, updated_embedding_vars)), reuse=True)['moments'] query_embeddings = self.embedding_fn( data.query_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, updated_embedding_vars)), moments=support_set_moments, # Use support set stats for batch norm. reuse=True, backprop_through_moments=self.backprop_through_moments)['embeddings'] query_logits = (tf.matmul(query_embeddings, updated_fc_weights) + updated_fc_bias)[:, 0:data.way] return query_logits
def bn(x, params=None, moments=None, backprop_through_moments=True, use_ema=False, is_training=True, ema_epsilon=.9): """Batch normalization. The usage should be as follows: If x is the support images, moments should be None so that they are computed from the support set examples. On the other hand, if x is the query images, the moments argument should be used in order to pass in the mean and var that were computed from the support set. Args: x: inputs. params: None or a dict containing the values of the offset and scale params. moments: None or a dict containing the values of the mean and var to use for batch normalization. backprop_through_moments: Whether to allow gradients to flow through the given support set moments. Only applies to non-transductive batch norm. use_ema: apply moving averages of batch norm statistics, or update them, depending on whether we are training or testing. Note that passing moments will override this setting, and result in neither updating or using ema statistics. This is important to make sure that episodic learners don't update ema statistics a second time when processing queries. is_training: if use_ema=True, this determines whether to apply the moving averages, or update them. ema_epsilon: if updating moving averages, use this value for the exponential moving averages. Returns: output: The result of applying batch normalization to the input. params: The updated params. moments: The updated moments. """ params_keys, params_vars, moments_keys, moments_vars = [], [], [], [] with tf.variable_scope('batch_norm'): scope_name = tf.get_variable_scope().name if use_ema: ema_shape = [1, 1, 1, x.get_shape().as_list()[-1]] mean_ema = tf.get_variable( 'mean_ema', shape=ema_shape, initializer=tf.initializers.zeros(), trainable=False) var_ema = tf.get_variable( 'var_ema', shape=ema_shape, initializer=tf.initializers.ones(), trainable=False) if moments is not None: if backprop_through_moments: mean = moments[scope_name + '/mean'] var = moments[scope_name + '/var'] else: # This variant does not yield good resutls. mean = tf.stop_gradient(moments[scope_name + '/mean']) var = tf.stop_gradient(moments[scope_name + '/var']) elif use_ema and not is_training: mean = mean_ema var = var_ema else: # If not provided, compute the mean and var of the current batch. replica_ctx = tf.distribute.get_replica_context() if replica_ctx: # from third_party/tensorflow/python/keras/layers/normalization_v2.py axes = list(range(len(x.shape) - 1)) local_sum = tf.reduce_sum(x, axis=axes, keepdims=True) local_squared_sum = tf.reduce_sum( tf.square(x), axis=axes, keepdims=True) batch_size = tf.cast(tf.shape(x)[0], tf.float32) x_sum, x_squared_sum, global_batch_size = ( replica_ctx.all_reduce('sum', [local_sum, local_squared_sum, batch_size])) axes_vals = [(tf.shape(x))[i] for i in range(1, len(axes))] multiplier = tf.cast(tf.reduce_prod(axes_vals), tf.float32) multiplier = multiplier * global_batch_size mean = x_sum / multiplier x_squared_mean = x_squared_sum / multiplier # var = E(x^2) - E(x)^2 var = x_squared_mean - tf.square(mean) else: mean, var = tf.nn.moments( x, axes=list(range(len(x.shape) - 1)), keep_dims=True) # Only update ema's if training and we computed the moments in the current # call. Note: at test time for episodic learners, ema's may be passed # from the support set to the query set, even if it's not really needed. if use_ema and is_training and moments is None: replica_ctx = tf.distribute.get_replica_context() mean_upd = tf.assign(mean_ema, mean_ema * ema_epsilon + mean * (1.0 - ema_epsilon)) var_upd = tf.assign(var_ema, var_ema * ema_epsilon + var * (1.0 - ema_epsilon)) updates = tf.group([mean_upd, var_upd]) if replica_ctx: tf.add_to_collection( tf.GraphKeys.UPDATE_OPS, tf.cond( tf.equal(replica_ctx.replica_id_in_sync_group, 0), lambda: updates, tf.no_op)) else: tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, updates) moments_keys += [scope_name + '/mean'] moments_vars += [mean] moments_keys += [scope_name + '/var'] moments_vars += [var] if params is None: offset = tf.get_variable( 'offset', shape=mean.get_shape().as_list(), initializer=tf.initializers.zeros()) scale = tf.get_variable( 'scale', shape=var.get_shape().as_list(), initializer=tf.initializers.ones()) else: offset = params[scope_name + '/offset'] scale = params[scope_name + '/scale'] params_keys += [scope_name + '/offset'] params_vars += [offset] params_keys += [scope_name + '/scale'] params_vars += [scale] output = tf.nn.batch_normalization(x, mean, var, offset, scale, 0.00001) params = collections.OrderedDict(zip(params_keys, params_vars)) moments = collections.OrderedDict(zip(moments_keys, moments_vars)) return output, params, moments