예제 #1
0
    def sample_cell(self, mu, norm, kappa):
        """

        :param mu: z_dir (batchsz, lat_dim) . ALREADY normed.
        :param norm: z_norm (batchsz, lat_dim).
        :param kappa: scalar
        :return:
        """
        """vMF sampler in pytorch.
        http://stats.stackexchange.com/questions/156729/sampling-from-von-mises-fisher-distribution-in-python
        Args:
            mu (Tensor): of shape (batch_size, 2*word_dim)
            kappa (Float): controls dispersion. kappa of zero is no dispersion.
        """
        batch_sz, lat_dim = mu.size()
        # Unif VMF
        norm_with_noise = self.add_norm_noise_batch(norm, self.norm_eps)
        # Unif VMF
        w = self._sample_weight_batch(kappa, lat_dim, batch_sz)
        w = w.unsqueeze(1)
        w_var = GVar(w * torch.ones(batch_sz, lat_dim))
        v = self._sample_ortho_batch(mu, lat_dim)
        scale_factr = torch.sqrt(
            GVar(torch.ones(batch_sz, lat_dim)) - torch.pow(w_var, 2))
        orth_term = v * scale_factr
        muscale = mu * w_var
        sampled_vec = (orth_term + muscale) * norm_with_noise

        return sampled_vec.unsqueeze(0)
예제 #2
0
 def init_hidden(self, bsz):
     weight = next(self.parameters()).data
     if self.rnn_type == 'LSTM':
         return (GVar(weight.new(self.nlayers, bsz, self.nhid).zero_()),
                 GVar(weight.new(self.nlayers, bsz, self.nhid).zero_()))
     else:
         return GVar(weight.new(self.nlayers, bsz, self.nhid).zero_())
예제 #3
0
    def evaluate(self, dev_batches):
        self.learner.eval()
        print("Test start")
        acc_loss = 0
        acc_accuracy = 0
        all_cnt = 0
        cnt = 0
        random.shuffle(dev_batches)
        for idx, batch in enumerate(dev_batches):
            self.optim.zero_grad()
            bit, vec = batch
            bit = GVar(bit)
            vec = GVar(vec)
            # print(bit)
            loss, pred = self.learner(vec, bit)
            _, argmax = torch.max(pred, dim=1)
            loss.backward()
            self.optim.step()

            argmax = argmax.data
            bit = bit.data
            for idx, num in enumerate(argmax):
                gt = bit[idx]
                all_cnt += 1
                if gt == num:
                    acc_accuracy += 1

            acc_loss += loss.data[0]
            cnt += 1
        # print("===============test===============")
        # print(acc_loss / cnt)
        print("Loss {}  \tAccuracy {}".format(acc_loss / cnt,
                                              acc_accuracy / all_cnt))

        return float(acc_accuracy / all_cnt)
예제 #4
0
    def analysis_evaluation(self):
        self.logger.info("Start Analyzing ...")
        start_time = time.time()
        test_batches = self.data.test
        self.logger.info("Total {} batches to analyze".format(len(test_batches)))
        acc_loss = 0
        acc_kl_loss = 0
        acc_aux_loss = 0
        acc_avg_cos = 0
        acc_avg_norm = 0

        batch_cnt = 0
        all_cnt = 0
        cnt = 0
        sample_bag = []
        try:
            for idx, batch in enumerate(test_batches):
                if idx % 10 == 0:
                    print("Idx: {}".format(idx))
                seq_len, batch_sz = batch.size()
                if self.data.condition:
                    seq_len -= 1
                    bit = batch[0, :]
                    batch = batch[1:, :]
                    bit = GVar(bit)
                else:
                    bit = None
                feed = self.data.get_feed(batch)

                if self.args.swap > 0.00001:
                    feed = swap_by_batch(feed, self.args.swap)
                if self.args.replace > 0.00001:
                    feed = replace_by_batch(feed, self.args.replace, self.model.ntoken)

                target = GVar(batch)

                recon_loss, kld, aux_loss, tup, vecs, decoded = self.model(feed, target, bit)
                # target: seq_len, batchsz
                # decoded: seq_len, batchsz, dict_sz
                # tup: 'mean' 'logvar' for Gaussian
                #         'mu' for vMF
                # vecs
                bag = self.analyze_batch(target, kld, tup, vecs, decoded)
                sample_bag += bag
                acc_loss += recon_loss.data * seq_len * batch_sz
                acc_kl_loss += torch.sum(kld).data
                acc_aux_loss += torch.sum(aux_loss).data
                acc_avg_cos += tup['avg_cos'].data
                acc_avg_norm += tup['avg_norm'].data
                cnt += 1
                batch_cnt += batch_sz
                all_cnt += batch_sz * seq_len
        except KeyboardInterrupt:
            print("early stop")
        self.write_samples(sample_bag)
        cur_loss = acc_loss[0] / all_cnt
        cur_kl = acc_kl_loss[0] / all_cnt
        cur_real_loss = cur_loss + cur_kl
        return cur_loss, cur_kl, cur_real_loss
