コード例 #1
0
    def encode_and_pad(self, data_batches, word2id_dictionary):
        #################### Prepare Training data################
        print('Encoding Data...')
        max_sentences = []
        max_length = []
        no_padding_sentences = []
        no_padding_lengths = []
        for index, batch in tqdm(enumerate(data_batches)):
            batch = hF.encode_batch(batch, word2id_dictionary)

            num_sentences = [len(x) for x in batch]
            sentence_lengthes = [[len(x) for x in y] for y in batch]
            max_num_sentences = max(num_sentences)
            max_sentences_length = max([max(x) for x in sentence_lengthes])

            batch, no_padding_num_sentences = hF.pad_batch_with_sentences(
                batch, max_num_sentences)
            batch, no_padding_sentence_lengths = hF.pad_batch_sequences(
                batch, max_sentences_length)

            max_sentences.append(max_num_sentences)
            max_length.append(max_sentences_length)
            no_padding_sentences.append(no_padding_num_sentences)
            no_padding_lengths.append(no_padding_sentence_lengths)
            data_batches[index] = batch
        ##########################################
        return data_batches, max_sentences, max_length, no_padding_sentences, no_padding_lengths
コード例 #2
0
    def encode_and_pad_BERT(self, data_batches, Bert_model_Path, device,
                            bert_layers, bert_dims):
        from pytorch_pretrained_bert import BertTokenizer, BertModel
        tokenizer = BertTokenizer.from_pretrained(
            Bert_model_Path
        )  # '../../pytorch-pretrained-BERT/bert_models/uncased_L-12_H-768_A-12/')
        model = BertModel.from_pretrained(
            Bert_model_Path
        )  # '../../pytorch-pretrained-BERT/bert_models/uncased_L-12_H-768_A-12/')
        model.eval()
        model.to(device)
        #################### Prepare Training data################
        print('Encoding Data using BERT...')
        max_sentences = []
        no_padding_sentences = []
        for index, batch in tqdm(enumerate(data_batches)):
            batch = hF.encode_batch_BERT(batch, model, tokenizer, device,
                                         bert_layers)
            # data_batches[index] = batch
            num_sentences = [len(x) for x in batch]
            max_num_sentences = max(num_sentences)

            batch, no_padding_num_sentences = hF.pad_batch_with_sentences_BERT(
                batch, max_num_sentences, bert_layers, bert_dims)

            max_sentences.append(max_num_sentences)
            no_padding_sentences.append(no_padding_num_sentences)
            data_batches[index] = batch
        ##########################################
        return data_batches, max_sentences, None, no_padding_sentences, None
コード例 #3
0
    def tokenize(self,
                 data,
                 use_back_translation=False
                 ):  # , max_num_sentences=None, max_sentence_length=None):
        all_comments = []
        all_posts = []
        all_answers = []
        all_human_summaries = []

        all_comments_translated = []
        all_posts_translated = []

        print('Tokenizing Data...')
        for i in tqdm(range(0, len(data))):
            post = [
                x.replace('\n', '').replace('\r', '').strip()
                for x in data[i].initial_post
            ]
            comments = [
                x.replace('\n', '').replace('\r', '').strip()
                for x in data[i].reply_sentences
            ]
            selected_sentences = [
                x.replace('\n', '').replace('\r', '').strip()
                for x in data[i].reply_sentences
            ]

            answers = [1 if x in selected_sentences else 0 for x in comments]
            post = [hF.tokenize_text(x) for x in post]

            comments = [hF.tokenize_text(x) for x in comments]
            human_summary = data[i].human_summary_1

            if use_back_translation is True:
                post_translated = [
                    x.replace('\n', '').replace('\r', '').strip().split(' ')
                    for x in data[i].initial_post_translated
                ]
                comments_translated = [
                    x.replace('\n', '').replace('\r', '').strip()
                    for x in data[i].reply_sentences_translated
                ]

                post_translated = [x.split(' ') for x in post_translated]
                all_comments_translated.append(comments_translated)
                all_posts_translated.append(post_translated)

            all_answers.append(answers)
            all_comments.append(comments)
            all_posts.append(post)
            all_human_summaries.append(human_summary)

        if use_back_translation is True:
            return all_posts, all_comments, all_answers, all_human_summaries, all_posts_translated, all_comments_translated
        else:
            return all_posts, all_comments, all_answers, all_human_summaries
