def _train_op(self): with tf.name_scope("train_op"): d_opt = tf.train.GradientDescentOptimizer(self.d_lr) var_list = tf.trainable_variables(self.scope + "/discriminator") gvs, d_norm = clip_grads(self.d_loss, var_list) self.d_train = d_opt.minimize(self.d_loss, var_list=var_list, global_step=self._global_step) g_opt = tf.train.AdamOptimizer(self.g_lr) var_list = tf.trainable_variables(self.scope + "/generator") gvs, g_norm = clip_grads(self.g_loss, var_list) self.g_train = g_opt.minimize(self.g_loss, var_list=var_list, global_step=self._global_step) # g_train = g_opt.apply_gradients(gvs, global_step=self._global_step) self.train = tf.cond(self.flag, lambda: self.g_train, lambda: self.d_train) self._summary_dict.update({ "distance": self._gen_norm(self.x_fake, self.y), "g_norm": g_norm, "d_norm": d_norm, "g_loss": self.g_loss, "d_loss": self.d_loss })
def build(self, *args, **kwargs): self._global_step = tf.Variable(initial_value=0, dtype=tf.int32, name="global_step", trainable=False) self._ph_op() self._graph_op(*args, **kwargs) self._predict_op() self._vars = tf.trainable_variables() self._loss_op() self._train_op() self._summary_op() self._built = True tf.logging.log(logging.INFO, "Built model with scope {}".format(self._scope))
def forward_pass(self, data): """Computes the query logits for the given episode `data`.""" if self.film_init == 'scratch': self.film_selector = None elif self.film_init == 'imagenet': # Note: this makes the assumption that the first set of learned FiLM # parameters corresponds to the ImageNet dataset. Otherwise, the # following line should be changed appropriately. self.film_selector = 0 elif self.film_init in ['blender', 'blender_hard']: dataset_logits = functional_backbones.dataset_classifier( data.support_images) if self.film_init == 'blender_hard': # Select only the argmax entry. self.film_selector = tf.one_hot( tf.math.argmax(dataset_logits, axis=-1), depth=tf.shape(dataset_logits)[1]) else: # Take a convex combination. self.film_selector = tf.nn.softmax(dataset_logits, axis=-1) if self.num_steps: # Initial forward pass, required for the `unused_op` below and for placing # variables in tf.trainable_variables() for the below block to pick up. loss = self._compute_losses(data, compute_on_query=False)['loss'] # Pick out the variables to optimize. self.opt_vars = [] for var in tf.trainable_variables(): if '_for_film_learner' in var.name: self.opt_vars.append(var) tf.logging.info('FiLMLearner will optimize vars: {}'.format( self.opt_vars)) for i in range(self.num_steps): if i == 0: # Re-initialize the variables to optimize for the new episode, to ensure # the FiLM parameters aren't re-used across tasks of a given dataset. vars_reset = tf.variables_initializer(var_list=self.opt_vars) # Adam related variables are created when minimize() is called. # We create an unused op here to put all adam varariables under # the 'adam_opt' namescope and create a reset op to reinitialize # these variables before the first finetune step. with tf.variable_scope('adam_opt', reuse=tf.AUTO_REUSE): unused_op = self.opt.minimize(loss, var_list=self.opt_vars) adam_reset = tf.variables_initializer(self.opt.variables()) with tf.control_dependencies([vars_reset, adam_reset, loss] + self.opt_vars): print_op = tf.no_op() if self.debug_log: print_op = tf.print([ 'step: %d' % i, self.opt_vars[0][0], 'loss:', loss ], summarize=-1) with tf.control_dependencies([print_op]): # Get the train op. results = self._get_train_op(data) (train_op, loss, query_loss, acc, query_acc) = (results['train_op'], results['loss'], results['query_loss'], results['acc'], results['query_acc']) else: with tf.control_dependencies([train_op, loss, acc] + self.opt_vars + [query_loss, query_acc] * int(self.debug_log)): print_op = tf.no_op() if self.debug_log: print_list = [ '################', 'step: %d' % i, self.opt_vars[0][0], 'support loss:', loss, 'query loss:', query_loss, 'support acc:', acc, 'query acc:', query_acc, ] print_op = tf.print(print_list) with tf.control_dependencies([print_op]): # Get the train op (the loss is returned just for printing). results = self._get_train_op(data) (train_op, loss, query_loss, acc, query_acc) = (results['train_op'], results['loss'], results['query_loss'], results['acc'], results['query_acc']) # Training is now over, compute the final query logits. dependency_list = [] if not self.num_steps else [train_op ] + self.opt_vars with tf.control_dependencies(dependency_list): results = self._compute_losses(data, compute_on_query=True) (loss, query_loss, query_logits, acc, query_acc) = (results['loss'], results['query_loss'], results['query_logits'], results['acc'], results['query_acc']) print_op = tf.no_op() if self.debug_log: print_op = tf.print([ 'Done training', 'support loss:', loss, 'query loss:', query_loss, 'support acc:', acc, 'query acc:', query_acc, ]) with tf.control_dependencies([print_op]): query_logits = tf.identity(query_logits) return query_logits
def compute_logits(self, data): """Computes the class logits for the episode. Args: data: A `meta_dataset.providers.Episode`. Returns: The query set logits as a [num_query_images, way] matrix. Raises: ValueError: Distance must be one of l2 or cosine. """ # ------------------------ Finetuning ------------------------------- # Possibly make copies of embedding variables, if they will get modified. # This is for making temporary-only updates to the embedding network # which will not persist after the end of the episode. make_copies = self.finetune_all_layers # TODO(eringrant): Reduce the number of times the embedding function graph # is built with the same input. support_embeddings_params_moments = self.embedding_fn( data.support_images, self.is_training) support_embeddings = support_embeddings_params_moments['embeddings'] support_embeddings_var_dict = support_embeddings_params_moments['params'] (embedding_vars_keys, embedding_vars, embedding_vars_copy_ops) = get_embeddings_vars_copy_ops( support_embeddings_var_dict, make_copies) embedding_vars_copy_op = tf.group(*embedding_vars_copy_ops) # Compute the initial training loss (only for printing purposes). This # line is also needed for adding the fc variables to the graph so that the # tf.all_variables() line below detects them. logits = self._fc_layer(support_embeddings)[:, 0:data.way] finetune_loss = self.compute_loss( onehot_labels=data.onehot_support_labels, predictions=logits, ) # Decide which variables to finetune. fc_vars, vars_to_finetune = [], [] for var in tf.trainable_variables(): if 'fc_finetune' in var.name: fc_vars.append(var) vars_to_finetune.append(var) if self.finetune_all_layers: vars_to_finetune.extend(embedding_vars) logging.info('Finetuning will optimize variables: %s', vars_to_finetune) for i in range(self.num_finetune_steps): if i == 0: # Randomly initialize the fc layer. fc_reset = tf.variables_initializer(var_list=fc_vars) # Adam related variables are created when minimize() is called. # We create an unused op here to put all adam varariables under # the 'adam_opt' namescope and create a reset op to reinitialize # these variables before the first finetune step. adam_reset = tf.no_op() if self.finetune_with_adam: with tf.variable_scope('adam_opt'): unused_op = self.finetune_opt.minimize( finetune_loss, var_list=vars_to_finetune) adam_reset = tf.variables_initializer(self.finetune_opt.variables()) with tf.control_dependencies( [fc_reset, adam_reset, finetune_loss, embedding_vars_copy_op] + vars_to_finetune): print_op = tf.no_op() if self.debug_log: print_op = tf.print([ 'step: %d' % i, vars_to_finetune[0][0, 0], 'loss:', finetune_loss ]) with tf.control_dependencies([print_op]): # Get the operation for finetuning. # (The logits and loss are returned just for printing). logits, finetune_loss, finetune_op = self._get_finetune_op( data, embedding_vars_keys, embedding_vars, vars_to_finetune, support_embeddings if not self.finetune_all_layers else None) if self.debug_log: # Test logits are computed only for printing logs. query_embeddings = self.embedding_fn( data.query_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, embedding_vars)), reuse=True)['embeddings'] query_logits = (self._fc_layer(query_embeddings)[:, 0:data.way]) else: with tf.control_dependencies([finetune_op, finetune_loss] + vars_to_finetune): print_op = tf.no_op() if self.debug_log: print_op = tf.print([ 'step: %d' % i, vars_to_finetune[0][0, 0], 'loss:', finetune_loss, 'accuracy:', self.compute_accuracy( labels=data.onehot_support_labels, predictions=logits), 'query accuracy:', self.compute_accuracy( labels=data.onehot_query_labels, predictions=query_logits), ]) with tf.control_dependencies([print_op]): # Get the operation for finetuning. # (The logits and loss are returned just for printing). logits, finetune_loss, finetune_op = self._get_finetune_op( data, embedding_vars_keys, embedding_vars, vars_to_finetune, support_embeddings if not self.finetune_all_layers else None) if self.debug_log: # Test logits are computed only for printing logs. query_embeddings = self.embedding_fn( data.query_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, embedding_vars)), reuse=True)['embeddings'] query_logits = (self._fc_layer(query_embeddings)[:, 0:data.way]) # Finetuning is now over, compute the query performance using the updated # fc layer, and possibly the updated embedding network. with tf.control_dependencies([finetune_op] + vars_to_finetune): query_embeddings = self.embedding_fn( data.query_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, embedding_vars)), reuse=True)['embeddings'] query_logits = self._fc_layer(query_embeddings)[:, 0:data.way] if self.debug_log: # The train logits are computed only for printing. support_embeddings = self.embedding_fn( data.support_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, embedding_vars)), reuse=True)['embeddings'] logits = self._fc_layer(support_embeddings)[:, 0:data.way] print_op = tf.no_op() if self.debug_log: print_op = tf.print([ 'accuracy:', self.compute_accuracy( labels=data.onehot_support_labels, predictions=logits), 'query accuracy:', self.compute_accuracy( labels=data.onehot_query_labels, predictions=query_logits), ]) with tf.control_dependencies([print_op]): query_logits = self._fc_layer(query_embeddings)[:, 0:data.way] return query_logits
def get_train_op(loss, learning_rate=0.001, lr_decay_steps=10000, lr_decay_rate=0.98, gradient_clip_norm=3.0, use_tpu=True, variables=None): """Get training operation with gradient clipping and learning rate decay. Distilled from tf.contrib.layers.optimize_loss(). Args: loss: Scalar tensor of the loss function. learning_rate: Scalar initial learning rate. lr_decay_steps: Exponential decay timescale. lr_decay_rate: Exponential decay magnitude. gradient_clip_norm: Global norm by which to scale gradients. use_tpu: Use tpu for training. variables: List of variables to optimize. tf.trainable_variables() if None. Returns: train_op: Operation that runs one iteration of training. """ global_step = tf.train.get_or_create_global_step() with tf.variable_scope('training', values=[loss, global_step]): # Make sure update ops run before computing loss. update_ops = list(set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))) with tf.control_dependencies(update_ops): loss = tf.identity(loss) # Learning rate variable, with decay. learning_rate_decay_fn = functools.partial(tf.train.exponential_decay, decay_steps=lr_decay_steps, decay_rate=lr_decay_rate, staircase=True) lr = tf.get_variable( 'learning_rate', [], trainable=False, initializer=tf.constant_initializer(learning_rate)) lr = learning_rate_decay_fn(lr, global_step) # Optimizer. opt = tf.train.AdamOptimizer(lr) if use_tpu: opt = tf.tpu.CrossShardOptimizer(opt) # All trainable variables, if specific variables are not specified. if variables is None: variables = tf.trainable_variables() # Compute gradients. gradients = opt.compute_gradients(loss, variables, colocate_gradients_with_ops=False) # Optionally clip gradients by global norm. if isinstance(gradient_clip_norm, float): gradients = _clip_gradients_by_norm(gradients, gradient_clip_norm) # Create gradient updates. grad_updates = opt.apply_gradients(gradients, global_step=global_step, name='train') # Ensure the train_op computes grad_updates. with tf.control_dependencies([grad_updates]): train_op = tf.identity(loss) return train_op