예제 #1
0
파일: train.py 프로젝트: zxlzr/TAKG
def main(opt):
    try:
        start_time = time.time()
        train_data_loader, train_bow_loader, valid_data_loader, valid_bow_loader, \
        word2idx, idx2word, vocab, bow_dictionary = load_data_and_vocab(opt, load_train=True)
        opt.bow_vocab_size = len(bow_dictionary)
        load_data_time = time_since(start_time)
        logging.info('Time for loading the data: %.1f' % load_data_time)

        start_time = time.time()
        model = Seq2SeqModel(opt).to(opt.device)
        ntm_model = NTM(opt).to(opt.device)
        optimizer_seq2seq, optimizer_ntm, optimizer_whole = init_optimizers(
            model, ntm_model, opt)

        train_mixture.train_model(model, ntm_model, optimizer_seq2seq,
                                  optimizer_ntm, optimizer_whole,
                                  train_data_loader, valid_data_loader,
                                  bow_dictionary, train_bow_loader,
                                  valid_bow_loader, opt)

        training_time = time_since(start_time)

        logging.info('Time for training: %.1f' % training_time)

    except Exception as e:
        logging.exception("message")
    return
예제 #2
0
def evaluate_valid_loss(data_loader, model, opt):
    model.eval()
    evaluation_loss_sum = 0.0
    total_trg_tokens = 0
    n_batch = 0
    loss_compute_time_total = 0.0
    forward_time_total = 0.0

    with torch.no_grad():
        for batch_i, batch in enumerate(data_loader):
            # load one2many dataset
            src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _ = batch
            num_trgs = [
                len(trg_str_list) for trg_str_list in trg_str_2dlist
            ]  # a list of num of targets in each batch, with len=batch_size

            max_num_oov = max([len(oov) for oov in oov_lists
                               ])  # max number of oov for each batch

            batch_size = src.size(0)
            n_batch += batch_size

            # move data to GPU if available
            src = src.to(opt.device)
            src_mask = src_mask.to(opt.device)
            trg = trg.to(opt.device)
            trg_mask = trg_mask.to(opt.device)
            src_oov = src_oov.to(opt.device)
            trg_oov = trg_oov.to(opt.device)

            start_time = time.time()
            decoder_dist, attention_dist = model(src, src_lens, trg, src_oov,
                                                 max_num_oov, src_mask,
                                                 num_trgs)
            forward_time = time_since(start_time)
            forward_time_total += forward_time

            start_time = time.time()

            loss = masked_cross_entropy(decoder_dist, trg_oov, trg_mask)
            loss_compute_time = time_since(start_time)
            loss_compute_time_total += loss_compute_time

            evaluation_loss_sum += loss.item()
            total_trg_tokens += sum(trg_lens)

    eval_loss_stat = LossStatistics(evaluation_loss_sum,
                                    total_trg_tokens,
                                    n_batch,
                                    forward_time=forward_time_total,
                                    loss_compute_time=loss_compute_time_total)
    return eval_loss_stat
예제 #3
0
def main(opt):
    try:
        start_time = time.time()
        train_data_loader, valid_data_loader, word2idx, idx2word, vocab = load_data_and_vocab(opt, load_train=True)
        load_data_time = time_since(start_time)
        logging.info('Time for loading the data: %.1f' % load_data_time)
        start_time = time.time()
        model = init_model(opt)
        optimizer_ml, optimizer_rl, criterion = init_optimizer_criterion(model, opt)
        if opt.train_ml:
            train_ml.train_model(model, optimizer_ml, optimizer_rl, criterion, train_data_loader, valid_data_loader, opt)
        else:
            train_rl.train_model(model, optimizer_ml, optimizer_rl, criterion, train_data_loader, valid_data_loader, opt)
        training_time = time_since(start_time)
        logging.info('Time for training: %.1f' % training_time)
    except Exception as e:
        logging.exception("message")
    return
예제 #4
0
def main(opt):
    start_time = time.time()
    train_bow_loader, valid_bow_loader, word2idx, idx2word, vocab, bow_dictionary \
                                        = load_data_and_vocab(opt, load_train=True)
    opt.bow_vocab_size = len(bow_dictionary)
    load_data_time = time_since(start_time)
    logging.info('Time for loading the data: %.1f' % load_data_time)

    start_time = time.time()
    ntm_model = NTM(opt).to(opt.device)
    optimizer_ntm = init_optimizers(ntm_model, opt)

    train_model.train_model(ntm_model, optimizer_ntm, bow_dictionary,
                            train_bow_loader, valid_bow_loader, opt)

    training_time = time_since(start_time)

    logging.info('Time for training: %.1f' % training_time)

    return
예제 #5
0
def main(opt):
    try:
        start_time = time.time()
        train_data_loader, valid_data_loader, word2idx, idx2word, vocab = load_data_and_vocab(
            opt, load_train=True)
        load_data_time = time_since(start_time)
        logging.info('Time for loading the data: %.1f' % load_data_time)
        start_time = time.time()
        model = init_model(opt)

        optimizer = Adam(params=filter(lambda p: p.requires_grad,
                                       model.parameters()),
                         lr=opt.learning_rate)
        train_model(model, optimizer, train_data_loader, valid_data_loader,
                    opt)

        training_time = time_since(start_time)
        logging.info('Time for training: %.1f' % training_time)
    except Exception as e:
        logging.exception("")
    return
예제 #6
0
def main(opt):
    try:
        start_time = time.time()
        load_data_time = time_since(start_time)
        test_data_loader, word2idx, idx2word, vocab = load_data_and_vocab(opt, load_train=False)
        model = init_pretrained_model(opt)
        logging.info('Time for loading the data and model: %.1f' % load_data_time)
        start_time = time.time()

        predict(test_data_loader, model, opt)

        total_testing_time = time_since(start_time)
        logging.info('Time for a complete testing: %.1f' % total_testing_time)
        print('Time for a complete testing: %.1f' % total_testing_time)
        sys.stdout.flush()

    except Exception as e:
        logging.exception("message")
    return

    pass
예제 #7
0
def train_one_batch(batch, model, optimizer, opt, batch_i):
    # load one2many data
    """
    src: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], with oov words replaced by unk idx
    src_lens: a list containing the length of src sequences for each batch, with len=batch
    src_mask: a FloatTensor, [batch, src_seq_len]
    src_oov: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], contains the index of oov words (used by copy)
    trg: LongTensor [batch, trg_seq_len], each target trg[i] contains the indices of a set of concatenated keyphrases, separated by opt.word2idx[pykp.io.SEP_WORD]
         if opt.delimiter_type = 0, SEP_WORD=<sep>, if opt.delimiter_type = 1, SEP_WORD=<eok>
    trg_lens: a list containing the length of trg sequences for each batch, with len=batch
    trg_mask: a FloatTensor, [batch, trg_seq_len]
    trg_oov: same as trg_oov, but all unk words are replaced with temporary idx, e.g. 50000, 50001 etc.
    """
    src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _ = batch
    # a list of num of targets in each batch, with len=batch_size
    num_trgs = [len(trg_str_list) for trg_str_list in trg_str_2dlist]

    max_num_oov = max([len(oov) for oov in oov_lists
                       ])  # max number of oov for each batch

    # move data to GPU if available
    src = src.to(opt.device)
    src_mask = src_mask.to(opt.device)
    trg = trg.to(opt.device)
    trg_mask = trg_mask.to(opt.device)
    src_oov = src_oov.to(opt.device)
    trg_oov = trg_oov.to(opt.device)

    optimizer.zero_grad()

    start_time = time.time()

    decoder_dist, attention_dist = model(src,
                                         src_lens,
                                         trg,
                                         src_oov,
                                         max_num_oov,
                                         src_mask,
                                         num_trgs=num_trgs)
    forward_time = time_since(start_time)

    start_time = time.time()

    loss = masked_cross_entropy(decoder_dist, trg_oov, trg_mask)

    loss_compute_time = time_since(start_time)

    total_trg_tokens = sum(trg_lens)

    if math.isnan(loss.item()):
        print("Batch i: %d" % batch_i)
        print("src")
        print(src)
        print(src_oov)
        print(src_str_list)
        print(src_lens)
        print(src_mask)
        print("trg")
        print(trg)
        print(trg_oov)
        print(trg_str_2dlist)
        print(trg_lens)
        print(trg_mask)
        print("oov list")
        print(oov_lists)
        print("Decoder")
        print(decoder_dist)
        print(attention_dist)
        raise ValueError("Loss is NaN")

    if opt.loss_normalization == "tokens":  # use number of target tokens to normalize the loss
        normalization = total_trg_tokens
    elif opt.loss_normalization == 'batches':  # use batch_size to normalize the loss
        normalization = src.size(0)
    else:
        raise ValueError('The type of loss normalization is invalid.')

    assert normalization > 0, 'normalization should be a positive number'

    start_time = time.time()
    # back propagation on the normalized loss
    loss.div(normalization).backward()
    backward_time = time_since(start_time)

    if opt.max_grad_norm > 0:
        grad_norm_before_clipping = nn.utils.clip_grad_norm_(
            model.parameters(), opt.max_grad_norm)

    optimizer.step()

    # construct a statistic object for the loss
    stat = LossStatistics(loss.item(),
                          total_trg_tokens,
                          n_batch=1,
                          forward_time=forward_time,
                          loss_compute_time=loss_compute_time,
                          backward_time=backward_time)

    return stat, decoder_dist.detach()