예제 #5
0
    def evaluate(self, args, model, dev_batches):

        # Turn on training mode which enables dropout.
        model.eval()
        model.FLAG_train = False

        acc_loss = 0
        acc_kl_loss = 0
        acc_aux_loss = 0
        acc_avg_cos = 0
        acc_avg_norm = 0

        batch_cnt = 0
        all_cnt = 0
        cnt = 0
        start_time = time.time()

        for idx, batch in enumerate(dev_batches):

            seq_len, batch_sz = batch.size()
            if self.data.condition:
                seq_len -= 1
                bit = batch[0, :]
                batch = batch[1:, :]
                bit = GVar(bit)
            else:
                bit = None
            feed = self.data.get_feed(batch)

            if self.args.swap > 0.00001:
                feed = swap_by_batch(feed, self.args.swap)
            if self.args.replace > 0.00001:
                feed = replace_by_batch(feed, self.args.replace,
                                        self.model.ntoken)

            target = GVar(batch)

            recon_loss, kld, aux_loss, tup, vecs, _ = model(feed, target, bit)

            acc_loss += recon_loss.data * seq_len * batch_sz
            acc_kl_loss += torch.sum(kld).data
            acc_aux_loss += torch.sum(aux_loss).data
            acc_avg_cos += tup['avg_cos'].data
            acc_avg_norm += tup['avg_norm'].data
            cnt += 1
            batch_cnt += batch_sz
            all_cnt += batch_sz * seq_len

        cur_loss = acc_loss.item() / all_cnt
        cur_kl = acc_kl_loss.item() / all_cnt
        cur_aux_loss = acc_aux_loss.item() / all_cnt
        cur_avg_cos = acc_avg_cos.item() / cnt
        cur_avg_norm = acc_avg_norm.item() / cnt
        cur_real_loss = cur_loss + cur_kl

        # Runner.log_eval(print_ppl)
        # print('loss {:5.2f} | KL {:5.2f} | ppl {:8.2f}'.format(            cur_loss, cur_kl, math.exp(print_ppl)))
        return cur_loss, cur_kl, cur_real_loss
예제 #6
0
    def __init__(self, hid_dim, lat_dim, kappa=1):
        super().__init__()
        self.hid_dim = hid_dim
        self.lat_dim = lat_dim
        self.kappa = kappa
        # self.func_kappa = torch.nn.Linear(hid_dim, lat_dim)
        self.func_mu = torch.nn.Linear(hid_dim, lat_dim)

        self.kld = GVar(torch.from_numpy(vMF._vmf_kld(kappa, lat_dim)).float())
        print('KLD: {}'.format(self.kld.data[0]))
예제 #7
0
    def evaluate(self, args, model, corpus_dev, corpus_dev_cnt, dev_batches):
        """
        Standard evaluation function on dev or test set.
        :param args:
        :param model:
        :param dev_batches:
        :return:
        """

        # Turn on training mode which enables dropout.
        model.eval()

        acc_loss = 0
        acc_kl_loss = 0
        acc_real_loss = 0
        word_cnt = 0
        doc_cnt = 0
        start_time = time.time()
        ntokens = self.data.vocab_size

        for idx, batch in enumerate(dev_batches):
            data_batch, count_batch = self.data.fetch_data(
                corpus_dev, corpus_dev_cnt, batch, ntokens)

            data_batch = GVar(torch.FloatTensor(data_batch))

            recon_loss, kld, aux_loss, tup, vecs = model(data_batch)

            count_batch = GVar(torch.FloatTensor(count_batch))
            # real_loss = torch.div((recon_loss + kld).data, count_batch)
            doc_num = len(count_batch)
            # remove nan
            # for n in real_loss:
            #     if n == n:
            #         acc_real_loss += n
            # acc_real_ppl += torch.sum(real_ppl)

            acc_loss += torch.sum(recon_loss).item()  #
            acc_kl_loss += torch.sum(kld).item()
            count_batch = count_batch + 1e-12

            word_cnt += torch.sum(count_batch)
            doc_cnt += doc_num

        # word ppl
        cur_loss = acc_loss / word_cnt  # word loss
        cur_kl = acc_kl_loss / word_cnt
        # cur_real_loss = acc_real_loss / doc_cnt
        cur_real_loss = cur_loss + cur_kl
        elapsed = time.time() - start_time

        # Runner.log_eval(print_ppl)
        # print('loss {:5.2f} | KL {:5.2f} | ppl {:8.2f}'.format(            cur_loss, cur_kl, math.exp(print_ppl)))
        return cur_loss, cur_kl, cur_real_loss
예제 #8
0
 def dropword(self, emb, drop_rate=0.3):
     """
     Mix the ground truth word with UNK.
     If drop rate = 1, no ground truth info is used. (Fly mode)
     :param emb:
     :param drop_rate: 0 - no drop; 1 - full drop, all UNK
     :return: mixed embedding
     """
     UNKs = GVar(torch.ones(emb.size()[0], emb.size()[1]).long() * 2)
     UNKs = self.emb(UNKs)
     # print(UNKs, emb)
     masks = numpy.random.binomial(1, drop_rate, size=(emb.size()[0], emb.size()[1]))
     masks = GVar(torch.FloatTensor(masks)).unsqueeze(2).expand_as(UNKs)
     emb = emb * (1 - masks) + UNKs * masks
     return emb
예제 #9
0
    def __init__(self, hid_dim, lat_dim, kappa=1):
        """
        von Mises-Fisher distribution class with batch support and manual tuning kappa value.
        Implementation follows description of my paper and Guu's.
        """

        super().__init__()
        self.hid_dim = hid_dim
        self.lat_dim = lat_dim
        self.kappa = kappa
        # self.func_kappa = torch.nn.Linear(hid_dim, lat_dim)
        self.func_mu = torch.nn.Linear(hid_dim, lat_dim)

        self.kld = GVar(torch.from_numpy(vMF._vmf_kld(kappa, lat_dim)).float())
        print('KLD: {}'.format(self.kld.data[0]))
