def analysis_evaluation(self):
        self.logger.info("Start Analyzing ...")
        start_time = time.time()
        test_batches = self.data.test
        self.logger.info("Total {} batches to analyze".format(len(test_batches)))
        acc_loss = 0
        acc_kl_loss = 0
        acc_aux_loss = 0
        acc_avg_cos = 0
        acc_avg_norm = 0

        batch_cnt = 0
        all_cnt = 0
        cnt = 0
        sample_bag = []
        try:
            for idx, batch in enumerate(test_batches):
                if idx % 10 == 0:
                    print("Idx: {}".format(idx))
                seq_len, batch_sz = batch.size()
                if self.data.condition:
                    seq_len -= 1
                    bit = batch[0, :]
                    batch = batch[1:, :]
                    bit = GVar(bit)
                else:
                    bit = None
                feed = self.data.get_feed(batch)

                if self.args.swap > 0.00001:
                    feed = swap_by_batch(feed, self.args.swap)
                if self.args.replace > 0.00001:
                    feed = replace_by_batch(feed, self.args.replace, self.model.ntoken)

                target = GVar(batch)

                recon_loss, kld, aux_loss, tup, vecs, decoded = self.model(feed, target, bit)
                # target: seq_len, batchsz
                # decoded: seq_len, batchsz, dict_sz
                # tup: 'mean' 'logvar' for Gaussian
                #         'mu' for vMF
                # vecs
                bag = self.analyze_batch(target, kld, tup, vecs, decoded)
                sample_bag += bag
                acc_loss += recon_loss.data * seq_len * batch_sz
                acc_kl_loss += torch.sum(kld).data
                acc_aux_loss += torch.sum(aux_loss).data
                acc_avg_cos += tup['avg_cos'].data
                acc_avg_norm += tup['avg_norm'].data
                cnt += 1
                batch_cnt += batch_sz
                all_cnt += batch_sz * seq_len
        except KeyboardInterrupt:
            print("early stop")
        self.write_samples(sample_bag)
        cur_loss = acc_loss[0] / all_cnt
        cur_kl = acc_kl_loss[0] / all_cnt
        cur_real_loss = cur_loss + cur_kl
        return cur_loss, cur_kl, cur_real_loss
    def evaluate(self, args, model, dev_batches):

        # Turn on training mode which enables dropout.
        model.eval()
        model.FLAG_train = False

        acc_loss = 0
        acc_kl_loss = 0
        acc_aux_loss = 0
        acc_avg_cos = 0
        acc_avg_norm = 0

        batch_cnt = 0
        all_cnt = 0
        cnt = 0
        start_time = time.time()

        for idx, batch in enumerate(dev_batches):

            seq_len, batch_sz = batch.size()
            if self.data.condition:
                seq_len -= 1
                bit = batch[0, :]
                batch = batch[1:, :]
                bit = GVar(bit)
            else:
                bit = None
            feed = self.data.get_feed(batch)

            if self.args.swap > 0.00001:
                feed = swap_by_batch(feed, self.args.swap)
            if self.args.replace > 0.00001:
                feed = replace_by_batch(feed, self.args.replace,
                                        self.model.ntoken)

            target = GVar(batch)

            recon_loss, kld, aux_loss, tup, vecs, _ = model(feed, target, bit)

            acc_loss += recon_loss.data * seq_len * batch_sz
            acc_kl_loss += torch.sum(kld).data
            acc_aux_loss += torch.sum(aux_loss).data
            acc_avg_cos += tup['avg_cos'].data
            acc_avg_norm += tup['avg_norm'].data
            cnt += 1
            batch_cnt += batch_sz
            all_cnt += batch_sz * seq_len

        cur_loss = acc_loss.item() / all_cnt
        cur_kl = acc_kl_loss.item() / all_cnt
        cur_aux_loss = acc_aux_loss.item() / all_cnt
        cur_avg_cos = acc_avg_cos.item() / cnt
        cur_avg_norm = acc_avg_norm.item() / cnt
        cur_real_loss = cur_loss + cur_kl

        # Runner.log_eval(print_ppl)
        # print('loss {:5.2f} | KL {:5.2f} | ppl {:8.2f}'.format(            cur_loss, cur_kl, math.exp(print_ppl)))
        return cur_loss, cur_kl, cur_real_loss
    def analysis_eval_order(self, feed, batch, bit):
        assert 0.33 > self.args.swap > 0.0001
        origin_feed = feed.clone()

        feed_1x = swap_by_batch(feed.clone(), self.args.swap)
        feed_2x = swap_by_batch(feed.clone(), self.args.swap * 2)
        feed_3x = swap_by_batch(feed.clone(), self.args.swap * 3)
        feed_4x = swap_by_batch(feed.clone(), self.args.swap * 4)
        feed_5x = swap_by_batch(feed.clone(), self.args.swap * 5)
        feed_6x = swap_by_batch(feed.clone(), self.args.swap * 6)
        target = GVar(batch)

        # recon_loss, kld, aux_loss, tup, vecs, decoded = self.model(feed, target, bit)
        original_recon_loss, kld, _, original_tup, original_vecs, _ = self.model(origin_feed, target, bit)
        if 'Distnor' in self.instance_name:
            key_name = "mean"
        elif 'vmf' in self.instance_name:
            key_name = "mu"
        else:
            raise NotImplementedError

        original_mu = original_tup[key_name]
        recon_loss_1x, _, _, tup_1x, vecs_1x, _ = self.model(feed_1x, target, bit)
        recon_loss_2x, _, _, tup_2x, vecs_2x, _ = self.model(feed_2x, target, bit)
        recon_loss_3x, _, _, tup_3x, vecs_3x, _ = self.model(feed_3x, target, bit)
        recon_loss_4x, _, _, tup_4x, vecs_4x, _ = self.model(feed_4x, target, bit)
        recon_loss_5x, _, _, tup_5x, vecs_5x, _ = self.model(feed_5x, target, bit)
        recon_loss_6x, _, _, tup_6x, vecs_6x, _ = self.model(feed_6x, target, bit)

        # target: seq_len, batchsz
        # decoded: seq_len, batchsz, dict_sz
        # tup: 'mean' 'logvar' for Gaussian
        #         'mu' for vMF
        # vecs
        # cos_1x = self.analyze_batch_order(original_vecs, vecs_1x).data
        # cos_2x = self.analyze_batch_order(original_vecs, vecs_2x).data
        # cos_3x = self.analyze_batch_order(original_vecs, vecs_3x).data
        cos_1x = torch.mean(cos(original_mu, tup_1x[key_name])).data
        cos_2x = torch.mean(cos(original_mu, tup_2x[key_name])).data
        cos_3x = torch.mean(cos(original_mu, tup_3x[key_name])).data
        cos_4x = torch.mean(cos(original_mu, tup_4x[key_name])).data
        cos_5x = torch.mean(cos(original_mu, tup_5x[key_name])).data
        cos_6x = torch.mean(cos(original_mu, tup_6x[key_name])).data
        # print(cos_1x, cos_2x, cos_3x)
        return [
            [original_recon_loss.data, recon_loss_1x.data, recon_loss_2x.data, recon_loss_3x.data, recon_loss_4x.data,
             recon_loss_5x.data, recon_loss_6x.data]
            , [cos_1x, cos_2x, cos_3x, cos_4x, cos_5x, cos_6x]]
    def train_epo(self, args, model, train_batches, epo, epo_start_time,
                  glob_iter):
        model.train()
        model.FLAG_train = True
        start_time = time.time()

        if self.args.optim == 'sgd':
            self.optim = torch.optim.SGD(model.parameters(),
                                         lr=self.args.cur_lr)

        acc_loss = 0
        acc_kl_loss = 0
        acc_aux_loss = 0
        acc_avg_cos = 0
        acc_avg_norm = 0

        batch_cnt = 0
        all_cnt = 0
        cnt = 0

        random.shuffle(train_batches)
        for idx, batch in enumerate(train_batches):
            self.optim.zero_grad()
            seq_len, batch_sz = batch.size()
            if self.data.condition:
                seq_len -= 1

                if self.model.input_cd_bit > 1:
                    bit = batch[0, :]
                    bit = GVar(bit)
                else:
                    bit = None
                batch = batch[1:, :]
            else:
                bit = None
            feed = self.data.get_feed(batch)

            if self.args.swap > 0.00001:
                feed = swap_by_batch(feed, self.args.swap)
            if self.args.replace > 0.00001:
                feed = replace_by_batch(feed, self.args.replace,
                                        self.model.ntoken)

            self.glob_iter += 1

            target = GVar(batch)

            recon_loss, kld, aux_loss, tup, vecs, _ = model(feed, target, bit)
            total_loss = recon_loss * seq_len + torch.mean(
                kld) * self.args.kl_weight + torch.mean(
                    aux_loss) * args.aux_weight

            total_loss.backward()

            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            # torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)        # Upgrade to pytorch 0.4.1
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.clip,
                                           norm_type=2)

            self.optim.step()

            acc_loss += recon_loss.data * seq_len * batch_sz
            acc_kl_loss += torch.sum(kld).data
            acc_aux_loss += torch.sum(aux_loss).data
            acc_avg_cos += tup['avg_cos'].data
            acc_avg_norm += tup['avg_norm'].data

            cnt += 1
            batch_cnt += batch_sz
            all_cnt += batch_sz * seq_len
            if idx % args.log_interval == 0 and idx > 0:
                cur_loss = acc_loss.item() / all_cnt
                cur_kl = acc_kl_loss.item() / all_cnt
                # if cur_kl < 0.03:
                #     raise KeyboardInterrupt
                # if cur_kl > 0.7:
                #     raise KeyboardInterrupt
                cur_aux_loss = acc_aux_loss.item() / all_cnt
                cur_avg_cos = acc_avg_cos.item() / cnt
                cur_avg_norm = acc_avg_norm.item() / cnt
                cur_real_loss = cur_loss + cur_kl
                Runner.log_instant(self.writer, self.args, self.glob_iter, epo,
                                   start_time, cur_avg_cos, cur_avg_norm,
                                   cur_loss, cur_kl, cur_aux_loss,
                                   cur_real_loss)