예제 #1
0
 def __init__(self, pc: BK.ParamCollection, jconf: JPosConf, pos_vocab):
     super().__init__(pc, None, None)
     self.jpos_stacking = jconf.jpos_stacking
     self.jpos_multitask = jconf.jpos_multitask
     self.jpos_lambda = jconf.jpos_lambda
     self.jpos_decode = jconf.jpos_decode
     # encoder0
     jconf.jpos_enc._input_dim = jconf._input_dim
     self.enc = self.add_sub_node("enc0", MyEncoder(self.pc,
                                                    jconf.jpos_enc))
     self.enc_output_dim = self.enc.get_output_dims()[0]
     # output
     # todo(warn): here, include some other things for convenience
     num_labels = len(pos_vocab)
     self.pred = self.add_sub_node(
         "pred",
         Affine(self.pc,
                self.enc_output_dim,
                num_labels,
                init_rop=NoDropRop()))
     # further stacking (if not, then simply multi-task learning)
     if jconf.jpos_stacking:
         self.pos_weights = self.add_param(
             "w", (num_labels, self.enc_output_dim))  # [n, dim] to be added
     else:
         self.pos_weights = None
예제 #2
0
파일: dochint.py 프로젝트: ValentinaPy/zmsp
 def __init__(self, pc, dh_conf: DocHintConf):
     super().__init__(pc, None, None)
     self.conf: DocHintConf = dh_conf
     # =====
     self.input_dim, self.output_dim = dh_conf._input_dim, dh_conf._output_dim
     # 1. doc encoding
     self.conf.enc_doc_conf._input_dim = self.input_dim
     self.enc_doc = self.add_sub_node(
         "enc_d", MyEncoder(self.pc, self.conf.enc_doc_conf))
     self.enc_output_dim = self.enc_doc.get_output_dims()[0]
     # 2. keyword/keysent based doc hints (key/value)
     katt_conf = dh_conf.katt_conf
     self.kw_att = self.add_sub_node(
         "kw",
         AttentionNode.get_att_node(katt_conf.type, pc, self.input_dim,
                                    self.enc_output_dim, self.input_dim,
                                    katt_conf))
     self.ks_att = self.add_sub_node(
         "ks",
         AttentionNode.get_att_node(katt_conf.type, pc, self.input_dim,
                                    self.enc_output_dim, self.input_dim,
                                    katt_conf))
     # word model (load from outside)
     self.keyword_model = None
     if self.conf.kconf.load_file:
         self.keyword_model = KeyWordModel.load(self.conf.kconf.load_file,
                                                self.conf.kconf)
     # 3. combine
     final_input_dims = [self.enc_output_dim] + [self.input_dim] * (
         int(dh_conf.use_keyword) + int(dh_conf.use_keysent))
     self.final_layer = self.add_sub_node(
         "fl", Affine(pc, final_input_dims, self.output_dim, act="elu"))
예제 #3
0
파일: enc.py 프로젝트: ValentinaPy/zmsp
 def __init__(self, pc: BK.ParamCollection, conf: FpEncConf,
              vpack: VocabPackage):
     super().__init__(pc, None, None)
     self.conf = conf
     # ===== Vocab =====
     self.word_vocab = vpack.get_voc("word")
     self.char_vocab = vpack.get_voc("char")
     self.pos_vocab = vpack.get_voc("pos")
     # avoid no params error
     self._tmp_v = self.add_param("nope", (1, ))
     # ===== Model =====
     # embedding
     self.emb = self.add_sub_node("emb",
                                  MyEmbedder(self.pc, conf.emb_conf, vpack))
     self.emb_output_dim = self.emb.get_output_dims()[0]
     # bert
     self.bert = self.add_sub_node("bert", Berter2(self.pc, conf.bert_conf))
     self.bert_output_dim = self.bert.get_output_dims()[0]
     # make sure there are inputs
     assert self.emb_output_dim > 0 or self.bert_output_dim > 0
     # middle?
     if conf.middle_dim > 0:
         self.middle_node = self.add_sub_node(
             "mid",
             Affine(self.pc,
                    self.emb_output_dim + self.bert_output_dim,
                    conf.middle_dim,
                    act="elu"))
         self.enc_input_dim = conf.middle_dim
     else:
         self.middle_node = None
         self.enc_input_dim = self.emb_output_dim + self.bert_output_dim  # concat the two parts (if needed)
     # encoder?
     # todo(note): feed compute-on-the-fly hp
     conf.enc_conf._input_dim = self.enc_input_dim
     self.enc = self.add_sub_node("enc", MyEncoder(self.pc, conf.enc_conf))
     self.enc_output_dim = self.enc.get_output_dims()[0]
     # ===== Input Specification =====
     # inputs (word, char, pos) and vocabulary
     self.need_word = self.emb.has_word
     self.need_char = self.emb.has_char
     # todo(warn): currently only allow extra fields for POS
     self.need_pos = False
     if len(self.emb.extra_names) > 0:
         assert len(
             self.emb.extra_names) == 1 and self.emb.extra_names[0] == "pos"
         self.need_pos = True
     #
     self.word_padder = DataPadder(2,
                                   pad_vals=self.word_vocab.pad,
                                   mask_range=2)
     self.char_padder = DataPadder(3,
                                   pad_lens=(0, 0, conf.char_max_length),
                                   pad_vals=self.char_vocab.pad)
     self.pos_padder = DataPadder(2, pad_vals=self.pos_vocab.pad)
