Beispiel #1
0
 def compute(self,
             batch: BatchData,
             outputs,
             step: int = None) -> Tuple[torch.Tensor, utils.TensorMap]:
     step = step or 0
     logit, post, prior = outputs
     batch_size = batch.batch_size
     max_conv_len = batch.max_conv_len
     s_logit, zstate_post, zstate_prior = \
         logit["state"], post["state"], prior["state"]
     conv_lens, sent_lens = batch.conv_lens, batch.sent.lens1
     conv_mask = utils.mask(conv_lens, max_conv_len)
     state_logit_mask = \
         (((s_logit != float("-inf")) & (s_logit != float("inf")))
          .masked_fill(~conv_mask.unsqueeze(-1), 0))
     kld_state = zstate_post.kl_div(zstate_prior).masked_fill(~conv_mask, 0)
     s_target = utils.to_dense(idx=batch.state.value,
                               lens=batch.state.lens1,
                               max_size=self.num_asv)
     p_target = batch.speaker.value.masked_fill(~conv_mask, -1)
     state_loss = (self._bce(s_logit, s_target.float()).masked_fill(
         ~state_logit_mask, 0)).sum(-1)
     kld_weight = self.kld_weight.get(step)
     nll = state_loss + kld_state
     loss = state_loss + kld_weight * kld_state
     state_mi = \
         (estimate_mi(zstate_post.view(batch_size * max_conv_len, -1))
          .view(batch_size, max_conv_len).masked_fill(~conv_mask, 0).sum(-1))
     stats = {
         "nll": nll.mean(),
         "state-mi": state_mi.mean(),
         "loss-state": state_loss.sum(-1).mean(),
         "loss-state-turn": state_loss.sum() / conv_lens.sum(),
         "loss-state-asv": state_loss.sum() / state_logit_mask.sum(),
         "kld-weight": torch.tensor(kld_weight),
         "kld-state": kld_state.sum(-1).mean(),
         "kld-state-turn": kld_state.sum() / conv_lens.sum(),
         "kld": kld_state.sum(-1).mean()
     }
     for spkr_idx, spkr in self.vocabs.speaker.i2f.items():
         if spkr == "<unk>":
             continue
         spkr_mask = p_target == spkr_idx
         spkr_state_mask = \
             state_logit_mask.masked_fill(~spkr_mask.unsqueeze(-1), 0)
         spkr_state_loss = state_loss.masked_fill(~spkr_mask, 0).sum()
         spkr_kld_state = kld_state.masked_fill(~spkr_mask, 0).sum()
         spkr_stats = {
             "loss-state": spkr_state_loss / batch_size,
             "loss-state-turn": spkr_state_loss / spkr_mask.sum(),
             "loss-state-asv": spkr_state_loss / spkr_state_mask.sum(),
             "kld-state": spkr_kld_state / batch_size,
             "kld-state-turn": spkr_kld_state / spkr_mask.sum(),
         }
         stats.update({f"{k}-{spkr}": v for k, v in spkr_stats.items()})
     return loss.mean(), stats
Beispiel #2
0
 def compute_accuracy(self,
                      pred: DoublyStacked1DTensor,
                      gold: DoublyStacked1DTensor,
                      turn_mask=None) -> utils.TensorMap:
     batch_size = pred.size(0)
     pred_dense = utils.to_dense(pred.value,
                                 pred.lens1,
                                 max_size=len(self.vocabs.goal_state.asv))
     gold_dense = utils.to_dense(gold.value,
                                 gold.lens1,
                                 max_size=len(self.vocabs.goal_state.asv))
     crt = (pred_dense == gold_dense).all(-1)
     conv_mask = utils.mask(pred.lens, pred.size(1))
     if turn_mask is None:
         turn_mask = torch.ones_like(conv_mask).bool()
     turn_mask = turn_mask & conv_mask
     crt = crt & turn_mask
     num_turns = turn_mask.sum()
     stats = {
         "acc": (crt | ~turn_mask).all(-1).sum().float() / batch_size,
         "acc-turn": crt.sum().float() / num_turns,
     }
     return stats
