Esempio n. 1
0
def evaluate_valid_loss(data_loader, model, opt):
    model.eval()
    evaluation_loss_sum = 0.0
    total_trg_tokens = 0
    n_batch = 0
    loss_compute_time_total = 0.0
    forward_time_total = 0.0

    with torch.no_grad():
        for batch_i, batch in enumerate(data_loader):
            # load one2many dataset
            src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _ = batch
            num_trgs = [
                len(trg_str_list) for trg_str_list in trg_str_2dlist
            ]  # a list of num of targets in each batch, with len=batch_size

            max_num_oov = max([len(oov) for oov in oov_lists
                               ])  # max number of oov for each batch

            batch_size = src.size(0)
            n_batch += batch_size

            # move data to GPU if available
            src = src.to(opt.device)
            src_mask = src_mask.to(opt.device)
            trg = trg.to(opt.device)
            trg_mask = trg_mask.to(opt.device)
            src_oov = src_oov.to(opt.device)
            trg_oov = trg_oov.to(opt.device)

            start_time = time.time()
            decoder_dist, attention_dist = model(src, src_lens, trg, src_oov,
                                                 max_num_oov, src_mask,
                                                 num_trgs)
            forward_time = time_since(start_time)
            forward_time_total += forward_time

            start_time = time.time()

            loss = masked_cross_entropy(decoder_dist, trg_oov, trg_mask)
            loss_compute_time = time_since(start_time)
            loss_compute_time_total += loss_compute_time

            evaluation_loss_sum += loss.item()
            total_trg_tokens += sum(trg_lens)

    eval_loss_stat = LossStatistics(evaluation_loss_sum,
                                    total_trg_tokens,
                                    n_batch,
                                    forward_time=forward_time_total,
                                    loss_compute_time=loss_compute_time_total)
    return eval_loss_stat
Esempio n. 2
0
def evaluate_loss(data_loader, model, opt):
    model.eval()
    evaluation_loss_sum = 0.0
    total_trg_tokens = 0
    n_batch = 0
    loss_compute_time_total = 0.0
    forward_time_total = 0.0

    with torch.no_grad():
        for batch_i, batch in enumerate(data_loader):
            if not opt.one2many:  # load one2one dataset
                src, src_lens, src_mask, trg, trg_lens, trg_mask, src_oov, trg_oov, oov_lists, title, title_oov, title_lens, title_mask = batch
            else:  # load one2many dataset
                src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _, title, title_oov, title_lens, title_mask = batch
                num_trgs = [
                    len(trg_str_list) for trg_str_list in trg_str_2dlist
                ]  # a list of num of targets in each batch, with len=batch_size

            max_num_oov = max([len(oov) for oov in oov_lists
                               ])  # max number of oov for each batch

            batch_size = src.size(0)
            n_batch += batch_size

            # move data to GPU if available
            src = src.to(opt.device)
            src_mask = src_mask.to(opt.device)
            trg = trg.to(opt.device)
            trg_mask = trg_mask.to(opt.device)
            src_oov = src_oov.to(opt.device)
            trg_oov = trg_oov.to(opt.device)
            if opt.title_guided:
                title = title.to(opt.device)
                title_mask = title_mask.to(opt.device)
                # title_oov = title_oov.to(opt.device)

            start_time = time.time()
            if not opt.one2many:
                decoder_dist, h_t, attention_dist, encoder_final_state, coverage, _, _, _ = model(
                    src,
                    src_lens,
                    trg,
                    src_oov,
                    max_num_oov,
                    src_mask,
                    title=title,
                    title_lens=title_lens,
                    title_mask=title_mask)
            else:
                decoder_dist, h_t, attention_dist, encoder_final_state, coverage, _, _, _ = model(
                    src,
                    src_lens,
                    trg,
                    src_oov,
                    max_num_oov,
                    src_mask,
                    num_trgs,
                    title=title,
                    title_lens=title_lens,
                    title_mask=title_mask)
            forward_time = time_since(start_time)
            forward_time_total += forward_time

            start_time = time.time()
            if opt.copy_attention:  # Compute the loss using target with oov words
                loss = masked_cross_entropy(decoder_dist,
                                            trg_oov,
                                            trg_mask,
                                            trg_lens,
                                            opt.coverage_attn,
                                            coverage,
                                            attention_dist,
                                            opt.lambda_coverage,
                                            coverage_loss=False)
            else:  # Compute the loss using target without oov words
                loss = masked_cross_entropy(decoder_dist,
                                            trg,
                                            trg_mask,
                                            trg_lens,
                                            opt.coverage_attn,
                                            coverage,
                                            attention_dist,
                                            opt.lambda_coverage,
                                            coverage_loss=False)
            loss_compute_time = time_since(start_time)
            loss_compute_time_total += loss_compute_time

            evaluation_loss_sum += loss.item()
            total_trg_tokens += sum(trg_lens)

    eval_loss_stat = LossStatistics(evaluation_loss_sum,
                                    total_trg_tokens,
                                    n_batch,
                                    forward_time=forward_time_total,
                                    loss_compute_time=loss_compute_time_total)
    return eval_loss_stat