예제 #4
0
파일: head.py 프로젝트: ValentinaPy/zmsp
 def __init__(self, pc, conf: NodeExtractorConfHead, vocab: HLabelVocab,
              extract_type: str):
     super().__init__(pc, conf, vocab, extract_type)
     # node selector
     conf.sel_conf._input_dim = conf._input_dim  # make dims fit
     self.sel: NodeSelector = self.add_sub_node(
         "sel", NodeSelector(pc, conf.sel_conf))
     # encoding
     self.dmxnn = conf.dmxnn
     self.posi_embed = self.add_sub_node(
         "pe", RelPosiEmbedding(pc, conf.posi_dim, max=conf.posi_cut))
     if self.dmxnn:
         conf.e_enc._input_dim = conf._input_dim + conf.posi_dim
     else:
         conf.e_enc._input_dim = conf._input_dim
     self.e_encoder = self.add_sub_node("ee", MyEncoder(pc, conf.e_enc))
     e_enc_dim = self.e_encoder.get_output_dims()[0]
     # decoding
     # todo(note): dropout after pooling; todo(+N): cannot go to previous layers if there are no encoders
     self.special_drop = self.add_sub_node("sd", Dropout(pc, (e_enc_dim, )))
     self.use_lab_f = conf.use_lab_f
     self.lab_f_use_lexi = conf.lab_f_use_lexi
     if self.use_lab_f:
         lab_f_input_dims = [e_enc_dim] * 3 if self.dmxnn else [e_enc_dim]
         if self.lab_f_use_lexi:
             lab_f_input_dims.append(conf._lexi_dim)
         self.lab_f = self.add_sub_node(
             "lab",
             Affine(pc,
                    lab_f_input_dims,
                    conf.lab_conf.n_dim,
                    act=conf.lab_f_act))
     else:
         self.lab_f = lambda x: x[0]  # only use the first one
     # secondary type
     self.use_secondary_type = conf.use_secondary_type
     if self.use_secondary_type:
         # todo(note): re-use vocab; or totally reuse the predictor?
         if conf.sectype_reuse_hl:
             self.hl2: HLabelNode = self.hl
         else:
             new_lab_conf = deepcopy(conf.lab_conf)
             new_lab_conf.zero_nil = False  # todo(note): not zero_nil here!
             self.hl2: HLabelNode = self.add_sub_node(
                 "hl", HLabelNode(pc, new_lab_conf, vocab))
         # enc+t1 -> t2
         self.t1tot2 = self.add_sub_node(
             "1to2", Embedding(pc, self.hl_output_size,
                               conf.lab_conf.n_dim))
     else:
         self.hl2 = None
         self.t1tot2 = None