コード例 #4
0
    def pad_batch(self, data_batch):
        num_sentences = [len(x) for x in data_batch]
        sentence_lengthes = [[len(x) for x in y] for y in data_batch]
        max_num_sentences = max(num_sentences)
        max_sentences_length = max([max(x) for x in sentence_lengthes])

        data_batch, no_padding_num_sentences = hF.pad_batch_with_sentences(
            data_batch, max_num_sentences)
        data_batch, no_padding_sentence_lengths = hF.pad_batch_sequences(
            data_batch, max_sentences_length)

        ##########################################
        return data_batch, max_num_sentences, max_sentences_length, no_padding_num_sentences, no_padding_sentence_lengths
コード例 #5
0
 def pad_batch_BERT(self, batch, bert_layers, bert_dims):
     num_sentences = [len(x) for x in batch]
     max_num_sentences = max(num_sentences)
     batch, no_padding_num_sentences = hF.pad_batch_with_sentences_BERT(
         batch, max_num_sentences, bert_layers, bert_dims)
     ##########################################
     return batch, max_num_sentences, None, no_padding_num_sentences, None
コード例 #6
0
    def encode_BERT(self, data, Bert_model_Path, device, bert_layers,
                    batch_size):
        from pytorch_pretrained_bert import BertTokenizer, BertModel
        if not os.path.exists(Bert_model_Path):
            print('Bet Model not found.. make sure path is correct')
            return
        tokenizer = BertTokenizer.from_pretrained(
            Bert_model_Path
        )  # '../../pytorch-pretrained-BERT/bert_models/uncased_L-12_H-768_A-12/')
        model = BertModel.from_pretrained(
            Bert_model_Path
        )  # '../../pytorch-pretrained-BERT/bert_models/uncased_L-12_H-768_A-12/')
        model.eval()
        model.to(device)
        #################### Prepare Training data################
        print('Encoding Data using BERT...')
        max_sentences = []
        no_padding_sentences = []
        j = 0
        for j in tqdm(range(0, len(data), batch_size)):
            if j + batch_size < len(data):
                batch = data[j:j + batch_size]
            else:
                batch = data[j:]
            batch = hF.encode_batch_BERT(batch, model, tokenizer, device,
                                         bert_layers)

            for i, doc in enumerate(batch):
                data[j + i] = batch[i]

        ##########################################
        return data
コード例 #7
0
    def pad(self, data_batches):
        print('padding Data...')
        max_sentences = []
        max_length = []
        no_padding_sentences = []
        no_padding_lengths = []
        for index, batch in tqdm(enumerate(data_batches)):
            num_sentences = [len(x) for x in batch]
            sentence_lengthes = [[len(x) for x in y] for y in batch]
            max_num_sentences = max(num_sentences)
            max_sentences_length = max([max(x) for x in sentence_lengthes])

            batch, no_padding_num_sentences = hF.pad_batch_with_sentences(
                batch, max_num_sentences)
            batch, no_padding_sentence_lengths = hF.pad_batch_sequences(
                batch, max_sentences_length)

            max_sentences.append(max_num_sentences)
            max_length.append(max_sentences_length)
            no_padding_sentences.append(no_padding_num_sentences)
            no_padding_lengths.append(no_padding_sentence_lengths)
            data_batches[index] = batch
        ##########################################
        return data_batches, max_sentences, max_length, no_padding_sentences, no_padding_lengths