Esempio n. 3
0
def train_one_batch(batch, model, optimizer, opt, batch_i):
    # load one2many data
    """
    src: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], with oov words replaced by unk idx
    src_lens: a list containing the length of src sequences for each batch, with len=batch
    src_mask: a FloatTensor, [batch, src_seq_len]
    src_oov: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], contains the index of oov words (used by copy)
    trg: LongTensor [batch, trg_seq_len], each target trg[i] contains the indices of a set of concatenated keyphrases, separated by opt.word2idx[pykp.io.SEP_WORD]
         if opt.delimiter_type = 0, SEP_WORD=<sep>, if opt.delimiter_type = 1, SEP_WORD=<eok>
    trg_lens: a list containing the length of trg sequences for each batch, with len=batch
    trg_mask: a FloatTensor, [batch, trg_seq_len]
    trg_oov: same as trg_oov, but all unk words are replaced with temporary idx, e.g. 50000, 50001 etc.
    """
    src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _ = batch
    # a list of num of targets in each batch, with len=batch_size
    num_trgs = [len(trg_str_list) for trg_str_list in trg_str_2dlist]

    max_num_oov = max([len(oov) for oov in oov_lists
                       ])  # max number of oov for each batch

    # move data to GPU if available
    src = src.to(opt.device)
    src_mask = src_mask.to(opt.device)
    trg = trg.to(opt.device)
    trg_mask = trg_mask.to(opt.device)
    src_oov = src_oov.to(opt.device)
    trg_oov = trg_oov.to(opt.device)

    optimizer.zero_grad()

    start_time = time.time()

    decoder_dist, attention_dist = model(src,
                                         src_lens,
                                         trg,
                                         src_oov,
                                         max_num_oov,
                                         src_mask,
                                         num_trgs=num_trgs)
    forward_time = time_since(start_time)

    start_time = time.time()

    loss = masked_cross_entropy(decoder_dist, trg_oov, trg_mask)

    loss_compute_time = time_since(start_time)

    total_trg_tokens = sum(trg_lens)

    if math.isnan(loss.item()):
        print("Batch i: %d" % batch_i)
        print("src")
        print(src)
        print(src_oov)
        print(src_str_list)
        print(src_lens)
        print(src_mask)
        print("trg")
        print(trg)
        print(trg_oov)
        print(trg_str_2dlist)
        print(trg_lens)
        print(trg_mask)
        print("oov list")
        print(oov_lists)
        print("Decoder")
        print(decoder_dist)
        print(attention_dist)
        raise ValueError("Loss is NaN")

    if opt.loss_normalization == "tokens":  # use number of target tokens to normalize the loss
        normalization = total_trg_tokens
    elif opt.loss_normalization == 'batches':  # use batch_size to normalize the loss
        normalization = src.size(0)
    else:
        raise ValueError('The type of loss normalization is invalid.')

    assert normalization > 0, 'normalization should be a positive number'

    start_time = time.time()
    # back propagation on the normalized loss
    loss.div(normalization).backward()
    backward_time = time_since(start_time)

    if opt.max_grad_norm > 0:
        grad_norm_before_clipping = nn.utils.clip_grad_norm_(
            model.parameters(), opt.max_grad_norm)

    optimizer.step()

    # construct a statistic object for the loss
    stat = LossStatistics(loss.item(),
                          total_trg_tokens,
                          n_batch=1,
                          forward_time=forward_time,
                          loss_compute_time=loss_compute_time,
                          backward_time=backward_time)

    return stat, decoder_dist.detach()
