def grad_fn(dy): """Compute gradients using a while loop to save memory.""" support_keys_id = tf.identity(support_keys) support_values_id = tf.identity(support_values) initial = (0, tf.zeros(tf.shape(query_queries)[1:], dtype=dy.dtype)[tf.newaxis, :][:zero_dim], tf.zeros(tf.shape(query_values)[1:], dtype=dy.dtype)[tf.newaxis, :][:zero_dim], tf.zeros(tf.shape(support_keys_id), dtype=dy.dtype), tf.zeros(tf.shape(support_values_id), dtype=dy.dtype)) def loop_body(idx, qq_grad, qv_grad, sk_grad, sv_grad): """Compute gradients for a single query.""" qq = query_queries[idx:idx + 1] qv = query_values[idx:idx + 1] x = self._get_dist(qq, qv, support_keys_id, support_values_id, labels) grads = tf.gradients( x, [qq, qv, support_keys_id, support_values_id], grad_ys=dy[:, idx:idx + 1]) qq_grad = tf.concat([qq_grad, grads[0]], axis=0) qv_grad = tf.concat([qv_grad, grads[1]], axis=0) sk_grad += grads[2] sv_grad += grads[3] return (idx + 1, qq_grad, qv_grad, sk_grad, sv_grad) agg_grads = tf.while_loop( lambda *arg: arg[0] < tf.shape(query_queries)[0], loop_body, initial, parallel_iterations=1) return agg_grads[1:] + (None,)
def detailed_forward_pass(self, data): """Returns all information from a forward pass of the `OptimizationLearner`. Args: data: A `meta_dataset.providers.Episode` containing the data for the episode. Returns: A `collections.NamedTuple` that contains the results of the forward pass. """ # Loop initialization. init_loop_variables = self.task_parameters init_loop_variable_refs = [ v.experimental_ref() for v in init_loop_variables ] # Construct ops for data-dependent episodic initialization. episodic_init_ops = self.episodic_init_ops( labels=data.support_labels, embeddings=self.embedding_fn(data.support_images, training=True), task_parameters=init_loop_variables, ) def _forward_pass(iteration_idx_, variables_mapping_, images_, onehot_labels_): """Helper function to compute the outputs of a forward pass.""" with self.embedding_fn.reparameterize(variables_mapping_): # TODO(eringrant): Implement non-transductive batch normalization (i.e., # pass the support set statistics through the query set forward pass. embeddings_ = self.embedding_fn(images_, training=True) # TODO(eringrant): `head_fn` is an attribute of the subclass. with self.head_fn.reparameterize(variables_mapping_): predictions_ = self.head_fn(embeddings_)[:, :data.way] accuracy_ = tf.reduce_mean(input_tensor=self.compute_accuracy( onehot_labels=onehot_labels_, predictions=predictions_)) inner_objective_ = self.inner_objective( onehot_labels=onehot_labels_, predictions=predictions_, iteration_idx=iteration_idx_) outer_objective_ = self.outer_objective( onehot_labels=onehot_labels_, predictions=predictions_, ) return ForwardPass( embeddings=embeddings_, predictions=predictions_, inner_objective_value=inner_objective_, outer_objective_value=outer_objective_, accuracy=accuracy_, ) def _objective_fn(loop_variables_, iteration_idx_): """Evaluate the support set objective given `loop_variables_`.""" # Get attribute paths for the loop_variables. loop_variables_mapping_ = dict( zip(init_loop_variable_refs, loop_variables_)) adaptation_support_results = _forward_pass( iteration_idx_=iteration_idx_, variables_mapping_=loop_variables_mapping_, images_=data.support_images, onehot_labels_=data.onehot_support_labels) return adaptation_support_results.inner_objective_value def _e_step(loop_variables_): """Evaluate expectations given `loop_variables_`.""" # Get attribute paths for the loop_variables. loop_variables_dict_ = dict( zip(init_loop_variable_refs, loop_variables_)) with self.embedding_fn.reparameterize(loop_variables_dict_): # TODO(eringrant): training to True for normalization with batch stats. # Figure out the appropriate way to pass this around. train_embeddings_ = self.embedding_fn(data.train_images, training=True) class_embeddings_ = learner_base.class_specific_data( data.onehot_train_labels, train_embeddings_, self.logit_dim) def _compute_responsibilities(examples_, class_idx): train_predictions_ = tf.squeeze(self.head_fn( embeddings=examples_, components=True, class_idx=[class_idx]), axis=1) return tf.nn.softmax(train_predictions_, axis=-1) with self.head_fn.reparameterize(loop_variables_dict_): class_responsibilities_ = [ _compute_responsibilities(embeddings_, class_idx=i) for i, embeddings_ in enumerate(class_embeddings_) ] return class_embeddings_, class_responsibilities_ def _m_step(preupdate_vars, all_embeddings_, all_responsibilities_): """Compute parameter estimates given `loop_variables_`.""" means, log_scales, logits = zip( *map(reparameterizable_distributions.fit_gaussian_mixture, all_embeddings_, all_responsibilities_, itertools.repeat(self.head_fn.damping))) def flatten(x): return list(itertools.chain.from_iterable(x)) means = flatten(means) log_scales = flatten(log_scales) logits = flatten(logits) if not self.head_fn.estimate_loc: means = [None for _ in means] if not self.head_fn.estimate_scale: log_scales = [None for _ in log_scales] if not self.head_fn.estimate_logits: logits = [None for _ in logits] updated_vars = means + log_scales + logits # Replace constant variables. # TODO(eringrant): This interface differs from just excluding these # variables from `task_variables`. no_none_updated_vars = [] for preupdate_var, updated_var in zip(preupdate_vars, updated_vars): if updated_var is None: no_none_updated_vars.append(preupdate_var) else: no_none_updated_vars.append(updated_var) # TODO(eringrant): This assumes an ordering of mean, log_scales, # mixing_logits. return no_none_updated_vars # Loop body. with tf.control_dependencies(episodic_init_ops): # Inner loop of expectation maximization. num_em_steps = self.getattr('num_em_steps', 0) if num_em_steps > 0: loop_variables = em_loop(num_updates=self.num_em_steps, e_step=_e_step, m_step=_m_step, variables=loop_variables) # Inner loop of gradient-based optimization. num_optimizer_steps = (self.num_update_steps + (self.additional_evaluation_update_steps if not self.is_training else 0)) if num_optimizer_steps > 0: # pylint: disable=no-value-for-parameter final_loop_variables = optimizer_loop( num_updates=num_optimizer_steps, objective_fn=_objective_fn, update_fn=self.update_fn, variables=init_loop_variables, first_order=self.first_order, clip_grad_norm=self.clip_grad_norm, ) # pylint: enable=no-value-for-parameter # If no inner loop adaptation is performed, ensure the episodic # initialization is still part of the graph via a control dependency. if num_optimizer_steps + num_em_steps == 0: loop_variables = [tf.identity(v) for v in init_loop_variables] # Get variable references to use when remapping the loop_variables. init_loop_variables_mapping = dict( zip(init_loop_variable_refs, init_loop_variables)) final_loop_variables_mapping = dict( zip(init_loop_variable_refs, final_loop_variables)) # Collect statistics about the inner optimization. with tf.compat.v1.name_scope('pre-adaptation'): with tf.compat.v1.name_scope('support'): pre_adaptation_support_results = _forward_pass( iteration_idx_=0, variables_mapping_=init_loop_variables_mapping, images_=data.support_images, onehot_labels_=data.onehot_support_labels) with tf.compat.v1.name_scope('query'): pre_adaptation_query_results = _forward_pass( iteration_idx_=0, variables_mapping_=init_loop_variables_mapping, images_=data.query_images, onehot_labels_=data.onehot_query_labels) with tf.compat.v1.name_scope('post-adaptation'): with tf.compat.v1.name_scope('support'): post_adaptation_support_results = _forward_pass( iteration_idx_=num_optimizer_steps, variables_mapping_=final_loop_variables_mapping, images_=data.support_images, onehot_labels_=data.onehot_support_labels, ) with tf.compat.v1.name_scope('query'): post_adaptation_query_results = _forward_pass( iteration_idx_=num_optimizer_steps, variables_mapping_=final_loop_variables_mapping, images_=data.query_images, onehot_labels_=data.onehot_query_labels, ) def _support_module_objective_fn(module_variables_, module_variable_refs_): """Evaluate the query set objective given `module_variables_`.""" # Use the values of the parameters at convergence as the default value. variables_mapping_ = final_loop_variables_mapping.copy() # Loop over and replace the module-specific variables. for module_variable_ref, module_variable in zip( module_variable_refs_, module_variables_): variables_mapping_[module_variable_ref] = module_variable adaptation_query_results = _forward_pass( iteration_idx_=num_optimizer_steps, variables_mapping_=variables_mapping_, images_=data.support_images, onehot_labels_=data.onehot_support_labels, ) return adaptation_query_results.inner_objective_value def _query_module_objective_fn(module_variables_, module_variable_refs_): """Evaluate the query set objective given `module_variables_`.""" # Use the values of the parameters at convergence as the default value. variables_mapping_ = final_loop_variables_mapping.copy() # Loop over and replace the module-specific variables. for module_variable_ref, module_variable in zip( module_variable_refs_, module_variables_): variables_mapping_[module_variable_ref] = module_variable adaptation_query_results = _forward_pass( iteration_idx_=num_optimizer_steps, variables_mapping_=variables_mapping_, images_=data.query_images, onehot_labels_=data.onehot_query_labels) return adaptation_query_results.inner_objective_value return Adaptation( pre_adaptation_support_results=pre_adaptation_support_results, post_adaptation_support_results=post_adaptation_support_results, pre_adaptation_query_results=pre_adaptation_query_results, post_adaptation_query_results=post_adaptation_query_results, objective_fn=_objective_fn, support_module_objective_fn=_support_module_objective_fn, query_module_objective_fn=_query_module_objective_fn, forward_pass_fn=_forward_pass, init_loop_variables_mapping=init_loop_variables_mapping, final_loop_variables_mapping=final_loop_variables_mapping, )
def forward_pass(self, data): """Computes the query logits for the given episode `data`.""" if self.film_init == 'scratch': self.film_selector = None elif self.film_init == 'imagenet': # Note: this makes the assumption that the first set of learned FiLM # parameters corresponds to the ImageNet dataset. Otherwise, the # following line should be changed appropriately. self.film_selector = 0 elif self.film_init in ['blender', 'blender_hard']: dataset_logits = functional_backbones.dataset_classifier( data.support_images) if self.film_init == 'blender_hard': # Select only the argmax entry. self.film_selector = tf.one_hot( tf.math.argmax(dataset_logits, axis=-1), depth=tf.shape(dataset_logits)[1]) else: # Take a convex combination. self.film_selector = tf.nn.softmax(dataset_logits, axis=-1) if self.num_steps: # Initial forward pass, required for the `unused_op` below and for placing # variables in tf.trainable_variables() for the below block to pick up. loss = self._compute_losses(data, compute_on_query=False)['loss'] # Pick out the variables to optimize. self.opt_vars = [] for var in tf.trainable_variables(): if '_for_film_learner' in var.name: self.opt_vars.append(var) tf.logging.info('FiLMLearner will optimize vars: {}'.format( self.opt_vars)) for i in range(self.num_steps): if i == 0: # Re-initialize the variables to optimize for the new episode, to ensure # the FiLM parameters aren't re-used across tasks of a given dataset. vars_reset = tf.variables_initializer(var_list=self.opt_vars) # Adam related variables are created when minimize() is called. # We create an unused op here to put all adam varariables under # the 'adam_opt' namescope and create a reset op to reinitialize # these variables before the first finetune step. with tf.variable_scope('adam_opt', reuse=tf.AUTO_REUSE): unused_op = self.opt.minimize(loss, var_list=self.opt_vars) adam_reset = tf.variables_initializer(self.opt.variables()) with tf.control_dependencies([vars_reset, adam_reset, loss] + self.opt_vars): print_op = tf.no_op() if self.debug_log: print_op = tf.print([ 'step: %d' % i, self.opt_vars[0][0], 'loss:', loss ], summarize=-1) with tf.control_dependencies([print_op]): # Get the train op. results = self._get_train_op(data) (train_op, loss, query_loss, acc, query_acc) = (results['train_op'], results['loss'], results['query_loss'], results['acc'], results['query_acc']) else: with tf.control_dependencies([train_op, loss, acc] + self.opt_vars + [query_loss, query_acc] * int(self.debug_log)): print_op = tf.no_op() if self.debug_log: print_list = [ '################', 'step: %d' % i, self.opt_vars[0][0], 'support loss:', loss, 'query loss:', query_loss, 'support acc:', acc, 'query acc:', query_acc, ] print_op = tf.print(print_list) with tf.control_dependencies([print_op]): # Get the train op (the loss is returned just for printing). results = self._get_train_op(data) (train_op, loss, query_loss, acc, query_acc) = (results['train_op'], results['loss'], results['query_loss'], results['acc'], results['query_acc']) # Training is now over, compute the final query logits. dependency_list = [] if not self.num_steps else [train_op ] + self.opt_vars with tf.control_dependencies(dependency_list): results = self._compute_losses(data, compute_on_query=True) (loss, query_loss, query_logits, acc, query_acc) = (results['loss'], results['query_loss'], results['query_logits'], results['acc'], results['query_acc']) print_op = tf.no_op() if self.debug_log: print_op = tf.print([ 'Done training', 'support loss:', loss, 'query loss:', query_loss, 'support acc:', acc, 'query acc:', query_acc, ]) with tf.control_dependencies([print_op]): query_logits = tf.identity(query_logits) return query_logits
def get_updated_global_step(self): with tf.control_dependencies([self.train_op]): global_step = tf.identity(tf.train.get_global_step()) return global_step
def get_train_op(loss, learning_rate=0.001, lr_decay_steps=10000, lr_decay_rate=0.98, gradient_clip_norm=3.0, use_tpu=True, variables=None): """Get training operation with gradient clipping and learning rate decay. Distilled from tf.contrib.layers.optimize_loss(). Args: loss: Scalar tensor of the loss function. learning_rate: Scalar initial learning rate. lr_decay_steps: Exponential decay timescale. lr_decay_rate: Exponential decay magnitude. gradient_clip_norm: Global norm by which to scale gradients. use_tpu: Use tpu for training. variables: List of variables to optimize. tf.trainable_variables() if None. Returns: train_op: Operation that runs one iteration of training. """ global_step = tf.train.get_or_create_global_step() with tf.variable_scope('training', values=[loss, global_step]): # Make sure update ops run before computing loss. update_ops = list(set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))) with tf.control_dependencies(update_ops): loss = tf.identity(loss) # Learning rate variable, with decay. learning_rate_decay_fn = functools.partial(tf.train.exponential_decay, decay_steps=lr_decay_steps, decay_rate=lr_decay_rate, staircase=True) lr = tf.get_variable( 'learning_rate', [], trainable=False, initializer=tf.constant_initializer(learning_rate)) lr = learning_rate_decay_fn(lr, global_step) # Optimizer. opt = tf.train.AdamOptimizer(lr) if use_tpu: opt = tf.tpu.CrossShardOptimizer(opt) # All trainable variables, if specific variables are not specified. if variables is None: variables = tf.trainable_variables() # Compute gradients. gradients = opt.compute_gradients(loss, variables, colocate_gradients_with_ops=False) # Optionally clip gradients by global norm. if isinstance(gradient_clip_norm, float): gradients = _clip_gradients_by_norm(gradients, gradient_clip_norm) # Create gradient updates. grad_updates = opt.apply_gradients(gradients, global_step=global_step, name='train') # Ensure the train_op computes grad_updates. with tf.control_dependencies([grad_updates]): train_op = tf.identity(loss) return train_op