Beispiel #3
0
def main():
    qid2list = {}
    inc = 0
    for line in sys.stdin:
        parts = line.split(' ')
        score = float(parts[0])
        rating = int(parts[-1].split('=')[1])
        qid = parts[2]
        vec = utils.to_dense(parts[3:])
        seq = qid2list.get(qid)
        if seq is None:
            seq = []
            qid2list[qid] = seq

        record = utils.Record(qid, rating, vec, score=score)
        seq.append(record)
        inc += 1

    for qid, seq in qid2list.items():
        seq.sort(key=lambda x: x.score, reverse=True)

    print 'click_num', click_num(qid2list)
    print 'view_deep', view_deep(qid2list)
    print 'map', map(qid2list)
Beispiel #4
0
def e_step(votes_ij, activations_j, mean_j, stdv_j, var_j, spatial_routing_matrix):
  """The e-step in EM routing between input capsules (i) and output capsules (j).
  
  Update the assignment weights using in routung. The output capsules (j) 
  compete for the input capsules (i).
  See Hinton et al. "Matrix Capsules with EM Routing" for detailed description 
  of e-step.
  
  Author:
    Ashley Gritzman 19/10/2018
    
  Args: 
    votes_ij: 
      votes from capsules in layer i to capsules in layer j
      For conv layer:
        (N, OH, OW, kh*kw*i, o, 4x4)
        (64, 6, 6, 9*8, 32, 16)
      For FC layer:
        The kernel dimensions are equal to the spatial dimensions of the input 
        layer i, and the spatial dimensions of the output layer j are 1x1.
        (N, 1, 1, child_space*child_space*i, output_classes, 4x4)
        (64, 1, 1, 4*4*16, 5, 16)
    activations_j: 
      activations of capsules in layer j (L+1)
      (N, OH, OW, 1, o, 1)
      (64, 6, 6, 1, 32, 1)
    mean_j: 
      mean of each channel in capsules of layer j (L+1)
      (N, OH, OW, 1, o, n_channels)
      (24, 6, 6, 1, 32, 16)
    stdv_j: 
      standard deviation of each channel in capsules of layer j (L+1)
      (N, OH, OW, 1, o, n_channels)
      (24, 6, 6, 1, 32, 16)
    var_j: 
      variance of each channel in capsules of layer j (L+1)
      (N, OH, OW, 1, o, n_channels)
      (24, 6, 6, 1, 32, 16)
    spatial_routing_matrix: ???
    
  Returns:
    rr: 
      assignment weights between capsules in layer i and layer j
      (N, OH, OW, kh*kw*i, o, 1)
      (64, 6, 6, 9*8, 16, 1)
  """
  
  with tf.variable_scope("e_step") as scope:
    
    # AG 26/06/2018: changed stdv_j to var_j
    o_p_unit0 = - tf.reduce_sum(
      tf.square(votes_ij - mean_j, name="num") / (2 * var_j), 
      axis=-1, 
      keepdims=True, 
      name="o_p_unit0")
    
    o_p_unit2 = - 0.5 * tf.reduce_sum(
      tf.log(2*np.pi * var_j), 
      axis=-1, 
      keepdims=True, 
      name="o_p_unit2"
    )

    # (24, 6, 6, 288, 32, 1)
    o_p = o_p_unit0 + o_p_unit2
    zz = tf.log(activations_j + FLAGS.epsilon) + o_p
    
    # AG 13/11/2018: New implementation of normalising across parents
    #----- Start -----#
    zz_shape = zz.get_shape().as_list()
    batch_size = zz_shape[0]
    parent_space = zz_shape[1]
    kh_kw_i = zz_shape[3]
    parent_caps = zz_shape[4]
    kk = int(np.sum(spatial_routing_matrix[:,0]))
    child_caps = int(kh_kw_i / kk)
    
    zz = tf.reshape(zz, [batch_size, parent_space, parent_space, kk, 
                         child_caps, parent_caps])
    
    """
    # In un-log space
    with tf.variable_scope("to_sparse_unlog") as scope:
      zz_unlog = tf.exp(zz)
      #zz_sparse_unlog = utl.to_sparse(zz_unlog, spatial_routing_matrix, 
      # sparse_filler=1e-15)
      zz_sparse_unlog = utl.to_sparse(
          zz_unlog, 
          spatial_routing_matrix, 
          sparse_filler=0.0)
      # maybe this value should be even lower 1e-15
      zz_sparse_log = tf.log(zz_sparse_unlog + 1e-15) 
      zz_sparse = zz_sparse_log
    """

    
    # In log space
    with tf.variable_scope("to_sparse_log") as scope:
      # Fill the sparse matrix with the smallest value in zz (at least -100)
      sparse_filler = tf.minimum(tf.reduce_min(zz), -100)