Esempio n. 4
0
def evaluate_loss(data_loader, model, ntm_model, opt):
    model.eval()
    ntm_model.eval()
    evaluation_loss_sum = 0.0
    total_trg_tokens = 0
    n_batch = 0
    loss_compute_time_total = 0.0
    forward_time_total = 0.0
    print("Evaluate loss for %d batches" % len(data_loader))
    with torch.no_grad():
        for batch_i, batch in enumerate(data_loader):
            if not opt.one2many:  # load one2one dataset
                src, src_lens, src_mask, trg, trg_lens, trg_mask, src_oov, trg_oov, oov_lists, src_bow = batch
            else:  # load one2many dataset
                src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _ = batch
                num_trgs = [len(trg_str_list) for trg_str_list in
                            trg_str_2dlist]  # a list of num of targets in each batch, with len=batch_size

            max_num_oov = max([len(oov) for oov in oov_lists])  # max number of oov for each batch

            batch_size = src.size(0)
            n_batch += batch_size

            # move data to GPU if available
            src = src.to(opt.device)
            src_mask = src_mask.to(opt.device)
            trg = trg.to(opt.device)
            trg_mask = trg_mask.to(opt.device)
            src_oov = src_oov.to(opt.device)
            trg_oov = trg_oov.to(opt.device)

            if opt.use_topic_represent:
                src_bow = src_bow.to(opt.device)
                src_bow_norm = F.normalize(src_bow)
                if opt.topic_type == 'z':
                    topic_represent, _, _, _, _ = ntm_model(src_bow_norm)
                else:
                    _, topic_represent, _, _, _ = ntm_model(src_bow_norm)
            else:
                topic_represent = None

            start_time = time.time()

            # one2one setting
            decoder_dist, h_t, attention_dist, encoder_final_state, coverage, _, _, _ \
                = model(src, src_lens, trg, src_oov, max_num_oov, src_mask, topic_represent)

            forward_time = time_since(start_time)
            forward_time_total += forward_time

            start_time = time.time()
            if opt.copy_attention:  # Compute the loss using target with oov words
                loss = masked_cross_entropy(decoder_dist, trg_oov, trg_mask, trg_lens,
                                            opt.coverage_attn, coverage, attention_dist, opt.lambda_coverage,
                                            coverage_loss=False)
            else:  # Compute the loss using target without oov words
                loss = masked_cross_entropy(decoder_dist, trg, trg_mask, trg_lens,
                                            opt.coverage_attn, coverage, attention_dist, opt.lambda_coverage,
                                            coverage_loss=False)
            loss_compute_time = time_since(start_time)
            loss_compute_time_total += loss_compute_time

            evaluation_loss_sum += loss.item()
            total_trg_tokens += sum(trg_lens)

            if (batch_i + 1) % (len(data_loader) // 5) == 0:
                print("Train: %d/%d batches, current avg loss: %.3f" %
                      ((batch_i + 1), len(data_loader), evaluation_loss_sum / total_trg_tokens))

    eval_loss_stat = LossStatistics(evaluation_loss_sum, total_trg_tokens, n_batch, forward_time=forward_time_total,
                                    loss_compute_time=loss_compute_time_total)
    return eval_loss_stat
Esempio n. 5
0
def train_one_batch(batch,
                    model,
                    optimizer,
                    opt,
                    batch_i,
                    source_representation_queue=None):
    if not opt.one2many:  # load one2one data
        src, src_lens, src_mask, trg, trg_lens, trg_mask, src_oov, trg_oov, oov_lists, title, title_oov, title_lens, title_mask = batch
        """
        src: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], with oov words replaced by unk idx
        src_lens: a list containing the length of src sequences for each batch, with len=batch
        src_mask: a FloatTensor, [batch, src_seq_len]
        trg: a LongTensor containing the word indices of target sentences, [batch, trg_seq_len]
        trg_lens: a list containing the length of trg sequences for each batch, with len=batch
        trg_mask: a FloatTensor, [batch, trg_seq_len]
        src_oov: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], contains the index of oov words (used by copy)
        trg_oov: a LongTensor containing the word indices of target sentences, [batch, src_seq_len], contains the index of oov words (used by copy)
        """
    else:  # load one2many data
        src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _, title, title_oov, title_lens, title_mask = batch
        num_trgs = [
            len(trg_str_list) for trg_str_list in trg_str_2dlist
        ]  # a list of num of targets in each batch, with len=batch_size
        """
        trg: LongTensor [batch, trg_seq_len], each target trg[i] contains the indices of a set of concatenated keyphrases, separated by opt.word2idx[pykp.io.SEP_WORD]
             if opt.delimiter_type = 0, SEP_WORD=<sep>, if opt.delimiter_type = 1, SEP_WORD=<eos>
        trg_oov: same as trg_oov, but all unk words are replaced with temporary idx, e.g. 50000, 50001 etc.
        """
    batch_size = src.size(0)
    max_num_oov = max([len(oov) for oov in oov_lists
                       ])  # max number of oov for each batch

    # move data to GPU if available
    src = src.to(opt.device)
    src_mask = src_mask.to(opt.device)
    trg = trg.to(opt.device)
    trg_mask = trg_mask.to(opt.device)
    src_oov = src_oov.to(opt.device)
    trg_oov = trg_oov.to(opt.device)
    if opt.title_guided:
        title = title.to(opt.device)
        title_mask = title_mask.to(opt.device)
        #title_oov = title_oov.to(opt.device)
    # title, title_oov, title_lens, title_mask

    optimizer.zero_grad()

    #if opt.one2many_mode == 0 or opt.one2many_mode == 1:
    start_time = time.time()

    if opt.use_target_encoder:  # Sample encoder representations
        if len(source_representation_queue
               ) < opt.source_representation_sample_size:
            source_representation_samples_2dlist = None
            source_representation_target_list = None
        else:
            source_representation_samples_2dlist = []
            source_representation_target_list = []
            for i in range(batch_size):
                # N encoder representation from the queue
                source_representation_samples_list = source_representation_queue.sample(
                    opt.source_representation_sample_size)
                # insert a place-holder for the ground-truth source representation to a random index
                place_holder_idx = np.random.randint(
                    0, opt.source_representation_sample_size + 1)
                source_representation_samples_list.insert(
                    place_holder_idx, None)  # len=N+1
                # insert the sample list of one batch to the 2d list
                source_representation_samples_2dlist.append(
                    source_representation_samples_list)
                # store the idx of place-holder for that batch
                source_representation_target_list.append(place_holder_idx)
    else:
        source_representation_samples_2dlist = None
        source_representation_target_list = None
        """
        if encoder_representation_samples_2dlist[0] is None and batch_i > math.ceil(
                opt.encoder_representation_sample_size / batch_size):
            # a return value of none indicates we don't have sufficient samples
            # it will only occurs in the first few training steps
            raise ValueError("encoder_representation_samples should not be none at this batch!")
        """

    if not opt.one2many:
        decoder_dist, h_t, attention_dist, encoder_final_state, coverage, delimiter_decoder_states, delimiter_decoder_states_lens, source_classification_dist = model(
            src,
            src_lens,
            trg,
            src_oov,
            max_num_oov,
            src_mask,
            sampled_source_representation_2dlist=
            source_representation_samples_2dlist,
            source_representation_target_list=source_representation_target_list,
            title=title,
            title_lens=title_lens,
            title_mask=title_mask)
    else:
        decoder_dist, h_t, attention_dist, encoder_final_state, coverage, delimiter_decoder_states, delimiter_decoder_states_lens, source_classification_dist = model(
            src,
            src_lens,
            trg,
            src_oov,
            max_num_oov,
            src_mask,
            num_trgs=num_trgs,
            sampled_source_representation_2dlist=
            source_representation_samples_2dlist,
            source_representation_target_list=source_representation_target_list,
            title=title,
            title_lens=title_lens,
            title_mask=title_mask)
    forward_time = time_since(start_time)

    if opt.use_target_encoder:  # Put all the encoder final states to the queue. Need to call detach() first
        # encoder_final_state: [batch, memory_bank_size]
        [
            source_representation_queue.put(encoder_final_state[i, :].detach())
            for i in range(batch_size)
        ]

    start_time = time.time()
    if opt.copy_attention:  # Compute the loss using target with oov words
        loss = masked_cross_entropy(
            decoder_dist, trg_oov, trg_mask, trg_lens, opt.coverage_attn,
            coverage, attention_dist, opt.lambda_coverage, opt.coverage_loss,
            delimiter_decoder_states, opt.orthogonal_loss,
            opt.lambda_orthogonal, delimiter_decoder_states_lens)
    else:  # Compute the loss using target without oov words
        loss = masked_cross_entropy(
            decoder_dist, trg, trg_mask, trg_lens, opt.coverage_attn, coverage,
            attention_dist, opt.lambda_coverage, opt.coverage_loss,
            delimiter_decoder_states, opt.orthogonal_loss,
            opt.lambda_orthogonal, delimiter_decoder_states_lens)

    loss_compute_time = time_since(start_time)

    #else:  # opt.one2many_mode == 2
    #    forward_time = 0
    #    loss_compute_time = 0
    #    # TODO: a for loop to accumulate loss for each keyphrase
    #    # TODO: meanwhile, accumulate the forward time and loss_compute time
    #    pass

    total_trg_tokens = sum(trg_lens)

    if math.isnan(loss.item()):
        print("Batch i: %d" % batch_i)
        print("src")
        print(src)
        print(src_oov)
        print(src_str_list)
        print(src_lens)
        print(src_mask)
        print("trg")
        print(trg)
        print(trg_oov)
        print(trg_str_2dlist)
        print(trg_lens)
        print(trg_mask)
        print("oov list")
        print(oov_lists)
        print("Decoder")
        print(decoder_dist)
        print(h_t)
        print(attention_dist)
        raise ValueError("Loss is NaN")

    if opt.loss_normalization == "tokens":  # use number of target tokens to normalize the loss
        normalization = total_trg_tokens
    elif opt.loss_normalization == 'batches':  # use batch_size to normalize the loss
        normalization = src.size(0)
    else:
        raise ValueError('The type of loss normalization is invalid.')

    assert normalization > 0, 'normalization should be a positive number'

    start_time = time.time()
    # back propagation on the normalized loss
    loss.div(normalization).backward()
    backward_time = time_since(start_time)

    if opt.max_grad_norm > 0:
        grad_norm_before_clipping = nn.utils.clip_grad_norm_(
            model.parameters(), opt.max_grad_norm)
        # grad_norm_after_clipping = (sum([p.grad.data.norm(2) ** 2 for p in model.parameters() if p.grad is not None])) ** (1.0 / 2)
        # logging.info('clip grad (%f -> %f)' % (grad_norm_before_clipping, grad_norm_after_clipping))

    optimizer.step()

    # Compute target encoder loss
    if opt.use_target_encoder and source_classification_dist is not None:
        start_time = time.time()
        optimizer.zero_grad()
        # convert source_representation_target_list to a LongTensor with size=[batch_size, max_num_delimiters]
        max_num_delimiters = delimiter_decoder_states.size(2)
        source_representation_target = torch.LongTensor(
            source_representation_target_list).to(trg.device)  # [batch_size]
        # expand along the second dimension, since for the target for each delimiter states in the same batch are the same
        source_representation_target = source_representation_target.view(
            -1,
            1).repeat(1,
                      max_num_delimiters)  # [batch_size, max_num_delimiters]
        # mask for source representation classification
        source_representation_target_mask = torch.zeros(
            batch_size, max_num_delimiters).to(trg.device)
        for i in range(batch_size):
            source_representation_target_mask[
                i, :delimiter_decoder_states_lens[i]].fill_(1)
        # compute the masked loss
        loss_te = masked_cross_entropy(source_classification_dist,
                                       source_representation_target,
                                       source_representation_target_mask)
        loss_compute_time += time_since(start_time)
        # back propagation on the normalized loss
        start_time = time.time()
        loss_te.div(normalization).backward()
        backward_time += time_since(start_time)

        if opt.max_grad_norm > 0:
            grad_norm_before_clipping = nn.utils.clip_grad_norm_(
                model.parameters(), opt.max_grad_norm)

        optimizer.step()

    # construct a statistic object for the loss
    stat = LossStatistics(loss.item(),
                          total_trg_tokens,
                          n_batch=1,
                          forward_time=forward_time,
                          loss_compute_time=loss_compute_time,
                          backward_time=backward_time)

    return stat, decoder_dist.detach()
Esempio n. 6
0
def train_one_batch(batch, model, ntm_model, optimizer, opt, batch_i):
    # train for one batch
    src, src_lens, src_mask, trg, trg_lens, trg_mask, src_oov, trg_oov, oov_lists, src_bow = batch
    max_num_oov = max([len(oov) for oov in oov_lists])  # max number of oov for each batch

    # move data to GPU if available
    src = src.to(opt.device)
    src_mask = src_mask.to(opt.device)
    trg = trg.to(opt.device)
    trg_mask = trg_mask.to(opt.device)
    src_oov = src_oov.to(opt.device)
    trg_oov = trg_oov.to(opt.device)

    # model.train()
    optimizer.zero_grad()

    if opt.use_topic_represent:
        src_bow = src_bow.to(opt.device)
        src_bow_norm = F.normalize(src_bow)
        if opt.topic_type == 'z':
            topic_represent, _, recon_batch, mu, logvar = ntm_model(src_bow_norm)
        else:
            _, topic_represent, recon_batch, mu, logvar = ntm_model(src_bow_norm)

        if opt.add_two_loss:
            ntm_loss = loss_function(recon_batch, src_bow, mu, logvar)
    else:
        topic_represent = None

    start_time = time.time()

    # for one2one setting
    decoder_dist, h_t, attention_dist, encoder_final_state, coverage, _, _, _ \
        = model(src, src_lens, trg, src_oov, max_num_oov, src_mask, topic_represent)

    forward_time = time_since(start_time)

    start_time = time.time()
    if opt.copy_attention:  # Compute the loss using target with oov words
        loss = masked_cross_entropy(decoder_dist, trg_oov, trg_mask, trg_lens,
                                    opt.coverage_attn, coverage, attention_dist, opt.lambda_coverage, opt.coverage_loss)
    else:  # Compute the loss using target without oov words
        loss = masked_cross_entropy(decoder_dist, trg, trg_mask, trg_lens,
                                    opt.coverage_attn, coverage, attention_dist, opt.lambda_coverage, opt.coverage_loss)

    loss_compute_time = time_since(start_time)

    total_trg_tokens = sum(trg_lens)

    if math.isnan(loss.item()):
        print("Batch i: %d" % batch_i)
        print("src")
        print(src)
        print(src_oov)
        print(src_lens)
        print(src_mask)
        print("trg")
        print(trg)
        print(trg_oov)
        print(trg_lens)
        print(trg_mask)
        print("oov list")
        print(oov_lists)
        print("Decoder")
        print(decoder_dist)
        print(h_t)
        print(attention_dist)
        raise ValueError("Loss is NaN")

    if opt.loss_normalization == "tokens":  # use number of target tokens to normalize the loss
        normalization = total_trg_tokens
    elif opt.loss_normalization == 'batches':  # use batch_size to normalize the loss
        normalization = src.size(0)
    else:
        raise ValueError('The type of loss normalization is invalid.')

    assert normalization > 0, 'normalization should be a positive number'

    start_time = time.time()
    if opt.add_two_loss:
        loss += ntm_loss
    # back propagation on the normalized loss
    loss.div(normalization).backward()
    backward_time = time_since(start_time)

    if opt.max_grad_norm > 0:
        grad_norm_before_clipping = nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm)

    optimizer.step()

    # construct a statistic object for the loss
    stat = LossStatistics(loss.item(), total_trg_tokens, n_batch=1, forward_time=forward_time,
                          loss_compute_time=loss_compute_time, backward_time=backward_time)

    return stat, decoder_dist.detach()