コード例 #1
0
ファイル: model.py プロジェクト: zzsfornlp/zmsp
 def __init__(self, conf: ZsfpModelConf, vpack: VocabPackage):
     super().__init__(conf)
     conf: ZsfpModelConf = self.conf
     self.vpack = vpack
     # =====
     # components
     # -- input
     self.input_choice_emb, self.input_choice_bert = [
         conf.input_choice == z for z in ["emb", "bert"]
     ]
     _need_berter = self.input_choice_bert or (
         self.input_choice_emb and conf.emb_conf.ec_bert.dim > 0)
     self.berter = BertEncoder(conf.bert_conf) if _need_berter else None
     self.emb = MyEmbdder(
         conf.emb_conf, vpack,
         berter=self.berter) if self.input_choice_emb else None
     # inputter's dim -> [bsize, slen, **D**]
     if self.input_choice_emb:
         self.input_dim = self.emb.get_output_dims()[0]
     elif self.input_choice_bert:
         self.input_dim = self.berter.get_output_dims()[0]
     else:
         raise NotImplementedError(
             f"Error: UNK input choice of {conf.input_choice}")
     # -- encoder -> [bsize, slen, D']
     self.enc = PlainEncoder(conf.enc_conf, input_dim=self.input_dim)
     self.enc_dim = self.enc.get_output_dims()[0]
     # -- decoder
     self.framer = MyFramer(conf.frame_conf,
                            vocab_evt=vpack.get_voc("evt"),
                            vocab_arg=vpack.get_voc("arg"),
                            isize=self.enc_dim)
     # =====
     # --
     zzz = self.optims  # finally build optim!
コード例 #2
0
ファイル: framer.py プロジェクト: zzsfornlp/zmsp
 def __init__(self, conf: MyFramerConf, vocab_evt: SimpleVocab, vocab_arg: SimpleVocab, **kwargs):
     super().__init__(conf, **kwargs)
     conf: MyFramerConf = self.conf
     # --
     # evt
     if conf.evt_cons_lex_file:
         cons_lex = LexConstrainer.load_from_file(conf.evt_cons_lex_file)
         # note: not adding here!! # add missing frames
         # all_frames = set(f for v in cons_lex.cmap.values() for f in v.keys())
         # for f in all_frames:
         #     vocab_evt.feed_one(f, c=0)
         # zlog(f"After adding all frames from cons_lex: {vocab_evt}")
     else:
         cons_lex = None
     self.evt_extractor = ExtractorGetter.make_extractor(conf.evt_conf, vocab_evt, cons_lex=cons_lex, isize=conf.isize)
     evt_psize = self.evt_extractor.get_output_dims()[0]
     # arg
     if conf.arg_cons_fe_file:
         cons_arg = FEConstrainer.load_from_file(conf.arg_cons_fe_file)
         self.cons_arg = ConstrainerNode(cons_arg, vocab_evt, vocab_arg, None)  # note: currently no other confs
     else:
         self.cons_arg = None
     self.fenc = PlainEncoder(conf.fenc_conf, input_dim=conf.isize)
     self.arg_extractor = ExtractorGetter.make_extractor(conf.arg_conf, vocab_arg, isize=conf.isize,
                                                         psize=evt_psize if conf.arg_use_finput else -1)
コード例 #3
0
 def __init__(self, conf: ExtenderConf, **kwargs):
     super().__init__(conf, **kwargs)
     conf: ExtenderConf = self.conf
     # --
     self.ext_span_getter = Mention.create_span_getter(conf.ext_span_mode)
     self.ext_span_setter = Mention.create_span_setter(conf.ext_span_mode)
     self.eenc = PlainEncoder(conf.eenc_conf, input_dim=conf.isize)
     self.enode = SpanExpanderNode(conf.econf, isize=conf.isize, psize=(conf.psize if conf.ext_use_finput else -1))
