コード例 #1
0
    def fit(
            self,
            train_iteractions: interactions.MatchInteraction,
            verbose=True,  # for printing out evaluation during training
            val_interactions: interactions.MatchInteraction = None,
            test_interactions: interactions.MatchInteraction = None):
        """
        Fit the model.
        Parameters
        ----------
        train_iteractions: :class:`matchzoo.DataPack` The input sequence dataset.
        val_interactions: :class:`matchzoo.DataPack`
        test_interactions: :class:`matchzoo.DataPack`
        """
        self._initialize()
        best_ce, best_epoch, test_ce = sys.maxsize, 0, 0
        test_results_dict = None
        iteration_counter = 0
        count_patience_epochs = 0

        for epoch_num in range(self._n_iter):
            # ------ Move to here ----------------------------------- #
            self._net.train(True)
            query_ids, left_contents, left_lengths, \
            doc_ids, right_contents, target_contents, right_lengths = self._sampler.get_instances(train_iteractions)

            queries, query_content, query_lengths, \
            docs, doc_content, target_contents, doc_lengths = my_utils.shuffle(query_ids, left_contents, left_lengths,
                                                              doc_ids, right_contents, target_contents, right_lengths)
            epoch_loss, total_pairs = 0.0, 0
            t1 = time.time()
            for (minibatch_num, (batch_query, batch_query_content, batch_query_len,
                 batch_doc, batch_doc_content, batch_doc_target, batch_docs_lens)) \
                    in enumerate(my_utils.minibatch(queries, query_content, query_lengths,
                                                    docs, doc_content, target_contents, doc_lengths,
                                                    batch_size = self._batch_size)):
                t3 = time.time()
                batch_query = my_utils.gpu(torch.from_numpy(batch_query),
                                           self._use_cuda)
                batch_query_content = my_utils.gpu(
                    torch.from_numpy(batch_query_content), self._use_cuda)
                # batch_query_len = my_utils.gpu(torch.from_numpy(batch_query_len), self._use_cuda)
                batch_doc = my_utils.gpu(torch.from_numpy(batch_doc),
                                         self._use_cuda)
                batch_doc_content = my_utils.gpu(
                    torch.from_numpy(batch_doc_content), self._use_cuda)
                batch_doc_target = my_utils.gpu(
                    torch.from_numpy(batch_doc_target), self._use_cuda)
                # batch_docs_lens = my_utils.gpu(torch.from_numpy(batch_docs_lens), self._use_cuda)

                total_pairs += batch_query.size(0)  # (batch_size)
                self._optimizer.zero_grad()
                loss = self._get_loss(batch_query, batch_query_content,
                                      batch_doc, batch_doc_content,
                                      batch_query_len, batch_docs_lens,
                                      batch_doc_target)
                epoch_loss += loss.item()
                iteration_counter += 1
                # if iteration_counter % 2 == 0: break
                TensorboardWrapper.mywriter().add_scalar(
                    "loss/minibatch_loss", loss.item(), iteration_counter)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self._net.parameters(),
                                               self._clip)
                self._optimizer.step()
                t4 = time.time()
                # if iteration_counter % 100 == 0: print("Running time for each mini-batch: ", (t4 - t3), "s")
            epoch_loss /= float(total_pairs)
            TensorboardWrapper.mywriter().add_scalar("loss/epoch_loss_avg",
                                                     epoch_loss, epoch_num)
            # print("Number of Minibatches: ", minibatch_num, "Avg. loss of epoch: ", epoch_loss)
            t2 = time.time()
            epoch_train_time = t2 - t1
            if verbose:  # validation after each epoch
                t1 = time.time()
                result_val = self.evaluate(val_interactions)
                val_ce = result_val["cross_entropy"]
                t2 = time.time()
                validation_time = t2 - t1

                TensorboardWrapper.mywriter().add_scalar(
                    "cross_entropy/val_ce", val_ce, epoch_num)
                FileHandler.myprint(
                    '|Epoch %03d | Train time: %04.1f(s) | Train loss: %.3f'
                    '| Val loss = %.5f | Validation time: %04.1f(s)' %
                    (epoch_num, epoch_train_time, epoch_loss, val_ce,
                     validation_time))

                if val_ce < best_ce:
                    count_patience_epochs = 0
                    with open(self.saved_model, "wb") as f:
                        torch.save(self._net.state_dict(), f)
                    # test_results_dict = result_test
                    best_ce, best_epoch = val_ce, epoch_num
                else:
                    count_patience_epochs += 1
                if self._early_stopping_patience and count_patience_epochs > self._early_stopping_patience:
                    FileHandler.myprint(
                        "Early Stopped due to no better performance in %s epochs"
                        % count_patience_epochs)
                    break

            if np.isnan(epoch_loss) or epoch_loss == 0.0:
                raise ValueError(
                    'Degenerate epoch loss: {}'.format(epoch_loss))
        FileHandler.myprint("Closing tensorboard")
        TensorboardWrapper.mywriter().close()
        FileHandler.myprint(
            'Best result: | vad cross_entropy = %.5f | epoch = %d' %
            (best_ce, best_epoch))
        FileHandler.myprint_details(
            json.dumps(test_results_dict, sort_keys=True, indent=2))