예제 #8
0
파일: evaluate.py 프로젝트: zxlzr/TAKG
def evaluate_loss(data_loader, model, ntm_model, opt):
    model.eval()
    ntm_model.eval()
    evaluation_loss_sum = 0.0
    total_trg_tokens = 0
    n_batch = 0
    loss_compute_time_total = 0.0
    forward_time_total = 0.0
    print("Evaluate loss for %d batches" % len(data_loader))
    with torch.no_grad():
        for batch_i, batch in enumerate(data_loader):
            if not opt.one2many:  # load one2one dataset
                src, src_lens, src_mask, trg, trg_lens, trg_mask, src_oov, trg_oov, oov_lists, src_bow = batch
            else:  # load one2many dataset
                src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _ = batch
                num_trgs = [len(trg_str_list) for trg_str_list in
                            trg_str_2dlist]  # a list of num of targets in each batch, with len=batch_size

            max_num_oov = max([len(oov) for oov in oov_lists])  # max number of oov for each batch

            batch_size = src.size(0)
            n_batch += batch_size

            # move data to GPU if available
            src = src.to(opt.device)
            src_mask = src_mask.to(opt.device)
            trg = trg.to(opt.device)
            trg_mask = trg_mask.to(opt.device)
            src_oov = src_oov.to(opt.device)
            trg_oov = trg_oov.to(opt.device)

            if opt.use_topic_represent:
                src_bow = src_bow.to(opt.device)
                src_bow_norm = F.normalize(src_bow)
                if opt.topic_type == 'z':
                    topic_represent, _, _, _, _ = ntm_model(src_bow_norm)
                else:
                    _, topic_represent, _, _, _ = ntm_model(src_bow_norm)
            else:
                topic_represent = None

            start_time = time.time()

            # one2one setting
            decoder_dist, h_t, attention_dist, encoder_final_state, coverage, _, _, _ \
                = model(src, src_lens, trg, src_oov, max_num_oov, src_mask, topic_represent)

            forward_time = time_since(start_time)
            forward_time_total += forward_time

            start_time = time.time()
            if opt.copy_attention:  # Compute the loss using target with oov words
                loss = masked_cross_entropy(decoder_dist, trg_oov, trg_mask, trg_lens,
                                            opt.coverage_attn, coverage, attention_dist, opt.lambda_coverage,
                                            coverage_loss=False)
            else:  # Compute the loss using target without oov words
                loss = masked_cross_entropy(decoder_dist, trg, trg_mask, trg_lens,
                                            opt.coverage_attn, coverage, attention_dist, opt.lambda_coverage,
                                            coverage_loss=False)
            loss_compute_time = time_since(start_time)
            loss_compute_time_total += loss_compute_time

            evaluation_loss_sum += loss.item()
            total_trg_tokens += sum(trg_lens)

            if (batch_i + 1) % (len(data_loader) // 5) == 0:
                print("Train: %d/%d batches, current avg loss: %.3f" %
                      ((batch_i + 1), len(data_loader), evaluation_loss_sum / total_trg_tokens))

    eval_loss_stat = LossStatistics(evaluation_loss_sum, total_trg_tokens, n_batch, forward_time=forward_time_total,
                                    loss_compute_time=loss_compute_time_total)
    return eval_loss_stat
예제 #9
0
def evaluate_loss(data_loader, model, opt):
    model.eval()
    evaluation_loss_sum = 0.0
    total_trg_tokens = 0
    n_batch = 0
    loss_compute_time_total = 0.0
    forward_time_total = 0.0

    with torch.no_grad():
        for batch_i, batch in enumerate(data_loader):
            if not opt.one2many:  # load one2one dataset
                src, src_lens, src_mask, trg, trg_lens, trg_mask, src_oov, trg_oov, oov_lists, title, title_oov, title_lens, title_mask = batch
            else:  # load one2many dataset
                src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _, title, title_oov, title_lens, title_mask = batch
                num_trgs = [
                    len(trg_str_list) for trg_str_list in trg_str_2dlist
                ]  # a list of num of targets in each batch, with len=batch_size

            max_num_oov = max([len(oov) for oov in oov_lists
                               ])  # max number of oov for each batch

            batch_size = src.size(0)
            n_batch += batch_size

            # move data to GPU if available
            src = src.to(opt.device)
            src_mask = src_mask.to(opt.device)
            trg = trg.to(opt.device)
            trg_mask = trg_mask.to(opt.device)
            src_oov = src_oov.to(opt.device)
            trg_oov = trg_oov.to(opt.device)
            if opt.title_guided:
                title = title.to(opt.device)
                title_mask = title_mask.to(opt.device)
                # title_oov = title_oov.to(opt.device)

            start_time = time.time()
            if not opt.one2many:
                decoder_dist, h_t, attention_dist, encoder_final_state, coverage, _, _, _ = model(
                    src,
                    src_lens,
                    trg,
                    src_oov,
                    max_num_oov,
                    src_mask,
                    title=title,
                    title_lens=title_lens,
                    title_mask=title_mask)
            else:
                decoder_dist, h_t, attention_dist, encoder_final_state, coverage, _, _, _ = model(
                    src,
                    src_lens,
                    trg,
                    src_oov,
                    max_num_oov,
                    src_mask,
                    num_trgs,
                    title=title,
                    title_lens=title_lens,
                    title_mask=title_mask)
            forward_time = time_since(start_time)
            forward_time_total += forward_time

            start_time = time.time()
            if opt.copy_attention:  # Compute the loss using target with oov words
                loss = masked_cross_entropy(decoder_dist,
                                            trg_oov,
                                            trg_mask,
                                            trg_lens,
                                            opt.coverage_attn,
                                            coverage,
                                            attention_dist,
                                            opt.lambda_coverage,
                                            coverage_loss=False)
            else:  # Compute the loss using target without oov words
                loss = masked_cross_entropy(decoder_dist,
                                            trg,
                                            trg_mask,
                                            trg_lens,
                                            opt.coverage_attn,
                                            coverage,
                                            attention_dist,
                                            opt.lambda_coverage,
                                            coverage_loss=False)
            loss_compute_time = time_since(start_time)
            loss_compute_time_total += loss_compute_time

            evaluation_loss_sum += loss.item()
            total_trg_tokens += sum(trg_lens)

    eval_loss_stat = LossStatistics(evaluation_loss_sum,
                                    total_trg_tokens,
                                    n_batch,
                                    forward_time=forward_time_total,
                                    loss_compute_time=loss_compute_time_total)
    return eval_loss_stat
def train_one_batch(batch, generator, optimizer, opt, lagrangian_params=None):
    src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_sent_2d_list, _, _, _, _, _ = batch
    """
    src: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], with oov words replaced by unk idx
    src_lens: a list containing the length of src sequences for each batch, with len=batch
    src_mask: a FloatTensor, [batch, src_seq_len]
    src_oov: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], contains the index of oov words (used by copy)
    oov_lists: a list of oov words for each src, 2dlist
    """

    # move data to GPU if available
    src = src.to(opt.device)
    src_mask = src_mask.to(opt.device)
    src_oov = src_oov.to(opt.device)

    optimizer.zero_grad()

    batch_size = src.size(0)
    reward_type = opt.reward_type
    sent_level_reward = opt.sent_level_reward
    baseline = opt.baseline
    regularization_type = opt.regularization_type
    regularization_factor = opt.regularization_factor

    if regularization_type == 2:
        entropy_regularize = True
    else:
        entropy_regularize = False

    trg_sent_2d_list_tokenized = [
    ]  # each item is a list of target sentences (tokenized) for an input sample
    trg_str_list = [
    ]  # each item is the target output sequence (tokenized) for an input sample
    for trg_sent_list in trg_sent_2d_list:
        trg_sent_list = [
            trg_sent.strip().split(' ') for trg_sent in trg_sent_list
        ]
        trg_sent_2d_list_tokenized.append(trg_sent_list)
        trg_str_list.append(list(concat(trg_sent_list)))

    trg_sent_2d_list = trg_sent_2d_list_tokenized  # each item is a list of target sentences (tokenized) for an input sample

    # if use self critical as baseline, greedily decode a sequence from the model
    if baseline == 'self':
        # sample greedy prediction
        generator.model.eval()
        with torch.no_grad():
            greedy_sample_list, _, _, greedy_eos_idx_mask, _, _ = generator.sample(
                src,
                src_lens,
                src_oov,
                src_mask,
                oov_lists,
                greedy=True,
                entropy_regularize=False)
            greedy_str_list = sample_list_to_str_list(greedy_sample_list,
                                                      oov_lists, opt.idx2word,
                                                      opt.vocab_size, io.EOS,
                                                      io.UNK, opt.replace_unk,
                                                      src_str_list)
            greedy_sent_2d_list = []
            for greedy_str in greedy_str_list:
                greedy_sent_list = nltk.tokenize.sent_tokenize(
                    ' '.join(greedy_str))
                greedy_sent_list = [
                    greedy_sent.strip().split(' ')
                    for greedy_sent in greedy_sent_list
                ]
                greedy_sent_2d_list.append(greedy_sent_list)

            # compute reward of greedily decoded sequence, tensor with size [batch_size]
            baseline = compute_batch_reward(greedy_str_list,
                                            greedy_sent_2d_list,
                                            trg_str_list,
                                            trg_sent_2d_list,
                                            batch_size,
                                            reward_type=reward_type,
                                            regularization_factor=0.0,
                                            regularization_type=0,
                                            entropy=None,
                                            device=src.device)
        generator.model.train()

    # sample a sequence from the model
    # sample_list is a list of dict, {"prediction": [], "scores": [], "attention": [], "done": True}, prediction is a list of 0 dim tensors
    # log_selected_token_dist: size: [batch, output_seq_len]

    # sample sequences for multiple times
    sample_batch_size = batch_size * opt.n_sample
    src = src.repeat(opt.n_sample, 1)
    src_lens = src_lens * opt.n_sample
    src_mask = src_mask.repeat(opt.n_sample, 1)
    src_oov = src_oov.repeat(opt.n_sample, 1)
    oov_lists = oov_lists * opt.n_sample
    src_str_list = src_str_list * opt.n_sample
    trg_sent_2d_list = trg_sent_2d_list * opt.n_sample
    trg_str_list = trg_str_list * opt.n_sample
    if opt.baseline != 'none':  # repeat the greedy rewards
        #baseline = np.tile(baseline, opt.n_sample)
        baseline = baseline.repeat(opt.n_sample)  # [sample_batch_size]

    start_time = time.time()
    sample_list, log_selected_token_dist, output_mask, pred_eos_idx_mask, entropy, location_of_eos_for_each_batch = generator.sample(
        src,
        src_lens,
        src_oov,
        src_mask,
        oov_lists,
        greedy=False,
        entropy_regularize=entropy_regularize)
    pred_str_list = sample_list_to_str_list(
        sample_list, oov_lists, opt.idx2word, opt.vocab_size, io.EOS, io.UNK,
        opt.replace_unk, src_str_list
    )  # a list of word list, len(pred_word_2dlist)=sample_batch_size
    sample_time = time_since(start_time)
    max_pred_seq_len = log_selected_token_dist.size(1)

    pred_sent_2d_list = [
    ]  # each item is a list of predicted sentences (tokenized) for an input sample, used to compute summary level Rouge-l
    for pred_str in pred_str_list:
        pred_sent_list = nltk.tokenize.sent_tokenize(' '.join(pred_str))
        pred_sent_list = [
            pred_sent.strip().split(' ') for pred_sent in pred_sent_list
        ]
        pred_sent_2d_list.append(pred_sent_list)

    if entropy_regularize:
        entropy_array = entropy.data.cpu().numpy()
    else:
        entropy_array = None

    # compute the reward
    with torch.no_grad():
        if sent_level_reward:
            raise ValueError("Not implemented.")
        else:  # neither using reward shaping
            # only receive reward at the end of whole sequence, tensor: [sample_batch_size]
            cumulative_reward = compute_batch_reward(
                pred_str_list,
                pred_sent_2d_list,
                trg_str_list,
                trg_sent_2d_list,
                sample_batch_size,
                reward_type=reward_type,
                regularization_factor=regularization_factor,
                regularization_type=regularization_type,
                entropy=entropy_array,
                device=src.device)
            # store the sum of cumulative reward (before baseline) for the experiment log
            cumulative_reward_sum = cumulative_reward.detach().sum(0).item()

            if opt.constrained_mdp:
                lagrangian_model, optimizer_lagrangian = lagrangian_params
                cumulative_cost = compute_batch_cost(
                    pred_str_list, pred_sent_2d_list, trg_str_list,
                    trg_sent_2d_list, sample_batch_size, opt.cost_types,
                    src.device)  # [sample_batch_size, num_cost_types]
                #cumulative_cost = torch.from_numpy(cumulative_cost_array).type(torch.FloatTensor).to(src.device)

                # cumulative_cost: [sample_batch_size, len(cost_types)]
                # subtract the regularization term: \lambda \dot C_t
                constraint_regularization = lagrangian_model.compute_regularization(
                    cumulative_cost)  # [sample_batch_size]
                cumulative_reward -= constraint_regularization

            # Subtract the cumulative reward by a baseline if needed
            if opt.baseline != 'none':
                cumulative_reward = cumulative_reward - baseline  # [sample_batch_size]
            # q value estimation for each time step equals to the (baselined) cumulative reward
            q_value_estimate = cumulative_reward.unsqueeze(1).repeat(
                1, max_pred_seq_len)  # [sample_batch_size, max_pred_seq_len]
            #q_value_estimate_array = np.tile(cumulative_reward.reshape([-1, 1]), [1, max_pred_seq_len])  # [batch, max_pred_seq_len]

    #shapped_baselined_reward = torch.gather(shapped_baselined_phrase_reward, dim=1, index=pred_phrase_idx_mask)

    # use the return as the estimation of q_value at each step

    #q_value_estimate = torch.from_numpy(q_value_estimate_array).type(torch.FloatTensor).to(src.device)
    q_value_estimate.requires_grad_(True)
    q_estimate_compute_time = time_since(start_time)

    # compute the policy gradient objective
    pg_loss = compute_pg_loss(log_selected_token_dist, output_mask,
                              q_value_estimate)

    # back propagation to compute the gradient
    if opt.loss_normalization == "samples":  # use number of target tokens to normalize the loss
        normalization = opt.n_sample
    elif opt.loss_normalization == 'batches':  # use batch_size to normalize the loss
        normalization = sample_batch_size
    else:
        normalization = 1
    start_time = time.time()
    pg_loss.div(normalization).backward()
    backward_time = time_since(start_time)

    if opt.max_grad_norm > 0:
        grad_norm_before_clipping = nn.utils.clip_grad_norm_(
            generator.model.parameters(), opt.max_grad_norm)

    # take a step of gradient descent
    optimizer.step()

    stat = RewardStatistics(cumulative_reward_sum, pg_loss.item(),
                            sample_batch_size, sample_time,
                            q_estimate_compute_time, backward_time)
    # (final_reward=0.0, pg_loss=0.0, n_batch=0, sample_time=0, q_estimate_compute_time=0, backward_time=0)
    # reward=0.0, pg_loss=0.0, n_batch=0, sample_time=0, q_estimate_compute_time=0, backward_time=0

    if opt.constrained_mdp:
        lagrangian_loss, lagrangian_grad_norm, violate_amount = train_lagrangian_multiplier(
            lagrangian_model, cumulative_cost, optimizer_lagrangian,
            normalization, opt.max_grad_norm)
        lagrangian_stat = LagrangianStatistics(
            lagrangian_loss=lagrangian_loss,
            n_batch=sample_batch_size,
            lagrangian_grad_norm=lagrangian_grad_norm,
            violate_amount=violate_amount)
        stat = (stat, lagrangian_stat)

    return stat, log_selected_token_dist.detach()
예제 #11
0
def train_one_batch(batch, model, ntm_model, optimizer, opt, batch_i):
    # train for one batch
    src, src_lens, src_mask, trg, trg_lens, trg_mask, src_oov, trg_oov, oov_lists, src_bow = batch
    max_num_oov = max([len(oov) for oov in oov_lists])  # max number of oov for each batch

    # move data to GPU if available
    src = src.to(opt.device)
    src_mask = src_mask.to(opt.device)
    trg = trg.to(opt.device)
    trg_mask = trg_mask.to(opt.device)
    src_oov = src_oov.to(opt.device)
    trg_oov = trg_oov.to(opt.device)

    # model.train()
    optimizer.zero_grad()

    if opt.use_topic_represent:
        src_bow = src_bow.to(opt.device)
        src_bow_norm = F.normalize(src_bow)
        if opt.topic_type == 'z':
            topic_represent, _, recon_batch, mu, logvar = ntm_model(src_bow_norm)
        else:
            _, topic_represent, recon_batch, mu, logvar = ntm_model(src_bow_norm)

        if opt.add_two_loss:
            ntm_loss = loss_function(recon_batch, src_bow, mu, logvar)
    else:
        topic_represent = None

    start_time = time.time()

    # for one2one setting
    decoder_dist, h_t, attention_dist, encoder_final_state, coverage, _, _, _ \
        = model(src, src_lens, trg, src_oov, max_num_oov, src_mask, topic_represent)

    forward_time = time_since(start_time)

    start_time = time.time()
    if opt.copy_attention:  # Compute the loss using target with oov words
        loss = masked_cross_entropy(decoder_dist, trg_oov, trg_mask, trg_lens,
                                    opt.coverage_attn, coverage, attention_dist, opt.lambda_coverage, opt.coverage_loss)
    else:  # Compute the loss using target without oov words
        loss = masked_cross_entropy(decoder_dist, trg, trg_mask, trg_lens,
                                    opt.coverage_attn, coverage, attention_dist, opt.lambda_coverage, opt.coverage_loss)

    loss_compute_time = time_since(start_time)

    total_trg_tokens = sum(trg_lens)

    if math.isnan(loss.item()):
        print("Batch i: %d" % batch_i)
        print("src")
        print(src)
        print(src_oov)
        print(src_lens)
        print(src_mask)
        print("trg")
        print(trg)
        print(trg_oov)
        print(trg_lens)
        print(trg_mask)
        print("oov list")
        print(oov_lists)
        print("Decoder")
        print(decoder_dist)
        print(h_t)
        print(attention_dist)
        raise ValueError("Loss is NaN")

    if opt.loss_normalization == "tokens":  # use number of target tokens to normalize the loss
        normalization = total_trg_tokens
    elif opt.loss_normalization == 'batches':  # use batch_size to normalize the loss
        normalization = src.size(0)
    else:
        raise ValueError('The type of loss normalization is invalid.')

    assert normalization > 0, 'normalization should be a positive number'

    start_time = time.time()
    if opt.add_two_loss:
        loss += ntm_loss
    # back propagation on the normalized loss
    loss.div(normalization).backward()
    backward_time = time_since(start_time)

    if opt.max_grad_norm > 0:
        grad_norm_before_clipping = nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm)

    optimizer.step()

    # construct a statistic object for the loss
    stat = LossStatistics(loss.item(), total_trg_tokens, n_batch=1, forward_time=forward_time,
                          loss_compute_time=loss_compute_time, backward_time=backward_time)

    return stat, decoder_dist.detach()
