Пример #1
0
	def infer(sentences, **kwargs):
		config = load_config(FLAGS.config_file)
		logger = get_logger(FLAGS.log_file)
		
		reformed_sentences = [' '.join(sen[1]) for sen in sentences]
		result = model.evaluate_lines(sess, inputs_from_sentences(reformed_sentences, char_to_id, FLAGS.max_char_length), id_to_tag)
		'''
		result = [
			       (0.0, ['ARG0', 'ARG3', '-']),
				   (0.0, ['ARG0', 'ARG1', '-'])
				 ]
		# evaluate_lines 함수는 문장 단위 분석 결과를 내어줍니다.
		# len(result) : 문장의 갯수, 따라서 위 예제는 두 문장의 결과입니다.
		# result[0] : 첫번째 문장의 분석 결과, result[1] : 두번째 문장의 분석 결과.
		
		# 각 문장의 분석 결과는 다시 (prob, [labels])로 구성됩니다.
		# prob에 해당하는 자료는 이번 task에서 사용하지 않습니다. 따라서 그 값이 결과에 영향을 미치지 않습니다.
		# [labels]는 각 어절의 분석 결과를 담고 있습니다. 따라서 다음과 같이 구성됩니다.
		## ['첫번째 어절의 분석 결과', '두번째 어절의 분석 결과', ...]
		# 예를 들면 위 주어진 예제에서 첫번째 문장의 첫번째 어절은 'ARG0'을, 첫번째 문장의 두번째 어절은 'ARG3'을 argument label로 가집니다.

		### 주의사항 ###
		# 모든 어절의 결과를 제출하여야 합니다.
		# 어절의 순서가 지켜져야 합니다. (첫번째 어절부터 순서대로 list 내에 위치하여야 합니다.)
		'''
		return result
Пример #2
0
def evaluate_cli(model, context_embeddings_op, elmo_context, elmo_ids):
    config = load_config(FLAGS.config_file)
    logger = get_logger(FLAGS.log_file)

    if FLAGS.task == "NER":
        with open(FLAGS.necessary, "rb") as f:
            word_to_id, id_to_word, char_to_id, id_to_char, pumsa_to_id, id_to_pumsa, tag_to_id, id_to_tag, ner_morph_tag = pickle.load(
                f)

    komoran = Komoran()
    results = []
    while True:
        # line = input("문장을 입력하세요.:")
        line = [
            "찬민이의 멘탈이 산산조각났습니다.", "진짜 진짜 진짜 맛있는 진짜 라면", "집에 가고 싶읍니다.", "집",
            "가 가 가 가 가 가, 가, 가 ,가, 가 가 가 가 가 가, 가, 가 ,가 !!!!! ."
        ]
        for idx in range(0, len(line), 5):
            l = line[idx:idx + 2]
            results.extend(
                model.evaluate_lines(
                    sess, context_embeddings_op, elmo_context, elmo_ids,
                    ner_morph_tag,
                    inputs_from_sentences(komoran, l, word_to_id, pumsa_to_id,
                                          char_to_id, elmo_dict,
                                          FLAGS.max_char_length,
                                          ner_morph_tag), id_to_tag))
        print(results)
Пример #3
0
def evaluate_cli(model):
	config = load_config(FLAGS.config_file)
	logger = get_logger(FLAGS.log_file)
	with open(FLAGS.map_file, "rb") as f:
		char_to_id, id_to_char, tag_to_id, id_to_tag = pickle.load(f)
	while True:
		line = input("문장을 입력하세요.:")
		result = model.evaluate_lines(sess, inputs_from_sentences([line], char_to_id, FLAGS.max_char_length), id_to_tag)
		print(result)