예제 #5
0
 def __init__(self, pc: BK.ParamCollection, bconf: BTConf,
              vpack: VocabPackage):
     super().__init__(pc, None, None)
     self.bconf = bconf
     # ===== Vocab =====
     self.word_vocab = vpack.get_voc("word")
     self.char_vocab = vpack.get_voc("char")
     self.pos_vocab = vpack.get_voc("pos")
     # ===== Model =====
     # embedding
     self.emb = self.add_sub_node(
         "emb", MyEmbedder(self.pc, bconf.emb_conf, vpack))
     emb_output_dim = self.emb.get_output_dims()[0]
     # encoder0 for jpos
     # todo(note): will do nothing if not use_jpos
     bconf.jpos_conf._input_dim = emb_output_dim
     self.jpos_enc = self.add_sub_node(
         "enc0", JPosModule(self.pc, bconf.jpos_conf, self.pos_vocab))
     enc0_output_dim = self.jpos_enc.get_output_dims()[0]
     # encoder
     # todo(0): feed compute-on-the-fly hp
     bconf.enc_conf._input_dim = enc0_output_dim
     self.enc = self.add_sub_node("enc", MyEncoder(self.pc, bconf.enc_conf))
     self.enc_output_dim = self.enc.get_output_dims()[0]
     # ===== Input Specification =====
     # inputs (word, char, pos) and vocabulary
     self.need_word = self.emb.has_word
     self.need_char = self.emb.has_char
     # todo(warn): currently only allow extra fields for POS
     self.need_pos = False
     if len(self.emb.extra_names) > 0:
         assert len(
             self.emb.extra_names) == 1 and self.emb.extra_names[0] == "pos"
         self.need_pos = True
     # todo(warn): currently only allow one aux field
     self.need_aux = False
     if len(self.emb.dim_auxes) > 0:
         assert len(self.emb.dim_auxes) == 1
         self.need_aux = True
     #
     self.word_padder = DataPadder(2,
                                   pad_vals=self.word_vocab.pad,
                                   mask_range=2)
     self.char_padder = DataPadder(3,
                                   pad_lens=(0, 0, bconf.char_max_length),
                                   pad_vals=self.char_vocab.pad)
     self.pos_padder = DataPadder(2, pad_vals=self.pos_vocab.pad)
     #
     self.random_sample_stream = Random.stream(Random.random_sample)
예제 #6
0
 def __init__(self, pc: BK.ParamCollection, conf: M3EncConf, tconf,
              vpack: VocabPackage):
     super().__init__(pc, conf, tconf, vpack)
     #
     self.conf = conf
     # ----- bert
     # modify bert_conf for other input
     BERT_OTHER_VSIZE = 50  # todo(+N): this should be enough for small inputs!
     conf.bert_conf.bert2_other_input_names = conf.bert_other_inputs
     conf.bert_conf.bert2_other_input_vsizes = [BERT_OTHER_VSIZE] * len(
         conf.bert_other_inputs)
     self.berter = self.add_sub_node("bert", Berter2(pc, conf.bert_conf))
     # -----
     # index fake sent
     self.index_helper = IndexerHelper(vpack)
     # extra encoder over bert?
     self.bert_dim, self.bert_fold = self.berter.get_output_dims()
     conf.m3_enc_conf._input_dim = self.bert_dim
     self.m3_encs = [
         self.add_sub_node("m3e", MyEncoder(pc, conf.m3_enc_conf))
         for _ in range(self.bert_fold)
     ]
     self.m3_enc_out_dim = self.m3_encs[0].get_output_dims()[0]
     # skip m3_enc?
     self.m3_enc_is_empty = all(len(z.layers) == 0 for z in self.m3_encs)
     if self.m3_enc_is_empty:
         assert all(z.get_output_dims()[0] == self.bert_dim
                    for z in self.m3_encs)
         zlog("For m3_enc, we will skip it since it is empty!!")
     # dep as basic?
     if conf.m2e_use_basic_dep:
         MAX_LABEL_NUM = 200  # this should be enough
         self.dep_label_emb = self.add_sub_node(
             "dlab",
             Embedding(self.pc,
                       MAX_LABEL_NUM,
                       conf.dep_label_dim,
                       name="dlab"))
         self.dep_layer = self.add_sub_node(
             "dep",
             TaskSpecAdp(pc, [(self.m3_enc_out_dim, self.bert_fold), None],
                         [conf.dep_label_dim], conf.dep_output_dim))
     else:
         self.dep_label_emb = self.dep_layer = None
     self.dep_padder = DataPadder(
         2, pad_vals=0)  # 0 for both head-idx and label