예제 #12
0
def evaluate_loss(data_loader,
                  overall_model,
                  classification_loss_func,
                  opt,
                  print_incon_stats=False):
    overall_model.eval()
    generation_loss_sum = 0.0
    joint_loss_sum = 0.0
    classification_loss_sum = 0.0
    enc_classification_loss_sum = 0.0
    dec_classification_loss_sum = 0.0
    inconsist_loss_sum = 0.0
    total_trg_tokens = 0
    total_num_iterations = 0
    loss_compute_time_total = 0.0
    forward_time_total = 0.0
    enc_rating_preds = None
    dec_rating_preds = None
    incon_loss_preds = None

    with torch.no_grad():
        for batch_i, batch in enumerate(data_loader):
            # src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_sent_2d_list, trg, trg_oov, trg_lens, trg_mask, rating, _ = batch

            # changed by wchen to a dictionary batch
            src = batch['src_tensor']
            src_lens = batch['src_lens']
            src_mask = batch['src_mask']
            src_sent_positions = batch['src_sent_positions']
            src_sent_nums = batch['src_sent_nums']
            src_sent_mask = batch['src_sent_mask']
            src_oov = batch['src_oov_tensor']
            oov_lists = batch['oov_lists']
            src_str_list = batch['src_list_tokenized']
            trg_sent_2d_list = batch['tgt_sent_2d_list']
            trg = batch['tgt_tensor']
            trg_oov = batch['tgt_oov_tensor']
            trg_lens = batch['tgt_lens']
            trg_mask = batch['tgt_mask']
            rating = batch['rating_tensor']
            indices = batch['original_indices']

            max_num_oov = max([len(oov) for oov in oov_lists
                               ])  # max number of oov for each batch
            batch_size = src.size(0)
            total_num_iterations += 1
            # move data to GPU if available
            src = src.to(opt.device)
            src_mask = src_mask.to(opt.device)
            src_sent_positions = src_sent_positions.to(opt.device)
            src_sent_mask = src_sent_mask.to(opt.device)
            trg = trg.to(opt.device)
            trg_mask = trg_mask.to(opt.device)
            src_oov = src_oov.to(opt.device)
            trg_oov = trg_oov.to(opt.device)
            rating = rating.to(opt.device)

            start_time = time.time()

            # forward
            if overall_model.model_type == 'hre_max':
                decoder_dist, h_t, seq2seq_attention_dist, encoder_final_state, coverage, classifier_logit, classifier_attention_dist = \
                    overall_model(src, src_lens, trg, src_oov, max_num_oov, src_mask, trg_mask, src_sent_positions, src_sent_nums, src_sent_mask)
            else:
                decoder_dist, h_t, seq2seq_attention_dist, encoder_final_state, coverage, classifier_logit, classifier_attention_dist = overall_model(
                    src, src_lens, trg, src_oov, max_num_oov, src_mask,
                    trg_mask, rating, src_sent_positions, src_sent_nums,
                    src_sent_mask)

            forward_time = time_since(start_time)
            forward_time_total += forward_time

            start_time = time.time()
            if decoder_dist is not None:
                if opt.copy_attention:  # Compute the loss using target with oov words
                    generation_loss = masked_cross_entropy(
                        decoder_dist,
                        trg_oov,
                        trg_mask,
                        trg_lens,
                        opt.coverage_attn,
                        coverage,
                        seq2seq_attention_dist,
                        opt.lambda_coverage,
                        coverage_loss=False)
                else:  # Compute the loss using target without oov words
                    generation_loss = masked_cross_entropy(
                        decoder_dist,
                        trg,
                        trg_mask,
                        trg_lens,
                        opt.coverage_attn,
                        coverage,
                        seq2seq_attention_dist,
                        opt.lambda_coverage,
                        coverage_loss=False)
            else:
                generation_loss = torch.Tensor([0.0]).to(opt.device)

            # normalize generation loss
            num_trg_tokens = sum(trg_lens)
            normalized_generation_loss = generation_loss.div(num_trg_tokens)

            # compute loss of classification
            if print_incon_stats:
                assert isinstance(classifier_logit, tuple)

            if classifier_logit is not None:
                if isinstance(classifier_logit, tuple):
                    # from multi_view_model
                    enc_classifier_logit = classifier_logit[0]
                    dec_classifier_logit = classifier_logit[1]
                    enc_normalized_classification_loss = classification_loss_func(
                        classifier_logit[0],
                        rating)  # normalized by batch size already
                    dec_normalized_classification_loss = classification_loss_func(
                        classifier_logit[1],
                        rating)  # normalized by batch size already
                    # compute loss of inconsistency for the multi view model
                    if opt.inconsistency_loss_type != "None" or print_incon_stats:
                        inconsistency_loss = inconsistency_loss_func(
                            classifier_logit[0], classifier_logit[1],
                            opt.inconsistency_loss_type,
                            opt.detach_dec_incosist_loss)
                    else:
                        inconsistency_loss = torch.Tensor([0.0]).to(opt.device)
                else:
                    enc_classifier_logit = classifier_logit
                    dec_classifier_logit = None
                    enc_normalized_classification_loss = classification_loss_func(
                        classifier_logit,
                        rating)  # normalized by batch size already
                    dec_normalized_classification_loss = torch.Tensor(
                        [0.0]).to(opt.device)
                    inconsistency_loss = torch.Tensor([0.0]).to(opt.device)
            else:
                enc_classifier_logit = None
                dec_classifier_logit = None
                enc_normalized_classification_loss = torch.Tensor([0.0]).to(
                    opt.device)
                dec_normalized_classification_loss = torch.Tensor([0.0]).to(
                    opt.device)
                inconsistency_loss = torch.Tensor([0.0]).to(opt.device)

            total_normalized_classification_loss = opt.class_loss_internal_enc_weight * enc_normalized_classification_loss + \
                                                   opt.class_loss_internal_dec_weight * dec_normalized_classification_loss

            # compute validation performance
            if enc_rating_preds is None and dec_rating_preds is None:
                if opt.ordinal:
                    enc_rating_preds = binary_results_to_rating_preds(
                        enc_classifier_logit.detach().cpu().numpy(
                        )) if enc_classifier_logit is not None else None
                    dec_rating_preds = binary_results_to_rating_preds(
                        dec_classifier_logit.detach().cpu().numpy(
                        )) if dec_classifier_logit is not None else None
                else:
                    enc_rating_preds = enc_classifier_logit.detach().cpu(
                    ).numpy() if enc_classifier_logit is not None else None
                    dec_rating_preds = dec_classifier_logit.detach().cpu(
                    ).numpy() if dec_classifier_logit is not None else None
                    # if print_incon_stats:
                    #     incon_loss_preds = inconsistency_loss.detach().cpu().numpy()
                out_label_ids = rating.detach().cpu().numpy()
            else:
                if opt.ordinal:
                    enc_rating_preds = np.append(
                        enc_rating_preds,
                        binary_results_to_rating_preds(
                            enc_classifier_logit.detach().cpu().numpy()),
                        axis=0) if enc_classifier_logit is not None else None
                    dec_rating_preds = np.append(
                        dec_rating_preds,
                        binary_results_to_rating_preds(
                            dec_classifier_logit.detach().cpu().numpy()),
                        axis=0) if dec_classifier_logit is not None else None
                else:
                    enc_rating_preds = np.append(
                        enc_rating_preds,
                        enc_classifier_logit.detach().cpu().numpy(),
                        axis=0) if enc_classifier_logit is not None else None
                    dec_rating_preds = np.append(
                        dec_rating_preds,
                        dec_classifier_logit.detach().cpu().numpy(),
                        axis=0) if dec_classifier_logit is not None else None
                    # if print_incon_stats:
                    #     incon_loss_preds = np.append(incon_loss_preds, inconsistency_loss.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids,
                                          rating.detach().cpu().numpy(),
                                          axis=0)
            # joint loss
            joint_loss = opt.gen_loss_weight * normalized_generation_loss + opt.class_loss_weight * total_normalized_classification_loss + opt.inconsistency_loss_weight * inconsistency_loss

            loss_compute_time = time_since(start_time)
            loss_compute_time_total += loss_compute_time

            classification_loss_sum += total_normalized_classification_loss
            enc_classification_loss_sum += enc_normalized_classification_loss
            dec_classification_loss_sum += dec_normalized_classification_loss
            inconsist_loss_sum += inconsistency_loss
            joint_loss_sum += joint_loss.item()
            generation_loss_sum += generation_loss.item()
            total_trg_tokens += num_trg_tokens

    if not opt.ordinal:
        # merged preds
        if enc_rating_preds is not None and dec_rating_preds is not None:
            merged_rating_preds = (enc_rating_preds + dec_rating_preds) / 2
            merged_rating_preds = np.argmax(merged_rating_preds, axis=1)
        else:
            merged_rating_preds = None
        enc_rating_preds = np.argmax(
            enc_rating_preds, axis=1) if enc_rating_preds is not None else None
        dec_rating_preds = np.argmax(
            dec_rating_preds, axis=1) if dec_rating_preds is not None else None

    if print_incon_stats:
        inconsistency_statistics(out_label_ids, enc_rating_preds,
                                 dec_rating_preds, merged_rating_preds)

    enc_classification_result = acc_and_macro_f1(
        enc_rating_preds, out_label_ids) if enc_rating_preds is not None else {
            "acc": 0.0,
            "f1": 0.0,
            "acc_and_f1": 0.0
        }
    dec_classification_result = acc_and_macro_f1(
        dec_rating_preds,
        out_label_ids) if dec_rating_preds is not None else None
    loss_stat = JointLossStatistics(joint_loss_sum,
                                    generation_loss_sum,
                                    enc_classification_loss_sum,
                                    dec_classification_loss_sum,
                                    inconsist_loss_sum,
                                    total_num_iterations,
                                    total_trg_tokens,
                                    forward_time=forward_time_total,
                                    loss_compute_time=loss_compute_time_total)
    # joint_loss=0.0, generation_loss=0.0, classification_loss=0.0, n_iterations=0, n_tokens=0, forward_time=0.0, loss_compute_time=0.0, backward_time=0.0

    return loss_stat, (enc_classification_result, dec_classification_result)