コード例 #8
0
    def pad_BERT(self, data_batches, bert_layers, bert_dims):
        print('Padding Data using BERT...')
        max_sentences = []
        no_padding_sentences = []
        for index, batch in tqdm(enumerate(data_batches)):
            num_sentences = [len(x) for x in batch]
            max_num_sentences = max(num_sentences)

            batch, no_padding_num_sentences = hF.pad_batch_with_sentences_BERT(
                batch, max_num_sentences, bert_layers, bert_dims)

            max_sentences.append(max_num_sentences)
            no_padding_sentences.append(no_padding_num_sentences)
            data_batches[index] = batch
        ##########################################
        return data_batches, max_sentences, None, no_padding_sentences, None
コード例 #9
0
 def encode(self, data, word2id_dictionary):
     for index, doc in tqdm(enumerate(data)):
         data[index] = hF.encode_document(doc, word2id_dictionary)
     return data
コード例 #10
0
def train_batch(model, device, post_batch, comment_batch, answer_batch,
                max_sentences, max_length, no_padding_sentences,
                no_padding_lengths, posts_max_sentences, posts_max_length,
                posts_no_padding_sentences, posts_no_padding_lengths,
                optimizer, criterion, use_bert):
    # model.train()
    # epoch_loss = 0
    # pbar = tqdm_notebook(enumerate(comment_batches))
    # for index, batch in pbar:
    #     pbar.set_description("Training {}/{}, loss={}".format(index, len(comment_batches), round(epoch_loss)))
    #         tensor_batch = batch.to(device)#HelpingFunctions.convert_to_tensor(batch, device)
    #         tensor_post_batch = post_batches[index].to(device)#HelpingFunctions.convert_to_tensor(post_batches[index], device)

    if use_bert is True:
        tensor_batch = HelpingFunctions.convert_to_tensor(
            comment_batch, device, 'float')
        tensor_post_batch = HelpingFunctions.convert_to_tensor(
            post_batch, device, 'float')
    else:
        tensor_batch = HelpingFunctions.convert_to_tensor(
            comment_batch, device)
        tensor_post_batch = HelpingFunctions.convert_to_tensor(
            post_batch, device)

    if max_length is not None:
        batch_max_length = max_length
    else:
        batch_max_length = None

    if no_padding_lengths is not None:
        batch_no_padding_lengths = no_padding_lengths
    else:
        batch_no_padding_lengths = None

    if posts_max_length is not None:
        batch_posts_max_length = posts_max_length
    else:
        batch_posts_max_length = None

    if posts_no_padding_lengths is not None:
        batch_posts_no_padding_lengths = posts_no_padding_lengths
    else:
        batch_posts_no_padding_lengths = None

    sentence_probabilities = model(tensor_batch, max_sentences,
                                   batch_max_length, no_padding_sentences,
                                   batch_no_padding_lengths, tensor_post_batch,
                                   posts_max_sentences, batch_posts_max_length,
                                   posts_no_padding_sentences,
                                   batch_posts_no_padding_lengths)

    # for i in range(len(sentence_probabilities)):
    #     for j in range(len(sentence_probabilities[i,:])):
    #         if j >= no_padding_sentences[i]:
    #             sentence_probabilities[i, j] = 0 * sentence_probabilities[i, j]

    # targets = copy.deepcopy(answer_batch)
    # for i, elem in enumerate(targets):
    #     while len(targets[i]) < max_sentences:
    #         targets[i].append(0)

    targets = copy.deepcopy(answer_batch)
    for i, elem in enumerate(targets):
        while len(targets[i]) < max_sentences:
            targets[i].append(0)
    for i in range(len(sentence_probabilities)):
        for j in range(len(sentence_probabilities[i, :])):
            if j >= no_padding_sentences[i]:
                targets[i][j] = sentence_probabilities[i, j].item()

    # loss_cri = torch.nn.BCELoss(reduction='sum')
    # loss = criterion(sentence_probabilities, torch.FloatTensor(targets).to(device))
    loss = criterion(sentence_probabilities,
                     torch.FloatTensor(targets).to(device))
    loss = loss / sum(no_padding_sentences)

    optimizer.zero_grad()
    loss.backward()
    clip_grad_norm_(model.parameters(), 2)
    optimizer.step()
    return loss.item()
