Esempio n. 1
0
def run_epoch(session, m, conll_words, ptb_words, pos, ptb_pos, chunk, ptb_chunk, pos_vocab_size,
            chunk_vocab_size, vocab_size, num_steps, config, ptb_batches,
             ptb_iter, verbose=False, valid=False, model_type='JOINT'):
    """Runs the model on the given data."""
    # =====================================
    # Initialise variables
    # =====================================
    conll_epoch_size = (len(conll_words) // (m.batch_size*num_steps))+1
    ptb_epoch_size = (len(ptb_words) // (m.batch_size*m.num_steps))+1
    epoch_stats = {
        'comb_loss': 0.0,
        "pos_total_loss": 0.0,
        "chunk_total_loss": 0.0,
        "lm_total_loss": 0.0,
        "iters": 0,
        "accuracy": 0.0,
        "pos_predictions": [],
        "pos_true": [],
        "chunk_predictions": [],
        "chunk_true": [],
        "lm_predictions": [],
        "lm_true": []
    }

    print('creating batches')
    conll_batches = reader.create_batches(conll_words, pos, chunk, m.batch_size,
                            m.num_steps, pos_vocab_size, chunk_vocab_size, vocab_size, continuing=True)

    conll_iter = 0

    # =======================================================
    # Define the train batch method
    # -------------------------------------------------------
    # We're then going to use this method to train a batch of
    # each data type in turn
    # ======================================================

    def train_batch(batch, eval_op, model_type, epoch_stats, config, stop_write=False, validation=False):
        (x, y_pos, y_chunk, y_lm) = batch

        if validation==True:
            joint_loss, pos_int_pred, chunk_int_pred, lm_int_pred, pos_int_true, \
                chunk_int_true, lm_int_true, pos_loss, chunk_loss, lm_loss = \
                session.run([m.joint_loss, m.pos_int_pred,
                             m.chunk_int_pred, m.lm_int_pred, m.pos_int_targ, m.chunk_int_targ,
                             m.lm_int_targ, m.pos_loss, m.chunk_loss, m.lm_loss],
                            {m.input_data: x,
                             m.pos_targets: y_pos,
                             m.chunk_targets: y_chunk,
                             m.lm_targets: y_lm,
                             })

        else:
            joint_loss, _, pos_int_pred, chunk_int_pred, lm_int_pred, pos_int_true, \
                chunk_int_true, lm_int_true, pos_loss, chunk_loss, lm_loss = \
                session.run([m.joint_loss, eval_op, m.pos_int_pred,
                             m.chunk_int_pred, m.lm_int_pred, m.pos_int_targ, m.chunk_int_targ,
                             m.lm_int_targ, m.pos_loss, m.chunk_loss, m.lm_loss],
                            {m.input_data: x,
                             m.pos_targets: y_pos,
                             m.chunk_targets: y_chunk,
                             m.lm_targets: y_lm,
                             })

        epoch_stats["comb_loss"] += joint_loss
        epoch_stats["chunk_total_loss"] += chunk_loss
        epoch_stats["pos_total_loss"] += pos_loss
        epoch_stats["lm_total_loss"] += lm_loss
        epoch_stats["iters"] += 1

        if verbose and (epoch_stats["iters"] % 20 == 0):
            if model_type == 'POS':
                costs = epoch_stats["pos_total_loss"]
                cost = pos_loss
            elif model_type == 'CHUNK':
                costs = epoch_stats["chunk_total_loss"]
                cost = chunk_loss
            elif model_type == 'LM':
                costs = epoch_stats["lm_total_loss"]
                cost = lm_loss
            else:
                costs = epoch_stats["comb_loss"]
                cost = joint_loss
            print("Type: %s,cost: %3f, step: %3f" % (model_type, cost, epoch_stats['iters']))


        if model_type != "LM" and stop_write==False:
            pos_int_pred = np.reshape(pos_int_pred, [m.batch_size, m.num_steps])
            pos_int_true = np.reshape(pos_int_true, [m.batch_size, m.num_steps])
            epoch_stats["pos_predictions"].append(pos_int_pred)
            epoch_stats["pos_true"].append(pos_int_true)

            chunk_int_pred = np.reshape(chunk_int_pred, [m.batch_size, m.num_steps])
            chunk_int_true = np.reshape(chunk_int_true, [m.batch_size, m.num_steps])
            epoch_stats["chunk_predictions"].append(chunk_int_pred)
            epoch_stats["chunk_true"].append(chunk_int_true)

            lm_int_pred = np.reshape(lm_int_pred, [m.batch_size, m.num_steps])
            lm_int_true = np.reshape(lm_int_true, [m.batch_size, m.num_steps])
            epoch_stats["lm_predictions"].append(lm_int_pred)
            epoch_stats["lm_true"].append(lm_int_true)

        return epoch_stats

    # ==========================================================
    # Do the epoch
    # ----------------------------------------------------------
    # randomly choose a dataset, and then increment your counter
    # ==========================================================

    if valid:
        eval_op = tf.no_op()
        for i in range(conll_epoch_size):
            epoch_stats = train_batch(next(conll_batches), eval_op, "JOINT", epoch_stats, config, 0, validation=True)
    else:
        print('ptb epoch size: ' + str(ptb_epoch_size))
        print('conll epoch size: ' + str(conll_epoch_size))
        if m.mix_percent < 1:
            while  (conll_iter < conll_epoch_size):
                if np.random.rand(1) < m.mix_percent:
                    eval_op = m.joint_op
                    epoch_stats = train_batch(next(conll_batches), \
                        eval_op, "JOINT", epoch_stats, config, (conll_iter > conll_epoch_size))
                    conll_iter +=1
                    # print('conll iter: ' + str(conll_iter))
                else:
                    eval_op = m.lm_op
                    epoch_stats = train_batch(next(ptb_batches), \
                        eval_op, "LM", epoch_stats, 0, config)
                    ptb_iter += 1
                    ptb_iter = ptb_iter % ptb_epoch_size
                    # print('ptb iter: ' + str(ptb_iter))
        else:
            while (conll_iter < conll_epoch_size):
                eval_op = m.joint_op
                epoch_stats = train_batch(next(conll_batches), \
                    eval_op, "JOINT", epoch_stats, config, (conll_iter > conll_epoch_size))
                conll_iter +=1

    return (epoch_stats["comb_loss"] / epoch_stats["iters"]), \
        epoch_stats["pos_predictions"], epoch_stats["chunk_predictions"], \
        epoch_stats["lm_predictions"], epoch_stats["pos_true"], \
        epoch_stats["chunk_true"], epoch_stats["lm_true"], \
        (epoch_stats["pos_total_loss"] / epoch_stats["iters"]), \
        (epoch_stats["chunk_total_loss"] / epoch_stats["iters"]), \
        (epoch_stats["lm_total_loss"] / epoch_stats["iters"]), ptb_iter
Esempio n. 2
0
def run_epoch(session,
              m,
              words,
              pos,
              chunk,
              pos_vocab_size,
              chunk_vocab_size,
              vocab_size,
              num_steps,
              verbose=False,
              valid=False,
              model_type='JOINT'):
    """Runs the model on the given data."""
    epoch_size = ((len(words) // m.batch_size) + 1)
    start_time = time.time()
    comb_loss = 0.0
    pos_total_loss = 0.0
    chunk_total_loss = 0.0
    lm_total_loss = 0.0
    iters = 0
    accuracy = 0.0
    pos_predictions = []
    pos_true = []
    chunk_predictions = []
    chunk_true = []
    lm_predictions = []
    lm_true = []

    for step, (x, y_pos, y_chunk, y_lm) in enumerate(
            reader.create_batches(words, pos, chunk, m.batch_size, m.num_steps,
                                  pos_vocab_size, chunk_vocab_size,
                                  vocab_size)):

        if model_type == 'POS':
            if valid:
                eval_op = tf.no_op()
            else:
                eval_op = m.pos_op
        elif model_type == 'CHUNK':
            if valid:
                eval_op = tf.no_op()
            else:
                eval_op = m.chunk_op
        elif model_type == 'LM':
            if valid:
                eval_op = tf.no_op()
            else:
                eval_op = m.lm_op
        else:
            if valid:
                eval_op = tf.no_op()
            else:
                eval_op = m.joint_op


        joint_loss, _, pos_int_pred, chunk_int_pred, lm_int_pred, pos_int_true, \
            chunk_int_true, lm_int_true, pos_loss, chunk_loss, lm_loss = \
            session.run([m.joint_loss, eval_op, m.pos_int_pred,
                         m.chunk_int_pred, m.lm_int_pred, m.pos_int_targ, m.chunk_int_targ,
                         m.lm_int_targ, m.pos_loss, m.chunk_loss, m.lm_loss],
                        {m.input_data: x,
                         m.pos_targets: y_pos,
                         m.chunk_targets: y_chunk,
                         m.lm_targets: y_lm,
                         m.gold_embed: 0})

        comb_loss += joint_loss
        chunk_total_loss += chunk_loss
        pos_total_loss += pos_loss
        lm_total_loss += lm_loss
        iters += 1

        if verbose and step % 10 == 0:
            if model_type == 'POS':
                costs = pos_total_loss
                cost = pos_loss
            elif model_type == 'CHUNK':
                costs = chunk_total_loss
                cost = chunk_loss
            elif model_type == 'LM':
                costs = lm_total_loss
                cost = lm_loss
            else:
                costs = comb_loss
                cost = joint_loss
            print("Type: %s,cost: %3f, step: %3f" % (model_type, cost, step))

        pos_int_pred = np.reshape(pos_int_pred, [m.batch_size, m.num_steps])
        pos_int_true = np.reshape(pos_int_true, [m.batch_size, m.num_steps])
        pos_predictions.append(pos_int_pred)
        pos_true.append(pos_int_true)

        chunk_int_pred = np.reshape(chunk_int_pred,
                                    [m.batch_size, m.num_steps])
        chunk_int_true = np.reshape(chunk_int_true,
                                    [m.batch_size, m.num_steps])
        chunk_predictions.append(chunk_int_pred)
        chunk_true.append(chunk_int_true)

        lm_int_pred = np.reshape(lm_int_pred, [m.batch_size, m.num_steps])
        lm_int_true = np.reshape(lm_int_true, [m.batch_size, m.num_steps])
        lm_predictions.append(lm_int_pred)
        lm_true.append(lm_int_true)

    return (comb_loss / iters), pos_predictions, chunk_predictions, lm_predictions, \
        pos_true, chunk_true, lm_true, (pos_total_loss / iters), \
        (chunk_total_loss / iters), (lm_total_loss / iters)
Esempio n. 3
0
def main(model_type, dataset_path, ptb_path, save_path,
    num_steps, encoder_size, pos_decoder_size, chunk_decoder_size, dropout,
    batch_size, pos_embedding_size, num_shared_layers, num_private_layers, chunk_embedding_size,
    lm_decoder_size, bidirectional, lstm, write_to_file, mix_percent,glove_path,max_epoch,
    projection_size, num_batches_gold, reg_weight, word_embedding_size, embedding_trainable, \
    adam, connections, fraction_of_training_data=1, embedding=False, test=False):

    """Main."""
    config = Config(num_steps, encoder_size, pos_decoder_size, chunk_decoder_size, dropout,
    batch_size, pos_embedding_size, num_shared_layers, num_private_layers, chunk_embedding_size,
    lm_decoder_size, bidirectional, lstm, mix_percent, max_epoch, reg_weight, word_embedding_size, \
     embedding_trainable, adam, fraction_of_training_data, connections)

    raw_data_path = dataset_path + '/data'
    raw_data = reader.raw_x_y_data(
        raw_data_path, num_steps, ptb_path + '/data', embedding, glove_path)

    words_t, pos_t, chunk_t, words_v, \
        pos_v, chunk_v, word_to_id, pos_to_id, \
        chunk_to_id, words_test, pos_test, chunk_test, \
        words_c, pos_c, chunk_c, words_ptb, pos_ptb, chunk_ptb, word_embedding = raw_data

    num_train_examples = int(np.floor(len(words_t) * fraction_of_training_data))

    words_t = words_t[:num_train_examples]
    pos_t = pos_t[:num_train_examples]
    chunk_t = chunk_t[:num_train_examples]

    num_pos_tags = len(pos_to_id)
    num_chunk_tags = len(chunk_to_id)
    vocab_size = len(word_to_id)
    prev_chunk_F1 = 0.0

    ptb_batches = reader.create_batches(words_ptb, pos_ptb, chunk_ptb, config.batch_size,
                            config.num_steps, num_pos_tags, num_chunk_tags, vocab_size, continuing=True)

    ptb_iter = 0

    # Create an empty array to hold [epoch number, F1]
    if test==False:
        best_chunk_epoch = [0, 0.0]
        best_pos_epoch = [0, 0.0]
    else:
        best_chunk_epoch = [max_epoch, 0.0]

    print('constructing word embedding')

    if embedding==True:
        word_embedding = np.float32(word_embedding)
    else:
        word_embedding = np.float32((np.random.rand(vocab_size, config.word_embedding_size)-0.5)*config.init_scale)

    if test==False:
        with tf.Graph().as_default(), tf.Session() as session:
            print('building models')
            initializer = tf.random_uniform_initializer(-config.init_scale,
                                                        config.init_scale)

            # model to train hyperparameters on
            with tf.variable_scope("hyp_model", reuse=None, initializer=initializer):
                m = Shared_Model(is_training=True, config=config, num_pos_tags=num_pos_tags,
                num_chunk_tags=num_chunk_tags, vocab_size=vocab_size,
                word_embedding=word_embedding, projection_size=projection_size)

            with tf.variable_scope("hyp_model", reuse=True, initializer=initializer):
                mValid = Shared_Model(is_training=False, config=config, num_pos_tags=num_pos_tags,
                num_chunk_tags=num_chunk_tags, vocab_size=vocab_size,
                word_embedding=word_embedding, projection_size=projection_size)


            print('initialising variables')

            tf.initialize_all_variables().run()

            print("initialise word vectors")
            session.run(m.embedding_init, {m.embedding_placeholder: word_embedding})
            session.run(mValid.embedding_init, {mValid.embedding_placeholder: word_embedding})

            print('finding best epoch parameter')
            # ====================================
            # Create vectors for training results
            # ====================================

            # Create empty vectors for loss
            train_loss_stats = np.array([])
            train_pos_loss_stats = np.array([])
            train_chunk_loss_stats = np.array([])
            train_lm_loss_stats = np.array([])

            # Create empty vectors for accuracy
            train_pos_stats = np.array([])
            train_chunk_stats = np.array([])

            # ====================================
            # Create vectors for validation results
            # ====================================
            # Create empty vectors for loss
            valid_loss_stats = np.array([])
            valid_pos_loss_stats = np.array([])
            valid_chunk_loss_stats = np.array([])
            valid_lm_loss_stats = np.array([])

            # Create empty vectors for accuracy
            valid_pos_stats = np.array([])
            valid_chunk_stats = np.array([])

            for i in range(config.max_epoch):
                print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))

                print("Epoch: %d" % (i + 1))
                if config.random_mix == False:
                    if config.ptb == True:
                        _, _, _, _, _, _, _, _, _, _ = \
                            run_epoch(session, m,
                                      words_ptb, pos_ptb, chunk_ptb,
                                      num_pos_tags, num_chunk_tags, vocab_size, num_steps,
                                      verbose=True, model_type='LM')


                    mean_loss, posp_t, chunkp_t, lmp_t, post_t, chunkt_t, lmt_t, pos_loss, chunk_loss, lm_loss = \
                        run_epoch(session, m,
                                  words_t, pos_t, chunk_t,
                                  num_pos_tags, num_chunk_tags, vocab_size, num_steps,
                                  verbose=True, model_type=model_type)

                else:
                    # an additional if statement to get the gold vs pred connections
                    if i > num_batches_gold:
                        gold_percent = gold_percent * 0.8
                    else:
                        gold_percent = 1
                    if np.random.rand(1) < gold_percent:
                        gold_embed = 1
                    else:
                        gold_embed = 0
                    mean_loss, posp_t, chunkp_t, lmp_t, post_t, chunkt_t, lmt_t, pos_loss, chunk_loss, lm_loss, ptb_iter = \
                        run_epoch_random.run_epoch(session, m,
                                  words_t, words_ptb, pos_t, pos_ptb, chunk_t, chunk_ptb,
                                  num_pos_tags, num_chunk_tags, vocab_size, num_steps, gold_embed, config,
                                  ptb_batches, ptb_iter, verbose=True, model_type=model_type)


                print('epoch finished')
                # Save stats for charts
                train_loss_stats = np.append(train_loss_stats, mean_loss)
                train_pos_loss_stats = np.append(train_pos_loss_stats, pos_loss)
                train_chunk_loss_stats = np.append(train_chunk_loss_stats, chunk_loss)
                train_lm_loss_stats = np.append(train_lm_loss_stats, lm_loss)

                # get training predictions as list
                posp_t = reader._res_to_list(posp_t, config.batch_size, num_steps,
                                             pos_to_id, len(words_t), to_str=True)
                chunkp_t = reader._res_to_list(chunkp_t, config.batch_size, num_steps,
                                               chunk_to_id, len(words_t),to_str=True)
                lmp_t = reader._res_to_list(lmp_t, config.batch_size, num_steps,
                                                 word_to_id, len(words_t),to_str=True)
                post_t = reader._res_to_list(post_t, config.batch_size, num_steps,
                                             pos_to_id, len(words_t), to_str=True)
                chunkt_t = reader._res_to_list(chunkt_t, config.batch_size, num_steps,
                                                chunk_to_id, len(words_t), to_str=True)
                lmt_t = reader._res_to_list(lmt_t, config.batch_size, num_steps,
                                                 word_to_id, len(words_t),to_str=True)

                # find the accuracy
                print('finding accuracy')
                pos_acc = np.sum(posp_t==post_t)/float(len(posp_t))
                chunk_F1 = f1_score(chunkt_t, chunkp_t,average="weighted")

                # add to array
                train_pos_stats = np.append(train_pos_stats, pos_acc)
                train_chunk_stats = np.append(train_chunk_stats, chunk_F1)

                # print for tracking
                print("Pos Training Accuracy After Epoch %d :  %3f" % (i+1, pos_acc))
                print("Chunk Training F1 After Epoch %d : %3f" % (i+1, chunk_F1))

                valid_loss, posp_v, chunkp_v, lmp_v, post_v, chunkt_v, lmt_v, pos_v_loss, chunk_v_loss, lm_v_loss, ptb_iter = \
                    run_epoch_random.run_epoch(session, mValid,
                              words_v, words_ptb, pos_v, pos_ptb, chunk_v, chunk_ptb,
                              num_pos_tags, num_chunk_tags, vocab_size, num_steps, gold_embed, config,
                              ptb_batches, ptb_iter, verbose=True,  model_type=model_type, valid=True)

                # Save loss for charts
                valid_loss_stats = np.append(valid_loss_stats, valid_loss)
                valid_pos_loss_stats = np.append(valid_pos_loss_stats, pos_v_loss)
                valid_chunk_loss_stats = np.append(valid_chunk_loss_stats, chunk_v_loss)
                valid_lm_loss_stats = np.append(valid_lm_loss_stats, lm_v_loss)

                # get predictions as list
                posp_v = reader._res_to_list(posp_v, config.batch_size, num_steps,
                                             pos_to_id, len(words_v), to_str=True)
                chunkp_v = reader._res_to_list(chunkp_v, config.batch_size, num_steps,
                                                chunk_to_id, len(words_v), to_str=True)
                lmp_v = reader._res_to_list(lmp_v, config.batch_size, num_steps,
                                                word_to_id, len(words_v), to_str=True)
                chunkt_v = reader._res_to_list(chunkt_v, config.batch_size, num_steps,
                                                chunk_to_id, len(words_v), to_str=True)
                post_v = reader._res_to_list(post_v, config.batch_size, num_steps,
                                             pos_to_id, len(words_v), to_str=True)
                lmt_v = reader._res_to_list(lmt_v, config.batch_size, num_steps,
                                                word_to_id, len(words_v), to_str=True)

                # find accuracy
                pos_acc = np.sum(posp_v==post_v)/float(len(posp_v))
                chunk_F1 = f1_score(chunkt_v, chunkp_v, average="weighted")


                print("Pos Validation Accuracy After Epoch %d :  %3f" % (i+1, pos_acc))
                print("Chunk Validation F1 After Epoch %d : %3f" % (i+1, chunk_F1))

                # add to stats
                valid_pos_stats = np.append(valid_pos_stats, pos_acc)
                valid_chunk_stats = np.append(valid_chunk_stats, chunk_F1)

                if (abs(chunk_F1-prev_chunk_F1))<=0.001:
                    config.learning_rate = 0.8*config.learning_rate
                    print("learning rate updated")

                # update best parameters
                if(chunk_F1 > best_chunk_epoch[1]) or (pos_acc > best_pos_epoch[1]):
                    if pos_acc > best_pos_epoch[1]:
                        best_pos_epoch = [i+1, pos_acc]
                    if chunk_F1 > best_chunk_epoch[1]:
                        best_chunk_epoch = [i+1, chunk_F1]

                    saveload.save(save_path + '/val_model.pkl', session)
                    with open(save_path + '/pos_to_id.pkl', "wb") as file:
                        pickle.dump(pos_to_id, file)
                    with open(save_path + '/chunk_to_id.pkl', "wb") as file:
                        pickle.dump(chunk_to_id, file)
                    print("Model saved in file: %s" % save_path)

                    if write_to_file==False:
                        id_to_word = {v: k for k, v in word_to_id.items()}

                        words_t_unrolled = [id_to_word[k] for k in words_t[num_steps-1:]]
                        words_v_unrolled = [id_to_word[k] for k in words_v[num_steps-1:]]

                        # unroll data
                        train_custom = np.hstack((np.array(words_t_unrolled).reshape(-1,1), np.char.upper(post_t), np.char.upper(chunkt_t)))
                        valid_custom = np.hstack((np.array(words_v_unrolled).reshape(-1,1), np.char.upper(post_v), np.char.upper(chunkt_v)))
                        chunk_pred_train = np.concatenate((train_custom, np.char.upper(chunkp_t).reshape(-1,1)), axis=1)
                        chunk_pred_val = np.concatenate((valid_custom, np.char.upper(chunkp_v).reshape(-1,1)), axis=1)
                        pos_pred_train = np.concatenate((train_custom, np.char.upper(posp_t).reshape(-1,1)), axis=1)
                        pos_pred_val = np.concatenate((valid_custom, np.char.upper(posp_v).reshape(-1,1)), axis=1)

                        # write to file
                        np.savetxt(save_path + '/predictions/chunk_pred_train.txt',
                                   chunk_pred_train, fmt='%s')
                        print('writing to ' + save_path + '/predictions/chunk_pred_train.txt')
                        np.savetxt(save_path + '/predictions/chunk_pred_val.txt',
                                   chunk_pred_val, fmt='%s')
                        print('writing to ' + save_path + '/predictions/chunk_pred_val.txt')
                        np.savetxt(save_path + '/predictions/pos_pred_train.txt',
                                   pos_pred_train, fmt='%s')
                        print('writing to ' + save_path + '/predictions/pos_pred_train.txt')
                        np.savetxt(save_path + '/predictions/pos_pred_val.txt',
                                   pos_pred_val, fmt='%s')
                        print('writing to ' + save_path + '/predictions/pos_pred_val.txt')

                        print('Getting Testing Predictions (Valid)')
                        test_loss, posp_test, chunkp_test, lmp_test, post_test, chunkt_test, lmt_test, pos_test_loss, chunk_test_loss, lm_test_loss, ptb_iter = \
                            run_epoch_random.run_epoch(session, mValid,
                                      words_test, words_ptb, pos_test, pos_ptb, chunk_test, chunk_ptb,
                                      num_pos_tags, num_chunk_tags, vocab_size, num_steps, gold_embed, config,
                                      ptb_batches, ptb_iter, verbose=True,  model_type=model_type, valid=True)

                        # get predictions as list
                        posp_test = reader._res_to_list(posp_test, config.batch_size, num_steps,
                                                     pos_to_id, len(words_test), to_str=True)
                        chunkp_test = reader._res_to_list(chunkp_test, config.batch_size, num_steps,
                                                        chunk_to_id, len(words_test), to_str=True)
                        lmp_test = reader._res_to_list(lmp_test, config.batch_size, num_steps,
                                                        word_to_id, len(words_test), to_str=True)
                        chunkt_test = reader._res_to_list(chunkt_test, config.batch_size, num_steps,
                                                        chunk_to_id, len(words_test), to_str=True)
                        post_test = reader._res_to_list(post_test, config.batch_size, num_steps,
                                                     pos_to_id, len(words_test), to_str=True)
                        lmt_test = reader._res_to_list(lmt_test, config.batch_size, num_steps,
                                                        word_to_id, len(words_test), to_str=True)

                        words_test_c = [id_to_word[k] for k in words_test[num_steps-1:]]
                        test_data = np.hstack((np.array(words_test_c).reshape(-1,1), np.char.upper(post_test), np.char.upper(chunkt_test)))

                        # find the accuracy
                        print('finding  test accuracy')
                        pos_acc_train = np.sum(posp_test==post_test)/float(len(posp_test))
                        chunk_F1_train = f1_score(chunkt_test, chunkp_test,average="weighted")

                        print("POS Test Accuracy: " + str(pos_acc_train))
                        print("Chunk Test F1: " + str(chunk_F1_train))

                        chunk_pred_test = np.concatenate((test_data, np.char.upper(chunkp_test).reshape(-1,1)), axis=1)
                        pos_pred_test = np.concatenate((test_data, np.char.upper(posp_test).reshape(-1,1)), axis=1)

                        print('writing to ' + save_path + '/predictions/chunk_pred_combined.txt')
                        np.savetxt(save_path + '/predictions/chunk_pred_test.txt',
                                   chunk_pred_test, fmt='%s')
                        print('writing to ' + save_path + '/predictions/chunk_pred_test.txt')

                        np.savetxt(save_path + '/predictions/pos_pred_train.txt',
                                   pos_pred_train, fmt='%s')
                        print('writing to ' + save_path + '/predictions/pos_pred_train.txt')
                        np.savetxt(save_path + '/predictions/pos_pred_val.txt',
                                   pos_pred_val, fmt='%s')
                        print('writing to ' + save_path + '/predictions/pos_pred_val.txt')

                        np.savetxt(save_path + '/predictions/pos_pred_test.txt',
                                   pos_pred_test, fmt='%s')

                prev_chunk_F1 = chunk_F1

            # Save loss & accuracy plots
            np.savetxt(save_path + '/loss/valid_loss_stats.txt', valid_loss_stats)
            np.savetxt(save_path + '/loss/valid_pos_loss_stats.txt', valid_pos_loss_stats)
            np.savetxt(save_path + '/loss/valid_chunk_loss_stats.txt', valid_chunk_loss_stats)
            np.savetxt(save_path + '/accuracy/valid_pos_stats.txt', valid_pos_stats)
            np.savetxt(save_path + '/accuracy/valid_chunk_stats.txt', valid_chunk_stats)

            np.savetxt(save_path + '/loss/train_loss_stats.txt', train_loss_stats)
            np.savetxt(save_path + '/loss/train_pos_loss_stats.txt', train_pos_loss_stats)
            np.savetxt(save_path + '/loss/train_chunk_loss_stats.txt', train_chunk_loss_stats)
            np.savetxt(save_path + '/accuracy/train_pos_stats.txt', train_pos_stats)
            np.savetxt(save_path + '/accuracy/train_chunk_stats.txt', train_chunk_stats)


    if write_to_file == True:
            with tf.Graph().as_default(), tf.Session() as session:
                initializer = tf.random_uniform_initializer(-config.init_scale,
                                                            config.init_scale)

                with tf.variable_scope("final_model", reuse=None, initializer=initializer):
                    mTrain = Shared_Model(is_training=True, config=config, num_pos_tags=num_pos_tags,
                    num_chunk_tags=num_chunk_tags, vocab_size=vocab_size,
                    word_embedding=word_embedding, projection_size=projection_size)

                with tf.variable_scope("final_model", reuse=True, initializer=initializer):
                    mTest = Shared_Model(is_training=False, config=config, num_pos_tags=num_pos_tags,
                    num_chunk_tags=num_chunk_tags, vocab_size=vocab_size,
                    word_embedding=word_embedding, projection_size=projection_size)

                print("initialise variables")
                tf.initialize_all_variables().run()
                print("initialise word embeddings")
                session.run(mTrain.embedding_init, {mTrain.embedding_placeholder: word_embedding})
                session.run(mTest.embedding_init, {mTest.embedding_placeholder: word_embedding})




                # Train given epoch parameter
                if config.random_mix == False:
                    print('Train Given Best Epoch Parameter :' + str(best_chunk_epoch[0]))
                    for i in range(best_chunk_epoch[0]):
                        print("Epoch: %d" % (i + 1))
                        if config.ptb == False:
                            _, _, _, _, _, _, _, _, _, _ = \
                                run_epoch(session, mTrain,
                                          words_ptb, pos_ptb, chunk_ptb,
                                          num_pos_tags, num_chunk_tags, vocab_size, num_steps,
                                          verbose=True, model_type="LM")

                        _, posp_c, chunkp_c, _, _, _, _, _, _, _ = \
                            run_epoch(session, mTrain,
                                      words_c, pos_c, chunk_c,
                                      num_pos_tags, num_chunk_tags, vocab_size,
                                      verbose=True, model_type=model_type)

                else:
                    print('Train Given Best Epoch Parameter :' + str(best_chunk_epoch[0]))
                    # an additional if statement to get the gold vs pred connections
                    if i > num_batches_gold:
                        gold_percent = gold_percent * 0.8
                    else:
                        gold_percent = 1
                    if np.random.rand(1) < gold_percent:
                        gold_embed = 1
                    else:
                        gold_embed = 0
                    for i in range(best_chunk_epoch[0]):
                        print("Epoch: %d" % (i + 1))
                        _, posp_c, chunkp_c, _, post_c, chunkt_c, _, _, _, _, ptb_iter = \
                            run_epoch_random.run_epoch(session, mTrain,
                                      words_c, words_ptb, pos_c, pos_ptb, chunk_c, chunk_ptb,
                                      num_pos_tags, num_chunk_tags, vocab_size, num_steps, gold_embed, config,
                                      ptb_batches, ptb_iter, verbose=True, model_type=model_type)


                print('Getting Testing Predictions')
                test_loss, posp_test, chunkp_test, lmp_test, post_test, chunkt_test, lmt_test, pos_test_loss, chunk_test_loss, lm_test_loss, ptb_iter = \
                    run_epoch_random.run_epoch(session, mTest,
                              words_test, words_ptb, pos_test, pos_ptb, chunk_test, chunk_ptb,
                              num_pos_tags, num_chunk_tags, vocab_size, num_steps, gold_embed, config,
                              ptb_batches, ptb_iter, verbose=True,  model_type=model_type, valid=True)

                print('Writing Predictions')
                # prediction reshaping
                posp_c = reader._res_to_list(posp_c, config.batch_size, num_steps,
                                             pos_to_id, len(words_c), to_str=True)
                posp_test = reader._res_to_list(posp_test, config.batch_size, num_steps,
                                                pos_to_id, len(words_test), to_str=True)
                chunkp_c = reader._res_to_list(chunkp_c, config.batch_size, num_steps,
                                               chunk_to_id, len(words_c),to_str=True)
                chunkp_test = reader._res_to_list(chunkp_test, config.batch_size, num_steps,
                                                  chunk_to_id, len(words_test),  to_str=True)

                post_c = reader._res_to_list(post_c, config.batch_size, num_steps,
                                             pos_to_id, len(words_c), to_str=True)
                post_test = reader._res_to_list(post_test, config.batch_size, num_steps,
                                                pos_to_id, len(words_test), to_str=True)
                chunkt_c = reader._res_to_list(chunkt_c, config.batch_size, num_steps,
                                               chunk_to_id, len(words_c),to_str=True)
                chunkt_test = reader._res_to_list(chunkt_test, config.batch_size, num_steps,
                                                  chunk_to_id, len(words_test),  to_str=True)

                # save pickle - save_path + '/saved_variables.pkl'
                print('saving checkpoint')
                saveload.save(save_path + '/fin_model.ckpt', session)

                words_t = [id_to_word[k] for k in words_t[num_steps-1:]]
                words_v = [id_to_word[k] for k in words_v[num_steps-1:]]
                words_c = [id_to_word[k] for k in words_c[num_steps-1:]]
                words_test = [id_to_word[k] for k in words_test[num_steps-1:]]

                # find the accuracy
                print('finding test accuracy')
                pos_acc = np.sum(posp_test==post_test)/float(len(posp_test))
                chunk_F1 = f1_score(chunkt_test, chunkp_test,average="weighted")

                print("POS Test Accuracy (Both): " + str(pos_acc))
                print("Chunk Test F1(Both): " + str(chunk_F1))

                print("POS Test Accuracy (Train): " + str(pos_acc_train))
                print("Chunk Test F1 (Train): " + str(chunk_F1_train))


                if test==False:
                    train_custom = np.hstack((np.array(words_t).reshape(-1,1), np.char.upper(post_t), np.char.upper(chunkt_t)))
                    valid_custom = np.hstack((np.array(words_v).reshape(-1,1), np.char.upper(post_v), np.char.upper(chunkt_v)))
                combined = np.hstack((np.array(words_c).reshape(-1,1), np.char.upper(post_c), np.char.upper(chunkt_c)))
                test_data = np.hstack((np.array(words_test).reshape(-1,1), np.char.upper(post_test), np.char.upper(chunkt_test)))

                print('loaded text')

                if test==False:
                    chunk_pred_train = np.concatenate((train_custom, np.char.upper(chunkp_t).reshape(-1,1)), axis=1)
                    chunk_pred_val = np.concatenate((valid_custom, np.char.upper(chunkp_v).reshape(-1,1)), axis=1)
                chunk_pred_c = np.concatenate((combined, np.char.upper(chunkp_c).reshape(-1,1)), axis=1)
                chunk_pred_test = np.concatenate((test_data, np.char.upper(chunkp_test).reshape(-1,1)), axis=1)
                if test==False:
                    pos_pred_train = np.concatenate((train_custom, np.char.upper(posp_t).reshape(-1,1)), axis=1)
                    pos_pred_val = np.concatenate((valid_custom, np.char.upper(posp_v).reshape(-1,1)), axis=1)
                pos_pred_c = np.concatenate((combined, np.char.upper(posp_c).reshape(-1,1)), axis=1)
                pos_pred_test = np.concatenate((test_data, np.char.upper(posp_test).reshape(-1,1)), axis=1)

                print('finished concatenating, about to start saving')

                if test == False:
                    np.savetxt(save_path + '/predictions/chunk_pred_train.txt',
                               chunk_pred_train, fmt='%s')
                    print('writing to ' + save_path + '/predictions/chunk_pred_train.txt')
                    np.savetxt(save_path + '/predictions/chunk_pred_val.txt',
                               chunk_pred_val, fmt='%s')
                    print('writing to ' + save_path + '/predictions/chunk_pred_val.txt')

                np.savetxt(save_path + '/predictions/chunk_pred_combined.txt',
                           chunk_pred_c, fmt='%s')
                print('writing to ' + save_path + '/predictions/chunk_pred_combined.txt')
                np.savetxt(save_path + '/predictions/chunk_pred_test.txt',
                           chunk_pred_test, fmt='%s')
                print('writing to ' + save_path + '/predictions/chunk_pred_test.txt')

                if test == False:
                    np.savetxt(save_path + '/predictions/pos_pred_train.txt',
                               pos_pred_train, fmt='%s')
                    print('writing to ' + save_path + '/predictions/pos_pred_train.txt')
                    np.savetxt(save_path + '/predictions/pos_pred_val.txt',
                               pos_pred_val, fmt='%s')
                    print('writing to ' + save_path + '/predictions/pos_pred_val.txt')

                np.savetxt(save_path + '/predictions/pos_pred_combined.txt',
                           pos_pred_c, fmt='%s')
                np.savetxt(save_path + '/predictions/pos_pred_test.txt',
                           pos_pred_test, fmt='%s')

    else:
        print('Best Validation F1 ' + str(best_chunk_epoch[1]))
        print('Best Validation Epoch ' + str(best_chunk_epoch[0]))