def main():
    #print("agsnf efnghrrqthg")
    clip = 5
    start_time = time.time()
    train_data_loader, valid_data_loader, word2idx, idx2word, vocab = load_data_and_vocab(
        opt, load_train=True)
    load_data_time = time_since(start_time)
    print(idx2word[5])
    logging.info('Time for loading the data: %.1f' % load_data_time)

    model = Seq2SeqModel(opt)
    #model = model.device()
    #print("The Device is",opt.gpuid)
    #model = model.to(devices)
    model = model.to(devices)

    # model.load_state_dict(torch.load("model/kp20k.ml.one2many.cat.copy.bi-directional.20190628-114655/kp20k.ml.one2many.cat.copy.bi-directional.epoch=2.batch=54573.total_batch=116000.model"))
    model.load_state_dict(
        torch.load(
            "model/kp20k.ml.one2many.cat.copy.bi-directional.20190715-132016/kp20k.ml.one2many.cat.copy.bi-directional.epoch=3.batch=26098.total_batch=108000.model"
        ))
    generator = SequenceGenerator(model,
                                  bos_idx=opt.word2idx[pykp.io.BOS_WORD],
                                  eos_idx=opt.word2idx[pykp.io.EOS_WORD],
                                  pad_idx=opt.word2idx[pykp.io.PAD_WORD],
                                  peos_idx=opt.word2idx[pykp.io.PEOS_WORD],
                                  beam_size=1,
                                  max_sequence_length=opt.max_length,
                                  copy_attn=opt.copy_attention,
                                  coverage_attn=opt.coverage_attn,
                                  review_attn=opt.review_attn,
                                  cuda=opt.gpuid > -1)

    init_perturb_std = opt.init_perturb_std
    final_perturb_std = opt.final_perturb_std
    perturb_decay_factor = opt.perturb_decay_factor
    perturb_decay_mode = opt.perturb_decay_mode

    D_model = Discriminator(opt.vocab_size, embedding_dim, hidden_dim,
                            n_layers, opt.word2idx[pykp.io.PAD_WORD])

    print("The Discriminator statistics are ", D_model)

    if torch.cuda.is_available():
        D_model = D_model.to(devices)

    D_model.train()

    D_optimizer = torch.optim.Adam(D_model.parameters(), lr=0.001)

    print("gdsf")
    total_epochs = 5
    for epoch in range(total_epochs):

        total_batch = 0
        print("Starting with epoch:", epoch)
        for batch_i, batch in enumerate(train_data_loader):
            total_batch += 1
            D_optimizer.zero_grad()

            if perturb_decay_mode == 0:  # do not decay
                perturb_std = init_perturb_std
            elif perturb_decay_mode == 1:  # exponential decay
                perturb_std = final_perturb_std + (
                    init_perturb_std - final_perturb_std) * math.exp(
                        -1. * total_batch * perturb_decay_factor)
            elif perturb_decay_mode == 2:  # steps decay
                perturb_std = init_perturb_std * math.pow(
                    perturb_decay_factor, math.floor((1 + total_batch) / 4000))

            avg_batch_loss, real_r, fake_r = train_one_batch(
                D_model, batch, generator, opt, perturb_std)
            #            print("Currently loss is",avg_batch_loss.item())
            #            print("Currently real loss is",real_r.item())
            #            print("Currently fake loss is",fake_r.item())
            #            state_dfs = D_model.state_dict()
            #            torch.save(state_dfs,"Checkpoint_" + str(epoch) + ".pth.tar")
            #

            if batch_i % 350 == 0:
                print("Currently loss is", avg_batch_loss.item())
                print("Currently real loss is", real_r.item())
                print("Currently fake loss is", fake_r.item())

                print("Saving the file ...............----------->>>>>")
                state_dfs = D_model.state_dict()
                torch.save(
                    state_dfs, "Discriminator_checkpts/D_model_combined" +
                    str(epoch) + ".pth.tar")

            torch.nn.utils.clip_grad_norm_(D_model.parameters(), clip)
            avg_batch_loss.backward()
            D_optimizer.step()
            #sys.exit()

            #sys.exit()

        print("Saving the file ...............----------->>>>>")
        state_dfs = D_model.state_dict()
        torch.save(
            state_dfs, "Discriminator_checkpts/D_model_combined" + str(epoch) +
            ".pth.tar")
예제 #14
0
def main(opt):
    try:
        start_time = time.time()

        # construct vocab
        with open(join(opt.data, 'vocab_cnt.pkl'), 'rb') as f:
            wc = pkl.load(f)
        word2idx, idx2word = io.make_vocab(wc, opt.v_size)
        opt.word2idx = word2idx
        opt.idx2word = idx2word

        # construct
        if opt.rating_memory_pred:
            rating_tokens_tensor = []
            for i in range(1, 6):
                vocab_i = os.path.join(
                    opt.data,
                    'rating_{}_vocab_counter_no_stop_word_and_punc.pkl'.format(
                        i))
                vocab_i = pkl.load(open(vocab_i, 'rb'))
                vocab_i = vocab_i.most_common(opt.rating_v_size)
                vocab_tokens_i = [w[0] for w in vocab_i]
                # topk_rating_tokens.append(vocab_tokens_i)
                # we assume all the topk rating tokens are in the predefined vocabulary
                # otherwise, one error will happen
                # This condition also brings convenience for the final copy mechanism
                vocab_tokens_i_tensor = [word2idx[w] for w in vocab_tokens_i]
                vocab_tokens_i_tensor = torch.LongTensor(vocab_tokens_i_tensor)
                rating_tokens_tensor.append(vocab_tokens_i_tensor)
            # [5, rating_v_size]
            rating_tokens_tensor = torch.stack(rating_tokens_tensor, dim=0)
            # save rating_tokens_tensor
            torch.save(rating_tokens_tensor,
                       os.path.join(opt.model_path, 'rating_tokens_tensor.pt'))
        else:
            rating_tokens_tensor = None

        # dump word2idx
        with open(join(opt.model_path, 'vocab.pkl'), 'wb') as f:
            pkl.dump(word2idx, f, pkl.HIGHEST_PROTOCOL)

        # construct loader
        load_data_time = time_since(start_time)
        train_data_loader, valid_data_loader, class_weights = build_loader(
            opt.data, opt.batch_size, word2idx, opt.src_max_len,
            opt.trg_max_len, opt.batch_workers, opt.weighted_sampling)
        logging.info('Time for loading the data: %.1f' % load_data_time)

        # construct model
        start_time = time.time()
        overall_model = init_model(opt, rating_tokens_tensor)
        logging.info(overall_model)

        # construct optimizer
        optimizer_ml = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, overall_model.parameters()),
                                        lr=opt.learning_rate)

        # construct loss function
        #print(class_weights)
        #exit()
        if opt.classifier_loss_type == "ordinal_mse":
            train_classification_loss_func = OrdinalMSELoss(opt.num_classes,
                                                            device=opt.device)
            val_classification_loss_func = train_classification_loss_func
        elif opt.classifier_loss_type == "ordinal_xe":
            train_classification_loss_func = OrdinalXELoss(opt.num_classes,
                                                           device=opt.device)
            val_classification_loss_func = train_classification_loss_func
        else:
            if opt.weighted_classifier_loss:
                train_classification_loss_func = nn.NLLLoss(
                    reduction='mean', weight=class_weights)
            else:
                train_classification_loss_func = nn.NLLLoss(reduction='mean')
            val_classification_loss_func = nn.NLLLoss(reduction='mean')

        # train the model
        ml_pipeline.train_model(overall_model, optimizer_ml, train_data_loader,
                                valid_data_loader, opt,
                                train_classification_loss_func,
                                val_classification_loss_func)

        training_time = time_since(start_time)
        logging.info('Model path: {}'.format(opt.model_path))
        logging.info('Time for training: {}'.format(
            datetime.timedelta(seconds=training_time)))

    except Exception as e:
        logging.exception("message")
    return
from pykp.model import Seq2SeqModel
from torch.optim import Adam
import pykp
from pykp.model import Seq2SeqModel
import train_ml
import train_rl

from utils.time_log import time_since
from utils.data_loader import load_data_and_vocab
from utils.string_helper import convert_list_to_kphs
import time
import numpy as np
import random
from torch import device 
from Discriminator_Softmax import Discriminator

#####################################################################################################
opt = argparse.Namespace(attn_mode='concat', baseline='self', batch_size=32, batch_workers=4, bidirectional=True, bridge='copy', checkpoint_interval=4000, copy_attention=True, copy_input_feeding=False, coverage_attn=False, coverage_loss=False, custom_data_filename_suffix=False, custom_vocab_filename_suffix=False, data='data/kp20k_separated/', data_filename_suffix='', dec_layers=1, decay_method='', decoder_size=300, decoder_type='rnn', delimiter_type=0, delimiter_word='<sep>', device=device(type='cuda', index=2), disable_early_stop_rl=False, dropout=0.1, dynamic_dict=True, early_stop_tolerance=4, enc_layers=1, encoder_size=150, encoder_type='rnn', epochs=20, exp='kp20k.rl.one2many.cat.copy.bi-directional', exp_path='exp/kp20k.rl.one2many.cat.copy.bi-directional.20190701-192604', final_perturb_std=0, fix_word_vecs_dec=False, fix_word_vecs_enc=False, goal_vector_mode=0, goal_vector_size=16, gpuid=1, init_perturb_std=0, input_feeding=False, lambda_coverage=1, lambda_orthogonal=0.03, lambda_target_encoder=0.03, learning_rate=0.001, learning_rate_decay=0.5, learning_rate_decay_rl=False, learning_rate_rl=5e-05, loss_normalization='tokens', manager_mode=1, match_type='exact', max_grad_norm=1, max_length=60, max_sample_length=6, max_unk_words=1000, mc_rollouts=False, model_path='model/kp20k.rl.one2many.cat.copy.bi-directional.20190701-192604', must_teacher_forcing=False, num_predictions=1, num_rollouts=3, one2many=True, one2many_mode=1, optim='adam', orthogonal_loss=False, param_init=0.1, perturb_baseline=False, perturb_decay_factor=0.0001, perturb_decay_mode=1, pre_word_vecs_dec=None, pre_word_vecs_enc=None, pretrained_model='model/kp20k.ml.one2many.cat.copy.bi-directional.20190628-114655/kp20k.ml.one2many.cat.copy.bi-directional.epoch=2.batch=54573.total_batch=116000.model', regularization_factor=0.0, regularization_type=0, remove_src_eos=False, replace_unk=True, report_every=10, review_attn=False, reward_shaping=False, reward_type=7, save_model='model', scheduled_sampling=False, scheduled_sampling_batches=10000, seed=9527, separate_present_absent=True, share_embeddings=True, source_representation_queue_size=128, source_representation_sample_size=32, start_checkpoint_at=2, start_decay_at=8, start_epoch=1, target_encoder_size=64, teacher_forcing_ratio=0, timemark='20190701-192604', title_guided=False, topk='G', train_from='', train_ml=False, train_rl=True, truncated_decoder=0, use_target_encoder=False, vocab='data/kp20k_separated/', vocab_filename_suffix='', vocab_size=50002, warmup_steps=4000, word_vec_size=100, words_min_frequency=0)


