예제 #1
0
def compute_idx(pages_path_in, pages_path_out, vocab):


    f = h5py.File(pages_path_in, 'r')

    if prm.att_doc and prm.att_segment_type == 'sentence':
        nltk.download('punkt')
        tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')

    os.remove(pages_path_out) if os.path.exists(pages_path_out) else None

    # Save to HDF5
    fout = h5py.File(pages_path_out,'a')

    if prm.att_doc:
        shape = (f['text'].shape[0],prm.max_segs_doc,prm.max_words)
    else:
        shape=(f['text'].shape[0],prm.max_words)

    idxs = fout.create_dataset('idx', shape=shape, dtype=np.int32)
    mask = fout.create_dataset('mask', shape=(f['text'].shape[0],), dtype=np.float32)

    i = 0
    for text in f['text']:
        st = time.time()

        if prm.att_doc:
            if prm.att_segment_type.lower() == 'section' or prm.att_segment_type.lower() == 'subsection':
                segs = ['']
                for line in text.split('\n'):
                    if prm.att_segment_type == 'section':
                        line = line.replace('===', '')
                    if line.strip().startswith('==') and line.strip().endswith('=='):
                        segs.append('')
                    segs[-1] += line.lower() + '\n'
            elif prm.att_segment_type.lower() == 'sentence':
                segs = tokenizer.tokenize(text.lower().decode('ascii', 'ignore'))
            else:
                raise ValueError('Not a valid value for the attention segment type (att_segment_type) parameter. Valid options are "section", "subsection" or "sentence".')

            segs = segs[:prm.max_segs_doc]
            idxs_, _ = utils.text2idx2(segs, vocab, prm.max_words)
            idxs[i,:len(idxs_),:] = idxs_
            mask[i] = len(idxs_)
        else:
            idx, _ = utils.text2idx2([text.lower()], vocab, prm.max_words)
            idxs[i,:] = idx[0]
        i += 1

        #if i > 3000:
        #    break

        print 'processing article', i, 'time', time.time()-st

    f.close()
    fout.close()
예제 #2
0
    def reset(self):
        print '********************reset**********'
        """
        Reset environment and setup for new episode.

        Returns:
            initial state of resetted environment.
        """
        t0 = time()
        qi, qi_i, qi_lst, D_gt_id, D_gt_title = self.get_samples(
            sample_num=self.batch_size,
            max_words_input=self.search.max_words_input)
        current_query_text = qi_lst
        current_query_code = qi_i
        self.action = current_query_code
        print 'reset current query text:', current_query_text
        print 'reset current query code:', current_query_code
        print 'reset action code:', self.action
        self.D_gt_id = D_gt_id
        self.D_gt_title = D_gt_title
        metrics, D_i_, D_id_, D_gt_m_ = self.search.perform(
            current_query_code, D_gt_id, self.is_train, current_query_text)
        expanded_query_text = current_query_text
        metric_idx = self.search.metrics_map[self.rewardtype.upper()]
        reward = metrics[:, metric_idx]
        expanded_query_text = self.process_expanded_words(expanded_query_text)
        print ' reset expanded text:', expanded_query_text
        expanded_query_code, expandes_lst_ = utils.text2idx2(
            expanded_query_text, self.vocab, self.cfg['search']['max_terms'])
        #expanded_query_code, terminal,reward =  self.execute(self.action)
        print 'reward', reward
        self.state = expanded_query_code
        self.counsteps = 0
        return self.state
예제 #3
0
 def execute(self, actions):
     print 'execute'
     done = False
     reformulated_query = actions
     current_queries = self.current_queries
     D_gt_id = self.D_gt_id
     metrics, D_i_, D_id_, D_gt_m_ = self.search.perform(
         reformulated_query, D_gt_id, self.is_train, current_queries)
     # print "D_id_", D_id_
     # i = 3
     # print "ALALALA ", [self.search.engine.id_title_map[d_id] for d_id in D_id_[i]]
     text = [[self.search.engine.id_title_map[d_id] for d_id in D_id_[i]]
             for i in range(D_id_.shape[0])]
     actions = current_queries
     metric_idx = self.search.metrics_map[self.reward.upper()]
     reward = metrics[metric_idx]
     if (len(actions) == 0):  # or self.counsteps > 10):
         done = True
     state = [
         utils.text2idx2(t, self.vocab, dim=self.search.max_words_input)[0]
         for t in text
     ]
     # reward = 1.0
     metric_idx = self.search.metrics_map[self.reward.upper()]
     reward = metrics[:, metric_idx].sum()
     return state, done, reward