コード例 #11
0
def val_batch(model, device, post_batch, comment_batch, answer_batch,
              max_sentences, max_length, no_padding_sentences,
              no_padding_lengths, posts_max_sentences, posts_max_length,
              posts_no_padding_sentences, posts_no_padding_lengths, criterion,
              use_bert):
    if use_bert is True:
        tensor_batch = HelpingFunctions.convert_to_tensor(
            comment_batch, device, 'float')
        tensor_post_batch = HelpingFunctions.convert_to_tensor(
            post_batch, device, 'float')
    else:
        tensor_batch = HelpingFunctions.convert_to_tensor(
            comment_batch, device)
        tensor_post_batch = HelpingFunctions.convert_to_tensor(
            post_batch, device)

    if max_length is not None:
        batch_max_length = max_length
    else:
        batch_max_length = None

    if no_padding_lengths is not None:
        batch_no_padding_lengths = no_padding_lengths
    else:
        batch_no_padding_lengths = None

    if posts_max_length is not None:
        batch_posts_max_length = posts_max_length
    else:
        batch_posts_max_length = None

    if posts_no_padding_lengths is not None:
        batch_posts_no_padding_lengths = posts_no_padding_lengths
    else:
        batch_posts_no_padding_lengths = None

    sentence_probabilities = model(tensor_batch, max_sentences,
                                   batch_max_length, no_padding_sentences,
                                   batch_no_padding_lengths, tensor_post_batch,
                                   posts_max_sentences, batch_posts_max_length,
                                   posts_no_padding_sentences,
                                   batch_posts_no_padding_lengths)

    # for i in range(len(sentence_probabilities)):
    #     for j in range(len(sentence_probabilities[i,:])):
    #         if j >= no_padding_sentences[i]:
    #             sentence_probabilities[i, j] = 0 * sentence_probabilities[i, j]

    # targets = answer_batch
    # for i, elem in enumerate(targets):
    #     while len(targets[i]) < max_sentences:
    #         targets[i].append(0)

    targets = copy.deepcopy(answer_batch)
    for i, elem in enumerate(targets):
        while len(targets[i]) < max_sentences:
            targets[i].append(0)

    for i in range(len(sentence_probabilities)):
        for j in range(len(sentence_probabilities[i, :])):
            if j >= no_padding_sentences[i]:
                targets[i][j] = sentence_probabilities[i, j].item()

    loss = criterion(sentence_probabilities,
                     torch.FloatTensor(targets).to(device))
    loss = loss / sum(no_padding_sentences)
    return loss.item()