Пример #4
0
def evaluate_line():
    config = model_utils.load_config(FLAGS.config_file)
    logger = model_utils.get_logger(FLAGS.log_file)
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    with open(FLAGS.map_file, 'rb') as f:
        word_to_id, id_to_word, tag_to_id, id_to_tag = pickle.load(f)
    with tf.Session(config=tf_config) as sess:
        model = model_utils.create(sess, Model, FLAGS.ckpt_path, load_word2vec,
                                   config, id_to_word, logger, FLAGS.train)
        while True:
            line = input('请输入测试句子(输入q退出):')
            if line == 'q':
                return
            result = model.evaluate_line(
                sess, data_utils.input_from_line(line, word_to_id), id_to_tag)
            print(result)
Пример #5
0
    def infer(sentences, **kwargs):
        config = load_config(FLAGS.config_file)
        logger = get_logger(FLAGS.log_file)
        if config['elmo']:
            elmo_dict = load_elmo_dict(FLAGS.elmo_dict)
        else:
            elmo_dict = None

        results = []
        komoran = Komoran()
        reformed_sentences = [' '.join(sen[1]) for sen in sentences]
        for idx in range(0, len(reformed_sentences), 100):
            reformed_sentence = reformed_sentences[idx:idx + 100]
            results.extend(
                model.evaluate_lines(
                    sess, context_embeddings_op, elmo_context, elmo_ids,
                    ner_morph_tag,
                    inputs_from_sentences(komoran, reformed_sentence,
                                          word_to_id, pumsa_to_id, char_to_id,
                                          elmo_dict, FLAGS.max_char_length,
                                          ner_morph_tag), id_to_tag))
            # results.extend(model.evaluate_lines(sess, context_embeddings_op, elmo_context, elmo_ids, ner_morph_tag,\
            # 									inputs_from_sentences(komoran, reformed_sentence, word_to_id, pumsa_to_id, char_to_id, elmo_dict, FLAGS.max_char_length, ner_morph_tag), id_to_tag))
        '''
		result = [
			       (0.0, ['ARG0', 'ARG3', '-']),
				   (0.0, ['ARG0', 'ARG1', '-'])
				 ]
		# evaluate_lines 함수는 문장 단위 분석 결과를 내어줍니다.
		# len(result) : 문장의 갯수, 따라서 위 예제는 두 문장의 결과입니다.
		# result[0] : 첫번째 문장의 분석 결과, result[1] : 두번째 문장의 분석 결과.
		
		# 각 문장의 분석 결과는 다시 (prob, [labels])로 구성됩니다.
		# prob에 해당하는 자료는 이번 task에서 사용하지 않습니다. 에 영향을 미치지 않습니다.
		# [labels]는 각 어절의 분석 결과를 담고 있습니다. 따라서 다음과 같이 구성됩따라서 그 값이 결과니다.
		## ['첫번째 어절의 분석 결과', '두번째 어절의 분석 결과', ...]
		# 예를 들면 위 주어진 예제에서 첫번째 문장의 첫번째 어절은 'ARG0'을, 첫번째 문장의 두번째 어절은 'ARG3'을 argument label로 가집니다.

		### 주의사항 ###
		# 모든 어절의 결과를 제출하여야 합니다.
		# 어절의 순서가 지켜져야 합니다. (첫번째 어절부터 순서대로 list 내에 위치하여야 합니다.)
		''' ''
        # results[1000000000000000000000000]
        # test
        return results
Пример #6
0
	def load(dir_path, *args):
		global char_to_id
		global id_to_tag

		config = load_config(FLAGS.config_file)
		logger = get_logger(FLAGS.log_file)
		tf.get_variable_scope().reuse_variables()

		with open(os.path.join(dir_path,FLAGS.map_file), "rb") as f:
			char_to_id, _, __, id_to_tag = pickle.load(f)

		saver = tf.train.Saver()
		ckpt = tf.train.get_checkpoint_state(dir_path)
		if ckpt and ckpt.model_checkpoint_path:
			checkpoint = os.path.basename(ckpt.model_checkpoint_path)
			saver.restore(sess, os.path.join(dir_path, checkpoint))
		else:
			raise NotImplemented('No checkpoint found!')
		print ('model loaded!')
