def __init__(self, model_path, batch_size, train_dataset, test_dataset): self.train_batch_size = batch_size self.test_batch_size = batch_size self.test_dataset = test_dataset self.train_dataset = train_dataset latest_checkpoint = tf.train.latest_checkpoint( checkpoint_dir=os.path.join(model_path, 'train')) print(latest_checkpoint) step = int(os.path.basename(latest_checkpoint).split('-')[1]) flags = Namespace( utils.load_and_save_params(default_params=dict(), exp_dir=model_path)) image_size = data_loader.get_image_size(flags.dataset) self.flags = flags with tf.Graph().as_default(): self.tensor_images, self.tensor_labels = placeholder_inputs( batch_size=self.train_batch_size, image_size=image_size, scope='inputs') if flags.dataset == 'cifar10' or flags.dataset == 'cifar100': tensor_images_aug = data_loader.augment_cifar( self.tensor_images, is_training=False) else: tensor_images_aug = data_loader.augment_tinyimagenet( self.tensor_images, is_training=False) model = build_model(flags) with tf.variable_scope('Proto_training'): self.representation, self.variance = build_feature_extractor_graph( inputs=tensor_images_aug, flags=flags, is_variance=True, is_training=False, model=model) self.tensor_train_rep, self.tensor_test_rep, \ self.tensor_train_rep_label, self.tensor_test_rep_label,\ self.center = get_class_center_for_evaluation( self.train_batch_size, self.test_batch_size, flags.num_classes_total) self.prediction, self.acc \ = make_predictions_for_evaluation(self.center, self.tensor_test_rep, self.tensor_test_rep_label, self.flags) self.tensor_test_variance = tf.placeholder( shape=[self.test_batch_size, feature_dim], dtype=tf.float32) self.nll, self.confidence = confidence_estimation_and_evaluation( self.center, self.tensor_test_rep, self.tensor_test_variance, self.tensor_test_rep_label, flags) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True self.sess = tf.Session(config=config) # Runs init before loading the weights self.sess.run(tf.global_variables_initializer()) # Loads weights saver = tf.train.Saver() saver.restore(self.sess, latest_checkpoint) self.flags = flags self.step = step log_dir = flags.log_dir graphpb_txt = str(tf.get_default_graph().as_graph_def()) with open(os.path.join(log_dir, 'eval', 'graph.pbtxt'), 'w') as f: f.write(graphpb_txt)
def train(flags): """Training entry point.""" log_dir = flags.log_dir flags.pretrained_model_dir = log_dir log_dir = os.path.join(log_dir, 'train') flags.eval_interval_secs = 0 with tf.Graph().as_default(): global_step = tf.Variable( 0, trainable=False, name='global_step', dtype=tf.int64) global_step_confidence = tf.Variable( 0, trainable=False, name='global_step_confidence', dtype=tf.int64) model = build_model(flags) images_query_pl, labels_query_pl, \ images_support_pl, labels_support_pl = \ build_episode_placeholder(flags) # Augments the input. if flags.dataset == 'cifar10' or flags.dataset == 'cifar100': images_query_pl_aug = data_loader.augment_cifar( images_query_pl, is_training=True) images_support_pl_aug = data_loader.augment_cifar( images_support_pl, is_training=True) elif flags.dataset == 'tinyimagenet': images_query_pl_aug = data_loader.augment_tinyimagenet( images_query_pl, is_training=True) images_support_pl_aug = data_loader.augment_tinyimagenet( images_support_pl, is_training=True) logits, logits_z = build_proto_train_graph( images_query=images_query_pl_aug, images_support=images_support_pl_aug, flags=flags, is_training=True, model=model) # Losses and optimizer ## Classification loss loss_classification = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( logits=logits, labels=tf.one_hot(labels_query_pl, flags.num_classes_train))) # Confidence loss _, top_k_indices = tf.nn.top_k(logits, k=1) pred = tf.squeeze(top_k_indices) incorrect_mask = tf.math.logical_not(tf.math.equal(pred, labels_query_pl)) incorrect_logits_z = tf.boolean_mask(logits_z, incorrect_mask) incorrect_labels_z = tf.boolean_mask(labels_query_pl, incorrect_mask) signal_variance = tf.math.reduce_sum(tf.cast(incorrect_mask, tf.int32)) loss_variance_incorrect = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( logits=incorrect_logits_z, labels=tf.one_hot(incorrect_labels_z, flags.num_classes_train))) loss_variance_zero = 0.0 loss_confidence = tf.cond( tf.greater(signal_variance, 0), lambda: loss_variance_incorrect, lambda: loss_variance_zero) regu_losses = tf.losses.get_regularization_losses() loss = tf.add_n([loss_classification] + regu_losses) # Learning rate if flags.lr_anneal == 'const': learning_rate = flags.init_learning_rate elif flags.lr_anneal == 'pwc': learning_rate = get_pwc_learning_rate(global_step, flags) elif flags.lr_anneal == 'exp': lr_decay_step = flags.number_of_steps // flags.n_lr_decay learning_rate = tf.train.exponential_decay( flags.init_learning_rate, global_step, lr_decay_step, 1.0 / flags.lr_decay_rate, staircase=True) else: raise Exception('Not implemented') # Optimizer optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=0.9) optimizer_confidence = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=0.9) train_op = contrib_slim.learning.create_train_op( total_loss=loss, optimizer=optimizer, global_step=global_step, clip_gradient_norm=flags.clip_gradient_norm) variable_variance = [] for v in tf.trainable_variables(): if 'fc_variance' in v.name: variable_variance.append(v) train_op_confidence = contrib_slim.learning.create_train_op( total_loss=loss_confidence, optimizer=optimizer_confidence, global_step=global_step_confidence, clip_gradient_norm=flags.clip_gradient_norm, variables_to_train=variable_variance) tf.summary.scalar('loss', loss) tf.summary.scalar('loss_classification', loss_classification) tf.summary.scalar('loss_variance', loss_confidence) tf.summary.scalar('regu_loss', tf.add_n(regu_losses)) tf.summary.scalar('learning_rate', learning_rate) # Merges all summaries except for pretrain summary = tf.summary.merge( tf.get_collection('summaries', scope='(?!pretrain).*')) # Gets datasets few_shot_data_train, test_dataset, train_dataset = get_train_datasets(flags) # Defines session and logging summary_writer_train = tf.summary.FileWriter(log_dir, flush_secs=1) saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True) print(saver.saver_def.filename_tensor_name) print(saver.saver_def.restore_op_name) # pylint: disable=unused-variable run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() supervisor = tf.train.Supervisor( logdir=log_dir, init_feed_dict=None, summary_op=None, init_op=tf.global_variables_initializer(), summary_writer=summary_writer_train, saver=saver, global_step=global_step, save_summaries_secs=flags.save_summaries_secs, save_model_secs=0) with supervisor.managed_session() as sess: checkpoint_step = sess.run(global_step) if checkpoint_step > 0: checkpoint_step += 1 eval_interval_steps = flags.eval_interval_steps for step in range(checkpoint_step, flags.number_of_steps): # Computes the classification loss using a batch of data. images_query, labels_query,\ images_support, labels_support = \ few_shot_data_train.next_few_shot_batch( query_batch_size_per_task=flags.train_batch_size, num_classes_per_task=flags.num_classes_train, num_supports_per_class=flags.num_shots_train, num_tasks=flags.num_tasks_per_batch) feed_dict = { images_query_pl: images_query.astype(dtype=np.float32), labels_query_pl: labels_query, images_support_pl: images_support.astype(dtype=np.float32), labels_support_pl: labels_support } t_batch = time.time() dt_batch = time.time() - t_batch t_train = time.time() loss, loss_confidence = sess.run([train_op, train_op_confidence], feed_dict=feed_dict) dt_train = time.time() - t_train if step % 100 == 0: summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer_train.add_summary(summary_str, step) summary_writer_train.flush() logging.info('step %d, loss : %.4g, dt: %.3gs, dt_batch: %.3gs', step, loss, dt_train, dt_batch) if float(step) / flags.number_of_steps > 0.5: eval_interval_steps = flags.eval_interval_fine_steps if eval_interval_steps > 0 and step % eval_interval_steps == 0: saver.save(sess, os.path.join(log_dir, 'model'), global_step=step) eval( flags=flags, train_dataset=train_dataset, test_dataset=test_dataset) if float( step ) > 0.5 * flags.number_of_steps + flags.number_of_steps_to_early_stop: break