예제 #10
0
    def forward(self, x):
        batch_sz = x.size()[0]

        linear_x = self.enc_vec(x)
        linear_x = self.dropout(linear_x)
        active_x = self.active(linear_x)
        linear_x_2 = self.enc_vec_2(active_x)

        tup, kld, vecs = self.dist.build_bow_rep(linear_x_2, self.n_sample)
        # vecs: n_samples, batch_sz, lat_dim

        if 'redundant_norm' in tup:
            aux_loss = tup['redundant_norm'].view(batch_sz)
        else:
            aux_loss = GVar(torch.zeros(batch_sz))

        # stat
        avg_cos = BowVAE.check_dispersion(vecs)
        avg_norm = torch.mean(tup['norm'])
        tup['avg_cos'] = avg_cos
        tup['avg_norm'] = avg_norm

        flatten_vecs = vecs.view(self.n_sample * batch_sz, self.n_lat)
        flatten_vecs = self.dec_act(self.dec_linear(flatten_vecs))
        logit = self.dropout(self.out(flatten_vecs))
        logit = torch.nn.functional.log_softmax(logit, dim=1)
        logit = logit.view(self.n_sample, batch_sz, self.vocab_size)
        flatten_x = x.unsqueeze(0).expand(self.n_sample, batch_sz,
                                          self.vocab_size)
        error = torch.mul(flatten_x, logit)
        error = torch.mean(error, dim=0)

        recon_loss = -torch.sum(error, dim=1, keepdim=False)

        return recon_loss, kld, aux_loss, tup, vecs
예제 #11
0
    def sample_cell(self, mu, norm, kappa):
        batch_sz, lat_dim = mu.size()
        # mu = GVar(mu)
        mu = mu / torch.norm(mu, p=2, dim=1, keepdim=True)
        w = self._sample_weight_batch(kappa, lat_dim, batch_sz)
        w = w.unsqueeze(1)

        # batch version
        w_var = GVar(w * torch.ones(batch_sz, lat_dim).to(device))
        v = self._sample_ortho_batch(mu, lat_dim)
        scale_factr = torch.sqrt(
            GVar(torch.ones(batch_sz, lat_dim)) - torch.pow(w_var, 2))
        orth_term = v * scale_factr
        muscale = mu * w_var
        sampled_vec = orth_term + muscale

        return sampled_vec.unsqueeze(0).to(device)
예제 #12
0
 def _sample_orthonormal_to(self, mu, dim):
     """Sample point on sphere orthogonal to mu.
     """
     v = GVar(torch.randn(dim))
     rescale_value = mu.dot(v) / mu.norm()
     proj_mu_v = mu * rescale_value.expand(dim)
     ortho = v - proj_mu_v
     ortho_norm = torch.norm(ortho)
     return ortho / ortho_norm.expand_as(ortho)
예제 #13
0
 def add_norm_noise(self, munorm, eps):
     """
     KL loss is - log(maxvalue/eps)
     cut at maxvalue-eps, and add [0,eps] noise.
     """
     # if np.random.rand()<0.05:
     #     print(munorm[0])
     trand = torch.rand(1).expand(munorm.size()) * eps
     return munorm + GVar(trand)
예제 #14
0
    def forward(self, inp, target, bit=None):
        """
        Forward with ground truth (maybe mixed with UNK) as input.
        :param inp:  seq_len, batch_sz
        :param target: seq_len, batch_sz
        :param bit: 1, batch_sz
        :return:
        """
        seq_len, batch_sz = inp.size()
        emb = self.drop(self.emb(inp))

        if self.input_cd_bow > 1:
            bow = self.enc_bow(emb)
        else:
            bow = None
        if self.input_cd_bit > 1:
            bit = self.enc_bit(bit)
        else:
            bit = None

        h = self.forward_enc(emb, bit)
        tup, kld, vecs = self.forward_build_lat(h, self.args.nsample)  # batchsz, lat dim

        if 'redundant_norm' in tup:
            aux_loss = tup['redundant_norm'].view(batch_sz)
        else:
            aux_loss = GVar(torch.zeros(batch_sz))
        if 'norm' not in tup:
            tup['norm'] = GVar(torch.zeros(batch_sz))
        # stat
        avg_cos = check_dispersion(vecs)
        tup['avg_cos'] = avg_cos

        avg_norm = torch.mean(tup['norm'])
        tup['avg_norm'] = avg_norm

        vec = torch.mean(vecs, dim=0)

        decoded = self.forward_decode_ground(emb, vec, bit, bow)  # (seq_len, batch, dict sz)

        flatten_decoded = decoded.view(-1, self.ntoken)
        flatten_target = target.view(-1)
        loss = self.criterion(flatten_decoded, flatten_target)
        return loss, kld, aux_loss, tup, vecs, decoded
예제 #15
0
    def play_eval(self, args, model, train_batches, epo, epo_start_time,
                  glob_iter):
        # reveal the relation between latent space and length and loss
        # reveal the distribution of latent space
        model.eval()
        model.FLAG_train = False
        start_time = time.time()

        acc_loss = 0
        acc_kl_loss = 0
        acc_aux_loss = 0
        acc_avg_cos = 0
        acc_avg_norm = 0

        batch_cnt = 0
        all_cnt = 0
        cnt = 0
        random.shuffle(train_batches)

        if self.args.dist == 'nor':
            vs = visual_gauss(self.data.dictionary)
        elif self.args.dist == 'vmf':
            vs = visual_vmf(self.data.dictionary)

        for idx, batch in enumerate(train_batches):
            seq_len, batch_sz = batch.size()
            feed = self.data.get_feed(batch)

            glob_iter += 1

            target = GVar(batch)

            recon_loss, kld, aux_loss, tup, vecs = model(feed, target)

            acc_loss += recon_loss.data * seq_len * batch_sz
            acc_kl_loss += torch.sum(kld).data
            acc_aux_loss += torch.sum(aux_loss).data
            acc_avg_cos += tup['avg_cos'].data
            acc_avg_norm += tup['avg_norm'].data

            cnt += 1
            batch_cnt += batch_sz
            all_cnt += batch_sz * seq_len

            vs.add_batch(target.data, tup, kld.data)

        cur_loss = acc_loss[0] / all_cnt
        cur_kl = acc_kl_loss[0] / all_cnt
        cur_aux_loss = acc_aux_loss[0] / all_cnt
        cur_avg_cos = acc_avg_cos[0] / cnt
        cur_avg_norm = acc_avg_norm[0] / cnt
        cur_real_loss = cur_loss + cur_kl
        Runner.log_instant(None, self.args, glob_iter, epo, start_time,
                           cur_avg_cos, cur_avg_norm, cur_loss, cur_kl,
                           cur_aux_loss, cur_real_loss)
        vs.write_log()
