示例#1
0
 def __init__(self, args, sess):
     self.env_act = Environment(args, 'act')
     # self.net_act = DeepQLearner(args, 'act', 'channels_first')
     self.net_act = DeepQLearner(args, sess, 'act')  # for tensorflow
     self.env_arg = Environment(args, 'arg')
     # self.net_arg = DeepQLearner(args, 'arg', 'channels_first')
     self.net_arg = DeepQLearner(args, sess, 'arg')  # for tensorflow
     self.num_words = args.num_words
     self.context_len = args.context_len
     self.gamma = args.gamma
     self.uncertainty_mode = 'cml'  # or 'cml'
示例#2
0
def EADQN_main(table, num, weights_dir):  #actionDBs, num):
    import argparse
    import sys
    import time
    import tensorflow as tf

    from Environment import Environment
    from ReplayMemory import ReplayMemory
    from EADQN import DeepQLearner
    from Agent import Agent
    parser = argparse.ArgumentParser()

    envarg = parser.add_argument_group('Environment')
    envarg.add_argument("--model_dir",
                        default="/home/fengwf/Documents/",
                        help="")
    envarg.add_argument("--vec_model", default='mymodel5-5-50', help="")
    envarg.add_argument("--vec_length", type=int, default=50, help="")
    envarg.add_argument("--actionDB", default='tag_actions', help="")
    envarg.add_argument("--max_text_num", default='64', help="")
    envarg.add_argument("--reward_assign",
                        default='2.0 1.0 -1.0 -2.0',
                        help="")
    envarg.add_argument("--action_rate", type=float, default=0.15, help="")
    envarg.add_argument("--penal_radix", type=float, default=5.0, help="")
    envarg.add_argument("--action_label", type=int, default=2, help="")
    envarg.add_argument("--non_action_label", type=int, default=1, help="")
    envarg.add_argument("--long_text_flag", type=int, default=1, help="")

    memarg = parser.add_argument_group('Replay memory')
    memarg.add_argument("--replay_size", type=int, default=100000, help="")
    memarg.add_argument("--channel", type=int, default=1, help="")
    memarg.add_argument("--positive_rate", type=float, default=0.75, help="")
    memarg.add_argument("--priority", default=1, help="")
    memarg.add_argument("--reward_bound", type=float, default=0, help="")

    netarg = parser.add_argument_group('Deep Q-learning network')
    netarg.add_argument("--num_actions", type=int, default=1000, help="")
    netarg.add_argument("--words_num", type=int, default=500, help="")
    netarg.add_argument("--wordvec", type=int, default=100, help="")
    netarg.add_argument("--learning_rate", type=float, default=0.0025, help="")
    netarg.add_argument("--momentum", type=float, default=0.1, help="")
    netarg.add_argument("--epsilon", type=float, default=1e-6, help="")
    netarg.add_argument("--decay_rate", type=float, default=0.88, help="")
    netarg.add_argument("--discount_rate", type=float, default=0.9, help="")
    netarg.add_argument("--batch_size", type=int, default=8, help="")
    netarg.add_argument("--target_output", type=int, default=2, help="")

    antarg = parser.add_argument_group('Agent')
    antarg.add_argument("--exploration_rate_start",
                        type=float,
                        default=1,
                        help="")
    antarg.add_argument("--exploration_rate_end",
                        type=float,
                        default=0.1,
                        help="")
    antarg.add_argument("--exploration_decay_steps",
                        type=int,
                        default=1000,
                        help="")
    antarg.add_argument("--exploration_rate_test",
                        type=float,
                        default=0.0,
                        help="")
    antarg.add_argument("--train_frequency", type=int, default=1, help="")
    antarg.add_argument("--train_repeat", type=int, default=1, help="")
    antarg.add_argument("--target_steps", type=int, default=5, help="")
    antarg.add_argument("--random_play", default=0, help="")

    mainarg = parser.add_argument_group('Main loop')
    mainarg.add_argument("--result_dir", default="test_result", help="")
    mainarg.add_argument("--train_steps", type=int, default=0, help="")
    mainarg.add_argument("--test_one", type=int, default=1, help="")
    mainarg.add_argument("--text_dir", default='', help="")
    mainarg.add_argument("--test", type=int, default=1, help="")
    mainarg.add_argument("--test_text_num", type=int, default=8, help="")
    mainarg.add_argument("--epochs", type=int, default=2, help="")
    mainarg.add_argument("--start_epoch", type=int, default=0, help="")
    mainarg.add_argument("--home_dir", default="./", help="")
    mainarg.add_argument("--load_weights", default="", help="")
    mainarg.add_argument("--save_weights_prefix", default="", help="")
    mainarg.add_argument("--computer_id", type=int, default=1, help="")
    mainarg.add_argument("--gpu_rate", type=float, default=0.2, help="")
    mainarg.add_argument("--cnn_format", default='NCHW', help="")

    args = parser.parse_args()
    tables_num = len(args.actionDB.split())
    args.load_weights = weights_dir
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_rate)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        net = DeepQLearner(args, sess)
        env = Environment(args)
        mem = ReplayMemory(args.replay_size, args)
        agent = Agent(env, mem, net, args)
        words = []
        states = []

        if args.load_weights:
            print 'Loading weights from %s...' % args.load_weights
            net.load_weights(args.home_dir +
                             args.load_weights)  #load last trained weights

        if args.test_one and args.load_weights:
            '''
            for i,ad in enumerate(actionDBs):
                tmp_w = []
                tmp_s = []
                for j in range(num[i]):
                    print 'table = %s,  text_num = %d'%(actionDBs[i],j)
                    ws, act_seq, st = agent.test_one_db(actionDBs[i], j)
                    tmp_w.append(ws)
                    tmp_s.append(st)
                    #print '\nStates: %s\n'%str(st)
                    #print '\nWords: %s\n'%str(ws)
                    #print '\n\nAction_squence: %s\n'%str(act_seq)
                words.append(tmp_w)
                states.append(tmp_s)
            '''
            tmp_w = []
            tmp_s = []
            for j in range(num):
                #print 'table = %s,  text_num = %d'%(table,j)
                ws, act_seq, st = agent.test_one_db(table, j)
                tmp_w.append(ws)
                tmp_s.append(st)
            words = tmp_w
            states = tmp_s
            print 'len(words) = %d,  len(states) = %d' % (len(words),
                                                          len(states))
        return words, states