Пример #7
0
def train():
    # 加载数据集
    train_sentences = data_loader.load_sentences(FLAGS.train_file)
    dev_sentences = data_loader.load_sentences(FLAGS.dev_file)
    test_sentences = data_loader.load_sentences(FLAGS.test_file)

    # 转换编码
    data_loader.update_tag_scheme(train_sentences, FLAGS.tag_schema)
    data_loader.update_tag_scheme(dev_sentences, FLAGS.tag_schema)
    data_loader.update_tag_scheme(test_sentences, FLAGS.tag_schema)

    # 创建单词和词典映射
    if not os.path.isfile(FLAGS.map_file):
        if FLAGS.pre_emb:
            dico_words_train = data_loader.word_mapping(train_sentences)[0]
            dico_word, word_to_id, id_to_word = data_utils.augment_with_pretrained(
                dico_words_train.copy(), FLAGS.emb_file,
                list(
                    itertools.chain.from_iterable([[w[0] for w in s]
                                                   for s in test_sentences])))
        else:
            _, word_to_id, id_to_word = data_loader.word_mapping(
                train_sentences)
        _, tag_to_id, id_to_tag = data_loader.tag_mapping(train_sentences)
        with open(FLAGS.map_file, 'wb') as f:
            pickle.dump([word_to_id, id_to_word, tag_to_id, id_to_tag], f)
    else:
        with open(FLAGS.map_file, 'rb') as f:
            word_to_id, id_to_word, tag_to_id, id_to_tag = pickle.load(f)

    # 准备数据
    train_data = data_loader.prepare_dataset(train_sentences, word_to_id,
                                             tag_to_id)
    dev_data = data_loader.prepare_dataset(dev_sentences, word_to_id,
                                           tag_to_id)
    test_data = data_loader.prepare_dataset(test_sentences, word_to_id,
                                            tag_to_id)

    # 将数据分批处理
    train_manager = data_utils.BatchManager(train_data, FLAGS.batch_size)
    dev_manager = data_utils.BatchManager(dev_data, FLAGS.batch_size)
    test_manager = data_utils.BatchManager(test_data, FLAGS.batch_size)

    # 创建不存在的文件夹
    model_utils.make_path(FLAGS)

    # 判断配置文件
    if os.path.isfile(FLAGS.config_file):
        config = model_utils.load_config(FLAGS.config_file)
    else:
        config = model_utils.config_model(FLAGS, word_to_id, tag_to_id)
        model_utils.save_config(config, FLAGS.config_file)

    # 配置印logger
    log_path = os.path.join('log', FLAGS.log_file)
    logger = model_utils.get_logger(log_path)
    model_utils.print_config(config, logger)

    tf_config = tf.ConfigProto(allow_soft_placement=True)
    tf_config.gpu_options.allow_growth = True

    step_per_epoch = train_manager.len_data
    with tf.Session(config=tf_config) as sess:
        model = model_utils.create(sess, Model, FLAGS.ckpt_path, load_word2vec,
                                   config, id_to_word, logger)
        logger.info('开始训练')
        loss = []
        start = time.time()
        for i in range(100):
            for batch in train_manager.iter_batch(shuffle=True):
                step, batch_loss = model.run_step(sess, True, batch)
                loss.append(batch_loss)
                if step % FLAGS.setps_chech == 0:
                    iteration = step // step_per_epoch + 1
                    logger.info(
                        "iteration{}: step{}/{}, NER loss:{:>9.6f}".format(
                            iteration, step % step_per_epoch, step_per_epoch,
                            np.mean(loss)))
                    loss = []
            best = evaluate(sess, model, 'dev', dev_manager, id_to_tag, logger)

            if best:
                model_utils.save_model(sess, model, FLAGS.ckpt_path, logger)
            evaluate(sess, model, 'test', test_manager, id_to_tag, logger)
        t = time.time() - start
        logger.info('cost time: %f' % t)
