예제 #1
0
class Evaluate_pg(object):
    def __init__(self, model_file_path, is_word_level, is_combined, alpha):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        # self.batcher = Batcher(config.eval_data_path, self.vocab, mode='eval',
        #                        batch_size=config.batch_size, single_pass=True)
        self.dataset = DailyMailDataset("val", self.vocab)
        # time.sleep(15)
        model_name = os.path.basename(model_file_path)

        self.is_word_level = is_word_level
        self.is_combined = is_combined
        self.alpha = alpha

        eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name))
        if not os.path.exists(eval_dir):
            os.mkdir(eval_dir)

        self.model = Model(model_file_path, is_eval=True)

    def compute_policy_grads_using_rewards(self, sentence_rewards,
                                           word_rewards, sentence_losses,
                                           word_losses, word_to_sent_ind):
        if self.is_combined:
            pg_losses = [[(self.alpha * word_reward + (1 - self.alpha) *
                           sentence_rewards[i][word_to_sent_ind[i][j]]) *
                          word_losses[i][j]
                          for j, word_reward in enumerate(abstract_rewards)
                          if j < len(word_to_sent_ind[i])]
                         for i, abstract_rewards in enumerate(word_rewards)]
            pg_losses = [sum(pg) for pg in pg_losses]
        elif self.is_word_level:
            pg_losses = [[
                word_reward * word_losses[i][j]
                for j, word_reward in enumerate(abstract_rewards)
                if j < len(word_to_sent_ind[i])
            ] for i, abstract_rewards in enumerate(word_rewards)]
            pg_losses = [sum(pg) for pg in pg_losses]
        else:
            pg_losses = [[
                rs * sentence_losses[ri][rsi] for rsi, rs in enumerate(r)
            ] for ri, r in enumerate(sentence_rewards)]
            pg_losses = [sum(pg) for pg in pg_losses]
        return pg_losses

    def compute_pg_loss(self, orig, pred, sentence_losses, split_predictions,
                        word_losses, word_to_sent_ind):
        sentence_rewards = None
        word_rewards = None
        # First compute the rewards
        if not self.is_word_level or self.is_combined:
            sentence_rewards = get_sentence_rewards(orig, pred)

        if self.is_word_level or self.is_combined:
            word_rewards = get_word_level_rewards(orig, split_predictions)

        pg_losses = self.compute_policy_grads_using_rewards(
            sentence_rewards=sentence_rewards,
            word_rewards=word_rewards,
            sentence_losses=sentence_losses,
            word_losses=word_losses,
            word_to_sent_ind=word_to_sent_ind)

        return pg_losses

    def compute_batched_loss(self, word_losses, orig, pred):
        orig_sum = []
        new_pred = []
        pred_sum = []
        sentence_losses = []

        # Convert the original sum as one single string per article
        for i in range(len(orig)):
            orig_sum.append(' '.join(map(str, orig[i])))
            new_pred.append([])
            pred_sum.append([])
            sentence_losses.append([])

        batch_sent_indices = []
        for i in range(len(pred)):
            sentence = []
            sentence = pred[i]
            losses = word_losses[i]
            sentence_indices = []
            count = 0
            while len(sentence) > 0:
                try:
                    idx = sentence.index(".")
                except ValueError:
                    idx = len(sentence)

                sentence_indices.extend([count for _ in range(idx)])

                if count > 0:
                    new_pred[i].append(new_pred[i][count - 1] +
                                       sentence[:idx + 1])
                else:
                    new_pred[i].append(sentence[:idx + 1])

                sentence_losses[i].append(sum(losses[:idx + 1]))

                sentence = sentence[idx + 1:]
                losses = losses[idx + 1:]
                count += 1
            batch_sent_indices.append(sentence_indices)

        for i in range(len(pred)):
            for j in range(len(new_pred[i])):
                pred_sum[i].append(' '.join(map(str, new_pred[i][j])))

        pg_losses = self.compute_pg_loss(orig_sum,
                                         pred_sum,
                                         sentence_losses,
                                         split_predictions=pred,
                                         word_losses=word_losses,
                                         word_to_sent_ind=batch_sent_indices)

        return pg_losses

    def eval_one_batch(self, batch):
        batch_size = batch.batch_size

        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        output_ids = []
        y_t_1 = torch.ones(batch_size, dtype=torch.long) * self.vocab.word2id(
            data.START_DECODING)

        if config.use_gpu:
            y_t_1 = y_t_1.cuda()

        for _ in range(batch_size):
            output_ids.append([])
            step_losses.append([])

        for di in range(min(max_dec_len, config.max_dec_steps)):
            #y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)  #NLL
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask

            # Move on to the next token
            _, idx = torch.max(final_dist, 1)
            idx = idx.reshape(batch_size, -1).squeeze()
            y_t_1 = idx

            for i, pred in enumerate(y_t_1):
                if not pred.item() == data.PAD_TOKEN:
                    output_ids[i].append(pred.item())

            for i, loss in enumerate(step_loss):
                step_losses[i].append(step_loss[i])

        # Obtain the original and predicted summaries
        original_abstracts = batch.original_abstracts_sents
        predicted_abstracts = [
            data.outputids2words(ids, self.vocab, None) for ids in output_ids
        ]

        # Compute the batched loss
        batched_losses = self.compute_batched_loss(step_losses,
                                                   original_abstracts,
                                                   predicted_abstracts)
        losses = torch.stack(batched_losses)
        losses = losses / dec_lens_var

        loss = torch.mean(losses)

        return loss.item()

    def run_eval(self, model_dir, train_iter_id):
        dataloader = DataLoader(self.dataset,
                                batch_size=config.batch_size,
                                shuffle=False,
                                num_workers=1,
                                collate_fn=create_batch_collate(
                                    self.vocab, config.batch_size))
        running_avg_loss, iter = 0, 0
        start = time.time()
        # batch = self.batcher.next_batch()
        pg_losses = []
        run_avg_losses = []
        for batch in dataloader:
            loss = self.eval_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     iter)
            print("Iteration:", iter, "  loss:", loss, "  Running avg loss:",
                  running_avg_loss)
            iter += 1

            print_interval = 1000
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' %
                      (iter, print_interval, time.time() - start,
                       running_avg_loss))
                start = time.time()

            pg_losses.append(loss)
            run_avg_losses.append(running_avg_loss)

        # Dump val losses
        pickle.dump(
            pg_losses,
            open(
                os.path.join(model_dir,
                             'val_pg_losses_{}.p'.format(train_iter_id)),
                'wb'))
        pickle.dump(
            run_avg_losses,
            open(
                os.path.join(model_dir,
                             'val_run_avg_losses_{}.p'.format(train_iter_id)),
                'wb'))

        return run_avg_losses
