Example #1
0
 def expand(self, ags: List[BfsAgenda]):
     # flatten things out
     flattened_states = []
     for ag in ags:
         flattened_states.extend(ag.beam)
         flattened_states.extend(ag.gbeam)
     # read/write cache to change status
     cur_cache = self.cache
     # update reprs and scores
     EfState.set_running_bidxes(flattened_states)
     if cur_cache.step > 0:
         bidxes = [s.prev.running_bidx for s in flattened_states]
         # extend previous cache to the current bsize (dimension 0 at batch-dim)
         cur_cache.arange_cache(bidxes)
         # update cahces and scores
         # todo(+N): for certain modes, some calculations are not needed
         cur_cache.update_cache(flattened_states)
     cur_cache.update_step()
     # get new masks and final scores
     scoring_mask_ct = cur_cache.scoring_mask_ct
     for sidx, state in enumerate(flattened_states):
         state.update_cands_mask(
             scoring_mask_ct[sidx])  # inplace mask update
     scoring_mask_ct *= cur_cache.scoring_fixed_mask_ct  # apply the fixed masks
     scoring_mask_device = BK.to_device(scoring_mask_ct)
     cur_arc_scores = cur_cache.get_arc_scores(
         self.mw_arc) + Constants.REAL_PRAC_MIN * (1. - scoring_mask_device)
     # todo(+N): possible normalization for the scores
     return flattened_states, cur_arc_scores, scoring_mask_ct
Example #2
0
 def __init__(self,
              pc: BK.ParamCollection,
              mod: BK.Module,
              output_dims=None):
     super().__init__(
         pc, f"{self.__class__.__name__}:{mod.__class__.__name__}", None)
     # -----
     self.mod = mod
     BK.to_device(self.mod)  # move to target device
     self.output_dims = output_dims
     # collect parameters
     prefix_name = self.pc.nnc_name(self.name, True) + "/"
     named_params = self.pc.param_add_external(prefix_name, mod)
     # add to self.params
     for one_name, one_param in named_params:
         assert one_name not in self.params
         self.params[one_name] = one_param
Example #3
0
 def init_cache(self, enc_repr, enc_mask_arr, insts, g1_pack):
     # init caches and scores, [orig_bsize, max_slen, D]
     self.enc_repr = enc_repr
     self.scoring_fixed_mask_ct = self._init_fixed_mask(enc_mask_arr)
     # init other masks
     self.scoring_mask_ct = BK.copy(self.scoring_fixed_mask_ct)
     full_shape = BK.get_shape(self.scoring_mask_ct)
     # init oracle masks
     oracle_mask_ct = BK.constants(full_shape,
                                   value=0.,
                                   device=BK.CPU_DEVICE)
     # label=0 means nothing, but still need it to avoid index error (dummy oracle for wrong/no-oracle states)
     oracle_label_ct = BK.constants(full_shape,
                                    value=0,
                                    dtype=BK.int64,
                                    device=BK.CPU_DEVICE)
     for i, inst in enumerate(insts):
         EfOracler.init_oracle_mask(inst, oracle_mask_ct[i],
                                    oracle_label_ct[i])
     self.oracle_mask_t = BK.to_device(oracle_mask_ct)
     self.oracle_mask_ct = oracle_mask_ct
     self.oracle_label_t = BK.to_device(oracle_label_ct)
     # scoring cache
     self.scoring_cache.init_cache(enc_repr, g1_pack)
Example #4
0
 def arange_cache(self, bidxes):
     new_bsize = len(bidxes)
     # if the idxes are already fine, then no need to select
     if not Helper.check_is_range(bidxes, self.cur_bsize):
         # mask is on CPU to make assigning easier
         bidxes_ct = BK.input_idx(bidxes, BK.CPU_DEVICE)
         self.scoring_fixed_mask_ct = self.scoring_fixed_mask_ct.index_select(
             0, bidxes_ct)
         self.scoring_mask_ct = self.scoring_mask_ct.index_select(
             0, bidxes_ct)
         self.oracle_mask_ct = self.oracle_mask_ct.index_select(
             0, bidxes_ct)
         # other things are all on target-device (possibly GPU)
         bidxes_device = BK.to_device(bidxes_ct)
         self.enc_repr = self.enc_repr.index_select(0, bidxes_device)
         self.scoring_cache.arange_cache(bidxes_device)
         # oracles
         self.oracle_mask_t = self.oracle_mask_t.index_select(
             0, bidxes_device)
         self.oracle_label_t = self.oracle_label_t.index_select(
             0, bidxes_device)
         # update bsize
         self.update_bsize(new_bsize)