예제 #7
0
파일: model.py 프로젝트: ValentinaPy/zmsp
 def __init__(self, pc: BK.ParamCollection, bconf: BTConf, tconf: 'BaseTrainingConf', vpack: VocabPackage):
     super().__init__(pc, None, None)
     self.bconf = bconf
     # ===== Vocab =====
     self.word_vocab = vpack.get_voc("word")
     self.char_vocab = vpack.get_voc("char")
     self.lemma_vocab = vpack.get_voc("lemma")
     self.upos_vocab = vpack.get_voc("upos")
     self.ulabel_vocab = vpack.get_voc("ulabel")
     # ===== Model =====
     # embedding
     self.emb = self.add_sub_node("emb", MyEmbedder(self.pc, bconf.emb_conf, vpack))
     emb_output_dim = self.emb.get_output_dims()[0]
     self.emb_output_dim = emb_output_dim
     # doc hint
     self.use_doc_hint = bconf.use_doc_hint
     self.dh_combine_method = bconf.dh_combine_method
     if self.use_doc_hint:
         assert len(bconf.emb_conf.dim_auxes)>0
         # todo(note): currently use the concat of them if input multiple layers
         bconf.dh_conf._input_dim = bconf.emb_conf.dim_auxes[0]  # same as input bert dim
         bconf.dh_conf._output_dim = emb_output_dim  # same as emb_output_dim
         self.dh_node = self.add_sub_node("dh", DocHintModule(pc, bconf.dh_conf))
     else:
         self.dh_node = None
     # encoders
     # shared
     # todo(note): feed compute-on-the-fly hp
     bconf.enc_conf._input_dim = emb_output_dim
     self.enc = self.add_sub_node("enc", MyEncoder(self.pc, bconf.enc_conf))
     tmp_enc_output_dim = self.enc.get_output_dims()[0]
     # privates
     bconf.enc_ef_conf._input_dim = tmp_enc_output_dim
     self.enc_ef = self.add_sub_node("enc_ef", MyEncoder(self.pc, bconf.enc_ef_conf))
     self.enc_ef_output_dim = self.enc_ef.get_output_dims()[0]
     bconf.enc_evt_conf._input_dim = tmp_enc_output_dim
     self.enc_evt = self.add_sub_node("enc_evt", MyEncoder(self.pc, bconf.enc_evt_conf))
     self.enc_evt_output_dim = self.enc_evt.get_output_dims()[0]
     # ===== Input Specification =====
     # inputs (word, lemma, char, upos, ulabel) and vocabulary
     self.need_word = self.emb.has_word
     self.need_char = self.emb.has_char
     # extra fields
     # todo(warn): need to
     self.need_lemma = False
     self.need_upos = False
     self.need_ulabel = False
     for one_extra_name in self.emb.extra_names:
         if one_extra_name == "lemma":
             self.need_lemma = True
         elif one_extra_name == "upos":
             self.need_upos = True
         elif one_extra_name == "ulabel":
             self.need_ulabel = True
         else:
             raise NotImplementedError("UNK extra input name: " + one_extra_name)
     # todo(warn): currently only allow one aux field
     self.need_aux = False
     if len(self.emb.dim_auxes) > 0:
         assert len(self.emb.dim_auxes) == 1
         self.need_aux = True
     # padders
     self.word_padder = DataPadder(2, pad_vals=self.word_vocab.pad, mask_range=2)
     self.char_padder = DataPadder(3, pad_lens=(0, 0, bconf.char_max_length), pad_vals=self.char_vocab.pad)
     self.lemma_padder = DataPadder(2, pad_vals=self.lemma_vocab.pad)
     self.upos_padder = DataPadder(2, pad_vals=self.upos_vocab.pad)
     self.ulabel_padder = DataPadder(2, pad_vals=self.ulabel_vocab.pad)
     #
     self.random_sample_stream = Random.stream(Random.random_sample)
     self.train_skip_noevt_rate = tconf.train_skip_noevt_rate
     self.train_skip_length = tconf.train_skip_length
     self.train_min_length = tconf.train_min_length
     self.test_min_length = tconf.test_min_length
     self.test_skip_noevt_rate = tconf.test_skip_noevt_rate
     self.train_sent_based = tconf.train_sent_based
     #
     assert not self.train_sent_based, "The basic model should not use this sent-level mode!"