예제 #2
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               mode='train',
                               batch_size=config.batch_size,
                               single_pass=False)
        time.sleep(15)

        train_dir = os.path.join(config.log_root,
                                 'train_{}'.format(int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iters):
        state = {
            'iter': iters,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(
            self.model_dir, 'model_{}_{}'.format(iters, int(time.time())))
        torch.save(state, model_save_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = Adagrad(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def train_one_batch(self, batch):

        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        start = time.time()
        while iter < n_iters:
            batch = self.batcher.next_batch()
            loss = self.train_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % 100 == 0:
                self.summary_writer.flush()
            print_interval = 1000
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' %
                      (iter, print_interval, time.time() - start, loss))
                start = time.time()
            if iter % 5000 == 0:
                self.save_model(running_avg_loss, iter)
예제 #3
0
class BeamSearch(object):
    def __init__(self, model_file_path, model_type="stem", load_batcher=True):

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        if load_batcher:
            self.batcher = Batcher(config.decode_data_path,
                                   self.vocab,
                                   mode='decode',
                                   batch_size=config.beam_size,
                                   single_pass=True)
            time.sleep(15)
        self.model = Model(model_file_path, is_eval=True)
        self.model_type = model_type

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def restore_text(self, text):
        if self.model_type == "stem":
            return " ".join(text).replace(" +", "")
        elif self.model_type == "gram":
            return "".join(text).replace(" ", "").replace("▁", " ")
        else:
            return " ".join(text)

    def decode(self):
        lemm = pymystem3.Mystem()
        rouge = RougeCalculator(stopwords=True, lang=LangRU())
        result_rouge = [0] * 6

        batch = self.batcher.next_batch()

        iters = 0
        while batch is not None:
            # Run beam search to get best Hypothesis
            with torch.no_grad():
                best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]
            original_text = batch.original_articles
            article_oov = batch.art_oovs[0] if batch.art_oovs else None

            batch = self.batcher.next_batch()

            original_abstract_sents = self.restore_text(
                original_abstract_sents)
            decoded_words_restore = self.restore_text(decoded_words)
            decoded_words = " ".join(decoded_words)

            print(f"original_abstract : {original_abstract_sents}")
            print(f"original_text : {original_text}")
            print(f"decoded_words : {decoded_words_restore}")
            print(
                f"decoded_words_oov : {show_abs_oovs(decoded_words, self.vocab, article_oov)}"
            )

            cur_rouge = calk_rouge(original_abstract_sents,
                                   [decoded_words_restore], rouge, lemm)

            result_rouge = list(
                map(lambda x: x[0] + x[1], zip(result_rouge, cur_rouge)))
            iters += 1

        print("--" * 100)
        print("RESULT METRICS")
        result_rouge = [i / iters for i in result_rouge]
        print_results(result_rouge)
        print("++++" * 100)

    def beam_search(self, batch):
        # batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, \
        enc_batch_extend_vocab, extra_zeros, \
        c_t_0, coverage_t_0 = get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h, dec_c = dec_h.squeeze(), dec_c.squeeze()

        # decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]

        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]

            y_t_1 = torch.tensor(latest_tokens, dtype=torch.long)
            if use_cuda:
                y_t_1 = y_t_1.cuda()

            all_state_h, all_state_c, all_context = [], [], []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)
                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = [h.coverage for h in beams]
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = \
                self.model.decoder(y_t_1, s_t_1,
                                   encoder_outputs, encoder_feature,
                                   enc_padding_mask, c_t_1,
                                   extra_zeros, enc_batch_extend_vocab,
                                   coverage_t_1, steps)

            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h, dec_c = dec_h.squeeze(), dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    if h.latest_token != self.vocab.word2id(
                            data.UNKNOWN_TOKEN):
                        beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)
        return beams_sorted[0]

    def test_calc(self, article):
        example = batcher.Example(article, [], self.vocab)
        batch = batcher.Batch([example for _ in range(config.beam_size)],
                              self.vocab, config.beam_size)

        with torch.no_grad():
            best_summary = self.beam_search(batch)

        output_ids = [int(t) for t in best_summary.tokens[1:]]
        decoded_words = data.outputids2words(
            output_ids, self.vocab,
            (batch.art_oovs[0] if config.pointer_gen else None))

        article_restore = self.restore_text(
            batch.original_articles[-1].split())
        decoded_words_restore = self.restore_text(decoded_words).replace(
            "[STOP]", "")
        print(f"original_text : {article_restore}")
        print(f"decoded_words : {decoded_words_restore}")
        decoded_words = " ".join(decoded_words)
        print(
            f"decoded_words_oov : {show_abs_oovs(decoded_words, self.vocab, batch.art_oovs[0] if batch.art_oovs else None)}"
        )

    def test(self, mode, bpe_model_path=None):
        while True:
            file_path = input("File path: ").strip()
            file_path = r"C:\Users\lezgy\OneDrive\Рабочий стол\Data_summ\data.txt"
            if file_path == "q":
                break
            try:
                with open(file_path, "r", encoding="utf-8") as r:
                    article = r.read().strip().split("\n")
                    article = " ".join(article)
                    if mode in ["lemm", "stem", "gram", "base"]:
                        article = article.lower()
                        article = word_tokenize(article)
                        article = " ".join(article)
                    print(f"real_text : {article}")

                if mode == "lemm":
                    lemmatizer = mystem.Mystem()
                    article = preprocess_lemm(article, lemmatizer)
                elif mode == "stem":
                    stemmer = RussianStemmer(False)
                    article = preprocess_stemm(article, stemmer)
                elif mode == "gram":
                    token_model = youtokentome.BPE(model=bpe_model_path)
                    article = preprocess_gramm(article, token_model)
                self.test_calc(article)
            except Exception as e:
                print(e)
                print("File not found")
