示例#1
0
class emTrainer(object):
    ''''''
    def __init__(self, configs):
        ''''''
        self.config = configs
        self.training = True
        self.build()

    def __str__(self):
        ''''''
        return 'emtrainer'

    def get_vocabs(self):
        ''''''
        self.vocab = Vocab(self.config.train_file, self.config)

    def get_dataset(self):
        ''''''
        self.train_dataset = Dataset(self.config.train_file, self.vocab,
                                     self.config)
        self.dev_dataset = Dataset(self.config.dev_file, self.vocab,
                                   self.config)
        self.test_dataset = Dataset(self.config.test_file, self.vocab,
                                    self.config)

    def get_placeholder(self):
        ''''''
        self.window_features = tf.placeholder(dtype=tf.int32,
                                              shape=[
                                                  self.config.batch_size,
                                                  2 * self.config.window_size,
                                              ],
                                              name='Feature_Inputs')
        self.true_target = tf.placeholder(dtype=tf.int32,
                                          shape=[
                                              self.config.batch_size,
                                          ],
                                          name='True_Target')
        self.sample_targets = tf.placeholder(
            dtype=tf.int32,
            shape=[self.config.batch_size, self.config.sample_num],
            name='Sample_Targets')

    def get_emeddings(self):
        ''''''
        self.inputs = self.vocab._looking_up(self.window_features, type='EMB')
        self.true_word = self.vocab._looking_up(self.true_target, type='WGT')
        self.sample_words = [
            self.vocab._looking_up(self.sample_targets[:, idx], type='WGT')
            for idx in range(self.config.sample_num)
        ]

    def get_init_features(self):
        ''''''
        self.true_word_feature = self.get_element_wise_feature(
            self.inputs, self.true_word)
        self.sample_words_features = [
            self.get_element_wise_feature(self.inputs, single_sample)
            for single_sample in self.sample_words
        ]

    def get_element_wise_feature(self, _inputs, _targets):
        ''''''
        with tf.variable_scope('element_wise_multiply'):
            _inputs = tf.transpose(_inputs, [1, 0, 2])
            element_wise_features = tf.multiply(_inputs, _targets)
            element_wise_features_ = tf.transpose(element_wise_features,
                                                  [1, 0, 2])
            return element_wise_features_

    def get_allocate_true_samples(self):
        ''''''
        with tf.variable_scope('true_lables', reuse=tf.AUTO_REUSE):
            tr_loss, tr_correct, tr_predicts = self.get_logits_loss(
                self.true_word_feature, self.true_word, label=1)

        with tf.variable_scope('sample_labels', reuse=tf.AUTO_REUSE):
            sa_loss, sa_correct, sa_predicts = [], [], []
            for single_sample, sample_tar in zip(self.sample_words_features,
                                                 self.sample_words):
                sloss, scorrect, spredict = self.get_logits_loss(single_sample,
                                                                 sample_tar,
                                                                 label=0)
                sa_loss.append(sloss)
                sa_correct.append(scorrect)
                sa_predicts.append(spredict)
            sa_loss = tf.reduce_mean(sa_loss)
            sa_correct = tf.reduce_mean(sa_correct)

        self.trloss = tr_loss
        self.saloss = sa_loss
        self.trcorrect = tr_correct
        self.sacorrect = sa_correct
        self.predict = tr_predicts
        self.loss = self.trloss + self.config.lamb * self.saloss
        return

    def get_logits_loss(self, init_features, targets, label=1):
        ''''''
        routed_features = self.get_compute_graph(init_features)
        if self.config.use_targeted_attention:
            pass
        #here routed_features should have shape [batch_size, 2*window_size, 200]
        final_features = tf.reduce_sum(
            routed_features, axis=1,
            keep_dims=True)  # shape [batch_size, vec_dim]
        #final_features ==> [batch_size, 1, vec_dim]
        if self.config.use_reduce_sum:
            logits = tf.reduce_sum(final_features, axis=2, keep_dims=False)
            # logits ==> [batch_size, 1]
        elif self.config.use_matmul:
            targets = tf.reshape(
                targets, [self.config.batch_size, 1, self.config.vec_dim])
            Weight_tar = tf.reshape(
                targets, [self.config.batch_size, self.config.vec_dim, 1])
            logits = tf.matmul(final_features, Weight_tar)
        logits_ = tf.sigmoid(logits)  # shape [batch_size, 1]
        logits_ = tf.reshape(logits_, [
            self.config.batch_size,
        ])  #shape ==> [batch_size, ]
        predict = tf.to_int32(tf.greater(logits_, self.config.margin))

        predict_ = predict if label == 1 else 1 - predict  # shape ==> [batch_size, ]
        correct = tf.reduce_sum(predict_)  # shape ==> [int]
        loss = 1 - logits_ if label == 1 else logits_
        loss = tf.reduce_mean(tf.square(loss))  # shape ==> [float]

        return loss, correct, predict

    def margin_loss(self, logits, label=1):
        ''''''
        max_l = tf.square(tf.maximum(0., self.config.m_plus - logits))
        max_r = tf.square(tf.maximum(0., logits - self.config.m_minus))
        assert max_l.get_shape() == [
            self.config.batch_size,
        ]
        True_label = [label] * self.config.batch_size

    def get_compute_graph(self, inputs):
        ''''''
        #shape of inputs should be [batch_size, 2*windows_size, vec_dim] ==> 2*window_size is regarded as input_cap_num
        B_ij = tf.constant(
            np.zeros([self.config.batch_size, 2 * self.config.window_size, 1],
                     dtype=np.float32))
        feature_routed = self.routing_process(inputs, B_ij)
        return feature_routed

    def routing_process(self, inputs_, B_xx):
        ''''''
        #shape of inputs should be [batch_size, 2*windows_size, vec_dim], intergation on the input_cap_num dim
        #to compute how much the information from features we use

        inputs_stopped = tf.stop_gradient(inputs_, name='Inut_stop_gradient')
        for routing_iter in range(self.config.routing_times):
            C_ij = tf.nn.softmax(B_xx, dim=1)
            if routing_iter < self.config.routing_times - 1:
                S_j = tf.multiply(C_ij, inputs_stopped)
                V_j = tf.reduce_sum(S_j, axis=2, keep_dims=True)
                if self.config.use_squash:  #skip squash
                    V_j = self.squash(V_j, 2)
                B_xx += V_j

            elif routing_iter == self.config.routing_times - 1:
                S_j = tf.multiply(C_ij, inputs_)
                V_j = S_j
                #V_j = tf.reduce_sum(S_j, axis=1, keep_dims=False)
                #comment this because further step for the features
                if self.config.use_squash:
                    V_j = self.squash(V_j, 2)
                return V_j

    def squash(self, inputs, dim=2):
        print 'INFO: Skipped'
        raise NotImplemented

    def build(self):
        self.get_vocabs()
        self.get_dataset()
        self.read_analogies()
        self.get_placeholder()
        self._predict_op()
        self.get_emeddings()
        self.get_init_features()
        self.get_allocate_true_samples()
        self.get_train_op(self.config.lr_method, self.config.lr, self.loss,
                          self.config.clip)
        self.init_session_op()

    def init_session_op(self):
        ''''''
        config_ = tf.ConfigProto()
        config_.gpu_options.allow_growth = True
        self.sess = tf.Session(config=config_)
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver()

    def save_session_op(self):
        ''''''
        if not os.path.exists(self.config.model_dir):
            os.makedirs(self.config.model_dir)
        self.saver.save(self.sess, self.config.model_dir)

    def restore_session(self, modeldir):
        ''''''
        print 'INFO: Reload the model from %s: ' % modeldir
        self.saver.restore(self.sess, modeldir)

    def close_session(self):
        ''''''
        self.sess.close()

    def get_train_op(self, lmethod, lrate, loss, clip):
        ''''''
        lr_method = lmethod.lower()
        with tf.variable_scope('gt_train_op', reuse=tf.AUTO_REUSE):
            if lr_method == 'adam':  # sgd method
                optimizer = tf.train.AdamOptimizer(lrate)
            elif lr_method == 'adagrad':
                optimizer = tf.train.AdagradOptimizer(lrate)
            elif lr_method == 'sgd':
                optimizer = tf.train.GradientDescentOptimizer(lrate)
            elif lr_method == 'rmsprop':
                optimizer = tf.train.RMSPropOptimizer(lrate)
            else:
                raise NotImplementedError(
                    "Unknown method {}".format(lr_method))

            if clip > 0:
                grads, vs = zip(*optimizer.compute_gradients(loss))
                grads, gnorm = tf.clip_by_global_norm(grads, clip)
                self.train_op = optimizer.apply_gradients(zip(grads, vs))
            else:
                self.train_op = optimizer.minimize(loss)

    def run_train_epoch(self, epoch_num):
        ''''''
        print 'INFO: start training Epoch: %d' % epoch_num
        for fedict, _ in self.train_dataset.minibatchs():
            feed_dict = {}
            feed_dict[self.window_features] = fedict[self.train_dataset.inputs]
            feed_dict[self.true_target] = fedict[self.train_dataset.trueTa]
            feed_dict[self.sample_targets] = fedict[
                self.train_dataset.sampleTa]
            _, loss, tr_loss, sa_loss, tr_correct, sa_correct = self.sess.run(
                [
                    self.train_op, self.loss, self.trloss, self.saloss,
                    self.trcorrect, self.sacorrect
                ],
                feed_dict=feed_dict)
            #print 'In Epoch %d:'%epoch_num, 'train_loss: %.4f: '%loss
        self._save_emb()
        self.count_dist()
        results = self.run_evaluate(self.dev_dataset)
        msg = " - ".join(
            ["{} {:04.4f}".format(k, v) for k, v in results.items()])
        print "#######################" \
                "####   Dev result   ###" \
                "#######################"
        print msg
        return results["f1"]

    def run_evaluate(self, dataset):
        ''''''
        accs = []
        correct_preds, total_correct, total_preds = 0., 0., 0.
        for fed, _ in dataset.minibatchs():
            feed_dict = {}
            feed_dict[self.window_features] = fed[dataset.inputs]
            feed_dict[self.true_target] = fed[dataset.trueTa]
            feed_dict[self.sample_targets] = fed[dataset.sampleTa]
            tr_correct, sa_correct = self.sess.run(
                [self.trcorrect, self.sacorrect], feed_dict=feed_dict)
            correct_preds += tr_correct
            total_preds += self.config.batch_size
            total_correct += tr_correct + self.config.batch_size - sa_correct
            accs.append(
                (tr_correct + sa_correct) / float(2 * self.config.batch_size))
        p = correct_preds / total_preds if correct_preds > 0 else 0
        r = correct_preds / total_correct if correct_preds > 0 else 0
        f1 = 2.0 * p * r / (p + r) if correct_preds > 0 else 0
        acc = np.mean(accs)
        return {'acc': 100 * acc, 'f1': 100 * f1, 'precision': p, 'recall': r}

    def train(self):
        best_score = 0
        nepoc_no_imprv = 0
        for epoch in range(self.config.nepochs):
            f1_value = self.run_train_epoch(epoch)
            self.config.lr *= self.config.lr_decay
            if f1_value > best_score:
                best_score = f1_value
                nepoc_no_imprv = 0
                self.save_session_op()
                print '\nINFO: new best score'

            else:
                nepoc_no_imprv += 1
                if nepoc_no_imprv >= self.config.early_stop:
                    print 'INFO: early stop with improvement in %d epoch' % nepoc_no_imprv
                    break
        return best_score

    def evaluate(self):
        ''''''
        print 'INFO: Testing on test set\n'
        results = self.run_evaluate(self.test_dataset)
        msg = " - ".join(
            ["{} {:04.2f}".format(k, v) for k, v in results.items()])
        print "Dev result\n"
        print msg

    def _predict(self, analogy):
        ''''''
        idx, = self.sess.run(
            [self._ana_pred_idx], {
                self._ana_a: analogy[:, 0],
                self._ana_b: analogy[:, 1],
                self._ana_c: analogy[:, 2],
            })
        return idx

    def _predict_op(self):
        ''''''
        analogy_a = tf.placeholder(dtype=tf.int32)
        analogy_b = tf.placeholder(dtype=tf.int32)
        analogy_c = tf.placeholder(dtype=tf.int32)

        nemd = tf.nn.l2_normalize(self.vocab.embvec, 1)

        a_emb = tf.gather(nemd, analogy_a)
        b_emb = tf.gather(nemd, analogy_b)
        c_emb = tf.gather(nemd, analogy_c)

        target = c_emb + (b_emb - a_emb)

        dist = tf.matmul(target, nemd, transpose_b=True)

        _, pred_idx = tf.nn.top_k(dist, 4)

        nearby_word = tf.placeholder(dtype=tf.int32)
        nearby_emb = tf.gather(nemd, nearby_word)

        nearby_dist = tf.matmul(nearby_emb, nemd, transpose_b=True)
        nearby_val, nearby_idx = tf.nn.top_k(nearby_dist,
                                             min(1000, self.vocab.size))

        self._ana_a = analogy_a
        self._ana_b = analogy_b
        self._ana_c = analogy_c
        self._nearby_word = nearby_word
        self._ana_pred_idx = pred_idx
        self._nearby_val = nearby_val
        self._nearby_idx = nearby_idx

    def count_dist(self):
        ''''''
        correct = 0
        try:
            total = self._analogy_questions.shape[0]
        except AttributeError as e:
            raise AttributeError('Need to read the analogy data\n')
        start = 0

        while start < total:
            limit = start + 2500
            sub = self._analogy_questions[start:limit, :]
            idx = self._predict(sub)
            start = limit
            for question in xrange(sub.shape[0]):
                for j in xrange(4):
                    if idx[question, j] == sub[question, 3]:
                        correct += 1
                    elif idx[question, j] in sub[question, :3]:
                        continue
                    else:
                        break
        print("Eval %4d/%d accuracy = %4.1f%%" %
              (correct, total, correct * 100.0 / total))

    def analogy(self, wd0, wd1, wd2):
        '''wd0:wd1 VS wd2:wd3'''
        wid = self.vocab.word2id([wd0, wd1, wd2])
        idx = self._predict(np.array([wid]))
        for c in [self.vocab.id2word(i) for i in idx[0, :]]:
            if c not in [wd0, wd1, wd2]:
                print c
                return
            print('UNKOWN WORDS')

    def nearby(self, words, num=20):
        '''print the nearby words'''

        idxs = np.array([self.vocab.word2id(wd_) for wd_ in words])
        vals, idx = self.sess.run([self._nearby_val, self._nearby_idx],
                                  {self._nearby_word: idxs})
        for i in xrange(len(words)):
            print("\n%s\n=====================================" % (words[i]))
            for (neighbor, distance) in zip(idx[i, :num], vals[i, :num]):
                print("%-20s %6.4f" % (self.vocab.id2word(neighbor), distance))

    def _save_emb(self):
        ''''''
        final_emb_vec = self.vocab.embvec.eval(session=self.sess)
        final_wmb_vec = self.vocab.wtvec.eval(session=self.sess)
        #print final_emb_vec.shape
        #print final_wmb_vec.shape
        dicts = self.vocab.id2str
        plot_figure(dicts, final_emb_vec)
        return

    def read_analogies(self):
        questions = []
        questions_skipped = 0
        unk_id = self.vocab.word2id('<unk>')
        with open(self.config.eval_data, "rb") as analogy_f:
            for line in analogy_f:
                if line.startswith(b":"):  # Skip comments.
                    continue
                words = line.strip().lower().split(b" ")
                ids = self.vocab.word2id(words)
                if unk_id in ids or len(ids) != 4:
                    questions_skipped += 1
                else:
                    questions.append(np.array(ids))
        print("Eval analogy file: ", self.config.eval_data)
        print("Questions: ", len(questions))
        print("Skipped: ", questions_skipped)
        self._analogy_questions = np.array(questions, dtype=np.int32)