예제 #8
0
파일: mtl.py 프로젝트: ValentinaPy/zmsp
 def __init__(self, conf: MtlMlmModelConf, vpack: VocabPackage):
     super().__init__(conf)
     # for easier checking
     self.word_vocab = vpack.get_voc("word")
     # components
     self.embedder = self.add_node("emb", EmbedderNode(self.pc, conf.emb_conf, vpack))
     self.inputter = Inputter(self.embedder, vpack)  # not a node
     self.emb_out_dim = self.embedder.get_output_dims()[0]
     self.enc_attn_count = conf.default_attn_count
     if conf.enc_choice == "vrec":
         self.encoder = self.add_component("enc", VRecEncoder(self.pc, self.emb_out_dim, conf.venc_conf))
         self.enc_attn_count = self.encoder.attn_count
     elif conf.enc_choice == "original":
         conf.oenc_conf._input_dim = self.emb_out_dim
         self.encoder = self.add_node("enc", MyEncoder(self.pc, conf.oenc_conf))
     else:
         raise NotImplementedError()
     zlog(f"Finished building model's encoder {self.encoder}, all size is {self.encoder.count_allsize_parameters()}")
     self.enc_out_dim = self.encoder.get_output_dims()[0]
     # --
     conf.rprep_conf._rprep_vr_conf.matt_conf.head_count = self.enc_attn_count  # make head-count agree
     self.rpreper = self.add_node("rprep", RPrepNode(self.pc, self.enc_out_dim, conf.rprep_conf))
     # --
     self.lambda_agree = self.add_scheduled_value(ScheduledValue(f"agr:lambda", conf.lambda_agree))
     self.agree_loss_f = EntropyHelper.get_method(conf.agree_loss_f)
     # --
     self.masklm = self.add_component("mlm", MaskLMNode(self.pc, self.enc_out_dim, conf.mlm_conf, self.inputter))
     self.plainlm = self.add_component("plm", PlainLMNode(self.pc, self.enc_out_dim, conf.plm_conf, self.inputter))
     # todo(note): here we use attn as dim_pair, do not use pair if not using vrec!!
     self.orderpr = self.add_component("orp", OrderPredNode(
         self.pc, self.enc_out_dim, self.enc_attn_count, conf.orp_conf, self.inputter))
     # =====
     # pre-training pre-load point!!
     if conf.load_pretrain_model_name:
         zlog(f"At preload_pretrain point: Loading from {conf.load_pretrain_model_name}")
         self.pc.load(conf.load_pretrain_model_name, strict=False)
     # =====
     self.dpar = self.add_component("dpar", DparG1Decoder(
         self.pc, self.enc_out_dim, self.enc_attn_count, conf.dpar_conf, self.inputter))
     self.upos = self.add_component("upos", SeqLabNode(
         self.pc, "pos", self.enc_out_dim, self.conf.upos_conf, self.inputter))
     if conf.do_ner:
         if conf.ner_use_crf:
             self.ner = self.add_component("ner", SeqCrfNode(
                 self.pc, "ner", self.enc_out_dim, self.conf.ner_conf, self.inputter))
         else:
             self.ner = self.add_component("ner", SeqLabNode(
                 self.pc, "ner", self.enc_out_dim, self.conf.ner_conf, self.inputter))
     else:
         self.ner = None
     # for pairwise reprs (no trainable params here!)
     self.rel_dist_embed = self.add_node("oremb", PosiEmbedding2(self.pc, n_dim=self.enc_attn_count, max_val=100))
     self._prepr_f_attn_sum = lambda cache, rdist: BK.stack(cache.list_attn, 0).sum(0) if (len(cache.list_attn))>0 else None
     self._prepr_f_attn_avg = lambda cache, rdist: BK.stack(cache.list_attn, 0).mean(0) if (len(cache.list_attn))>0 else None
     self._prepr_f_attn_max = lambda cache, rdist: BK.stack(cache.list_attn, 0).max(0)[0] if (len(cache.list_attn))>0 else None
     self._prepr_f_attn_last = lambda cache, rdist: cache.list_attn[-1] if (len(cache.list_attn))>0 else None
     self._prepr_f_rdist = lambda cache, rdist: self._get_rel_dist_embed(rdist, False)
     self._prepr_f_rdist_abs = lambda cache, rdist: self._get_rel_dist_embed(rdist, True)
     self.prepr_f = getattr(self, "_prepr_f_"+conf.prepr_choice)  # shortcut
     # --
     self.testing_rand_gen = Random.create_sep_generator(conf.testing_rand_gen_seed)  # especial gen for testing
     # =====
     if conf.orp_loss_special:
         self.orderpr.add_node_special(self.masklm)
     # =====
     # extra one!!
     self.aug_word2 = self.aug_encoder = self.aug_mixturer = None
     if conf.aug_word2:
         self.aug_word2 = self.add_node("aug2", AugWord2Node(self.pc, conf.emb_conf, vpack,
                                                             "word2", conf.aug_word2_dim, self.emb_out_dim))
         if conf.aug_word2_aug_encoder:
             assert conf.enc_choice == "vrec"
             self.aug_detach_drop = self.add_node("dd", Dropout(self.pc, (self.enc_out_dim,), fix_rate=conf.aug_detach_dropout))
             self.aug_encoder = self.add_component("Aenc", VRecEncoder(self.pc, self.emb_out_dim, conf.venc_conf))
             self.aug_mixturer = self.add_node("Amix", BertFeaturesWeightLayer(self.pc, conf.aug_detach_numlayer))
