Esempio n. 1
0
    def __init__(self):
        ''' Initializes a Span Experiment Class
        '''
        super(SpanExperiment, self).__init__()
        parser = build_parser()
        self._build_char_index()
        self.args = parser.parse_args()

        self.model_name = self.args.rnn_type
        self.has_test=True
        # self.hierarchical = False
        self.patience = 0
        self.end_format = False
        self.return_idx = False

        # Use RNN for char-embed for now
        self.args.char_enc = 'RNN'
        self.args.cnn_size = self.args.rnn_size

        """ Dataset-specific settings
        """

        if('Squad' in self.args.dataset):
            self.has_test = False
            self.end_format = True
            self.args.align_spans = True
            self.args.use_lower = 0
        if('NewsQA' in self.args.dataset):
            self.end_format = True
            self.args.align_spans = True
            self.args.use_lower = 0
        if('Quasar' in self.args.dataset):
            self.end_format = True
            self.args.align_spans = True
            self.args.use_lower = 0
        if('TriviaQA' in self.args.dataset):
            self.end_format = True
            self.args.align_spans = True
            self.has_test =False
            self.args.use_lower = 0
        if('NarrativeQA' in self.args.dataset):
            self.end_format = True
            self.args.align_spans = False
            self.has_test = True
            self.args.use_lower = 1
        if('SearchQA' in self.args.dataset):
            self.end_format = True
            self.args.align_spans = False
            self.has_test = True
            self.args.use_lower =1

        printc("====================================", 'green')
        printc("[Start] Training Span Prediction task", 'green')
        printc('[{} Dataset]'.format(self.args.dataset), 'green')
        printc("Tensorflow {}".format(tf.__version__), 'green')

        self._setup()
        self.sdmax = None
        self.f_len = None
        self.num_choice = 0
        self.query_max = 0
        self.test_set2, self.dev_set2 = None, None
        self.num_features = 0
        self.num_global_features = 0

        print("Loading environment...")

        try:
            self.env = fast_load(
                    './datasets/{}/env.gz'.format(
                            self.args.dataset))
        except:
            print("Can't find GZ file. loading pure json instead..")
            self.env = fast_load('./datasets/{}/env.json'.format(
                                    self.args.dataset))
        if(self.args.add_features!=""):
            print("Loading Features..")
            self.feats =  fast_load(
                    './datasets/{}/feats.gz'.format(
                            self.args.dataset))
            if('EM' in self.args.add_features):
                self.num_features +=1
        else:
            self.feats = {'train':None, 'test':None, 'dev':None,
                            'dev2':None, 'test2': None}

        # Build Word Index and Inverse Index
        print("Word Index={}".format(len(self.env['word_index'])))
        self.index_word = {key: value for value,
                        key in self.env['word_index'].items()}
        self.word_index = self.env['word_index']
        
        ############################### no errors ########################################## 

        self.mdl = SpanModel(len(self.env['word_index']),
                        self.args,
                        char_size=len(self.char_index),
                        sdmax=self.sdmax,
                        f_len=self.f_len,
                        num_features=self.num_features,
                        num_global_features=self.num_global_features
                            )
        self._setup_tf()

        if('Baidu' in self.args.dataset):
            self = prep_all_baidu(self)
        else:
            self.dev_set, self.dev_eval = self._prepare_set(
                                        self.env['dev'],
                                        set_type='dev',
                                        features=self.feats['dev'])
            self.train_set, self.train_eval = self._prepare_set(
                                        self.env['train'],
                                        set_type='train',
                                        features=self.feats['train'])

            print('Train={} Dev={}'.format(len(list(self.train_set)),	# my edit
                                        len(list(self.dev_set))))

            if('dev2' in self.env):
                self.dev_set2, self.dev_eval2 = self._prepare_set(
                                        self.env['dev2'],
                                        set_type='dev',
                                        features=self.feats['dev2'])
            else:
                self.dev_set2, self.dev_eval2 = None, None

        if(self.has_test):
            self.test_set, self.test_eval = self._prepare_set(self.env['test'],
                                        set_type='test',
                                        features=self.feats['test'])
            print("Test={}".format(len(list(self.test_set))))

        print("Loaded environment")
        print("Vocab size={}".format(len(self.env['word_index'])))

        self._make_dir()

        # Primary metric to use to align dev and test sets
        if('NarrativeQA' in self.args.dataset):
            self.eval_primary = 'Rouge'
            self.show_metrics  = ['Bleu1','Bleu4','Meteor','Rouge']
        else:
            self.eval_primary = 'EM'
            self.show_metrics = ['EM','F1']
