Esempio n. 1
0
def main(argv=None): 
  # Configurations
  config = Config(gpu='1',
                  root_dir='./data/test/',
                  root_dir_val=None,
                  mode='testing')
  config.BATCH_SIZE = 1

  # Get images and labels.
  dataset_test = Dataset(config, 'test')

  # Train
  _M, _s, _b, _C, _T, _imname = _step(config, dataset_test, False)

  # Add ops to save and restore all the variables.
  saver = tf.train.Saver(max_to_keep=50,)
  with tf.Session(config=config.GPU_CONFIG) as sess:
    # Restore the model
    ckpt = tf.train.get_checkpoint_state(config.LOG_DIR)
    if ckpt and ckpt.model_checkpoint_path:
      saver.restore(sess, ckpt.model_checkpoint_path)
      last_epoch = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
      print('**********************************************************')
      print('Restore from Epoch '+str(last_epoch))
      print('**********************************************************')
    else:
      init = tf.initializers.global_variables()
      last_epoch = 0
      sess.run(init)
      print('**********************************************************')
      print('Train from scratch.')
      print('**********************************************************')

    step_per_epoch = int(len(dataset_test.name_list) / config.BATCH_SIZE)
    with open(config.LOG_DIR + '/test/score.txt', 'w') as f:
      for step in range(step_per_epoch):
        M, s, b, C, T, imname = sess.run([_M, _s, _b, _C, _T, _imname])
        # save the score
        for i in range(config.BATCH_SIZE):
            _name = imname[i].decode('UTF-8')
            _line = _name + ',' + str("{0:.3f}".format(M[i])) + ','\
                                + str("{0:.3f}".format(s[i])) + ','\
                                + str("{0:.3f}".format(b[i])) + ','\
                                + str("{0:.3f}".format(C[i])) + ','\
                                + str("{0:.3f}".format(T[i]))
            f.write(_line + '\n')  
            print(str(step+1)+'/'+str(step_per_epoch)+':'+_line, end='\r')  
    print("\n")
Esempio n. 2
0
                       "/home/umit/xDataset/deepFake-dat/Train_Live_Much_3",
                       "/home/umit/xDataset/deepFake-dat/Train_Fake_Much_4",
                       "/home/umit/xDataset/deepFake-dat/Train_Live_Much_4",
                       "/home/umit/xDataset/deepFake-dat/Train_Fake_Much_5",
                       "/home/umit/xDataset/deepFake-dat/Train_Live_Much_5",
                       "/home/umit/xDataset/deepFake-dat/Train_Fake_Much_6"
                       "/home/umit/xDataset/deepFake-dat/Train_Live_Much_6",
                       "/home/umit/xDataset/deepFake-dat/Train_Fake_Much_7"]
    

    config.LOG_DIR = './log/model'
    config.MODE = 'training'
    config.STEPS_PER_EPOCH = 2000
    config.MAX_EPOCH = 1000
    config.LEARNING_RATE = 0.00001 #0.00005 #0.0001 #0.0005 #0.001
    config.BATCH_SIZE = 20
    # Validation
    config.DATA_DIR_VAL = ["/home/umit/xDataset/deepFake-dat/Train_Fake_Much_1",
                           "/home/umit/xDataset/deepFake-dat/Train_Live_Few_1"]
    config.STEPS_PER_EPOCH_VAL = 500
   
    config.display()

    # Get images and labels.
    dataset_train = Dataset(config,'train')
    #dataset_validation = Dataset(config,'validation')
    
    # Build a Graph
    model = Model(config)

    # # Train the model