示例#3
0
def main(args):
    print('Current time is: %s' % get_time())
    print('Starting at main...')
    result = {'rec': [], 'pre': [], 'f1': [], 'rw': []}

    start = time.time()

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    set_session(tf.Session(config=config)) # global Keras session

    env_act = Environment(args, args.agent_mode)
    net_act = DeepQLearner(args, args.agent_mode, 'channels_last')
    mem_act = ReplayMemory(args, args.agent_mode)
    agent = Agent(env_act, mem_act, net_act, args)  # agent takes in environment, memory, model and agent_mode

    # loop over epochs
    epoch_result = {'rec': [0.0], 'pre': [0.0], 'f1': [0.0], 'rw': [0.0]}
    training_result = {'rec': [], 'pre': [], 'f1': [], 'loss': [], 'rw': []}
    test_result = {'rec': [], 'pre': [], 'f1': [], 'loss': [], 'rw': []}
    log_epoch = 0


    # if we are loading weights, we don't need to train [no exploration is required. We have exploration rate start = end = 0.1], just test on test set.
    if args.load_weights:
        print('Loading weights ...')
        filename = 'weights/%s_%s_%s.h5' % (args.domain, args.agent_mode, args.contextual_embedding)
        net_act.load_weights(filename)
        #accuracy on test set
        with open("%s.txt" % (args.result_dir + 'testset'), 'w') as outfile:
            rec, pre, f1, rw = agent.test(args.test_steps, outfile, test_flag=True)
            outfile.write('\n\n Test f1 value: {}, recall : {}, precision : {}, reward: {} \n'.format(f1, rec,pre,rw ))
            print('\n\n Test f1 value: {}, recall : {}, precision : {}, reward: {} \n'.format(f1, rec,pre,rw ))

    if not args.load_weights:
        with open("%s.txt" % (args.result_dir), 'w') as outfile:
            print('\n Arguments:')
            outfile.write('\n Arguments:\n')
            for k, v in sorted(args.__dict__.items(), key=lambda x: x[0]):
                print('{}: {}'.format(k, v))
                outfile.write('{}: {}\n'.format(k, v))
            print('\n')
            outfile.write('\n')

            # do training

            for epoch in tqdm(range(args.start_epoch, args.start_epoch + args.epochs)):
                num_test = -1
                env_act.train_epoch_end_flag = False
                while not env_act.train_epoch_end_flag: #unless all documents are covered
                    # training
                    num_test += 1
                    restart_init = False if num_test > 0 else True
                    tmp_result = agent.train(args.train_steps, args.train_episodes, restart_init) #Train episodes = 50 , max episodes.
                    for k in training_result:
                        training_result[k].extend(tmp_result[k])

                    rec, pre, f1, rw = agent.test(args.valid_steps, outfile) # not testing; actually validation

                    if f1 > max(epoch_result['f1']):
                        if args.save_weights:
                            filename = 'weights/%s_%s_%s.h5' % (args.domain, args.agent_mode, args.contextual_embedding)
                            net_act.save_weights(filename)

                        epoch_result['f1'].append(f1)
                        epoch_result['rec'].append(rec)
                        epoch_result['pre'].append(pre)
                        epoch_result['rw'].append(rw)
                        log_epoch = epoch
                        outfile.write('\n\n Best f1 value: {}  best epoch: {}\n'.format(epoch_result, log_epoch))
                        print('\n\n Best f1 value: {}  best epoch: {}\n'.format(epoch_result, log_epoch))

                # if no improvement after args.stop_epoch_gap, break
                # EARLY STOPPING
                if epoch - log_epoch >= args.stop_epoch_gap:
                    outfile.write('\n\nBest f1 value: {}  best epoch: {}\n'.format(epoch_result, log_epoch))
                    print('\nepoch: %d  result_dir: %s' % (epoch, args.result_dir))
                    print('-----Early stopping, no improvement after %d epochs-----\n' % args.stop_epoch_gap)
                    break

            # if args.save_replay: #0 by default
            #     mem_act.save(args.save_replay_name, args.save_replay_size)

            filename = '%s_training_process.pdf' % (args.result_dir)
            plot_results(epoch_result, args.domain, filename)
            outfile.write('\n\n training process:\n{}\n\n'.format(epoch_result))

            best_ind = epoch_result['f1'].index(max(epoch_result['f1']))
            for k in epoch_result:
                result[k].append(epoch_result[k][best_ind])
                outfile.write('{}: {}\n'.format(k, result[k]))
                print(('{}: {}\n'.format(k, result[k])))
            avg_f1 = sum(result['f1']) / len(result['f1'])
            avg_rw = sum(result['rw']) / len(result['rw'])
            outfile.write('\nAvg f1: {}  Avg reward: {}\n'.format(avg_f1, avg_rw))
            print('\nAvg f1: {}  Avg reward: {}\n'.format(avg_f1, avg_rw))

            tf.compat.v1.reset_default_graph()
        end = time.time()
        print('Total time cost: %ds' % (end - start))
        print('Current time is: %s\n' % get_time())