Пример #8
0
if os.path.isfile(FLAGS.config_file):

config = model_utils.load_config(FLAGS.config_file)

else:

config = model_utils.config_model(FLAGS, word_to_id, tag_to_id)

model_utils.save_config(config, FLAGS.config_file)



log_path = os.path.join("log", FLAGS.log_file)

logger = model_utils.get_logger(log_path)

model_utils.print_config(config, logger)



tf_config = tf.ConfigProto()

tf_config.gpu_options.allow_growth = True

steps_per_epoch =train_manager.len_data

with tf.Session(config = tf_config) as sess:

model = model_utils.create(sess, Model, FLAGS.ckpt_path, load_word2vec, config, id_to_word, logger)
def main():
    # Parse
    parser = model_utils.get_parser()
    FLAGS, unparsed = parser.parse_known_args()
    # Setup model_dir
    if FLAGS.model_name is None:
        model_name = "LanguageModel"
    else:
        model_name = FLAGS.model_name

    model_dir = os.path.abspath(FLAGS.base_dir) + '/{}/'.format(model_name)
    if not os.path.exists(model_dir):
        model_utils.setup_model_dir(model_dir, create_base=True)
    if FLAGS.no_restore:
        model_utils.remove_history(model_dir)
        model_utils.setup_model_dir(model_dir, create_base=False)
    # Start logging
    logger = model_utils.get_logger(model_name, model_dir)
    logger.info("Started constructing {}".format(model_name))
    logger.info("Parsed args {}".format(FLAGS))
    if FLAGS.no_restore:
        logger.info('Not restoring, deleted history.')

    # Get Dataset
    logger.info("Getting dataset {}".format(FLAGS.dataset_name))
    full_dataset, tokenizer, size = data.make_dataset(FLAGS.dataset_name,
                                                      FLAGS.dataset_type,
                                                      FLAGS.data_dir,
                                                      FLAGS.seq_length)
    # Create model
    hparams = create_hparams(FLAGS.hparams)
    lm = LanguageModel(tokenizer.vocab_size, hparams.embedding_dim,
                       hparams.rnn_size, hparams.use_cudnn)
    optimizer = tf.train.AdamOptimizer(hparams.lr, hparams.beta1,
                                       hparams.beta2, hparams.epsilon)
    epoch_count = tf.Variable(1, 'epoch_count')
    global_step = tf.train.get_or_create_global_step()
    logger.info("Model created")
    # Create checkpointing
    checkpoint_dir = os.path.abspath(model_dir + 'ckpts/' + FLAGS.run_name)
    logger.info("Checkpoints at {}".format(checkpoint_dir))
    checkpoint_prefix = checkpoint_dir + '/ckpt'
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     lm=lm,
                                     epoch_count=epoch_count,
                                     global_step=global_step)
    if not FLAGS.no_restore:
        if not FLAGS.load_checkpoint is None:
            load_checkpoint = FLAGS.load_checkpoint
        else:
            load_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
            logger.info("Loading latest checkpoint...")
        logger.info("Loading checkpoint {}".format(load_checkpoint))
        checkpoint.restore(load_checkpoint)

    # Create summary writer
    summary_dir = model_dir + 'log/' + FLAGS.run_name + '/'
    summary_writer = tf.contrib.summary.create_file_writer(summary_dir,
                                                           flush_millis=1000)

    # Training
    if FLAGS.mode == "train":
        logger.info("Beginning training...")
        device = '/gpu:0' if not FLAGS.no_gpu else '/cpu:0'
        # Get training Dataset
        logger.info("Full dataset size: {}".format(int(size)))
        logger.info("Train dataset size: {}".format(
            int(size * FLAGS.use_frac * FLAGS.train_frac)))
        train_dataset, valid_dataset = model_utils.split(
            full_dataset, size, FLAGS.use_frac, FLAGS.train_frac)
        train_dataset = train_dataset.batch(FLAGS.batch_size,
                                            drop_remainder=True)
        valid_dataset = valid_dataset.batch(FLAGS.batch_size,
                                            drop_remainder=True)
        train_dataset = (
            tf.data.experimental.prefetch_to_device(device)(train_dataset))
        valid_dataset = (
            tf.data.experimental.prefetch_to_device(device)(valid_dataset))
        # Train loop
        train_losses = []
        val_losses = []
        patience_count = 0
        for epoch in range(FLAGS.epochs):
            cur_epoch = epoch_count.numpy() + epoch
            logger.info("Starting epoch {}...".format(cur_epoch))
            start = time.time()
            with summary_writer.as_default():
                train_loss = lm.train(train_dataset, optimizer, global_step,
                                      FLAGS.log_interval)
                logger.info("Epoch {} complete: train loss = {:0.03f}".format(
                    cur_epoch, train_loss))
                logger.info("Validating...")
                val_loss = lm.evaluate(valid_dataset)
                logger.info("Validation loss = {:0.03f}".format(val_loss))
            time_elapsed = time.time() - start
            logger.info("Took {:0.01f} seconds".format(time_elapsed))
            # Checkpoint
            if FLAGS.early_stopping:
                if not val_losses or val_loss < min(
                        val_losses) - FLAGS.es_delta:
                    logger.info("Checkpointing...")
                    checkpoint.save(checkpoint_prefix)
                elif patience_count + 1 > FLAGS.patience:
                    logger.info("Early stopping reached")
                    break
                else:
                    patience_count += 1
            else:
                logger.info("Checkpointing...")
                checkpoint.save(checkpoint_prefix)

    elif FLAGS.mode == "eval":
        logger.info("Beginning evaluation...")
        device = '/gpu:0' if not FLAGS.no_gpu else '/cpu:0'
        with summary_writer.as_default():
            val_loss = lm.evaluate(full_dataset)
            logger.info("Validation loss: {:0.02f}".format(val_loss))

    elif FLAGS.mode == "generate":
        # Generate samples
        logger.info("Generating samples...")
        for _ in range(FLAGS.num_samples):
            tokens = tokenizer.tokenize(FLAGS.seed_text)
            inp = tf.constant(np.array(tokens, dtype=np.int16))
            inp = tf.expand_dims(inp, 0)
            _, state = lm.call_with_state(inp[:, 0:-1])  # Setup state
            cur_token = tokens[-1]
            done = False
            while not done:
                inp = tf.constant(np.array([cur_token], dtype=np.int16))
                inp = tf.expand_dims(inp, 0)
                logits, state = lm.call_with_state(inp, state)
                logits = tf.squeeze(logits, 0)
                logits = logits / FLAGS.temperature
                cur_token = tf.multinomial(logits, num_samples=1)[-1,
                                                                  0].numpy()
                tokens.append(cur_token)
                if len(tokens) > FLAGS.sample_length:
                    done = True
            logger.info("{}".format(tokenizer.untokenize(tokens)))
            lm.recurrent.reset_states()