#       sparse_filler = -100
      zz_sparse = utl.to_sparse(
          zz, 
          spatial_routing_matrix, 
          sparse_filler=sparse_filler)
  
    
    with tf.variable_scope("softmax_across_parents") as scope:
      rr_sparse = utl.softmax_across_parents(zz_sparse, spatial_routing_matrix)
    
    with tf.variable_scope("to_dense") as scope:
      rr_dense = utl.to_dense(rr_sparse, spatial_routing_matrix)
      
    rr = tf.reshape(
        rr_dense, 
        [batch_size, parent_space, parent_space, kh_kw_i, parent_caps, 1])
    #----- End -----#

    # AG 02/11/2018
    # In response to a question on OpenReview, Hinton et al. wrote the 
    # following:
    # "The gradient flows through EM algorithm. We do not use stop gradient. A 
    # routing of 3 is like a 3 layer network where the weights of layers are 
    # shared."
    # https://openreview.net/forum?id=HJWLfGWRb&noteId=S1eo2P1I3Q
    
    return rr
Beispiel #5
0
def test_jda(create_fn=create_vhda, gen_fn=vhda_gen):
    dataset = create_dummy_dataset()
    dataloader = create_dataloader(
        dataset,
        batch_size=2,
    )
    model = create_fn(dataset)
    optimizer = op.Adam(p for p in model.parameters() if p.requires_grad)
    ce = nn.CrossEntropyLoss(ignore_index=-1, reduction="none")
    bce = nn.BCEWithLogitsLoss(reduction="none")
    model.reset_parameters()
    for eidx in range(300):
        model.train()
        for i, batch in enumerate(dataloader):
            batch: BatchData = batch
            optimizer.zero_grad()
            model.inference()
            w, p = batch.word, batch.speaker
            g, g_lens = batch.goal, batch.goal_lens
            s, s_lens = batch.turn, batch.turn_lens
            sent_lens, conv_lens = batch.sent_lens, batch.conv_lens
            batch_size, max_conv_len, max_sent_len = w.size()
            w_logit, p_logit, g_logit, s_logit, info = model(batch.to_dict())
            w_target = w.masked_fill(~utils.mask(conv_lens).unsqueeze(-1), -1)
            w_target = w_target.view(-1, max_sent_len).masked_fill(
                ~utils.mask(sent_lens.view(-1)), -1
            ).view(batch_size, max_conv_len, -1)
            recon_loss = ce(
                w_logit[:, :, :-1].contiguous().view(-1, w_logit.size(-1)),
                w_target[:, :, 1:].contiguous().view(-1)
            ).view(batch_size, max_conv_len, max_sent_len - 1).sum(-1).sum(-1)
            goal_loss = bce(
                g_logit,
                utils.to_dense(g, g_lens, g_logit.size(-1)).float()
            )
            goal_loss = (goal_loss.masked_fill(~utils.mask(conv_lens)
                                               .unsqueeze(-1).unsqueeze(-1), 0)
                         .sum(-1).sum(-1).sum(-1))
            turn_loss = bce(
                s_logit,
                utils.to_dense(s, s_lens, s_logit.size(-1)).float()
            )
            turn_loss = (turn_loss.masked_fill(~utils.mask(conv_lens)
                                               .unsqueeze(-1).unsqueeze(-1), 0)
                         .sum(-1).sum(-1).sum(-1))
            speaker_loss = ce(
                p_logit.view(-1, p_logit.size(-1)),
                p.masked_fill(~utils.mask(conv_lens), -1).view(-1)
            ).view(batch_size, max_conv_len).sum(-1)
            kld_loss = sum(v for k, v in info.items()
                           if k in {"sent", "conv", "speaker", "goal", "turn"})
            loss = (recon_loss + goal_loss + turn_loss + speaker_loss +
                    kld_loss * min(0.3, max(0.01, i / 500)))
            print(f"[e{eidx + 1}] "
                  f"loss={loss.mean().item(): 4.4f} "
                  f"recon={recon_loss.mean().item(): 4.4f} "
                  f"goal={goal_loss.mean().item(): 4.4f} "
                  f"turn={turn_loss.mean().item(): 4.4f} "
                  f"speaker={speaker_loss.mean().item(): 4.4f} "
                  f"kld={kld_loss.mean().item(): 4.4f}")
            loss.mean().backward()
            optimizer.step()

            model.eval()
            model.genconv_post()
            batch_gen, info = gen_fn(model)(batch.to_dict())
            print("Input: ")
            print(f"{dataset.processor.lexicalize(batch[0])}")
            print()
            print(f"Predicted (prob={info['logprob'][0].exp().item():.4f}): ")
            print(f"{dataset.processor.lexicalize(batch_gen[0])}")
    model.eval()
    model.genconv_post()
    for batch in dataloader:
        batch_gen, logprobs = gen_fn(model)(batch.to_dict())
        for x, y in zip(map(dataset.processor.lexicalize, batch),
                        map(dataset.processor.lexicalize, batch_gen)):
            assert x == y, f"{x}\n!=\n{y}"
