예제 #1
0
def main(_):
    print("-" * 80)

    if FLAGS.inference_mode:
        if not os.path.isdir(FLAGS.checkpoint_dir):
            print("Path {} does not exist. need checkpoint dir to inference.".format(FLAGS.output_dir))
        inference()
    elif FLAGS.test_mode:
        test_mode()
    else:
        if not os.path.isdir(FLAGS.output_dir):
            print(FLAGS.output_dir)
            print("Path {} does not exist. Creating.".format(FLAGS.output_dir))
            os.makedirs(FLAGS.output_dir)
        elif FLAGS.reset_output_dir:
            print("Path {} exists. Remove and remake.".format(FLAGS.output_dir))
            shutil.rmtree(FLAGS.output_dir)
            os.makedirs(FLAGS.output_dir)

        print("-" * 80)
        log_file = os.path.join(FLAGS.output_dir, "stdout")
        print("Logging to {}".format(log_file))
        sys.stdout = Logger(log_file)

        utils.print_user_flags()
        train()
예제 #2
0
def main(_):
  print("-" * 80)
  if not os.path.isdir(FLAGS.output_dir):
    print("Path {} does not exist. Creating.".format(FLAGS.output_dir))
    os.makedirs(FLAGS.output_dir)
  elif FLAGS.reset_output_dir:
    print("Path {} exists. Remove and remake.".format(FLAGS.output_dir))
    shutil.rmtree(FLAGS.output_dir)
    os.makedirs(FLAGS.output_dir)

  print("-" * 80)
  log_file = os.path.join(FLAGS.output_dir, "stdout")
  print("Logging to {}".format(log_file))
  sys.stdout = Logger(log_file)

  utils.print_user_flags()
  if not FLAGS.structure_path:
    exit()
  with open(FLAGS.structure_path, 'r') as fp:
    lines = fp.readlines()
  lines = [eval(line.strip()) for line in lines]
  structures = lines
  n = len(lines)
  # eval the first structure
  eva = Eval()
  eva.eval(structures[0])
  eva.eval(structures[1])
  eva.eval(structures[0])