コード例 #4
0
class ExtenderNode(BasicNode):
    def __init__(self, conf: ExtenderConf, **kwargs):
        super().__init__(conf, **kwargs)
        conf: ExtenderConf = self.conf
        # --
        self.ext_span_getter = Mention.create_span_getter(conf.ext_span_mode)
        self.ext_span_setter = Mention.create_span_setter(conf.ext_span_mode)
        self.eenc = PlainEncoder(conf.eenc_conf, input_dim=conf.isize)
        self.enode = SpanExpanderNode(conf.econf, isize=conf.isize, psize=(conf.psize if conf.ext_use_finput else -1))

    # [**, slen, D], [**, slen, D']
    def _forward_eenc(self, flt_input_expr: BK.Expr, flt_full_expr: BK.Expr, flt_mask_expr: BK.Expr):
        if self.conf.eenc_mix_center:
            mixed_input_t = flt_input_expr + flt_full_expr  # simply adding
        else:
            mixed_input_t = flt_input_expr
        eenc_output = self.eenc.forward(mixed_input_t, mask_expr=flt_mask_expr)
        return eenc_output

    # --
    # assume already flattened inputs

    # [*], [*, slen, D], [*, D'], [*, slen]; [*]
    def loss(self, flt_items, flt_input_expr, flt_pair_expr, flt_full_expr, flt_mask_expr, flt_extra_weights=None):
        conf: ExtenderConf = self.conf
        _loss_lambda = conf._loss_lambda
        # --
        enc_t = self._forward_eenc(flt_input_expr, flt_full_expr, flt_mask_expr)  # [*, slen, D]
        s_left, s_right = self.enode.score(enc_t, flt_pair_expr if conf.ext_use_finput else None, flt_mask_expr)  # [*, slen]
        # --
        gold_posi = [self.ext_span_getter(z.mention) for z in flt_items]  # List[(widx, wlen)]
        widx_t = BK.input_idx([z[0] for z in gold_posi])  # [*]
        wlen_t = BK.input_idx([z[1] for z in gold_posi])
        loss_left_t, loss_right_t = BK.loss_nll(s_left, widx_t), BK.loss_nll(s_right, widx_t+wlen_t-1)  # [*]
        if flt_extra_weights is not None:
            loss_left_t *= flt_extra_weights
            loss_right_t *= flt_extra_weights
            loss_div = flt_extra_weights.sum()  # note: also use this!
        else:
            loss_div = BK.constants([len(flt_items)], value=1.).sum()
        loss_left_item = LossHelper.compile_leaf_loss("left", loss_left_t.sum(), loss_div, loss_lambda=_loss_lambda)
        loss_right_item = LossHelper.compile_leaf_loss("right", loss_right_t.sum(), loss_div, loss_lambda=_loss_lambda)
        ret_loss = LossHelper.combine_multiple_losses([loss_left_item, loss_right_item])
        return ret_loss

    # [*], [*, D], [*, D], [*]
    def predict(self, flt_items, flt_input_expr, flt_pair_expr, flt_full_expr, flt_mask_expr):
        conf: ExtenderConf = self.conf
        if len(flt_items) <= 0:
            return None  # no input item!
        # --
        enc_t = self._forward_eenc(flt_input_expr, flt_full_expr, flt_mask_expr)  # [*, D]
        s_left, s_right = self.enode.score(enc_t, flt_pair_expr if conf.ext_use_finput else None, flt_mask_expr)
        # --
        max_scores, left_idxes, right_idxes = SpanExpanderNode.decode_with_scores(s_left, s_right, normalize=True)
        all_arrs = [BK.get_value(z) for z in [left_idxes, right_idxes]]
        for cur_item, cur_left_idx, cur_right_idx in zip(flt_items, *all_arrs):
            new_widx, new_wlen = int(cur_left_idx), int(cur_right_idx+1-cur_left_idx)
            self.ext_span_setter(cur_item.mention, new_widx, new_wlen)
