def linear_classifier(embeddings, num_classes, cosine_classifier, cosine_logits_multiplier, use_weight_norm, weight_decay): """Forward pass through a linear classifier, or possibly a cosine classifier. Args: embeddings: A Tensor of size [batch size, embedding dim]. num_classes: An integer; the dimension of the classification. cosine_classifier: A bool. If true, a cosine classifier is used, which does not require a bias. cosine_logits_multiplier: A float. Only used if cosine_classifier is True, and multiplies the resulting logits. use_weight_norm: A bool. Whether weight norm was used. If so, then if using cosine classifier, normalize only the embeddings but not the weights. weight_decay: A float; the scalar multiple on the L2 regularization of the weight matrix. Returns: logits: A Tensor of size [batch size, num outputs]. """ embedding_dims = embeddings.get_shape().as_list()[-1] if use_weight_norm: # A variable to keep track of whether the initialization has already # happened. data_dependent_init_done = tf.get_variable('data_dependent_init_done', initializer=0, dtype=tf.int32, trainable=False) w_fc = tf.get_variable('w_fc', [embedding_dims, num_classes], initializer=tf.random_normal_initializer( 0, 0.05), trainable=True) # This init is temporary as it needs to be done in a data-dependent way. # It will be overwritten during the first forward pass through this layer. g = tf.get_variable('g', dtype=tf.float32, initializer=tf.ones([num_classes]), trainable=True) b_fc = None if not cosine_classifier: # Also initialize a bias. b_fc = tf.get_variable('b_fc', initializer=tf.zeros([num_classes]), trainable=True) def _do_data_dependent_init(): """Returns ops for the data-dependent init of g and maybe b_fc.""" w_fc_normalized = tf.nn.l2_normalize(w_fc.read_value(), [0]) output_init = tf.matmul(embeddings, w_fc_normalized) mean_init, var_init = tf.nn.moments(output_init, [0]) # Data-dependent init values. g_init_value = 1. / tf.sqrt(var_init + 1e-10) ops = [tf.assign(g, g_init_value)] if not cosine_classifier: # Also initialize a bias in a data-dependent way. b_fc_init_value = -mean_init * g_init_value ops.append(tf.assign(b_fc, b_fc_init_value)) # Mark that the data-dependent initialization is done to prevent it from # happening again in the future. ops.append(tf.assign(data_dependent_init_done, 1)) return tf.group(*ops) # Possibly perform data-dependent init (if it hasn't been done already). init_op = tf.cond(tf.equal(data_dependent_init_done, 0), _do_data_dependent_init, tf.no_op) with tf.control_dependencies([init_op]): # Apply weight normalization. w_fc *= g / tf.sqrt(tf.reduce_sum(tf.square(w_fc), [0])) # Forward pass through the layer defined by w_fc and b_fc. logits = linear_classifier_forward_pass(embeddings, w_fc, b_fc, cosine_classifier, cosine_logits_multiplier, True) else: # No weight norm. w_fc = functional_backbones.weight_variable( [embedding_dims, num_classes], weight_decay=weight_decay) b_fc = None if not cosine_classifier: # Also initialize a bias. b_fc = functional_backbones.bias_variable([num_classes]) # Forward pass through the layer defined by w_fc and b_fc. logits = linear_classifier_forward_pass(embeddings, w_fc, b_fc, cosine_classifier, cosine_logits_multiplier, False) return logits
def forward_pass(self, data): """Computes the test logits of MAML. Args: data: A `meta_dataset.providers.Episode` containing the data for the episode. Returns: The output logits for the query data in this episode. """ # Have to use one-hot labels since sparse softmax doesn't allow # second derivatives. support_embeddings_ = self.embedding_fn( data.support_images, self.is_training, reuse=tf.AUTO_REUSE) support_embeddings = support_embeddings_['embeddings'] embedding_vars_dict = support_embeddings_['params'] # TODO(eringrant): Refactor to make use of # `functional_backbones.linear_classifier`, which allows Gin-configuration. with tf.variable_scope('linear_classifier', reuse=tf.AUTO_REUSE): embedding_depth = support_embeddings.shape.as_list()[-1] fc_weights = functional_backbones.weight_variable( [embedding_depth, self.logit_dim], weight_decay=self.classifier_weight_decay) fc_bias = functional_backbones.bias_variable([self.logit_dim]) # A list of variable names, a list of corresponding Variables, and a list # of operations (possibly empty) that creates a copy of each Variable. (embedding_vars_keys, embedding_vars, embedding_vars_copy_ops) = get_embeddings_vars_copy_ops( embedding_vars_dict, make_copies=not self.is_training) # A Variable for the weights of the fc layer, a Variable for the bias of the # fc layer, and a list of operations (possibly empty) that copies them. (fc_weights, fc_bias, fc_vars_copy_ops) = get_fc_vars_copy_ops( fc_weights, fc_bias, make_copies=not self.is_training) fc_vars = [fc_weights, fc_bias] num_embedding_vars = len(embedding_vars) num_fc_vars = len(fc_vars) def _cond(step, *args): del args num_steps = self.num_update_steps if not self.is_training: num_steps += self.additional_evaluation_update_steps return step < num_steps def _body(step, *args): """The inner update loop body.""" updated_embedding_vars = args[0:num_embedding_vars] updated_fc_vars = args[num_embedding_vars:num_embedding_vars + num_fc_vars] support_embeddings = self.embedding_fn( data.support_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, updated_embedding_vars)), reuse=True)['embeddings'] updated_fc_weights, updated_fc_bias = updated_fc_vars support_logits = tf.matmul(support_embeddings, updated_fc_weights) + updated_fc_bias support_logits = support_logits[:, 0:data.way] loss = tf.losses.softmax_cross_entropy(data.onehot_support_labels, support_logits) print_op = tf.no_op() if self.debug_log: print_op = tf.print(['step: ', step, updated_fc_bias[0], 'loss:', loss]) with tf.control_dependencies([print_op]): updated_embedding_vars = gradient_descent_step( loss, updated_embedding_vars, self.first_order, self.adapt_batch_norm, self.alpha, False)['updated_vars'] updated_fc_vars = gradient_descent_step(loss, updated_fc_vars, self.first_order, self.adapt_batch_norm, self.alpha, False)['updated_vars'] step = step + 1 return tuple([step] + list(updated_embedding_vars) + list(updated_fc_vars)) # MAML meta updates using query set examples from an episode. if self.zero_fc_layer: # To account for variable class sizes, we initialize the output # weights to zero. See if truncated normal initialization will help. zero_weights_op = tf.assign(fc_weights, tf.zeros_like(fc_weights)) zero_bias_op = tf.assign(fc_bias, tf.zeros_like(fc_bias)) fc_vars_init_ops = [zero_weights_op, zero_bias_op] else: fc_vars_init_ops = fc_vars_copy_ops if self.proto_maml_fc_layer_init: support_embeddings = self.embedding_fn( data.support_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, embedding_vars)), reuse=True)['embeddings'] prototypes = metric_learners.compute_prototypes( support_embeddings, data.onehot_support_labels) pmaml_fc_weights = self.proto_maml_fc_weights( prototypes, zero_pad_to_max_way=True) pmaml_fc_bias = self.proto_maml_fc_bias( prototypes, zero_pad_to_max_way=True) fc_vars = [pmaml_fc_weights, pmaml_fc_bias] # These control dependencies assign the value of each variable to a new copy # variable that corresponds to it. This is required at test time for # initilizing the copies as they are used in place of the original vars. with tf.control_dependencies(fc_vars_init_ops + embedding_vars_copy_ops): # Make step a local variable as we don't want to save and restore it. step = tf.Variable( 0, trainable=False, name='inner_step_counter', collections=[tf.GraphKeys.LOCAL_VARIABLES]) loop_vars = [step] + embedding_vars + fc_vars step_and_all_updated_vars = tf.while_loop( _cond, _body, loop_vars, swap_memory=True) step = step_and_all_updated_vars[0] all_updated_vars = step_and_all_updated_vars[1:] updated_embedding_vars = all_updated_vars[0:num_embedding_vars] updated_fc_weights, updated_fc_bias = all_updated_vars[ num_embedding_vars:num_embedding_vars + num_fc_vars] # Forward pass the training images with the updated weights in order to # compute the means and variances, to use for the query's batch norm. support_set_moments = None if not self.transductive_batch_norm: support_set_moments = self.embedding_fn( data.support_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, updated_embedding_vars)), reuse=True)['moments'] query_embeddings = self.embedding_fn( data.query_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, updated_embedding_vars)), moments=support_set_moments, # Use support set stats for batch norm. reuse=True, backprop_through_moments=self.backprop_through_moments)['embeddings'] query_logits = (tf.matmul(query_embeddings, updated_fc_weights) + updated_fc_bias)[:, 0:data.way] return query_logits