Пример #10
0
def train():
    # 加载数据集
    train_sentences = data_loader.load_sentences(FLAGS.train_file)
    dev_sentences = data_loader.load_sentences(FLAGS.dev_file)
    test_sentences = data_loader.load_sentences(FLAGS.test_file)

    # 转换编码 bio转bioes
    data_loader.update_tag_scheme(train_sentences, FLAGS.tag_schema)
    data_loader.update_tag_scheme(test_sentences, FLAGS.tag_schema)
    data_loader.update_tag_scheme(dev_sentences, FLAGS.tag_schema)

    # 创建单词映射及标签映射
    if not os.path.isfile(FLAGS.map_file):
        if FLAGS.pre_emb:
            dico_words_train = data_loader.word_mapping(train_sentences)[0]
            dico_word, word_to_id, id_to_word = data_utils.augment_with_pretrained(
                dico_words_train.copy(),
                FLAGS.emb_file,
                list(
                    itertools.chain.from_iterable(
                        [[w[0] for w in s] for s in test_sentences]
                    )
                )
            )
        else:
            _, word_to_id, id_to_word = data_loader.word_mapping(train_sentences)

        _, tag_to_id, id_to_tag = data_loader.tag_mapping(train_sentences)

        with open(FLAGS.map_file, "wb") as f:
            pickle.dump([word_to_id, id_to_word, tag_to_id, id_to_tag], f)
    else:
        with open(FLAGS.map_file, 'rb') as f:
            word_to_id, id_to_word, tag_to_id, id_to_tag = pickle.load(f)

    train_data = data_loader.prepare_dataset(
        train_sentences, word_to_id, tag_to_id
    )

    dev_data = data_loader.prepare_dataset(
        dev_sentences, word_to_id, tag_to_id
    )

    test_data = data_loader.prepare_dataset(
        test_sentences, word_to_id, tag_to_id
    )

    train_manager = data_utils.BatchManager(train_data, FLAGS.batch_size)
    dev_manager = data_utils.BatchManager(dev_data, FLAGS.batch_size)
    test_manager = data_utils.BatchManager(test_data, FLAGS.batch_size)

    print('train_data_num %i, dev_data_num %i, test_data_num %i' % (len(train_data), len(dev_data), len(test_data)))

    model_utils.make_path(FLAGS)

    if os.path.isfile(FLAGS.config_file):
        config = model_utils.load_config(FLAGS.config_file)
    else:
        config = model_utils.config_model(FLAGS, word_to_id, tag_to_id)
        model_utils.save_config(config, FLAGS.config_file)

    log_path = os.path.join("log", FLAGS.log_file)
    logger = model_utils.get_logger(log_path)
    model_utils.print_config(config, logger)

    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    steps_per_epoch =train_manager.len_data
    with tf.Session(config = tf_config) as sess:
        model = model_utils.create(sess, Model, FLAGS.ckpt_path, load_word2vec, config, id_to_word, logger)
        logger.info("开始训练")
        loss = []
        for i in range(100):
            for batch in train_manager.iter_batch(shuffle=True):
                step, batch_loss = model.run_step(sess, True, batch)
                loss.append(batch_loss)
                if step % FLAGS.setps_chech== 0:
                    iterstion = step // steps_per_epoch + 1
                    logger.info("iteration:{} step{}/{},NER loss:{:>9.6f}".format(iterstion, step%steps_per_epoch, steps_per_epoch, np.mean(loss)))
                    loss = []

            best = evaluate(sess,model,"dev", dev_manager, id_to_tag, logger)

            if best:
                model_utils.save_model(sess, model, FLAGS.ckpt_path, logger)
            evaluate(sess, model, "test", test_manager, id_to_tag, logger)
