Ejemplo n.º 1
0
 def train(self):
     start_time = time.time()
     print('Training and evaluating...')
     self.sess.run(tf.global_variables_initializer())
     l = self.X_train["feature_emb"].shape[0]
     qid_unique = np.unique(self.X_train["qid"])
     num_qid_unique = len(qid_unique)
     train_idx_shuffle = np.arange(l)
     total_batch = 0
     # evaluate before training
     #loss_mean_valid, err_mean_valid, ndcg_mean_valid, ndcg_all_mean_valid = self.evaluate(self.X_valid)
     # training model...
     for epoch in range(2):
         np.random.seed(epoch)
         np.random.shuffle(train_idx_shuffle)
         batches = get_batch_index(train_idx_shuffle, conf.batch_size)
         for i, idx in enumerate(batches):
             ind = idx
             feed_dict = self._get_feed_dict(self.X_train, ind, training=True)
             #Asess_out = self.sess.run({'loss': self.loss, 'train_op': self.train_op, 'learn_rate': self.learning_rate,
             #                           'debug_info': self.debug_info}, feed_dict=feed_dict)
             loss, lr, opt = self.sess.run((self.loss, self.learning_rate, self.train_op), feed_dict=feed_dict)
             total_batch += 1
             if total_batch % conf.eval_every_num_update == 0:
                 loss_mean_valid, err_mean_valid, ndcg_mean_valid, ndcg_all_mean_valid = self.evaluate(self.X_valid)
                 print("[epoch-{}, batch-{}] -- Train Loss: {:5f} -- Valid Loss: {:5f} NDCG: {:5f} -- {:5f} s".format(
                         epoch + 1, total_batch, loss, loss_mean_valid, ndcg_mean_valid, time.time() - start_time))
             a=1
     self.save_session()
Ejemplo n.º 2
0
 def get_batch_data(self, aspects, contexts, labels, context_lens, cost_ws, feature_eng, aspects_ch, contexts_ch,
                    aspects_ch_lens, contexts_ch_lens, batch_size, is_shuffle, keep_prob):
     total = int(len(context_lens) / batch_size) + 1 if len(context_lens) % batch_size != 0 else int(
         len(context_lens) / batch_size)
     for index in tqdm_notebook(get_batch_index(len(aspects), batch_size, is_shuffle), total=total):
         feed_dict = {
             self.aspects: aspects[index],
             self.contexts: contexts[index],
             self.context_lens: context_lens[index],
             self.feature_eng: feature_eng[index],
             self.cost_ws: cost_ws[index],
             self.dropout_keep_prob: keep_prob
         }
         if self.use_char_emb:
             feed_dict[self.aspects_ch] = aspects_ch[index]
             feed_dict[self.contexts_ch] = contexts_ch[index]
             feed_dict[self.aspects_ch_lens] = aspects_ch_lens[index].reshape(-1)
             feed_dict[self.contexts_ch_lens] = contexts_ch_lens[index]
         if len(labels) > 0: feed_dict[self.labels] = labels[index]
         if keep_prob < 1 and self.random_del_prob > 0:  # train mode
             cnt_per_row = int(self.random_del_prob * self.max_context_len)
             zero_rows = np.concatenate([np.ones(cnt_per_row).astype(int) * i for i in range(len(index))])
             zero_cols = np.random.choice(self.max_context_len, len(index) * cnt_per_row)
             if self.use_char_emb:
                 feed_dict[self.contexts_ch_lens][zero_rows, zero_cols] = 0
             # 这里不一定是0,可以试试别的
             feed_dict[self.contexts][zero_rows, zero_cols] = 0
         if self.use_char_emb:
             feed_dict[self.contexts_ch_lens] = feed_dict[self.contexts_ch_lens].reshape(-1)
         yield feed_dict, len(index)
Ejemplo n.º 3
0
 def get_batch_data(self, aspects, contexts, labels, aspect_lens, context_lens, batch_size, is_shuffle, keep_prob):
     for index in get_batch_index(len(aspects), batch_size, is_shuffle):
         feed_dict = {
             self.aspects: aspects[index],
             self.contexts: contexts[index],
             self.labels: labels[index],
             self.aspect_lens: aspect_lens[index],
             self.context_lens: context_lens[index],
             self.dropout_keep_prob: keep_prob,
         }
         yield feed_dict, len(index)
