def main(_): with tf.Graph().as_default(): train_data, test_data, word_embed = base_reader.inputs() m_train, m_valid = cnn_model.build_train_valid_model( word_embed, train_data, test_data) m_train.set_saver('cnn-%d-%d' % (FLAGS.num_epochs, FLAGS.word_dim)) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # for file queue config = tf.ConfigProto() # config.gpu_options.per_process_gpu_memory_fraction = 0.9 # 占用GPU90%的显存 config.gpu_options.allow_growth = True # sv finalize the graph with tf.Session(config=config) as sess: sess.run(init_op) print('=' * 80) if FLAGS.trace: trace_runtime(sess, m_train) elif FLAGS.test: test(sess, m_valid) else: train(sess, m_train, m_valid)
def main(_): with tf.Graph().as_default(): train_data, test_data, word_embed = base_reader.inputs() m_train, m_valid = crnn_model.build_train_valid_model(word_embed, train_data, test_data) m_train.set_saver('crnnatt-%d-%d' % (FLAGS.num_epochs, FLAGS.word_dim)) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # for file queue config = tf.ConfigProto() config.gpu_options.allow_growth = True # sv finalize the graph with tf.Session(config=config) as sess: sess.run(init_op) print('=' * 80) writer = tf.summary.FileWriter("logs/", sess.graph) if FLAGS.trace: trace_runtime(sess, m_train) elif FLAGS.test: test(sess, m_valid) else: train(sess, m_train, m_valid)
def main(_): with tf.Graph().as_default(): train_data, test_data, word_embed = base_reader.inputs() # sv = tf.train.Supervisor() # with sv.managed_session() as sess: # print('='*80) # for i in range(10): # arr = sess.run(test_data) # print(train_data[2].shape, arr[2].shape) # print(arr[2][0]) # exit() if FLAGS.model == 'cnn': m_train, m_valid = cnn_model.build_train_valid_model(word_embed, train_data, test_data) elif FLAGS.model == 'mtl': m_train, m_valid = mtl_model.build_train_valid_model(word_embed, train_data, test_data) m_train.set_saver(FLAGS.model) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())# for file queue config = tf.ConfigProto() # config.gpu_options.per_process_gpu_memory_fraction = 0.9 # 占用GPU90%的显存 config.gpu_options.allow_growth = True # sv finalize the graph with tf.Session(config=config) as sess: sess.run(init_op) print('='*80) if FLAGS.trace: trace_runtime(sess, m_train) elif FLAGS.test: test(sess, m_valid) else: train(sess, m_train, m_valid)