Beispiel #6
0
def test_dense_sparse():
    x = torch.randint(0, 2, (3, 4, 5)).byte()
    y = utils.to_dense(*utils.to_sparse(x))
    assert (x == y).all()
import sys
import utils

qid2list = {}

ARG_type, ARG_do_bounce, ARG_seq_len = sys.argv[1], sys.argv[2], sys.argv[3]
ARG_seq_len = int(ARG_seq_len)

for line in sys.stdin:
    parts = line.split(' ')
    score = float(parts[0])
    rating = int(parts[1])
    qid = parts[2]
    vec = utils.to_dense(parts[3:])
    record = utils.Record(qid, rating, vec, score=score)
    seq = qid2list.get(qid)
    if seq is None:
        seq = []
        qid2list[qid] = seq

    seq.append(record)

for qid, seq in qid2list.items():
    if len(seq) < ARG_seq_len:
        continue
    seq.sort(key=lambda x: x.score, reverse=True)
    if len(seq) > ARG_seq_len:
        seq = seq[0:ARG_seq_len]

    fb_seq = utils.gen_scan_seq(seq)
Beispiel #8
0
 def compute(self,
             batch: BatchData,
             outputs,
             step: int = None) -> Tuple[torch.Tensor, utils.TensorMap]:
     logit, post, prior = outputs
     batch_size = batch.batch_size
     max_conv_len = batch.max_conv_len
     max_sent_len = batch.max_sent_len
     w_logit, p_logit, g_logit, s_logit = \
         (logit[k] for k in ("sent", "speaker", "goal", "state"))
     conv_lens, sent_lens = batch.conv_lens, batch.sent.lens1
     conv_mask = utils.mask(conv_lens, max_conv_len)
     sent_lens = sent_lens.masked_fill(~conv_mask, 0)
     sent_mask = utils.mask(sent_lens, max_sent_len)
     goal_logit_mask = (((g_logit != float("-inf")) &
                         (g_logit != float("inf"))).masked_fill(
                             ~conv_mask.unsqueeze(-1), 0))
     state_logit_mask = \
         (((s_logit != float("-inf")) & (s_logit != float("inf")))
          .masked_fill(~conv_mask.unsqueeze(-1), 0))
     w_target = (batch.sent.value.masked_fill(~sent_mask,
                                              -1).view(-1,
                                                       max_sent_len))[...,
                                                                      1:]
     g_target = utils.to_dense(idx=batch.goal.value,
                               lens=batch.goal.lens1,
                               max_size=self.num_asv)
     s_target = utils.to_dense(idx=batch.state.value,
                               lens=batch.state.lens1,
                               max_size=self.num_asv)
     p_target = batch.speaker.value.masked_fill(~conv_mask, -1)
     goal_loss = (self._bce(g_logit, g_target.float()).masked_fill(
         ~goal_logit_mask, 0)).sum(-1)
     state_loss = (self._bce(s_logit, s_target.float()).masked_fill(
         ~state_logit_mask, 0)).sum(-1)
     spkr_loss = self._ce(p_logit.view(-1, self.vocabs.num_speakers),
                          p_target.view(-1)).view(batch_size, max_conv_len)
     sent_loss = self._ce(
         w_logit[:, :, :-1].contiguous().view(-1, len(self.vocabs.word)),
         w_target.contiguous().view(-1)).view(batch_size, max_conv_len,
                                              -1).sum(-1)
     loss_recon = (sent_loss.sum(-1) + state_loss.sum(-1) +
                   goal_loss.sum(-1) + spkr_loss.sum(-1))
     loss = nll = loss_recon
     stats = {
         "nll": nll.mean(),
         "loss": loss.mean(),
         "loss-recon": loss_recon.mean(),
         "loss-sent": sent_loss.sum(-1).mean(),
         "loss-sent-turn": sent_loss.sum() / conv_lens.sum(),
         "loss-sent-word": sent_loss.sum() / sent_lens.sum(),
         "ppl-turn": (sent_loss.sum() / conv_lens.sum()).exp(),
         "ppl-word": (sent_loss.sum() / sent_lens.sum()).exp(),
         "loss-goal": goal_loss.sum(-1).mean(),
         "loss-goal-turn": goal_loss.sum() / conv_lens.sum(),
         "loss-goal-asv": goal_loss.sum() / goal_logit_mask.sum(),
         "loss-state": state_loss.sum(-1).mean(),
         "loss-state-turn": state_loss.sum() / conv_lens.sum(),
         "loss-state-asv": state_loss.sum() / state_logit_mask.sum(),
         "loss-spkr": spkr_loss.sum(-1).mean(),
         "loss-spkr-turn": spkr_loss.sum() / conv_lens.sum()
     }
     for spkr_idx, spkr in self.vocabs.speaker.i2f.items():
         if spkr == "<unk>":
             continue
         spkr_mask = p_target == spkr_idx
         spkr_sent_lens = sent_lens.masked_fill(~spkr_mask, 0)
         spkr_goal_mask = \
             goal_logit_mask.masked_fill(~spkr_mask.unsqueeze(-1), 0)
         spkr_state_mask = \
             state_logit_mask.masked_fill(~spkr_mask.unsqueeze(-1), 0)
         spkr_sent_loss = sent_loss.masked_fill(~spkr_mask, 0).sum()
         spkr_goal_loss = goal_loss.masked_fill(~spkr_mask, 0).sum()
         spkr_state_loss = state_loss.masked_fill(~spkr_mask, 0).sum()
         spkr_spkr_loss = spkr_loss.masked_fill(~spkr_mask, 0).sum()
         spkr_stats = {
             "loss-sent": spkr_sent_loss / batch_size,
             "loss-sent-turn": spkr_sent_loss / spkr_mask.sum(),
             "loss-sent-word": spkr_sent_loss / spkr_sent_lens.sum(),
             "ppl-turn": (spkr_sent_loss / spkr_mask.sum()).exp(),
             "ppl-word": (spkr_sent_loss / spkr_sent_lens.sum()).exp(),
             "loss-goal": spkr_goal_loss / batch_size,
             "loss-goal-turn": spkr_goal_loss / spkr_mask.sum(),
             "loss-goal-asv": spkr_goal_loss / spkr_goal_mask.sum(),
             "loss-state": spkr_state_loss / batch_size,
             "loss-state-turn": spkr_state_loss / spkr_mask.sum(),
             "loss-state-asv": spkr_state_loss / spkr_state_mask.sum(),
             "loss-spkr": spkr_spkr_loss / batch_size,
             "loss-spkr-turn": spkr_spkr_loss / spkr_mask.sum()
         }
         stats.update({f"{k}-{spkr}": v for k, v in spkr_stats.items()})
     return loss.mean(), stats