Пример #11
0
	trans = SRL_Model.trans.eval(session=sess)

	for _, batch in enumerate(unlabeled_manager.iter_batch(shuffle=False)):
		sentence = batch[-1]
		lengths, scores = SRL_Model.run_step(sess, ELMo_context, ELMo_ids, False, batch)
		batch_paths = SRL_Model.decode(scores, lengths, trans)
		for i in range(len(sentence)):
			string = sentence[i] + "|||"
			pred = [idx2label[int(x)] for x in batch_paths[i][:lengths[i]]]
			sentences.append(string + " ".join(pred))

	write_tag(sentences)

if __name__ == "__main__":
	log_path = os.path.join("log", config.log_file)
	logger = get_logger(log_path)

	tf_config = tf.ConfigProto()
	tf_config.gpu_options.allow_growth = True
	sess = tf.Session(config=tf_config)

	#load or make vocab
	if os.path.isfile(config.necessary):
		with open(config.necessary, 'rb') as f:
			word2idx, pumsa2idx, lemma2idx, char2idx, label2idx, idx2label = pickle.load(f)
	else:
		word2idx, pumsa2idx, lemma2idx, char2idx, label2idx, idx2label = get_necessary()

	if config.pretrained_embeddings:
		word_embedding_matrix = load_word_embedding_matrix(word2idx)
	else:
Пример #12
0
def main():
    # Parse
    parser = model_utils.get_parser()
    FLAGS, unparsed = parser.parse_known_args()
    # Setup model_dir
    if FLAGS.model_name is None:
        model_name = "UniRNN"
    else:
        model_name = FLAGS.model_name

    model_dir = os.path.abspath(FLAGS.base_dir) + '/{}/'.format(model_name)
    if not os.path.exists(model_dir):
        model_utils.setup_model_dir(model_dir, create_base=True)
    if FLAGS.no_restore:
        model_utils.remove_history(model_dir)
        model_utils.setup_model_dir(model_dir, create_base=False)
    # Start logging
    logger = model_utils.get_logger(model_name, model_dir)
    logger.info("Started constructing {}".format(model_name))
    logger.info("Parsed args {}".format(FLAGS))
    if FLAGS.no_restore:
        logger.info('Not restoring, deleted history.')

    # Get Dataset
    logger.info("Getting dataset {}".format(FLAGS.dataset_name))
    with open('/home/gray/code/seqgan-opinion-spam/data/happydb/happytok.pkl',
              'rb') as f:
        tokenizer = pickle.load(f)
    #full_dataset, tokenizer, size = data.make_dataset(FLAGS.dataset_name,
    #                                                  FLAGS.dataset_type,
    #                                                  FLAGS.data_dir,
    #                                                  FLAGS.seq_length)
    # Create model
    hparams = create_hparams(FLAGS.hparams)
    model = UniRNN(tokenizer.vocab_size, hparams.embedding_dim,
                   hparams.rnn_size, FLAGS.n_classes, hparams.use_vat,
                   hparams.adv_size, hparams.use_cudnn)
    model_optimizer = tf.train.AdamOptimizer(hparams.lr, hparams.beta1,
                                             hparams.beta2, hparams.epsilon)
    epoch_count = tf.Variable(1, 'epoch_count')
    global_step = tf.train.get_or_create_global_step()
    logger.info("Model created")
    # Create checkpointing
    checkpoint_dir = os.path.abspath(model_dir + 'ckpts/' + FLAGS.run_name)
    logger.info("Checkpoints at {}".format(checkpoint_dir))
    checkpoint_prefix = checkpoint_dir + '/ckpt'
    checkpoint = tf.train.Checkpoint(model_optimizer=model_optimizer,
                                     model=model,
                                     epoch_count=epoch_count,
                                     global_step=global_step)
    if not FLAGS.no_restore:
        if not FLAGS.load_checkpoint is None:
            load_checkpoint = FLAGS.load_checkpoint
        else:
            load_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
            logger.info("Loading latest checkpoint...")
        logger.info("Loading checkpoint {}".format(load_checkpoint))
        checkpoint.restore(load_checkpoint)

    # Upload pretrained embeddings and RNN if using
    lm = LanguageModel.LanguageModel(tokenizer.vocab_size,
                                     hparams.embedding_dim, hparams.rnn_size,
                                     hparams.use_cudnn)
    if not FLAGS.pretrained_lm_dir is None:
        logger.info('Loading pretrained LM at {}'.format(
            FLAGS.pretrained_lm_dir))
        # TODO Tensorflow is parsing doubled slashes...
        pretrain_ckpt_path = os.path.normpath(FLAGS.pretrained_lm_dir)
        print(pretrain_ckpt_path)
        lm_checkpoint = tf.train.Checkpoint(lm=lm)
        path = tf.train.latest_checkpoint(pretrain_ckpt_path)
        checkpoint.restore(path)
        logger.info('Loaded {}'.format(path))
        model.embedding.set_weights(lm.embedding.get_weights())
        model.recurrent.set_weights(lm.recurrent.get_weights())

    # Create summary writer
    summary_dir = model_dir + 'log/' + FLAGS.run_name + '/'
    summary_writer = tf.contrib.summary.create_file_writer(summary_dir,
                                                           flush_millis=1000)

    # Training
    if FLAGS.mode == "train":
        logger.info("Beginning training...")
        device = '/gpu:0' if not FLAGS.no_gpu else '/cpu:0'
        # Get training Dataset
        #logger.info("Full dataset size: {}".format(int(size)))
        #logger.info("Train dataset size: {}".format(
        #    int(size * FLAGS.use_frac * FLAGS.train_frac)))
        #TODO fix this
        trainX = np.load(
            '/home/gray/code/seqgan-opinion-spam/data/happydb/moments_tr_data.npy'
        )
        trainY = np.load(
            '/home/gray/code/seqgan-opinion-spam/data/happydb/moments_tr_labels.npy'
        )
        trainY = np.expand_dims(trainY, 1)
        valX = np.load(
            '/home/gray/code/seqgan-opinion-spam/data/happydb/moments_val_data.npy'
        )
        valY = np.load(
            '/home/gray/code/seqgan-opinion-spam/data/happydb/moments_val_labels.npy'
        )
        valY = np.expand_dims(valY, 1)
        train_dataset = tf.data.Dataset.from_tensor_slices((trainX, trainY))
        valid_dataset = tf.data.Dataset.from_tensor_slices((valX, valY))

        #train_dataset, valid_dataset = model_utils.split(full_dataset,
        #                                                 size,
        #                                                 FLAGS.use_frac,
        #                                                 FLAGS.train_frac)
        train_dataset = train_dataset.batch(FLAGS.batch_size,
                                            drop_remainder=True)
        valid_dataset = valid_dataset.batch(FLAGS.batch_size,
                                            drop_remainder=True)
        #train_dataset = (tf.data.experimental
        #                    .prefetch_to_device(device)(train_dataset))
        #valid_dataset = (tf.data.experimental
        #                    .prefetch_to_device(device)(valid_dataset))
        # Train loop
        train_losses = []
        val_losses = []
        min_val_loss = 1e8
        patience_count = 0
        for epoch in range(FLAGS.epochs):
            cur_epoch = epoch_count.numpy() + epoch
            logger.info("Starting epoch {}...".format(cur_epoch))
            start = time.time()
            with summary_writer.as_default():
                train_loss = model.train(train_dataset, model_optimizer,
                                         global_step, FLAGS.log_interval)
                logger.info("Epoch {} complete: train loss = {:0.03f}".format(
                    cur_epoch, train_loss))
                logger.info("Validating...")
                val_loss = model.evaluate(valid_dataset)
                logger.info("Validation loss = {:0.03f}".format(val_loss))
            time_elapsed = time.time() - start
            logger.info("Took {:0.01f} seconds".format(time_elapsed))
            # Checkpoint
            if FLAGS.early_stopping:
                if val_loss < (min_val_loss - FLAGS.es_delta):
                    logger.info("Checkpointing...")
                    checkpoint.save(checkpoint_prefix)
                    patience_count = 0
                    min_val_loss = val_loss
                elif patience_count + 1 > FLAGS.patience:
                    logger.info("Early stopping reached")
                    break
                else:
                    patience_count += 1
            else:
                logger.info("Checkpointing...")
                checkpoint.save(checkpoint_prefix)

    elif FLAGS.mode == "eval":
        logger.info("Beginning evaluation...")
        device = '/gpu:0' if not FLAGS.no_gpu else '/cpu:0'
        with summary_writer.as_default():
            val_loss = model.evaluate(full_dataset)
            logger.info("Validation loss: {:0.02f}".format(val_loss))