hidden_dim = 150
embedding_dim = 200
n_layers = 2 
clip = 5 

def main():
      clip = 5
    start_time = time.time()
    train_data_loader, valid_data_loader, word2idx, idx2word, vocab = load_data_and_vocab(opt, load_train=True)
    load_data_time = time_since(start_time)
    
예제 #16
0
def main(opt):
    #print("agsnf efnghrrqthg")
    clip = 5
    start_time = time.time()
    train_data_loader, valid_data_loader, word2idx, idx2word, vocab = load_data_and_vocab(
        opt, load_train=True)
    load_data_time = time_since(start_time)
    logging.info('Time for loading the data: %.1f' % load_data_time)

    print("______________________ Data Successfully Loaded ______________")
    model = Seq2SeqModel(opt)
    if torch.cuda.is_available():
        model.load_state_dict(torch.load(opt.model_path))
        model = model.to(opt.gpuid)
    else:
        model.load_state_dict(torch.load(opt.model_path, map_location="cpu"))

    print(
        "___________________ Generator Initialised and Loaded _________________________"
    )
    generator = SequenceGenerator(model,
                                  bos_idx=opt.word2idx[pykp.io.BOS_WORD],
                                  eos_idx=opt.word2idx[pykp.io.EOS_WORD],
                                  pad_idx=opt.word2idx[pykp.io.PAD_WORD],
                                  peos_idx=opt.word2idx[pykp.io.PEOS_WORD],
                                  beam_size=1,
                                  max_sequence_length=opt.max_length,
                                  copy_attn=opt.copy_attention,
                                  coverage_attn=opt.coverage_attn,
                                  review_attn=opt.review_attn,
                                  cuda=opt.gpuid > -1)

    init_perturb_std = opt.init_perturb_std
    final_perturb_std = opt.final_perturb_std
    perturb_decay_factor = opt.perturb_decay_factor
    perturb_decay_mode = opt.perturb_decay_mode
    hidden_dim = opt.D_hidden_dim
    embedding_dim = opt.D_embedding_dim
    n_layers = opt.D_layers

    hidden_dim = opt.D_hidden_dim
    embedding_dim = opt.D_embedding_dim
    n_layers = opt.D_layers
    D_model = Discriminator(opt.vocab_size, embedding_dim, hidden_dim,
                            n_layers, opt.word2idx[pykp.io.PAD_WORD])
    print("The Discriminator Description is ", D_model)

    PG_optimizer = torch.optim.Adagrad(model.parameters(),
                                       opt.learning_rate_rl)
    if torch.cuda.is_available():
        D_model.load_state_dict(torch.load(opt.Discriminator_model_path))
        D_model = D_model.to(opt.gpuid)
    else:
        D_model.load_state_dict(
            torch.load(opt.Discriminator_model_path, map_location="cpu"))

    # D_model.load_state_dict(torch.load("Discriminator_checkpts/D_model_combined1.pth.tar"))
    total_epochs = opt.epochs
    for epoch in range(total_epochs):

        total_batch = 0
        print("Starting with epoch:", epoch)
        for batch_i, batch in enumerate(train_data_loader):

            model.train()
            PG_optimizer.zero_grad()

            if perturb_decay_mode == 0:  # do not decay
                perturb_std = init_perturb_std
            elif perturb_decay_mode == 1:  # exponential decay
                perturb_std = final_perturb_std + (
                    init_perturb_std - final_perturb_std) * math.exp(
                        -1. * total_batch * perturb_decay_factor)
            elif perturb_decay_mode == 2:  # steps decay
                perturb_std = init_perturb_std * math.pow(
                    perturb_decay_factor, math.floor((1 + total_batch) / 4000))

            avg_rewards = train_one_batch(D_model, batch, generator, opt,
                                          perturb_std)

            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            avg_rewards.backward()
            PG_optimizer.step()

            if batch_i % 4000 == 0:
                print("Saving the file ...............----------->>>>>")
                print("The avg reward is", -avg_rewards.item())
                state_dfs = model.state_dict()
                torch.save(
                    state_dfs, "RL_Checkpoints/Attention_Generator_" +
                    str(epoch) + ".pth.tar")
예제 #17
0
def main(opt):
    clip = 5
    start_time = time.time()
    train_data_loader, valid_data_loader, word2idx, idx2word, vocab = load_data_and_vocab(
        opt, load_train=True)
    load_data_time = time_since(start_time)
    logging.info('Time for loading the data: %.1f' % load_data_time)

    print(
        "Data Successfully Loaded __.__.__.__.__.__.__.__.__.__.__.__.__.__.")
    model = Seq2SeqModel(opt)

    ##    if torch.cuda.is_available():
    if torch.cuda.is_available():
        model.load_state_dict(torch.load(opt.model_path))
        model = model.to(opt.gpuid)
    else:
        model.load_state_dict(torch.load(opt.model_path, map_location="cpu"))

    print(
        "___________________ Generator Initialised and Loaded _________________________"
    )
    generator = SequenceGenerator(model,
                                  bos_idx=opt.word2idx[pykp.io.BOS_WORD],
                                  eos_idx=opt.word2idx[pykp.io.EOS_WORD],
                                  pad_idx=opt.word2idx[pykp.io.PAD_WORD],
                                  peos_idx=opt.word2idx[pykp.io.PEOS_WORD],
                                  beam_size=1,
                                  max_sequence_length=opt.max_length,
                                  copy_attn=opt.copy_attention,
                                  coverage_attn=opt.coverage_attn,
                                  review_attn=opt.review_attn,
                                  cuda=opt.gpuid > -1)

    init_perturb_std = opt.init_perturb_std
    final_perturb_std = opt.final_perturb_std
    perturb_decay_factor = opt.perturb_decay_factor
    perturb_decay_mode = opt.perturb_decay_mode
    hidden_dim = opt.D_hidden_dim
    embedding_dim = opt.D_embedding_dim
    n_layers = opt.D_layers
    if torch.cuda.is_available():
        D_model = Discriminator(opt.vocab_size, embedding_dim, hidden_dim,
                                n_layers, opt.word2idx[pykp.io.PAD_WORD],
                                opt.gpuid)
    else:
        D_model = Discriminator(opt.vocab_size, embedding_dim, hidden_dim,
                                n_layers, opt.word2idx[pykp.io.PAD_WORD],
                                "cpu")
    print("The Discriminator Description is ", D_model)
    if opt.pretrained_Discriminator:
        if torch.cuda.is_available():
            D_model.load_state_dict(torch.load(opt.Discriminator_model_path))
            D_model = D_model.to(opt.gpuid)
        else:
            D_model.load_state_dict(
                torch.load(opt.Discriminator_model_path, map_location="cpu"))
    else:
        if torch.cuda.is_available():
            D_model = D_model.to(opt.gpuid)
        else:
            D_model.load_state_dict(
                torch.load(opt.Discriminator_model_path, map_location="cpu"))
    D_optimizer = torch.optim.Adam(D_model.parameters(), opt.learning_rate)
    print("Beginning with training Discriminator")
    print(
        "########################################################################################################"
    )
    total_epochs = 5
    for epoch in range(total_epochs):
        total_batch = 0
        print("Starting with epoch:", epoch)
        for batch_i, batch in enumerate(train_data_loader):
            best_valid_loss = 1000
            D_model.train()
            D_optimizer.zero_grad()

            if perturb_decay_mode == 0:  # do not decay
                perturb_std = init_perturb_std
            elif perturb_decay_mode == 1:  # exponential decay
                perturb_std = final_perturb_std + (
                    init_perturb_std - final_perturb_std) * math.exp(
                        -1. * total_batch * perturb_decay_factor)
            elif perturb_decay_mode == 2:  # steps decay
                perturb_std = init_perturb_std * math.pow(
                    perturb_decay_factor, math.floor((1 + total_batch) / 4000))
            avg_batch_loss, _, _ = train_one_batch(D_model, batch, generator,
                                                   opt, perturb_std)
            torch.nn.utils.clip_grad_norm_(D_model.parameters(), clip)
            avg_batch_loss.backward()

            D_optimizer.step()
            D_model.eval()

            if batch_i % 4000 == 0:
                total = 0
                valid_loss_total, valid_real_total, valid_fake_total = 0, 0, 0
                for batch_j, valid_batch in enumerate(valid_data_loader):
                    total += 1
                    valid_loss, valid_real, valid_fake = train_one_batch(
                        D_model, valid_batch, generator, opt, perturb_std)
                    valid_loss_total += valid_loss.cpu().detach().numpy()
                    valid_real_total += valid_real.cpu().detach().numpy()
                    valid_fake_total += valid_fake.cpu().detach().numpy()
                    D_optimizer.zero_grad()

                print("Currently loss is ", valid_loss_total.item() / total)
                print("Currently real loss is ",
                      valid_real_total.item() / total)
                print("Currently fake loss is ",
                      valid_fake_total.item() / total)

                if best_valid_loss > valid_loss_total.item() / total:
                    print(
                        "Loss Decreases so saving the file ...............----------->>>>>"
                    )
                    state_dfs = D_model.state_dict()
                    torch.save(
                        state_dfs,
                        "Discriminator_checkpts/Attention_Disriminator_" +
                        str(epoch) + ".pth.tar")
                    best_valid_loss = valid_loss_total.item() / total
예제 #18
0
def main():
    #print("agsnf efnghrrqthg")
    print("dfsgf")
    clip = 5
    start_time = time.time()
    train_data_loader, valid_data_loader, word2idx, idx2word, vocab = load_data_and_vocab(
        opt, load_train=True)
    load_data_time = time_since(start_time)
    logging.info('Time for loading the data: %.1f' % load_data_time)

    model = Seq2SeqModel(opt)
    #model = model.device()
    #print("The Device is",opt.gpuid)
    model = model.to("cuda:2")

    #model.load_state_dict(torch.load("model/kp20k.ml.one2many.cat.copy.bi-directional.20190704-170553/kp20k.ml.one2many.cat.copy.bi-directional.epoch=2.batch=264.total_batch=8000.model"))
    # model.load_state_dict(torch.load("Checkpoint_individual_3.pth.tar"))
    model.load_state_dict(
        torch.load(
            "model/kp20k.ml.one2many.cat.copy.bi-directional.20190715-132016/kp20k.ml.one2many.cat.copy.bi-directional.epoch=3.batch=26098.total_batch=108000.model"
        ))
    generator = SequenceGenerator(model,
                                  bos_idx=opt.word2idx[pykp.io.BOS_WORD],
                                  eos_idx=opt.word2idx[pykp.io.EOS_WORD],
                                  pad_idx=opt.word2idx[pykp.io.PAD_WORD],
                                  peos_idx=opt.word2idx[pykp.io.PEOS_WORD],
                                  beam_size=1,
                                  max_sequence_length=opt.max_length,
                                  copy_attn=opt.copy_attention,
                                  coverage_attn=opt.coverage_attn,
                                  review_attn=opt.review_attn,
                                  cuda=opt.gpuid > -1)

    init_perturb_std = opt.init_perturb_std
    final_perturb_std = opt.final_perturb_std
    perturb_decay_factor = opt.perturb_decay_factor
    perturb_decay_mode = opt.perturb_decay_mode

    D_model = Discriminator(opt.vocab_size, embedding_dim, hidden_dim,
                            n_layers, opt.word2idx[pykp.io.PAD_WORD])

    # D_model.load_state_dict(torch.load("Discriminator_checkpts/Checkpoint_Individual_Training_4.pth.tar"))

    PG_optimizer = torch.optim.Adagrad(model.parameters(), 0.00005)

    print("The Discriminator statistics are ", D_model)

    if torch.cuda.is_available():
        D_model = D_model.to("cuda:1")

    total_epochs = 5
    for epoch in range(total_epochs):

        total_batch = 0
        print("Starting with epoch:", epoch)
        for batch_i, batch in enumerate(valid_data_loader):
            total_batch += 1

            PG_optimizer.zero_grad()

            if perturb_decay_mode == 0:  # do not decay
                perturb_std = init_perturb_std
            elif perturb_decay_mode == 1:  # exponential decay
                perturb_std = final_perturb_std + (
                    init_perturb_std - final_perturb_std) * math.exp(
                        -1. * total_batch * perturb_decay_factor)
            elif perturb_decay_mode == 2:  # steps decay
                perturb_std = init_perturb_std * math.pow(
                    perturb_decay_factor, math.floor((1 + total_batch) / 4000))

            avg_rewards = train_one_batch(D_model, batch, generator, opt,
                                          perturb_std)

            avg_rewards.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

            PG_optimizer.step()

            if batch_i % 15 == 0:
                print("The avg reward is", -avg_rewards.item())
            if batch_i % 100 == 0:
                print("Saving the file ...............----------->>>>>")
                print("The avg reward is", -avg_rewards.item())
                state_dfs = model.state_dict()
                torch.save(
                    state_dfs, "RL_Checkpoints/Checkpoint_SeqGAN_" +
                    str(epoch) + ".pth.tar")

        print("Saving the file ...............----------->>>>>")
        state_dfs = model.state_dict()
        torch.save(
            state_dfs,
            "RL_Checkpoints/Checkpoint_SeqGAN_" + str(epoch) + ".pth.tar")
