コード例 #1
0
ファイル: hrnn.py プロジェクト: zengjichuan/DTDMN
    def forward(self, batch_data, mode=TRAIN):
        # depack batch input
        batch_pos_utts_bow = batch_data.batch_pos_utts_bow
        batch_pos_utts_seq = batch_data.batch_pos_utts_seq
        batch_neg_utts_bow = batch_data.batch_neg_utts_bow
        batch_neg_utts_seq = batch_data.batch_neg_utts_seq
        batch_pos_masks = batch_data.batch_pos_masks
        batch_neg_masks = batch_data.batch_neg_masks
        batch_pos_lens = batch_data.batch_pos_lens
        batch_neg_lens = batch_data.batch_neg_lens
        batch_pos_words_lens = batch_data.batch_pos_words_lens
        batch_neg_words_lens = batch_data.batch_neg_words_lens

        pos_c_inputs, _, _, _ = self.utt_encoder(batch_pos_utts_seq, batch_pos_words_lens, return_all=True)
        pos_c_outs, _, _, pos_ctx_attn = self.ctx_encoder(pos_c_inputs, lens=batch_pos_lens, masks=batch_pos_masks, return_all=True)
        neg_c_inputs, _, _, _ = self.utt_encoder(batch_neg_utts_seq, batch_neg_words_lens, return_all=True)
        neg_c_outs, _, _, neg_ctx_attn = self.ctx_encoder(neg_c_inputs, lens=batch_neg_lens, masks=batch_neg_masks, return_all=True)
        pos_c_outs, neg_c_outs = pos_c_outs.squeeze(0), neg_c_outs.squeeze(0)


        pos_pred_logit = self.predictor(pos_c_outs)
        neg_pred_logit = self.predictor(neg_c_outs)

        if mode == INFER:
            pred = pos_pred_logit > neg_pred_logit
            results = Pack(pred=pred)
            return results

        # loss
        bce_loss = self.pair_bce_loss(pos_pred_logit, neg_pred_logit)

        results = Pack(bce_loss=bce_loss)
        return results
コード例 #2
0
ファイル: data_loaders.py プロジェクト: zengjichuan/DTDMN
    def permute_dialog(self, data):
        rst = []
        for dialog in data:
            if self.name == "train":
                pairs = list(
                    itertools.product(
                        zip(dialog.pos_conv_bow_lst, dialog.pos_conv_seq_lst),
                        zip(dialog.neg_conv_bow_lst, dialog.neg_conv_seq_lst)))
                for (pos_utts_bow, pos_utts_seq), (neg_utts_bow,
                                                   neg_utts_seq) in pairs:
                    # length filter
                    pos_turns_len = len(pos_utts_bow)
                    neg_turns_len = len(neg_utts_bow)

                    if pos_turns_len <= neg_turns_len:
                        neg_utts_bow = neg_utts_bow[:pos_turns_len]
                        neg_utts_seq = neg_utts_seq[:pos_turns_len]
                        rst.append(
                            Pack(op=dialog.op,
                                 title=dialog.title,
                                 pos_utts_bow=pos_utts_bow,
                                 pos_utts_seq=pos_utts_seq,
                                 neg_utts_bow=neg_utts_bow,
                                 neg_utts_seq=neg_utts_seq))
            elif self.name == "test":
                # if in test dataset, randomly select one pos and one neg in a conv
                pairs = list(
                    itertools.product(
                        zip(dialog.pos_conv_bow_lst, dialog.pos_conv_seq_lst),
                        zip(dialog.neg_conv_bow_lst, dialog.neg_conv_seq_lst)))
                tmp_lst = []
                for (pos_utts_bow, pos_utts_seq), (neg_utts_bow,
                                                   neg_utts_seq) in pairs:
                    # length filter
                    pos_turns_len = len(pos_utts_bow)
                    neg_turns_len = len(neg_utts_bow)
                    if pos_turns_len <= neg_turns_len:
                        tmp_lst.append(((pos_utts_bow, pos_utts_seq),
                                        (neg_utts_bow[:pos_turns_len],
                                         neg_utts_seq[:pos_turns_len])))
                if not tmp_lst:
                    continue
                (pos_utts_bow,
                 pos_utts_seq), (neg_utts_bow,
                                 neg_utts_seq) = random.choice(tmp_lst)
                rst.append(
                    Pack(op=dialog.op,
                         title=dialog.title,
                         pos_utts_bow=pos_utts_bow,
                         pos_utts_seq=pos_utts_seq,
                         neg_utts_bow=neg_utts_bow,
                         neg_utts_seq=neg_utts_seq))

        logger.info("%d conversation pairs after product" % (len(rst)))
        return rst
コード例 #3
0
ファイル: hrnn.py プロジェクト: zengjichuan/DTDMN
    def epoch_forward(self, batch_data, mask):
        # this is the number of time steps we need to process in the mini-batch
        T_max = batch_data.size(1)

        x_logit_lst = []
        z_sample_lst = []
        z_mu_lst = []
        z_logvar_lst = []

        for t in range(T_max):
            vae_x_resp = self.pxz_forward(self.qzx_forward(batch_data[:, t, :]))

            x_logit_lst.append(vae_x_resp.x_logit.unsqueeze(1))
            z_sample_lst.append(vae_x_resp.sample_z.unsqueeze(1))
            z_mu_lst.append(vae_x_resp.z_mu.unsqueeze(1))
            z_logvar_lst.append(vae_x_resp.z_logvar.unsqueeze(1))

        x_logit_seq = torch.cat(x_logit_lst, dim=1)
        z_sample_seq = torch.cat(z_sample_lst, dim=1)
        z_mu_seq = torch.cat(z_mu_lst, dim=1)
        z_logvar_seq = torch.cat(z_logvar_lst, dim=1)

        # for prediction
        pred_logit = self.predictor(mask_mean(z_sample_seq, mask))
        return Pack(x_logit_seq=x_logit_seq, z_mu_seq=z_mu_seq, z_logvar_seq=z_logvar_seq, z_sample_seq=z_sample_seq,
                    pred_logit=pred_logit)