Esempio n. 2
0
class SpanExperiment(Experiment):
    '''

    This experiment conducts Span-based Experiments
    (e.g., NewsQA, TriviaQA, Squad etc.)
    '''

    def __init__(self):
        ''' Initializes a Span Experiment Class
        '''
        super(SpanExperiment, self).__init__()
        parser = build_parser()
        self._build_char_index()
        self.args = parser.parse_args()

        self.model_name = self.args.rnn_type
        self.has_test=True
        # self.hierarchical = False
        self.patience = 0
        self.end_format = False
        self.return_idx = False

        # Use RNN for char-embed for now
        self.args.char_enc = 'RNN'
        self.args.cnn_size = self.args.rnn_size

        """ Dataset-specific settings
        """

        if('Squad' in self.args.dataset):
            self.has_test = False
            self.end_format = True
            self.args.align_spans = True
            self.args.use_lower = 0
        if('NewsQA' in self.args.dataset):
            self.end_format = True
            self.args.align_spans = True
            self.args.use_lower = 0
        if('Quasar' in self.args.dataset):
            self.end_format = True
            self.args.align_spans = True
            self.args.use_lower = 0
        if('TriviaQA' in self.args.dataset):
            self.end_format = True
            self.args.align_spans = True
            self.has_test =False
            self.args.use_lower = 0
        if('NarrativeQA' in self.args.dataset):
            self.end_format = True
            self.args.align_spans = False
            self.has_test = True
            self.args.use_lower = 1
        if('SearchQA' in self.args.dataset):
            self.end_format = True
            self.args.align_spans = False
            self.has_test = True
            self.args.use_lower =1

        printc("====================================", 'green')
        printc("[Start] Training Span Prediction task", 'green')
        printc('[{} Dataset]'.format(self.args.dataset), 'green')
        printc("Tensorflow {}".format(tf.__version__), 'green')

        self._setup()
        self.sdmax = None
        self.f_len = None
        self.num_choice = 0
        self.query_max = 0
        self.test_set2, self.dev_set2 = None, None
        self.num_features = 0
        self.num_global_features = 0

        print("Loading environment...")

        try:
            self.env = fast_load(
                    './datasets/{}/env.gz'.format(
                            self.args.dataset))
        except:
            print("Can't find GZ file. loading pure json instead..")
            self.env = fast_load('./datasets/{}/env.json'.format(
                                    self.args.dataset))
        if(self.args.add_features!=""):
            print("Loading Features..")
            self.feats =  fast_load(
                    './datasets/{}/feats.gz'.format(
                            self.args.dataset))
            if('EM' in self.args.add_features):
                self.num_features +=1
        else:
            self.feats = {'train':None, 'test':None, 'dev':None,
                            'dev2':None, 'test2': None}

        # Build Word Index and Inverse Index
        print("Word Index={}".format(len(self.env['word_index'])))
        self.index_word = {key: value for value,
                        key in self.env['word_index'].items()}
        self.word_index = self.env['word_index']
        
        ############################### no errors ########################################## 

        self.mdl = SpanModel(len(self.env['word_index']),
                        self.args,
                        char_size=len(self.char_index),
                        sdmax=self.sdmax,
                        f_len=self.f_len,
                        num_features=self.num_features,
                        num_global_features=self.num_global_features
                            )
        self._setup_tf()

        if('Baidu' in self.args.dataset):
            self = prep_all_baidu(self)
        else:
            self.dev_set, self.dev_eval = self._prepare_set(
                                        self.env['dev'],
                                        set_type='dev',
                                        features=self.feats['dev'])
            self.train_set, self.train_eval = self._prepare_set(
                                        self.env['train'],
                                        set_type='train',
                                        features=self.feats['train'])

            print('Train={} Dev={}'.format(len(list(self.train_set)),	# my edit
                                        len(list(self.dev_set))))

            if('dev2' in self.env):
                self.dev_set2, self.dev_eval2 = self._prepare_set(
                                        self.env['dev2'],
                                        set_type='dev',
                                        features=self.feats['dev2'])
            else:
                self.dev_set2, self.dev_eval2 = None, None

        if(self.has_test):
            self.test_set, self.test_eval = self._prepare_set(self.env['test'],
                                        set_type='test',
                                        features=self.feats['test'])
            print("Test={}".format(len(list(self.test_set))))

        print("Loaded environment")
        print("Vocab size={}".format(len(self.env['word_index'])))

        self._make_dir()

        # Primary metric to use to align dev and test sets
        if('NarrativeQA' in self.args.dataset):
            self.eval_primary = 'Rouge'
            self.show_metrics  = ['Bleu1','Bleu4','Meteor','Rouge']
        else:
            self.eval_primary = 'EM'
            self.show_metrics = ['EM','F1']

    def compute_metrics(self, spans, passage, passage_str, labels, qids,
                        maxspan=15, questions=None, set_type='',
                        align_spans=None, spans2=None):
        """ Compute EM and F1 Metrics
        """

        all_em, all_f1 = [], []
        all_em2, all_f12 = [], []
        all_rouge, all_b4, all_b1, all_meteor = [],[],[],[]

        if(type(passage_str[0]) is list):
            # print(passage[0])
            def passage2id(passage):
                return [self.index_word[x] for x in passage if x>0]

            passage_str = [passage2id(x) for x in passage]
            passage_str = [' '.join(x) for x in passage_str]

        max_span = self.args.max_span

        assert(len(spans)==len(passage)==len(passage_str)==len(labels))

        predict_dict, ans_dict = {}, {}

        error_rate = 0

        if(align_spans is not None):
            assert(len(align_spans)==len(spans))

        for i in tqdm(range(len(passage)), desc='Evaluating'):
            _span = spans[i]
            if(spans2 is not None):
                _spans2 = spans2[i]
            else:
                _spans2 = None
            _passage_str = passage_str[i]
            ans_start = np.array(_span[0]).reshape((-1)).tolist()
            ans_end = np.array(_span[1]).reshape((-1)).tolist()
            _qids = qids[i]
            _passage = passage[i]
            if(len(labels[i])==0):
                continue

            _passage_words = _passage_str.split(' ')[:self.args.tmax]

            if(self.args.dataset=='SearchQA'):
                # Unigram only
                if(len(labels[i][0].split())==1):
                    # Unigram sample
                    ans, _ = get_ans_string_single_post_pad_search_updated(
                                                    _passage_str,
                                                    _passage_words,
                                                    ans_start, ans_end,
                                                    maxspan=1
                                                    )
                    _em = metric_max_over_ground_truths(exact_match_score,
                                                    ans, labels[i])
                    all_em2.append(_em)
                else:
                    # Ngram sample
                    ans, _ = get_ans_string_single_post_pad_search_updated(
                                                    _passage_str,
                                                    _passage_words,
                                                    ans_start, ans_end,
                                                    maxspan=3
                                                    )
                    _f1 = metric_max_over_ground_truths(f1_score, ans,
                                                            labels[i])
                    all_f12.append(_f1)

            if(align_spans is not None):
                _align_spans = align_spans[i]
            else:
                _align_spans = None

            ans, error = get_ans_string_single_post_pad_search_updated(
                                                _passage_str,
                                                _passage_words,
                                                ans_start, ans_end,
                                                maxspan=max_span,
                                                align_spans=_align_spans,
                                                spans2=_spans2,
                                                return_idx=self.return_idx
                                                )

            error_rate += error

            predict_dict[str(_qids)] = [ans]
            ans_dict[str(_qids)] = [x for x in labels[i]]

            _em = metric_max_over_ground_truths(exact_match_score,
                                                    ans, labels[i])
            _f1 = metric_max_over_ground_truths(f1_score, ans,
                                                    labels[i])
            all_em.append(_em)
            all_f1.append(_f1)

        # Merge dicts
        merge_dict = {}
        for key, value in predict_dict.items():
            _ans = ans_dict[key]
            #_ans = [x.encode('utf-8') for x in _ans]		# my edit
            #value = [x.encode('utf-8') for x in value]		# my edit
            merge_dict[key] = [value, _ans]

        
        print("errors={} out of {}".format(error_rate, len(passage)))
        try:
            with open(self.out_dir+'./{}.pred_ans.json'.format(set_type), 'w+') as f:
                json.dump(merge_dict, f, indent=4, ensure_ascii=False)	
        except:
            print("Can't find write due to some reason..")

        if(self.args.dataset=='SearchQA'):
            metric1 = 100 * np.mean(all_em2)
            metric2 = 100 * np.mean(all_f12)
            self.write_to_file('[2nd Eval] EM={} F1={}'.format(100 * np.mean(all_em),
                                    100 * np.mean(all_f1)))
            return [metric1, metric2]
        elif('NarrativeQA' in self.args.dataset):
            bleu = batch_bleu_score(ans_dict, predict_dict, n=4)
            metric1 = 100 * bleu[0]
            metric2 = 100 * bleu[3]
            # metric2 = batch_bleu_score(ans_dict, predict_dict, n=4)
            metric3 = 100 * batch_meteor_score(ans_dict, predict_dict)
            metric4 = 100 * batch_rouge_score(ans_dict, predict_dict)
            return [metric1, metric2, metric3, metric4]
        elif('Baidu' in self.args.dataset):
            bleu = batch_bleu_score(ans_dict, predict_dict, n=4)
            bleu4 = 100 * bleu[3]
            metric4 = 100 * batch_rouge_score(ans_dict, predict_dict)
            return [bleu4, metric4]
        else:
            metric1 = 100 * np.mean(all_em)
            metric2 = 100 * np.mean(all_f1)
            return [metric1, metric2]

    def evaluate(self, epoch, data, original_data, name='', set_type=''):
        """ Main evaluation function
        """
        if('NarrativeQA2' in self.args.dataset):
            metrics = evaluate_nqa2(self, epoch, data, original_data, name=name,
                                    set_type=set_type)
            return metrics
        # Training Iteration
        losses, accuracies = [],[]

        batch_size = int(self.args.batch_size/self.args.test_bsz_div)

        num_batches = int(len(list(data)) / batch_size)
        accuracies = 0
        all_start = []
        all_end = []

        ground_truth = [x[1] for x in original_data]
        passages_str = [x[0] for x in original_data]
        if(self.args.align_spans==1):
            # Spans exist
            print("Found align spans, using them")
            align_spans = [x[3] for x in original_data]
        else:
            align_spans = None
        qids = [x[2] for x in original_data]

        passages = [x[0] for x in data]
        all_yp1 = []
        all_yp2 = []

        for i in tqdm(range(0, num_batches + 1), desc='predicting'):
            batch = make_batch(data, batch_size, i)
            if(batch is None):
                continue
            feed_dict = self.mdl.get_feed_dict_v2(
                                        self.mdl.feed_holder,
                                        batch,
                                        mode='testing')

            loss, p = self.sess.run(
                    [self.mdl.cost, self.mdl.predict_op],
                    feed_dict)

            yp1, yp2 = self.sess.run([self.mdl.yp1, self.mdl.yp2],
                                feed_dict)

            start_p = p[0]
            end_p = p[1]
            all_start += [x for x in start_p]
            all_end += [x for x in end_p]
            all_yp1 += yp1.tolist()
            all_yp2 += yp2.tolist()
            losses.append(loss)

        assert(len(all_start)==len(all_end))
        assert(len(all_start)==len(data))
        spans = list(zip(all_start, all_end))

        if('Baidu' in self.args.dataset):
            original = passages_str
            metrics = compute_baidu_metrics(self, spans, original,
                                        ground_truth, qids, self.pmax,
                                        set_type=set_type)
        else:
            metrics = self.compute_metrics(spans, passages,
                                    passages_str, ground_truth, qids,
                                    questions=[x[1] for x in original_data],
                                    set_type=set_type,
                                    align_spans=align_spans,
                                    spans2=list(zip(all_yp1, all_yp2)))
        acc = 0
        self.write_to_file("[{}] Loss={}".format(
                            name, np.mean(losses)))
        if('NarrativeQA' in self.args.dataset):
            self._register_eval_score(epoch, set_type, 'Bleu1',metrics[0])
            self._register_eval_score(epoch, set_type, 'Bleu4', metrics[1])
            self._register_eval_score(epoch, set_type, 'Meteor',metrics[2])
            self._register_eval_score(epoch, set_type, 'Rouge', metrics[3])
        elif('Baidu' in self.args.dataset):
            self._register_eval_score(epoch, set_type, 'Bleu4',metrics[0])
            self._register_eval_score(epoch, set_type, 'Rouge', metrics[1])
        else:
            self._register_eval_score(epoch, set_type, 'F1', metrics[1])
            self._register_eval_score(epoch, set_type, 'EM', metrics[0])
        return metrics

    def get_predictions(self, epoch, data, name='', set_type=''):
        """ Same as evaluate but do not compute metrics
        """
        num_batches = int(len(data) / self.args.batch_size)
        accuracies = 0
        all_p = []

        for i in range(0, num_batches + 1):
            batch = make_batch(data,
                    self.args.batch_size, i)
            if(batch is None):
                continue
            feed_dict = self.mdl.get_feed_dict_v2(self.mdl.feed_holder,
                                        batch,
                                        mode='testing')
            p = self.sess.run([self.mdl.predictions], feed_dict)
            all_p += [x for x in p[0]]
        # print(all_p)
        assert(len(all_p)==len(data))
        return all_p

    def _write_predictions(self, preds, set_type):
        with open(self.out_dir + './{}_pred.txt'.format(set_type), 'w+') as f:
            for p in preds:
                f.write(str(p) + '\n')

    def _prepare_set(self, data,
                    set_type='', features=None, ans_features=None):
        """ Prepares set.
        Takes in raw processed data (env.gz) and loads them into
        arrays for passing into feed_dict.
        """
        # data = data


        if(set_type=='train' and self.args.adjust==0):
            print("Removing all samples with ptr more than {}".format(
                                                    self.args.smax))
            print("Original Samples={}".format(len(data)))
            if(self.end_format==1):
                data = [x for x in data if x[3]<self.args.smax]
            else:
                data = [x for x in data if x[2]+x[3]<self.args.smax]
            print("Reduced Samples={}".format(len(data)))

        if(self.args.align_spans==1):
            eval_data = [[x[6], x[5],x[4],x[7]] for x in data]
            asp = [x[7] for x in data]
            show_stats('align spans', [len(x) for x in asp])

        else:
            eval_data = [[x[0],x[5],x[4]] for x in data]

        self.char_pad_token = [0 for i in range(self.args.char_max)]

        def flatten_list(l):
            flat_list = [item for sublist in l for item in sublist]
            return flat_list

        print("Preparing {}".format(set_type))

        def w2i(w):
            try:
                return self.word_index[w]
            except:
                return 1

        def tokenize(s, vmax, pad=True):
            if(self.args.use_lower):
                s = [x.lower() for x in s]
            s = s[:vmax]
            tokens = [w2i(x) for x in s]
            lengths = [len(x) for x in s]
            if(pad==True):
                tokens = pad_to_max(tokens, vmax)
            return tokens

        def hierarchical_tokenize(s, vmax):
            s = [[tokenize(x, vmax) for x in y] for y in s]
            s = [x for x in s]
            return s

        def clip_len(s, vmax):
            if(s>vmax):
                return vmax
            else:
                return s

        if(set_type=='train'):
            smax, qmax = self.args.smax, self.args.qmax
        else:
            smax, qmax = self.args.tmax, self.args.qmax
        q = [x[1] for x in data]
        q_raw = [x[1] for x in data]
        q = [x.split(' ') for x in q]
        q_raw = [x.split(' ') for x in q_raw]
        qlen_raw = [len(x) for x in q]

        qlen = [clip_len(x, qmax) for x in qlen_raw]

        q = [tokenize(x, qmax) for x in tqdm(q, desc='tokenizing qns')]

        start = [x[2] for x in data]

        passages = [x[0] for x in data]
        _passages = [x.split(' ') for x in passages]
        if(self.args.adjust==1 and set_type=='train'):
            print("Adjusting passages")
            _passages, start, align = adjust_passages(_passages, start,
                                    self.args.smax,
                                    span=int(self.args.smax/2))
        # print(_passages[0])
        dlen_raw = [len(x) for x in _passages]
        dlen = [clip_len(x, smax) for x in dlen_raw]
        docs = [tokenize(x, smax) for x in tqdm(_passages,
                                        desc='tokenizing docs')]

        if(self.args.align_spans==1):
            show_stats('pointer', [x[3] for x in data])
            end = [x[3] for x in data]
            start = [min(x, smax-1) for x in start]
            end = [min(x, smax-1) for x in end]
        else:
            show_stats('pointer', [x[2]+x[3]-1 for x in data])
            label_len = [x[3] for x in data]
            end = np.array(start) + np.array(label_len)
            start = [min(x, smax-1) for x in start]
            end = [min(x-1, smax-1) for x in end]

        print("================================")
        printc('Showing passage stats {}'.format(set_type),'cyan')
        show_stats('passage', dlen_raw)
        show_stats('question', qlen_raw)
        print("=================================")

        output = [docs, dlen, q, qlen, start, end]

        # print(dlen)
        self.mdl.register_index_map(0, 'doc_inputs')
        self.mdl.register_index_map(1, 'doc_len')
        self.mdl.register_index_map(2, 'query_inputs')
        self.mdl.register_index_map(3, 'query_len')
        self.mdl.register_index_map(4, 'start_label')
        self.mdl.register_index_map(5, 'end_label')

        if('CHAR' in self.args.rnn_type):
            # print("Preparing Chars...")
            char_index = self.char_index
            char_pad_token = self.char_pad_token
            char_max = self.args.char_max

            def char_idx(x, idx):
                try:
                    return idx[x]
                except:
                    return 1

            def char_ids(d, smax):
                txt = d.split(' ')[:smax]
                _txt = [[char_idx(y, char_index) for y in x] for x in txt]
                _txt = [pad_to_max(x, char_max) for x in _txt]
                _txt = pad_to_max(_txt, smax,
                        pad_token=char_pad_token)
                return _txt

            queries = [x[1] for x in data]
            qc = [char_ids(x[0], self.args.qmax) for x in tqdm(data,
                                                    desc='Prep char query')]
            pc = [char_ids(x, smax) for x in tqdm(passages,
                                                    desc="Prep char doc")]
            qc = np.array(qc).reshape((-1,
                    self.args.qmax, self.args.char_max))
            pc = np.array(pc).reshape((-1,
                    smax, self.args.char_max))

            print("Constructed Char Inputs")
            print(pc.shape)
            print(qc.shape)
            self.mdl.register_index_map(len(output),
                            'doc_char_inputs')
            output.append(pc)
            self.mdl.register_index_map(len(output),
                            'query_char_inputs')
            output.append(qc)

        if(self.args.add_features!=''):
            if('EM' in self.args.add_features):
                qem = [x[1] for x in features]
                pem = [x[0] for x in features]
                if(self.args.adjust==1 and set_type=='train'):
                    qem = apply_alignment(qem, align)
                pem = [pad_to_max(x, smax) for x in pem]
                qem = [pad_to_max(x, qmax) for x in qem]
                pem = np.array(pem).reshape((-1, smax, 1))
                qem = np.array(qem).reshape((-1, qmax, 1))
                self.mdl.register_index_map(len(output), 'doc_feats')
                output.append(pem)
                self.mdl.register_index_map(len(output), 'query_feats')
                output.append(qem)
            if('QT' in self.args.add_features):
                # Question type features
                qt = [question_type(x[1]) for x in data]
                qt = np.array(qt).reshape((-1, 1))
                self.mdl.register_index_map(len(output), 'qt_feats')
                output.append(qt)
            if('FQ' in self.args.add_features):
                freq = [two_way_frequency(x[1], x[0]) for x in data]
                freq_q = [x[0] for x in freq]
                freq_c = [x[1] for x in freq]
                freq_q = [pad_to_max(x, qmax) for x in freq_q]
                freq_c = [pad_to_max(x, smax) for x in freq_c]
                freq_q = np.array(freq_q).reshape((-1, qmax, 1))
                freq_c = np.array(freq_c).reshape((-1, smax, 1))
                self.mdl.register_index_map(len(output), 'doc_fq')
                output.append(freq_c)
                self.mdl.register_index_map(len(output), 'query_fq')
                output.append(freq_q)

        new_data = zip(*output)
        return list(new_data), list(eval_data)

    def train(self):
        """ Main Training Function
        """

        print("Starting training")
        data = list(self.train_set)
        lr = self.args.lr
        for epoch in range(1, self.args.epochs + 1):
            self.write_to_file('===================================')
            losses, accuracies = [],[]

            if(self.args.shuffle==1):
                random.shuffle(data)

            num_batches = int(len(data) / self.args.batch_size)
            accuracies = 0
            all_p = []
            self.sess.run(tf.assign(self.mdl.is_train, self.mdl.true))
            for i in tqdm(range(0, num_batches + 1)):
                batch = make_batch(data,
                        self.args.batch_size, i)
                if(batch is None):
                    continue
                if(self.args.num_gpu>1):
                    # Multi-GPU feed dict
                    feed_dict = {}
                    gpu_bsz = int(len(batch) / self.args.num_gpu)
                    for gid, feed_holder in enumerate(self.mdl.multi_feed_dicts):
                        mbatch = make_batch(batch, gpu_bsz, gid)
                        fd = self.mdl.get_feed_dict_v2(feed_holder,
                                                    mbatch,
                                                    mode='training', lr=lr)
                        feed_dict.update(fd)
                else:
                    feed_dict = self.mdl.get_feed_dict_v2(
                                                self.mdl.feed_holder,
                                                batch,
                                                mode='training', lr=lr)

                _, loss = self.sess.run(
                    [self.mdl.train_op, self.mdl.cost],
                    feed_dict)

                losses.append(loss)

                if(self.args.tensorboard):
                    self.train_writer.add_summary(
                        summary, epoch * num_batches + i)

            self.write_to_file("[{}] [{}] Epoch [{}] Loss={}".format(
                 self.args.dataset, self.model_name, epoch, np.mean(losses)))
            self.write_to_file('[smax={}] [rnn={}] [lr={}] [f={}] [cove={}]'.format(
                                        self.args.smax,
                                        self.args.rnn_size,
                                        lr,self.args.add_features,
                                        self.args.use_cove
                                        ))
            print("[GPU={}]".format(self.args.gpu))

            self.sess.run(tf.assign(self.mdl.is_train,
                                  self.mdl.false))
            lr = self._run_evaluation(epoch, lr)


    def _run_evaluation(self, epoch, lr):
        """ Run Evaluation on test set
        """

        dev_metrics = self.evaluate(epoch, self.dev_set, self.dev_eval,
                                    name='dev',
                                    set_type='dev')
        self._show_metrics(epoch, self.eval_dev,
                                    self.show_metrics,
                                        name='Dev')

        best_epoch, _ = self._select_test_by_dev(epoch, self.eval_dev,
                                    None,
                                    no_test=True,
                                    name='dev')

        if(self.dev_set2 is not None):
            dev_metrics2 = self.evaluate(epoch, self.dev_set2, self.dev_eval2,
                                        name='dev2',
                                        set_type='dev2')
            self._show_metrics(epoch, self.eval_dev2,
                                        self.show_metrics,
                                            name='Dev2')
            best_epoch, _ = self._select_test_by_dev(epoch, self.eval_dev2,
                                        None,
                                        no_test=True,
                                        name='dev2')

        if(self.args.dev_lr>0 and best_epoch!=epoch):
            self.patience +=1
            print('Patience={}'.format(self.patience))
            if(self.patience>=self.args.patience):
                print("Reducing LR by {} times".format(self.args.dev_lr))
                lr = lr / self.args.dev_lr
                print("LR={}".format(lr))
                self.patience = 0

        """ Evaluation on Test Set Locally
        All other datasets which require submission should use test_eval==0
        """
        if(self.has_test==0):
            if('Baidu' in self.args.dataset and epoch==best_epoch):
                print("Best epoch! Writing predictions to file!")
                generate_baidu_test(self, epoch, self.test_set, self.test_eval,
                                            name='test',
                                            set_type='test')
            return lr

        test_metrics = self.evaluate(epoch, self.test_set,
                                    self.test_eval,
                                    name='test',
                                    set_type='test')
        self._show_metrics(epoch, self.eval_test,
                                    self.show_metrics,
                                        name='Test')
        _, max_ep, best_ep = self._select_test_by_dev(epoch, self.eval_dev,
                                    self.eval_test,
                                    no_test=False, name='test')
        return lr