コード例 #5
0
ファイル: model.py プロジェクト: zzsfornlp/zmsp
class ZsfpModel(BaseModel):
    def __init__(self, conf: ZsfpModelConf, vpack: VocabPackage):
        super().__init__(conf)
        conf: ZsfpModelConf = self.conf
        self.vpack = vpack
        # =====
        # components
        # -- input
        self.input_choice_emb, self.input_choice_bert = [
            conf.input_choice == z for z in ["emb", "bert"]
        ]
        _need_berter = self.input_choice_bert or (
            self.input_choice_emb and conf.emb_conf.ec_bert.dim > 0)
        self.berter = BertEncoder(conf.bert_conf) if _need_berter else None
        self.emb = MyEmbdder(
            conf.emb_conf, vpack,
            berter=self.berter) if self.input_choice_emb else None
        # inputter's dim -> [bsize, slen, **D**]
        if self.input_choice_emb:
            self.input_dim = self.emb.get_output_dims()[0]
        elif self.input_choice_bert:
            self.input_dim = self.berter.get_output_dims()[0]
        else:
            raise NotImplementedError(
                f"Error: UNK input choice of {conf.input_choice}")
        # -- encoder -> [bsize, slen, D']
        self.enc = PlainEncoder(conf.enc_conf, input_dim=self.input_dim)
        self.enc_dim = self.enc.get_output_dims()[0]
        # -- decoder
        self.framer = MyFramer(conf.frame_conf,
                               vocab_evt=vpack.get_voc("evt"),
                               vocab_arg=vpack.get_voc("arg"),
                               isize=self.enc_dim)
        # =====
        # --
        zzz = self.optims  # finally build optim!

    # helper: embed and encode
    def _input_emb(self, insts: List[Sent]):
        input_map = self.emb.run_inputs(insts)
        mask_expr, emb_expr = self.emb.run_embeds(input_map)  # [bs, slen, ?]
        return mask_expr, emb_expr

    def _input_bert(self, insts: List[Sent]):
        bi = self.berter.create_input_batch_from_sents(insts)
        mask_expr = BK.input_real(
            DataPadder.lengths2mask([len(z) for z in insts]))  # [bs, slen, *]
        bert_expr = self.berter.forward(bi)
        return mask_expr, bert_expr

    def _emb_and_enc(self, insts: List[Sent]):
        # input
        if self.input_choice_emb:  # use emb
            mask_expr, emb_expr = self._input_emb(insts)
        elif self.input_choice_bert:  # use bert
            mask_expr, emb_expr = self._input_bert(insts)
        else:
            raise NotImplementedError()
        # encode
        enc_expr = self.enc.forward(emb_expr, mask_expr=mask_expr)
        return mask_expr, emb_expr, enc_expr  # [bs, slen, *]

    # =====
    def loss_on_batch(self,
                      annotated_insts: List,
                      loss_factor=1.,
                      training=True,
                      **kwargs):
        self.refresh_batch(training)
        # --
        sents: List[Sent] = list(yield_sents(annotated_insts))
        # emb and enc
        mask_expr, input_expr, enc_expr = self._emb_and_enc(sents)
        # frame
        f_loss = self.framer.loss(sents, enc_expr, mask_expr)
        # --
        # final loss and backward
        info = {"inst": len(annotated_insts), "sent": len(sents), "fb": 1}
        final_loss, loss_info = self.collect_loss([f_loss])
        info.update(loss_info)
        if training:
            assert final_loss.requires_grad
            BK.backward(final_loss, loss_factor)
        return info

    def predict_on_batch(self, insts: List, **kwargs):
        conf: ZsfpModelConf = self.conf
        self.refresh_batch(False)
        # --
        sents: List[Sent] = list(yield_sents(insts))
        with BK.no_grad_env():
            # batch run inside if input is doc
            sent_buckets = BatchHelper.group_buckets(
                sents,
                thresh_diff=conf.decode_sent_thresh_diff,
                thresh_all=conf.decode_sent_thresh_batch,
                size_f=lambda x: 1,
                sort_key=lambda x: len(x))
            for one_sents in sent_buckets:
                # emb and enc
                mask_expr, emb_expr, enc_expr = self._emb_and_enc(one_sents)
                # frame
                self.framer.predict(one_sents, enc_expr, mask_expr)
        # --
        info = {"inst": len(insts), "sent": len(sents)}
        return info