예제 #4
0
class Train(object):
    def __init__(self, train_dir=None, eval_dir=None, vocab=None, vectors=None):
        self.vectors = vectors
        if vocab is None:
            self.vocab = Vocab(config.vocab_path, config.vocab_size)
        else:
            self.vocab = vocab

        print(self.vocab)
        self.batcher_train = Batcher(config.train_data_path, self.vocab, mode='train',
                                     batch_size=config.batch_size, single_pass=False)
        time.sleep(15)
        self.batcher_eval = Batcher(config.eval_data_path, self.vocab, mode='eval',
                                    batch_size=config.batch_size, single_pass=True)
        time.sleep(15)

        cur_time = int(time.time())
        if train_dir is None:
            train_dir = os.path.join(config.log_root, 'train_%d' % (cur_time))
            if not os.path.exists(train_dir):
                os.mkdir(train_dir)

        if eval_dir is None:
            eval_dir = os.path.join(config.log_root, 'eval_%s' % (cur_time))
            if not os.path.exists(eval_dir):
                os.mkdir(eval_dir)

        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer_train = writer.FileWriter(train_dir)
        self.summary_writer_eval = writer.FileWriter(eval_dir)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path, vectors=self.vectors)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())

        pytorch_total_params = sum(p.numel() for p in params if p.requires_grad)
        print(f"Parameters count: {pytorch_total_params}")

        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        # self.optimizer = adagrad.Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc)
        self.optimizer = Adam(params, lr=initial_lr)
        start_iter, start_training_loss, start_eval_loss = 0, 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path, map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_training_loss = state['current_train_loss']
            start_eval_loss = state['current_eval_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            print(k)
                            if isinstance(v, torch.Tensor):
                                state[k] = v.cuda()

        self.chechpoint = Checkpoint(self.model,
                                     self.optimizer,
                                     self.model_dir,
                                     start_eval_loss if start_eval_loss != 0 else float("inf"))

        return start_iter, start_training_loss, start_eval_loss

    def model_batch_step(self, batch, eval):

        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        step_decoded_idx = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing

            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = \
                self.model.decoder(y_t_1, s_t_1,
                                   encoder_outputs,
                                   encoder_feature,
                                   enc_padding_mask, c_t_1,
                                   extra_zeros,
                                   enc_batch_extend_vocab,
                                   coverage, di)

            if eval:
                _, top_idx = final_dist.topk(1)
                step_decoded_idx.append(top_idx)

            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        final_decoded_sentences = None
        if eval:
            final_decoded_sentences = torch.stack(step_decoded_idx, 2).squeeze(1)
            print(final_decoded_sentences)

        return loss, final_decoded_sentences

    def train_one_batch(self, batch):
        self.optimizer.zero_grad()
        loss, _ = self.model_batch_step(batch, False)
        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm)

        self.optimizer.step()

        return loss.item()

    def run_eval(self):

        self.model.eval()
        batch = self.batcher_eval.next_batch()
        iter = 0
        start = time.time()
        running_avg_loss = 0
        with torch.no_grad():
            while batch is not None:
                loss, _ = self.model_batch_step(batch, False)
                loss = loss.item()
                running_avg_loss = calc_running_avg_loss(loss, running_avg_loss)
                batch = self.batcher_eval.next_batch()

                iter += 1
                if iter % config.print_interval == 0:
                    print('Eval steps %d, seconds for %d batch: %.2f , loss: %f' % (
                        iter, config.print_interval, time.time() - start, running_avg_loss))
                    start = time.time()

        return running_avg_loss

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss_train, running_avg_loss_eval = self.setup_train(model_file_path)
        start = time.time()

        loss_train = 0
        while iter < n_iters:

            self.model.train()
            batch = self.batcher_train.next_batch()
            loss_train = self.train_one_batch(batch)
            running_avg_loss_train = calc_and_write_running_avg_loss(loss_train,
                                                                     "running_avg_loss_train",
                                                                     running_avg_loss_train,
                                                                     self.summary_writer_train,
                                                                     iter)
            iter += 1

            if iter % 100 == 0:
                self.summary_writer_train.flush()

            if iter % config.print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f, loss: %f, avg_loss: %f' % (iter, config.print_interval,
                                                                                        time.time() - start,
                                                                                        loss_train,
                                                                                        running_avg_loss_train))
                start = time.time()

            if iter % 5000 == 0:
                running_avg_loss_eval = self.run_eval()
                write_summary("running_avg_loss_eval",
                              running_avg_loss_eval,
                              self.summary_writer_eval,
                              iter)
                self.summary_writer_eval.flush()
                self.chechpoint.check_loss(running_avg_loss_eval, running_avg_loss_train, iter)
                start = time.time()
                self.batcher_eval.start_threads()

            if config.is_coverage and iter % 2000 == 0:
                 self.chechpoint.save_model("coverage", running_avg_loss_eval, running_avg_loss_train, iter)
            if iter % 10000 == 0:
                self.chechpoint.save_model("critical", running_avg_loss_eval, running_avg_loss_train, iter)
