Ejemplo n.º 1
0
    def test(self):
        self.model.eval()
        batch_loss_history = []
        n_total_words = 0
        for batch_i, (conversations, conversation_length, sentence_length) in enumerate(tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            input_conversations = [conv[:-1] for conv in conversations]
            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            input_sentences = [
                sent for conv in input_conversations for sent in conv]
            target_sentences = [
                sent for conv in target_conversations for sent in conv]
            input_sentence_length = [
                l for len_list in sentence_length for l in len_list[:-1]]
            target_sentence_length = [
                l for len_list in sentence_length for l in len_list[1:]]
            input_conversation_length = [l - 1 for l in conversation_length]

            with torch.no_grad():
                input_sentences = to_var(torch.LongTensor(input_sentences))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                input_sentence_length = to_var(
                    torch.LongTensor(input_sentence_length))
                target_sentence_length = to_var(
                    torch.LongTensor(target_sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))

            sentence_logits = self.model(
                input_sentences,
                input_sentence_length,
                input_conversation_length,
                target_sentences)

            batch_loss, n_words = masked_cross_entropy(
                sentence_logits,
                target_sentences,
                target_sentence_length)

            assert not isnan(batch_loss.item())
            batch_loss_history.append(batch_loss.item())
            n_total_words += n_words.item()

        epoch_loss = np.sum(batch_loss_history) / n_total_words

        print(f'Number of words: {n_total_words}')
        print(f'Bits per word: {epoch_loss:.3f}')
        word_perplexity = np.exp(epoch_loss)

        print_str = f'Word perplexity : {word_perplexity:.3f}\n'
        print(print_str)

        return word_perplexity
Ejemplo n.º 2
0
    def evaluate(self):
        self.model.eval()
        batch_loss_history = []
        recon_loss_history = []
        kl_div_history = []
        n_total_words = 0
        for batch_i, (conversations, conversation_length, sentence_length) \
                in enumerate(tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths
            target_conversations = [conv[1:] for conv in conversations]

            utterances = [utter for conv in conversations for utter in conv]
            target_utterances = [
                utter for conv in target_conversations for utter in conv
            ]
            utterance_length = [
                l for len_list in sentence_length for l in len_list
            ]
            target_utterance_length = [
                l for len_list in sentence_length for l in len_list[1:]
            ]
            input_conversation_length = [
                conv_len - 1 for conv_len in conversation_length
            ]

            with torch.no_grad():
                utterances = to_var(torch.LongTensor(utterances))
                utterance_length = to_var(torch.LongTensor(utterance_length))
                target_utterances = to_var(torch.LongTensor(target_utterances))
                target_utterance_length = to_var(
                    torch.LongTensor(target_utterance_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))

            sentence_logits, kl_div = self.model(utterances, utterance_length,
                                                 input_conversation_length,
                                                 target_utterances)

            recon_loss, n_words = masked_cross_entropy(
                sentence_logits, target_utterances, target_utterance_length)
            batch_loss = recon_loss + kl_div

            batch_loss_history.append(batch_loss.item())
            recon_loss_history.append(recon_loss.item())
            kl_div_history.append(kl_div.item())
            n_total_words += n_words.item()

        epoch_loss = np.sum(batch_loss_history) / n_total_words
        epoch_recon_loss = np.sum(recon_loss_history) / n_total_words
        epoch_kl_div = np.sum(kl_div_history) / n_total_words

        print_str = f'Validation loss: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}'
        print(print_str)
        print('\n')

        return epoch_loss
Ejemplo n.º 3
0
    def test(self):
        self.model.eval()
        batch_loss_history = []
        n_total_words = 0
        for batch_i, (conversations, convs_length, utterances_length) in \
                enumerate(tqdm(self.eval_data_loader, ncols=80)):
            input_conversations = [conv[:-1] for conv in conversations]
            target_conversations = [conv[1:] for conv in conversations]

            input_utterances = [
                utter for conv in input_conversations for utter in conv
            ]
            target_utterances = [
                utter for conv in target_conversations for utter in conv
            ]
            input_utterance_length = [
                l for len_list in utterances_length for l in len_list[:-1]
            ]
            target_utterance_length = [
                l for len_list in utterances_length for l in len_list[1:]
            ]
            input_conversation_length = [
                conv_len - 1 for conv_len in convs_length
            ]

            with torch.no_grad():
                input_utterances = to_var(torch.LongTensor(input_utterances))
                target_utterances = to_var(torch.LongTensor(target_utterances))
                input_utterance_length = to_var(
                    torch.LongTensor(input_utterance_length))
                target_utterance_length = to_var(
                    torch.LongTensor(target_utterance_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))

            utterances_logits = self.model(input_utterances,
                                           input_utterance_length,
                                           input_conversation_length,
                                           target_utterances)

            batch_loss, n_words = masked_cross_entropy(
                utterances_logits, target_utterances, target_utterance_length)

            assert not isnan(batch_loss.item())
            batch_loss_history.append(batch_loss.item())
            n_total_words += n_words.item()

        epoch_loss = np.sum(batch_loss_history) / n_total_words

        print(f'Number of words: {n_total_words}')
        print(f'Bits per word: {epoch_loss:.3f}')
        word_perplexity = np.exp(epoch_loss)
        print(f'Word perplexity : {word_perplexity:.3f}\n')

        return word_perplexity
Ejemplo n.º 4
0
    def importance_sample(self):
        ''' Perform importance sampling to get tighter bound
        '''
        self.model.eval()
        weight_history = []
        n_total_words = 0
        kl_div_history = []
        for batch_i, (conversations, conversation_length, sentence_length) \
                in enumerate(tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            sentences = [sent for conv in conversations for sent in conv]
            input_conversation_length = [l - 1 for l in conversation_length]
            target_sentences = [
                sent for conv in target_conversations for sent in conv
            ]
            target_sentence_length = [
                l for len_list in sentence_length for l in len_list[1:]
            ]
            sentence_length = [
                l for len_list in sentence_length for l in len_list
            ]

            # n_words += sum([len([word for word in sent if word != PAD_ID]) for sent in target_sentences])
            with torch.no_grad():
                sentences = to_var(torch.LongTensor(sentences))
                sentence_length = to_var(torch.LongTensor(sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                target_sentence_length = to_var(
                    torch.LongTensor(target_sentence_length))

            # treat whole batch as one data sample
            weights = []
            for j in range(self.config.importance_sample):
                sentence_logits, kl_div, log_p_z, log_q_zx = self.model(
                    sentences, sentence_length, input_conversation_length,
                    target_sentences)

                recon_loss, n_words = masked_cross_entropy(
                    sentence_logits, target_sentences, target_sentence_length)

                log_w = (-recon_loss.sum() + log_p_z - log_q_zx).data
                weights.append(log_w)
                if j == 0:
                    n_total_words += n_words.item()
                    kl_div_history.append(kl_div.item())

            # weights: [n_samples]
            weights = torch.stack(weights, 0)
            m = np.floor(weights.max())
            weights = np.log(torch.exp(weights - m).type(get_type()).sum())
            weights = m + weights - np.log(self.config.importance_sample)
            weight_history.append(weights)

        print(f'Number of words: {n_total_words}')
        bits_per_word = -np.sum(weight_history) / n_total_words
        print(f'Bits per word: {bits_per_word:.3f}')
        word_perplexity = np.exp(bits_per_word)

        epoch_kl_div = np.sum(kl_div_history) / n_total_words

        print_str = f'Word perplexity upperbound using {self.config.importance_sample} importance samples: {word_perplexity:.3f}, kl_div: {epoch_kl_div:.3f}\n'
        print(print_str)

        return word_perplexity
Ejemplo n.º 5
0
    def evaluate(self):
        self.model.eval()
        batch_loss_history = []
        recon_loss_history = []
        kl_div_history = []
        bow_loss_history = []
        bleu_history = []
        sequences_history = []
        levenshteins_history = []
        n_total_words = 0
        for batch_i, (conversations, conversation_length, sentence_length) \
                in enumerate(tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            sentences = [sent for conv in conversations for sent in conv]
            input_conversation_length = [l - 1 for l in conversation_length]
            target_sentences = [
                sent for conv in target_conversations for sent in conv
            ]
            target_sentence_length = [
                l for len_list in sentence_length for l in len_list[1:]
            ]
            sentence_length = [
                l for len_list in sentence_length for l in len_list
            ]

            with torch.no_grad():
                sentences = to_var(torch.LongTensor(sentences))
                sentence_length = to_var(torch.LongTensor(sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                target_sentence_length = to_var(
                    torch.LongTensor(target_sentence_length))

            if batch_i == 0:
                input_conversations = [conv[:-1] for conv in conversations]
                input_sentences = [
                    sent for conv in input_conversations for sent in conv
                ]
                with torch.no_grad():
                    input_sentences = to_var(torch.LongTensor(input_sentences))
                scores = self.generate_sentence(sentences, sentence_length,
                                                input_conversation_length,
                                                input_sentences,
                                                target_sentences)

                bleu_history += scores["bleus"]
                sequences_history += scores["sequences"]
                levenshteins_history += scores["levenshteins"]

            sentence_logits, kl_div, _, _ = self.model(
                sentences, sentence_length, input_conversation_length,
                target_sentences)

            recon_loss, n_words = masked_cross_entropy(sentence_logits,
                                                       target_sentences,
                                                       target_sentence_length)

            batch_loss = recon_loss + kl_div
            if self.config.bow:
                bow_loss = self.model.compute_bow_loss(target_conversations)
                bow_loss_history.append(bow_loss.item())

            assert not isnan(batch_loss.item())
            batch_loss_history.append(batch_loss.item())
            recon_loss_history.append(recon_loss.item())
            kl_div_history.append(kl_div.item())
            n_total_words += n_words.item()

        epoch_loss = np.sum(batch_loss_history) / n_total_words
        epoch_recon_loss = np.sum(recon_loss_history) / n_total_words
        epoch_kl_div = np.sum(kl_div_history) / n_total_words

        print_str = f'Validation loss: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}'
        if bow_loss_history:
            epoch_bow_loss = np.sum(bow_loss_history) / n_total_words
            print_str += f', bow_loss = {epoch_bow_loss:.3f}'
        print(print_str)
        print('\n')

        self.average_bleu = sum(bleu_history) / len(bleu_history)
        self.average_sequences = sum(sequences_history) / len(
            sequences_history)
        self.average_levenshteins = sum(levenshteins_history) / len(
            levenshteins_history)
        print("Average scores:")
        print(" -> bleu:", self.average_bleu)
        print(" -> sequence:", self.average_sequences)
        print(" -> levenshteins:", self.average_levenshteins)

        return epoch_loss
Ejemplo n.º 6
0
    def train(self):
        epoch_loss_history = []
        kl_mult = 0.0
        conv_kl_mult = 0.0
        for epoch_i in range(self.epoch_i, self.config.n_epoch):
            self.epoch_i = epoch_i
            batch_loss_history = []
            recon_loss_history = []
            kl_div_history = []
            kl_div_sent_history = []
            kl_div_conv_history = []
            bow_loss_history = []
            self.model.train()
            n_total_words = 0

            # self.evaluate()

            for batch_i, (conversations, conversation_length, sentence_length) \
                    in enumerate(tqdm(self.train_data_loader, ncols=80)):
                # conversations: (batch_size) list of conversations
                #   conversation: list of sentences
                #   sentence: list of tokens
                # conversation_length: list of int
                # sentence_length: (batch_size) list of conversation list of sentence_lengths

                target_conversations = [conv[1:] for conv in conversations]

                # flatten input and target conversations
                sentences = [sent for conv in conversations for sent in conv]
                input_conversation_length = [
                    l - 1 for l in conversation_length
                ]
                target_sentences = [
                    sent for conv in target_conversations for sent in conv
                ]
                target_sentence_length = [
                    l for len_list in sentence_length for l in len_list[1:]
                ]
                sentence_length = [
                    l for len_list in sentence_length for l in len_list
                ]

                sentences = to_var(torch.LongTensor(sentences))
                sentence_length = to_var(torch.LongTensor(sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                target_sentence_length = to_var(
                    torch.LongTensor(target_sentence_length))

                # reset gradient
                self.optimizer.zero_grad()

                sentence_logits, kl_div, _, _ = self.model(
                    sentences, sentence_length, input_conversation_length,
                    target_sentences)

                recon_loss, n_words = masked_cross_entropy(
                    sentence_logits, target_sentences, target_sentence_length)

                batch_loss = recon_loss + kl_mult * kl_div
                batch_loss_history.append(batch_loss.item())
                recon_loss_history.append(recon_loss.item())
                kl_div_history.append(kl_div.item())
                n_total_words += n_words.item()

                if self.config.bow:
                    bow_loss = self.model.compute_bow_loss(
                        target_conversations)
                    batch_loss += bow_loss
                    bow_loss_history.append(bow_loss.item())

                assert not isnan(batch_loss.item())

                if batch_i % self.config.print_every == 0:
                    print_str = f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item() / n_words.item():.3f}, recon = {recon_loss.item() / n_words.item():.3f}, kl_div = {kl_div.item() / n_words.item():.3f}'
                    if self.config.bow:
                        print_str += f', bow_loss = {bow_loss.item() / n_words.item():.3f}'
                    tqdm.write(print_str)

                # Back-propagation
                batch_loss.backward()

                # Gradient cliping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.config.clip)

                # Run optimizer
                self.optimizer.step()
                kl_mult = min(kl_mult + 1.0 / self.config.kl_annealing_iter,
                              1.0)

            epoch_loss = np.sum(batch_loss_history) / n_total_words
            epoch_loss_history.append(epoch_loss)

            epoch_recon_loss = np.sum(recon_loss_history) / n_total_words
            epoch_kl_div = np.sum(kl_div_history) / n_total_words

            self.kl_mult = kl_mult
            self.epoch_loss = epoch_loss
            self.epoch_recon_loss = epoch_recon_loss
            self.epoch_kl_div = epoch_kl_div

            print_str = f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}'
            if bow_loss_history:
                self.epoch_bow_loss = np.sum(bow_loss_history) / n_total_words
                print_str += f', bow_loss = {self.epoch_bow_loss:.3f}'
            print(print_str)

            if epoch_i % self.config.save_every_epoch == 0:
                self.save_model(epoch_i + 1)

            print('\n<Validation>...')
            self.validation_loss = self.evaluate()

            if epoch_i % self.config.plot_every_epoch == 0:
                self.write_summary(epoch_i)

        return epoch_loss_history
Ejemplo n.º 7
0
    def train(self):
        epoch_loss_history = []
        for epoch_i in range(self.epoch_i, self.config.n_epoch):
            self.epoch_i = epoch_i
            batch_loss_history = []
            self.model.train()
            n_total_words = 0
            for batch_i, (conversations, conversation_length,
                          sentence_length) in enumerate(
                              tqdm(self.train_data_loader, ncols=80)):
                # conversations: (batch_size) list of conversations
                #   conversation: list of sentences
                #   sentence: list of tokens
                # conversation_length: list of int
                # sentence_length: (batch_size) list of conversation list of sentence_lengths

                input_conversations = [conv[:-1] for conv in conversations]
                target_conversations = [conv[1:] for conv in conversations]

                # flatten input and target conversations
                input_sentences = [
                    sent for conv in input_conversations for sent in conv
                ]
                target_sentences = [
                    sent for conv in target_conversations for sent in conv
                ]
                input_sentence_length = [
                    l for len_list in sentence_length for l in len_list[:-1]
                ]
                target_sentence_length = [
                    l for len_list in sentence_length for l in len_list[1:]
                ]
                input_conversation_length = [
                    l - 1 for l in conversation_length
                ]

                input_sentences = to_var(torch.LongTensor(input_sentences))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                input_sentence_length = to_var(
                    torch.LongTensor(input_sentence_length))
                target_sentence_length = to_var(
                    torch.LongTensor(target_sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))

                # reset gradient
                self.optimizer.zero_grad()

                sentence_logits = self.model(input_sentences,
                                             input_sentence_length,
                                             input_conversation_length,
                                             target_sentences,
                                             decode=False)

                batch_loss, n_words = masked_cross_entropy(
                    sentence_logits, target_sentences, target_sentence_length)

                assert not isnan(batch_loss.item())
                batch_loss_history.append(batch_loss.item())
                n_total_words += n_words.item()

                if batch_i % self.config.print_every == 0:
                    tqdm.write(
                        f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item()/ n_words.item():.3f}'
                    )

                # Back-propagation
                batch_loss.backward()

                # Gradient cliping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.config.clip)

                # Run optimizer
                self.optimizer.step()

            epoch_loss = np.sum(batch_loss_history) / n_total_words
            epoch_loss_history.append(epoch_loss)
            self.epoch_loss = epoch_loss

            print_str = f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}'
            print(print_str)

            if epoch_i % self.config.save_every_epoch == 0:
                self.save_model(epoch_i + 1)

            print('\n<Validation>...')
            self.validation_loss = self.evaluate()

            if epoch_i % self.config.plot_every_epoch == 0:
                self.write_summary(epoch_i)

        self.save_model(self.config.n_epoch)

        return epoch_loss_history
Ejemplo n.º 8
0
    def test(self):
        self.model.eval()
        batch_loss_history = []
        n_total_words = 0
        n_sentences = 0
        f1_total = []
        for batch_i, (conversations, conversation_length,
                      sentence_length) in enumerate(
                          tqdm(self.test_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            input_conversations = [conv[:-1] for conv in conversations]
            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            input_sentences = [
                sent for conv in input_conversations for sent in conv
            ]
            target_sentences = [
                sent for conv in target_conversations for sent in conv
            ]
            input_sentence_length = [
                l for len_list in sentence_length for l in len_list[:-1]
            ]
            target_sentence_length = [
                l for len_list in sentence_length for l in len_list[1:]
            ]
            input_conversation_length = [l - 1 for l in conversation_length]

            with torch.no_grad():
                input_sentences = to_var(torch.LongTensor(input_sentences))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                input_sentence_length = to_var(
                    torch.LongTensor(input_sentence_length))
                target_sentence_length = to_var(
                    torch.LongTensor(target_sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))

            if batch_i == 0:
                self.generate_sentence(input_sentences, input_sentence_length,
                                       input_conversation_length,
                                       target_sentences)

            generated_sentences = self.generate_conversations_with_gold_responses(
                input_sentences, input_sentence_length,
                input_conversation_length, target_sentences)
            conv_f1 = 0
            for target_sent, output_sent in zip(target_sentences,
                                                generated_sentences):
                target_sent = self.vocab.decode(target_sent)
                output_sent = self.vocab.decode(output_sent)
                f1 = metrics.f1_score(output_sent, target_sent)
                conv_f1 += f1
            conv_f1 = conv_f1 / target_sentences.shape[0]

            sentence_logits = self.model(input_sentences,
                                         input_sentence_length,
                                         input_conversation_length,
                                         target_sentences)

            batch_loss, n_words = masked_cross_entropy(sentence_logits,
                                                       target_sentences,
                                                       target_sentence_length)

            assert not isnan(batch_loss.item())
            batch_loss_history.append(batch_loss.item())
            n_total_words += n_words.item()
            f1_total.append(conv_f1)
            n_sentences += target_sentences.shape[0]

        epoch_loss = np.sum(batch_loss_history) / n_total_words
        f1_average = np.sum(f1_total) / n_sentences

        print(f'Number of words: {n_total_words}')
        print(f'Bits per word: {epoch_loss:.3f}')
        word_perplexity = np.exp(epoch_loss)

        return word_perplexity, f1_average
Ejemplo n.º 9
0
    def evaluate(self):
        self.model.eval()
        batch_freq_loss_history = []
        n_total_freq_words = 0
        batch_rare_loss_history = []
        n_total_rare_words = 0
        for batch_i, (conversations, conversation_length,
                      sentence_length) in enumerate(
                          tqdm(self.eval_freq_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            input_conversations = [conv[:-1] for conv in conversations]
            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            input_sentences = [
                sent for conv in input_conversations for sent in conv
            ]
            target_sentences = [
                sent for conv in target_conversations for sent in conv
            ]
            input_sentence_length = [
                l for len_list in sentence_length for l in len_list[:-1]
            ]
            target_sentence_length = [
                l for len_list in sentence_length for l in len_list[1:]
            ]
            input_conversation_length = [l - 1 for l in conversation_length]

            with torch.no_grad():
                input_sentences = to_var(torch.LongTensor(input_sentences))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                input_sentence_length = to_var(
                    torch.LongTensor(input_sentence_length))
                target_sentence_length = to_var(
                    torch.LongTensor(target_sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))
            '''
            if batch_i == 0:
                self.generate_sentence(input_sentences,
                                       input_sentence_length,
                                       input_conversation_length,
                                       target_sentences)
            '''
            sentence_logits = self.model(input_sentences,
                                         input_sentence_length,
                                         input_conversation_length,
                                         target_sentences)

            batch_loss, n_words = masked_cross_entropy(sentence_logits,
                                                       target_sentences,
                                                       target_sentence_length)

            assert not isnan(batch_loss.item())
            batch_freq_loss_history.append(batch_loss.item())
            n_total_freq_words += n_words.item()

        epoch_freq_loss = np.sum(batch_freq_loss_history) / n_total_freq_words

        print_str = f'Validation freq loss: {epoch_freq_loss:.3f}\n'
        print(print_str)

        for batch_i, (conversations, conversation_length,
                      sentence_length) in enumerate(
                          tqdm(self.eval_rare_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            input_conversations = [conv[:-1] for conv in conversations]
            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            input_sentences = [
                sent for conv in input_conversations for sent in conv
            ]
            target_sentences = [
                sent for conv in target_conversations for sent in conv
            ]
            input_sentence_length = [
                l for len_list in sentence_length for l in len_list[:-1]
            ]
            target_sentence_length = [
                l for len_list in sentence_length for l in len_list[1:]
            ]
            input_conversation_length = [l - 1 for l in conversation_length]

            with torch.no_grad():
                input_sentences = to_var(torch.LongTensor(input_sentences))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                input_sentence_length = to_var(
                    torch.LongTensor(input_sentence_length))
                target_sentence_length = to_var(
                    torch.LongTensor(target_sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))
            '''
            if batch_i == 0:
                self.generate_sentence(input_sentences,
                                       input_sentence_length,
                                       input_conversation_length,
                                       target_sentences)
            '''
            sentence_logits = self.model(input_sentences,
                                         input_sentence_length,
                                         input_conversation_length,
                                         target_sentences)

            batch_loss, n_words = masked_cross_entropy(sentence_logits,
                                                       target_sentences,
                                                       target_sentence_length)

            assert not isnan(batch_loss.item())
            batch_rare_loss_history.append(batch_loss.item())
            n_total_rare_words += n_words.item()

        epoch_rare_loss = np.sum(batch_rare_loss_history) / n_total_rare_words

        print_str = f'Validation rare loss: {epoch_rare_loss:.3f}\n'
        print(print_str)

        return epoch_freq_loss, epoch_rare_loss
Ejemplo n.º 10
0
    def train(self):
        epoch_loss_history = list()
        min_validation_loss = sys.float_info.max
        patience_cnt = self.config.patience

        for epoch_i in range(self.epoch_i, self.config.n_epoch):
            self.epoch_i = epoch_i
            batch_loss_history = list()
            self.model.train()
            n_total_words = 0
            for batch_i, (conversations, convs_length, utterances_length) in \
                    enumerate(tqdm(self.train_data_loader, ncols=80)):
                # conversations: [batch_size, max_conv_len, max_utter_len] list of conversation
                #   A conversation: [max_conv_len, max_utter_len] list of utterances
                #   An utterance: [max_utter_len] list of tokens
                # convs_length: [batch_size] list of integer
                # utterances_length: [batch_size, max_conv_len] list of conversation that has a list of utterance length

                input_conversations = [conv[:-1] for conv in conversations]
                target_conversations = [conv[1:] for conv in conversations]

                input_utterances = [
                    utter for conv in input_conversations for utter in conv
                ]
                target_utterances = [
                    utter for conv in target_conversations for utter in conv
                ]
                input_utterance_length = [
                    l for len_list in utterances_length for l in len_list[:-1]
                ]
                target_utterance_length = [
                    l for len_list in utterances_length for l in len_list[1:]
                ]
                input_conversation_length = [
                    conv_len - 1 for conv_len in convs_length
                ]

                input_utterances = to_var(torch.LongTensor(input_utterances))
                target_utterances = to_var(torch.LongTensor(target_utterances))
                input_utterance_length = to_var(
                    torch.LongTensor(input_utterance_length))
                target_utterance_length = to_var(
                    torch.LongTensor(target_utterance_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))

                self.optimizer.zero_grad()

                utterances_logits = self.model(input_utterances,
                                               input_utterance_length,
                                               input_conversation_length,
                                               target_utterances,
                                               decode=False)

                batch_loss, n_words = masked_cross_entropy(
                    utterances_logits, target_utterances,
                    target_utterance_length)

                assert not isnan(batch_loss.item())
                batch_loss_history.append(batch_loss.item())
                n_total_words += n_words.item()

                if batch_i % self.config.print_every == 0:
                    tqdm.write(
                        f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item()/ n_words.item():.3f}'
                    )

                batch_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.config.clip)
                self.optimizer.step()

            epoch_loss = np.sum(batch_loss_history) / n_total_words
            epoch_loss_history.append(epoch_loss)
            self.epoch_loss = epoch_loss

            print(f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}')

            if epoch_i % self.config.save_every_epoch == 0:
                self.save_model(epoch_i + 1)

            print('\n<Validation>...')
            self.validation_loss = self.evaluate()

            if epoch_i % self.config.plot_every_epoch == 0:
                self.write_summary(epoch_i)

            if min_validation_loss > self.validation_loss:
                min_validation_loss = self.validation_loss
            else:
                patience_cnt -= 1
                self.save_model(epoch_i)

            if patience_cnt < 0:
                print(f'\nEarly stop at {epoch_i}')
                self.save_model(epoch_i)
                return epoch_loss_history

        self.save_model(self.config.n_epoch)

        return epoch_loss_history
Ejemplo n.º 11
0
    def evaluate(self):
        self.model.eval()
        batch_loss_history = []
        recon_loss_history = []
        kl_div_history = []
        bow_loss_history = []
        n_total_words = 0
        for batch_i, (conversations, conversation_length, sentence_length) \
                in enumerate(tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            sentences = [sent for conv in conversations for sent in conv]
            input_conversation_length = [l - 1 for l in conversation_length]
            target_sentences = [
                sent for conv in target_conversations for sent in conv
            ]
            target_sentence_length = [
                l for len_list in sentence_length for l in len_list[1:]
            ]
            sentence_length = [
                l for len_list in sentence_length for l in len_list
            ]

            sentences = to_var(torch.LongTensor(sentences), eval=True)
            sentence_length = to_var(torch.LongTensor(sentence_length),
                                     eval=True)
            input_conversation_length = to_var(
                torch.LongTensor(input_conversation_length), eval=True)
            target_sentences = to_var(torch.LongTensor(target_sentences),
                                      eval=True)
            target_sentence_length = to_var(
                torch.LongTensor(target_sentence_length), eval=True)

            if batch_i == 0:
                self.generate_sentence(sentences, sentence_length,
                                       input_conversation_length,
                                       target_sentences)

            sentence_logits, kl_div, _, _ = self.model(
                sentences, sentence_length, input_conversation_length,
                target_sentences)

            recon_loss, n_words = masked_cross_entropy(sentence_logits,
                                                       target_sentences,
                                                       target_sentence_length)

            batch_loss = recon_loss + kl_div
            if self.config.bow:
                bow_loss = self.model.compute_bow_loss(target_conversations)
                bow_loss_history.append(bow_loss.data[0])

            assert not isnan(batch_loss.data[0])
            batch_loss_history.append(batch_loss.data[0])
            recon_loss_history.append(recon_loss.data[0])
            kl_div_history.append(kl_div.data[0])
            n_total_words += n_words.data[0]

        epoch_loss = np.sum(batch_loss_history) / n_total_words
        epoch_recon_loss = np.sum(recon_loss_history) / n_total_words
        epoch_kl_div = np.sum(kl_div_history) / n_total_words

        print_str = f'Validation loss: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}'
        if bow_loss_history:
            epoch_bow_loss = np.sum(bow_loss_history) / n_total_words
            print_str += f', bow_loss = {epoch_bow_loss:.3f}'
        print(print_str)
        print('\n')

        return epoch_loss
Ejemplo n.º 12
0
    def generate_file(self):
        batch_loss_history = []
        recon_loss_history = []
        kl_div_history = []
        bow_loss_history = []
        n_total_words = 0
        for batch_i, (conversations, conversation_length, sentence_length) \
                in enumerate(tqdm(self.eval_data_loader, ncols=80)):
            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            sentences = [sent for conv in conversations for sent in conv]
            input_conversation_length = [l - 1 for l in conversation_length]
            target_sentences = [
                sent for conv in target_conversations for sent in conv
            ]
            target_sentence_length = [
                l for len_list in sentence_length for l in len_list[1:]
            ]
            sentence_length = [
                l for len_list in sentence_length for l in len_list
            ]

            with torch.no_grad():
                sentences = to_var(torch.LongTensor(sentences))
                sentence_length = to_var(torch.LongTensor(sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                target_sentence_length = to_var(
                    torch.LongTensor(target_sentence_length))

            input_conversations = [conv[:-1] for conv in conversations]
            input_sentences = [
                sent for conv in input_conversations for sent in conv
            ]
            with torch.no_grad():
                input_sentences = to_var(torch.LongTensor(input_sentences))
            self.generate_sentence(sentences, sentence_length,
                                   input_conversation_length, input_sentences,
                                   target_sentences)

            sentence_logits, kl_div, _, _ = self.model(
                sentences, sentence_length, input_conversation_length,
                target_sentences)

            recon_loss, n_words = masked_cross_entropy(sentence_logits,
                                                       target_sentences,
                                                       target_sentence_length)

            batch_loss = recon_loss + kl_div
            if self.config.bow:
                bow_loss = self.model_eval.compute_bow_loss(
                    target_conversations)
                bow_loss_history.append(bow_loss.item())

            assert not isnan(batch_loss.item())
            batch_loss_history.append(batch_loss.item())
            recon_loss_history.append(recon_loss.item())
            kl_div_history.append(kl_div.item())
            n_total_words += n_words.item()
            # delete the cache data
            torch.cuda.empty_cache()

        epoch_loss = np.sum(batch_loss_history) / n_total_words
        epoch_recon_loss = np.sum(recon_loss_history) / n_total_words
        epoch_kl_div = np.sum(kl_div_history) / n_total_words

        print_str = f'test loss: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}'
        with open(self.config.kl_log_dir, 'a+') as fout:
            fout.write(print_str)
            fout.write('\n')

        if bow_loss_history:
            epoch_bow_loss = np.sum(bow_loss_history) / n_total_words
            print_str += f', bow_loss = {epoch_bow_loss:.3f}'
        print(print_str)
        print('\n')
        # self.model_eval = None
        torch.cuda.empty_cache()

        return epoch_loss
Ejemplo n.º 13
0
    def evaluate(self):
        # try to copy a model to
        self.model.eval()
        """
        self.model_eval = getattr(models, self.config.model)(self.config)
        if torch.cuda.is_available():
            self.model_eval = self.model_eval.cuda(3)
        ckpt_path = os.path.join(self.config.save_path, f'{self.epoch_i+1}.pkl')
        self.model_eval.load_state_dict(torch.load(ckpt_path))
        self.model_eval.eval()
        """

        batch_loss_history = []
        recon_loss_history = []
        kl_div_history = []
        bow_loss_history = []
        n_total_words = 0
        for batch_i, (conversations, conversation_length, sentence_length) \
                in enumerate(tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            sentences = [sent for conv in conversations for sent in conv]
            input_conversation_length = [l - 1 for l in conversation_length]
            target_sentences = [
                sent for conv in target_conversations for sent in conv
            ]
            target_sentence_length = [
                l for len_list in sentence_length for l in len_list[1:]
            ]
            sentence_length = [
                l for len_list in sentence_length for l in len_list
            ]

            with torch.no_grad():
                sentences = to_var(torch.LongTensor(sentences))
                sentence_length = to_var(torch.LongTensor(sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                target_sentence_length = to_var(
                    torch.LongTensor(target_sentence_length))

            if batch_i == -1:
                input_conversations = [conv[:-1] for conv in conversations]
                input_sentences = [
                    sent for conv in input_conversations for sent in conv
                ]
                with torch.no_grad():
                    input_sentences = to_var(torch.LongTensor(input_sentences))
                self.generate_sentence(sentences, sentence_length,
                                       input_conversation_length,
                                       input_sentences, target_sentences)

            sentence_logits, kl_div, _, _ = self.model(
                sentences, sentence_length, input_conversation_length,
                target_sentences)

            recon_loss, n_words = masked_cross_entropy(sentence_logits,
                                                       target_sentences,
                                                       target_sentence_length)

            batch_loss = recon_loss + kl_div
            if self.config.bow:
                bow_loss = self.model.compute_bow_loss(target_conversations)
                bow_loss_history.append(bow_loss.item())

            assert not isnan(batch_loss.item())
            batch_loss_history.append(batch_loss.item())
            recon_loss_history.append(recon_loss.item())
            kl_div_history.append(kl_div.item())
            n_total_words += n_words.item()
            # delete the cache data
            torch.cuda.empty_cache()

        epoch_loss = np.sum(batch_loss_history) / n_total_words
        epoch_recon_loss = np.sum(recon_loss_history) / n_total_words
        epoch_kl_div = np.sum(kl_div_history) / n_total_words

        print_str = f'Validation loss: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}'
        with open(self.config.kl_log_dir, 'a+') as fout:
            fout.write(print_str)
            fout.write('\n')

        if bow_loss_history:
            epoch_bow_loss = np.sum(bow_loss_history) / n_total_words
            print_str += f', bow_loss = {epoch_bow_loss:.3f}'
        print(print_str)
        print('\n')
        # self.model_eval = None
        torch.cuda.empty_cache()

        return epoch_loss
Ejemplo n.º 14
0
    def train(self):
        epoch_loss_history = list()
        kl_mult = 0.0
        min_validation_loss = sys.float_info.max
        patience_cnt = self.config.patience

        for epoch_i in range(self.epoch_i, self.config.n_epoch):
            self.epoch_i = epoch_i
            batch_loss_history = list()
            recon_loss_history = []
            kl_div_history = []
            bow_loss_history = []
            self.model.train()
            n_total_words = 0

            for batch_i, (conversations, conversation_length, sentence_length) \
                    in enumerate(tqdm(self.train_data_loader, ncols=80)):
                # conversations: (batch_size) list of conversations
                #   conversation: list of sentences
                #   sentence: list of tokens
                # conversation_length: list of int
                # sentence_length: (batch_size) list of conversation list of sentence_lengths

                target_conversations = [conv[1:] for conv in conversations]

                utterances = [
                    utter for conv in conversations for utter in conv
                ]
                target_utterances = [
                    utter for conv in target_conversations for utter in conv
                ]
                utterance_length = [
                    l for len_list in sentence_length for l in len_list
                ]
                target_utterance_length = [
                    l for len_list in sentence_length for l in len_list[1:]
                ]
                input_conversation_length = [
                    conv_len - 1 for conv_len in conversation_length
                ]

                utterances = to_var(torch.LongTensor(utterances))
                utterance_length = to_var(torch.LongTensor(utterance_length))
                target_utterances = to_var(torch.LongTensor(target_utterances))
                target_utterance_length = to_var(
                    torch.LongTensor(target_utterance_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))

                self.optimizer.zero_grad()

                utterances_logits, kl_div = self.model(
                    utterances,
                    utterance_length,
                    input_conversation_length,
                    target_utterances,
                    decode=False)

                recon_loss, n_words = masked_cross_entropy(
                    utterances_logits, target_utterances,
                    target_utterance_length)

                batch_loss = recon_loss + kl_mult * kl_div
                batch_loss_history.append(batch_loss.item())
                recon_loss_history.append(recon_loss.item())
                kl_div_history.append(kl_div.item())
                n_total_words += n_words.item()

                if batch_i % self.config.print_every == 0:
                    print_str = f'Epoch: {epoch_i + 1}, iter {batch_i}: ' \
                                f'loss = {batch_loss.item() / n_words.item():.3f}, ' \
                                f'recon = {recon_loss.item() / n_words.item():.3f}, ' \
                                f'kl_div = {kl_div.item() / n_words.item():.3f}'
                    tqdm.write(print_str)

                batch_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.config.clip)
                self.optimizer.step()
                kl_mult = min(kl_mult + 1.0 / self.config.kl_annealing_iter,
                              1.0)

            epoch_loss = np.sum(batch_loss_history) / n_total_words
            epoch_loss_history.append(epoch_loss)

            epoch_recon_loss = np.sum(recon_loss_history) / n_total_words
            epoch_kl_div = np.sum(kl_div_history) / n_total_words

            self.kl_mult = kl_mult
            self.epoch_loss = epoch_loss
            self.epoch_recon_loss = epoch_recon_loss
            self.epoch_kl_div = epoch_kl_div

            print_str = f'Epoch {epoch_i + 1} loss average: {epoch_loss:.3f}, ' \
                        f'recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}'
            if bow_loss_history:
                self.epoch_bow_loss = np.sum(bow_loss_history) / n_total_words
                print_str += f', bow_loss = {self.epoch_bow_loss:.3f}'
            print(print_str)

            if epoch_i % self.config.save_every_epoch == 0:
                self.save_model(epoch_i + 1)

            print('\n<Validation>...')
            self.validation_loss = self.evaluate()

            if epoch_i % self.config.plot_every_epoch == 0:
                self.write_summary(epoch_i)

            if min_validation_loss > self.validation_loss:
                min_validation_loss = self.validation_loss
            else:
                patience_cnt -= 1
                self.save_model(epoch_i)

            if patience_cnt < 0:
                print(f'\nEarly stop at {epoch_i}')
                self.save_model(epoch_i)
                return epoch_loss_history

        self.save_model(self.config.n_epoch)

        return epoch_loss_history