예제 #4
0
    def get_samples(self,
                    input_queries,
                    target_docs,
                    vocab,
                    index,
                    engine,
                    max_words_input=200):
        qi = [utils.clean(input_queries[t].lower()) for t in index]
        D_gt_title = [target_docs[t] for t in index]

        D_gt_id_lst = []
        for j, t in enumerate(index):
            #print("j",j)
            D_gt_id_lst.append([])
            for title in D_gt_title[j]:
                #print("title", title)
                if title in engine.title_id_map:
                    D_gt_id_lst[-1].append(engine.title_id_map[title])
                #else:
                #    print 'ground-truth doc not in index:', title

        D_gt_id = utils.lst2matrix(D_gt_id_lst)
        qi_i, qi_lst_ = utils.text2idx2(qi, vocab, max_words_input)

        qi_lst = []
        for qii_lst in qi_lst_:
            # append empty strings, so the list size becomes <dim>.
            qi_lst.append(qii_lst +
                          max(0, max_words_input - len(qii_lst)) * [''])
        return qi, qi_i, qi_lst, D_gt_id, D_gt_title
    def get_samples(self, sample_num=1, max_words_input=200):
        if sample_num <= 0:
            sample_num = 1
        train_index = utils.get_one_random_batch_idx(len(self.qi), sample_num)

        input_queries = self.qi
        target_docs = self.dt
        vocab = self.vocab
        engine = self.search.engine
        qi = [utils.clean(input_queries[t].lower()) for t in train_index]
        D_gt_title = [target_docs[t] for t in train_index]

        D_gt_id_lst = []
        for j, t in enumerate(train_index):
            #print("j",j)
            D_gt_id_lst.append([])
            for title in D_gt_title[j]:
                #print("title", title)
                if title in engine.title_id_map:
                    D_gt_id_lst[-1].append(engine.title_id_map[title])
                #else:
                #    print 'ground-truth doc not in train_index:', title

        D_gt_id = utils.lst2matrix(D_gt_id_lst)

        qi_i, qi_lst_ = utils.text2idx2(qi, vocab, max_words_input)
        #print("qi_i", qi_i)
        #print("qi_lst_", qi_lst_)

        qi_lst = []
        for qii_lst in qi_lst_:
            # append empty strings, so the list size becomes <dim>.
            qi_lst.append(qii_lst +
                          max(0, max_words_input - len(qii_lst)) * [''])
        return qi, qi_i, qi_lst, D_gt_id, D_gt_title
 def execute(self, actions):
     done = False
     # reformulated_query, current_queries, D_gt_id = action
     # print(actions)
     reformulated_query = actions
     # print(reformulated_query)
     current_queries = self.current_queries
     D_gt_id = self.D_gt_id
     metrics, D_i_, D_id_, D_gt_m_ = self.search.perform(
         reformulated_query, D_gt_id, self.is_train, current_queries)
     print "D_id_", D_id_
     i = 3
     print "ALALALA ", [
         self.search.engine.id_title_map[d_id] for d_id in D_id_[i]
     ]
     text = [[self.search.engine.id_title_map[d_id] for d_id in D_id_[i]]
             for i in range(D_id_.shape[0])]
     actions = current_queries
     metric_idx = self.search.metrics_map[self.reward.upper()]
     reward = metrics[metric_idx]
     if (len(actions) == 0):  # or self.counsteps > 10):
         done = True
     # return [text, actions], reward, done, {}            # text: candidates return by search, actions: previous query. Combine provide the states
     # return text, reward, done, {}
     state = [
         utils.text2idx2(t, self.vocab, dim=self.search.max_words_input)[0]
         for t in text
     ]
     reward = 1.0
     return state, done, reward