コード例 #2
0
    def fit(
            self,
            train_iteractions: interactions.MatchInteraction,
            verbose=True,  # for printing out evaluation during training
            topN=10,
            val_interactions: interactions.MatchInteraction = None,
            test_interactions: interactions.MatchInteraction = None):
        """
        Fit the model.
        Parameters
        ----------
        train_iteractions: :class:`matchzoo.DataPack` The input sequence dataset.
        val_interactions: :class:`matchzoo.DataPack`
        test_interactions: :class:`matchzoo.DataPack`
        """
        self._initialize(train_iteractions)
        best_hit, best_ndcg, best_epoch, test_ndcg, test_hit = 0, 0, 0, 0, 0
        test_results_dict = None
        iteration_counter = 0
        count_patience_epochs = 0

        for epoch_num in range(self._n_iter):

            # ------ Move to here ----------------------------------- #
            self._net.train(True)
            query_ids, left_contents, left_lengths, \
            doc_ids, right_contents, right_lengths, \
            neg_docs_ids, neg_docs_contents, neg_docs_lens = self._sampler.get_train_instances(train_iteractions, self._num_negative_samples)

            queries, query_content, query_lengths, \
            docs, doc_content, doc_lengths, \
            neg_docs, neg_docs_contents, neg_docs_lens = my_utils.shuffle(query_ids, left_contents, left_lengths,
                                                                doc_ids, right_contents, right_lengths,
                                                                neg_docs_ids, neg_docs_contents, neg_docs_lens)
            epoch_loss, total_pairs = 0.0, 0
            t1 = time.time()
            for (minibatch_num,
                (batch_query, batch_query_content, batch_query_len,
                 batch_doc, batch_doc_content, batch_docs_lens,
                 batch_neg_docs, batch_neg_doc_content, batch_neg_docs_lens)) \
                    in enumerate(my_utils.minibatch(queries, query_content, query_lengths,
                                                    docs, doc_content, doc_lengths,
                                                    neg_docs, neg_docs_contents, neg_docs_lens,
                                                    batch_size = self._batch_size)):
                # add idf here...
                query_idfs = None
                if len(TFIDF.get_term_idf()) != 0:
                    query_idf_dict = TFIDF.get_term_idf()
                    query_idfs = [[
                        query_idf_dict.get(int(word_idx), 0.0)
                        for word_idx in row
                    ] for row in batch_query_content]
                    query_idfs = torch_utils.gpu(
                        torch.from_numpy(np.array(query_idfs)).float(),
                        self._use_cuda)

                batch_query = my_utils.gpu(torch.from_numpy(batch_query),
                                           self._use_cuda)
                batch_query_content = my_utils.gpu(
                    torch.from_numpy(batch_query_content), self._use_cuda)
                batch_doc = my_utils.gpu(torch.from_numpy(batch_doc),
                                         self._use_cuda)
                batch_doc_content = my_utils.gpu(
                    torch.from_numpy(batch_doc_content), self._use_cuda)
                batch_neg_doc_content = my_utils.gpu(
                    torch.from_numpy(batch_neg_doc_content), self._use_cuda)
                total_pairs += self._batch_size * self._num_negative_samples

                self._optimizer.zero_grad()
                if self._loss in ["bpr", "hinge", "pce", "bce"]:
                    loss = self._get_multiple_negative_predictions_normal(
                        batch_query,
                        batch_query_content,
                        batch_doc,
                        batch_doc_content,
                        batch_neg_docs,
                        batch_neg_doc_content,
                        batch_query_len,
                        batch_docs_lens,
                        batch_neg_docs_lens,
                        self._num_negative_samples,
                        query_idf=query_idfs)
                epoch_loss += loss.item()
                iteration_counter += 1
                # if iteration_counter % 2 == 0: break
                TensorboardWrapper.mywriter().add_scalar(
                    "loss/minibatch_loss", loss.item(), iteration_counter)
                loss.backward()
                self._optimizer.step()
            epoch_loss /= float(total_pairs)
            TensorboardWrapper.mywriter().add_scalar("loss/epoch_loss_avg",
                                                     epoch_loss, epoch_num)
            # print("Number of Minibatches: ", minibatch_num, "Avg. loss of epoch: ", epoch_loss)
            t2 = time.time()
            epoch_train_time = t2 - t1
            if verbose:  # validation after each epoch
                t1 = time.time()
                assert len(val_interactions.unique_queries_test
                           ) in KeyWordSettings.QueryCountVal, len(
                               val_interactions.unique_queries_test)
                result_val = self.evaluate(val_interactions, topN)
                hits = result_val["hits"]
                ndcg = result_val["ndcg"]
                t2 = time.time()
                valiation_time = t2 - t1

                if epoch_num and epoch_num % self._testing_epochs == 0:
                    t1 = time.time()
                    assert len(test_interactions.unique_queries_test
                               ) in KeyWordSettings.QueryCountTest
                    result_test = self.evaluate(test_interactions, topN)
                    hits_test = result_test["hits"]
                    ndcg_test = result_test["ndcg"]
                    t2 = time.time()
                    testing_time = t2 - t1
                    TensorboardWrapper.mywriter().add_scalar(
                        "hit/hit_test", hits_test, epoch_num)
                    TensorboardWrapper.mywriter().add_scalar(
                        "ndcg/ndcg_test", ndcg_test, epoch_num)
                    FileHandler.myprint(
                        '|Epoch %03d | Test hits@%d = %.5f | Test ndcg@%d = %.5f | Testing time: %04.1f(s)'
                        % (epoch_num, topN, hits_test, topN, ndcg_test,
                           testing_time))

                TensorboardWrapper.mywriter().add_scalar(
                    "hit/hits_val", hits, epoch_num)
                TensorboardWrapper.mywriter().add_scalar(
                    "ndcg/ndcg_val", ndcg, epoch_num)
                FileHandler.myprint(
                    '|Epoch %03d | Train time: %04.1f(s) | Train loss: %.3f'
                    '| Vad hits@%d = %.5f | Vad ndcg@%d = %.5f | Validation time: %04.1f(s)'
                    % (epoch_num, epoch_train_time, epoch_loss, topN, hits,
                       topN, ndcg, valiation_time))

                if hits > best_hit or (hits == best_hit and ndcg > best_ndcg):
                    # if (hits + ndcg) > (best_hit + best_ndcg):
                    count_patience_epochs = 0
                    with open(self.saved_model, "wb") as f:
                        torch.save(self._net.state_dict(), f)
                    # test_results_dict = result_test
                    best_hit, best_ndcg, best_epoch = hits, ndcg, epoch_num
                    # test_hit, test_ndcg = hits_test, ndcg_test
                else:
                    count_patience_epochs += 1
                if self._early_stopping_patience and count_patience_epochs > self._early_stopping_patience:
                    FileHandler.myprint(
                        "Early Stopped due to no better performance in %s epochs"
                        % count_patience_epochs)
                    break

            if np.isnan(epoch_loss) or epoch_loss == 0.0:
                raise ValueError(
                    'Degenerate epoch loss: {}'.format(epoch_loss))
        FileHandler.myprint("Closing tensorboard")
        TensorboardWrapper.mywriter().close()
        FileHandler.myprint(
            'Best result: | vad hits@%d = %.5f | vad ndcg@%d = %.5f | epoch = %d'
            % (topN, best_hit, topN, best_ndcg, best_epoch))
        FileHandler.myprint_details(
            json.dumps(test_results_dict, sort_keys=True, indent=2))