예제 #19
0
def train_one_batch(one2many_batch, generator, optimizer, opt, perturb_std=0):
    src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _, title, title_oov, title_lens, title_mask = one2many_batch
    """
    src: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], with oov words replaced by unk idx
    src_lens: a list containing the length of src sequences for each batch, with len=batch
    src_mask: a FloatTensor, [batch, src_seq_len]
    src_oov: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], contains the index of oov words (used by copy)
    oov_lists: a list of oov words for each src, 2dlist
    """

    one2many = opt.one2many
    one2many_mode = opt.one2many_mode
    if one2many and one2many_mode > 1:
        num_predictions = opt.num_predictions
    else:
        num_predictions = 1

    # move data to GPU if available
    src = src.to(opt.device)
    src_mask = src_mask.to(opt.device)
    src_oov = src_oov.to(opt.device)
    # trg = trg.to(opt.device)
    # trg_mask = trg_mask.to(opt.device)
    # trg_oov = trg_oov.to(opt.device)

    if opt.title_guided:
        title = title.to(opt.device)
        title_mask = title_mask.to(opt.device)
        #title_oov = title_oov.to(opt.device)

    optimizer.zero_grad()

    eos_idx = opt.word2idx[pykp.io.EOS_WORD]
    delimiter_word = opt.delimiter_word
    batch_size = src.size(0)
    topk = opt.topk
    reward_type = opt.reward_type
    reward_shaping = opt.reward_shaping
    baseline = opt.baseline
    match_type = opt.match_type
    regularization_type = opt.regularization_type
    regularization_factor = opt.regularization_factor

    if regularization_type == 2:
        entropy_regularize = True
    else:
        entropy_regularize = False

    if opt.perturb_baseline:
        baseline_perturb_std = perturb_std
    else:
        baseline_perturb_std = 0

    #generator.model.train()

    # sample a sequence from the model
    # sample_list is a list of dict, {"prediction": [], "scores": [], "attention": [], "done": True}, prediction is a list of 0 dim tensors
    # log_selected_token_dist: size: [batch, output_seq_len]
    start_time = time.time()
    sample_list, log_selected_token_dist, output_mask, pred_eos_idx_mask, entropy, location_of_eos_for_each_batch, location_of_peos_for_each_batch = generator.sample(
        src,
        src_lens,
        src_oov,
        src_mask,
        oov_lists,
        opt.max_length,
        greedy=False,
        one2many=one2many,
        one2many_mode=one2many_mode,
        num_predictions=num_predictions,
        perturb_std=perturb_std,
        entropy_regularize=entropy_regularize,
        title=title,
        title_lens=title_lens,
        title_mask=title_mask)
    pred_str_2dlist = sample_list_to_str_2dlist(
        sample_list, oov_lists, opt.idx2word, opt.vocab_size, eos_idx,
        delimiter_word, opt.word2idx[pykp.io.UNK_WORD], opt.replace_unk,
        src_str_list, opt.separate_present_absent, pykp.io.PEOS_WORD)
    sample_time = time_since(start_time)
    max_pred_seq_len = log_selected_token_dist.size(1)

    if entropy_regularize:
        entropy_array = entropy.data.cpu().numpy()
    else:
        entropy_array = None

    # if use self critical as baseline, greedily decode a sequence from the model
    if baseline == 'self':
        generator.model.eval()
        with torch.no_grad():
            start_time = time.time()
            greedy_sample_list, _, _, greedy_eos_idx_mask, _, _, _ = generator.sample(
                src,
                src_lens,
                src_oov,
                src_mask,
                oov_lists,
                opt.max_length,
                greedy=True,
                one2many=one2many,
                one2many_mode=one2many_mode,
                num_predictions=num_predictions,
                perturb_std=baseline_perturb_std,
                title=title,
                title_lens=title_lens,
                title_mask=title_mask)
            greedy_str_2dlist = sample_list_to_str_2dlist(
                greedy_sample_list, oov_lists, opt.idx2word, opt.vocab_size,
                eos_idx, delimiter_word, opt.word2idx[pykp.io.UNK_WORD],
                opt.replace_unk, src_str_list, opt.separate_present_absent,
                pykp.io.PEOS_WORD)
        generator.model.train()

    # Compute the reward for each predicted keyphrase
    # if using reward shaping, each keyphrase will have its own reward, else, only the last keyphrase will get a reward
    # In addition, we adds a regularization terms to the reward

    if reward_shaping:
        max_num_pred_phrases = max(
            [len(pred_str_list) for pred_str_list in pred_str_2dlist])

        # compute the reward for each phrase, np array with size: [batch_size, num_predictions]
        phrase_reward = compute_phrase_reward(
            pred_str_2dlist, trg_str_2dlist, batch_size, max_num_pred_phrases,
            reward_shaping, reward_type, topk, match_type,
            regularization_factor, regularization_type, entropy_array)
        # store the sum of cumulative reward for the experiment log
        cumulative_reward = phrase_reward[:, -1]
        cumulative_reward_sum = cumulative_reward.sum(0)

        # Subtract reward by a baseline if needed
        if opt.baseline == 'self':
            max_num_greedy_phrases = max([
                len(greedy_str_list) for greedy_str_list in greedy_str_2dlist
            ])
            assert max_num_pred_phrases == max_num_greedy_phrases, "if you use self-critical training with reward shaping, make sure the number of phrases sampled from the policy and that decoded by greedy are the same."
            # use the reward of greedy decoding as baseline
            phrase_baseline = compute_phrase_reward(
                greedy_str_2dlist, trg_str_2dlist, batch_size,
                max_num_greedy_phrases, reward_shaping, reward_type, topk,
                match_type, regularization_factor, regularization_type,
                entropy_array)
            phrase_reward = phrase_reward - phrase_baseline

        # convert each phrase reward to its improvement in reward
        phrase_reward = shape_reward(phrase_reward)

        # convert to reward received at each decoding step
        stepwise_reward = phrase_reward_to_stepwise_reward(
            phrase_reward, pred_eos_idx_mask)
        q_value_estimate_array = np.cumsum(stepwise_reward[:, ::-1],
                                           axis=1)[:, ::-1].copy()

    elif opt.separate_present_absent:
        present_absent_reward = compute_present_absent_reward(
            pred_str_2dlist,
            trg_str_2dlist,
            reward_type=reward_type,
            topk=topk,
            match_type=match_type,
            regularization_factor=regularization_factor,
            regularization_type=regularization_type,
            entropy=entropy_array)
        cumulative_reward = present_absent_reward.sum(1)
        cumulative_reward_sum = cumulative_reward.sum(0)
        # Subtract reward by a baseline if needed
        if opt.baseline == 'self':
            present_absent_baseline = compute_present_absent_reward(
                greedy_str_2dlist,
                trg_str_2dlist,
                reward_type=reward_type,
                topk=topk,
                match_type=match_type,
                regularization_factor=regularization_factor,
                regularization_type=regularization_type,
                entropy=entropy_array)
            present_absent_reward = present_absent_reward - present_absent_baseline
        stepwise_reward = present_absent_reward_to_stepwise_reward(
            present_absent_reward, max_pred_seq_len,
            location_of_peos_for_each_batch, location_of_eos_for_each_batch)
        q_value_estimate_array = np.cumsum(stepwise_reward[:, ::-1],
                                           axis=1)[:, ::-1].copy()

    else:  # neither using reward shaping
        # only receive reward at the end of whole sequence, np array: [batch_size]
        cumulative_reward = compute_batch_reward(
            pred_str_2dlist,
            trg_str_2dlist,
            batch_size,
            reward_type=reward_type,
            topk=topk,
            match_type=match_type,
            regularization_factor=regularization_factor,
            regularization_type=regularization_type,
            entropy=entropy_array)
        # store the sum of cumulative reward (before baseline) for the experiment log
        cumulative_reward_sum = cumulative_reward.sum(0)
        # Subtract the cumulative reward by a baseline if needed
        if opt.baseline == 'self':
            baseline = compute_batch_reward(
                greedy_str_2dlist,
                trg_str_2dlist,
                batch_size,
                reward_type=reward_type,
                topk=topk,
                match_type=match_type,
                regularization_factor=regularization_factor,
                regularization_type=regularization_type,
                entropy=entropy_array)
            cumulative_reward = cumulative_reward - baseline
        # q value estimation for each time step equals to the (baselined) cumulative reward
        q_value_estimate_array = np.tile(cumulative_reward.reshape(
            [-1, 1]), [1, max_pred_seq_len])  # [batch, max_pred_seq_len]

    #shapped_baselined_reward = torch.gather(shapped_baselined_phrase_reward, dim=1, index=pred_phrase_idx_mask)

    # use the return as the estimation of q_value at each step

    q_value_estimate = torch.from_numpy(q_value_estimate_array).type(
        torch.FloatTensor).to(src.device)
    q_value_estimate.requires_grad_(True)
    q_estimate_compute_time = time_since(start_time)

    # compute the policy gradient objective
    pg_loss = compute_pg_loss(log_selected_token_dist, output_mask,
                              q_value_estimate)

    # back propagation to compute the gradient
    start_time = time.time()
    pg_loss.backward()
    backward_time = time_since(start_time)

    if opt.max_grad_norm > 0:
        grad_norm_before_clipping = nn.utils.clip_grad_norm_(
            generator.model.parameters(), opt.max_grad_norm)

    # take a step of gradient descent
    optimizer.step()

    stat = RewardStatistics(cumulative_reward_sum, pg_loss.item(), batch_size,
                            sample_time, q_estimate_compute_time,
                            backward_time)
    # (final_reward=0.0, pg_loss=0.0, n_batch=0, sample_time=0, q_estimate_compute_time=0, backward_time=0)
    # reward=0.0, pg_loss=0.0, n_batch=0, sample_time=0, q_estimate_compute_time=0, backward_time=0

    return stat, log_selected_token_dist.detach()