コード例 #4
0
    def qzc_forward(self, ctx_utts):
        ctx_out = F.tanh(self.ctx_encoder(ctx_utts))
        z_mu = self.q_z_mu(ctx_out)
        z_logvar = self.q_z_logvar(ctx_out)

        sample_z = self.reparameterize(z_mu, z_logvar)
        return Pack(sample_z=sample_z, z_mu=z_mu, z_logvar=z_logvar)
コード例 #5
0
ファイル: main.py プロジェクト: justinchiu/hmmlm-jax
def eval_loop(
    args,
    V,
    iter,
    model,
):
    total_ll = 0
    total_elbo = 0
    n = 0
    lpz, last_states = None, None
    with th.no_grad():
        for i, batch in enumerate(iter):
            model.train(False)
            if hasattr(model, "noise_scale"):
                model.noise_scale = 0
            mask, lengths, n_tokens = get_mask_lengths(batch.text, V)
            if args.iterator != "bptt":
                lpz, last_states = None, None
            losses, lpz, _ = model.score(
                batch.text,
                lpz=lpz,
                last_states=last_states,
                mask=mask,
                lengths=lengths,
            )
            total_ll += losses.evidence.detach()
            if losses.elbo is not None:
                total_elbo += losses.elbo.detach()
            n += n_tokens
    return Pack(evidence=total_ll, elbo=total_elbo), n
コード例 #6
0
ファイル: hrnn.py プロジェクト: zengjichuan/DTDMN
    def qzx_forward(self, batch_utt):
        x_out = torch.tanh(self.x_encoder(batch_utt))
        z_mu = self.q_z_mu(x_out)
        z_logvar = self.q_z_logvar(x_out)

        sample_z = self.reparameterize(z_mu, z_logvar)
        return Pack(sample_z=sample_z, z_mu=z_mu, z_logvar=z_logvar)
コード例 #7
0
ファイル: data_loaders.py プロジェクト: zengjichuan/DTDMN
 def permute_dialog(self, data):
     rst = []
     for dialog in data:
         pairs = list(
             itertools.product(
                 zip(dialog.pos_conv_bow_lst, dialog.pos_conv_seq_lst),
                 zip(dialog.neg_conv_bow_lst, dialog.neg_conv_seq_lst)))
         for (pos_utts_bow, pos_utts_seq), (neg_utts_bow,
                                            neg_utts_seq) in pairs:
             # length filter, here we turncat the longer one
             pos_turns_len = len(pos_utts_bow)
             neg_turns_len = len(neg_utts_bow)
             if pos_turns_len <= neg_turns_len:
                 neg_utts_bow = neg_utts_bow[:pos_turns_len]
                 neg_utts_seq = neg_utts_seq[:pos_turns_len]
             else:
                 pos_utts_bow = pos_utts_bow[:neg_turns_len]
                 pos_utts_seq = pos_utts_seq[:neg_turns_len]
             rst.append(
                 Pack(pos_utts_bow=pos_utts_bow,
                      pos_utts_seq=pos_utts_seq,
                      neg_utts_bow=neg_utts_bow,
                      neg_utts_seq=neg_utts_seq))
     logger.info("%d conversation pairs after product" % (len(rst)))
     return rst
コード例 #8
0
ファイル: dtdmn.py プロジェクト: zengjichuan/DTDMN
    def forward(self, w_corr, utt_emb, pre_mem, wr_corr=None):
        """
        Forward
        :param w_corr: correlation weight (batch_size, mem_size)
        :param utt_emb: utterance embedding (batch_size, emb_size)
        :param pre_mem: memory (mem_size, mem_state_dim)
        :return: mem: (mem_size, mem_state_dim)  read_content: (batch_size, mem_state_size)
        """
        # write process
        e_f = F.sigmoid(self.e_conn(utt_emb))  #(batch_size, mem_state_dim)
        e_w = 1 - torch.bmm(w_corr.view(
            -1, self.mem_size, 1), e_f.view(
                -1, 1,
                self.mem_state_size))  #(batch_size, mem_size, mem_state_dim)
        a_f = torch.tanh(self.a_conn(utt_emb))  #(batch_size, mem_state_dim)
        a_w = torch.bmm(w_corr.view(-1, self.mem_size, 1),
                        a_f.view(-1, 1, self.mem_state_size))
        mem = pre_mem * e_w + a_w

        # read process
        if wr_corr is not None:
            read_content = torch.bmm(wr_corr.view(-1, 1, self.mem_size),
                                     mem).view(-1, self.mem_state_size)
        else:
            read_content = torch.bmm(w_corr.view(-1, 1, self.mem_size),
                                     mem).view(-1, self.mem_state_size)
        return Pack(mem=mem, read_content=read_content)
コード例 #9
0
ファイル: dtdmn.py プロジェクト: zengjichuan/DTDMN
 def get_batch(self, data_feed):
     """
     process data batch and tensorlize
     :param data_feed:
     :return:
     """
     batch_pos_utts_seq = self.np2var(data_feed.pos_utts_seq, LONG)
     batch_neg_utts_seq = self.np2var(data_feed.neg_utts_seq, LONG)
     batch_pos_utts_bow = self.np2var(data_feed.pos_utts_bow, FLOAT)
     batch_neg_utts_bow = self.np2var(data_feed.neg_utts_bow, FLOAT)
     batch_pos_masks = self.np2var(data_feed.pos_masks, FLOAT)
     batch_neg_masks = self.np2var(data_feed.neg_masks, FLOAT)
     batch_pos_lens = self.np2var(data_feed.pos_lens, LONG)
     batch_neg_lens = self.np2var(data_feed.neg_lens, LONG)
     batch_pos_words_lens = self.np2var(data_feed.pos_words_lens, LONG)
     batch_neg_words_lens = self.np2var(data_feed.neg_words_lens, LONG)
     batch_data = Pack(batch_pos_utts_seq=batch_pos_utts_seq,
                       batch_neg_utts_seq=batch_neg_utts_seq,
                       batch_pos_utts_bow=batch_pos_utts_bow,
                       batch_neg_utts_bow=batch_neg_utts_bow,
                       batch_pos_masks=batch_pos_masks,
                       batch_neg_masks=batch_neg_masks,
                       batch_pos_lens=batch_pos_lens,
                       batch_neg_lens=batch_neg_lens,
                       batch_pos_words_lens=batch_pos_words_lens,
                       batch_neg_words_lens=batch_neg_words_lens)
     return batch_data