コード例 #6
0
ファイル: framer.py プロジェクト: zzsfornlp/zmsp
class MyFramer(BasicNode):
    def __init__(self, conf: MyFramerConf, vocab_evt: SimpleVocab, vocab_arg: SimpleVocab, **kwargs):
        super().__init__(conf, **kwargs)
        conf: MyFramerConf = self.conf
        # --
        # evt
        if conf.evt_cons_lex_file:
            cons_lex = LexConstrainer.load_from_file(conf.evt_cons_lex_file)
            # note: not adding here!! # add missing frames
            # all_frames = set(f for v in cons_lex.cmap.values() for f in v.keys())
            # for f in all_frames:
            #     vocab_evt.feed_one(f, c=0)
            # zlog(f"After adding all frames from cons_lex: {vocab_evt}")
        else:
            cons_lex = None
        self.evt_extractor = ExtractorGetter.make_extractor(conf.evt_conf, vocab_evt, cons_lex=cons_lex, isize=conf.isize)
        evt_psize = self.evt_extractor.get_output_dims()[0]
        # arg
        if conf.arg_cons_fe_file:
            cons_arg = FEConstrainer.load_from_file(conf.arg_cons_fe_file)
            self.cons_arg = ConstrainerNode(cons_arg, vocab_evt, vocab_arg, None)  # note: currently no other confs
        else:
            self.cons_arg = None
        self.fenc = PlainEncoder(conf.fenc_conf, input_dim=conf.isize)
        self.arg_extractor = ExtractorGetter.make_extractor(conf.arg_conf, vocab_arg, isize=conf.isize,
                                                            psize=evt_psize if conf.arg_use_finput else -1)

    # helper for cons_ef
    def _get_arg_external_extra_score(self, flt_items):
        if self.cons_arg is not None:
            evt_idxes = [(0 if z is None else z.label_idx) for z in flt_items]
            valid_masks = self.cons_arg.lookup(BK.input_idx(evt_idxes))  # [*, L]
            ret = Constants.REAL_PRAC_MIN * (1. - valid_masks)  # [*, L]
            return ret.unsqueeze(-2)  # [bs, 1, L], let later broadcast!
        else:
            return None

    # [**, slen, D], [**, slen, D']
    def _forward_fenc(self, flt_input_expr: BK.Expr, flt_full_expr: BK.Expr, flt_mask_expr: BK.Expr):
        if self.conf.fenc_mix_frame:
            mixed_input_t = flt_input_expr + flt_full_expr  # simply adding
        else:
            mixed_input_t = flt_input_expr
        fenc_output = self.fenc.forward(mixed_input_t, mask_expr=flt_mask_expr)
        return fenc_output

    # [*, slen, D], [*, slen]
    def loss(self, insts: List[Sent], input_expr: BK.Expr, mask_expr: BK.Expr):
        conf: MyFramerConf = self.conf
        # --
        all_losses = []
        # evt
        if conf.loss_evt > 0.:
            evt_loss, evt_res = self.evt_extractor.loss(insts, input_expr, mask_expr)
            one_loss = LossHelper.compile_component_loss("evt", [evt_loss], loss_lambda=conf.loss_evt)
            all_losses.append(one_loss)
        else:
            evt_res = None
        # arg
        if conf.loss_arg > 0.:
            if evt_res is None:
                evt_res = self.evt_extractor.lookup_flatten(insts, input_expr, mask_expr)
            flt_items, flt_sidx, flt_expr, flt_full_expr = evt_res  # flatten to make dim0 -> frames
            flt_input_expr, flt_mask_expr = input_expr[flt_sidx], mask_expr[flt_sidx]
            flt_fenc_expr = self._forward_fenc(flt_input_expr, flt_full_expr, flt_mask_expr)  # [**, slen, D]
            arg_loss, _ = self.arg_extractor.loss(
                flt_items, flt_fenc_expr, flt_mask_expr, pair_expr=(flt_expr if conf.arg_use_finput else None),
                external_extra_score=self._get_arg_external_extra_score(flt_items))
            one_loss = LossHelper.compile_component_loss("arg", [arg_loss], loss_lambda=conf.loss_arg)
            all_losses.append(one_loss)
        # --
        ret_loss = LossHelper.combine_multiple_losses(all_losses)
        return ret_loss

    # [*, slen, D], [*, slen]
    def predict(self, insts: List[Sent], input_expr: BK.Expr, mask_expr: BK.Expr):
        conf: MyFramerConf = self.conf
        if conf.pred_evt:  # predict evt
            evt_res = self.evt_extractor.predict(insts, input_expr, mask_expr)
        else:
            evt_res = None
        if conf.pred_arg:  # predict arg
            # --
            # todo(+N): special delete for old ones here!!
            for s in insts:
                s.delete_frames(self.arg_extractor.conf.arg_ftag)
            # --
            if evt_res is None:
                evt_res = self.evt_extractor.lookup_flatten(insts, input_expr, mask_expr)
            flt_items, flt_sidx, flt_expr, flt_full_expr = evt_res  # flatten to make dim0 -> frames
            if len(flt_items)>0:  # can be erroneous if zero
                flt_input_expr, flt_mask_expr = input_expr[flt_sidx], mask_expr[flt_sidx]
                flt_fenc_expr = self._forward_fenc(flt_input_expr, flt_full_expr, flt_mask_expr)  # [**, slen, D]
                self.arg_extractor.predict(
                    flt_items, flt_fenc_expr, flt_mask_expr, pair_expr=(flt_expr if conf.arg_use_finput else None),
                    external_extra_score=self._get_arg_external_extra_score(flt_items))