예제 #20
0
def evaluate_reward(data_loader, generator, opt):
    """Return the avg. reward in the validation dataset"""
    generator.model.eval()
    final_reward_sum = 0.0
    n_batch = 0
    sample_time_total = 0.0
    topk = opt.topk
    reward_type = opt.reward_type
    #reward_type = 7
    match_type = opt.match_type
    eos_idx = opt.word2idx[pykp.io.EOS_WORD]
    delimiter_word = opt.delimiter_word
    one2many = opt.one2many
    one2many_mode = opt.one2many_mode
    if one2many and one2many_mode > 1:
        num_predictions = opt.num_predictions
    else:
        num_predictions = 1

    with torch.no_grad():
        for batch_i, batch in enumerate(data_loader):
            # load one2many dataset
            src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _, title, title_oov, title_lens, title_mask = batch
            num_trgs = [
                len(trg_str_list) for trg_str_list in trg_str_2dlist
            ]  # a list of num of targets in each batch, with len=batch_size

            batch_size = src.size(0)
            n_batch += batch_size

            # move data to GPU if available
            src = src.to(opt.device)
            src_mask = src_mask.to(opt.device)
            src_oov = src_oov.to(opt.device)
            #trg = trg.to(opt.device)
            #trg_mask = trg_mask.to(opt.device)
            #trg_oov = trg_oov.to(opt.device)
            if opt.title_guided:
                title = title.to(opt.device)
                title_mask = title_mask.to(opt.device)
                # title_oov = title_oov.to(opt.device)

            start_time = time.time()
            # sample a sequence
            # sample_list is a list of dict, {"prediction": [], "scores": [], "attention": [], "done": True}, preidiction is a list of 0 dim tensors
            sample_list, log_selected_token_dist, output_mask, pred_idx_mask, _, _, _ = generator.sample(
                src,
                src_lens,
                src_oov,
                src_mask,
                oov_lists,
                opt.max_length,
                greedy=True,
                one2many=one2many,
                one2many_mode=one2many_mode,
                num_predictions=num_predictions,
                perturb_std=0,
                title=title,
                title_lens=title_lens,
                title_mask=title_mask)
            #pred_str_2dlist = sample_list_to_str_2dlist(sample_list, oov_lists, opt.idx2word, opt.vocab_size, eos_idx, delimiter_word)
            pred_str_2dlist = sample_list_to_str_2dlist(
                sample_list, oov_lists, opt.idx2word, opt.vocab_size, eos_idx,
                delimiter_word, opt.word2idx[pykp.io.UNK_WORD],
                opt.replace_unk, src_str_list)
            #print(pred_str_2dlist)
            sample_time = time_since(start_time)
            sample_time_total += sample_time

            final_reward = compute_batch_reward(
                pred_str_2dlist,
                trg_str_2dlist,
                batch_size,
                reward_type,
                topk,
                match_type,
                regularization_factor=0.0)  # np.array, [batch_size]

            final_reward_sum += final_reward.sum(0)

    eval_reward_stat = RewardStatistics(final_reward_sum,
                                        pg_loss=0,
                                        n_batch=n_batch,
                                        sample_time=sample_time_total)

    return eval_reward_stat
예제 #21
0
def train_one_batch(batch,
                    model,
                    optimizer,
                    opt,
                    batch_i,
                    source_representation_queue=None):
    if not opt.one2many:  # load one2one data
        src, src_lens, src_mask, trg, trg_lens, trg_mask, src_oov, trg_oov, oov_lists, title, title_oov, title_lens, title_mask = batch
        """
        src: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], with oov words replaced by unk idx
        src_lens: a list containing the length of src sequences for each batch, with len=batch
        src_mask: a FloatTensor, [batch, src_seq_len]
        trg: a LongTensor containing the word indices of target sentences, [batch, trg_seq_len]
        trg_lens: a list containing the length of trg sequences for each batch, with len=batch
        trg_mask: a FloatTensor, [batch, trg_seq_len]
        src_oov: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], contains the index of oov words (used by copy)
        trg_oov: a LongTensor containing the word indices of target sentences, [batch, src_seq_len], contains the index of oov words (used by copy)
        """
    else:  # load one2many data
        src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _, title, title_oov, title_lens, title_mask = batch
        num_trgs = [
            len(trg_str_list) for trg_str_list in trg_str_2dlist
        ]  # a list of num of targets in each batch, with len=batch_size
        """
        trg: LongTensor [batch, trg_seq_len], each target trg[i] contains the indices of a set of concatenated keyphrases, separated by opt.word2idx[pykp.io.SEP_WORD]
             if opt.delimiter_type = 0, SEP_WORD=<sep>, if opt.delimiter_type = 1, SEP_WORD=<eos>
        trg_oov: same as trg_oov, but all unk words are replaced with temporary idx, e.g. 50000, 50001 etc.
        """
    batch_size = src.size(0)
    max_num_oov = max([len(oov) for oov in oov_lists
                       ])  # max number of oov for each batch

    # move data to GPU if available
    src = src.to(opt.device)
    src_mask = src_mask.to(opt.device)
    trg = trg.to(opt.device)
    trg_mask = trg_mask.to(opt.device)
    src_oov = src_oov.to(opt.device)
    trg_oov = trg_oov.to(opt.device)
    if opt.title_guided:
        title = title.to(opt.device)
        title_mask = title_mask.to(opt.device)
        #title_oov = title_oov.to(opt.device)
    # title, title_oov, title_lens, title_mask

    optimizer.zero_grad()

    #if opt.one2many_mode == 0 or opt.one2many_mode == 1:
    start_time = time.time()

    if opt.use_target_encoder:  # Sample encoder representations
        if len(source_representation_queue
               ) < opt.source_representation_sample_size:
            source_representation_samples_2dlist = None
            source_representation_target_list = None
        else:
            source_representation_samples_2dlist = []
            source_representation_target_list = []
            for i in range(batch_size):
                # N encoder representation from the queue
                source_representation_samples_list = source_representation_queue.sample(
                    opt.source_representation_sample_size)
                # insert a place-holder for the ground-truth source representation to a random index
                place_holder_idx = np.random.randint(
                    0, opt.source_representation_sample_size + 1)
                source_representation_samples_list.insert(
                    place_holder_idx, None)  # len=N+1
                # insert the sample list of one batch to the 2d list
                source_representation_samples_2dlist.append(
                    source_representation_samples_list)
                # store the idx of place-holder for that batch
                source_representation_target_list.append(place_holder_idx)
    else:
        source_representation_samples_2dlist = None
        source_representation_target_list = None
        """
        if encoder_representation_samples_2dlist[0] is None and batch_i > math.ceil(
                opt.encoder_representation_sample_size / batch_size):
            # a return value of none indicates we don't have sufficient samples
            # it will only occurs in the first few training steps
            raise ValueError("encoder_representation_samples should not be none at this batch!")
        """

    if not opt.one2many:
        decoder_dist, h_t, attention_dist, encoder_final_state, coverage, delimiter_decoder_states, delimiter_decoder_states_lens, source_classification_dist = model(
            src,
            src_lens,
            trg,
            src_oov,
            max_num_oov,
            src_mask,
            sampled_source_representation_2dlist=
            source_representation_samples_2dlist,
            source_representation_target_list=source_representation_target_list,
            title=title,
            title_lens=title_lens,
            title_mask=title_mask)
    else:
        decoder_dist, h_t, attention_dist, encoder_final_state, coverage, delimiter_decoder_states, delimiter_decoder_states_lens, source_classification_dist = model(
            src,
            src_lens,
            trg,
            src_oov,
            max_num_oov,
            src_mask,
            num_trgs=num_trgs,
            sampled_source_representation_2dlist=
            source_representation_samples_2dlist,
            source_representation_target_list=source_representation_target_list,
            title=title,
            title_lens=title_lens,
            title_mask=title_mask)
    forward_time = time_since(start_time)

    if opt.use_target_encoder:  # Put all the encoder final states to the queue. Need to call detach() first
        # encoder_final_state: [batch, memory_bank_size]
        [
            source_representation_queue.put(encoder_final_state[i, :].detach())
            for i in range(batch_size)
        ]

    start_time = time.time()
    if opt.copy_attention:  # Compute the loss using target with oov words
        loss = masked_cross_entropy(
            decoder_dist, trg_oov, trg_mask, trg_lens, opt.coverage_attn,
            coverage, attention_dist, opt.lambda_coverage, opt.coverage_loss,
            delimiter_decoder_states, opt.orthogonal_loss,
            opt.lambda_orthogonal, delimiter_decoder_states_lens)
    else:  # Compute the loss using target without oov words
        loss = masked_cross_entropy(
            decoder_dist, trg, trg_mask, trg_lens, opt.coverage_attn, coverage,
            attention_dist, opt.lambda_coverage, opt.coverage_loss,
            delimiter_decoder_states, opt.orthogonal_loss,
            opt.lambda_orthogonal, delimiter_decoder_states_lens)

    loss_compute_time = time_since(start_time)

    #else:  # opt.one2many_mode == 2
    #    forward_time = 0
    #    loss_compute_time = 0
    #    # TODO: a for loop to accumulate loss for each keyphrase
    #    # TODO: meanwhile, accumulate the forward time and loss_compute time
    #    pass

    total_trg_tokens = sum(trg_lens)

    if math.isnan(loss.item()):
        print("Batch i: %d" % batch_i)
        print("src")
        print(src)
        print(src_oov)
        print(src_str_list)
        print(src_lens)
        print(src_mask)
        print("trg")
        print(trg)
        print(trg_oov)
        print(trg_str_2dlist)
        print(trg_lens)
        print(trg_mask)
        print("oov list")
        print(oov_lists)
        print("Decoder")
        print(decoder_dist)
        print(h_t)
        print(attention_dist)
        raise ValueError("Loss is NaN")

    if opt.loss_normalization == "tokens":  # use number of target tokens to normalize the loss
        normalization = total_trg_tokens
    elif opt.loss_normalization == 'batches':  # use batch_size to normalize the loss
        normalization = src.size(0)
    else:
        raise ValueError('The type of loss normalization is invalid.')

    assert normalization > 0, 'normalization should be a positive number'

    start_time = time.time()
    # back propagation on the normalized loss
    loss.div(normalization).backward()
    backward_time = time_since(start_time)

    if opt.max_grad_norm > 0:
        grad_norm_before_clipping = nn.utils.clip_grad_norm_(
            model.parameters(), opt.max_grad_norm)
        # grad_norm_after_clipping = (sum([p.grad.data.norm(2) ** 2 for p in model.parameters() if p.grad is not None])) ** (1.0 / 2)
        # logging.info('clip grad (%f -> %f)' % (grad_norm_before_clipping, grad_norm_after_clipping))

    optimizer.step()

    # Compute target encoder loss
    if opt.use_target_encoder and source_classification_dist is not None:
        start_time = time.time()
        optimizer.zero_grad()
        # convert source_representation_target_list to a LongTensor with size=[batch_size, max_num_delimiters]
        max_num_delimiters = delimiter_decoder_states.size(2)
        source_representation_target = torch.LongTensor(
            source_representation_target_list).to(trg.device)  # [batch_size]
        # expand along the second dimension, since for the target for each delimiter states in the same batch are the same
        source_representation_target = source_representation_target.view(
            -1,
            1).repeat(1,
                      max_num_delimiters)  # [batch_size, max_num_delimiters]
        # mask for source representation classification
        source_representation_target_mask = torch.zeros(
            batch_size, max_num_delimiters).to(trg.device)
        for i in range(batch_size):
            source_representation_target_mask[
                i, :delimiter_decoder_states_lens[i]].fill_(1)
        # compute the masked loss
        loss_te = masked_cross_entropy(source_classification_dist,
                                       source_representation_target,
                                       source_representation_target_mask)
        loss_compute_time += time_since(start_time)
        # back propagation on the normalized loss
        start_time = time.time()
        loss_te.div(normalization).backward()
        backward_time += time_since(start_time)

        if opt.max_grad_norm > 0:
            grad_norm_before_clipping = nn.utils.clip_grad_norm_(
                model.parameters(), opt.max_grad_norm)

        optimizer.step()

    # construct a statistic object for the loss
    stat = LossStatistics(loss.item(),
                          total_trg_tokens,
                          n_batch=1,
                          forward_time=forward_time,
                          loss_compute_time=loss_compute_time,
                          backward_time=backward_time)

    return stat, decoder_dist.detach()