コード例 #12
0
def summarize(model, device, post_batches, test_comment_batches,
              test_human_summary_batches, sentences_str_batches, max_sentences,
              max_length, no_padding_sentences, no_padding_lengths,
              posts_max_sentences, posts_max_length,
              posts_no_padding_sentences, posts_no_padding_lengths,
              id2word_dic, output_dir, use_bert):
    model.eval()
    predicted_words = []

    for index, batch in tqdm_notebook(enumerate(test_comment_batches)):
        #         tensor_batch = batch.to(device)#HelpingFunctions.convert_to_tensor(batch, device)
        #         tensor_post_batch = post_batches[index].to(device)#HelpingFunctions.convert_to_tensor(post_batches[index], device)

        if use_bert is True:
            tensor_batch = HelpingFunctions.convert_to_tensor(
                batch, device, 'float')
            tensor_post_batch = HelpingFunctions.convert_to_tensor(
                post_batches[index], device, 'float')
        else:
            tensor_batch = HelpingFunctions.convert_to_tensor(batch, device)
            tensor_post_batch = HelpingFunctions.convert_to_tensor(
                post_batches[index], device)

        if max_length is not None:
            batch_max_length = max_length[index]
        else:
            batch_max_length = None

        if no_padding_lengths is not None:
            batch_no_padding_lengths = no_padding_lengths[index]
        else:
            batch_no_padding_lengths = None

        if posts_max_length is not None:
            batch_posts_max_length = posts_max_length[index]
        else:
            batch_posts_max_length = None

        if posts_no_padding_lengths is not None:
            batch_posts_no_padding_lengths = posts_no_padding_lengths[index]
        else:
            batch_posts_no_padding_lengths = None

        sentence_probabilities = model(
            tensor_batch, max_sentences[index], batch_max_length,
            no_padding_sentences[index], batch_no_padding_lengths,
            tensor_post_batch, posts_max_sentences[index],
            batch_posts_max_length, posts_no_padding_sentences[index],
            batch_posts_no_padding_lengths)

        for i in range(len(sentence_probabilities)):
            for j in range(len(sentence_probabilities[i, :])):
                if j >= no_padding_sentences[index][i]:
                    sentence_probabilities[i,
                                           j] = 0 * sentence_probabilities[i,
                                                                           j]

        sentence_probabilities = sentence_probabilities.tolist()

        for prediction, indcies in zip(sentence_probabilities,
                                       sentences_str_batches[index]):
            pre_tokens = []
            for i, val in enumerate(prediction):
                if prediction[i] > 0.5:
                    pre_tokens.append(indcies[i])
            predicted_words.append(pre_tokens)

    index = 1
    if not os.path.exists(output_dir + '/dec/'):
        os.mkdir(output_dir + '/dec/')

    for prediction_elem in predicted_words:
        predicted_output = codecs.open(output_dir +
                                       '/dec/{}.dec'.format(index),
                                       'w',
                                       encoding='utf8')

        for sentence in prediction_elem:
            # sentence_text = ' '.join([id2word_dic[word] for word in sentence]).replace('<SOS>', '').replace('<EOS>', '').replace('<pad>', '').strip()
            sentence_text = ' '.join(sentence).replace('<SOS>', '').replace(
                '<EOS>', '').replace('<pad>', '').strip()
            predicted_output.write(sentence_text + '\n')
        predicted_output.close()

        index += 1
