Ejemplo n.º 1
0
    def __init__(self, options, train_graphs=None, restore_model_and_saveable=None, train_extra=None):
        random.seed(1)

        self.decoder = graph_decoders[options.decoder](options)
        self.test_decoder = graph_decoders[options.test_decoder](options) \
            if options.test_decoder is not None \
            else self.decoder
        self.cost_augment = cost_augmentors[options.cost_augment](options)

        self.options = options
        if "func" in options:
            del options.func

        self.labelsFlag = options.labelsFlag

        if restore_model_and_saveable:
            self.model, self.network = restore_model_and_saveable
        else:
            self.model = dn.Model()
            self.trainer = nn.trainers[options.trainer](self.model)
            if train_extra is not None:
                statistics = Statistics.from_sentences(train_graphs + train_extra)
            else:
                statistics = Statistics.from_sentences(train_graphs)
            self.network = EdgeEvaluationNetwork(self.model, statistics, options)
Ejemplo n.º 2
0
 def __init__(self, model, statistics, options):
     super(EdgeDoubleEvaluationNetwork, self).__init__(model)
     self.share_embedding = options.share_embedding
     self.share_lstm = options.share_lstm
     self.graph_network = EdgeEvaluationNetwork(model, statistics[0],
                                                options)
     self.tree_network = EdgeEvaluationNetwork(model, statistics[1],
                                               options)
     self.set_share_parameter(model)
Ejemplo n.º 3
0
 def add_parser_arguments(cls, arg_parser):
     super(MaxSubTreeParser, cls).add_parser_arguments(arg_parser)
     group = arg_parser.add_argument_group(cls.__name__)
     group.add_argument("--optimizer", type=str, dest="optimizer", default="adam", choices=nn.trainers.keys())
     group.add_argument("--decoder", type=str, dest="decoder", default="eisner", choices=decoders)
     group.add_argument("--cost-augment", action="store_true", dest="cost_augment", default=True)
     group.add_argument("--batch-size", type=int, dest="batch_size", default=cls.default_batch_size)
     group.add_argument("--model-format", dest="model_format", choices=nn.model_formats, default="pickle")
     EdgeEvaluationNetwork.add_parser_arguments(arg_parser)
Ejemplo n.º 4
0
 def __init__(self, options, train_sentences=None):
     self.model = dn.Model()
     self.statistics = Statistics.from_sentences(train_sentences)
     self.container = nn.Container(self.model)
     self.network = EdgeEvaluationNetwork(self.container, self.statistics, options)
     self.optimizer = nn.get_optimizer(self.model, options)
     self.decoder = decoders[options.decoder]
     self.label_decoder = label_decoders[options.label_decoder]
     self.labelsFlag = options.labelsFlag
     self.options = options
     if "func" in options:
         del options.func
Ejemplo n.º 5
0
 def add_parser_arguments(cls, arg_parser):
     """:type arg_parser: argparse.ArgumentParser"""
     EdgeEvaluationNetwork.add_parser_arguments(arg_parser)
     group = arg_parser.add_argument_group(cls.__name__)
     group.add_argument("--not-share-embedding",
                        action="store_false",
                        dest="share_embedding",
                        default=True)
     group.add_argument("--not-share-lstm",
                        action="store_false",
                        dest="share_lstm",
                        default=True)
Ejemplo n.º 6
0
    def __init__(self,
                 options,
                 train_sentences=None,
                 restore_file=None,
                 statistics=None):
        self.model = dn.Model()

        random.seed(1)
        self.trainer = dn.AdamTrainer(self.model)

        self.activation = activations[options.activation]
        # self.decoder = decoders[options.decoder]

        self.labelsFlag = options.labelsFlag
        self.costaugFlag = options.cost_augment
        self.options = options

        if "func" in options:
            del options.func

        if restore_file:
            self.container, = dn.load(restore_file, self.model)
            networks = list(self.container.components)
            self.network = networks.pop(0)
            self.statistics = statistics
            self.has_emptys = len(statistics.emptys) > 0
            if self.has_emptys:
                self.network_for_emptys = networks.pop(0)
            if self.options.use_2nd:
                self.network3 = networks.pop(0)
                if self.has_emptys:
                    self.network3_for_emptys_mid = networks.pop(0)
                    self.network3_for_emptys_out = networks.pop(0)
            assert not networks
        else:
            self.container = nn.Container(self.model)
            self.statistics = statistics = StatisticsWithEmpty.from_sentences(
                train_sentences)
            self.has_emptys = len(statistics.emptys) > 0
            self.network = EdgeEvaluationNetwork(self.container, statistics,
                                                 options)
            if self.has_emptys:
                self.network_for_emptys = EdgeEvaluation(
                    self.container, options)
            if options.use_2nd:
                self.network3 = EdgeSiblingEvaluation(self.container, options)
                if self.has_emptys:
                    self.network3_for_emptys_mid = EdgeSiblingEvaluation(
                        self.container, options)
                    self.network3_for_emptys_out = EdgeSiblingEvaluation(
                        self.container, options)
Ejemplo n.º 7
0
    def add_parser_arguments(cls, arg_parser):
        super(MaxSubTreeParser, cls).add_parser_arguments(arg_parser)

        group = arg_parser.add_argument_group(cls.__name__)
        group.add_argument("--decoder",
                           type=str,
                           dest="decoder",
                           default="eisner2nd",
                           choices=["eisner2nd"])
        group.add_argument("--cost-augment",
                           action="store_true",
                           dest="cost_augment",
                           default=True)

        EdgeEvaluationNetwork.add_parser_arguments(arg_parser)