コード例 #10
0
ファイル: corpora.py プロジェクト: zengjichuan/DTDMN
    def _to_id_corpus_bow(self, data, vocab):
        results = []
        word_cnt = 0
        msg_cnt = 0

        for dialog in data:
            # convert utterance and feature into numeric numbers
            id_dialog = Pack(pos_conv_bow_lst=[], neg_conv_bow_lst=[])
            for turns in dialog["pos_conv_lst"]:
                new_turns = []
                for turn in turns:
                    id_turn = self._sent2id_bow(turn, vocab)
                    if id_turn:  # filter empty utt
                        new_turns.append(id_turn)
                        word_cnt += np.sum([j for i, j in id_turn])
                        msg_cnt += 1
                if new_turns:
                    id_dialog["pos_conv_bow_lst"].append(new_turns)
            for turns in dialog["neg_conv_lst"]:
                new_turns = []
                for turn in turns:
                    id_turn = self._sent2id_bow(turn, vocab)
                    if id_turn:  # filter empty utt
                        new_turns.append(id_turn)
                        word_cnt += np.sum([j for i, j in id_turn])
                        msg_cnt += 1
                if new_turns:
                    id_dialog["neg_conv_bow_lst"].append(new_turns)
            if id_dialog.pos_conv_bow_lst and id_dialog.neg_conv_bow_lst:
                results.append(id_dialog)
        print("Load bow with %d msgs, %d words" % (msg_cnt, word_cnt))
        return results, msg_cnt, word_cnt
コード例 #11
0
ファイル: corpora.py プロジェクト: zengjichuan/DTDMN
 def get_corpus(self):
     id_train = self._to_id_corpus(self.train_corpus, self.vocab_seq,
                                   self.vocab_bow)
     id_test = self._to_id_corpus(self.test_corpus, self.vocab_seq,
                                  self.vocab_bow)
     return Pack(train=id_train,
                 test=id_test,
                 vocab_size=len(self.vocab_bow))
コード例 #12
0
ファイル: corpora.py プロジェクト: zengjichuan/DTDMN
 def get_corpus_bow(self, keep_stopwords=True):
     if keep_stopwords:
         vocab = self.vocab_bow
     else:
         vocab = self.vocab_bow_non_stopwords
     id_train = self._to_id_corpus_bow(self.train_corpus, vocab)
     id_test = self._to_id_corpus_bow(self.test_corpus, vocab)
     return Pack(train=id_train, test=id_test, vocab_size=len(vocab))
コード例 #13
0
    def qdx_forward(self, tar_utts):
        qd_logits = self.x_encoder(tar_utts).view(-1, self.config.d)
        qd_logits_multi = qd_logits.repeat(self.config.d_size, 1, 1)
        sample_d_multi, d_ids_multi = self.cat_connector(qd_logits_multi,
                                                         1.0,
                                                         self.use_gpu,
                                                         return_max_id=True)
        sample_d = sample_d_multi.mean(0)
        d_ids = d_ids_multi.view(self.config.d_size, -1).transpose(0, 1)

        return Pack(qd_logits=qd_logits, sample_d=sample_d, d_ids=d_ids)
コード例 #14
0
ファイル: dtdmn.py プロジェクト: zengjichuan/DTDMN
 def forward(self, bow_data):
     x_out = torch.tanh(self.x_encoder(bow_data))
     z_mu, z_logvar = self.q_z_mu(x_out), self.q_z_logvar(x_out)
     sample_z = self.reparameterize(z_mu, z_logvar)
     z_gen = self.generator(sample_z)
     z_softmax = F.softmax(z_gen, dim=z_gen.dim() - 1)
     x_logit = self.x_decoder(z_gen)
     return Pack(z_mu=z_mu,
                 z_logvar=z_logvar,
                 z_softmax=z_softmax,
                 x_logit=x_logit,
                 z_gen=z_gen)
コード例 #15
0
    def __init__(self, sp_id, sp_secret):
        _version = Pack.get_unsigned_char_data(CMPP_VERSION)

        _sp_id = sp_id.encode('utf-8')
        _sp_secret = sp_secret.encode("utf-8")
        _time_str = utils.get_string_time()

        self.auth_source = utils.get_md5_digest(_sp_id + 9 * b'\x00' +
                                                _sp_secret +
                                                _time_str.encode("utf-8"))
        message_body = _sp_id + self.auth_source + _version + Pack. \
            get_unsigned_long_data(int(_time_str))
        RequestInstance.__init__(self, CMPP_CONNECT_REQ, message_body)