예제 #9
0
def main():
    np.random.seed(1234)
    NUM_POS = 10
    # build vocabs
    reader = TextReader("./test_utils.py")
    vb_word = VocabBuilder("w")
    vb_char = VocabBuilder("c")
    for one in reader:
        vb_word.feed_stream(one.tokens)
        vb_char.feed_stream((c for w in one.tokens for c in w))
    voc_word = vb_word.finish()
    voc_char = vb_char.finish()
    voc_pos = VocabBuilder.build_from_stream(range(NUM_POS), name="pos")
    vpack = VocabPackage({
        "word": voc_word,
        "char": voc_char,
        "pos": voc_pos
    }, {"word": None})
    # build model
    pc = BK.ParamCollection()
    conf_emb = EmbedConf().init_from_kwargs(init_words_from_pretrain=False,
                                            dim_char=10,
                                            dim_posi=10,
                                            emb_proj_dim=400,
                                            dim_extras="50",
                                            extra_names="pos")
    conf_emb.do_validate()
    mod_emb = MyEmbedder(pc, conf_emb, vpack)
    conf_enc = EncConf().init_from_kwargs(enc_rnn_type="lstm2",
                                          enc_cnn_layer=1,
                                          enc_att_layer=1)
    conf_enc._input_dim = mod_emb.get_output_dims()[0]
    mod_enc = MyEncoder(pc, conf_enc)
    enc_output_dim = mod_enc.get_output_dims()[0]
    mod_scorer = BiAffineScorer(pc, enc_output_dim, enc_output_dim, 10)
    # build data
    word_padder = DataPadder(2, pad_lens=(0, 50), mask_range=2)
    char_padder = DataPadder(3, pad_lens=(0, 50, 20))
    word_idxes = []
    char_idxes = []
    pos_idxes = []
    for toks in reader:
        one_words = []
        one_chars = []
        for w in toks.tokens:
            one_words.append(voc_word.get_else_unk(w))
            one_chars.append([voc_char.get_else_unk(c) for c in w])
        word_idxes.append(one_words)
        char_idxes.append(one_chars)
        pos_idxes.append(
            np.random.randint(voc_pos.trg_len(), size=len(one_words)) +
            1)  # pred->trg
    word_arr, word_mask_arr = word_padder.pad(word_idxes)
    pos_arr, _ = word_padder.pad(pos_idxes)
    char_arr, _ = char_padder.pad(char_idxes)
    #
    # run
    rop = layers.RefreshOptions(hdrop=0.2, gdrop=0.2, fix_drop=True)
    for _ in range(5):
        mod_emb.refresh(rop)
        mod_enc.refresh(rop)
        mod_scorer.refresh(rop)
        #
        expr_emb = mod_emb(word_arr, char_arr, [pos_arr])
        zlog(BK.get_shape(expr_emb))
        expr_enc = mod_enc(expr_emb, word_mask_arr)
        zlog(BK.get_shape(expr_enc))
        #
        mask_expr = BK.input_real(word_mask_arr)
        score0 = mod_scorer.paired_score(expr_enc, expr_enc, mask_expr,
                                         mask_expr)
        score1 = mod_scorer.plain_score(expr_enc.unsqueeze(-2),
                                        expr_enc.unsqueeze(-3),
                                        mask_expr.unsqueeze(-1),
                                        mask_expr.unsqueeze(-2))
        #
        zmiss = float(BK.avg(score0 - score1))
        assert zmiss < 0.0001
    zlog("OK")
    pass