def build_task_parameters(self): """Assign to attributes the meta parameters.""" self.locs = [ tf.Variable(tf.zeros((self.num_dims)), name='loc_{}'.format(i)) for i in range(self.num_components) ] self.log_scales = [ tf.Variable(tf.zeros((self.num_dims)), name='log_scale_{}'.format(i)) for i in range(self.num_components) ]
def build_meta_parameters(self): """Assign to attributes the task parameters.""" self.meta_loc = tf.Variable( self.loc_initializer([self.num_modes, self.num_dims]), trainable=self.trainable_loc, name='meta_loc') self.meta_log_scale = tf.Variable( self.log_scale_initializer([self.num_modes, self.num_dims]), trainable=self.trainable_scale, name='meta_log_scale') self.meta_logits = tf.Variable( self.logits_initializer([self.num_modes]), trainable=self.trainable_logits, name='meta_logits')
def em_loop( num_updates, e_step, m_step, variables, ): """Expectation-maximization of objective_fn wrt variables for num_updates.""" def _body(step, preupdate_vars): train_predictions_, responsibilities_ = e_step(preupdate_vars) updated_vars = m_step(preupdate_vars, train_predictions_, responsibilities_) return step + 1, updated_vars def _cond(step, *args): del args return step < num_updates step = tf.Variable(0, trainable=False, name='inner_step_counter') loop_vars = (step, variables) step, updated_vars = tf.while_loop(cond=_cond, body=_body, loop_vars=loop_vars, swap_memory=True) return updated_vars
def __init__(self, ckpt_dir, save_epoch_freq=1, max_to_keep=3): self._ckpt_saved_epoch = tf.Variable(initial_value=tf.constant( -1, dtype=tf.dtypes.int64), name='ckpt_saved_epoch') self.ckpt_dir = ckpt_dir self.max_to_keep = max_to_keep self.save_epoch_freq = save_epoch_freq
def _construct_variables(): """Construct an initialization for task parameters.""" def _split_mode_params(params): return [ tf.squeeze(p) for p in tf.split( params, axis=0, num_or_size_splits=self.num_modes) ] locs = _split_mode_params(tf.zeros_like(self.meta_loc)) log_scales = _split_mode_params(tf.zeros_like(self.meta_log_scale)) logits = tf.zeros_like(self.meta_logits) return ( [tf.Variable(loc, 'loc') for loc in locs], [tf.Variable(log_scale, 'log_scale') for log_scale in log_scales], tf.Variable(logits, 'logits'), )
def _build_sync_op(self): """Build the sync op.""" sync_count = tf.Variable(0, trainable=False) sync_ops = [tf.assign_add(sync_count, 1)] trainables_online = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='Online') trainables_target = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='Target') for (w_online, w_target) in zip(trainables_online, trainables_target): sync_ops.append(w_target.assign(w_online, use_locking=True)) tf.summary.scalar('Learning/SyncCount', sync_count) return sync_ops
def get_fc_vars_copy_ops(fc_weights, fc_bias, make_copies): """Gets copies of the classifier layer variables or returns those variables. At meta-test time, a copy is created for the given Variables, and these copies copies will be used in place of the original ones. Args: fc_weights: A Variable for the weights of the fc layer. fc_bias: A Variable for the bias of the fc layer. make_copies: A bool. Whether to copy the given variables. If not, those variables themselves are returned. Returns: fc_weights: A Variable for the weights of the fc layer. Might be the same as the input fc_weights or a copy of it. fc_bias: Analogously, a Variable for the bias of the fc layer. fc_vars_copy_ops: A (possibly empty) list of operations for assigning the value of each of fc_weights and fc_bias to a respective copy variable. """ fc_vars_copy_ops = [] if make_copies: with tf.variable_scope('weight_copy'): # fc_weights copy fc_weights_copy = tf.Variable( tf.zeros(fc_weights.shape.as_list()), collections=[tf.GraphKeys.LOCAL_VARIABLES]) fc_weights_copy_op = tf.assign(fc_weights_copy, fc_weights) fc_vars_copy_ops.append(fc_weights_copy_op) # fc_bias copy fc_bias_copy = tf.Variable( tf.zeros(fc_bias.shape.as_list()), collections=[tf.GraphKeys.LOCAL_VARIABLES]) fc_bias_copy_op = tf.assign(fc_bias_copy, fc_bias) fc_vars_copy_ops.append(fc_bias_copy_op) fc_weights = fc_weights_copy fc_bias = fc_bias_copy return fc_weights, fc_bias, fc_vars_copy_ops
def build(self, *args, **kwargs): self._global_step = tf.Variable(initial_value=0, dtype=tf.int32, name="global_step", trainable=False) self._ph_op() self._graph_op(*args, **kwargs) self._predict_op() self._vars = tf.trainable_variables() self._loss_op() self._train_op() self._summary_op() self._built = True tf.logging.log(logging.INFO, "Built model with scope {}".format(self._scope))
def __init__(self, embeddings_config: EmbeddingsConfig, model_config: TransformerSoftmaxModelConfig = gin.REQUIRED): super().__init__(embeddings_config) self.model_config = model_config self.pre_normalization_layer = tf.keras.layers.LayerNormalization(epsilon=1e-6) self.pre_dropout_layer = tf.keras.layers.Dropout(rate=self.model_config.pre_dropout_rate) self.transformer_layer = StackedTransformerEncodersLayer() self.post_hidden_layer = tf.keras.layers.Dense( units=self.embeddings_layer.config.embeddings_dimension, kernel_initializer=parameters_factory.get_parameters_initializer(), ) self.post_normalization_layer = tf.keras.layers.LayerNormalization(epsilon=1e-6) self.projection_bias = tf.Variable( initial_value=tf.zeros_initializer()(shape=(tf.shape(self.get_similarity_matrix())[0], )), trainable=True, )
def optimizer_loop( num_updates, objective_fn, update_fn, variables, first_order, clip_grad_norm, ): """Optimization of `objective_fn` for `num_updates` of `variables`.""" # Optimizer specifics. init, update, get_params = update_fn() def _body(step, preupdate_vars): """Optimization loop body.""" updated_vars = optimizer_update( iterate_collection=preupdate_vars, iteration_idx=step, objective_fn=objective_fn, update_fn=update, get_params_fn=get_params, first_order=first_order, clip_grad_norm=clip_grad_norm, ) return step + 1, updated_vars def _cond(step, *args): """Optimization truncation condition.""" del args return step < num_updates step = tf.Variable(0, trainable=False, name='inner_step_counter') loop_vars = (step, [init(var) for var in variables]) step, updated_vars = tf.while_loop(cond=_cond, body=_body, loop_vars=loop_vars, swap_memory=True) return [get_params(v) for v in updated_vars]
def get_embeddings_vars_copy_ops(embedding_vars_dict, make_copies): """Gets copies of the embedding variables or returns those variables. This is useful at meta-test time for MAML and the finetuning baseline. In particular, at meta-test time, we don't want to make permanent updates to the model's variables, but only modifications that persist in the given episode. This can be achieved by creating copies of each variable and modifying and using these copies instead of the variables themselves. Args: embedding_vars_dict: A dict mapping each variable name to the corresponding Variable. make_copies: A bool. Whether to copy the given variables. If not, those variables themselves will be returned. Typically, this is True at meta- test time and False at meta-training time. Returns: embedding_vars_keys: A list of variable names. embeddings_vars: A corresponding list of Variables. embedding_vars_copy_ops: A (possibly empty) list of operations, each of which assigns the value of one of the provided Variables to a new Variable which is its copy. """ embedding_vars_keys = [] embedding_vars = [] embedding_vars_copy_ops = [] for name, var in six.iteritems(embedding_vars_dict): embedding_vars_keys.append(name) if make_copies: with tf.variable_scope('weight_copy'): shape = var.shape.as_list() var_copy = tf.Variable( tf.zeros(shape), collections=[tf.GraphKeys.LOCAL_VARIABLES]) var_copy_op = tf.assign(var_copy, var) embedding_vars_copy_ops.append(var_copy_op) embedding_vars.append(var_copy) else: embedding_vars.append(var) return embedding_vars_keys, embedding_vars, embedding_vars_copy_ops
def train(hparams, num_epoch, tuning): log_dir = './results/' test_batch_size = 8 # Load dataset training_set, valid_set = make_dataset(BATCH_SIZE=hparams['HP_BS'], file_name='train_tf_record', split=True) test_set = make_dataset(BATCH_SIZE=test_batch_size, file_name='test_tf_record', split=False) class_names = ['NRDR', 'RDR'] # Model model = ResNet() # set optimizer optimizer = tf.keras.optimizers.Adam(learning_rate=hparams['HP_LR']) # set metrics train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() valid_accuracy = tf.keras.metrics.Accuracy() valid_con_mat = ConfusionMatrix(num_class=2) test_accuracy = tf.keras.metrics.Accuracy() test_con_mat = ConfusionMatrix(num_class=2) # Save Checkpoint if not tuning: ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model) manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=5) # Set up summary writers current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") tb_log_dir = log_dir + current_time + '/train' summary_writer = tf.summary.create_file_writer(tb_log_dir) # Restore Checkpoint if not tuning: ckpt.restore(manager.latest_checkpoint) if manager.latest_checkpoint: logging.info('Restored from {}'.format(manager.latest_checkpoint)) else: logging.info('Initializing from scratch.') @tf.function def train_step(train_img, train_label): # Optimize the model loss_value, grads = grad(model, train_img, train_label) optimizer.apply_gradients(zip(grads, model.trainable_variables)) train_pred, _ = model(train_img) train_label = tf.expand_dims(train_label, axis=1) train_accuracy.update_state(train_label, train_pred) for epoch in range(num_epoch): begin = time() # Training loop for train_img, train_label, train_name in training_set: train_img = data_augmentation(train_img) train_step(train_img, train_label) with summary_writer.as_default(): tf.summary.scalar('Train Accuracy', train_accuracy.result(), step=epoch) for valid_img, valid_label, _ in valid_set: valid_img = tf.cast(valid_img, tf.float32) valid_img = valid_img / 255.0 valid_pred, _ = model(valid_img, training=False) valid_pred = tf.cast(tf.argmax(valid_pred, axis=1), dtype=tf.int64) valid_con_mat.update_state(valid_label, valid_pred) valid_accuracy.update_state(valid_label, valid_pred) # Log the confusion matrix as an image summary cm_valid = valid_con_mat.result() figure = plot_confusion_matrix(cm_valid, class_names=class_names) cm_valid_image = plot_to_image(figure) with summary_writer.as_default(): tf.summary.scalar('Valid Accuracy', valid_accuracy.result(), step=epoch) tf.summary.image('Valid ConfusionMatrix', cm_valid_image, step=epoch) end = time() logging.info( "Epoch {:d} Training Accuracy: {:.3%} Validation Accuracy: {:.3%} Time:{:.5}s" .format(epoch + 1, train_accuracy.result(), valid_accuracy.result(), (end - begin))) train_accuracy.reset_states() valid_accuracy.reset_states() valid_con_mat.reset_states() if not tuning: if int(ckpt.step) % 5 == 0: save_path = manager.save() logging.info('Saved checkpoint for epoch {}: {}'.format( int(ckpt.step), save_path)) ckpt.step.assign_add(1) for test_img, test_label, _ in test_set: test_img = tf.cast(test_img, tf.float32) test_img = test_img / 255.0 test_pred, _ = model(test_img, training=False) test_pred = tf.cast(tf.argmax(test_pred, axis=1), dtype=tf.int64) test_accuracy.update_state(test_label, test_pred) test_con_mat.update_state(test_label, test_pred) cm_test = test_con_mat.result() # Log the confusion matrix as an image summary figure = plot_confusion_matrix(cm_test, class_names=class_names) cm_test_image = plot_to_image(figure) with summary_writer.as_default(): tf.summary.scalar('Test Accuracy', test_accuracy.result(), step=epoch) tf.summary.image('Test ConfusionMatrix', cm_test_image, step=epoch) logging.info("Trained finished. Final Accuracy in test set: {:.3%}".format( test_accuracy.result())) # Visualization if not tuning: for vis_img, vis_label, vis_name in test_set: vis_label = vis_label[0] vis_name = vis_name[0] vis_img = tf.cast(vis_img[0], tf.float32) vis_img = tf.expand_dims(vis_img, axis=0) vis_img = vis_img / 255.0 with tf.GradientTape() as tape: vis_pred, conv_output = model(vis_img, training=False) pred_label = tf.argmax(vis_pred, axis=-1) vis_pred = tf.reduce_max(vis_pred, axis=-1) grad_1 = tape.gradient(vis_pred, conv_output) weight = tf.reduce_mean(grad_1, axis=[1, 2]) / grad_1.shape[1] act_map0 = tf.nn.relu( tf.reduce_sum(weight * conv_output, axis=-1)) act_map0 = tf.squeeze(tf.image.resize(tf.expand_dims(act_map0, axis=-1), (256, 256), antialias=True), axis=-1) plot_map(vis_img, act_map0, vis_pred, pred_label, vis_label, vis_name) break return test_accuracy.result()
def learn_metric(self, verbose=False): """Approximate the bisimulation metric by learning. Args: verbose: bool, whether to print verbose messages. """ summary_writer = tf.summary.FileWriter(self.base_dir) global_step = tf.Variable(0, trainable=False) inc_global_step_op = tf.assign_add(global_step, 1) bisim_horizon = 0.0 bisim_horizon_discount_value = 1.0 if self.use_decayed_learning_rate: learning_rate = tf.train.exponential_decay(self.starting_learning_rate, global_step, self.num_iterations, self.learning_rate_decay, staircase=self.staircase) else: learning_rate = self.starting_learning_rate tf.summary.scalar('Learning/LearningRate', learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=self.epsilon) train_op = self._build_train_op(optimizer) sync_op = self._build_sync_op() eval_op = tf.stop_gradient(self._build_eval_metric()) eval_states = [] # Build the evaluation tensor. for state in range(self.num_states): row, col = self.inverse_index_states[state] # We make the evaluation states at the center of each grid cell. eval_states.append([row + 0.5, col + 0.5]) eval_states = np.array(eval_states, dtype=np.float64) normalized_bisim_metric = ( self.bisim_metric / np.linalg.norm(self.bisim_metric)) metric_errors = [] average_metric_errors = [] normalized_metric_errors = [] average_normalized_metric_errors = [] saver = tf.train.Saver(max_to_keep=3) with tf.Session() as sess: summary_writer.add_graph(graph=tf.get_default_graph()) sess.run(tf.global_variables_initializer()) merged_summaries = tf.summary.merge_all() for i in range(self.num_iterations): sampled_states = np.random.randint(self.num_states, size=(self.batch_size,)) sampled_actions = np.random.randint(4, size=(self.batch_size,)) if self.add_noise: sampled_noise = np.clip( np.random.normal(0, 0.1, size=(self.batch_size, 2)), -0.3, 0.3) sampled_action_names = [self.actions[x] for x in sampled_actions] next_states = [self.next_states[a][s] for s, a in zip(sampled_states, sampled_action_names)] rewards = np.array([self.rewards[a][s] for s, a in zip(sampled_states, sampled_action_names)]) states = np.array( [self.inverse_index_states[x] for x in sampled_states]) next_states = np.array([self.inverse_index_states[x] for x in next_states]) states = states.astype(np.float64) states += 0.5 # Place points in center of grid. next_states = next_states.astype(np.float64) next_states += 0.5 if self.add_noise: states += sampled_noise next_states += sampled_noise _, summary = sess.run( [train_op, merged_summaries], feed_dict={self.s1_ph: states, self.s2_ph: next_states, self.action_ph: sampled_actions, self.rewards_ph: rewards, self.bisim_horizon_ph: bisim_horizon, self.eval_states_ph: eval_states}) summary_writer.add_summary(summary, i) if self.double_period_halfway and i > self.num_iterations / 2.: self.target_update_period *= 2 self.double_period_halfway = False if i % self.target_update_period == 0: bisim_horizon = 1.0 - bisim_horizon_discount_value bisim_horizon_discount_value *= self.bisim_horizon_discount sess.run(sync_op) # Now compute difference with exact metric. self.learned_distance = sess.run( eval_op, feed_dict={self.eval_states_ph: eval_states}) self.learned_distance = np.reshape(self.learned_distance, (self.num_states, self.num_states)) metric_difference = np.max( abs(self.learned_distance - self.bisim_metric)) average_metric_difference = np.mean( abs(self.learned_distance - self.bisim_metric)) normalized_learned_distance = ( self.learned_distance / np.linalg.norm(self.learned_distance)) normalized_metric_difference = np.max( abs(normalized_learned_distance - normalized_bisim_metric)) average_normalized_metric_difference = np.mean( abs(normalized_learned_distance - normalized_bisim_metric)) error_summary = tf.Summary(value=[ tf.Summary.Value(tag='Approx/Error', simple_value=metric_difference), tf.Summary.Value(tag='Approx/AvgError', simple_value=average_metric_difference), tf.Summary.Value(tag='Approx/NormalizedError', simple_value=normalized_metric_difference), tf.Summary.Value(tag='Approx/AvgNormalizedError', simple_value=average_normalized_metric_difference), ]) summary_writer.add_summary(error_summary, i) sess.run(inc_global_step_op) if i % 100 == 0: # Collect statistics every 100 steps. metric_errors.append(metric_difference) average_metric_errors.append(average_metric_difference) normalized_metric_errors.append(normalized_metric_difference) average_normalized_metric_errors.append( average_normalized_metric_difference) saver.save(sess, os.path.join(self.base_dir, 'tf_ckpt'), global_step=i) if self.debug and i % 100 == 0: self.pretty_print_metric(metric_type='learned') print('Iteration: {}'.format(i)) print('Metric difference: {}'.format(metric_difference)) print('Normalized metric difference: {}'.format( normalized_metric_difference)) if self.add_noise: # Finally, if we have noise, we draw a bunch of samples to get estimates # of the distances between states. sampled_distances = {} for _ in range(self.total_final_samples): eval_states = [] for state in range(self.num_states): row, col = self.inverse_index_states[state] # We make the evaluation states at the center of each grid cell. eval_states.append([row + 0.5, col + 0.5]) eval_states = np.array(eval_states, dtype=np.float64) eval_noise = np.clip( np.random.normal(0, 0.1, size=(self.num_states, 2)), -0.3, 0.3) eval_states += eval_noise distance_samples = sess.run( eval_op, feed_dict={self.eval_states_ph: eval_states}) distance_samples = np.reshape(distance_samples, (self.num_states, self.num_states)) for s1 in range(self.num_states): for s2 in range(self.num_states): sampled_distances[(tuple(eval_states[s1]), tuple(eval_states[s2]))] = ( distance_samples[s1, s2]) else: # Otherwise we just use the last evaluation metric. sampled_distances = self.learned_distance learned_statistics = { 'num_iterations': self.num_iterations, 'metric_errors': metric_errors, 'average_metric_errors': average_metric_errors, 'normalized_metric_errors': normalized_metric_errors, 'average_normalized_metric_errors': average_normalized_metric_errors, 'learned_distances': sampled_distances, } self.statistics['learned'] = learned_statistics if verbose: self.pretty_print_metric(metric_type='learned')
def forward_pass(self, data): """Computes the test logits of MAML. Args: data: A `meta_dataset.providers.Episode` containing the data for the episode. Returns: The output logits for the query data in this episode. """ # Have to use one-hot labels since sparse softmax doesn't allow # second derivatives. support_embeddings_ = self.embedding_fn( data.support_images, self.is_training, reuse=tf.AUTO_REUSE) support_embeddings = support_embeddings_['embeddings'] embedding_vars_dict = support_embeddings_['params'] # TODO(eringrant): Refactor to make use of # `functional_backbones.linear_classifier`, which allows Gin-configuration. with tf.variable_scope('linear_classifier', reuse=tf.AUTO_REUSE): embedding_depth = support_embeddings.shape.as_list()[-1] fc_weights = functional_backbones.weight_variable( [embedding_depth, self.logit_dim], weight_decay=self.classifier_weight_decay) fc_bias = functional_backbones.bias_variable([self.logit_dim]) # A list of variable names, a list of corresponding Variables, and a list # of operations (possibly empty) that creates a copy of each Variable. (embedding_vars_keys, embedding_vars, embedding_vars_copy_ops) = get_embeddings_vars_copy_ops( embedding_vars_dict, make_copies=not self.is_training) # A Variable for the weights of the fc layer, a Variable for the bias of the # fc layer, and a list of operations (possibly empty) that copies them. (fc_weights, fc_bias, fc_vars_copy_ops) = get_fc_vars_copy_ops( fc_weights, fc_bias, make_copies=not self.is_training) fc_vars = [fc_weights, fc_bias] num_embedding_vars = len(embedding_vars) num_fc_vars = len(fc_vars) def _cond(step, *args): del args num_steps = self.num_update_steps if not self.is_training: num_steps += self.additional_evaluation_update_steps return step < num_steps def _body(step, *args): """The inner update loop body.""" updated_embedding_vars = args[0:num_embedding_vars] updated_fc_vars = args[num_embedding_vars:num_embedding_vars + num_fc_vars] support_embeddings = self.embedding_fn( data.support_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, updated_embedding_vars)), reuse=True)['embeddings'] updated_fc_weights, updated_fc_bias = updated_fc_vars support_logits = tf.matmul(support_embeddings, updated_fc_weights) + updated_fc_bias support_logits = support_logits[:, 0:data.way] loss = tf.losses.softmax_cross_entropy(data.onehot_support_labels, support_logits) print_op = tf.no_op() if self.debug_log: print_op = tf.print(['step: ', step, updated_fc_bias[0], 'loss:', loss]) with tf.control_dependencies([print_op]): updated_embedding_vars = gradient_descent_step( loss, updated_embedding_vars, self.first_order, self.adapt_batch_norm, self.alpha, False)['updated_vars'] updated_fc_vars = gradient_descent_step(loss, updated_fc_vars, self.first_order, self.adapt_batch_norm, self.alpha, False)['updated_vars'] step = step + 1 return tuple([step] + list(updated_embedding_vars) + list(updated_fc_vars)) # MAML meta updates using query set examples from an episode. if self.zero_fc_layer: # To account for variable class sizes, we initialize the output # weights to zero. See if truncated normal initialization will help. zero_weights_op = tf.assign(fc_weights, tf.zeros_like(fc_weights)) zero_bias_op = tf.assign(fc_bias, tf.zeros_like(fc_bias)) fc_vars_init_ops = [zero_weights_op, zero_bias_op] else: fc_vars_init_ops = fc_vars_copy_ops if self.proto_maml_fc_layer_init: support_embeddings = self.embedding_fn( data.support_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, embedding_vars)), reuse=True)['embeddings'] prototypes = metric_learners.compute_prototypes( support_embeddings, data.onehot_support_labels) pmaml_fc_weights = self.proto_maml_fc_weights( prototypes, zero_pad_to_max_way=True) pmaml_fc_bias = self.proto_maml_fc_bias( prototypes, zero_pad_to_max_way=True) fc_vars = [pmaml_fc_weights, pmaml_fc_bias] # These control dependencies assign the value of each variable to a new copy # variable that corresponds to it. This is required at test time for # initilizing the copies as they are used in place of the original vars. with tf.control_dependencies(fc_vars_init_ops + embedding_vars_copy_ops): # Make step a local variable as we don't want to save and restore it. step = tf.Variable( 0, trainable=False, name='inner_step_counter', collections=[tf.GraphKeys.LOCAL_VARIABLES]) loop_vars = [step] + embedding_vars + fc_vars step_and_all_updated_vars = tf.while_loop( _cond, _body, loop_vars, swap_memory=True) step = step_and_all_updated_vars[0] all_updated_vars = step_and_all_updated_vars[1:] updated_embedding_vars = all_updated_vars[0:num_embedding_vars] updated_fc_weights, updated_fc_bias = all_updated_vars[ num_embedding_vars:num_embedding_vars + num_fc_vars] # Forward pass the training images with the updated weights in order to # compute the means and variances, to use for the query's batch norm. support_set_moments = None if not self.transductive_batch_norm: support_set_moments = self.embedding_fn( data.support_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, updated_embedding_vars)), reuse=True)['moments'] query_embeddings = self.embedding_fn( data.query_images, self.is_training, params=collections.OrderedDict( zip(embedding_vars_keys, updated_embedding_vars)), moments=support_set_moments, # Use support set stats for batch norm. reuse=True, backprop_through_moments=self.backprop_through_moments)['embeddings'] query_logits = (tf.matmul(query_embeddings, updated_fc_weights) + updated_fc_bias)[:, 0:data.way] return query_logits
def evaluation(model_class=None, input_fn=None, num_quantitative_examples=1000, num_qualitative_examples=50): """A function that build the model and eval quali.""" tensorboard_callback = callback_utils.CustomTensorBoard( log_dir=FLAGS.eval_dir, batch_update_freq=1, split=FLAGS.split, num_qualitative_examples=num_qualitative_examples, num_steps_per_epoch=FLAGS.num_steps_per_epoch) model = model_class() checkpoint = tf.train.Checkpoint(model=model, ckpt_saved_epoch=tf.Variable( initial_value=-1, dtype=tf.int64)) val_inputs = input_fn(is_training=False, batch_size=1) num_evauated_epoch = -1 while True: ckpt_path = tf.train.latest_checkpoint(FLAGS.ckpt_dir) if ckpt_path: ckpt_num_of_epoch = int(ckpt_path.split('/')[-1].split('-')[-1]) if num_evauated_epoch == ckpt_num_of_epoch: logging.info( 'Found old epoch %d ckpt, skip and will check later.', num_evauated_epoch) time.sleep(30) continue try: logging.info('Restoring new checkpoint[epoch:%d] at %s', ckpt_num_of_epoch, ckpt_path) checkpoint.restore(ckpt_path) except tf.errors.NotFoundError: logging.info( 'Restoring from checkpoint has failed. Maybe file missing.' 'Try again now.') time.sleep(3) continue else: logging.info( 'No checkpoint found at %s, will check again 10 s later..', FLAGS.ckpt_dir) time.sleep(10) continue tensorboard_callback.set_epoch_number(ckpt_num_of_epoch) logging.info('Start qualitative eval for %d steps...', num_quantitative_examples) try: # TODO(huangrui): there is still possibility of crash due to # not found ckpt files. model._predict_counter.assign(0) # pylint: disable=protected-access tensorboard_callback.set_model(model) tensorboard_callback.on_predict_begin() for i, inputs in enumerate( val_inputs.take(num_quantitative_examples), start=1): tensorboard_callback.on_predict_batch_begin(batch=i) outputs = model(inputs, training=False) model._predict_counter.assign_add(1) # pylint: disable=protected-access tensorboard_callback.on_predict_batch_end(batch=i, logs={ 'outputs': outputs, 'inputs': inputs }) if i % FLAGS.num_steps_per_log == 0: logging.info('eval progress %d / %d...', i, num_quantitative_examples) tensorboard_callback.on_predict_end() num_evauated_epoch = ckpt_num_of_epoch logging.info('Finished eval for epoch %d, sleeping for :%d s...', num_evauated_epoch, 100) time.sleep(100) except tf.errors.NotFoundError: logging.info( 'Restoring from checkpoint has failed. Maybe file missing.' 'Try again now.') continue