Ejemplo n.º 8
0
    def add_parser_arguments(cls, arg_parser):
        """:type arg_parser: argparse.ArgumentParser"""
        super(MaxSubGraphParser, cls).add_parser_arguments(arg_parser)

        group = arg_parser.add_argument_group(cls.__name__)
        group.add_argument("--trainer", type=str, dest="trainer", default="adam", choices=nn.trainers.keys())
        group.add_argument("--cost-augment", type=str, dest="cost_augment", default="hamming", choices=cost_augmentors)
        group.add_argument("--decoder", type=str, dest="decoder", default="arcfactor", choices=graph_decoders.keys())
        group.add_argument("--predict-decoder", type=str, dest="test_decoder", default=None)
        group.add_argument("--hamming-a", type=float, dest="hamming_a", default=0.4)
        group.add_argument("--hamming-b", type=float, dest="hamming_b", default=0.6)
        group.add_argument("--vine-arc-length", type=int, dest="vine_arc_length", default=20)
        group.add_argument("--basic-costaug-decrease", type=int, dest="basic_costaug_decrease", default=1)
        group.add_argument("--loose-value", type=float, dest="loose", default=-1)
        group.add_argument("--delta", type=float, dest="delta", default=1)

        EdgeEvaluationNetwork.add_parser_arguments(arg_parser)
Ejemplo n.º 9
0
    def __init__(self, options, train_sentences=None, restore_file=None):
        self.model = dn.Model()
        if restore_file:
            old_options, self.network = nn.model_load_helper(None, restore_file, self.model)
            if options is not None:
                old_options.__dict__.update(options.__dict__)
                options = old_options
        else:
            statistics = Statistics.from_sentences(train_sentences)
            self.network = EdgeEvaluationNetwork(self.model, statistics, options)

        self.optimizer = nn.get_optimizer(self.model, options)
        self.decoder = decoders[options.decoder]
        self.labelsFlag = options.labelsFlag
        self.options = options

        if "func" in options:
            del options.func
Ejemplo n.º 10
0
    def add_parser_arguments(cls, arg_parser):
        super(MaxSubTreeWithEmptyParser, cls).add_parser_arguments(arg_parser)

        group = arg_parser.add_argument_group(cls.__name__)
        group.add_argument("--optimizer",
                           type=str,
                           default="adam",
                           choices=nn.trainers.keys())
        group.add_argument("--decoder",
                           type=str,
                           dest="decoder",
                           default="eisner2nd",
                           choices=["eisner2nd"])
        group.add_argument("--disable-cost-augment",
                           action="store_false",
                           dest="cost_augment",
                           default=True)
        group.add_argument("--enable-2nd",
                           action="store_true",
                           dest="use_2nd",
                           default=False)

        EdgeEvaluationNetwork.add_parser_arguments(arg_parser)
Ejemplo n.º 11
0
    def __init__(self,
                 options,
                 train_sentences=None,
                 restore_file=None,
                 statistics=None):
        self.model = dn.Model()

        random.seed(1)
        self.optimizer = nn.get_optimizer(self.model, options)

        self.activation = activations[options.activation]
        # self.decoder = decoders[options.decoder]

        self.labelsFlag = options.labelsFlag
        self.costaugFlag = options.cost_augment
        self.options = options

        if "func" in options:
            del options.func

        self.container = nn.Container(self.model)
        self.statistics = statistics = StatisticsWithEmpty.from_sentences(
            train_sentences)
        self.has_emptys = len(statistics.emptys) > 0
        self.network = EdgeEvaluationNetwork(self.container, statistics,
                                             options)
        if self.has_emptys:
            self.network_for_emptys = EdgeEvaluation(self.container, options)
            self.label_network_for_emptys = ToEmptyLabelEvaluation(
                self.container, self.statistics.emptys, options)
        if options.use_2nd:
            self.network3 = EdgeSiblingEvaluation(self.container, options)
            if self.has_emptys:
                self.network3_for_emptys_mid = EdgeSiblingEvaluation(
                    self.container, options)
                self.network3_for_emptys_out = EdgeSiblingEvaluation(
                    self.container, options)
Ejemplo n.º 12
0
    def __init__(self, options, train_sentences=None, restore_file=None):
        self.model = dn.Model()
        random.seed(1)
        self.trainer = dn.AdamTrainer(self.model)

        self.activation = activations[options.activation]
        self.decoder = decoders[options.decoder]

        self.labelsFlag = options.labelsFlag
        self.costaugFlag = options.cost_augment
        self.options = options

        if "func" in options:
            del options.func

        if restore_file:
            self.container, = dn.load(restore_file, self.model)
            self.network, self.network3 = self.container.components
        else:
            self.container = nn.Container(self.model)
            statistics = Statistics.from_sentences(train_sentences)
            self.network = EdgeEvaluationNetwork(self.container, statistics,
                                                 options)
            self.network3 = EdgeSiblingEvaluation(self.container, options)
