Ejemplo n.º 1
0
def main(_):
    with tf.Graph().as_default():
        train_data, test_data, word_embed = base_reader.inputs()

        m_train, m_valid = cnn_model.build_train_valid_model(
            word_embed, train_data, test_data)

        m_train.set_saver('cnn-%d-%d' % (FLAGS.num_epochs, FLAGS.word_dim))

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())  # for file queue

        config = tf.ConfigProto()
        # config.gpu_options.per_process_gpu_memory_fraction = 0.9 # 占用GPU90%的显存
        config.gpu_options.allow_growth = True

        # sv finalize the graph
        with tf.Session(config=config) as sess:
            sess.run(init_op)
            print('=' * 80)

            if FLAGS.trace:
                trace_runtime(sess, m_train)
            elif FLAGS.test:
                test(sess, m_valid)
            else:
                train(sess, m_train, m_valid)
def main(_):
	with tf.Graph().as_default():
		train_data, test_data, word_embed = base_reader.inputs()

		m_train, m_valid = crnn_model.build_train_valid_model(word_embed,
															  train_data, test_data)

		m_train.set_saver('crnnatt-%d-%d' % (FLAGS.num_epochs, FLAGS.word_dim))

		init_op = tf.group(tf.global_variables_initializer(),
						   tf.local_variables_initializer())  # for file queue

		config = tf.ConfigProto()
		config.gpu_options.allow_growth = True

		# sv finalize the graph
		with tf.Session(config=config) as sess:
			sess.run(init_op)
			print('=' * 80)
			writer = tf.summary.FileWriter("logs/", sess.graph)
			if FLAGS.trace:
				trace_runtime(sess, m_train)
			elif FLAGS.test:
				test(sess, m_valid)
			else:
				train(sess, m_train, m_valid)
Ejemplo n.º 3
0
def main(_):
  with tf.Graph().as_default():
    train_data, test_data, word_embed = base_reader.inputs()

    # sv = tf.train.Supervisor()
    # with sv.managed_session() as sess:
    #   print('='*80)
    #   for i in range(10):
    #     arr = sess.run(test_data)
    #     print(train_data[2].shape, arr[2].shape)
    #     print(arr[2][0])
    #   exit()

    
    if FLAGS.model == 'cnn':
      m_train, m_valid = cnn_model.build_train_valid_model(word_embed, 
                                                      train_data, test_data)
    elif FLAGS.model == 'mtl':
      m_train, m_valid = mtl_model.build_train_valid_model(word_embed, 
                                                      train_data, test_data)
    
    m_train.set_saver(FLAGS.model)
    
    init_op = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())# for file queue

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.9 # 占用GPU90%的显存 
    config.gpu_options.allow_growth = True
    
    # sv finalize the graph
    with tf.Session(config=config) as sess:
      sess.run(init_op)
      print('='*80)

      if FLAGS.trace:
        trace_runtime(sess, m_train)
      elif FLAGS.test:
        test(sess, m_valid)
      else:
        train(sess, m_train, m_valid)