def main(_):
  print("-" * 80)
  if not os.path.isdir(FLAGS.output_dir):
    print("Path {0} does not exist. Creating.".format(FLAGS.output_dir))
    os.makedirs(FLAGS.output_dir)
  elif FLAGS.reset_output_dir:
    print("Path {0} exists. Remove and remake.".format(FLAGS.output_dir))
    shutil.rmtree(FLAGS.output_dir)
    os.makedirs(FLAGS.output_dir)

  print_user_flags()

  hparams = Hparams()
  images, labels = read_data(FLAGS.data_path)

  g = tf.Graph()
  with g.as_default():
    ops = get_ops(images, labels)

    # count model variables
    tf_variables = tf.trainable_variables()
    num_params = count_model_params(tf_variables)

    print("-" * 80)
    print("Starting session")
    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.train.SingularMonitoredSession(
      config=config, checkpoint_dir=FLAGS.output_dir) as sess:

        # training loop
        print("-" * 80)
        print("Starting training")
        for step in range(1, hparams.train_steps + 1):
          sess.run(ops["train_op"])
          if step % FLAGS.log_every == 0:
            global_step, train_loss, valid_acc = sess.run([
              ops["global_step"],
              ops["train_loss"],
              ops["valid_acc"],
            ])
            log_string = ""
            log_string += "step={0:<6d}".format(step)
            log_string += " loss={0:<5.2f}".format(train_loss)
            log_string += " val_acc={0:<3d}/{1:<3d}".format(
              valid_acc, hparams.eval_batch_size)
            print(log_string)
            sys.stdout.flush()

        # final test
        print("-" * 80)
        print("Training done. Eval on TEST set")
        num_corrects = 0
        for _ in range(10000 // hparams.eval_batch_size):
          num_corrects += sess.run(ops["test_acc"])
        print("test_accuracy: {0:>5d}/10000".format(num_corrects))
예제 #4
0
def main(_):
    print("-" * 80)

    if not os.path.isdir(FLAGS.checkpoint_dir):
        print("Path {} does not exist. Can't find checkpoint.".format(
            FLAGS.checkpoint_dir))
    else:
        print("checkpoint exists. at {}".format(FLAGS.checkpoint_dir))
        print("-" * 80)
        utils.print_user_flags()
        evaluate()
예제 #5
0
def main(_):
    print("-" * 80)
    if not os.path.isdir(FLAGS.output_dir):
        print("Path {} does not exist. Creating.".format(FLAGS.output_dir))
        os.makedirs(FLAGS.output_dir)
    elif FLAGS.reset_output_dir:
        print("Path {} exists. Remove and remake.".format(FLAGS.output_dir))
        shutil.rmtree(FLAGS.output_dir)
        os.makedirs(FLAGS.output_dir)

    print("-" * 80)
    log_file = os.path.join(FLAGS.output_dir, "stdout")
    print("Logging to {}".format(log_file))
    sys.stdout = Logger(log_file)

    utils.print_user_flags()
    train(mode="train")
예제 #6
0
파일: evaluation.py 프로젝트: yyht/D-VAE
def Eval_NN():
  print("-" * 80)
  if not os.path.isdir(FLAGS.output_dir):
    print("Path {} does not exist. Creating.".format(FLAGS.output_dir))
    os.makedirs(FLAGS.output_dir)
  elif FLAGS.reset_output_dir:
    print("Path {} exists. Remove and remake.".format(FLAGS.output_dir))
    shutil.rmtree(FLAGS.output_dir)
    os.makedirs(FLAGS.output_dir)

  print("-" * 80)
  log_file = os.path.join(FLAGS.output_dir, "stdout")
  print("Logging to {}".format(log_file))
  sys.stdout = Logger(log_file)

  utils.print_user_flags()

  '''
  # below are for batch evaluation of all arcs defined in the structure_path
  if not FLAGS.structure_path:
    exit()
  with open(FLAGS.structure_path, 'r') as fp:
    lines = fp.readlines()
  lines = [eval(line.strip()) for line in lines]
  structures = []
  for line in lines:
    row = []
    for ele in line:
      row += ele
    structures.append(row) 
  n = len(lines)
  # eval the first structure
  Acc = []
  eva = Eval()
  eva.eval(structures[0])
  eva.eval(structures[1])
  acc = eva.eval(structures[0])
  print(acc)
  pdb.set_trace()
  '''
  eva = Eval()
  return eva
예제 #7
0
def main(_):
    print("-" * 80)
    if not os.path.isdir(FLAGS.output_dir):
        print("Path {} does not exist. Creating.".format(FLAGS.output_dir))
        os.makedirs(FLAGS.output_dir)
    elif FLAGS.reset_output_dir:
        print("Path {} exists. Remove and remake.".format(FLAGS.output_dir))
        shutil.rmtree(FLAGS.output_dir)
        os.makedirs(FLAGS.output_dir)

    print("-" * 80)
    log_file = os.path.join(FLAGS.output_dir, "stdout")
    print("Logging to {}".format(log_file))
    sys.stdout = Logger(log_file)

    utils.print_user_flags()
    model_file = os.path.join(FLAGS.output_dir, "models.csv")

    if FLAGS.child_fixed_arc is None:
        with open(model_file, 'a+') as f:
            headers = ['num_layers', 'accuracy', 'models_arc']
            writer = csv.DictWriter(f,
                                    headers,
                                    delimiter=',',
                                    lineterminator='\n')
            writer.writeheader()
            for i in range(FLAGS.search_from, FLAGS.child_num_layers + 1):
                tf.compat.v1.logging.info(
                    "Searching with constraint, num_layers: %d" % i)
                map_task = train(i)
                for k, v in map_task.items():
                    writer.writerow({
                        'num_layers': i,
                        'accuracy': k,
                        'models_arc': v
                    })
                f.flush()
    else:
        _ = train(FLAGS.child_num_layers)
예제 #8
0
파일: main.py 프로젝트: Vincentcent1/enas
def main(_):
  # Prepare directory
  pdb.set_trace()
  print("-" * 80)
  if not os.path.isdir(FLAGS.output_dir):
    print("Path {} does not exist. Creating.".format(FLAGS.output_dir))
    os.makedirs(FLAGS.output_dir)
  elif FLAGS.reset_output_dir:
    print("Path {} exists. Remove and remake.".format(FLAGS.output_dir))
    shutil.rmtree(FLAGS.output_dir)
    os.makedirs(FLAGS.output_dir)

  # Redirect stdout1 --------------------------------------------------------------------------------------------
  print("-" * 80)
  log_file = os.path.join(FLAGS.output_dir, "stdout1")
  if not os.path.exists(log_file):
    os.mknod(log_file)

  print("Logging to {}".format(log_file))
  sys.stdout = Logger(log_file)

  utils.print_user_flags()

  print('Reserving gpu memory...')
  tf.Session()
  # Load pickles file
  print('Loading pickled file...')
  with open('/home/yuwei/projects/vincent/pickleRick/allCrops1.pkl') as p_crop:
      allCrops1 = cPickle.load(p_crop)
  with open('/home/yuwei/projects/vincent/pickleRick/allCrops2.pkl') as p_crop:
      allCrops2 = cPickle.load(p_crop)

  with open('/home/yuwei/projects/vincent/pickleRick/labels.pkl','r') as p_crop:
    labels1 = cPickle.load(p_crop)
    labels2 = cPickle.load(p_crop)
    labels3 = cPickle.load(p_crop)    
    labels1_Brio1 = cPickle.load(p_crop)
    labels1_Brio2 = cPickle.load(p_crop)
    labels2_Brio1 = cPickle.load(p_crop)
  
  # Prepare and divide data
  autoTrainNN = AutoTrain()
  combined_1 = zip(allCrops1 + allCrops2, np.concatenate((labels1,labels2)))
  autoTrainNN.addLabelledData(combined_1)
  # train(autoTrainNN)

  # Redirect stdout2 -------------------------------------------------------------------------------------------
  print("-" * 80)
  log_file = os.path.join(FLAGS.output_dir, "stdout2")
  if not os.path.exists(log_file):
    os.mknod(log_file)

  print("Logging to {}".format(log_file))
  sys.stdout.log = open(log_file, "a") # Change log file

  # Load pickles file
  print('Loading pickled file...')
  with open('/home/yuwei/projects/vincent/pickleRick/allCrops3.pkl') as p_crop:
      allCrops3 = cPickle.load(p_crop)    
  with open('/home/yuwei/projects/vincent/pickleRick/brio1/allCrops1.pkl') as p_crop:
      allCrops1_Brio1 = cPickle.load(p_crop)

  combined_2 = zip(allCrops3 + allCrops1_Brio1, np.concatenate((labels3,labels1_Brio1)))
  autoTrainNN.addLabelledData(combined_2)
  # train(autoTrainNN)

  # Redirect stdout3 --------------------------------------------------------------------------------------------
  print("-" * 80)
  log_file = os.path.join(FLAGS.output_dir, "stdout3")
  if not os.path.exists(log_file):
    os.mknod(log_file)

  print("Logging to {}".format(log_file))
  sys.stdout.log = open(log_file, "a") # Change log file

  utils.print_user_flags()

  # Load pickles file
  print('Loading pickled file...')
  with open('/home/yuwei/projects/vincent/pickleRick/brio2/allCrops1.pkl') as p_crop:
      allCrops1_Brio2 = cPickle.load(p_crop)
  with open('/home/yuwei/projects/vincent/pickleRick/brio1/allCrops2.pkl') as p_crop:
      allCrops2_Brio1 = cPickle.load(p_crop)

  combined_3 = zip(allCrops1_Brio2 + allCrops2_Brio1, np.concatenate((labels1_Brio2,labels2_Brio1)))
  autoTrainNN.addLabelledData(combined_3)
  train(autoTrainNN)