コード例 #13
0
def test_epoch(model, device, post_batches, comment_batches, answer_batches,
               human_summary_batches, sentences_str_batches, max_sentences,
               max_length, no_padding_sentences, no_padding_lengths,
               posts_max_sentences, posts_max_length,
               posts_no_padding_sentences, posts_no_padding_lengths,
               id2word_dic, output_dir, use_bert):
    model.eval()

    target_words = []
    predicted_words = []
    human_summaries = []

    pbar = tqdm_notebook(enumerate(comment_batches))
    for index, batch in pbar:
        pbar.set_description("Testing {}/{}".format(index,
                                                    len(comment_batches)))
        #         tensor_batch = batch.to(device)#HelpingFunctions.convert_to_tensor(batch, device)
        #         tensor_post_batch = post_batches[index].to(device)#HelpingFunctions.convert_to_tensor(post_batches[index], device)

        if use_bert is True:
            tensor_batch = HelpingFunctions.convert_to_tensor(
                batch, device, 'float')
            tensor_post_batch = HelpingFunctions.convert_to_tensor(
                post_batches[index], device, 'float')
        else:
            tensor_batch = HelpingFunctions.convert_to_tensor(batch, device)
            tensor_post_batch = HelpingFunctions.convert_to_tensor(
                post_batches[index], device)

        if max_length is not None:
            batch_max_length = max_length[index]
        else:
            batch_max_length = None

        if no_padding_lengths is not None:
            batch_no_padding_lengths = no_padding_lengths[index]
        else:
            batch_no_padding_lengths = None

        if posts_max_length is not None:
            batch_posts_max_length = posts_max_length[index]
        else:
            batch_posts_max_length = None

        if posts_no_padding_lengths is not None:
            batch_posts_no_padding_lengths = posts_no_padding_lengths[index]
        else:
            batch_posts_no_padding_lengths = None

        sentence_probabilities = model(
            tensor_batch, max_sentences[index], batch_max_length,
            no_padding_sentences[index], batch_no_padding_lengths,
            tensor_post_batch, posts_max_sentences[index],
            batch_posts_max_length, posts_no_padding_sentences[index],
            batch_posts_no_padding_lengths)

        # for i in range(len(sentence_probabilities)):
        #     for j in range(len(sentence_probabilities[i,:])):
        #         if j >= no_padding_sentences[index][i]:
        #             sentence_probabilities[i, j] = 0 * sentence_probabilities[i, j]

        sentence_probabilities = sentence_probabilities.tolist()
        targets = answer_batches[index]
        human_summaries += human_summary_batches[index]

        for target, prediction, indcies in zip(targets, sentence_probabilities,
                                               sentences_str_batches[index]):
            pre_sentences = []
            tar_sentences = []
            for i, val in enumerate(target):
                if target[i] == 1:
                    tar_sentences.append(indcies[i])
                if prediction[i] > 0.5:
                    pre_sentences.append(indcies[i])

            target_words.append(tar_sentences)
            predicted_words.append(pre_sentences)

    index = 1
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    if not os.path.exists(output_dir + '/ref/'):
        os.mkdir(output_dir + '/ref/')
    if not os.path.exists(output_dir + '/ref_abs/'):
        os.mkdir(output_dir + '/ref_abs/')
    if not os.path.exists(output_dir + '/dec/'):
        os.mkdir(output_dir + '/dec/')

    for target_elem, prediction_elem, human_summary in zip(
            target_words, predicted_words, human_summaries):
        gold_abs_output = codecs.open(output_dir +
                                      '/ref_abs/{}.ref'.format(index),
                                      'w',
                                      encoding='utf8')
        gold_output = codecs.open(output_dir + '/ref/{}.ref'.format(index),
                                  'w',
                                  encoding='utf8')
        predicted_output = codecs.open(output_dir +
                                       '/dec/{}.dec'.format(index),
                                       'w',
                                       encoding='utf8')

        for sentence in target_elem:
            #sentence_text = ' '.join([id2word_dic[word] for word in sentence]).replace('<SOS>', '').replace('<EOS>', '').replace('<pad>', '').strip()
            sentence_text = ' '.join(sentence).replace('<SOS>', '').replace(
                '<EOS>', '').replace('<pad>', '').strip()
            gold_output.write(sentence_text + '\n')
        gold_output.close()

        gold_abs_output.write(
            human_summary.replace('<SOS>',
                                  '').replace('<EOS>',
                                              '').replace('<pad>', '').strip())
        gold_abs_output.close()

        for sentence in prediction_elem:
            #sentence_text = ' '.join([id2word_dic[word] for word in sentence]).replace('<SOS>', '').replace('<EOS>', '').replace('<pad>', '').strip()
            sentence_text = ' '.join(sentence).replace('<SOS>', '').replace(
                '<EOS>', '').replace('<pad>', '').strip()
            predicted_output.write(sentence_text + '\n')
        predicted_output.close()

        index += 1