Example #5
0
 def __init__(self, pc: BK.ParamCollection, bconf: Berter2Conf):
     super().__init__(pc, None, None)
     self.bconf = bconf
     self.model_name = bconf.bert2_model
     zlog(
         f"Loading pre-trained bert model for Berter2 of {self.model_name}")
     # Load pretrained model/tokenizer
     self.tokenizer = BertTokenizer.from_pretrained(
         self.model_name,
         do_lower_case=bconf.bert2_lower_case,
         cache_dir=None if
         (not bconf.bert2_cache_dir) else bconf.bert2_cache_dir)
     self.model = BertModel.from_pretrained(
         self.model_name,
         output_hidden_states=True,
         cache_dir=None if
         (not bconf.bert2_cache_dir) else bconf.bert2_cache_dir)
     zlog(f"Load done, move to default device {BK.DEFAULT_DEVICE}")
     BK.to_device(self.model)
     # =====
     # zero padding embeddings?
     if bconf.bert2_zero_pademb:
         with BK.no_grad_env():
             # todo(warn): specific!!
             zlog(
                 f"Unusual operation: make bert's padding embedding (idx0) zero!!"
             )
             self.model.embeddings.word_embeddings.weight[0].fill_(0.)
     # =====
     # check trainable ones and add parameters
     # todo(+N): this part is specific and looking into the lib, can break in further versions!!
     # the idx of layer is [1(embed)] + [N(enc)], that is, layer0 is the output of embeddings
     self.hidden_size = self.model.config.hidden_size
     self.num_bert_layers = len(
         self.model.encoder.layer) + 1  # +1 for embeddings
     self.output_layers = [
         i if i >= 0 else (self.num_bert_layers + i)
         for i in bconf.bert2_output_layers
     ]
     self.layer_is_output = [False] * self.num_bert_layers
     for i in self.output_layers:
         self.layer_is_output[i] = True
     # the highest used layer
     self.output_max_layer = max(
         self.output_layers) if len(self.output_layers) > 0 else -1
     # from max-layer down
     self.trainable_layers = list(range(self.output_max_layer, -1,
                                        -1))[:bconf.bert2_trainable_layers]
     # the lowest trainable layer
     self.trainable_min_layer = min(self.trainable_layers) if len(
         self.trainable_layers) > 0 else (self.output_max_layer + 1)
     zlog(f"Build Berter2: {self}")
     # add parameters
     prefix_name = self.pc.nnc_name(self.name, True) + "/"
     for layer_idx in self.trainable_layers:
         if layer_idx == 0:  # add the embedding layer
             infix_name = "embed"
             named_params = self.pc.param_add_external(
                 prefix_name + infix_name, self.model.embeddings)
         else:
             # here we should use the original (-1) index
             infix_name = "enc" + str(layer_idx)
             named_params = self.pc.param_add_external(
                 prefix_name + infix_name,
                 self.model.encoder.layer[layer_idx - 1])
         # add to self.params
         for one_name, one_param in named_params:
             assert f"{infix_name}_{one_name}" not in self.params
             self.params[f"{infix_name}_{one_name}"] = one_param
     # for dropout/mask input
     self.random_sample_stream = Random.stream(Random.random_sample)
     # =====
     # for other inputs; todo(note): still, 0 means all-zero embedding
     self.other_embeds = [
         self.add_sub_node(
             "OE", Embedding(self.pc,
                             vsize,
                             self.hidden_size,
                             fix_row0=True))
         for vsize in bconf.bert2_other_input_vsizes
     ]
     # =====
     # for output
     if bconf.bert2_output_mode == "layered":
         self.output_f = lambda x: x
         self.output_dims = (
             self.hidden_size,
             len(self.output_layers),
         )
     elif bconf.bert2_output_mode == "concat":
         self.output_f = lambda x: x.view(BK.get_shape(x)[:-2] + [-1]
                                          )  # combine the last two dims
         self.output_dims = (self.hidden_size * len(self.output_layers), )
     elif bconf.bert2_output_mode == "weighted":
         self.output_f = self.add_sub_node(
             "wb", BertFeaturesWeightLayer(pc, len(self.output_layers)))
         self.output_dims = (self.hidden_size, )
     else:
         raise NotImplementedError(
             f"UNK mode for bert2 output: {bconf.bert2_output_mode}")