예제 #7
0
    def add_doc(self, doc_id, title, txt, add_terms):

        doc = Document()
        txt = utils.clean(txt)

        if add_terms:
            if prm.top_tfidf > 0:
                words_idx = []
                words, _ = utils.top_tfidf(txt.lower(), self.idf,
                                           prm.top_tfidf, prm.min_term_freq)

                if len(words) == 0:
                    words.append('unk')

                for w in words:
                    if w in self.vocab:
                        words_idx.append(self.vocab[w])
                    else:
                        words_idx.append(-1)  # unknown words.

            else:
                txt_ = txt.lower()
                words_idx, words = utils.text2idx2([txt_], self.vocab,
                                                   prm.max_terms_per_doc)
                words_idx = words_idx[0]
                words = words[0]

        doc.add(Field("id", str(doc_id), self.t1))
        doc.add(Field("title", title, self.t1))
        doc.add(Field("text", txt, self.t2))
        if add_terms:
            doc.add(Field("word_idx", ' '.join(map(str, words_idx)), self.t3))
            doc.add(Field("word", '<&>'.join(words), self.t3))
        self.writer.addDocument(doc)
예제 #8
0
    def execute(self, actions):
        """
        Executes action, observes next state(s) and reward.

        Args:
            actions: Actions to execute.

        Returns:
            (Dict of) next state(s), boolean indicating terminal, and reward signal.
        """
        done = False
        #reformulated_query, current_queries, D_gt_id = actions

        self.counsteps += 1
        print "************execute query reformulation iteration:", self.counsteps
        #print 'current states........', self.state
        #query_text = utils.idx2text2(self.state, self.vocabinv)

        n, m = np.shape(actions)
        query_index = -2 * np.ones((n, m), dtype='int')
        for i in range(len(actions)):
            actionids = actions[i]
            for j in range(len(actionids)):
                query_index[i, j] = self.state[i, actionids[j]]

        print 'current_actions = ', np.shape(query_index), query_index
        '''action_status=utils.is_emptyaction(query_index)
        if sum(action_status)<self.batch_size:
            self.state=self.reset()
            query_index=self.action
            
        print 'current_actions = ', np.shape(query_index), query_index'''
        query_text = utils.idx2text2(query_index, self.vocabinv)
        print 'current actions = ', query_text
        D_gt_id = self.D_gt_id
        metrics, D_i_, D_id_, D_gt_m_ = self.search.perform(
            query_index, D_gt_id, self.is_train, query_text)

        #self.D_gt_id=D_id_

        #print 'current_queries (after calling search.perform) = ',query_text
        i = 0
        expanded_query_text = query_text
        metric_idx = self.search.metrics_map[self.rewardtype.upper()]
        reward = metrics[:, metric_idx]

        if self.counsteps > self.cfg['reformulation']['max_steps']:
            done = True
        print 'expanded query text,:', expanded_query_text
        expanded_query_text = self.process_expanded_words(expanded_query_text)
        #print ' process_expanded_words:', expanded_query_text
        expanded_i, expandes_lst_ = utils.text2idx2(
            expanded_query_text, self.vocab, self.cfg['search']['max_terms'])
        #self.state= expanded_i
        #print 'execution__new state id after expansion',expanded_i
        terminal = done
        return self.state, terminal, sum(reward)
예제 #9
0
    def add_doc(self, doc_id, title, txt, add_terms):

        doc = Document()
        txt = utils.clean(txt)

        if add_terms:
            txt_ = txt.lower()
            words_idx, words = utils.text2idx2([txt_], self.vocab,
                                               prm.max_terms_per_doc)
            words_idx = words_idx[0]
            words = words[0]

        doc.add(Field("id", str(doc_id), self.t1))
        doc.add(Field("title", title, self.t1))
        doc.add(Field("text", txt, self.t2))
        if add_terms:
            doc.add(Field("word_idx", ' '.join(map(str, words_idx)), self.t3))
            doc.add(Field("word", '<&>'.join(words), self.t3))
        self.writer.addDocument(doc)
예제 #10
0
def get_samples(input_queries, target_docs, index, options):
    qi = [utils.clean(input_queries[t].lower()) for t in index]
    D_gt_title = [target_docs[t] for t in index]

    D_gt_id_lst = []
    for j, t in enumerate(index):
        D_gt_id_lst.append([])
        for title in D_gt_title[j]:
            if title in options['engine'].title_id_map:
                D_gt_id_lst[-1].append(options['engine'].title_id_map[title])
            else:
                print 'ground-truth doc not in index:', title

    D_gt_id = lst2matrix(D_gt_id_lst)
    
    qi_i, qi_lst_ = utils.text2idx2(qi, options['vocab'], prm.max_words_input)
    
    qi_lst = []
    for qii_lst in qi_lst_:
        # append empty strings, so the list size becomes <dim>.
        qi_lst.append(qii_lst + max(0, prm.max_words_input - len(qii_lst)) * [''])

    return qi, qi_i, qi_lst, D_gt_id, D_gt_title