예제 #22
0
def evaluate_beam_search(generator,
                         one2many_data_loader,
                         opt,
                         delimiter_word='<sep>'):
    #score_dict_all = defaultdict(list)  # {'precision@5':[],'recall@5':[],'f1_score@5':[],'num_matches@5':[],'precision@10':[],'recall@10':[],'f1score@10':[],'num_matches@10':[]}
    # file for storing the predicted keyphrases
    if opt.pred_file_prefix == "":
        pred_output_file = open(os.path.join(opt.pred_path, "predictions.txt"),
                                "w")
    else:
        pred_output_file = open(
            os.path.join(opt.pred_path,
                         "%s_predictions.txt" % opt.pred_file_prefix), "w")
    # debug
    interval = 1000

    with torch.no_grad():
        start_time = time.time()
        for batch_i, batch in enumerate(one2many_data_loader):
            if (batch_i + 1) % interval == 0:
                print(
                    "Batch %d: Time for running beam search on %d batches : %.1f"
                    % (batch_i + 1, interval, time_since(start_time)))
                sys.stdout.flush()
                start_time = time.time()
            src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, _, _, _, _, original_idx_list, title, title_oov, title_lens, title_mask = batch
            """
            src: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], with oov words replaced by unk idx
            src_lens: a list containing the length of src sequences for each batch, with len=batch
            src_mask: a FloatTensor, [batch, src_seq_len]
            src_oov: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], contains the index of oov words (used by copy)
            oov_lists: a list of oov words for each src, 2dlist
            """
            src = src.to(opt.device)
            src_mask = src_mask.to(opt.device)
            src_oov = src_oov.to(opt.device)
            if opt.title_guided:
                title = title.to(opt.device)
                title_mask = title_mask.to(opt.device)
                # title_oov = title_oov.to(opt.device)

            beam_search_result = generator.beam_search(
                src,
                src_lens,
                src_oov,
                src_mask,
                oov_lists,
                opt.word2idx,
                opt.max_eos_per_output_seq,
                title=title,
                title_lens=title_lens,
                title_mask=title_mask)
            pred_list = preprocess_beam_search_result(
                beam_search_result, opt.idx2word, opt.vocab_size, oov_lists,
                opt.word2idx[pykp.io.EOS_WORD], opt.word2idx[pykp.io.UNK_WORD],
                opt.replace_unk, src_str_list)
            # list of {"sentences": [], "scores": [], "attention": []}

            # recover the original order in the dataset
            seq_pairs = sorted(zip(original_idx_list, src_str_list,
                                   trg_str_2dlist, pred_list, oov_lists),
                               key=lambda p: p[0])
            original_idx_list, src_str_list, trg_str_2dlist, pred_list, oov_lists = zip(
                *seq_pairs)

            # Process every src in the batch
            for src_str, trg_str_list, pred, oov in zip(
                    src_str_list, trg_str_2dlist, pred_list, oov_lists):
                # src_str: a list of words; trg_str: a list of keyphrases, each keyphrase is a list of words
                # pred_seq_list: a list of sequence objects, sorted by scores
                # oov: a list of oov words
                pred_str_list = pred[
                    'sentences']  # predicted sentences from a single src, a list of list of word, with len=[beam_size, out_seq_len], does not include the final <EOS>
                pred_score_list = pred['scores']
                pred_attn_list = pred[
                    'attention']  # a list of FloatTensor[output sequence length, src_len], with len = [n_best]

                if opt.one2many:
                    all_keyphrase_list = [
                    ]  # a list of word list contains all the keyphrases in the top max_n sequences decoded by beam search
                    for word_list in pred_str_list:
                        all_keyphrase_list += split_word_list_by_delimiter(
                            word_list, delimiter_word,
                            opt.separate_present_absent, pykp.io.PEOS_WORD)
                        #not_duplicate_mask = check_duplicate_keyphrases(all_keyphrase_list)
                    #pred_str_list = [word_list for word_list, is_keep in zip(all_keyphrase_list, not_duplicate_mask) if is_keep]
                    pred_str_list = all_keyphrase_list

                # output the predicted keyphrases to a file
                pred_print_out = ''
                for word_list_i, word_list in enumerate(pred_str_list):
                    if word_list_i < len(pred_str_list) - 1:
                        pred_print_out += '%s;' % ' '.join(word_list)
                    else:
                        pred_print_out += '%s' % ' '.join(word_list)
                pred_print_out += '\n'
                pred_output_file.write(pred_print_out)

    pred_output_file.close()
    print("done!")
def train_one_batch(batch, overall_model, optimizer, opt, global_step,
                    classification_loss_func, tb_writer):
    # src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_sent_2d_list, trg, trg_oov, trg_lens, trg_mask, rating, _ = batch

    # changed by wchen to a dictionary batch
    src = batch['src_tensor']
    src_lens = batch['src_lens']
    src_mask = batch['src_mask']
    src_sent_positions = batch['src_sent_positions']
    src_sent_nums = batch['src_sent_nums']
    src_sent_mask = batch['src_sent_mask']
    src_oov = batch['src_oov_tensor']
    oov_lists = batch['oov_lists']
    src_str_list = batch['src_list_tokenized']
    trg_sent_2d_list = batch['tgt_sent_2d_list']
    trg = batch['tgt_tensor']
    trg_oov = batch['tgt_oov_tensor']
    trg_lens = batch['tgt_lens']
    trg_mask = batch['tgt_mask']
    rating = batch['rating_tensor']
    indices = batch['original_indices']
    """
    trg: LongTensor [batch, trg_seq_len], each target trg[i] contains the indices of a set of concatenated keyphrases, separated by opt.word2idx[io.SEP_WORD]
                 if opt.delimiter_type = 0, SEP_WORD=<sep>, if opt.delimiter_type = 1, SEP_WORD=<eos>
    trg_oov: same as trg_oov, but all unk words are replaced with temporary idx, e.g. 50000, 50001 etc.
    """
    #seq2seq_model = overall_model['generator']
    #classifier_model = overall_model['classifier']
    batch_size = src.size(0)
    max_num_oov = max([len(oov) for oov in oov_lists
                       ])  # max number of oov for each batch

    # move data to GPU if available
    src = src.to(opt.device)
    src_mask = src_mask.to(opt.device)
    trg = trg.to(opt.device)

    src_sent_positions = src_sent_positions.to(opt.device)
    src_sent_mask = src_sent_mask.to(opt.device)

    trg_mask = trg_mask.to(opt.device)
    src_oov = src_oov.to(opt.device)
    trg_oov = trg_oov.to(opt.device)
    rating = rating.to(opt.device)

    optimizer.zero_grad()

    start_time = time.time()

    # forward
    if overall_model.model_type == 'hre_max':
        decoder_dist, h_t, seq2seq_attention_dist, encoder_final_state, coverage, classifier_logit, classifier_attention_dist = \
            overall_model(src, src_lens, trg, src_oov, max_num_oov, src_mask, trg_mask, src_sent_positions, src_sent_nums, src_sent_mask)
    else:
        decoder_dist, h_t, seq2seq_attention_dist, encoder_final_state, coverage, classifier_logit, classifier_attention_dist = \
            overall_model(src, src_lens, trg, src_oov, max_num_oov, src_mask, trg_mask, rating, src_sent_positions, src_sent_nums, src_sent_mask)

    forward_time = time_since(start_time)

    start_time = time.time()
    # compute loss for generation
    if decoder_dist is not None:
        if opt.copy_attention:  # Compute the loss using target with oov words
            generation_loss = masked_cross_entropy(decoder_dist, trg_oov,
                                                   trg_mask, trg_lens,
                                                   opt.coverage_attn, coverage,
                                                   seq2seq_attention_dist,
                                                   opt.lambda_coverage,
                                                   opt.coverage_loss)
        else:  # Compute the loss using target without oov words
            generation_loss = masked_cross_entropy(decoder_dist, trg, trg_mask,
                                                   trg_lens, opt.coverage_attn,
                                                   coverage,
                                                   seq2seq_attention_dist,
                                                   opt.lambda_coverage,
                                                   opt.coverage_loss)
    else:
        # RnnEncSingleClassifier model
        assert opt.class_loss_internal_enc_weight == 1.0
        assert opt.class_loss_weight == 1.0
        generation_loss = torch.Tensor([0.0]).to(opt.device)

    if math.isnan(generation_loss.item()):
        logging.info("global_step: %d" % global_step)
        logging.info("src")
        logging.info(src)
        logging.info(src_oov)
        logging.info(src_str_list)
        logging.info(src_lens)
        logging.info(src_mask)
        logging.info("trg")
        logging.info(trg)
        logging.info(trg_oov)
        logging.info(trg_sent_2d_list)
        logging.info(trg_lens)
        logging.info(trg_mask)
        logging.info("oov list")
        logging.info(oov_lists)
        logging.info("Decoder")
        logging.info(decoder_dist)
        logging.info(h_t)
        logging.info(seq2seq_attention_dist)
        raise ValueError("Generation loss is NaN")

    # normalize generation loss
    total_trg_tokens = sum(trg_lens)
    if opt.loss_normalization == "tokens":  # use number of target tokens to normalize the loss
        generation_loss_normalization = total_trg_tokens
    elif opt.loss_normalization == 'batches':  # use batch_size to normalize the loss
        generation_loss_normalization = batch_size
    else:
        raise ValueError('The type of loss normalization is invalid.')
    assert generation_loss_normalization > 0, 'normalization should be a positive number'
    normalized_generation_loss = generation_loss.div(
        generation_loss_normalization)

    # compute loss of classification
    if classifier_logit is not None:
        if isinstance(classifier_logit, tuple):
            # from multi_view_model
            enc_normalized_classification_loss = classification_loss_func(
                classifier_logit[0],
                rating)  # normalized by batch size already
            dec_normalized_classification_loss = classification_loss_func(
                classifier_logit[1],
                rating)  # normalized by batch size already
            # compute loss of inconsistency for the multi view model
            if opt.inconsistency_loss_type != "None":
                inconsistency_loss = inconsistency_loss_func(
                    classifier_logit[0], classifier_logit[1],
                    opt.inconsistency_loss_type, opt.detach_dec_incosist_loss)
            else:
                inconsistency_loss = torch.Tensor([0.0]).to(opt.device)
        else:
            enc_normalized_classification_loss = classification_loss_func(
                classifier_logit, rating)  # normalized by batch size already
            dec_normalized_classification_loss = torch.Tensor([0.0]).to(
                opt.device)
            inconsistency_loss = torch.Tensor([0.0]).to(opt.device)
    else:
        enc_normalized_classification_loss = torch.Tensor([0.0]).to(opt.device)
        dec_normalized_classification_loss = torch.Tensor([0.0]).to(opt.device)
        inconsistency_loss = torch.Tensor([0.0]).to(opt.device)

    total_normalized_classification_loss = opt.class_loss_internal_enc_weight * enc_normalized_classification_loss + \
                                           opt.class_loss_internal_dec_weight * dec_normalized_classification_loss

    joint_loss = opt.gen_loss_weight * normalized_generation_loss + opt.class_loss_weight * total_normalized_classification_loss + opt.inconsistency_loss_weight * inconsistency_loss

    loss_compute_time = time_since(start_time)

    start_time = time.time()
    # back propagation on the joint loss
    joint_loss.backward()
    backward_time = time_since(start_time)

    if opt.max_grad_norm > 0:
        grad_norm_before_clipping = nn.utils.clip_grad_norm_(
            overall_model.parameters(), opt.max_grad_norm)
        # grad_norm_after_clipping = (sum([p.grad.data.norm(2) ** 2 for p in model.parameters() if p.grad is not None])) ** (1.0 / 2)
        # logging.info('clip grad (%f -> %f)' % (grad_norm_before_clipping, grad_norm_after_clipping))

    optimizer.step()

    # log each loss to tensorboard
    if tb_writer is not None:
        tb_writer.add_scalar('enc_classification_loss',
                             enc_normalized_classification_loss.item(),
                             global_step)
        tb_writer.add_scalar('dec_classification_loss',
                             dec_normalized_classification_loss.item(),
                             global_step)
        tb_writer.add_scalar('inconsistency_loss', inconsistency_loss.item(),
                             global_step)
        tb_writer.add_scalar('total_classification_loss',
                             total_normalized_classification_loss.item(),
                             global_step)
        tb_writer.add_scalar('generation_loss',
                             normalized_generation_loss.item(), global_step)
        tb_writer.add_scalar('joint_loss', joint_loss.item(), global_step)

    # construct a statistic object for the loss
    stat = JointLossStatistics(joint_loss.item(),
                               generation_loss.item(),
                               enc_normalized_classification_loss.item(),
                               dec_normalized_classification_loss.item(),
                               inconsistency_loss.item(),
                               n_iterations=1,
                               n_tokens=total_trg_tokens,
                               forward_time=forward_time,
                               loss_compute_time=loss_compute_time,
                               backward_time=backward_time)

    decoder_dist_out = decoder_dist.detach(
    ) if decoder_dist is not None else None
    return stat, decoder_dist_out,