Example #1
0
  def __init__(self, model_path, batch_size, train_dataset, test_dataset):
    self.train_batch_size = batch_size
    self.test_batch_size = batch_size
    self.test_dataset = test_dataset
    self.train_dataset = train_dataset

    latest_checkpoint = tf.train.latest_checkpoint(
        checkpoint_dir=os.path.join(model_path, 'train'))
    print(latest_checkpoint)
    step = int(os.path.basename(latest_checkpoint).split('-')[1])
    flags = Namespace(
        utils.load_and_save_params(default_params=dict(), exp_dir=model_path))
    image_size = data_loader.get_image_size(flags.dataset)
    self.flags = flags

    with tf.Graph().as_default():
      self.tensor_images, self.tensor_labels = placeholder_inputs(
          batch_size=self.train_batch_size,
          image_size=image_size,
          scope='inputs')
      if flags.dataset == 'cifar10' or flags.dataset == 'cifar100':
        tensor_images_aug = data_loader.augment_cifar(
            self.tensor_images, is_training=False)
      else:
        tensor_images_aug = data_loader.augment_tinyimagenet(
            self.tensor_images, is_training=False)
      model = build_model(flags)
      with tf.variable_scope('Proto_training'):
        self.representation, self.variance = build_feature_extractor_graph(
            inputs=tensor_images_aug,
            flags=flags,
            is_variance=True,
            is_training=False,
            model=model)
      self.tensor_train_rep, self.tensor_test_rep, \
      self.tensor_train_rep_label, self.tensor_test_rep_label,\
      self.center = get_class_center_for_evaluation(
          self.train_batch_size, self.test_batch_size, flags.num_classes_total)

      self.prediction, self.acc \
        = make_predictions_for_evaluation(self.center,
                                          self.tensor_test_rep,
                                          self.tensor_test_rep_label,
                                          self.flags)
      self.tensor_test_variance = tf.placeholder(
          shape=[self.test_batch_size, feature_dim], dtype=tf.float32)
      self.nll, self.confidence = confidence_estimation_and_evaluation(
          self.center, self.tensor_test_rep, self.tensor_test_variance,
          self.tensor_test_rep_label, flags)

      config = tf.ConfigProto(allow_soft_placement=True)
      config.gpu_options.allow_growth = True
      self.sess = tf.Session(config=config)
      # Runs init before loading the weights
      self.sess.run(tf.global_variables_initializer())
      # Loads weights
      saver = tf.train.Saver()
      saver.restore(self.sess, latest_checkpoint)
      self.flags = flags
      self.step = step
      log_dir = flags.log_dir
      graphpb_txt = str(tf.get_default_graph().as_graph_def())
      with open(os.path.join(log_dir, 'eval', 'graph.pbtxt'), 'w') as f:
        f.write(graphpb_txt)
