예제 #1
0
def main(argv=None):
    # pylint: disable=unused-argument
    # pylint: disable=unused-variable
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.per_process_gpu_memory_fraction = 1.0
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    default_params = get_arguments()
    log_dir = default_params.log_dir
    ad = pathlib.Path(log_dir)
    if not ad.exists():
        ad.mkdir(parents=True)

    # This makes sure that we can store a json and recove a namespace back
    flags = Namespace(utils.load_and_save_params(vars(default_params),
                                                 log_dir))
예제 #2
0
def main(argv=None):
  # pylint: disable=unused-argument
  # pylint: disable=unused-variable
  config = tf.ConfigProto(allow_soft_placement=True)
  config.gpu_options.per_process_gpu_memory_fraction = 1.0
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)

  # Gets parameters.
  default_params = get_arguments()

  # Creates the experiment directory.
  log_dir = default_params.log_dir
  ad = pathlib.Path(log_dir)
  if not ad.exists():
    ad.mkdir(parents=True)

  # Main function for training and evaluation.
  flags = Namespace(utils.load_and_save_params(vars(default_params),
                                               log_dir, ignore_existing=True))
  train(flags=flags)
예제 #3
0
  def __init__(self, model_path, batch_size, train_dataset, test_dataset):
    self.train_batch_size = batch_size
    self.test_batch_size = batch_size
    self.test_dataset = test_dataset
    self.train_dataset = train_dataset

    latest_checkpoint = tf.train.latest_checkpoint(
        checkpoint_dir=os.path.join(model_path, 'train'))
    print(latest_checkpoint)
    step = int(os.path.basename(latest_checkpoint).split('-')[1])
    flags = Namespace(
        utils.load_and_save_params(default_params=dict(), exp_dir=model_path))
    image_size = data_loader.get_image_size(flags.dataset)
    self.flags = flags

    with tf.Graph().as_default():
      self.tensor_images, self.tensor_labels = placeholder_inputs(
          batch_size=self.train_batch_size,
          image_size=image_size,
          scope='inputs')
      if flags.dataset == 'cifar10' or flags.dataset == 'cifar100':
        tensor_images_aug = data_loader.augment_cifar(
            self.tensor_images, is_training=False)
      else:
        tensor_images_aug = data_loader.augment_tinyimagenet(
            self.tensor_images, is_training=False)
      model = build_model(flags)
      with tf.variable_scope('Proto_training'):
        self.representation, self.variance = build_feature_extractor_graph(
            inputs=tensor_images_aug,
            flags=flags,
            is_variance=True,
            is_training=False,
            model=model)
      self.tensor_train_rep, self.tensor_test_rep, \
      self.tensor_train_rep_label, self.tensor_test_rep_label,\
      self.center = get_class_center_for_evaluation(
          self.train_batch_size, self.test_batch_size, flags.num_classes_total)

      self.prediction, self.acc \
        = make_predictions_for_evaluation(self.center,
                                          self.tensor_test_rep,
                                          self.tensor_test_rep_label,
                                          self.flags)
      self.tensor_test_variance = tf.placeholder(
          shape=[self.test_batch_size, feature_dim], dtype=tf.float32)
      self.nll, self.confidence = confidence_estimation_and_evaluation(
          self.center, self.tensor_test_rep, self.tensor_test_variance,
          self.tensor_test_rep_label, flags)

      config = tf.ConfigProto(allow_soft_placement=True)
      config.gpu_options.allow_growth = True
      self.sess = tf.Session(config=config)
      # Runs init before loading the weights
      self.sess.run(tf.global_variables_initializer())
      # Loads weights
      saver = tf.train.Saver()
      saver.restore(self.sess, latest_checkpoint)
      self.flags = flags
      self.step = step
      log_dir = flags.log_dir
      graphpb_txt = str(tf.get_default_graph().as_graph_def())
      with open(os.path.join(log_dir, 'eval', 'graph.pbtxt'), 'w') as f:
        f.write(graphpb_txt)