예제 #11
0
def compute_idx(pages_path_in, pages_path_out, vocab):

    f = h5py.File(pages_path_in, 'r')

    if prm.att_doc and prm.att_segment_type == 'sentence':
        nltk.download('punkt')
        tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')

    os.remove(pages_path_out) if os.path.exists(pages_path_out) else None

    # Save to HDF5
    fout = h5py.File(pages_path_out, 'a')

    if prm.att_doc:
        shape = (f['text'].shape[0], prm.max_segs_doc, prm.max_words)
    else:
        shape = (f['text'].shape[0], prm.max_words)

    idxs = fout.create_dataset('idx', shape=shape, dtype=np.int32)
    mask = fout.create_dataset('mask',
                               shape=(f['text'].shape[0], ),
                               dtype=np.float32)

    i = 0
    for text in f['text']:
        st = time.time()

        if prm.att_doc:
            if prm.att_segment_type.lower(
            ) == 'section' or prm.att_segment_type.lower() == 'subsection':
                segs = ['']
                for line in text.split('\n'):
                    if prm.att_segment_type == 'section':
                        line = line.replace('===', '')
                    if line.strip().startswith('==') and line.strip().endswith(
                            '=='):
                        segs.append('')
                    segs[-1] += line.lower() + '\n'
            elif prm.att_segment_type.lower() == 'sentence':
                segs = tokenizer.tokenize(text.lower().decode(
                    'ascii', 'ignore'))
            elif prm.att_segment_type.lower() == 'word':
                segs = wordpunct_tokenize(text.decode('ascii', 'ignore'))
            else:
                raise ValueError(
                    'Not a valid value for the attention segment type (att_segment_type) parameter. Valid options are "section", "subsection", "sentence", or "word".'
                )

            segs = segs[:prm.max_segs_doc]
            idxs_, _ = utils.text2idx2(segs, vocab, prm.max_words)
            idxs[i, :len(idxs_), :] = idxs_
            mask[i] = len(idxs_)
        else:
            idx, _ = utils.text2idx2([text.lower()], vocab, prm.max_words)
            idxs[i, :] = idx[0]
        i += 1

        #if i > 3000:
        #    break

        print 'processing article', i, 'time', time.time() - st

    f.close()
    fout.close()