Ejemplo n.º 4
0
 def get_batch_data(self, sentences, aspects, sentence_lens, sentence_locs, labels, batch_size, is_shuffle, keep_prob):
     for index in get_batch_index(len(sentences), batch_size, is_shuffle):
         feed_dict = {
             self.sentences: sentences[index],
             self.aspects: aspects[index],
             self.sentence_lens: sentence_lens[index],
             self.sentence_locs: sentence_locs[index],
             self.labels: labels[index],
             self.dropout_keep_prob: keep_prob,
         }
         yield feed_dict, len(index)
Ejemplo n.º 5
0
 def get_batch_data(self, sentences, aspects, sentence_lens, sentence_locs,
                    labels, batch_size, is_shuffle, keep_prob):
     for index in get_batch_index(len(sentences), batch_size, is_shuffle):
         feed_dict = {
             self.sentences: sentences[index],
             self.aspects: aspects[index],
             self.sentence_lens: sentence_lens[index],
             self.sentence_locs: sentence_locs[index],
             self.labels: labels[index],
             self.dropout_keep_prob: keep_prob,
         }
         yield feed_dict, len(index)
Ejemplo n.º 6
0
 def get_batch(aspects, contexts, labels, aspect_lens, context_lens, aspect_lex, context_lex,
               batch_size, is_shuffle, keep_prob):
     aspects = np.array(aspects)
     contexts = np.array(contexts)
     labels = np.array(labels)
     aspect_lens = np.array(aspect_lens)
     context_lens = np.array(context_lens)
     context_lex = np.array(context_lex)
     aspect_lex = np.array(aspect_lex)
     for index in get_batch_index(len(aspects), batch_size, is_shuffle):
         feed_dict = {
             input_aspects: aspects[index],
             input_contexts: contexts[index],
             input_labels: labels[index],
             input_aspect_lens: aspect_lens[index],
             input_context_lens: context_lens[index],
             input_context_lex: context_lex[index],
             input_aspect_lex: aspect_lex[index],
             dropout_keep_prob: keep_prob
         }
         yield feed_dict, len(index)
Ejemplo n.º 7
0
 def get_batch_data(self, aspects, contexts, labels, aspect_lens, context_lens, aspect_lex, context_lex, batch_size, is_shuffle,
                    keep_prob):
     aspects = np.array(aspects)
     contexts = np.array(contexts)
     labels = np.array(labels)
     aspect_lens = np.array(aspect_lens)
     context_lens = np.array(context_lens)
     context_lex = np.array(context_lex)
     aspect_lex = np.array(aspect_lex)
     for index in get_batch_index(len(aspects), batch_size, is_shuffle):
         feed_dict = {
             self.aspects: aspects[index],
             self.contexts: contexts[index],
             self.labels: labels[index],
             self.aspect_lens: aspect_lens[index],
             self.context_lens: context_lens[index],
             self.context_lex_embedding: context_lex[index],
             self.aspect_lex_embedding: aspect_lex[index],
             self.dropout_keep_prob: keep_prob,
         }
         yield feed_dict, len(index)
Ejemplo n.º 8
0
 def get_batch_data(self, aspects, contexts, labels, aspect_lens,
                    context_lens, batch_size, is_shuffle, keep_prob):
     for index in tqdm_notebook(
             get_batch_index(len(aspects), batch_size, is_shuffle),
             total=int(len(context_lens) / batch_size) + 1):
         if len(labels) <= 0:
             feed_dict = {
                 self.aspects: aspects[index],
                 self.contexts: contexts[index],
                 self.aspect_lens: aspect_lens[index],
                 self.context_lens: context_lens[index],
                 self.dropout_keep_prob: keep_prob,
             }
         else:
             feed_dict = {
                 self.aspects: aspects[index],
                 self.contexts: contexts[index],
                 self.labels: labels[index],
                 self.aspect_lens: aspect_lens[index],
                 self.context_lens: context_lens[index],
                 self.dropout_keep_prob: keep_prob,
             }
         yield feed_dict, len(index)