Ejemplo n.º 13
0
class MaxSubGraphParser(GraphParserBase):
    @classmethod
    def add_parser_arguments(cls, arg_parser):
        """:type arg_parser: argparse.ArgumentParser"""
        super(MaxSubGraphParser, cls).add_parser_arguments(arg_parser)

        group = arg_parser.add_argument_group(cls.__name__)
        group.add_argument("--trainer", type=str, dest="trainer", default="adam", choices=nn.trainers.keys())
        group.add_argument("--cost-augment", type=str, dest="cost_augment", default="hamming", choices=cost_augmentors)
        group.add_argument("--decoder", type=str, dest="decoder", default="arcfactor", choices=graph_decoders.keys())
        group.add_argument("--predict-decoder", type=str, dest="test_decoder", default=None)
        group.add_argument("--hamming-a", type=float, dest="hamming_a", default=0.4)
        group.add_argument("--hamming-b", type=float, dest="hamming_b", default=0.6)
        group.add_argument("--vine-arc-length", type=int, dest="vine_arc_length", default=20)
        group.add_argument("--basic-costaug-decrease", type=int, dest="basic_costaug_decrease", default=1)
        group.add_argument("--loose-value", type=float, dest="loose", default=-1)
        group.add_argument("--delta", type=float, dest="delta", default=1)

        EdgeEvaluationNetwork.add_parser_arguments(arg_parser)

    @classmethod
    def add_common_arguments(cls, arg_parser):
        super(MaxSubGraphParser, cls).add_common_arguments(arg_parser)
        group = arg_parser.add_argument_group(cls.__name__ + " (common)")
        group.add_argument("--batch-size", type=int, dest="batch_size", default=32)

    def __init__(self, options, train_graphs=None, restore_model_and_saveable=None, train_extra=None):
        random.seed(1)

        self.decoder = graph_decoders[options.decoder](options)
        self.test_decoder = graph_decoders[options.test_decoder](options) \
            if options.test_decoder is not None \
            else self.decoder
        self.cost_augment = cost_augmentors[options.cost_augment](options)

        self.options = options
        if "func" in options:
            del options.func

        self.labelsFlag = options.labelsFlag

        if restore_model_and_saveable:
            self.model, self.network = restore_model_and_saveable
        else:
            self.model = dn.Model()
            self.trainer = nn.trainers[options.trainer](self.model)
            if train_extra is not None:
                statistics = Statistics.from_sentences(train_graphs + train_extra)
            else:
                statistics = Statistics.from_sentences(train_graphs)
            self.network = EdgeEvaluationNetwork(self.model, statistics, options)

    def predict_session(self, sentence):
        """
        step 1: yield all edge expressions
        step 2: yield all label expressions
        step 3: yield result graph
        """
        lstm_output = self.network.get_lstm_output(sentence)
        length = len(sentence)
        raw_exprs = self.network.edge_eval.get_complete_raw_exprs(lstm_output)
        yield raw_exprs

        scores = self.network.edge_eval.raw_exprs_to_scores(raw_exprs, length)
        output_graph = self.test_decoder(scores)

        edges = []
        for source_id in range(len(sentence)):
            for target_id in range(len(sentence)):
                if target_id == 0:  # avoid edges pointed to root
                    continue
                if output_graph[source_id][target_id]:
                    edges.append(graph_utils.Edge(source_id, "X", target_id))

        if self.labelsFlag:
            labeled_edges = []
            labels_exprs = list(self.network.label_eval.get_label_scores(lstm_output, edges))
            yield labels_exprs
            for edge, r_scores_expr in zip(edges, labels_exprs):
                r_scores = r_scores_expr.value()
                label_index = max(((l, scr) for l, scr in enumerate(r_scores)), key=itemgetter(1))[0]
                label = self.network.irels[label_index]
                labeled_edges.append(graph_utils.Edge(edge.source, label, edge.target))
            edges = labeled_edges
        else:
            yield []

        result = sentence.replaced_edges(edges)
        yield result

    def predict(self, graphs):
        self.network.sent_embedding.rnn.disable_dropout()
        for sentence_idx, batch_idx, batch_sentences in split_to_batches(
                graphs, self.options.batch_size):
            sessions = [self.predict_session(sentence)
                        for sentence in batch_sentences]
            all_exprs = [next(i) for i in sessions]
            if all_exprs:
                dn.forward(all_exprs)
            all_labels_exprs = [j for i in sessions for j in next(i)]
            if all_labels_exprs:
                dn.forward(all_labels_exprs)
            for i in sessions:
                yield next(i)
            dn.renew_cg()


    def training_session(self, sentence, print_logger, loose_var=-1):
        """
        step 1: yield all edge expressions
        step 2: yield all label expressions
        step 3: yield loss
        """
        lstm_output = self.network.get_lstm_output(sentence)
        length = len(sentence)
        raw_exprs = self.network.edge_eval.get_complete_raw_exprs(lstm_output)
        yield raw_exprs

        scores = self.network.edge_eval.raw_exprs_to_scores(raw_exprs, length)
        exprs = self.network.edge_eval.raw_exprs_to_exprs(raw_exprs, length)

        self.cost_augment(scores, sentence)

        output_graph = self.decoder(scores)
        gold_graph = sentence.to_matrix()

        label_loss = dn.scalarInput(0.0)
        if self.labelsFlag:
            edges = list(sentence.generate_edges())
            labels_exprs = list(self.network.label_eval.get_label_scores(lstm_output, edges))
            yield labels_exprs
            for edge, r_scores_expr \
                    in zip(edges, labels_exprs):
                head, label, modifier = edge
                r_scores = r_scores_expr.value()
                gold_label_index = self.network.rels[label]
                wrong_label_index = max(((l, scr)
                                         for l, scr in enumerate(r_scores)
                                         if l != gold_label_index), key=itemgetter(1))[0]
                # if loose_var is set, we could do something to loose the update of tagging
                delta = self.options.delta if loose_var > 0 else 1
                if r_scores[gold_label_index] < r_scores[wrong_label_index] + delta:
                    label_loss += r_scores_expr[wrong_label_index] - r_scores_expr[gold_label_index] + 1
        else:
            yield []

        edge_loss = dn.scalarInput(0.0)
        for source_id in range(len(sentence)):
            for target_id in range(len(sentence)):
                gold_exist = gold_graph[source_id][target_id]
                output_exist = output_graph[source_id][target_id]
                if gold_exist and output_exist:
                    print_logger.total_gold_edge += 1
                    print_logger.total_predict_edge += 1
                    print_logger.correct_predict_edge += 1
                    print_logger.recalled_gold_edge += 1
                elif not gold_exist and not output_exist:
                    pass
                elif gold_exist and not output_exist:
                    print_logger.total_gold_edge += 1
                    if loose_var > 0 and scores[source_id][target_id] > -loose_var: #-0.1:
                        pass
                    else:
                        edge_loss -= exprs[source_id][target_id]
                elif not gold_exist and output_exist:
                    print_logger.total_predict_edge += 1
                    if loose_var > 0 and scores[source_id][target_id] < loose_var: #0.05:
                        pass
                    else:
                        edge_loss += exprs[source_id][target_id]
                else:
                    raise SystemError()

        loss_shift = self.cost_augment.get_loss_shift(output_graph, gold_graph)
        loss = label_loss + edge_loss + loss_shift
        yield loss

    def train_gen(self, graphs, update=True, extra=None):
        """
        :type graphs: list[graph_utils.Graph]
        """
        self.logger = PrintLogger()
        self.network.sent_embedding.rnn.set_dropout(self.options.lstm_dropout)
        print_per = (100 // self.options.batch_size + 1) * self.options.batch_size

        if extra is not None:
            for sentence_idx, batch_idx, batch_sentences in split_to_batches(
                    extra, self.options.batch_size):
                if sentence_idx % print_per == 0 and sentence_idx != 0:
                    self.logger.print(sentence_idx)
                sessions = [self.training_session(sentence, self.logger, loose_var=self.options.loose)
                            for sentence in batch_sentences]
                all_exprs = [next(i) for i in sessions]
                if all_exprs:
                    dn.forward(all_exprs)
                all_labels_exprs = [j for i in sessions for j in next(i)]
                if all_labels_exprs:
                    dn.forward(all_labels_exprs)
                loss = sum(next(i) for i in sessions) / len(sessions)
                self.logger.total_loss_value += loss.value()
                if update:
                    loss.backward()
                    self.trainer.update()
                    dn.renew_cg()
                    sessions.clear()

        for sentence_idx, batch_idx, batch_sentences in split_to_batches(
                graphs, self.options.batch_size):
            if sentence_idx % print_per == 0 and sentence_idx != 0:
                self.logger.print(sentence_idx)
            sessions = [self.training_session(sentence, self.logger)
                        for sentence in batch_sentences]
            all_exprs = [next(i) for i in sessions]
            if all_exprs:
                dn.forward(all_exprs)
            all_labels_exprs = [j for i in sessions for j in next(i)]
            if all_labels_exprs:
                dn.forward(all_labels_exprs)
            loss = sum(next(i) for i in sessions) / len(sessions)
            self.logger.total_loss_value += loss.value()
            if update:
                loss.backward()
                self.trainer.update()
                dn.renew_cg()
                sessions.clear()
            yield (loss if not update else None)

    def train(self, graphs, extra=None):
        for _ in self.train_gen(graphs, extra=extra):
            pass

    def save(self, prefix):
        nn.model_save_helper("pickle", prefix, self.network, self.options)

    @classmethod
    def load(cls, prefix, new_options=None):
        """
        :param prefix: model file name prefix
        :type prefix: str
        :rtype: MaxSubGraphParser
        """
        model = dn.Model()
        options, savable = nn.model_load_helper(None, prefix, model)
        options.__dict__.update(new_options.__dict__)
        ret = cls(options, None, (model, savable))
        return ret
Ejemplo n.º 14
0
class MaxSubTreeWithEmptyParser(TreeWithEmptyTrainer):
    @classmethod
    def add_parser_arguments(cls, arg_parser):
        super(MaxSubTreeWithEmptyParser, cls).add_parser_arguments(arg_parser)

        group = arg_parser.add_argument_group(cls.__name__)
        group.add_argument("--decoder",
                           type=str,
                           dest="decoder",
                           default="eisner2nd",
                           choices=["eisner2nd"])
        group.add_argument("--disable-cost-augment",
                           action="store_false",
                           dest="cost_augment",
                           default=True)
        group.add_argument("--enable-2nd",
                           action="store_true",
                           dest="use_2nd",
                           default=False)

        EdgeEvaluationNetwork.add_parser_arguments(arg_parser)

    def __init__(self,
                 options,
                 train_sentences=None,
                 restore_file=None,
                 statistics=None):
        self.model = dn.Model()

        random.seed(1)
        self.trainer = dn.AdamTrainer(self.model)

        self.activation = activations[options.activation]
        # self.decoder = decoders[options.decoder]

        self.labelsFlag = options.labelsFlag
        self.costaugFlag = options.cost_augment
        self.options = options

        if "func" in options:
            del options.func

        if restore_file:
            self.container, = dn.load(restore_file, self.model)
            networks = list(self.container.components)
            self.network = networks.pop(0)
            self.statistics = statistics
            self.has_emptys = len(statistics.emptys) > 0
            if self.has_emptys:
                self.network_for_emptys = networks.pop(0)
            if self.options.use_2nd:
                self.network3 = networks.pop(0)
                if self.has_emptys:
                    self.network3_for_emptys_mid = networks.pop(0)
                    self.network3_for_emptys_out = networks.pop(0)
            assert not networks
        else:
            self.container = nn.Container(self.model)
            self.statistics = statistics = StatisticsWithEmpty.from_sentences(
                train_sentences)
            self.has_emptys = len(statistics.emptys) > 0
            self.network = EdgeEvaluationNetwork(self.container, statistics,
                                                 options)
            if self.has_emptys:
                self.network_for_emptys = EdgeEvaluation(
                    self.container, options)
            if options.use_2nd:
                self.network3 = EdgeSiblingEvaluation(self.container, options)
                if self.has_emptys:
                    self.network3_for_emptys_mid = EdgeSiblingEvaluation(
                        self.container, options)
                    self.network3_for_emptys_out = EdgeSiblingEvaluation(
                        self.container, options)

    def predict(self, sentence):
        """ :type sentence: SentenceWithEmpty"""
        for iSentence, sentence in enumerate(sentence):
            if len(sentence) >= MAX_SENT_SIZE_EMPTY - 1:
                logger.info("sent too long...")
                heads = [0 for _ in range(len(sentence))]
                labels = [None for _ in range(len(sentence))]
                yield sentence.to_simple_sentence(heads, labels, set())
                continue

            lstm_output = self.network.get_lstm_output(sentence)
            scores, exprs = self.network.get_complete_scores(lstm_output)
            labels = [None for _ in range(len(sentence))]
            if self.has_emptys:
                scores_ec, exprs_ec = self.network_for_emptys.get_complete_scores(
                    lstm_output)
            else:
                scores_ec = -np.ones(
                    (len(lstm_output), len(lstm_output)), dtype=np.float64)

            scores_all = np.stack([scores, scores_ec])
            if self.options.use_2nd:
                exprs2nd2, scores2nd2, exprs2nd3, scores2nd3 = self.network3.get_complete_scores(
                    lstm_output)
                if self.has_emptys:
                    _, _, exprs_mid, scores_mid = self.network3_for_emptys_mid.get_complete_scores(
                        lstm_output, False)
                    exprs2nd2_ec, scores2nd2_ec, exprs_out, scores_out = self.network3_for_emptys_out.get_complete_scores(
                        lstm_output)
                    scores_all_ec = np.stack([scores2nd2, scores2nd2_ec])
                    heads, emptys = emptyeisner2nd(scores_all, scores_all_ec,
                                                   scores2nd3, scores_mid,
                                                   scores_out)
                else:
                    heads = eisner2nd(scores, scores2nd2, scores2nd3)
                    emptys = set()
            else:
                heads, emptys = empty_eisner_greedy(scores_all)

            if self.labelsFlag:
                edges = [(head, "_", modifier)
                         for modifier, head in enumerate(heads[1:], 1)]
                for edge, scores_expr in \
                        zip(edges, self.network.get_label_scores(lstm_output, edges)):
                    head, _, modifier = edge
                    labels_scores = scores_expr.value()
                    labels[modifier] = \
                        self.network.irels[max(enumerate(labels_scores), key=itemgetter(1))[0]]

            dn.renew_cg()

            result = sentence.to_simple_sentence(heads, labels, emptys)
            if self.options.output_scores:
                # extract full edge scores
                if self.has_emptys:
                    scores_list = [
                        scores, scores_ec, scores2nd2, scores2nd3,
                        scores2nd2_ec, scores_mid, scores_out
                    ]
                else:
                    scores_list = [
                        scores, None, scores2nd2, scores2nd3, None, None, None
                    ]
                result.comment = [
                    base64.b64encode(pickle.dumps(scores_list)).decode()
                ]
            yield result

    # noinspection PyUnboundLocalVariable
    def train_gen(self, sentences, update=True):
        eloss = 0.0
        mloss = 0.0
        eerrors = 0
        etotal = 0
        start = time.time()

        errs = []
        lerrs = []

        empty_correct_count = 0
        empty_gold_total = 0.00000001
        empty_pred_total = 0.00000001

        for sent_idx, sentence in enumerate(sentences):
            if sent_idx % 100 == 0 and sent_idx != 0:
                logger.info(
                    "Processing sentence number: %d, Loss: %.2f,"
                    "Empty-R: %.2f, Empty-P: %.2f, Errors: %.2f, Time: %.2f",
                    sent_idx, eloss / etotal,
                    empty_correct_count / empty_gold_total,
                    empty_correct_count / empty_pred_total,
                    (float(eerrors)) / etotal,
                    time.time() - start)
                start = time.time()
                eerrors = 0
                eloss = 0.0
                etotal = 0
                empty_gold_total = 0.00000001
                empty_pred_total = 0.00000001
                empty_correct_count = 0

            if len(sentence) >= MAX_SENT_SIZE_EMPTY - 1:
                logger.info("sent too long...")
                continue

            empty_gold_total += len(sentence.empty_nodes)
            lstm_output = self.network.get_lstm_output(sentence)
            scores, exprs = self.network.get_complete_scores(lstm_output)
            if self.has_emptys:
                scores_ec, exprs_ec = self.network_for_emptys.get_complete_scores(
                    lstm_output)
            else:
                scores_ec = -np.ones(
                    (len(lstm_output), len(lstm_output)), dtype=np.float64)

            gold = [entry.parent_id for entry in sentence]
            scores_all = np.stack([scores, scores_ec])

            if self.options.use_2nd:
                exprs2nd2, scores2nd2, exprs2nd3, scores2nd3 = self.network3.get_complete_scores(
                    lstm_output)
                if self.has_emptys:
                    _, _, exprs_mid, scores_mid = self.network3_for_emptys_mid.get_complete_scores(
                        lstm_output, False)
                    exprs2nd2_ec, scores2nd2_ec, exprs_out, scores_out = \
                        self.network3_for_emptys_out.get_complete_scores(lstm_output)
                    scores_all_ec = np.stack([scores2nd2, scores2nd2_ec])
                    heads, emptys = emptyeisner2nd(
                        scores_all, scores_all_ec, scores2nd3, scores_mid,
                        scores_out, sentence if self.costaugFlag else None)
                else:
                    heads = eisner2nd(scores, scores2nd2, scores2nd3,
                                      gold if self.costaugFlag else None)
                    emptys = set()
            else:
                heads, emptys = empty_eisner_greedy(
                    scores_all, sentence if self.costaugFlag else None)
            empty_pred_total += len(emptys)

            if self.labelsFlag:
                edges = [(head, "_", modifier)
                         for modifier, head in enumerate(gold[1:], 1)]
                for edge, r_scores_expr in \
                        zip(edges, self.network.get_label_scores(lstm_output, edges)):
                    head, _, modifier = edge
                    r_scores = r_scores_expr.value()
                    gold_label_index = self.network.rels[
                        sentence[modifier].relation]
                    wrong_label_index = max(((l, scr)
                                             for l, scr in enumerate(r_scores)
                                             if l != gold_label_index),
                                            key=itemgetter(1))[0]
                    if r_scores[gold_label_index] < r_scores[
                            wrong_label_index] + 1:
                        lerrs.append(r_scores_expr[wrong_label_index] -
                                     r_scores_expr[gold_label_index])

            e = sum([1 for h, g in zip(heads[1:], gold[1:]) if h != g])
            eerrors += e
            if e > 0:
                loss = [(exprs[h][i] - exprs[g][i])
                        for i, (h, g) in enumerate(zip(heads, gold))
                        if h != g]  # * (1.0/float(e))
                eloss += e
                mloss += e
                errs.extend(loss)

            if self.has_emptys:
                empty_incorrect = sentence.empty_nodes.symmetric_difference(
                    emptys)
                empty_correct = sentence.empty_nodes.intersection(emptys)
                empty_correct_count += len(empty_correct)
                for i in empty_incorrect:
                    if i not in sentence.empty_nodes:
                        errs.append(exprs_ec[i.head][i.position])
                    else:
                        errs.append(-exprs_ec[i.head][i.position])

            if self.options.use_2nd:
                # noinspection PyUnboundLocalVariable
                biarcs2, biarcs3 = arcs_with_empty_to_biarcs(heads, emptys)
                biarcs2_gold, biarcs3_gold = arcs_with_empty_to_biarcs(
                    gold, sentence.empty_nodes)
                for s, t in biarcs2.symmetric_difference(biarcs2_gold):
                    sign = 1 if (s, t) in biarcs2 else -1
                    if isinstance(t, int):
                        errs.append(exprs2nd2[s, t] * sign)
                    elif isinstance(t, EdgeToEmpty):
                        errs.append(exprs2nd2_ec[s, t.position] * sign)

                for s, m, t in biarcs3.symmetric_difference(biarcs3_gold):
                    sign = 1 if (s, m, t) in biarcs3 else -1
                    if isinstance(m, int) and isinstance(t, int):
                        errs.append(exprs2nd3[s, m, t] * sign)
                    elif isinstance(m, EdgeToEmpty) and isinstance(t, int):
                        errs.append(exprs_mid[s, m.position, t] * sign)
                    elif isinstance(m, int) and isinstance(t, EdgeToEmpty):
                        errs.append(exprs_out[s, m, t.position] * sign)
                    else:
                        raise TypeError

            etotal += len(sentence)

            if errs or lerrs:
                loss = dn.esum(errs + lerrs)  # * (1.0/(float(len(errs))))
            else:
                loss = dn.scalarInput(0.0)
            loss_value = loss.scalar_value()
            errs = []
            lerrs = []
            if loss_value != 0.0:
                if update:
                    loss.backward()
                    self.trainer.update()
                    dn.renew_cg()
            yield (loss if not update else None)

    def train(self, sentences):
        for _ in self.train_gen(sentences):
            pass

    def save(self, prefix):
        with open(prefix + ".options", "wb") as f:
            pickle.dump((self.options, self.statistics), f)
        # noinspection PyArgumentList
        dn.save(prefix, [self.container])

    @classmethod
    def load(cls, prefix, new_options=None):
        """
        :param prefix: model file name prefix
        :type prefix: str
        :rtype: MaxSubTreeWithEmptyParser
        """
        with open(prefix + ".options", "rb") as f:
            options, statistics = pickle.load(f)
        options.__dict__.update(new_options.__dict__)
        ret = cls(options, None, prefix, statistics)
        return ret
Ejemplo n.º 15
0
class MaxSubTreeParser(TreeParserBase):
    @classmethod
    def add_parser_arguments(cls, arg_parser):
        super(MaxSubTreeParser, cls).add_parser_arguments(arg_parser)

        group = arg_parser.add_argument_group(cls.__name__)
        group.add_argument("--decoder",
                           type=str,
                           dest="decoder",
                           default="eisner2nd",
                           choices=["eisner2nd"])
        group.add_argument("--cost-augment",
                           action="store_true",
                           dest="cost_augment",
                           default=True)

        EdgeEvaluationNetwork.add_parser_arguments(arg_parser)

    def __init__(self, options, train_sentences=None, restore_file=None):
        self.model = dn.Model()
        random.seed(1)
        self.trainer = dn.AdamTrainer(self.model)

        self.activation = activations[options.activation]
        self.decoder = decoders[options.decoder]

        self.labelsFlag = options.labelsFlag
        self.costaugFlag = options.cost_augment
        self.options = options

        if "func" in options:
            del options.func

        if restore_file:
            self.container, = dn.load(restore_file, self.model)
            self.network, self.network3 = self.container.components
        else:
            self.container = nn.Container(self.model)
            statistics = Statistics.from_sentences(train_sentences)
            self.network = EdgeEvaluationNetwork(self.container, statistics,
                                                 options)
            self.network3 = EdgeSiblingEvaluation(self.container, options)

    def predict(self, sentence):
        def convert_node(node, heads, labels):
            return CoNLLUNode(node.id, node.form, node.lemma, node.cpos,
                              node.pos, node.feats, heads[node.id],
                              labels[node.id], "_", "_")

        for iSentence, sentence in enumerate(sentence):
            if len(sentence) >= MAX_SENT_SIZE_NOEMPTY - 1:
                logger.info("sent too long...")
                heads = [0 for _ in range(len(sentence))]
                heads[0] = -1
                labels = [None for _ in range(len(sentence))]
                yield CoNLLUSentence(
                    convert_node(node, heads, labels) for node in sentence
                    if node.id > 0)
                continue
            lstm_output = self.network.get_lstm_output(sentence)
            scores, exprs = self.network.get_complete_scores(lstm_output)
            exprs2nd2, scores2nd2, exprs2nd3, scores2nd3 = self.network3.get_complete_scores(
                lstm_output)
            heads = self.decoder(scores, scores2nd2, scores2nd3)
            labels = [None for _ in range(len(sentence))]

            if self.labelsFlag:
                edges = [(head, "_", modifier)
                         for modifier, head in enumerate(heads[1:], 1)]
                for edge, scores_expr in \
                        zip(edges, self.network.get_label_scores(lstm_output, edges)):
                    head, _, modifier = edge
                    scores = scores_expr.value()
                    labels[modifier] = \
                        self.network.irels[max(enumerate(scores), key=itemgetter(1))[0]]

            dn.renew_cg()

            yield CoNLLUSentence(
                convert_node(node, heads, labels) for node in sentence
                if node.id > 0)

    def train_gen(self, sentences, update=True):
        eloss = 0.0
        mloss = 0.0
        eerrors = 0
        etotal = 0
        start = time.time()

        errs = []
        lerrs = []

        for sent_idx, sentence in enumerate(sentences):
            if len(sentence) >= MAX_SENT_SIZE_NOEMPTY - 1:
                logger.info("sent too long...")
                continue

            if sent_idx % 100 == 0 and sent_idx != 0:
                logger.info(
                    'Processing sentence number: %d, Loss: %.2f, Errors: %.2f, Time: %.2f',
                    sent_idx, eloss / etotal, (float(eerrors)) / etotal,
                    time.time() - start)
                start = time.time()
                eerrors = 0
                eloss = 0.0
                etotal = 0

            lstm_output = self.network.get_lstm_output(sentence)
            scores, exprs = self.network.get_complete_scores(lstm_output)
            exprs2nd2, scores2nd2, exprs2nd3, scores2nd3 = self.network3.get_complete_scores(
                lstm_output)

            gold = [entry.parent_id for entry in sentence]
            heads = self.decoder(scores, scores2nd2, scores2nd3,
                                 gold if self.costaugFlag else None)

            if self.labelsFlag:
                edges = [(head, "_", modifier)
                         for modifier, head in enumerate(gold[1:], 1)]
                for edge, r_scores_expr in \
                        zip(edges, self.network.get_label_scores(lstm_output, edges)):
                    head, _, modifier = edge
                    r_scores = r_scores_expr.value()
                    gold_label_index = self.network.rels[
                        sentence[modifier].relation]
                    wrong_label_index = max(((l, scr)
                                             for l, scr in enumerate(r_scores)
                                             if l != gold_label_index),
                                            key=itemgetter(1))[0]
                    if r_scores[gold_label_index] < r_scores[
                            wrong_label_index] + 1:
                        lerrs.append(r_scores_expr[wrong_label_index] -
                                     r_scores_expr[gold_label_index])

            e = sum([1 for h, g in zip(heads[1:], gold[1:]) if h != g])
            eerrors += e
            if e > 0:
                loss = [(exprs[h][i] - exprs[g][i])
                        for i, (h, g) in enumerate(zip(heads, gold))
                        if h != g]  # * (1.0/float(e))
                eloss += e
                mloss += e
                errs.extend(loss)

            heads2nd2, heads2nd3 = arcs_to_biarcs(heads)
            gold2nd2, gold2nd3 = arcs_to_biarcs(gold)
            for i, k in heads2nd2:
                errs.append(exprs2nd2[i, k])
            for i, k in gold2nd2:
                errs.append(-exprs2nd2[i, k])
            for i, j, k in heads2nd3:
                errs.append(exprs2nd3[i, j, k])
            for i, j, k in gold2nd3:
                errs.append(-exprs2nd3[i, j, k])

            etotal += len(sentence)

            if errs or lerrs:
                loss = dn.esum(errs + lerrs)  # * (1.0/(float(len(errs))))
            else:
                loss = dn.scalarInput(0.0)
            loss_value = loss.scalar_value()
            errs = []
            lerrs = []
            if loss_value != 0.0:
                if update:
                    loss.backward()
                    self.trainer.update()
                    dn.renew_cg()
            yield (loss if not update else None)

    def train(self, sentences):
        for _ in self.train_gen(sentences):
            pass

    def save(self, prefix):
        with open(prefix + ".options", "wb") as f:
            pickle.dump(self.options, f)
        # noinspection PyArgumentList
        dn.save(prefix, [self.network])

    @classmethod
    def load(cls, prefix, new_optons=None):
        """
        :param prefix: model file name prefix
        :type prefix: str
        :rtype: MaxSubGraphParser
        """
        with open(prefix + ".options") as f:
            options = pickle.load(f)
        if new_optons is not None:
            options.__dict__.update(new_optons)
        ret = cls(options, None, prefix)
        return ret
Ejemplo n.º 16
0
class MaxSubTreeParser(TreeParserBase):
    # smaller batch size is better for this task
    default_batch_size = 2
    default_test_batch_size = 64

    @classmethod
    def add_parser_arguments(cls, arg_parser):
        super(MaxSubTreeParser, cls).add_parser_arguments(arg_parser)
        group = arg_parser.add_argument_group(cls.__name__)
        group.add_argument("--optimizer", type=str, dest="optimizer", default="adam", choices=nn.trainers.keys())
        group.add_argument("--decoder", type=str, dest="decoder", default="eisner", choices=decoders)
        group.add_argument("--cost-augment", action="store_true", dest="cost_augment", default=True)
        group.add_argument("--batch-size", type=int, dest="batch_size", default=cls.default_batch_size)
        group.add_argument("--model-format", dest="model_format", choices=nn.model_formats, default="pickle")
        EdgeEvaluationNetwork.add_parser_arguments(arg_parser)

    @classmethod
    def add_predict_arguments(cls, arg_parser):
        super(MaxSubTreeParser, cls).add_predict_arguments(arg_parser)
        group = arg_parser.add_argument_group(cls.__name__)
        group.add_argument("--model-format", dest="model_format", choices=nn.model_formats, default=None)

    @classmethod
    def add_common_arguments(cls, arg_parser):
        super(MaxSubTreeParser, cls).add_common_arguments(arg_parser)
        group = arg_parser.add_argument_group(cls.__name__ + " (common)")
        group.add_argument("--test-batch-size", type=int, dest="test_batch_size", default=cls.default_test_batch_size)
        group.add_argument("--concurrent-count", type=int, dest="concurrent_count", default=5)

    def __init__(self, options, train_sentences=None, restore_file=None):
        self.model = dn.Model()
        if restore_file:
            old_options, self.network = nn.model_load_helper(None, restore_file, self.model)
            if options is not None:
                old_options.__dict__.update(options.__dict__)
                options = old_options
        else:
            statistics = Statistics.from_sentences(train_sentences)
            self.network = EdgeEvaluationNetwork(self.model, statistics, options)

        self.optimizer = nn.get_optimizer(self.model, options)
        self.decoder = decoders[options.decoder]
        self.labelsFlag = options.labelsFlag
        self.options = options

        if "func" in options:
            del options.func

    def predict_session(self, sentence, pool):
        lstm_output = self.network.get_lstm_output(sentence)
        length = len(sentence)
        raw_exprs = self.network.edge_eval.get_complete_raw_exprs(lstm_output)
        yield raw_exprs

        scores = self.network.edge_eval.raw_exprs_to_scores(raw_exprs, length)

        heads_future = pool.apply_async(self.decoder, (scores, ))
        yield None
        heads = heads_future.get()
        labels = [None for _ in range(len(sentence))]

        if self.labelsFlag:
            edges = [(head, "_", modifier) for modifier, head in enumerate(heads[1:], 1)]
            labels_exprs = list(self.network.label_eval.get_label_scores(lstm_output, edges))
            yield labels_exprs
            for edge, scores_expr in zip(edges, labels_exprs):
                head, _, modifier = edge
                edges_scores = scores_expr.value()
                labels[modifier] = \
                    self.network.irels[max(enumerate(edges_scores), key=itemgetter(1))[0]]
        else:
            yield []

        def convert_node(node):
            # noinspection PyArgumentList
            return CoNLLUNode(node.id, node.form, node.lemma, node.cpos,
                              node.pos, node.feats,
                              heads[node.id], labels[node.id],
                              "_", "_")

        result = CoNLLUSentence(convert_node(node) for node in sentence if node.id > 0)
        if self.options.output_scores:
            # extract full edge scores
            edges = [(head, "_", modifier) for head in range(len(sentence))
                     for modifier in range(len(sentence))]
            edges_scores_all = np.array(list(i.value() for i in self.network.get_label_scores(lstm_output, edges)))
            edges_scores_all = edges_scores_all.reshape((len(sentence), len(sentence), len(self.network.rels)))
            result.comment = [base64.b64encode(pickle.dumps(scores)).decode()]
            result.comment.append(base64.b64encode(pickle.dumps(edges_scores_all)).decode()
                                  if self.labelsFlag else "No labels")
        yield result

    def predict(self, sentences):
        self.network.sent_embedding.rnn.disable_dropout()
        pool = Pool(self.options.concurrent_count)
        for sentence_idx, batch_idx, batch_sentences in split_to_batches(
                sentences, self.options.test_batch_size):
            sessions = [self.predict_session(sentence, pool)
                        for sentence in batch_sentences]
            all_exprs = [next(i) for i in sessions]
            if all_exprs:
                dn.forward(all_exprs)
            # spawn decoders
            for i in sessions:
                next(i)
            all_labels_exprs = [j for i in sessions for j in next(i)]
            if all_labels_exprs:
                dn.forward(all_labels_exprs)
            for i in sessions:
                yield next(i)
            dn.renew_cg()

    def training_session(self, sentence, print_logger, pool):
        lstm_output = self.network.get_lstm_output(sentence)
        length = len(sentence)
        raw_exprs = self.network.edge_eval.get_complete_raw_exprs(lstm_output)
        yield raw_exprs

        scores = self.network.edge_eval.raw_exprs_to_scores(raw_exprs, length)
        exprs = self.network.edge_eval.raw_exprs_to_exprs(raw_exprs, length)

        gold = [entry.parent_id for entry in sentence]
        heads_future = pool.apply_async(self.decoder,
                                        (scores, gold if self.options.cost_augment else None))
        yield None
        heads = heads_future.get()

        label_loss = dn.scalarInput(0.0)
        if self.labelsFlag:
            edges = [(head, "_", modifier) for modifier, head in enumerate(gold[1:], 1)]
            label_exprs = list(self.network.get_label_scores(lstm_output, edges))
            yield label_exprs
            for edge, r_scores_expr in zip(edges, label_exprs):
                head, _, modifier = edge
                r_scores = r_scores_expr.value()
                gold_label_index = self.network.rels[sentence[modifier].relation]
                wrong_label_index = max(((l, scr) for l, scr in enumerate(r_scores)
                                         if l != gold_label_index), key=itemgetter(1))[0]
                if r_scores[gold_label_index] < r_scores[wrong_label_index] + 1:
                    label_loss += r_scores_expr[wrong_label_index] - r_scores_expr[gold_label_index] + 1
        else:
            yield []

        head_exprs = [(exprs[h][i] - exprs[g][i] + 1)
                      for i, (h, g) in enumerate(zip(heads, gold)) if
                      h != g]
        print_logger.correct_edge += len(sentence) - len(head_exprs)
        print_logger.total_edge += len(sentence)
        head_loss = dn.esum(head_exprs) if head_exprs else dn.scalarInput(0.0)
        yield label_loss + head_loss

    def train_gen(self, sentences, update=True):
        print_logger = PrintLogger()
        pool = Pool(self.options.concurrent_count)
        self.network.sent_embedding.rnn.set_dropout(self.options.lstm_dropout)
        print_per = (100 // self.options.batch_size + 1) * self.options.batch_size

        for sentence_idx, batch_idx, batch_sentences in split_to_batches(
                sentences, self.options.batch_size):
            if sentence_idx % print_per == 0 and sentence_idx != 0:
                print_logger.print(sentence_idx)
            sessions = [self.training_session(sentence, print_logger, pool)
                        for sentence in batch_sentences]
            all_exprs = [next(i) for i in sessions]
            if all_exprs:
                dn.forward(all_exprs)
            # spawn decoders
            for i in sessions:
                next(i)
            all_labels_exprs = [j for i in sessions for j in next(i)]
            if all_labels_exprs:
                dn.forward(all_labels_exprs)
            loss = sum(next(i) for i in sessions) / len(sessions)
            print_logger.total_loss_value += loss.value()
            if update:
                loss.backward()
                self.optimizer.update()
                dn.renew_cg()
            yield (loss if not update else None)

    def train(self, sentences):
        for _ in self.train_gen(sentences):
            pass

    def save(self, prefix):
        with open(prefix + ".options", "wb") as f:
            pickle.dump(self.options, f)
        # noinspection PyArgumentList
        dn.save(prefix, [self.network])

    @classmethod
    def load(cls,
             prefix,  # type: str
             new_options=None):
        """
        :param prefix: model file name prefix
        :rtype: MaxSubGraphParser
        """
        ret = cls(new_options, None, prefix)
        return ret