예제 #16
0
    def play_eval(self, args, model, train_batches, epo, epo_start_time, glob_iter):
        # reveal the relation between latent space and length and loss
        # reveal the distribution of latent space
        model.eval()
        start_time = time.time()
        acc_loss = 0
        acc_kl_loss = 0
        acc_real_loss = 0

        word_cnt = 0
        doc_cnt = 0

        random.shuffle(train_batches)

        if self.args.dist == 'nor':
            vs = visual_gauss()
        elif self.args.dist == 'vmf':
            vs = visual_vmf()

        for idx, batch in enumerate(train_batches):
            # seq_len, batch_sz = batch.size()
            data_batch, count_batch = DataNg.fetch_data(
                self.data.test[0], self.data.test[1], batch)

            data_batch = GVar(torch.FloatTensor(data_batch))

            recon_loss, kld, total_loss, tup, vecs = model(data_batch)

            vs.add_batch(data_batch, tup, kld.data, vecs)

            count_batch = torch.FloatTensor(count_batch).cuda()
            real_loss = torch.div((recon_loss + kld).data, count_batch)
            doc_num = len(count_batch)
            # remove nan
            for n in real_loss:
                if n == n:
                    acc_real_loss += n
            # acc_real_ppl += torch.sum(real_ppl)

            acc_loss += torch.sum(recon_loss).item()  #
            acc_kl_loss += torch.sum(kld.item())
            count_batch = count_batch + 1e-12

            word_cnt += torch.sum(count_batch)
            doc_cnt += doc_num

        cur_loss = acc_loss[0] / word_cnt  # word loss
        cur_kl = acc_kl_loss / word_cnt
        # cur_real_loss = acc_real_loss / doc_cnt
        cur_real_loss = cur_loss + cur_kl

        Runner.log_instant(None, self.args, glob_iter, epo, start_time, cur_loss
                           , cur_kl,
                           cur_real_loss)
        vs.write_log()
예제 #17
0
    def evaluate(self, dev_batches):
        self.learner.eval()
        print("Test start")
        acc_loss = 0
        cnt = 0
        random.shuffle(dev_batches)
        for idx, batch in enumerate(dev_batches):
            self.optim.zero_grad()
            seq_len, batch_sz = batch.size()
            if self.data.condition:
                seq_len -= 1

                if self.model.input_cd_bit > 1:
                    bit = batch[0, :]
                    bit = GVar(bit)
                else:
                    bit = None
                batch = batch[1:, :]
            else:
                bit = None
            feed = self.data.get_feed(batch)

            seq_len, batch_sz = feed.size()
            emb = self.model.drop(self.model.emb(feed))

            if self.model.input_cd_bit > 1:
                bit = self.model.enc_bit(bit)
            else:
                bit = None

            h = self.model.forward_enc(emb, bit)
            tup, kld, vecs = self.model.forward_build_lat(
                h)  # batchsz, lat dim
            if self.model.dist_type == 'vmf':
                code = tup['mu']
            elif self.model.dist_type == 'nor':
                code = tup['mean']
            else:
                raise NotImplementedError
            emb = torch.mean(emb, dim=0)
            if self.c2b:
                loss = self.learner(code, emb)
            else:
                loss = self.learner(code, emb)
            acc_loss += loss.data[0]
            cnt += 1
            if idx % 400 == 0:
                acc_loss = 0
                cnt = 0
        # print("===============test===============")
        # print(acc_loss / cnt)
        print(acc_loss / cnt)
        return float(acc_loss / cnt)
예제 #18
0
    def train_epo(self, train_batches):
        self.learner.train()
        print("Epo start")
        acc_loss = 0
        cnt = 0

        random.shuffle(train_batches)
        for idx, batch in enumerate(train_batches):
            self.optim.zero_grad()
            seq_len, batch_sz = batch.size()
            if self.data.condition:
                seq_len -= 1

                if self.model.input_cd_bit > 1:
                    bit = batch[0, :]
                    bit = GVar(bit)
                else:
                    bit = None
                batch = batch[1:, :]
            else:
                bit = None
            feed = self.data.get_feed(batch)

            seq_len, batch_sz = feed.size()
            emb = self.model.drop(self.model.emb(feed))

            if self.model.input_cd_bit > 1:
                bit = self.model.enc_bit(bit)
            else:
                bit = None

            h = self.model.forward_enc(emb, bit)
            tup, kld, vecs = self.model.forward_build_lat(
                h)  # batchsz, lat dim
            if self.model.dist_type == 'vmf':
                code = tup['mu']
            elif self.model.dist_type == 'nor':
                code = tup['mean']
            else:
                raise NotImplementedError
            emb = torch.mean(emb, dim=0)
            if self.c2b:
                loss = self.learner(code, emb)
            else:
                loss = self.learner(code, emb)
            loss.backward()
            self.optim.step()
            acc_loss += loss.data[0]
            cnt += 1
            if idx % 400 == 0 and (idx > 0):
                print("Training {}".format(acc_loss / cnt))
                acc_loss = 0
                cnt = 0