コード例 #14
0
def validate_epoch(model, device, post_batches, comment_batches,
                   answer_batches, max_sentences, max_length,
                   no_padding_sentences, no_padding_lengths,
                   posts_max_sentences, posts_max_length,
                   posts_no_padding_sentences, posts_no_padding_lengths,
                   criterion, use_bert):
    model.eval()
    val_loss = 0
    pbar = tqdm_notebook(enumerate(comment_batches))
    for index, batch in pbar:
        pbar.set_description("Validating {}/{}, loss={}".format(
            index, len(comment_batches), round(val_loss)))
        #         tensor_batch = batch.to(device)#HelpingFunctions.convert_to_tensor(batch, device)
        #         tensor_post_batch = post_batches[index].to(device)#HelpingFunctions.convert_to_tensor(post_batches[index], device)

        if use_bert is True:
            tensor_batch = HelpingFunctions.convert_to_tensor(
                batch, device, 'float')
            tensor_post_batch = HelpingFunctions.convert_to_tensor(
                post_batches[index], device, 'float')
        else:
            tensor_batch = HelpingFunctions.convert_to_tensor(batch, device)
            tensor_post_batch = HelpingFunctions.convert_to_tensor(
                post_batches[index], device)

        if max_length is not None:
            batch_max_length = max_length[index]
        else:
            batch_max_length = None

        if no_padding_lengths is not None:
            batch_no_padding_lengths = no_padding_lengths[index]
        else:
            batch_no_padding_lengths = None

        if posts_max_length is not None:
            batch_posts_max_length = posts_max_length[index]
        else:
            batch_posts_max_length = None

        if posts_no_padding_lengths is not None:
            batch_posts_no_padding_lengths = posts_no_padding_lengths[index]
        else:
            batch_posts_no_padding_lengths = None

        sentence_probabilities = model(
            tensor_batch, max_sentences[index], batch_max_length,
            no_padding_sentences[index], batch_no_padding_lengths,
            tensor_post_batch, posts_max_sentences[index],
            batch_posts_max_length, posts_no_padding_sentences[index],
            batch_posts_no_padding_lengths)

        for i in range(len(sentence_probabilities)):
            for j in range(len(sentence_probabilities[i, :])):
                if j >= no_padding_sentences[index][i]:
                    sentence_probabilities[i,
                                           j] = 0 * sentence_probabilities[i,
                                                                           j]

        targets = answer_batches[index]
        for i, elem in enumerate(targets):
            while len(targets[i]) < max_sentences[index]:
                targets[i].append(0)
        loss = criterion(sentence_probabilities,
                         torch.FloatTensor(targets).to(device))
        val_loss += loss.item()
    val_loss = val_loss / len(answer_batches)
    # print('Validation Loss:\t{}'.format(val_loss))
    return val_loss
コード例 #15
0
def train_epoch(model, device, post_batches, comment_batches, answer_batches,
                max_sentences, max_length, no_padding_sentences,
                no_padding_lengths, posts_max_sentences, posts_max_length,
                posts_no_padding_sentences, posts_no_padding_lengths,
                optimizer, criterion, use_bert):
    model.train()
    epoch_loss = 0
    pbar = tqdm_notebook(enumerate(comment_batches))
    for index, batch in pbar:
        pbar.set_description("Training {}/{}, loss={}".format(
            index, len(comment_batches), round(epoch_loss)))
        #         tensor_batch = batch.to(device)#HelpingFunctions.convert_to_tensor(batch, device)
        #         tensor_post_batch = post_batches[index].to(device)#HelpingFunctions.convert_to_tensor(post_batches[index], device)

        if use_bert is True:
            tensor_batch = HelpingFunctions.convert_to_tensor(
                batch, device, 'float')
            tensor_post_batch = HelpingFunctions.convert_to_tensor(
                post_batches[index], device, 'float')
        else:
            tensor_batch = HelpingFunctions.convert_to_tensor(batch, device)
            tensor_post_batch = HelpingFunctions.convert_to_tensor(
                post_batches[index], device)

        if max_length is not None:
            batch_max_length = max_length[index]
        else:
            batch_max_length = None

        if no_padding_lengths is not None:
            batch_no_padding_lengths = no_padding_lengths[index]
        else:
            batch_no_padding_lengths = None

        if posts_max_length is not None:
            batch_posts_max_length = posts_max_length[index]
        else:
            batch_posts_max_length = None

        if posts_no_padding_lengths is not None:
            batch_posts_no_padding_lengths = posts_no_padding_lengths[index]
        else:
            batch_posts_no_padding_lengths = None

        sentence_probabilities = model(
            tensor_batch, max_sentences[index], batch_max_length,
            no_padding_sentences[index], batch_no_padding_lengths,
            tensor_post_batch, posts_max_sentences[index],
            batch_posts_max_length, posts_no_padding_sentences[index],
            batch_posts_no_padding_lengths)

        for i in range(len(sentence_probabilities)):
            for j in range(len(sentence_probabilities[i, :])):
                if j >= no_padding_sentences[index][i]:
                    sentence_probabilities[i,
                                           j] = 0 * sentence_probabilities[i,
                                                                           j]

        targets = answer_batches[index]
        for i, elem in enumerate(targets):
            while len(targets[i]) < max_sentences[index]:
                targets[i].append(0)

        loss = criterion(sentence_probabilities,
                         torch.FloatTensor(targets).to(device))
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(model.parameters(), 2)
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss = epoch_loss / len(comment_batches)
    # print('Epch {}:\t{}'.format(epoch, epoch_loss))
    return epoch_loss