示例#4
0
localtime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print 'Current time is:', localtime
print 'Starting at main.py...'

# use for investigating the influence of tag length
'''
f = open(args.home_dir + args.result_dir + "_train.txt",'w')
f1 = open(args.home_dir + args.result_dir + "_test.txt",'w')
f.write(str(args)+'\n')
f.write('\nCurrent time is: %s'%localtime)
f.write('\nStarting at main.py...')
'''
#Initial environment, replay memory, deep q net and agent
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_rate)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    net = DeepQLearner(args, sess)
    env = Environment(args)

    temp_size = env.train_steps * args.epochs + env.test_steps
    if temp_size > 100000:
        temp_size = 100000
    args.replay_size = temp_size
    args.train_steps = env.train_steps
    assert args.replay_size > 0

    mem = ReplayMemory(args.replay_size, args)
    agent = Agent(env, mem, net, args)

    #argsDict = args.__dict__
    #for eachArg in argsDict.keys():
    #    print eachArg, argsDict[eachArg]
示例#5
0
class Agent(object):
    """
    RL Agent for online Active Learning
    """
    def __init__(self, args, sess):
        self.env_act = Environment(args, 'act')
        # self.net_act = DeepQLearner(args, 'act', 'channels_first')
        self.net_act = DeepQLearner(args, sess, 'act')  # for tensorflow
        self.env_arg = Environment(args, 'arg')
        # self.net_arg = DeepQLearner(args, 'arg', 'channels_first')
        self.net_arg = DeepQLearner(args, sess, 'arg')  # for tensorflow
        self.num_words = args.num_words
        self.context_len = args.context_len
        self.gamma = args.gamma
        self.uncertainty_mode = 'cml'  # or 'cml'

    def load_data(self):
        """
        Load all unlabeled texts.
        PS: the file 'home_and_garden_500_words_with_title.pkl' contains more than 15k 
            unlabeled texts from wikihow home and garden category.
        """
        print(
            'Loading texts from data/home_and_garden_500_words_with_title.pkl ...'
        )
        self.texts = load_pkl('data/home_and_garden_500_words_with_title.pkl')
        self.label2text = text_classification()
        self.history_texts = []
        self.sort_ind = 1 if self.uncertainty_mode == 'diff' else 0
        self.category = 0  # category of the currently chosen text
        self.max_category = len(self.label2text) - 1

    def choose_unlabeled_texts(self, num_texts, dialog=None):
        """
        Apply Active Learning. 
        Choose texts from each class and sort them by cumulative reward.
        """
        chosen_texts = []
        while len(chosen_texts) < num_texts:
            # text_ind = np.random.randint(len(self.texts))
            # text = self.texts[text_ind]
            text_ind = random.sample(self.label2text[self.category], 1)[0]
            if text_ind in self.history_texts:  # or len(text['title'].split()) < 2:
                continue

            # print('textID: {:<10}  category: {}'.format(text_ind, self.category))
            # traverse all categories, choose texts from each category
            self.category = self.category + 1 if self.category < self.max_category else 0
            # predict Q-values, compute cumulative reward
            text = self.texts[text_ind]
            sents, word2sent, R_t = self.predict(text['sent'])
            r_t = R_t[:-1] - self.gamma * R_t[
                1:]  # deduced from R_t = r_t + gamma * R_{t+1}
            cml_rwd = sum(r_t) + self.gamma * R_t[-1]
            delta_r = abs(
                R_t[0] -
                cml_rwd)  # difference between predicted and real cml_rwd
            text['sents'] = sents
            text['reward'] = (cml_rwd, delta_r)  #
            text['r_t'] = r_t  # len(r_t) = len(words) - 1
            text['text_ind'] = text_ind
            text['word2sent'] = word2sent
            chosen_texts.append(text)
            if dialog:
                dialog.Update(
                    len(chosen_texts),
                    'Progress: %d/%d' % (len(chosen_texts), num_texts))
        # sort the texts by cumulative reward
        sorted_texts = sorted(chosen_texts,
                              key=lambda x: x['reward'][self.sort_ind])
        # for t in sorted_texts:
        #     print(t['text_ind'], t['reward'][self.sort_ind])
        # print('\n')
        return sorted_texts

    def predict(self, text):
        """
        Call EASDRL model to generate output actions for an input text
        e.g. text = ['Cook the rice the day before.', 'Use leftover rice.']
        """
        self.env_act.init_predict_act_text(text)
        # act_seq = []
        sents = []
        for i in range(len(self.env_act.current_text['sents'])):
            if i > 0:
                last_sent = self.env_act.current_text['sents'][i - 1]
                # last_pos = self.env_act.current_text['sent_pos'][i - 1]
            else:
                last_sent = []
                # last_pos = []
            this_sent = self.env_act.current_text['sents'][i]
            # this_pos = self.env_act.current_text['sent_pos'][i]
            sents.append({
                'last_sent': last_sent,
                'this_sent': this_sent,
                'acts': []
            })  #,
            #'last_pos': last_pos, 'this_pos': this_pos})
        word2sent = self.env_act.current_text['word2sent']
        # ipdb.set_trace()
        R_t = []
        for i in range(self.num_words):
            state_act = self.env_act.getState()
            qvalues_act = self.net_act.predict(state_act)
            R_t.append(max(qvalues_act[0]))
            action_act = np.argmax(qvalues_act[0])
            self.env_act.act_online(action_act, i)
            if action_act == 1:
                last_sent, this_sent = self.env_arg.init_predict_arg_text(
                    i, self.env_act.current_text)
                for j in range(self.context_len):
                    state_arg = self.env_arg.getState()
                    qvalues_arg = self.net_arg.predict(state_arg)
                    action_arg = np.argmax(qvalues_arg[0])
                    self.env_arg.act_online(action_arg, j)
                    if self.env_arg.terminal_flag:
                        break
                # act_name = self.env_act.current_text['tokens'][i]
                # act_arg = [act_name]
                act_idx = i
                obj_idxs = []
                sent_words = self.env_arg.current_text['tokens']
                tmp_num = self.context_len if len(
                    sent_words) >= self.context_len else len(sent_words)
                for j in range(tmp_num):
                    if self.env_arg.state[j, -1] == 2:
                        #act_arg.append(sent_words[j])
                        if j == len(sent_words) - 1:
                            j = -1
                        obj_idxs.append(j)
                if len(obj_idxs) == 0:
                    # act_arg.append(sent_words[-1])
                    obj_idxs.append(-1)
                # ipdb.set_trace()
                si, ai = self.env_act.current_text['word2sent'][i]
                ai += len(sents[si]['last_sent'])
                sents[si]['acts'].append({
                    'act_idx': ai,
                    'obj_idxs': [obj_idxs, []],
                    'act_type': 1,
                    'related_acts': []
                })
                # act_seq.append(act_arg)
            if self.env_act.terminal_flag:
                break
        # for k, v in act_seq.iteritems():
        #     print(k, v)
        # ipdb.set_trace()
        return sents, word2sent, np.array(R_t)