Beispiel #9
0
 def compute(self,
             batch: BatchData,
             outputs,
             step: int = None) -> Tuple[torch.Tensor, utils.TensorMap]:
     step = step or 0
     logit, post, prior = outputs
     batch_size = batch.batch_size
     max_conv_len = batch.max_conv_len
     max_sent_len = batch.max_sent_len
     max_goal_len = batch.max_goal_len
     max_state_len = batch.max_state_len
     w_logit, p_logit, s_logit = \
         (logit[k] for k in ("sent", "speaker", "state"))
     zconv_post, zstate_post, zsent_post, zspkr_post = \
         (post[k] for k in ("conv", "state", "sent", "speaker"))
     zconv_prior, zstate_prior, zsent_prior, zspkr_prior = \
         (prior[k] for k in ("conv", "state", "sent", "speaker"))
     conv_lens, sent_lens = batch.conv_lens, batch.sent.lens1
     conv_mask = utils.mask(conv_lens, max_conv_len)
     sent_lens = sent_lens.masked_fill(~conv_mask, 0)
     sent_mask = utils.mask(sent_lens, max_sent_len)
     state_logit_mask = \
         (((s_logit != float("-inf")) & (s_logit != float("inf")))
          .masked_fill(~conv_mask.unsqueeze(-1), 0))
     kld_conv = zconv_post.kl_div()
     kld_state = zstate_post.kl_div(zstate_prior).masked_fill(~conv_mask, 0)
     kld_sent = zsent_post.kl_div(zsent_prior).masked_fill(~conv_mask, 0)
     kld_spkr = zspkr_post.kl_div(zspkr_prior).masked_fill(~conv_mask, 0)
     w_target = (batch.sent.value.masked_fill(~sent_mask,
                                              -1).view(-1,
                                                       max_sent_len))[...,
                                                                      1:]
     s_target = utils.to_dense(idx=batch.state.value,
                               lens=batch.state.lens1,
                               max_size=self.num_asv)
     p_target = batch.speaker.value.masked_fill(~conv_mask, -1)
     state_loss = (self._bce(s_logit, s_target.float()).masked_fill(
         ~state_logit_mask, 0)).sum(-1)
     spkr_loss = self._ce(p_logit.view(-1, self.vocabs.num_speakers),
                          p_target.view(-1)).view(batch_size, max_conv_len)
     sent_loss = self._ce(
         w_logit[:, :, :-1].contiguous().view(-1, len(self.vocabs.word)),
         w_target.contiguous().view(-1)).view(batch_size, max_conv_len,
                                              -1).sum(-1)
     kld_weight = self.kld_weight.get(step)
     loss_kld = (kld_conv + kld_sent.sum(-1) + kld_state.sum(-1) +
                 kld_spkr.sum(-1))
     loss_recon = (sent_loss.sum(-1) + state_loss.sum(-1) +
                   spkr_loss.sum(-1))
     nll = loss_recon + loss_kld
     conv_mi = estimate_mi(zconv_post)
     sent_mi = \
         (estimate_mi(zsent_post.view(batch_size * max_conv_len, -1))
          .view(batch_size, max_conv_len).masked_fill(~conv_mask, 0).sum(-1))
     spkr_mi = \
         (estimate_mi(zspkr_post.view(batch_size * max_conv_len, -1))
          .view(batch_size, max_conv_len).masked_fill(~conv_mask, 0).sum(-1))
     state_mi = \
         (estimate_mi(zstate_post.view(batch_size * max_conv_len, -1))
          .view(batch_size, max_conv_len).masked_fill(~conv_mask, 0).sum(-1))
     if self.enable_kl:
         if self.kl_mode == "kl-mi":
             loss = loss_recon + kld_weight * (loss_kld - conv_mi)
         elif self.kl_mode == "kl-mi+":
             loss = loss_recon + kld_weight * (loss_kld - conv_mi -
                                               sent_mi - spkr_mi - state_mi)
         else:
             loss = loss_recon + kld_weight * loss_kld
     else:
         loss = loss_recon
     stats = {
         "nll": nll.mean(),
         "conv-mi": conv_mi.mean(),
         "sent-mi": sent_mi.mean(),
         "state-mi": state_mi.mean(),
         "spkr-mi": spkr_mi.mean(),
         "loss": loss.mean(),
         "loss-recon": loss_recon.mean(),
         "loss-sent": sent_loss.sum(-1).mean(),
         "loss-sent-turn": sent_loss.sum() / conv_lens.sum(),
         "loss-sent-word": sent_loss.sum() / sent_lens.sum(),
         "ppl-turn": (sent_loss.sum() / conv_lens.sum()).exp(),
         "ppl-word": (sent_loss.sum() / sent_lens.sum()).exp(),
         "loss-state": state_loss.sum(-1).mean(),
         "loss-state-turn": state_loss.sum() / conv_lens.sum(),
         "loss-state-asv": state_loss.sum() / state_logit_mask.sum(),
         "loss-spkr": spkr_loss.sum(-1).mean(),
         "loss-spkr-turn": spkr_loss.sum() / conv_lens.sum(),
         "kld-weight": torch.tensor(kld_weight),
         "kld-sent": kld_sent.sum(-1).mean(),
         "kld-sent-turn": kld_sent.sum() / conv_lens.sum(),
         "kld-conv": kld_conv.sum(-1).mean(),
         "kld-state": kld_state.sum(-1).mean(),
         "kld-state-turn": kld_state.sum() / conv_lens.sum(),
         "kld-spkr": kld_spkr.sum(-1).mean(),
         "kld-spkr-turn": kld_spkr.sum() / conv_lens.sum(),
         "kld": loss_kld.mean()
     }
     for spkr_idx, spkr in self.vocabs.speaker.i2f.items():
         if spkr == "<unk>":
             continue
         spkr_mask = p_target == spkr_idx
         spkr_sent_lens = sent_lens.masked_fill(~spkr_mask, 0)
         spkr_state_mask = \
             state_logit_mask.masked_fill(~spkr_mask.unsqueeze(-1), 0)
         spkr_sent_loss = sent_loss.masked_fill(~spkr_mask, 0).sum()
         spkr_state_loss = state_loss.masked_fill(~spkr_mask, 0).sum()
         spkr_spkr_loss = spkr_loss.masked_fill(~spkr_mask, 0).sum()
         spkr_kld_sent = kld_sent.masked_fill(~spkr_mask, 0).sum()
         spkr_kld_state = kld_state.masked_fill(~spkr_mask, 0).sum()
         spkr_kld_spkr = kld_spkr.masked_fill(~spkr_mask, 0).sum()
         spkr_stats = {
             "loss-sent": spkr_sent_loss / batch_size,
             "loss-sent-turn": spkr_sent_loss / spkr_mask.sum(),
             "loss-sent-word": spkr_sent_loss / spkr_sent_lens.sum(),
             "ppl-turn": (spkr_sent_loss / spkr_mask.sum()).exp(),
             "ppl-word": (spkr_sent_loss / spkr_sent_lens.sum()).exp(),
             "loss-state": spkr_state_loss / batch_size,
             "loss-state-turn": spkr_state_loss / spkr_mask.sum(),
             "loss-state-asv": spkr_state_loss / spkr_state_mask.sum(),
             "loss-spkr": spkr_spkr_loss / batch_size,
             "loss-spkr-turn": spkr_spkr_loss / spkr_mask.sum(),
             "kld-sent": spkr_kld_sent / batch_size,
             "kld-sent-turn": spkr_kld_sent / spkr_mask.sum(),
             "kld-state": spkr_kld_state / batch_size,
             "kld-state-turn": spkr_kld_state / spkr_mask.sum(),
             "kld-spkr": spkr_kld_spkr / batch_size,
             "kld-spkr-turn": spkr_kld_spkr / spkr_mask.sum(),
         }
         stats.update({f"{k}-{spkr}": v for k, v in spkr_stats.items()})
     return loss.mean(), stats