コード例 #16
0
def test_batch(model, device, post_batch, comment_batch, answer_batch,
               human_summary_batch, sentences_str_batch, max_sentences,
               max_length, no_padding_sentences, no_padding_lengths,
               posts_max_sentences, posts_max_length,
               posts_no_padding_sentences, posts_no_padding_lengths, use_bert):
    if use_bert is True:
        tensor_batch = HelpingFunctions.convert_to_tensor(
            comment_batch, device, 'float')
        tensor_post_batch = HelpingFunctions.convert_to_tensor(
            post_batch, device, 'float')
    else:
        tensor_batch = HelpingFunctions.convert_to_tensor(
            comment_batch, device)
        tensor_post_batch = HelpingFunctions.convert_to_tensor(
            post_batch, device)

    if max_length is not None:
        batch_max_length = max_length
    else:
        batch_max_length = None

    if no_padding_lengths is not None:
        batch_no_padding_lengths = no_padding_lengths
    else:
        batch_no_padding_lengths = None

    if posts_max_length is not None:
        batch_posts_max_length = posts_max_length
    else:
        batch_posts_max_length = None

    if posts_no_padding_lengths is not None:
        batch_posts_no_padding_lengths = posts_no_padding_lengths
    else:
        batch_posts_no_padding_lengths = None

    sentence_probabilities = model(tensor_batch, max_sentences,
                                   batch_max_length, no_padding_sentences,
                                   batch_no_padding_lengths, tensor_post_batch,
                                   posts_max_sentences, batch_posts_max_length,
                                   posts_no_padding_sentences,
                                   batch_posts_no_padding_lengths)

    for i in range(len(sentence_probabilities)):
        for j in range(len(sentence_probabilities[i, :])):
            if j >= no_padding_sentences[i]:
                sentence_probabilities[i, j] = 0 * sentence_probabilities[i, j]

    sentence_probabilities = sentence_probabilities.tolist()
    targets = answer_batch
    human_summaries = human_summary_batch

    target_words = []
    predicted_words = []
    # human_summaries = []
    for target, prediction, indcies in zip(targets, sentence_probabilities,
                                           sentences_str_batch):
        ## for each document in the batch
        pre_sentences = []
        tar_sentences = []
        for i, val in enumerate(target):
            if target[i] == 1:
                tar_sentences.append(indcies[i])
            if prediction[i] > 0.5:
                pre_sentences.append(indcies[i])

        target_words.append(tar_sentences)
        predicted_words.append(pre_sentences)

    target_sentences = []
    predicted_sentences = []
    for doc in target_words:
        doc_text = ''
        for sentence in doc:
            sentence_text = ' '.join(sentence).replace('<SOS>', '').replace(
                '<EOS>', '').replace('<pad>', '').strip()
            doc_text += sentence_text
        target_sentences.append(doc_text)

    for index, human_summary in enumerate(human_summaries):
        human_summaries[index] = human_summary.replace('<SOS>', '').replace(
            '<EOS>', '').replace('<pad>', '').strip()

    for doc in predicted_words:
        doc_text = ''
        for sentence in doc:
            sentence_text = ' '.join(sentence).replace('<SOS>', '').replace(
                '<EOS>', '').replace('<pad>', '').strip()
            doc_text += sentence_text + '\n'
        predicted_sentences.append(doc_text)

    return predicted_sentences, target_sentences, human_summaries