def sample(self, reinitialize=False): """Samples weights for the network, displaying training and test errors.""" # we store training loss for sanity check training_loss_h, test_loss_h = [], [] # Keep only one over thinning_interval sample num_iters = self.num_samples * self.thinning_interval # Saving the weights of the last layer num_ll_weights = int((self.dim_input + 1) * self.num_classes) sampled_weights = np.zeros((self.num_samples, num_ll_weights)) # Random initialization of the weights if needed. if reinitialize: self.sess.run(tf.variables_initializer(self.ll_vars)) # sampling init_t = time.time() print('-----------------------------------------------------') print('Starting sampling of the Bayesian Neural Network by ' + six.ensure_str(self.sampler)) for i in np.arange(0, num_iters): batch_x, batch_y = self.next_batch() feed_dict = {self.x: batch_x, self.y: batch_y} self.sess.run([self.train_op], feed_dict=feed_dict) if (i + 1) % self.thinning_interval == 0: summary, loss_v, ll_vars_v = self.sess.run( [self.all_summaries, self.loss, self.ll_vars_concat], feed_dict=feed_dict) training_loss_h.append(loss_v) self.writer.add_summary(summary, i + 1) self.writer.flush() sampled_weights[ i // self.thinning_interval, :] = ll_vars_v.flatten('F') feed_dict_test = {self.x: self.x_test, self.y: self.y_test} test_loss_v = self.sess.run([self.loss], feed_dict=feed_dict_test) test_loss_h.append(test_loss_v) msg = ('{} steps. Loss = {}. \t Test Loss = {}.').format( str(i + 1), str(loss_v), str(test_loss_v)) print(msg) self.saver.save(self.sess, os.path.join(self.working_dir, 'weights/saved-last-weights'), global_step=i + 1, write_meta_graph=False) print('-----------------------------------------------------') print('Training complete after {} seconds.'.format(time.time() - init_t)) print('-----------------------------------------------------') self.writer.close() return training_loss_h, test_loss_h, sampled_weights
def sample(self, reinitialize=False): """Samples weights after training, for bootstrap we only need one sample.""" # we store training loss for sanity check training_loss_h, test_loss_h = [], [] # we first need to train the model to convergence (on bootstrapped data) num_training_iters = self.num_training_iters # saving the weights of the last layer num_ll_weights = int((self.dim_input + 1) * self.num_classes) sampled_weights = np.zeros((1, num_ll_weights)) # random initialization of the weights if needed. if reinitialize: self.sess.run(tf.variables_initializer(self.ll_vars)) # sampling init_t = time.time() print('-----------------------------------------------------') print('Starting sampling of the Bootstrapped Neural Network.') # we first train the model using bootstrap for i in np.arange(num_training_iters): # train step batch_x, batch_y = self.next_batch() feed_dict = {self.x: batch_x, self.y: batch_y} _, summary, loss_v, = self.sess.run( [self.train_op, self.all_summaries, self.loss], feed_dict=feed_dict) self.writer.add_summary(summary, i) self.writer.flush() training_loss_h.append(loss_v) # test error every 100 iters if i % 100 == 0: feed_dict = {self.x: self.x_test, self.y: self.y_test} test_loss_v = self.sess.run([self.loss], feed_dict=feed_dict) test_loss_h.append(test_loss_v) msg = ('{} steps. Loss = {}. \t Test Loss = {}.').format( str(i), str(loss_v), str(test_loss_v)) print(msg) # finally, we store the last model ll_vars_v = self.sess.run(self.ll_vars_concat, feed_dict=feed_dict) sampled_weights[0, :] = ll_vars_v.flatten('F') self.saver.save(self.sess, os.path.join( self.working_dir, 'weights/saved-boot{}-last-weights'.format( self.worker_id)), global_step=0, write_meta_graph=False) print('-----------------------------------------------------') print('Training complete after {} seconds.'.format(time.time() - init_t)) print('-----------------------------------------------------') self.writer.close() return training_loss_h, test_loss_h, sampled_weights
def forward_pass(self, data): """Computes the query logits for the given episode `data`.""" if self.film_init == 'scratch': self.film_selector = None elif self.film_init == 'imagenet': # Note: this makes the assumption that the first set of learned FiLM # parameters corresponds to the ImageNet dataset. Otherwise, the # following line should be changed appropriately. self.film_selector = 0 elif self.film_init in ['blender', 'blender_hard']: dataset_logits = functional_backbones.dataset_classifier( data.support_images) if self.film_init == 'blender_hard': # Select only the argmax entry. self.film_selector = tf.one_hot( tf.math.argmax(dataset_logits, axis=-1), depth=tf.shape(dataset_logits)[1]) else: # Take a convex combination. self.film_selector = tf.nn.softmax(dataset_logits, axis=-1) if self.num_steps: # Initial forward pass, required for the `unused_op` below and for placing # variables in tf.trainable_variables() for the below block to pick up. loss = self._compute_losses(data, compute_on_query=False)['loss'] # Pick out the variables to optimize. self.opt_vars = [] for var in tf.trainable_variables(): if '_for_film_learner' in var.name: self.opt_vars.append(var) tf.logging.info('FiLMLearner will optimize vars: {}'.format( self.opt_vars)) for i in range(self.num_steps): if i == 0: # Re-initialize the variables to optimize for the new episode, to ensure # the FiLM parameters aren't re-used across tasks of a given dataset. vars_reset = tf.variables_initializer(var_list=self.opt_vars) # Adam related variables are created when minimize() is called. # We create an unused op here to put all adam varariables under # the 'adam_opt' namescope and create a reset op to reinitialize # these variables before the first finetune step. with tf.variable_scope('adam_opt', reuse=tf.AUTO_REUSE): unused_op = self.opt.minimize(loss, var_list=self.opt_vars) adam_reset = tf.variables_initializer(self.opt.variables()) with tf.control_dependencies([vars_reset, adam_reset, loss] + self.opt_vars): print_op = tf.no_op() if self.debug_log: print_op = tf.print([ 'step: %d' % i, self.opt_vars[0][0], 'loss:', loss ], summarize=-1) with tf.control_dependencies([print_op]): # Get the train op. results = self._get_train_op(data) (train_op, loss, query_loss, acc, query_acc) = (results['train_op'], results['loss'], results['query_loss'], results['acc'], results['query_acc']) else: with tf.control_dependencies([train_op, loss, acc] + self.opt_vars + [query_loss, query_acc] * int(self.debug_log)): print_op = tf.no_op() if self.debug_log: print_list = [ '################', 'step: %d' % i, self.opt_vars[0][0], 'support loss:', loss, 'query loss:', query_loss, 'support acc:', acc, 'query acc:', query_acc, ] print_op = tf.print(print_list) with tf.control_dependencies([print_op]): # Get the train op (the loss is returned just for printing). results = self._get_train_op(data) (train_op, loss, query_loss, acc, query_acc) = (results['train_op'], results['loss'], results['query_loss'], results['acc'], results['query_acc']) # Training is now over, compute the final query logits. dependency_list = [] if not self.num_steps else [train_op ] + self.opt_vars with tf.control_dependencies(dependency_list): results = self._compute_losses(data, compute_on_query=True) (loss, query_loss, query_logits, acc, query_acc) = (results['loss'], results['query_loss'], results['query_logits'], results['acc'], results['query_acc']) print_op = tf.no_op() if self.debug_log: print_op = tf.print([ 'Done training', 'support loss:', loss, 'query loss:', query_loss, 'support acc:', acc, 'query acc:', query_acc, ]) with tf.control_dependencies([print_op]): query_logits = tf.identity(query_logits) return query_logits
def compute_logits(self, data): """Computes the class logits for the episode. Args: data: A `meta_dataset.providers.Episode`. Returns: The query set logits as a [num_query_images, way] matrix. Raises: ValueError: Distance must be one of l2 or cosine. """ # ------------------------ Finetuning ------------------------------- # Possibly make copies of embedding variables, if they will get modified. # This is for making temporary-only updates to the embedding network # which will not persist after the end of the episode. make_copies = self.finetune_all_layers # TODO(eringrant): Reduce the number of times the embedding function graph # is built with the same input. support_embeddings_params_moments = self.embedding_fn( data.support_images, self.is_training) support_embeddings = support_embeddings_params_moments['embeddings'] support_embeddings_var_dict = support_embeddings_params_moments['params'] (embedding_vars_keys, embedding_vars, embedding_vars_copy_ops) = get_embeddings_vars_copy_ops( support_embeddings_var_dict, make_copies) embedding_vars_copy_op = tf.group(*embedding_vars_copy_ops) # Compute the initial training loss (only for printing purposes). This # line is also needed for adding the fc variables to the graph so that the # tf.all_variables() line below detects them. logits = self._fc_layer(support_embeddings)[:, 0:data.way] finetune_loss = self.compute_loss( onehot_labels=data.onehot_support_labels, predictions=logits, ) # Decide which variables to finetune. fc_vars, vars_to_finetune = [], [] for var in tf.trainable_variables(): if 'fc_finetune' in var.name: fc_vars.append(var) vars_to_finetune.append(var) if self.finetune_all_layers: vars_to_finetune.extend(embedding_vars) logging.info('Finetuning will optimize variables: %s', vars_to_finetune) for i in range(self.num_finetune_steps): if i == 0: # Randomly initialize the fc layer. fc_reset = tf.variables_initializer(var_list=fc_vars) # Adam related variables are created when minimize() is called. # We create an unused op here to put all adam varariables under # the 'adam_opt' namescope and create a reset op to reinitialize # these variables before the first finetune step. adam_reset = tf.no_op() if self.finetune_with_adam: with tf.variable_scope('adam_opt'): unused_op = self.finetune_opt.minimize( finetune_loss, var_list=vars_to_finetune) adam_reset = tf.variables_initializer(self.finetune_opt.variables()) with tf.control_dependencies( [fc_reset, adam_reset, finetune_loss, embedding_vars_copy_op] + vars_to_finetune): print_op = tf.no_op() if self.debug_log: print_op = tf.print([ 'step: %d' % i, vars_to_finetune[0][0, 0], 'loss:', finetune_loss ]) with tf.control_dependencies([print_op]): # Get the operation for finetuning. # (The logits and loss are returned just for printing). logits, finetune_loss, finetune_op = self._get_finetune_op( data, embedding_vars_keys, embedding_vars, vars_to_finetune, support_embeddings if not self.finetune_all_layers else None) if self.debug_log: # Test logits are computed only for printing logs. query_embeddings = self.embedding_fn( data.query_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, embedding_vars)), reuse=True)['embeddings'] query_logits = (self._fc_layer(query_embeddings)[:, 0:data.way]) else: with tf.control_dependencies([finetune_op, finetune_loss] + vars_to_finetune): print_op = tf.no_op() if self.debug_log: print_op = tf.print([ 'step: %d' % i, vars_to_finetune[0][0, 0], 'loss:', finetune_loss, 'accuracy:', self.compute_accuracy( labels=data.onehot_support_labels, predictions=logits), 'query accuracy:', self.compute_accuracy( labels=data.onehot_query_labels, predictions=query_logits), ]) with tf.control_dependencies([print_op]): # Get the operation for finetuning. # (The logits and loss are returned just for printing). logits, finetune_loss, finetune_op = self._get_finetune_op( data, embedding_vars_keys, embedding_vars, vars_to_finetune, support_embeddings if not self.finetune_all_layers else None) if self.debug_log: # Test logits are computed only for printing logs. query_embeddings = self.embedding_fn( data.query_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, embedding_vars)), reuse=True)['embeddings'] query_logits = (self._fc_layer(query_embeddings)[:, 0:data.way]) # Finetuning is now over, compute the query performance using the updated # fc layer, and possibly the updated embedding network. with tf.control_dependencies([finetune_op] + vars_to_finetune): query_embeddings = self.embedding_fn( data.query_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, embedding_vars)), reuse=True)['embeddings'] query_logits = self._fc_layer(query_embeddings)[:, 0:data.way] if self.debug_log: # The train logits are computed only for printing. support_embeddings = self.embedding_fn( data.support_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, embedding_vars)), reuse=True)['embeddings'] logits = self._fc_layer(support_embeddings)[:, 0:data.way] print_op = tf.no_op() if self.debug_log: print_op = tf.print([ 'accuracy:', self.compute_accuracy( labels=data.onehot_support_labels, predictions=logits), 'query accuracy:', self.compute_accuracy( labels=data.onehot_query_labels, predictions=query_logits), ]) with tf.control_dependencies([print_op]): query_logits = self._fc_layer(query_embeddings)[:, 0:data.way] return query_logits
def initialize_session(self): """Initializes a tf Session.""" if ENABLE_TF_OPTIMIZATIONS: self.sess = tf.Session() else: rewriter_config = rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, constant_folding=rewriter_config_pb2.RewriterConfig.OFF, arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, remapping=rewriter_config_pb2.RewriterConfig.OFF, shape_optimization=rewriter_config_pb2.RewriterConfig.OFF, dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, function_optimization=rewriter_config_pb2.RewriterConfig.OFF, layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, loop_optimization=rewriter_config_pb2.RewriterConfig.OFF, memory_optimization=rewriter_config_pb2.RewriterConfig. NO_MEM_OPT) graph_options = tf.GraphOptions(rewrite_options=rewriter_config) session_config = tf.ConfigProto(graph_options=graph_options) self.sess = tf.Session(config=session_config) # Restore or initialize the variables. self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.local_variables_initializer()) if self.learner_config.checkpoint_for_eval: # Requested a specific checkpoint. self.saver.restore(self.sess, self.learner_config.checkpoint_for_eval) tf.logging.info('Restored checkpoint: %s' % self.learner_config.checkpoint_for_eval) else: # Continue from the latest checkpoint if one exists. # This handles fault-tolerance. latest_checkpoint = None if self.checkpoint_dir is not None: latest_checkpoint = tf.train.latest_checkpoint( self.checkpoint_dir) if latest_checkpoint: self.saver.restore(self.sess, latest_checkpoint) tf.logging.info('Restored checkpoint: %s' % latest_checkpoint) else: tf.logging.info('No previous checkpoint.') self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.local_variables_initializer()) # For episodic models, potentially use pretrained weights at the start of # training. If this happens it will overwrite the embedding weights, but # taking care to not restore the Adam parameters. if self.learner_config.pretrained_checkpoint and not self.sess.run( tf.train.get_global_step()): self.saver.restore(self.sess, self.learner_config.pretrained_checkpoint) tf.logging.info('Restored checkpoint: %s' % self.learner_config.pretrained_checkpoint) # We only want the embedding weights of the checkpoint we just restored. # So we re-initialize everything that's not an embedding weight. Also, # since this episodic finetuning procedure is a different optimization # problem than the original training of the baseline whose embedding # weights are re-used, we do not reload ADAM's variables and instead learn # them from scratch. vars_to_reinit, embedding_var_names, vars_to_reinit_names = [], [], [] for var in tf.global_variables(): if (any(keyword in var.name for keyword in EMBEDDING_KEYWORDS) and 'adam' not in var.name.lower()): embedding_var_names.append(var.name) continue vars_to_reinit.append(var) vars_to_reinit_names.append(var.name) tf.logging.info('Initializing all variables except for %s.' % embedding_var_names) self.sess.run(tf.variables_initializer(vars_to_reinit)) tf.logging.info('Re-initialized vars %s.' % vars_to_reinit_names)