示例#6
0
def main(args):
    """
    main function, build, train, validate, save and load model
    """
    start = time.time()
    print('Current time is: %s' % get_time())
    print('Starting at main...')
    # store k-fold cross-validation results, including recall, precision, f1 and average reward
    fold_result = {'rec': [], 'pre': [], 'f1': [], 'rw': []}

    # one can continue to train model from the start_fold rather than fold 0
    for fi in range(args.start_fold, args.end_fold):
        fold_start = time.time()
        args.fold_id = fi
        if args.fold_id == args.start_fold:
            # Initialize environment and replay memory
            env_act = Environment(args, args.agent_mode)
            mem_act = ReplayMemory(args, args.agent_mode)
        else:
            env_act.get_fold_data(args.fold_id)
            mem_act.reset()
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=args.gpu_fraction)
        # set_session(tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))) # for keras
        with tf.Session(config=tf.ConfigProto(
                gpu_options=gpu_options)) as sess:
            # ipdb.set_trace()
            # Initialize deep_q_net and agent
            net_act = DeepQLearner(args, sess, args.agent_mode)
            agent = Agent(env_act, mem_act, net_act, args)

            # loop over epochs
            epoch_result = {
                'rec': [0.0],
                'pre': [0.0],
                'f1': [0.0],
                'rw': [0.0]
            }
            training_result = {
                'rec': [],
                'pre': [],
                'f1': [],
                'loss': [],
                'rw': []
            }
            log_epoch = 0
            with open("%s_fold%d.txt" % (args.result_dir, args.fold_id),
                      'w') as outfile:
                # print all args to the screen and outfile
                print_args(args, outfile)

                if args.load_weights:
                    print('Loading weights ...')
                    filename = 'weights/%s_%s_%d_fold%d.h5' % (
                        args.domain, args.agent_mode, args.k_fold,
                        args.fold_id)
                    net_act.load_weights(filename)

                for epoch in range(args.start_epoch,
                                   args.start_epoch + args.epochs):
                    # test the model every args.train_episodes or at the end of an epoch
                    num_test = -1
                    env_act.train_epoch_end_flag = False
                    while not env_act.train_epoch_end_flag:
                        # training
                        num_test += 1
                        restart_init = False if num_test > 0 else True
                        tmp_result = agent.train(args.train_steps,
                                                 args.train_episodes,
                                                 restart_init)
                        for k in training_result:
                            training_result[k].extend(tmp_result[k])
                        # testing
                        rec, pre, f1, rw = agent.test(args.valid_steps,
                                                      outfile)

                        if f1 > max(epoch_result['f1']):
                            if args.save_weights:
                                filename = 'weights/%s_%s_%d_fold%d.h5' % (
                                    args.domain, args.agent_mode, args.k_fold,
                                    args.fold_id)
                                net_act.save_weights(filename)

                            epoch_result['f1'].append(f1)
                            epoch_result['rec'].append(rec)
                            epoch_result['pre'].append(pre)
                            epoch_result['rw'].append(rw)
                            log_epoch = epoch
                            outfile.write(
                                '\n\n Best f1 score: {}  best epoch: {}\n'.
                                format(epoch_result, log_epoch))
                            print('\n\n Best f1 score: {}  best epoch: {}\n'.
                                  format(epoch_result, log_epoch))

                    # if no improvement after args.stop_epoch_gap, break
                    if epoch - log_epoch >= args.stop_epoch_gap:
                        outfile.write(
                            '\n\nBest f1 score: {}  best epoch: {}\n'.format(
                                epoch_result, log_epoch))
                        print('\nepoch: %d  result_dir: %s' %
                              (epoch, args.result_dir))
                        print(
                            '-----Early stopping, no improvement after %d epochs-----\n'
                            % args.stop_epoch_gap)
                        break
                if args.save_replay:
                    mem_act.save(args.save_replay_name, args.save_replay_size)

                # plot the training process results if you want
                # filename = '%s_fold%d_training_process.pdf'%(args.result_dir, args.fold_id)
                # plot_results(epoch_result, args.domain, filename)
                # outfile.write('\n\n training process:\n{}\n\n'.format(epoch_result))

                # find out the best f1 score in the current fold, add it to fold_result
                best_ind = epoch_result['f1'].index(max(epoch_result['f1']))
                for k in epoch_result:
                    fold_result[k].append(epoch_result[k][best_ind])
                    outfile.write('{}: {}\n'.format(k, fold_result[k]))
                    print(('{}: {}\n'.format(k, fold_result[k])))
                # compute the average f1 and average reward of all fold results up to now
                avg_f1 = sum(fold_result['f1']) / len(fold_result['f1'])
                avg_rw = sum(fold_result['rw']) / len(fold_result['rw'])
                outfile.write('\nAvg f1: {}  Avg reward: {}\n'.format(
                    avg_f1, avg_rw))
                print('\nAvg f1: {}  Avg reward: {}\n'.format(avg_f1, avg_rw))

                fold_end = time.time()
                print('Total time cost of fold %d is: %ds' %
                      (args.fold_id, fold_end - fold_start))
                outfile.write('\nTotal time cost of fold %d is: %ds\n' %
                              (args.fold_id, fold_end - fold_start))

        tf.reset_default_graph()
    end = time.time()
    print('Total time cost: %ds' % (end - start))
    print('Current time is: %s\n' % get_time())
