def build_episode_placeholder(flags): """Builds the placeholders for the support and query input batches.""" image_size = data_loader.get_image_size(flags.dataset) images_query_pl, labels_query_pl = placeholder_inputs( batch_size=flags.num_tasks_per_batch * flags.train_batch_size, image_size=image_size, scope='inputs/query') images_support_pl, labels_support_pl = placeholder_inputs( batch_size=flags.num_tasks_per_batch * flags.num_classes_train * flags.num_shots_train, image_size=image_size, scope='inputs/support') return images_query_pl, labels_query_pl, images_support_pl, labels_support_pl
def __init__(self, model_path, batch_size, train_dataset, test_dataset): self.train_batch_size = batch_size self.test_batch_size = batch_size self.test_dataset = test_dataset self.train_dataset = train_dataset latest_checkpoint = tf.train.latest_checkpoint( checkpoint_dir=os.path.join(model_path, 'train')) print(latest_checkpoint) step = int(os.path.basename(latest_checkpoint).split('-')[1]) flags = Namespace( utils.load_and_save_params(default_params=dict(), exp_dir=model_path)) image_size = data_loader.get_image_size(flags.dataset) self.flags = flags with tf.Graph().as_default(): self.tensor_images, self.tensor_labels = placeholder_inputs( batch_size=self.train_batch_size, image_size=image_size, scope='inputs') if flags.dataset == 'cifar10' or flags.dataset == 'cifar100': tensor_images_aug = data_loader.augment_cifar( self.tensor_images, is_training=False) else: tensor_images_aug = data_loader.augment_tinyimagenet( self.tensor_images, is_training=False) model = build_model(flags) with tf.variable_scope('Proto_training'): self.representation, self.variance = build_feature_extractor_graph( inputs=tensor_images_aug, flags=flags, is_variance=True, is_training=False, model=model) self.tensor_train_rep, self.tensor_test_rep, \ self.tensor_train_rep_label, self.tensor_test_rep_label,\ self.center = get_class_center_for_evaluation( self.train_batch_size, self.test_batch_size, flags.num_classes_total) self.prediction, self.acc \ = make_predictions_for_evaluation(self.center, self.tensor_test_rep, self.tensor_test_rep_label, self.flags) self.tensor_test_variance = tf.placeholder( shape=[self.test_batch_size, feature_dim], dtype=tf.float32) self.nll, self.confidence = confidence_estimation_and_evaluation( self.center, self.tensor_test_rep, self.tensor_test_variance, self.tensor_test_rep_label, flags) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True self.sess = tf.Session(config=config) # Runs init before loading the weights self.sess.run(tf.global_variables_initializer()) # Loads weights saver = tf.train.Saver() saver.restore(self.sess, latest_checkpoint) self.flags = flags self.step = step log_dir = flags.log_dir graphpb_txt = str(tf.get_default_graph().as_graph_def()) with open(os.path.join(log_dir, 'eval', 'graph.pbtxt'), 'w') as f: f.write(graphpb_txt)