예제 #19
0
 def get_feed(data_patch):
     """
     Given data patch, get the corresponding input of that data patch.
     Given: [A, B, C, D]
     Return: [SOS, A, B, C]
     :param data_patch:
     :return:
     """
     # seq, batch
     bsz = data_patch.size()[1]
     sos = torch.LongTensor(1, bsz).fill_(1)
     input_data = GVar(torch.cat((sos, data_patch[:-1])))
     return input_data
예제 #20
0
    def __init__(self, hid_dim, lat_dim, kappa=1, norm_max=2, norm_func=True):
        super().__init__()
        self.hid_dim = hid_dim
        self.lat_dim = lat_dim
        self.kappa = kappa
        # self.func_kappa = torch.nn.Linear(hid_dim, lat_dim)
        self.func_mu = torch.nn.Linear(hid_dim, lat_dim)
        self.func_norm = torch.nn.Linear(hid_dim, 1)

        # self.noise_scaler = kappa
        self.norm_eps = 1
        self.norm_max = norm_max
        self.norm_clip = torch.nn.Hardtanh(0.00001,
                                           self.norm_max - self.norm_eps)

        self.norm_func = norm_func

        # KLD accounts for both VMF and uniform parts
        kld_value = unif_vMF._vmf_kld(kappa, lat_dim) \
                    + unif_vMF._uniform_kld(0., self.norm_eps, 0., self.norm_max)
        self.kld = GVar(torch.from_numpy(np.array([kld_value])).float())
        print('KLD: {}'.format(self.kld.data[0]))
예제 #21
0
    def _sample_ortho_batch(self, mu, dim):
        """

        :param mu: Variable, [batch size, latent dim]
        :param dim: scala. =latent dim
        :return:
        """
        _batch_sz, _lat_dim = mu.size()
        assert _lat_dim == dim
        squeezed_mu = mu.unsqueeze(1)

        v = GVar(torch.randn(_batch_sz, dim, 1))  # TODO random

        # v = GVar(torch.linspace(-1, 1, steps=dim))
        # v = v.expand(_batch_sz, dim).unsqueeze(2)

        rescale_val = torch.bmm(squeezed_mu, v).squeeze(2)
        proj_mu_v = mu * rescale_val
        ortho = v.squeeze() - proj_mu_v
        ortho_norm = torch.norm(ortho, p=2, dim=1, keepdim=True)
        y = ortho / ortho_norm
        return y
예제 #22
0
    def sample_cell(self, mu, norm, kappa):
        batch_sz, lat_dim = mu.size()
        result = []
        sampled_vecs = GVar(torch.FloatTensor(batch_sz, lat_dim))
        for b in range(batch_sz):
            this_mu = mu[b]
            # kappa = np.linalg.norm(this_theta)
            this_mu = this_mu / torch.norm(this_mu, p=2)

            w = self._sample_weight(kappa, lat_dim)
            w_var = GVar(w * torch.ones(lat_dim))

            v = self._sample_orthonormal_to(this_mu, lat_dim)

            scale_factr = torch.sqrt(GVar(torch.ones(lat_dim)) - torch.pow(w_var, 2))
            orth_term = v * scale_factr
            muscale = this_mu * w_var
            sampled_vec = orth_term + muscale
            sampled_vecs[b] = sampled_vec
            # sampled_vec = torch.FloatTensor(sampled_vec)
            # result.append(sampled_vec)

        return sampled_vecs.unsqueeze(0)
예제 #23
0
    def analysis_evaluation_order_and_importance(self):
        """
        Measure the change of cos sim given different encoding sequence
        :return:
        """
        self.logger.info("Start Analyzing ... Picking up 100 batches to analyze")
        start_time = time.time()
        test_batches = self.data.test
        random.shuffle(test_batches)
        test_batches = test_batches[:100]
        self.logger.info("Total {} batches to analyze".format(len(test_batches)))
        acc_loss = 0
        acc_kl_loss = 0

        batch_cnt = 0
        all_cnt = 0
        cnt = 0
        sample_bag = []
        try:
            for idx, batch in enumerate(test_batches):
                if idx % 10 == 0:
                    print("Now Idx: {}".format(idx))
                seq_len, batch_sz = batch.size()
                if self.data.condition:
                    seq_len -= 1
                    bit = batch[0, :]
                    batch = batch[1:, :]
                    bit = GVar(bit)
                else:
                    bit = None
                feed = self.data.get_feed(batch)

                if self.args.swap > 0.0001:
                    bag = self.analysis_eval_order(feed, batch, bit)
                elif self.args.replace > 0.0001:
                    bag = self.analysis_eval_word_importance(feed, batch, bit)
                else:
                    print("Maybe Wrong mode?")
                    raise NotImplementedError

                sample_bag.append(bag)
        except KeyboardInterrupt:
            print("early stop")
        if self.args.swap > 0.0001:
            return self.unpack_bag_order(sample_bag)
        elif self.args.replace > 0.0001:
            return self.unpack_bag_word_importance(sample_bag)
        else:
            raise NotImplementedError
예제 #24
0
    def forward_build_lat(self, hidden, nsample=3):
        """

        :param hidden:
        :return: tup, kld [batch_sz], out [nsamples, batch_sz, lat_dim]
        """
        # hidden: batch_sz, nhid
        if self.args.dist == 'nor':
            tup, kld, out = self.dist.build_bow_rep(hidden, nsample)  # 2 for bidirect, 2 for h and
        elif self.args.dist == 'vmf':
            tup, kld, out = self.dist.build_bow_rep(hidden, nsample)
        elif self.args.dist == 'unifvmf':
            tup, kld, out = self.dist.build_bow_rep(hidden, nsample)
        elif self.args.dist == 'vmf_diff':
            tup, kld, out = self.dist.build_bow_rep(hidden, nsample)
        elif self.args.dist == 'sph':
            tup, kld, out = self.dist.build_bow_rep(hidden, nsample)
        elif self.args.dist == 'zero':
            out = GVar(torch.zeros(1, hidden.size()[0], self.lat_dim))
            tup = {}
            kld = GVar(torch.zeros(1))
        else:
            raise NotImplementedError
        return tup, kld, out