示例#7
0
def main(args):
    if args.load_weights:
        args.exploration_decay_steps = 10

    start = time.time()
    localtime = time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(time.time()))
    print 'Current time is:',localtime
    print 'Starting at main.py...'

    # use for investigating the influence of tag length
    '''
    f = open(args.home_dir + args.result_dir + "_train.txt",'w')
    f1 = open(args.home_dir + args.result_dir + "_test.txt",'w')
    f.write(str(args)+'\n')
    f.write('\nCurrent time is: %s'%localtime)
    f.write('\nStarting at main.py...')
    '''
    #Initial environment, replay memory, deep q net and agent
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_rate)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        net = DeepQLearner(args, sess)
        env = Environment(args)

        temp_size = env.train_steps * args.epochs + env.test_steps
        if temp_size > 100000:
            temp_size = 100000
        args.replay_size = temp_size
        args.train_steps = env.train_steps
        assert args.replay_size > 0

        mem = ReplayMemory(args.replay_size, args)
        agent = Agent(env, mem, net, args)

        print '\n',args,'\n'
        
        if args.load_weights:
            print 'Loading weights from %s...'%args.load_weights
            net.load_weights(args.home_dir + args.load_weights)  #load last trained weights


        if args.test_one and args.load_weights:
            ws, act_seq, st = agent.test_one(args.text_dir)
            #f0.write('\nText_vec: %s'%str(env.text_vec))
            print '\nStates: %s\n'%str(st)
            print '\nWords: %s\n'%str(ws)
            print '\n\nAction_squence: %s\n'%str(act_seq)

        else:
            # loop over epochs
            for epoch in xrange(args.start_epoch, args.epochs):
                #print '\n----------epoch: %d----------'%(epoch+1)
                epoch_start = time.time()
                f = open(args.home_dir + args.result_dir + "_train"+ str(epoch) + ".txt",'w')
                f1 = open(args.home_dir + args.result_dir + "_test"+ str(epoch) + ".txt",'w')

                f.write(str(args)+'\n')
                f.write('\nCurrent time is: %s'%localtime)
                f.write('\nStarting at main.py...')
                #print 'env.train_steps: %d'%env.train_steps
                #print 'env.test_steps: %d'%env.test_steps
                #assert 1==0
                if args.train_steps > 0:
                    #agent.train(args.train_steps, epoch)
                    if epoch == args.start_epoch:
                        env.train_init()
                    agent.train(args.train_steps, epoch)
                    if args.save_weights_prefix:
                        filename = args.home_dir + args.save_weights_prefix + "_%d.prm" % (epoch + 1)
                        net.save_weights(filename)

                    cnt = 0
                    ras = 0
                    tas = 0
                    tta = 0

                    for i in range(env.size):#len(env.saved_text_vec)):
                        text_vec_tags = env.saved_text_vec[i,:,-1]
                        state_tags = env.saved_states[i,:,-1]
                        sum_tags = sum(text_vec_tags)
                        if not sum_tags:
                            break
                        count = 0
                        right_actions = 0
                        tag_actions = 0
                        total_actions = 0
                        total_words = args.num_actions/2
                        temp_words = env.saved_text_length[i]
                        if temp_words > total_words:
                            temp_words = total_words

                        #print "text_vec_tags",text_vec_tags
                        #print 'state_tags',state_tags
                        for t in text_vec_tags:
                            if t == args.action_label:
                                total_actions += 1

                        f.write('\n\nText:'+str(i))
                        f.write('\ntotal words: %d\n'%temp_words)
                        print '\ntotal words: %d\n'%temp_words
                        #f.write('\nsaved_text_vec:\n')
                        #f.write(str(env.saved_text_vec[i,:,-1]))
                        #f.write('\nsaved_states:\n')
                        #f.write(str(env.saved_states[i,:,-1]))

                        for s in xrange(temp_words):
                            if state_tags[s] == 0:
                                count += 1
                            elif state_tags[s] == args.action_label:
                                tag_actions += 1
                                if text_vec_tags[s] == state_tags[s]:
                                    right_actions += 1

                        cnt += count
                        ras += right_actions
                        tta += tag_actions
                        tas += total_actions
                        if total_actions > 0:
                            recall = float(right_actions)/total_actions
                        else:
                            recall = 0
                        if tag_actions > 0:
                            precision = float(right_actions)/tag_actions
                        else:
                            precision = 0
                        rp = recall + precision
                        if rp > 0:
                            F_value = (2.0*recall*precision)/(recall+precision)
                        else:
                            F_value = 0
                        f.write('\nWords left: %d'%count)
                        f.write('\nAcions: %d'%total_actions)
                        f.write('\nRight_actions: %d'%right_actions)
                        f.write('\nTag_actions: %d'%tag_actions)
                        f.write('\nActions_recall: %f'%recall)
                        f.write('\nActions_precision: %f'%precision)
                        f.write('\nF_measure: %f'%F_value)
                        print '\nText: %d'%i
                        print '\nWords left: %d'%count
                        print 'Acions: %d'%total_actions
                        print 'Right_actions: %d'%right_actions
                        print 'Tag_actions: %d'%tag_actions
                        print 'Actions_recall: %f'%recall
                        print 'Actions_precision: %f'%precision
                        print 'F_measure: %f'%F_value

                    if tas > 0:
                        average_recall = float(ras)/tas
                    else:
                        average_recall = 0
                    if tta > 0:
                        average_precision = float(ras)/tta
                    else:
                        average_precision = 0
                    arp = average_recall + average_precision
                    if arp > 0:
                        ave_F_value = (2*average_recall*average_precision)/(average_recall+average_precision)
                    else:
                        ave_F_value = 0
                    f.write('\nTotal words left: %d'%cnt)
                    f.write('\nTotal acions: %d'%tas)
                    f.write('\nTotal right_acions: %d'%ras)
                    f.write('\nTotal tag_acions: %d'%tta)
                    f.write('\nAverage_actions_recall: %f'%average_recall)
                    f.write('\nAverage_actions_precision: %f'%average_precision)
                    f.write('\nAverage_F_measure: %f'%ave_F_value)
                    print '\nTotal words left: %d'%cnt
                    print 'Total acions: %d'%tas
                    print 'Total right_actions: %d'%ras
                    print 'Total tag_actions: %d'%tta
                    print 'Average_actions_recall: %f'%average_recall
                    print 'Average_actions_precision: %f'%average_precision
                    print 'Average_F_measure: %f'%ave_F_value


                if args.test:
                    f1.write('test_texts: %s\ttexts_num: %d\n'%(str(env.test_text_name), args.test_text_num))
                    agent.test(args.words_num, env.test_steps/args.words_num, f1)

                epoch_end = time.time()
                print 'Total time cost of epoch %d is: %ds'%(epoch, epoch_end-epoch_start)
                f.write('\nTotal time cost of epoch %d is: %ds\n'%(epoch, epoch_end-epoch_start))
                f1.write('\nTotal time cost of epoch %d is: %ds\n'%(epoch, epoch_end-epoch_start))

                f.close()
                f1.close()

        end = time.time()
        print 'Total time cost: %ds'%(end-start)
        localtime = time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(time.time()))
        print 'Current time is: %s'%localtime