Example #2
0
def train(flags):
  """Training entry point."""
  log_dir = flags.log_dir
  flags.pretrained_model_dir = log_dir
  log_dir = os.path.join(log_dir, 'train')
  flags.eval_interval_secs = 0
  with tf.Graph().as_default():
    global_step = tf.Variable(
        0, trainable=False, name='global_step', dtype=tf.int64)
    global_step_confidence = tf.Variable(
        0, trainable=False, name='global_step_confidence', dtype=tf.int64)

    model = build_model(flags)
    images_query_pl, labels_query_pl, \
    images_support_pl, labels_support_pl = \
      build_episode_placeholder(flags)

    # Augments the input.
    if flags.dataset == 'cifar10' or flags.dataset == 'cifar100':
      images_query_pl_aug = data_loader.augment_cifar(
          images_query_pl, is_training=True)
      images_support_pl_aug = data_loader.augment_cifar(
          images_support_pl, is_training=True)
    elif flags.dataset == 'tinyimagenet':
      images_query_pl_aug = data_loader.augment_tinyimagenet(
          images_query_pl, is_training=True)
      images_support_pl_aug = data_loader.augment_tinyimagenet(
          images_support_pl, is_training=True)

    logits, logits_z = build_proto_train_graph(
        images_query=images_query_pl_aug,
        images_support=images_support_pl_aug,
        flags=flags,
        is_training=True,
        model=model)
    # Losses and optimizer
    ## Classification loss
    loss_classification = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(
            logits=logits,
            labels=tf.one_hot(labels_query_pl, flags.num_classes_train)))

    # Confidence loss
    _, top_k_indices = tf.nn.top_k(logits, k=1)
    pred = tf.squeeze(top_k_indices)
    incorrect_mask = tf.math.logical_not(tf.math.equal(pred, labels_query_pl))
    incorrect_logits_z = tf.boolean_mask(logits_z, incorrect_mask)
    incorrect_labels_z = tf.boolean_mask(labels_query_pl, incorrect_mask)
    signal_variance = tf.math.reduce_sum(tf.cast(incorrect_mask, tf.int32))
    loss_variance_incorrect = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(
            logits=incorrect_logits_z,
            labels=tf.one_hot(incorrect_labels_z, flags.num_classes_train)))
    loss_variance_zero = 0.0
    loss_confidence = tf.cond(
        tf.greater(signal_variance, 0), lambda: loss_variance_incorrect,
        lambda: loss_variance_zero)

    regu_losses = tf.losses.get_regularization_losses()
    loss = tf.add_n([loss_classification] + regu_losses)

    # Learning rate
    if flags.lr_anneal == 'const':
      learning_rate = flags.init_learning_rate
    elif flags.lr_anneal == 'pwc':
      learning_rate = get_pwc_learning_rate(global_step, flags)
    elif flags.lr_anneal == 'exp':
      lr_decay_step = flags.number_of_steps // flags.n_lr_decay
      learning_rate = tf.train.exponential_decay(
          flags.init_learning_rate,
          global_step,
          lr_decay_step,
          1.0 / flags.lr_decay_rate,
          staircase=True)
    else:
      raise Exception('Not implemented')

    # Optimizer
    optimizer = tf.train.MomentumOptimizer(
        learning_rate=learning_rate, momentum=0.9)
    optimizer_confidence = tf.train.MomentumOptimizer(
        learning_rate=learning_rate, momentum=0.9)

    train_op = contrib_slim.learning.create_train_op(
        total_loss=loss,
        optimizer=optimizer,
        global_step=global_step,
        clip_gradient_norm=flags.clip_gradient_norm)
    variable_variance = []
    for v in tf.trainable_variables():
      if 'fc_variance' in v.name:
        variable_variance.append(v)
    train_op_confidence = contrib_slim.learning.create_train_op(
        total_loss=loss_confidence,
        optimizer=optimizer_confidence,
        global_step=global_step_confidence,
        clip_gradient_norm=flags.clip_gradient_norm,
        variables_to_train=variable_variance)

    tf.summary.scalar('loss', loss)
    tf.summary.scalar('loss_classification', loss_classification)
    tf.summary.scalar('loss_variance', loss_confidence)
    tf.summary.scalar('regu_loss', tf.add_n(regu_losses))
    tf.summary.scalar('learning_rate', learning_rate)
    # Merges all summaries except for pretrain
    summary = tf.summary.merge(
        tf.get_collection('summaries', scope='(?!pretrain).*'))

    # Gets datasets
    few_shot_data_train, test_dataset, train_dataset = get_train_datasets(flags)
    # Defines session and logging
    summary_writer_train = tf.summary.FileWriter(log_dir, flush_secs=1)
    saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True)
    print(saver.saver_def.filename_tensor_name)
    print(saver.saver_def.restore_op_name)
    # pylint: disable=unused-variable
    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()
    supervisor = tf.train.Supervisor(
        logdir=log_dir,
        init_feed_dict=None,
        summary_op=None,
        init_op=tf.global_variables_initializer(),
        summary_writer=summary_writer_train,
        saver=saver,
        global_step=global_step,
        save_summaries_secs=flags.save_summaries_secs,
        save_model_secs=0)

    with supervisor.managed_session() as sess:
      checkpoint_step = sess.run(global_step)
      if checkpoint_step > 0:
        checkpoint_step += 1
      eval_interval_steps = flags.eval_interval_steps
      for step in range(checkpoint_step, flags.number_of_steps):
        # Computes the classification loss using a batch of data.
        images_query, labels_query,\
        images_support, labels_support = \
          few_shot_data_train.next_few_shot_batch(
              query_batch_size_per_task=flags.train_batch_size,
              num_classes_per_task=flags.num_classes_train,
              num_supports_per_class=flags.num_shots_train,
              num_tasks=flags.num_tasks_per_batch)

        feed_dict = {
            images_query_pl: images_query.astype(dtype=np.float32),
            labels_query_pl: labels_query,
            images_support_pl: images_support.astype(dtype=np.float32),
            labels_support_pl: labels_support
        }

        t_batch = time.time()
        dt_batch = time.time() - t_batch

        t_train = time.time()
        loss, loss_confidence = sess.run([train_op, train_op_confidence],
                                         feed_dict=feed_dict)
        dt_train = time.time() - t_train

        if step % 100 == 0:
          summary_str = sess.run(summary, feed_dict=feed_dict)
          summary_writer_train.add_summary(summary_str, step)
          summary_writer_train.flush()
          logging.info('step %d, loss : %.4g, dt: %.3gs, dt_batch: %.3gs', step,
                       loss, dt_train, dt_batch)

        if float(step) / flags.number_of_steps > 0.5:
          eval_interval_steps = flags.eval_interval_fine_steps

        if eval_interval_steps > 0 and step % eval_interval_steps == 0:
          saver.save(sess, os.path.join(log_dir, 'model'), global_step=step)
          eval(
              flags=flags,
              train_dataset=train_dataset,
              test_dataset=test_dataset)

        if float(
            step
        ) > 0.5 * flags.number_of_steps + flags.number_of_steps_to_early_stop:
          break