예제 #25
0
    def train_epo(self, train_batches):
        self.learner.train()
        print("Epo start")
        acc_loss = 0
        acc_accuracy = 0
        all_cnt = 0
        cnt = 0

        random.shuffle(train_batches)

        for idx, batch in enumerate(train_batches):
            self.optim.zero_grad()
            bit, vec = batch
            bit = GVar(bit)
            vec = GVar(vec)
            # print(bit)
            loss, pred = self.learner(vec, bit)
            _, argmax = torch.max(pred, dim=1)
            loss.backward()
            self.optim.step()

            argmax = argmax.data
            bit = bit.data
            for jdx, num in enumerate(argmax):
                gt = bit[jdx]
                all_cnt += 1
                if gt == num:
                    acc_accuracy += 1

            acc_loss += loss.data[0]
            cnt += 1
            if idx % 400 == 0:
                print("Loss {}  \tAccuracy {}".format(acc_loss / cnt,
                                                      acc_accuracy / all_cnt))
                acc_loss = 0
                cnt = 0
예제 #26
0
    def analysis_eval_order(self, feed, batch, bit):
        assert 0.33 > self.args.swap > 0.0001
        origin_feed = feed.clone()

        feed_1x = swap_by_batch(feed.clone(), self.args.swap)
        feed_2x = swap_by_batch(feed.clone(), self.args.swap * 2)
        feed_3x = swap_by_batch(feed.clone(), self.args.swap * 3)
        feed_4x = swap_by_batch(feed.clone(), self.args.swap * 4)
        feed_5x = swap_by_batch(feed.clone(), self.args.swap * 5)
        feed_6x = swap_by_batch(feed.clone(), self.args.swap * 6)
        target = GVar(batch)

        # recon_loss, kld, aux_loss, tup, vecs, decoded = self.model(feed, target, bit)
        original_recon_loss, kld, _, original_tup, original_vecs, _ = self.model(origin_feed, target, bit)
        if 'Distnor' in self.instance_name:
            key_name = "mean"
        elif 'vmf' in self.instance_name:
            key_name = "mu"
        else:
            raise NotImplementedError

        original_mu = original_tup[key_name]
        recon_loss_1x, _, _, tup_1x, vecs_1x, _ = self.model(feed_1x, target, bit)
        recon_loss_2x, _, _, tup_2x, vecs_2x, _ = self.model(feed_2x, target, bit)
        recon_loss_3x, _, _, tup_3x, vecs_3x, _ = self.model(feed_3x, target, bit)
        recon_loss_4x, _, _, tup_4x, vecs_4x, _ = self.model(feed_4x, target, bit)
        recon_loss_5x, _, _, tup_5x, vecs_5x, _ = self.model(feed_5x, target, bit)
        recon_loss_6x, _, _, tup_6x, vecs_6x, _ = self.model(feed_6x, target, bit)

        # target: seq_len, batchsz
        # decoded: seq_len, batchsz, dict_sz
        # tup: 'mean' 'logvar' for Gaussian
        #         'mu' for vMF
        # vecs
        # cos_1x = self.analyze_batch_order(original_vecs, vecs_1x).data
        # cos_2x = self.analyze_batch_order(original_vecs, vecs_2x).data
        # cos_3x = self.analyze_batch_order(original_vecs, vecs_3x).data
        cos_1x = torch.mean(cos(original_mu, tup_1x[key_name])).data
        cos_2x = torch.mean(cos(original_mu, tup_2x[key_name])).data
        cos_3x = torch.mean(cos(original_mu, tup_3x[key_name])).data
        cos_4x = torch.mean(cos(original_mu, tup_4x[key_name])).data
        cos_5x = torch.mean(cos(original_mu, tup_5x[key_name])).data
        cos_6x = torch.mean(cos(original_mu, tup_6x[key_name])).data
        # print(cos_1x, cos_2x, cos_3x)
        return [
            [original_recon_loss.data, recon_loss_1x.data, recon_loss_2x.data, recon_loss_3x.data, recon_loss_4x.data,
             recon_loss_5x.data, recon_loss_6x.data]
            , [cos_1x, cos_2x, cos_3x, cos_4x, cos_5x, cos_6x]]
예제 #27
0
    def analysis_eval_word_importance(self, feed, batch, bit):
        """
        Given a sentence, replace a certain word by UNK and see how lat code change from the origin one.
        :param feed:
        :param batch:
        :param bit:
        :return:
        """
        seq_len, batch_sz = batch.size()
        target = GVar(batch)
        origin_feed = feed.clone()
        original_recon_loss, kld, _, original_tup, original_vecs, _ = self.model(origin_feed, target, bit)
        # original_vecs = torch.mean(original_vecs, dim=0).unsqueeze(2)
        original_mu = original_tup['mu']
        # table_of_code = torch.FloatTensor(seq_len, batch_sz )
        table_of_mu = torch.FloatTensor(seq_len, batch_sz)
        for t in range(seq_len):
            cur_feed = feed.clone()
            cur_feed[t, :] = 2
            cur_recon, _, _, cur_tup, cur_vec, _ = self.model(cur_feed, target, bit)

            cur_mu = cur_tup['mu']
            # cur_vec = torch.mean(cur_vec, dim=0).unsqueeze(2)
            # x = cos(original_vecs, cur_vec)
            # x= x.squeeze()
            y = cos(original_mu, cur_mu)
            y = y.squeeze()

            # table_of_code[t,:] = x.data
            table_of_mu[t, :] = y.data
        bag = []
        for b in range(batch_sz):
            weight = table_of_mu[:, b]
            word_ids = feed[:, b]
            words = self.ids_to_words(word_ids.data.tolist())
            seq_of_words = words.split(" ")
            s = ""
            for t in range(seq_len):
                if weight[t] < 0.98:
                    s += "*" + seq_of_words[t] + "* "
                else:
                    s += seq_of_words[t] + " "
            bag.append(s)
        return bag