예제 #5
0
def train():
    target_field = Field(sequential=True,
                         init_token=START_DECODING,
                         eos_token=STOP_DECODING,
                         pad_token=PAD_TOKEN,
                         batch_first=True,
                         include_lengths=True,
                         unk_token=UNKNOWN_TOKEN,
                         lower=True)

    source_field = Field(sequential=True,
                         init_token=SENTENCE_START,
                         eos_token=SENTENCE_END,
                         pad_token=PAD_TOKEN,
                         batch_first=True,
                         include_lengths=True,
                         unk_token=UNKNOWN_TOKEN,
                         lower=True)
    train_path = '../data/incar_alexa/train_public.pickle'
    dev_path = '../data/incar_alexa/dev_public.pickle'
    test_path = '../data/incar_alexa/test_public.pickle'
    path = '../data/cnn_stories_tokenized'
    summary_writer = SummaryWriter(config.summary_path)

    train_src, train_tgt, train_id = load_data(train_path)
    dev_src, dev_tgt, dev_id = load_data(dev_path)
    test_src, test_tgt, test_id = load_data(test_path)
    # train_data = prepare_data_cnn(path)
    # # print(train_data[0])
    # train_src = [dt['src'] for dt in train_data]
    # train_tgt = [dt['tgt'] for dt in train_data]
    # train_id = [dt['id'] for dt in train_data]
    # train_src, test_src, train_tgt, test_tgt = train_test_split(
    #     train_src, train_tgt, test_size=0.15, random_state=123)
    # train_id, test_id = train_test_split(
    #     train_id, test_size=0.15, random_state=123)
    # # print(f"{len(train_src)}, {len(train_tgt)}")
    # train_src, dev_src, train_tgt, dev_tgt = train_test_split(
    #     train_src, train_tgt, test_size=0.15, random_state=123)
    # train_id, dev_id = train_test_split(
    #     train_id, test_size=0.15, random_state=123)

    # print(source_field.preprocess(train_src[0]))
    # exit()
    train_src_preprocessed = [source_field.preprocess(x) for x in train_src]
    dev_src_preprocessed = [source_field.preprocess(x) for x in dev_src]
    test_src_preprocessed = [source_field.preprocess(x) for x in test_src]

    train_tgt_preprocessed = [target_field.preprocess(x) for x in train_tgt]
    dev_tgt_preprocessed = [target_field.preprocess(x) for x in dev_tgt]
    test_tgt_preprocessed = [target_field.preprocess(x) for x in test_tgt]
    # train_src_preprocessed = source_field.apply(lambda x: source_field.preprocess(x))

    vectors = Vectors(
        name='/home/binhna/Downloads/shared_resources/cc.en.300.vec',
        cache='/home/binhna/Downloads/shared_resources/')

    source_field.build_vocab([
        train_src_preprocessed, dev_src_preprocessed, train_tgt_preprocessed,
        dev_tgt_preprocessed
    ],
                             vectors=vectors)
    target_field.build_vocab([
        train_src_preprocessed, dev_src_preprocessed, train_tgt_preprocessed,
        dev_tgt_preprocessed
    ],
                             vectors=vectors)

    train_data = [{
        'src': src,
        'tgt': tgt,
        'id': id
    } for src, tgt, id in zip(train_src, train_tgt, train_id)]
    train_data = Mydataset(data=train_data,
                           fields=(('source', source_field), ('target',
                                                              target_field)))
    dev_data = [{
        'src': src,
        'tgt': tgt,
        'id': id
    } for src, tgt, id in zip(dev_src, dev_tgt, dev_id)]
    # print(dev_data[0])
    dev_data = Mydataset(data=dev_data,
                         fields=(('source', source_field), ('target',
                                                            target_field)))

    test_data = [{
        'src': src,
        'tgt': tgt,
        'id': id
    } for src, tgt, id in zip(test_src, test_tgt, test_id)]
    test_data = Mydataset(data=test_data,
                          fields=(('source', source_field), ('target',
                                                             target_field)))
    # print(train_data[10].source)
    # print(train_data[10].target)
    # print(len(target_field.vocab))
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    train_iter, test_iter, dev_iter = BucketIterator.splits(
        datasets=(train_data, test_data, dev_data),
        batch_sizes=(config.batch_size, config.batch_size, config.batch_size),
        device=device,
        sort_key=lambda x: len(x.source),
        sort_within_batch=True)

    args = ARGS()
    setattr(args, 'vectors', source_field.vocab.vectors)
    setattr(args, 'vocab_size', len(source_field.vocab.itos))
    setattr(args, 'emb_dim', vectors.dim)
    model = Model(args)

    params = list(model.encoder.parameters()) + list(
        model.decoder.parameters()) + list(model.reduce_state.parameters())
    initial_lr = config.lr_coverage if config.is_coverage else config.lr
    optimizer = Adagrad(params,
                        lr=initial_lr,
                        initial_accumulator_value=config.adagrad_init_acc)

    iter, running_avg_loss = 0, 0
    start = time.time()
    for epoch in range(500):
        print(f"Epoch: {epoch+1}")
        for i, batch in tqdm(enumerate(train_iter), total=len(train_iter)):
            # print(batch.source[0].size())
            # exit()
            batch_size = batch.batch_size
            # encoder part
            enc_padding_mask = get_mask(batch.source, device)
            enc_batch = batch.source[0]
            enc_lens = batch.source[1]
            encoder_outputs, encoder_feature, encoder_hidden = model.encoder(
                enc_batch, enc_lens)
            s_t_1 = model.reduce_state(encoder_hidden)
            coverage = Variable(torch.zeros(batch.source[0].size())).to(device)
            c_t_1 = Variable(torch.zeros(
                (batch_size, 2 * config.hidden_dim))).to(device)
            extra_zeros, enc_batch_extend_vocab, max_art_oovs = get_extra_features(
                batch.source[0], source_field.vocab)
            extra_zeros = extra_zeros.to(device)
            enc_batch_extend_vocab = enc_batch_extend_vocab.to(device)
            # decoder part
            dec_batch = batch.target[0][:, :-1]
            # print(dec_batch.size())
            target_batch = batch.target[0][:, 0:]
            dec_lens_var = batch.target[1]
            dec_padding_mask = get_mask(batch.target, device)
            max_dec_len = max(dec_lens_var)

            step_losses = []
            for di in range(min(max_dec_len, config.max_dec_steps) - 1):
                y_t_1 = dec_batch[:, di]  # Teacher forcing
                final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = model.decoder(
                    y_t_1, s_t_1, encoder_outputs, encoder_feature,
                    enc_padding_mask, c_t_1, extra_zeros,
                    enc_batch_extend_vocab, coverage, di)
                target = target_batch[:, di]
                gold_probs = torch.gather(final_dist, 1,
                                          target.unsqueeze(1)).squeeze()
                step_loss = -torch.log(gold_probs + config.eps)
                if config.is_coverage:
                    step_coverage_loss = torch.sum(
                        torch.min(attn_dist, coverage), 1)
                    step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                    coverage = next_coverage

                step_mask = dec_padding_mask[:, di]
                step_loss = step_loss * step_mask
                step_losses.append(step_loss)
            sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
            batch_avg_loss = sum_losses / dec_lens_var
            loss = torch.mean(batch_avg_loss)

            loss.backward()

            norm = clip_grad_norm_(model.encoder.parameters(),
                                   config.max_grad_norm)
            clip_grad_norm_(model.decoder.parameters(), config.max_grad_norm)
            clip_grad_norm_(model.reduce_state.parameters(),
                            config.max_grad_norm)

            optimizer.step()

            running_avg_loss = calc_running_avg_loss(loss.item(),
                                                     running_avg_loss,
                                                     summary_writer, iter)
            iter += 1
            summary_writer.flush()
            # print_interval = 10
            # if iter % print_interval == 0:
            #     print(f'steps {iter}, batch number: {i} with {time.time() - start} seconds, loss: {loss}')
            #     start = time.time()
            if iter % 300 == 0:
                save_model(model, optimizer, running_avg_loss, iter,
                           config.model_dir)