コード例 #16
0
 def flatten_dialog(self, data, window_size):
     results = []
     for dialog in data:
         for i in range(len(dialog)):
             c_id = i
             s_id = max(0, c_id - window_size // 2)
             e_id = min(len(dialog), s_id + window_size)
             target = copy.copy(dialog[i])
             contexts = []
             for turn in dialog[s_id:e_id]:
                 contexts.append(turn)
             results.append(Pack(context=contexts, target=target))
     return results
コード例 #17
0
    def compute_loss(
        self,
        log_potentials,
        mask,
        lengths,
        keep_counts=False,
    ):
        blah = False
        if blah:
            import strux
            lc = strux.LinearChain()
            lpnp = log_potentials.cpu().detach().numpy(
            )  #.transpose(0, 1, 3, 2)
            with open("test.npy", "wb") as f:
                np.save(f, lpnp)
            lengthsnp = lengths.cpu().detach().numpy()
            Z = lc.sum(lpnp, lengthsnp + 1)  # weird off by 1?
            margs = lc.marginals(lpnp, lengthsnp + 1)
            # cool, Z ~= alpha_T2.logsumexp(-1) {the evidence}
            # margs ~= log_m2.exp()
            elbo_jax = lpnp * margs
            elbo_jax = elbo_jax[~np.isnan(elbo_jax)].sum()

        N = lengths.shape[0]
        fb = self.fb_train if self.training else self.fb_test
        log_m, alphas = fb(log_potentials.clone(),
                           mask=mask[:, 1:] if self.eval_shorter else mask)

        idx = th.arange(N, device=self.device)
        alpha_T = alphas[lengths - 1 if self.eval_shorter else lengths, idx]
        evidence = alpha_T.logsumexp(-1).sum()
        #elbo = (log_m.exp_() * log_potentials)[mask[:,1:]].sum()
        elbo = (log_m.exp_() *
                log_potentials)[mask[:, 1:] if self.eval_shorter else mask]
        elbo = elbo[elbo == elbo].sum()

        if blah:
            # compare here, Z v evidence, margs v log_m, elbo_jax
            import pdb
            pdb.set_trace()

        return Pack(
            elbo=elbo,
            evidence=evidence,
            loss=elbo,
        ), alpha_T.log_softmax(-1)
コード例 #18
0
ファイル: dtdmn.py プロジェクト: zengjichuan/DTDMN
 def forward(self, bow_data):
     qy_logit = self.x_encoder(bow_data).view(-1, self.disc_num)
     qy_logit_multi = qy_logit.repeat(self.disc_size, 1, 1)
     sample_y_multi, y_ids_multi = self.cat_connector(qy_logit_multi,
                                                      1.0,
                                                      self.use_gpu,
                                                      return_max_id=True)
     sample_y = sample_y_multi.mean(0)
     y_ids = y_ids_multi.view(self.disc_size, -1).transpose(0, 1)
     y_gen = self.generator(sample_y)
     y_softmax = F.softmax(y_gen, dim=y_gen.dim() - 1)
     x_logit = self.x_decoder(y_gen)
     return Pack(qy_logit=qy_logit,
                 y_ids=y_ids,
                 y_softmax=y_softmax,
                 x_logit=x_logit,
                 y_gen=y_gen)
コード例 #19
0
ファイル: corpora.py プロジェクト: zengjichuan/DTDMN
    def _to_id_corpus(self, data, vocab_seq, vocab_bow):
        results = []
        word_cnt = 0
        msg_cnt = 0

        for dialog in data:
            # convert utterance and feature into numeric numbers
            id_dialog = Pack(title=self._sent2id_seq(dialog["title"],
                                                     vocab_seq),
                             op=self._sent2id_seq(dialog["op"], vocab_seq),
                             pos_conv_seq_lst=[],
                             pos_conv_bow_lst=[],
                             neg_conv_seq_lst=[],
                             neg_conv_bow_lst=[])
            for turns in dialog["pos_conv_lst"]:
                new_turns_bow = []
                new_turns_seq = []
                for turn in turns:
                    id_turn_seq = self._sent2id_seq(turn, vocab_seq)
                    id_turn_bow = self._sent2id_bow(turn, vocab_bow)
                    if id_turn_seq and id_turn_bow:  # filter empty utt
                        new_turns_bow.append(id_turn_bow)
                        new_turns_seq.append(id_turn_seq)
                        word_cnt += len(id_turn_seq)
                        msg_cnt += 1
                if new_turns_seq and new_turns_bow:
                    id_dialog["pos_conv_bow_lst"].append(new_turns_bow)
                    id_dialog["pos_conv_seq_lst"].append(new_turns_seq)
            for turns in dialog["neg_conv_lst"]:
                new_turns_bow = []
                new_turns_seq = []
                for turn in turns:
                    id_turn_seq = self._sent2id_seq(turn, vocab_seq)
                    id_turn_bow = self._sent2id_bow(turn, vocab_bow)
                    if id_turn_seq and id_turn_bow:  # filter empty utt
                        new_turns_bow.append(id_turn_bow)
                        new_turns_seq.append(id_turn_seq)
                        word_cnt += len(id_turn_seq)
                        msg_cnt += 1
                if new_turns_seq and new_turns_bow:
                    id_dialog["neg_conv_bow_lst"].append(new_turns_bow)
                    id_dialog["neg_conv_seq_lst"].append(new_turns_seq)
            if id_dialog.pos_conv_bow_lst and id_dialog.neg_conv_bow_lst:
                results.append(id_dialog)
        print("Load seq with %d msgs, %d words" % (msg_cnt, word_cnt))
        return results, msg_cnt, word_cnt
コード例 #20
0
ファイル: main.py プロジェクト: justinchiu/hmmlm-jax
def cached_eval_loop(
    args,
    V,
    iter,
    model,
):
    total_ll = 0
    total_elbo = 0
    n = 0
    with th.no_grad():
        model.train(False)
        lpz = None
        start, transition, emission = model.compute_parameters(
            model.word2state)
        word2state = model.word2state
        for i, batch in enumerate(iter):
            if hasattr(model, "noise_scale"):
                model.noise_scale = 0

            text = batch.text

            mask, lengths, n_tokens = get_mask_lengths(text, V)
            N, T = text.shape

            if lpz is not None and args.iterator == "bptt":
                start = (lpz[:, :, None] +
                         transition[last_states, :]).logsumexp(1)

            log_potentials = (model.clamp(
                text, start, transition, emission,
                word2state) if model.eval_shorter else model.clamp2(
                    text, start, transition, emission, word2state))
            losses, lpz = model.compute_loss(log_potentials, mask, lengths)

            if word2state is not None:
                idx = th.arange(N, device=model.device)
                last_words = text[idx, lengths - 1]
                last_states = model.word2state[last_words]

            total_ll += losses.evidence.detach()
            if losses.elbo is not None:
                total_elbo += losses.elbo.detach()
            n += n_tokens
    return Pack(evidence=total_ll, elbo=total_elbo), n
コード例 #21
0
ファイル: factoredhmmlm.py プロジェクト: sustcsonglin/hmm-lm
    def score(
        self,
        text,
        lpz=None,
        last_states=None,
        mask=None,
        lengths=None,
    ):
        N, T = text.shape
        if self.training:
            I = (th.distributions.Gumbel(self.zero, self.one).sample(
                self.cluster2state.shape).squeeze(-1).topk(
                    self.train_states_per_word, dim=-1).indices)
            states = self.cluster2state.gather(1, I).view(-1)
        else:
            states = None
        if self.timing:
            startpot = timep.time()
        log_potentials = self.log_potentials(
            text,
            states,
            lpz,
            last_states,
        )
        if self.timing:
            print(f"log pot: {timep.time() - startpot}")
        fb = self.fb_train if self.training else self.fb_test
        with th.no_grad():
            log_m, alphas = fb(log_potentials.detach().clone(), mask=mask)
        idx = th.arange(N, device=self.device)
        alpha_T = alphas[lengths - 1, idx]
        evidence = alpha_T.logsumexp(-1).sum()
        elbo = (log_m.exp_() * log_potentials)[mask[:, 1:]].sum()

        last_words = text[idx, lengths - 1]
        c2s = states.view(self.config.num_clusters, -1)
        end_states = c2s[self.word2cluster[last_words]]

        return Pack(
            elbo=elbo,
            evidence=evidence,
            loss=elbo,
        ), alpha_T.log_softmax(-1), end_states
コード例 #22
0
    def _prepare_batch(self, selected_index):
        rows = [self.data[idx] for idx in selected_index]
        # input_context, context_lens, floors, topics, a_profiles, b_Profiles, outputs, output_lens
        context_lens, context_utts, target_utts, target_lens = [], [], [], []
        metas = []
        hashtags = []
        for row in rows:
            ctx = row.context
            target = row.target

            target_utt = target.utt
            context_lens.append(len(ctx))
            context_utts.append([turn.utt for turn in ctx])

            target_utts.append(target_utt)
            target_lens.append(len(target_utt))
            hashtags.append(target.hashtag)
            metas.append(target.meta)

        vec_context_lens = np.array(context_lens)
        vec_context = np.zeros(
            (len(vec_context_lens), np.max(vec_context_lens), self.vocab_size),
            dtype=np.int32)
        vec_targets = np.zeros((len(vec_context_lens), self.vocab_size),
                               dtype=np.int32)
        vec_target_lens = np.array(target_lens)

        for b_id in range(len(vec_context_lens)):
            vec_targets[b_id, :] = self._bow2vec(target_utts[b_id],
                                                 self.vocab_size)
            # fill the context tensor
            new_array = np.empty((vec_context_lens[b_id], self.vocab_size))
            new_array.fill(0)
            for i, row in enumerate(context_utts[b_id]):
                new_array[i, :] = self._bow2vec(row, self.vocab_size)
            vec_context[b_id, 0:vec_context_lens[b_id], :] = new_array

        return Pack(contexts=vec_context,
                    context_lens=vec_context_lens,
                    targets=vec_targets,
                    targets_lens=vec_target_lens,
                    metas=metas,
                    hashtags=hashtags)
コード例 #23
0
ファイル: factoredhmmlm.py プロジェクト: sustcsonglin/hmm-lm
    def compute_loss(
        self,
        log_potentials,
        mask,
        lengths,
        keep_counts=False,
    ):
        N = lengths.shape[0]
        fb = self.fb_train if self.training else self.fb_test
        log_m, alphas = fb(log_potentials.clone(), mask=mask)

        idx = th.arange(N, device=self.device)
        alpha_T = alphas[lengths - 1, idx]
        evidence = alpha_T.logsumexp(-1).sum()
        elbo = (log_m.exp_() * log_potentials)[mask[:, 1:]].sum()

        return Pack(
            elbo=elbo,
            evidence=evidence,
            loss=elbo,
        ), alpha_T.log_softmax(-1)
コード例 #24
0
def get_pack(set_number, pack_number):
    try:
        username = session['username']
    except KeyError:
        return ("You are not logged in <br><a href = '/login'></b>"
                "click here to log in</b></a>")

    # Get pack information
    cursor.execute(
        "select id_pack, cnt, s.set_nb, pos, s.size from pack p join sets s "
        "on s.set_nb = p.set_nb where p.set_nb == ? AND p.pos == ?",
        [set_number, pack_number])
    p = cursor.fetchone()
    pack = Pack(*p)

    # Get tweets from this pack
    cursor.execute(
        "SELECT * FROM tweet t WHERE t.pack == ?"
        " ORDER BY CAST(id_tweet AS INTEGER)", [pack.id])
    tweets = []
    tweets_hash = {}

    # Manage replies and quotes
    for t in cursor.fetchall():
        tweet = Tweet(*t)
        tweets_hash[tweet.id] = tweet
        if tweet.replies_to in tweets_hash:
            tweets_hash[tweet.replies_to].replies.append(tweet)
        elif tweet.quoting_id in tweets_hash:
            tweets_hash[tweet.quoting_id].replies.append(tweet)
        else:
            tweets.append(tweet)

    logging.debug([t.text for t in tweets])

    return render_template('annotate_pack.html',
                           tweets=tweets,
                           username=username,
                           pack=pack)
コード例 #25
0
def annotate_pack(set_number, pack_number):
    cursor.execute(
        "select id_pack, cnt, s.set_nb, pos, s.size from pack p join sets s "
        "on s.set_nb = p.set_nb where p.set_nb == ? AND p.pos == ?",
        [set_number, pack_number])
    p = cursor.fetchone()
    pack = Pack(*p)

    logging.info(f"pack annotation: {request.form['annotation']}")
    cursor.execute("INSERT OR REPLACE INTO pack_annotation VALUES (?,?,?)",
                   (pack.id, session['username'], request.form['annotation']))
    ids = request.form.getlist('tweet_id')[1:]
    annotations_prev = request.form.getlist('tweet_annotation_prev')
    annotations_src = request.form.getlist('tweet_annotation_src')

    logging.debug(ids)
    logging.debug(annotations_prev)
    logging.debug(annotations_src)

    for t_id, annotation_prev, annotation_src in zip(ids, annotations_prev,
                                                     annotations_src):
        logging.debug(
            f"INSERT OR REPLACE INTO tweet_annotation VALUES "
            f"({t_id}, {session['username']}, {annotation_prev}, {annotation_src})"
        )

        cursor.execute(
            "INSERT OR REPLACE INTO tweet_annotation VALUES (?,?,?,?)",
            (t_id, session['username'], annotation_prev, annotation_src))
    database.commit()

    # Next pack
    if pack.pos == pack.set_size:
        return redirect(f'/next_set/{pack.set_nb}')
    else:
        return redirect(f'/set/{pack.set_nb}/pack/{pack.pos + 1}')
コード例 #26
0
ファイル: data_loaders.py プロジェクト: zengjichuan/DTDMN
    def _prepare_batch(self, selected_index):
        rows = [self.data[idx] for idx in selected_index]
        pos_utts_lens, neg_utts_lens, pos_utts_bow_lst, pos_utts_seq_lst, neg_utts_bow_lst, neg_utts_seq_lst = [], [], [], [], [], []
        pos_words_lens, neg_words_lens = [], []
        for row in rows:
            pos_utts_bow_lst.append(row.pos_utts_bow)
            pos_utts_seq_lst.append(row.pos_utts_seq)
            neg_utts_bow_lst.append(row.neg_utts_bow)
            neg_utts_seq_lst.append(row.neg_utts_seq)
            pos_utts_lens.append(len(row.pos_utts_seq))
            neg_utts_lens.append(len(row.neg_utts_seq))
            pos_words_lens.append(list(map(len, row.pos_utts_seq)))
            neg_words_lens.append(list(map(len, row.neg_utts_seq)))

        vec_pos_lens = np.array(pos_utts_lens)
        vec_neg_lens = np.array(neg_utts_lens)

        vec_pos_utts_seq = np.zeros(
            (len(vec_pos_lens), np.max(vec_pos_lens), self.max_utt_size),
            dtype=np.int32)
        vec_neg_utts_seq = np.zeros(
            (len(vec_neg_lens), np.max(vec_neg_lens), self.max_utt_size),
            dtype=np.int32)
        vec_pos_utts_bow = np.zeros(
            (len(vec_pos_lens), np.max(vec_pos_lens), self.vocab_size),
            dtype=np.int32)
        vec_neg_utts_bow = np.zeros(
            (len(vec_neg_lens), np.max(vec_neg_lens), self.vocab_size),
            dtype=np.int32)

        vec_pos_masks = np.zeros((len(vec_pos_lens), np.max(vec_pos_lens)),
                                 dtype=np.int32)
        vec_neg_masks = np.zeros((len(vec_neg_lens), np.max(vec_neg_lens)),
                                 dtype=np.int32)

        vec_pos_words_lens = np.zeros(
            (len(vec_pos_lens), np.max(vec_pos_lens)), dtype=np.int32)
        vec_neg_words_lens = np.zeros(
            (len(vec_neg_lens), np.max(vec_neg_lens)), dtype=np.int32)

        assert len(pos_utts_lens) == len(neg_utts_lens)
        for b_id in range(len(pos_utts_lens)):
            vec_pos_masks[b_id, :vec_pos_lens[b_id]] = np.ones(
                vec_pos_lens[b_id])
            vec_neg_masks[b_id, :vec_neg_lens[b_id]] = np.ones(
                vec_neg_lens[b_id])
            pos_new_array_seq = np.zeros(
                (vec_pos_lens[b_id], self.max_utt_size), dtype=np.int32)
            pos_new_array_bow = np.zeros((vec_pos_lens[b_id], self.vocab_size),
                                         dtype=np.int32)
            neg_new_array_seq = np.zeros(
                (vec_neg_lens[b_id], self.max_utt_size), dtype=np.int32)
            neg_new_array_bow = np.zeros((vec_neg_lens[b_id], self.vocab_size),
                                         dtype=np.int32)

            vec_pos_words_lens[b_id, :vec_pos_lens[b_id]] = np.array(
                pos_words_lens[b_id])
            vec_neg_words_lens[b_id, :vec_neg_lens[b_id]] = np.array(
                neg_words_lens[b_id])

            # for pos
            for i, (pos_seq, pos_bow) in enumerate(
                    zip(pos_utts_seq_lst[b_id], pos_utts_bow_lst[b_id])):
                for j, ele in enumerate(pos_seq[:self.max_utt_size]):
                    pos_new_array_seq[i, j] = ele
                pos_new_array_bow[i, :] = self._bow2vec(
                    pos_bow, self.vocab_size)
            vec_pos_utts_seq[b_id, 0:vec_pos_lens[b_id], :] = pos_new_array_seq
            vec_pos_utts_bow[b_id, 0:vec_pos_lens[b_id], :] = pos_new_array_bow
            # for neg
            for i, (neg_seq, neg_bow) in enumerate(
                    zip(neg_utts_seq_lst[b_id], neg_utts_bow_lst[b_id])):
                for j, ele in enumerate(neg_seq[:self.max_utt_size]):
                    neg_new_array_seq[i, j] = ele
                neg_new_array_bow[i, :] = self._bow2vec(
                    neg_bow, self.vocab_size)
            vec_neg_utts_seq[b_id, 0:vec_neg_lens[b_id], :] = neg_new_array_seq
            vec_neg_utts_bow[b_id, 0:vec_neg_lens[b_id], :] = neg_new_array_bow

        return Pack(pos_utts_seq=vec_pos_utts_seq,
                    neg_utts_seq=vec_neg_utts_seq,
                    pos_utts_bow=vec_pos_utts_bow,
                    neg_utts_bow=vec_neg_utts_bow,
                    pos_masks=vec_pos_masks,
                    neg_masks=vec_neg_masks,
                    pos_lens=vec_pos_lens,
                    neg_lens=vec_neg_lens,
                    pos_words_lens=vec_pos_words_lens,
                    neg_words_lens=vec_neg_words_lens)
コード例 #27
0
ファイル: main.py プロジェクト: justinchiu/hmmlm-jax
def train_loop(
    args,
    V,
    iter,
    model,
    parameters,
    optimizer,
    scheduler,
    valid_iter=None,
    verbose=False,
):
    global WANDB_STEP

    noise_scales = np.linspace(1, 0, args.noise_anneal_steps)
    total_ll = 0
    total_elbo = 0
    n = 0
    # check is performed at end of epoch outside loop as well
    checkpoint = len(iter) // (args.num_checks - 1)
    with th.enable_grad():
        lpz = None
        last_states = None
        for i, batch in enumerate(iter):
            model.train(True)
            WANDB_STEP += 1
            optimizer.zero_grad()

            text = batch.textp1 if "lstm" in args.model else batch.text
            if args.iterator == "bucket":
                lpz = None
                last_states = None

            mask, lengths, n_tokens = get_mask_lengths(text, V)
            if model.timing:
                start_forward = timep.time()

            # check if iterator == bptt
            losses, lpz, last_states = model.score(text,
                                                   lpz=lpz,
                                                   last_states=last_states,
                                                   mask=mask,
                                                   lengths=lengths)

            if model.timing:
                print(f"forward time: {timep.time() - start_forward}")
            total_ll += losses.evidence.detach()
            if losses.elbo is not None:
                total_elbo += losses.elbo.detach()
            n += n_tokens

            loss = -losses.loss / n_tokens
            if model.timing:
                start_backward = timep.time()
            loss.backward()
            if model.timing:
                print(f"backward time: {timep.time() - start_backward}")
            clip_grad_norm_(parameters, args.clip)
            if args.schedule not in valid_schedules:
                # sched before opt since we want step = 1?
                # this is how huggingface does it
                scheduler.step()
            optimizer.step()
            #import pdb; pdb.set_trace()
            #wandb.log({
            #"running_training_loss": total_ll / n,
            #"running_training_ppl": math.exp(min(-total_ll / n, 700)),
            #}, step=WANDB_STEP)

            if verbose and i % args.report_every == args.report_every - 1:
                report(
                    Pack(evidence=total_ll, elbo=total_elbo),
                    n,
                    f"Train batch {i}",
                )

            if valid_iter is not None and i % checkpoint == checkpoint - 1:
                v_start_time = time.time()
                #eval_fn = cached_eval_loop if args.model == "mshmm" else eval_loop
                #valid_losses, valid_n  = eval_loop(
                #valid_losses, valid_n  = cached_eval_loop(
                if args.model == "mshmm" or args.model == "factoredhmm":
                    if args.num_classes > 2**15:
                        eval_fn = mixed_cached_eval_loop
                    else:
                        eval_fn = cached_eval_loop
                elif args.model == "hmm":
                    eval_fn = cached_eval_loop
                else:
                    eval_fn = eval_loop
                valid_losses, valid_n = eval_fn(
                    args,
                    V,
                    valid_iter,
                    model,
                )
                report(valid_losses, valid_n, "Valid eval", v_start_time)
                #wandb.log({
                #"valid_loss": valid_losses.evidence / valid_n,
                #"valid_ppl": math.exp(-valid_losses.evidence / valid_n),
                #}, step=WANDB_STEP)

                update_best_valid(valid_losses, valid_n, model, optimizer,
                                  scheduler, args.name)

                #wandb.log({
                #"lr": optimizer.param_groups[0]["lr"],
                #}, step=WANDB_STEP)
                scheduler.step(valid_losses.evidence)

                # remove this later?
                if args.log_counts > 0 and args.keep_counts > 0:
                    # TODO: FACTOR OUT
                    counts = (model.counts /
                              model.counts.sum(0, keepdim=True))[:, 4:]
                    c, v = counts.shape
                    #cg4 = counts > 1e-4
                    #cg3 = counts > 1e-3
                    cg2 = counts > 1e-2

                    #wandb.log({
                    #"avgcounts@1e-4": cg4.sum().item() / float(v),
                    #"avgcounts@1e-3": cg3.sum().item() / float(v),
                    #"avgcounts@1e-2": cg2.sum().item() / float(v),
                    #"maxcounts@1e-4": cg4.sum(0).max().item() / float(v),
                    #"maxcounts@1e-3": cg3.sum(0).max().item() / float(v),
                    #"maxcounts@1e-2": cg2.sum(0).max().item(),
                    #"mincounts@1e-4": cg4.sum(0).min().item() / float(v),
                    #"mincounts@1e-3": cg3.sum(0).min().item() / float(v),
                    #"mincounts@1e-2": cg2.sum(0).min().item(),
                    #"maxcounts": counts.sum(0).max().item(),
                    #"mincounts": counts.sum(0).min().item(),
                    #}, step=WANDB_STEP)
                    del cg2
                    del counts

    return Pack(evidence=total_ll, elbo=total_elbo), n
コード例 #28
0
 def message(self):
     cm_id = Pack.get_unsigned_long_data(self.command_id)
     seq_no = Pack.get_unsigned_long_data(self.sequence_no)
     t_len = len(self._message_body) + len(cm_id) + len(seq_no) + 4
     return Pack.get_unsigned_long_data(
         t_len) + cm_id + seq_no + self._message_body
コード例 #29
0
    def score(
        self,
        text,
        lpz=None,
        last_states=None,
        mask=None,
        lengths=None,
    ):
        N, T = text.shape
        if self.training:
            I = (th.distributions.Gumbel(self.zero, self.one).sample(
                self.cluster2state.shape).squeeze(-1).topk(
                    self.train_states_per_word, dim=-1).indices)
            states = self.cluster2state.gather(1, I).view(-1)
        else:
            states = None
        if self.timing:
            startpot = timep.time()
        #log_potentials, lp2 = self.log_potentials(
        log_potentials = self.log_potentials(
            text,
            states,
            lpz,
            last_states,
        )
        if self.timing:
            print(f"log pot: {timep.time() - startpot}")
        fb = self.fb_train if self.training else self.fb_test
        idx = th.arange(N, device=self.device)

        blah = False
        if blah:
            import strux
            lc = strux.LinearChain()
            lpnp = log_potentials.cpu().detach().numpy()
            with open("test.npy", "wb") as f:
                np.save(f, lpnp)
            lengthsnp = lengths.cpu().detach().numpy()
            Z = lc.sum(lpnp, lengthsnp + 1)  # weird off by 1?
            margs = lc.marginals(lpnp, lengthsnp + 1)
            # cool, Z ~= alpha_T2.logsumexp(-1) {the evidence}
            # margs ~= log_m2.exp()
            elbo_jax = lpnp * margs
            elbo_jax = elbo_jax[~np.isnan(elbo_jax)].sum()
            # elbo is close, as well

        if self.train_shorter:
            # original
            with th.no_grad():
                log_m, alphas = fb(log_potentials.detach().float(),
                                   mask=mask[:, 1:])
            alpha_T = alphas[lengths - 1, idx]
            evidence = alpha_T.logsumexp(-1).sum()
            elbo = (log_m.exp() * log_potentials)[mask[:, 1:]].sum()
        else:
            # larger
            with th.no_grad():
                #log_m2, alphas2 = fb(lp2.detach(), mask=mask)
                # No mask for testing...necessary at end of batch
                log_m, alphas = fb(log_potentials.detach().clone(), mask=mask)
            alpha_T = alphas[lengths, idx]
            evidence = alpha_T.logsumexp(-1).sum()
            elbo = (log_m.exp() * log_potentials)[mask]
            # elbo2 has nans from 0 * -inf
            elbo = elbo[~elbo.isnan()].sum()

        #print(th.allclose(log_potentials[:, 1:], lp2[:,2:]))
        #print(th.allclose(alpha_T, alpha_T2))
        #print(th.allclose(evidence, evidence2))
        #print((log_m[:,1:] - log_m2[:,2:]).abs().max())
        #print(th.allclose(log_m[:, 1:], log_m2[:, 2:], rtol=1e-4, atol=0))
        #print(th.allclose(log_m[:, 1:], log_m2[:, 2:], rtol=1e-3, atol=0))

        if blah:
            # check Z vs evidence
            # check elbo_jax vs elbo
            import pdb
            pdb.set_trace()

        # bookkeeping for bptt
        last_words = text[idx, lengths - 1]
        c2s = states.view(self.config.num_clusters, -1)
        end_states = c2s[self.word2cluster[last_words]]

        return Pack(
            elbo=elbo,
            evidence=evidence,
            loss=elbo,
        ), alpha_T.log_softmax(-1), end_states
コード例 #30
0
    def __init__(self,
                 msg_src,
                 msg_content,
                 src_id='1064899103013',
                 dest_terminal_id=[
                     '8613900000000',
                 ],
                 pk_total=1,
                 pk_number=1,
                 registered_delivery=0,
                 msg_level=0,
                 service_id='MI',
                 fee_usertype=2,
                 fee_terminal_id="",
                 fee_terminal_type=0,
                 tp_pid=0,
                 tp_udhi=0,
                 msg_fmt=8,
                 feetype='01',
                 feecode='000000',
                 valid_time=17 * '\x00',
                 at_time=17 * '\x00',
                 dest_terminal_type=0,
                 linkid=''):

        if len(msg_content) >= 70:
            raise ValueError("msg_content more than 70 words")
        if len(dest_terminal_id) > 100:
            raise ValueError("single submit more than 100 phone numbers")

        _msg_id = 8 * b'\x00'
        _pk_total = Pack.get_unsigned_char_data(pk_total)
        _pk_number = Pack.get_unsigned_char_data(pk_number)
        _registered_delivery = Pack.get_unsigned_char_data(registered_delivery)
        _msg_level = Pack.get_unsigned_char_data(msg_level)
        _service_id = (service_id +
                       (10 - len(service_id)) * '\x00').encode('utf-8')
        _fee_usertype = Pack.get_unsigned_char_data(fee_usertype)
        _fee_terminal_id = (
            fee_terminal_id +
            (32 - len(fee_terminal_id)) * '\x00').encode('utf-8')
        _fee_terminal_type = Pack.get_unsigned_char_data(fee_terminal_type)
        _tp_pid = Pack.get_unsigned_char_data(tp_pid)
        _tp_udhi = Pack.get_unsigned_char_data(tp_udhi)
        _msg_fmt = Pack.get_unsigned_char_data(msg_fmt)
        _msg_src = msg_src.encode('utf-8')
        _feetype = feetype.encode('utf-8')
        _feecode = feecode.encode('utf-8')
        _valid_time = valid_time.encode('utf-8')
        _at_time = at_time.encode('utf-8')
        _src_id = (src_id + (21 - len(src_id)) * '\x00').encode('utf-8')
        _destusr_tl = Pack.get_unsigned_char_data(len(dest_terminal_id))
        _dest_terminal_id = b""
        for msisdn in dest_terminal_id:
            _dest_terminal_id += (msisdn +
                                  (32 - len(msisdn)) * '\x00').encode('utf-8')
        _dest_terminal_type = Pack.get_unsigned_char_data(dest_terminal_type)
        _msg_content = msg_content.encode('utf-16-be')
        _msg_length = Pack.get_unsigned_char_data(len(_msg_content))
        _linkid = (linkid + (20 - len(linkid)) * '\x00').encode('utf-8')
        _message_body = _msg_id + _pk_total + _pk_number + \
                        _registered_delivery + _msg_level + _service_id + \
                        _fee_usertype + _fee_terminal_id + _fee_terminal_type \
                        + _tp_pid + _tp_udhi + _msg_fmt + _msg_src + _feetype \
                        + _feecode + _valid_time + _at_time + _src_id + \
                        _destusr_tl + _dest_terminal_id + _dest_terminal_type \
                        + _msg_length + _msg_content + _linkid

        RequestInstance.__init__(self, CMPP_SUBMIT_REQ, _message_body)