예제 #12
0
파일: run.py 프로젝트: jxwuyi/WebNav
def train_lstm():

    optimizer=adam  # only adam is supported by now.
    options = locals().copy()
    with open(prm.outpath, "a") as fout:
        fout.write("parameters:" + str(options) + str(prm.__dict__))

    print "loading dictionary..."
    vocab = utils.load_vocab(prm.vocab_path, prm.n_words)
    options['vocab'] = vocab

    options['vocabinv'] = {}
    for k,v in vocab.items():
        options['vocabinv'][v] = k

    print 'Loading data...'
    options['wiki'] = wiki.Wiki(prm.pages_path)
    options['wikiemb'] = wiki_emb.WikiEmb(prm.pages_emb_path)

    #load Q&A Wiki dataset
    qpp = qp.QP(prm.qp_path)
    q_train, q_valid, q_test = qpp.get_queries()
    a_train, a_valid, a_test = qpp.get_paths()

    print 'Building model'
    # This create the initial parameters as np ndarrays.
    # Dict name (string) -> np ndarray
    params, exclude_params = init_params()

    if prm.wordemb_path:
        print 'loading pre-trained weights for word embeddings'
        params = load_wemb(params, vocab)
        options['W'] = params['W']

    if prm.reload_model:
        load_params(prm.reload_model, params)

    params_next = OrderedDict()
    if prm.learning.lower() == 'q_learning' and prm.update_freq > 0:
        # copy params to params_next
        for kk, kv in params.items():
            params_next[kk] = kv.copy()

    # This create Theano Shared Variable from the parameters.
    # Dict name (string) -> Theano Tensor Shared Variable
    # params and tparams have different copy of the weights.
    tparams = init_tparams(params)

    if prm.update_freq > 0:
        tparams_next = init_tparams(params_next)
    else:
        tparams_next = None
  
    if prm.learning.lower() == 'reinforce':
        R_mean = theano.shared(0.71*np.ones((1,)), name='R_mean')
        R_std = theano.shared(np.ones((1,)), name='R_std')
        baseline_vars = {'R_mean': R_mean, 'R_std': R_std}
    else:
        baseline_vars = {}

    iin, out, updates, is_train, sup, max_hops, k_beam, mixer, f_pred, consider_constant \
            = build_model(tparams, tparams_next, baseline_vars, options)

    #get only parameters that are not in the exclude_params list
    tparams_ = OrderedDict([(kk, vv) for kk, vv in tparams.iteritems() if kk not in exclude_params])

    grads = tensor.grad(out[0], wrt=itemlist(tparams_), consider_constant=consider_constant)

    lr = tensor.scalar(name='lr')
    f_grad_shared, f_update = optimizer(lr, tparams_, grads, iin, out, updates)

    print 'Optimization'

    if prm.train_size == -1:
        train_size = len(q_train)
    else:
        train_size = prm.train_size

    if prm.valid_size == -1:
        valid_size = len(q_valid)
    else:
        valid_size = prm.valid_size

    if prm.test_size == -1:
        test_size = len(q_test)
    else:
        test_size = prm.test_size

    with open(prm.outpath, "a") as fout:
        fout.write("\n%d train examples" % len(q_train)) 
    with open(prm.outpath, "a") as fout:
        fout.write("\n%d valid examples" % len(q_valid)) 
    with open(prm.outpath, "a") as fout:
        fout.write("\n%d test examples" % len(q_test))

    history_errs = []
    best_p = None

    if prm.validFreq == -1:
        validFreq = len(q_train) / prm.batch_size_train
    else:
        validFreq = prm.validFreq

    if prm.saveFreq == -1:
        saveFreq = len(q_train) / prm.batch_size_train
    else:
        saveFreq = prm.saveFreq

    uidx = 0  # the number of update done
    estop = False  # early stop
    start_time = time.time()
    
    experience = deque(maxlen=prm.replay_mem_size) # experience replay memory as circular buffer.
    experience_r = deque(maxlen=prm.replay_mem_size) # reward of each entry in the replay memory.

    try:
        for eidx in xrange(prm.max_epochs):
            n_samples = 0

            # Get new shuffled index for the training set.
            kf = get_minibatches_idx(len(q_train), prm.batch_size_train, shuffle=True)

            for _, train_index in kf:
                st = time.time()

                uidx += 1
                is_train.set_value(1.)
                max_hops.set_value(prm.max_hops_train) # select training dataset
                k_beam.set_value(1) # Training does not use beam search
                
                # Select the random examples for this minibatch
                queries = [q_train[t].lower() for t in train_index]
                actions = [a_train[t] for t in train_index]
                
                if prm.learning.lower() == 'supervised':
                    sup.set_value(1.) # select supervised mode
                else:
                    sup.set_value(0.)

                # Get correct actions (supervision signal)
                acts_p =  get_acts(actions, prm.max_hops_train, k_beam=1)

                # MIXER
                if prm.mixer > 0 and prm.learning.lower() == 'reinforce':
                    mixer.set_value(max(0, prm.max_hops_train - uidx // prm.mixer))
                else:
                    if prm.learning.lower() == 'supervised':
                        mixer.set_value(prm.max_hops_train+1)
                    else:
                        mixer.set_value(0)

                root_pages = get_root_pages(actions)                
                
                # Get the BoW for the queries.
                q_i, q_m = utils.text2idx2(queries, vocab, prm.max_words_query*prm.n_consec)
                n_samples += len(queries)
                
                if uidx > 1 and prm.learning.lower() == 'q_learning':
                    # Randomly select experiences and convert them to numpy arrays.
                    idxs = np.random.choice(np.arange(len(experience)), size=len(queries))
                    rvs = []
                    for j in range(len(experience[idxs[0]])):
                        rv = []
                        for idx in idxs:
                            rv.append(experience[idx][j])

                        rvs.append(np.asarray(rv))
                else:
                    rvs = [np.zeros((len(queries),prm.max_words_query*prm.n_consec),dtype=np.float32), # rs_q
                           np.zeros((len(queries),prm.max_words_query*prm.n_consec),dtype=np.float32), # rs_q_m
                           np.zeros((len(queries),prm.max_hops_train+1),dtype=np.int32), # rl_idx
                           np.zeros((len(queries),prm.max_hops_train+1),dtype=np.float32), # rt
                           np.zeros((len(queries),prm.max_hops_train+1),dtype=np.float32) # rr
                          ]

                cost, R, l_idx, pages_idx, best_page_idx, best_answer, mask, dist \
                        = f_grad_shared(q_i, q_m, root_pages, acts_p, uidx, *rvs)
                f_update(prm.lrate)

                if prm.learning.lower() == 'q_learning': 
                    # update weights of the next_q_val network.
                    if (prm.update_freq > 0 and uidx % prm.update_freq == 0) or (uidx == prm.replay_start):
                        for tk, tv in tparams.items():
                            if tk in tparams_next:
                                tparams_next[tk].set_value(tv.get_value().copy())

                # Only update memory after freeze_mem or before replay_start.
                if (uidx < prm.replay_start or uidx > prm.freeze_mem) and prm.learning.lower() == 'q_learning':
                    # Update Replay Memory.
                    t = np.zeros((len(queries), prm.max_hops_train+1))
                    rR = np.zeros((len(queries), prm.max_hops_train+1))

                    for i in range(len(queries)):
                        j = np.minimum(mask[i].sum(), prm.max_hops_train)
                        # If the agent chooses to stop or the episode ends,
                        # the reward will be the reward obtained with the chosen document.
                        rR[i,j] = R[i]
                        t[i,j] = 1.
                        
                        add = True
                        if prm.selective_mem >= 0 and uidx > 1:
                            # Selective memory: keep the percentage of memories
                            # with reward=1 approximately equal to <selective_mem>.
                            pr = float(np.asarray(experience_r).sum()) / max(1., float(len(experience_r)))
                            if (pr < prm.selective_mem) ^ (rR[i,j] == 1.): # xor
                                add = False

                        if add:
                            experience.append([q_i[i], q_m[i], l_idx[i], t[i], rR[i]])
                            experience_r.append(rR[i])

                if np.isnan(cost) or np.isinf(cost):
                    print 'NaN detected'
                    return 1., 1., 1.
    
                #if uidx % 100 == 0:
                #    vis_att(pages_idx[:,-1], queries[-1], alpha[:,-1,:], uidx, options)

                if np.mod(uidx, prm.dispFreq) == 0:
                    with open(prm.outpath, "a") as fout:
                        fout.write("\n\nQuery: " + queries[-1].replace("\n"," "))
                        fout.write('\nBest Answer: ' + utils.idx2text(best_answer[-1], options['vocabinv']))
                        fout.write('\nBest page: ' + options['wiki'].get_article_title(best_page_idx[-1]))

                        for i, page_idx in enumerate(pages_idx[:,-1]):
                            fout.write('\niteration: ' +str(i) + " page idx " + str(page_idx) + ' title: ' + options['wiki'].get_article_title(page_idx))
                       
                        fout.write('\nEpoch '+ str(eidx) + ' Update '+ str(uidx) + ' Cost ' + str(cost) + \
                                   ' Reward Mean ' + str(R.mean()) + ' Reward Max ' + str(R.max()) + \
                                   ' Reward Min ' + str(R.min()) + \
                                   ' Q-Value Max (avg per sample) ' + str(dist.max(2).mean()) + \
                                   ' Q-Value Mean ' + str(dist.mean()))
                        #fout.write("\nCost Supervised: " + str(cost_sup))
                        #fout.write("\nCost RL: " + str(cost_RL))

                        fout.write("\nTime per Minibatch Update: " + str(time.time() - st))
                       

                if prm.saveto and np.mod(uidx, saveFreq) == 0:
                    print 'Saving...',

                    if best_p is not None:
                        params = best_p
                    else:
                        params = unzip(tparams)
                    np.savez(prm.saveto, history_errs=history_errs, **params)
                    pkl.dump(options, open('%s.pkl' % prm.saveto, 'wb'), -1)
                    print 'Done'

                if np.mod(uidx, validFreq) == 0 or uidx == 1:
                    if prm.visited_pages_path:
                        shuffle = False
                    else:
                        shuffle = True
                    kf_train = get_minibatches_idx(len(q_train), prm.batch_size_pred, shuffle=shuffle, max_samples=train_size)
                    kf_valid = get_minibatches_idx(len(q_valid), prm.batch_size_pred, shuffle=shuffle, max_samples=valid_size)
                    kf_test = get_minibatches_idx(len(q_test), prm.batch_size_pred, shuffle=shuffle, max_samples=test_size)

                    is_train.set_value(0.)
                    sup.set_value(0.) # supervised mode off
                    mixer.set_value(0) # no supervision
                    max_hops.set_value(prm.max_hops_pred)
                    k_beam.set_value(prm.k)

                    with open(prm.outpath, 'a') as fout:
                        fout.write('\n\nComputing Error Training Set')
                    train_err, train_R, train_accp, visited_pages_train = pred_error(f_pred, q_train, a_train, options, kf_train)

                    with open(prm.outpath, 'a') as fout:
                        fout.write('\n\nComputing Error Validation Set')
                    valid_err, valid_R, valid_accp, visited_pages_valid = pred_error(f_pred, q_valid, a_valid, options, kf_valid)

                    with open(prm.outpath, 'a') as fout:
                        fout.write('\n\nComputing Error Test Set')
                    test_err, test_R, test_accp, visited_pages_test = pred_error(f_pred, q_test, a_test, options, kf_test)

                    if prm.visited_pages_path:
                        pkl.dump([visited_pages_train, visited_pages_valid, visited_pages_test], open(prm.visited_pages_path, 'wb'))

                    history_errs.append([valid_err[-1], test_err[-1]])

                    if (uidx == 0 or
                        valid_err[-1] <= np.array(history_errs)[:,0].min()):

                        best_p = unzip(tparams)
                        bad_counter = 0

                    with open(prm.outpath, "a") as fout:
                        fout.write('\n[{per hop}, Avg] Train err ' + str(train_err) + '  Valid err ' + str(valid_err) + '  Test err ' + str(test_err))
                        fout.write('\n[{per hop}, Avg] Train R ' + str(train_R) + '  Valid R ' + str(valid_R) + '  Test R ' + str(test_R))
                        fout.write('\nAccuracy Page Actions   Train ' + str(train_accp) + '  Valid ' + str(valid_accp) + '  Test ' + str(test_accp))

                    if (len(history_errs) > prm.patience and
                        valid_err[-1] >= np.array(history_errs)[:-prm.patience,
                                                               0].min()):
                        bad_counter += 1
                        if bad_counter > prm.patience:
                            print 'Early Stop!'
                            estop = True
                            break

            with open(prm.outpath, "a") as fout:
                fout.write('\nSeen %d samples' % n_samples)

            if estop:
                break

    except KeyboardInterrupt:
        print "Training interupted"

    end_time = time.time()
    if best_p is not None:
        zipp(best_p, tparams)
    else:
        best_p = unzip(tparams)

    is_train.set_value(0.)
    sup.set_value(0.) # supervised mode off
    mixer.set_value(0) # no supervision
    max_hops.set_value(prm.max_hops_pred)
    k_beam.set_value(prm.k)

    kf_train_sorted = get_minibatches_idx(len(q_train), prm.batch_size_train)

    train_err, train_R, train_accp, visited_pages_train = pred_error(f_pred, q_train, a_train, options, kf_train_sorted)
    valid_err, valid_R, valid_accp, visited_pages_valid = pred_error(f_pred, q_valid, a_valid, options, kf_valid)
    test_err, test_R, test_accp, visited_pages_test = pred_error(f_pred, q_test, a_test, options, kf_test)

    with open(prm.outpath, "a") as fout:
        fout.write('\n[{per hop}, Avg] Train err ' + str(train_err) + '  Valid err ' + str(valid_err) + '  Test err ' + str(test_err))
        fout.write('\n[{per hop}, Avg] Train R ' + str(train_R) + '  Valid R ' + str(valid_R) + '  Test R ' + str(test_R))
        fout.write('\nAccuracy Page Actions   Train ' + str(train_accp) + '  Valid ' + str(valid_accp) + '  Test ' + str(test_accp))

    if prm.saveto:
        np.savez(prm.saveto, train_err=train_err,
                    valid_err=valid_err, test_err=test_err,
                    history_errs=history_errs, **best_p)
    with open(prm.outpath, "a") as fout:
        fout.write('\nThe code run for %d epochs, with %f sec/epochs' % ((eidx + 1), (end_time - start_time) / (1. * (eidx + 1))))
    with open(prm.outpath, "a") as fout:
        fout.write('\nTraining took %.1fs' % (end_time - start_time))
    return train_err, valid_err, test_err
예제 #13
0
파일: run.py 프로젝트: jxwuyi/WebNav
def pred_error(f_pred, queries, actions, options, iterator, verbose=False):
    """
    Just compute the error
    f_pred: Theano functin computing the prediction
    """

    valid_acc = np.zeros((prm.max_hops_train + 2), dtype=np.float32)
    valid_R = np.zeros((prm.max_hops_train + 2), dtype=np.float32)
    n = np.zeros((prm.max_hops_train + 2), dtype=np.float32)
    acts_pc = 0.
    acts_pt = 0.
    uidx = -1
    visited_pages = []            

    for _, valid_index in iterator:
        q_i, q_m = utils.text2idx2([queries[t].lower() for t in valid_index], options['vocab'], prm.max_words_query*prm.n_consec)
        acts = [actions[t] for t in valid_index]

        #fake acts that won't be used in the prediction
        acts_p = -np.ones((prm.max_hops_pred+1, len(q_i) * prm.k), dtype=np.float32)
        
        root_pages = get_root_pages(acts)

        best_answer, best_page_idx, R, pages_idx = f_pred(q_i, q_m, root_pages, acts_p, uidx)

        pages_idx_ = np.swapaxes(pages_idx,0,1)
        pages_idx_ = pages_idx_.reshape(pages_idx_.shape[0],-1)

        #get pages visited:
        for page_idx in pages_idx_:
            visited_pages.append([])
            for idx in page_idx:
                if idx != -1:
                    visited_pages[-1].append(idx)

        R_binary = np.ones_like(R)
        R_binary[R<1.0] = 0.0
        n[-1] += len(valid_index)
        valid_R[-1] += R.sum()
        valid_acc[-1] += R_binary.sum()
        
        # get correct page-actions.
        acts_p = get_acts(acts, prm.max_hops_pred, prm.k)

        pages_idx = pages_idx.reshape((pages_idx.shape[0],-1))

        # Check how many page actions the model got right.
        mask_pc = np.logical_or((pages_idx != -1.0), (acts_p != -1.0)).astype('float32')
        acts_pc += ((pages_idx == acts_p).astype('float32') * mask_pc).sum()
        acts_pt += mask_pc.sum() #total number of actions

        # compute accuracy per hop
        for i in range(prm.max_hops_train+1):
            n_hops = (acts_p != -1.0).astype('float32').sum(0)
            n_hops= n_hops.reshape((-1, prm.k))[:,0] # beam search use only the first n_samples actions
            ih = (n_hops==i)
            valid_R[i] += R[ih].sum()
            valid_acc[i] += R_binary[ih].sum()
            n[i] += ih.astype('float32').sum()

        with open(prm.outpath, 'a') as fout:
            fout.write("\n\nQuery: " + queries[valid_index[-1]].replace("\n"," "))
            nh = (acts_p[:,-1] != -1.0).astype('int32').sum()
            if nh == 0:
                fout.write('\nCorrect Path: ' + options['wiki'].get_article_title(int(root_pages[-1])))
            else:
                path = ''
                for a in acts_p[:nh, -1]:
                    path += ' -> ' + options['wiki'].get_article_title(int(a))
                fout.write('\nCorrect Path: ' + path)

            fout.write('\nNumber of hops: ' + str(int(nh)))
            fout.write('\nBest answer: ' + utils.idx2text(best_answer[-1], options['vocabinv']))
            fout.write('\nBest page: ' + options['wiki'].get_article_title(best_page_idx[-1]))
            for i, pageidx in enumerate(pages_idx[:,-1]):
                fout.write('\niteration: ' +str(i) + " page idx " + str(pageidx) + ' title '+ options['wiki'].get_article_title(pageidx))

        uidx -= 1
        
    valid_R = valid_R / n
    valid_err = 1 - valid_acc / n
    acts_pc = acts_pc / acts_pt

    return valid_err, valid_R, acts_pc, visited_pages