コード例 #1
0
ファイル: model.py プロジェクト: draplater/hrg-parser
    def train(self, trees, data_train=None):
        print_logger = PrintLogger()
        pool = Pool(self.options.concurrent_count)
        print_per = (100 // self.options.batch_size + 1) * self.options.batch_size
        self.sent_embeddings.rnn.set_dropout(self.options.lstm_dropout)
        for sentence_idx, batch_idx, batch_trees in split_to_batches(
                trees, self.options.batch_size):
            if sentence_idx % print_per == 0 and sentence_idx != 0:
                print_logger.print(sentence_idx)
            sessions = [self.training_session(tree, print_logger, pool)
                        for tree in batch_trees]

            batch_size_2 = int(math.ceil(len(sessions) / 2) + 0.5)
            assert batch_size_2 * 2 >= len(sessions)
            for _, _, sub_sessions in split_to_batches(
                    sessions, batch_size_2):
                exprs = [i for session in sub_sessions for i in next(session)]
                if exprs:
                    dn.forward(exprs)
                futures = [next(session) for session in sub_sessions]
            loss = dn.esum([next(session) for session in sessions]) / len(sessions)

            # update
            print_logger.total_loss += loss.scalar_value()
            loss.backward()
            self.optimizer.update()
            dn.renew_cg()
コード例 #2
0
    def train_gen(self, graphs, update=True, extra=None):
        """
        :type graphs: list[graph_utils.Graph]
        """
        self.logger = PrintLogger()
        self.network.sent_embedding.rnn.set_dropout(self.options.lstm_dropout)
        print_per = (100 // self.options.batch_size + 1) * self.options.batch_size

        if extra is not None:
            for sentence_idx, batch_idx, batch_sentences in split_to_batches(
                    extra, self.options.batch_size):
                if sentence_idx % print_per == 0 and sentence_idx != 0:
                    self.logger.print(sentence_idx)
                sessions = [self.training_session(sentence, self.logger, loose_var=self.options.loose)
                            for sentence in batch_sentences]
                all_exprs = [next(i) for i in sessions]
                if all_exprs:
                    dn.forward(all_exprs)
                all_labels_exprs = [j for i in sessions for j in next(i)]
                if all_labels_exprs:
                    dn.forward(all_labels_exprs)
                loss = sum(next(i) for i in sessions) / len(sessions)
                self.logger.total_loss_value += loss.value()
                if update:
                    loss.backward()
                    self.trainer.update()
                    dn.renew_cg()
                    sessions.clear()

        for sentence_idx, batch_idx, batch_sentences in split_to_batches(
                graphs, self.options.batch_size):
            if sentence_idx % print_per == 0 and sentence_idx != 0:
                self.logger.print(sentence_idx)
            sessions = [self.training_session(sentence, self.logger)
                        for sentence in batch_sentences]
            all_exprs = [next(i) for i in sessions]
            if all_exprs:
                dn.forward(all_exprs)
            all_labels_exprs = [j for i in sessions for j in next(i)]
            if all_labels_exprs:
                dn.forward(all_labels_exprs)
            loss = sum(next(i) for i in sessions) / len(sessions)
            self.logger.total_loss_value += loss.value()
            if update:
                loss.backward()
                self.trainer.update()
                dn.renew_cg()
                sessions.clear()
            yield (loss if not update else None)
コード例 #3
0
    def train_gen(self, sentences, update=True):
        print_logger = PrintLogger()
        pool = Pool(self.options.concurrent_count)
        self.network.sent_embedding.rnn.set_dropout(self.options.lstm_dropout)
        print_per = (100 // self.options.batch_size + 1) * self.options.batch_size

        for sentence_idx, batch_idx, batch_sentences in split_to_batches(
                sentences, self.options.batch_size):
            if sentence_idx % print_per == 0 and sentence_idx != 0:
                print_logger.print(sentence_idx)
            sessions = [self.training_session(sentence, print_logger, pool)
                        for sentence in batch_sentences]
            all_exprs = [next(i) for i in sessions]
            if all_exprs:
                dn.forward(all_exprs)
            # spawn decoders
            for i in sessions:
                next(i)
            all_labels_exprs = [j for i in sessions for j in next(i)]
            if all_labels_exprs:
                dn.forward(all_labels_exprs)
            loss = sum(next(i) for i in sessions) / len(sessions)
            print_logger.total_loss_value += loss.value()
            if update:
                loss.backward()
                self.optimizer.update()
                dn.renew_cg()
            yield (loss if not update else None)
コード例 #4
0
ファイル: leaftagger.py プロジェクト: draplater/hrg-parser
 def predict(self, trees):
     self.sent_embeddings.rnn.disable_dropout()
     for sentence_idx, batch_idx, batch_trees in split_to_batches(
             trees, self.options.batch_size):
         sessions = [self.predict_session(tree) for tree in batch_trees]
         exprs = [i for session in sessions for i in next(session)]
         dn.forward(exprs)
         for session in sessions:
             yield next(session)
         dn.renew_cg()
コード例 #5
0
 def predict(self, graphs):
     self.network.sent_embedding.rnn.disable_dropout()
     for sentence_idx, batch_idx, batch_sentences in split_to_batches(
             graphs, self.options.batch_size):
         sessions = [self.predict_session(sentence)
                     for sentence in batch_sentences]
         all_exprs = [next(i) for i in sessions]
         if all_exprs:
             dn.forward(all_exprs)
         all_labels_exprs = [j for i in sessions for j in next(i)]
         if all_labels_exprs:
             dn.forward(all_labels_exprs)
         for i in sessions:
             yield next(i)
         dn.renew_cg()
コード例 #6
0
    def train(self, sentences):
        self.span_ebd_network.rnn.set_dropout(self.options.lstm_dropout)
        self.span_ebd_network.init_special()
        pool = ThreadPool(self.options.concurrent_count)
        print_logger = PrintLogger()
        print_per = (100 // self.options.batch_size +
                     1) * self.options.batch_size

        for sentence_idx, batch_idx, batch_trees in split_to_batches(
                sentences, self.options.batch_size):
            if sentence_idx != 0 and sentence_idx % print_per == 0:
                print_logger.print(sentence_idx)

            self.span_ebd_network.init_special()
            sessions = [
                self.train_session(tree, print_logger, pool, decoder)
                for tree, decoder in zip(batch_trees, self.decoders)
            ]

            batch_size_2 = int(math.ceil(len(sessions) / 2) + 0.5)
            assert batch_size_2 * 2 >= len(sessions)
            for _, _, sub_sessions in split_to_batches(sessions, batch_size_2):
                # stage1: generate all expressions and forward
                expressions = [j for i in sub_sessions for j in next(i)]
                dn.forward(expressions)
                # stage2: spawn all decoders
                for session in sub_sessions:
                    next(session)
            # stage3: get all losses
            loss = dn.esum([next(session) for session in sessions])
            loss /= len(sessions)
            print_logger.total_loss += loss.value()
            loss.backward()
            self.optimizer.update()
            dn.renew_cg()
            self.span_ebd_network.init_special()
コード例 #7
0
ファイル: hrg_parser.py プロジェクト: draplater/hrg-parser
 def predict(self, trees, return_derivation=False):
     print_logger = PrintLogger()
     for sent_idx, batch_idx, batch_trees in split_to_batches(
             trees, self.options.batch_size):
         sessions = [self.training_session(tree, print_logger)
                     for tree in batch_trees]
         exprs = [expr for session in sessions for expr in next(session)]
         dn.forward(exprs)
         for tree, session in zip(batch_trees, sessions):
             final_beam_item = next(session)
             graph = final_beam_item.sub_graph.graph
             if return_derivation:
                 yield tree.extra["ID"], graph, list(self.construct_derivation(final_beam_item))
             else:
                 yield tree.extra["ID"], graph
         dn.renew_cg()
コード例 #8
0
ファイル: hrg_parser.py プロジェクト: draplater/hrg-parser
 def train(self, trees):
     print_logger = PrintLogger()
     print_per = (100 // self.options.batch_size + 1) * self.options.batch_size
     for sent_idx, batch_idx, batch_trees in split_to_batches(
             trees, self.options.batch_size):
         if sent_idx % print_per == 0 and sent_idx != 0:
             print_logger.print(sent_idx)
         sessions = [self.training_session(tree, print_logger,
                                           self.derivations[tree.extra["ID"]])
                     for tree in batch_trees]
         exprs = [expr for session in sessions for expr in next(session)]
         dn.forward(exprs)
         loss = sum(next(session) for session in sessions) / len(sessions)
         print_logger.total_loss += loss.value()
         loss.backward()
         self.optimizer.update()
         dn.renew_cg()
コード例 #9
0
 def predict(self, sentences):
     self.network.sent_embedding.rnn.disable_dropout()
     pool = Pool(self.options.concurrent_count)
     for sentence_idx, batch_idx, batch_sentences in split_to_batches(
             sentences, self.options.test_batch_size):
         sessions = [self.predict_session(sentence, pool)
                     for sentence in batch_sentences]
         all_exprs = [next(i) for i in sessions]
         if all_exprs:
             dn.forward(all_exprs)
         # spawn decoders
         for i in sessions:
             next(i)
         all_labels_exprs = [j for i in sessions for j in next(i)]
         if all_labels_exprs:
             dn.forward(all_labels_exprs)
         for i in sessions:
             yield next(i)
         dn.renew_cg()
コード例 #10
0
 def predict(self, trees):
     self.span_ebd_network.rnn.disable_dropout()
     pool = ThreadPool(self.options.concurrent_count)
     for sentence_idx, batch_idx, batch_trees in split_to_batches(
             trees, len(self.decoders)):
         self.span_ebd_network.init_special()
         sessions = [
             self.predict_session(tree, pool, decoder)
             for tree, decoder in zip(batch_trees, self.decoders)
         ]
         # stage1: generate all expressions and forward
         expressions = [j for i in sessions for j in next(i)]
         dn.forward(expressions)
         # stage2: spawn all decoders
         for session in sessions:
             next(session)
         # stage3: get all results
         for session in sessions:
             yield next(session)
         dn.renew_cg()
コード例 #11
0
ファイル: leaftagger.py プロジェクト: draplater/hrg-parser
    def train(self, trees, data_train=None):
        print_logger = PrintLogger()
        print_per = (100 // self.options.batch_size +
                     1) * self.options.batch_size
        self.sent_embeddings.rnn.set_dropout(self.options.lstm_dropout)
        for sentence_idx, batch_idx, batch_trees in split_to_batches(
                trees, self.options.batch_size):
            if sentence_idx % print_per == 0 and sentence_idx != 0:
                print_logger.print(sentence_idx)
            sessions = [
                self.training_session(tree, print_logger)
                for tree in batch_trees
            ]
            exprs = [i for session in sessions for i in next(session)]
            dn.forward(exprs)
            loss = dn.esum([next(session)
                            for session in sessions]) / len(sessions)

            # update
            print_logger.total_loss += loss.scalar_value()
            loss.backward()
            self.optimizer.update()
            dn.renew_cg()