예제 #28
0
class vMF(torch.nn.Module):
    def __init__(self, hid_dim, lat_dim, kappa=1):
        super().__init__()
        self.hid_dim = hid_dim
        self.lat_dim = lat_dim
        self.kappa = kappa
        # self.func_kappa = torch.nn.Linear(hid_dim, lat_dim)
        self.func_mu = torch.nn.Linear(hid_dim, lat_dim)

        self.kld = GVar(torch.from_numpy(vMF._vmf_kld(kappa, lat_dim)).float())
        print('KLD: {}'.format(self.kld.data[0]))

    def estimate_param(self, latent_code):
        ret_dict = {}
        ret_dict['kappa'] = self.kappa

        # Only compute mu, use mu/mu_norm as mu,
        #  use 1 as norm, use diff(mu_norm, 1) as redundant_norm
        mu = self.func_mu(latent_code)

        norm = torch.norm(mu, 2, 1, keepdim=True)
        mu_norm_sq_diff_from_one = torch.pow(torch.add(norm, -1), 2)
        redundant_norm = torch.sum(mu_norm_sq_diff_from_one, dim=1, keepdim=True)
        ret_dict['norm'] = torch.ones_like(mu)
        ret_dict['redundant_norm'] = redundant_norm

        mu = mu / torch.norm(mu, p=2, dim=1, keepdim=True)
        ret_dict['mu'] = mu

        return ret_dict

    def compute_KLD(self, tup, batch_sz):
        return self.kld.expand(batch_sz)

    @staticmethod
    def _vmf_kld(k, d):
        tmp = (k * ((sp.iv(d / 2.0 + 1.0, k) + sp.iv(d / 2.0, k) * d / (2.0 * k)) / sp.iv(d / 2.0, k) - d / (2.0 * k)) \
               + d * np.log(k) / 2.0 - np.log(sp.iv(d / 2.0, k)) \
               - sp.loggamma(d / 2 + 1) - d * np.log(2) / 2).real
        if tmp != tmp:
            exit()
        return np.array([tmp])

    def build_bow_rep(self, lat_code, n_sample):
        batch_sz = lat_code.size()[0]
        tup = self.estimate_param(latent_code=lat_code)
        mu = tup['mu']
        norm = tup['norm']
        kappa = tup['kappa']

        kld = self.compute_KLD(tup, batch_sz)
        vecs = []
        if n_sample == 1:
            return tup, kld, self.sample_cell(mu, norm, kappa)
        for n in range(n_sample):
            sample = self.sample_cell(mu, norm, kappa)
            vecs.append(sample)
        vecs = torch.cat(vecs, dim=0)
        return tup, kld, vecs

    def sample_cell(self, mu, norm, kappa):
        batch_sz, lat_dim = mu.size()
        result = []
        sampled_vecs = GVar(torch.FloatTensor(batch_sz, lat_dim))
        for b in range(batch_sz):
            this_mu = mu[b]
            # kappa = np.linalg.norm(this_theta)
            this_mu = this_mu / torch.norm(this_mu, p=2)

            w = self._sample_weight(kappa, lat_dim)
            w_var = GVar(w * torch.ones(lat_dim))

            v = self._sample_orthonormal_to(this_mu, lat_dim)

            scale_factr = torch.sqrt(GVar(torch.ones(lat_dim)) - torch.pow(w_var, 2))
            orth_term = v * scale_factr
            muscale = this_mu * w_var
            sampled_vec = orth_term + muscale
            sampled_vecs[b] = sampled_vec
            # sampled_vec = torch.FloatTensor(sampled_vec)
            # result.append(sampled_vec)

        return sampled_vecs.unsqueeze(0)

    def _sample_weight(self, kappa, dim):
        """Rejection sampling scheme for sampling distance from center on
        surface of the sphere.
        """
        dim = dim - 1  # since S^{n-1}
        b = dim / (np.sqrt(4. * kappa ** 2 + dim ** 2) + 2 * kappa)  # b= 1/(sqrt(4.* kdiv**2 + 1) + 2 * kdiv)
        x = (1. - b) / (1. + b)
        c = kappa * x + dim * np.log(1 - x ** 2)  # dim * (kdiv *x + np.log(1-x**2))

        while True:
            z = np.random.beta(dim / 2., dim / 2.)  # concentrates towards 0.5 as d-> inf
            w = (1. - (1. + b) * z) / (1. - (1. - b) * z)
            u = np.random.uniform(low=0, high=1)
            if kappa * w + dim * np.log(1. - x * w) - c >= np.log(
                    u):  # thresh is dim *(kdiv * (w-x) + log(1-x*w) -log(1-x**2))
                return w

    def _sample_orthonormal_to(self, mu, dim):
        """Sample point on sphere orthogonal to mu.
        """
        v = GVar(torch.randn(dim))
        rescale_value = mu.dot(v) / mu.norm()
        proj_mu_v = mu * rescale_value.expand(dim)
        ortho = v - proj_mu_v
        ortho_norm = torch.norm(ortho)
        return ortho / ortho_norm.expand_as(ortho)
예제 #29
0
    def train_epo(self, args, model, train_batches, epo, epo_start_time):
        model.train()
        start_time = time.time()

        if self.args.optim == 'sgd':
            self.optim = torch.optim.SGD(model.parameters(),
                                         lr=self.args.cur_lr)
        else:
            raise NotImplementedError
        acc_loss = 0
        acc_kl_loss = 0
        acc_aux_loss = 0
        acc_avg_cos = 0
        acc_avg_norm = 0
        # acc_real_loss = 0

        word_cnt = 0
        doc_cnt = 0
        cnt = 0
        random.shuffle(train_batches)
        for idx, batch in enumerate(train_batches):
            self.optim.zero_grad()

            self.glob_iter += 1
            data_batch, count_batch = DataNg.fetch_data(
                self.data.train[0], self.data.train[1], batch,
                self.data.vocab_size)

            model.zero_grad()

            data_batch = GVar(torch.FloatTensor(data_batch))

            recon_loss, kld, aux_loss, tup, vecs = model(data_batch)
            # print("Recon: {}\t KL: {}".format(recon_loss,kld))
            # total_loss = torch.mean(recon_loss + kld * args.kl_weight)
            total_loss = torch.mean(recon_loss + kld * args.kl_weight +
                                    aux_loss * args.aux_weight)
            total_loss.backward()

            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            self.optim.step()

            count_batch = GVar(torch.FloatTensor(count_batch))
            doc_num = len(count_batch)

            # real_loss = torch.div((recon_loss + kld).data, count_batch)
            # acc_real_loss += torch.sum(real_loss)

            acc_loss += torch.sum(recon_loss).item()
            acc_kl_loss += torch.sum(kld).item()
            acc_aux_loss += torch.sum(aux_loss).item()
            acc_avg_cos += tup['avg_cos'].item()
            acc_avg_norm += tup['avg_norm'].item()
            cnt += 1

            count_batch = count_batch + 1e-12
            word_cnt += torch.sum(count_batch).item()
            doc_cnt += doc_num

            if idx % args.log_interval == 0 and idx > 0:
                cur_loss = acc_loss / word_cnt  # word loss
                cur_kl = acc_kl_loss / word_cnt
                cur_aux_loss = acc_aux_loss / word_cnt
                cur_avg_cos = acc_avg_cos / cnt
                cur_avg_norm = acc_avg_norm / cnt
                # cur_real_loss = acc_real_loss / doc_cnt
                cur_real_loss = cur_loss + cur_kl

                # if cur_kl < 0.14 or cur_kl > 1.2:
                #     raise KeyboardInterrupt

                Runner.log_instant(self.writer, self.args, self.glob_iter, epo,
                                   start_time, cur_avg_cos, cur_avg_norm,
                                   cur_loss, cur_kl, cur_aux_loss,
                                   cur_real_loss)
                acc_loss = 0
                acc_kl_loss = 0
                acc_aux_loss = 0
                acc_avg_cos = 0
                acc_avg_norm = 0
                word_cnt = 0
                doc_cnt = 0
                cnt = 0
            if idx % (3 * args.log_interval) == 0 and idx > 0:
                with torch.no_grad():
                    self.eval_interface()
예제 #30
0
    def train_epo(self, args, model, train_batches, epo, epo_start_time,
                  glob_iter):
        model.train()
        model.FLAG_train = True
        start_time = time.time()

        if self.args.optim == 'sgd':
            self.optim = torch.optim.SGD(model.parameters(),
                                         lr=self.args.cur_lr)

        acc_loss = 0
        acc_kl_loss = 0
        acc_aux_loss = 0
        acc_avg_cos = 0
        acc_avg_norm = 0

        batch_cnt = 0
        all_cnt = 0
        cnt = 0

        random.shuffle(train_batches)
        for idx, batch in enumerate(train_batches):
            self.optim.zero_grad()
            seq_len, batch_sz = batch.size()
            if self.data.condition:
                seq_len -= 1

                if self.model.input_cd_bit > 1:
                    bit = batch[0, :]
                    bit = GVar(bit)
                else:
                    bit = None
                batch = batch[1:, :]
            else:
                bit = None
            feed = self.data.get_feed(batch)

            if self.args.swap > 0.00001:
                feed = swap_by_batch(feed, self.args.swap)
            if self.args.replace > 0.00001:
                feed = replace_by_batch(feed, self.args.replace,
                                        self.model.ntoken)

            self.glob_iter += 1

            target = GVar(batch)

            recon_loss, kld, aux_loss, tup, vecs, _ = model(feed, target, bit)
            total_loss = recon_loss * seq_len + torch.mean(
                kld) * self.args.kl_weight + torch.mean(
                    aux_loss) * args.aux_weight

            total_loss.backward()

            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            # torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)        # Upgrade to pytorch 0.4.1
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.clip,
                                           norm_type=2)

            self.optim.step()

            acc_loss += recon_loss.data * seq_len * batch_sz
            acc_kl_loss += torch.sum(kld).data
            acc_aux_loss += torch.sum(aux_loss).data
            acc_avg_cos += tup['avg_cos'].data
            acc_avg_norm += tup['avg_norm'].data

            cnt += 1
            batch_cnt += batch_sz
            all_cnt += batch_sz * seq_len
            if idx % args.log_interval == 0 and idx > 0:
                cur_loss = acc_loss.item() / all_cnt
                cur_kl = acc_kl_loss.item() / all_cnt
                # if cur_kl < 0.03:
                #     raise KeyboardInterrupt
                # if cur_kl > 0.7:
                #     raise KeyboardInterrupt
                cur_aux_loss = acc_aux_loss.item() / all_cnt
                cur_avg_cos = acc_avg_cos.item() / cnt
                cur_avg_norm = acc_avg_norm.item() / cnt
                cur_real_loss = cur_loss + cur_kl
                Runner.log_instant(self.writer, self.args, self.glob_iter, epo,
                                   start_time, cur_avg_cos, cur_avg_norm,
                                   cur_loss, cur_kl, cur_aux_loss,
                                   cur_real_loss)