def expand(self, ags: List[BfsAgenda]): # flatten things out flattened_states = [] for ag in ags: flattened_states.extend(ag.beam) flattened_states.extend(ag.gbeam) # read/write cache to change status cur_cache = self.cache # update reprs and scores EfState.set_running_bidxes(flattened_states) if cur_cache.step > 0: bidxes = [s.prev.running_bidx for s in flattened_states] # extend previous cache to the current bsize (dimension 0 at batch-dim) cur_cache.arange_cache(bidxes) # update cahces and scores # todo(+N): for certain modes, some calculations are not needed cur_cache.update_cache(flattened_states) cur_cache.update_step() # get new masks and final scores scoring_mask_ct = cur_cache.scoring_mask_ct for sidx, state in enumerate(flattened_states): state.update_cands_mask( scoring_mask_ct[sidx]) # inplace mask update scoring_mask_ct *= cur_cache.scoring_fixed_mask_ct # apply the fixed masks scoring_mask_device = BK.to_device(scoring_mask_ct) cur_arc_scores = cur_cache.get_arc_scores( self.mw_arc) + Constants.REAL_PRAC_MIN * (1. - scoring_mask_device) # todo(+N): possible normalization for the scores return flattened_states, cur_arc_scores, scoring_mask_ct
def __init__(self, pc: BK.ParamCollection, mod: BK.Module, output_dims=None): super().__init__( pc, f"{self.__class__.__name__}:{mod.__class__.__name__}", None) # ----- self.mod = mod BK.to_device(self.mod) # move to target device self.output_dims = output_dims # collect parameters prefix_name = self.pc.nnc_name(self.name, True) + "/" named_params = self.pc.param_add_external(prefix_name, mod) # add to self.params for one_name, one_param in named_params: assert one_name not in self.params self.params[one_name] = one_param
def init_cache(self, enc_repr, enc_mask_arr, insts, g1_pack): # init caches and scores, [orig_bsize, max_slen, D] self.enc_repr = enc_repr self.scoring_fixed_mask_ct = self._init_fixed_mask(enc_mask_arr) # init other masks self.scoring_mask_ct = BK.copy(self.scoring_fixed_mask_ct) full_shape = BK.get_shape(self.scoring_mask_ct) # init oracle masks oracle_mask_ct = BK.constants(full_shape, value=0., device=BK.CPU_DEVICE) # label=0 means nothing, but still need it to avoid index error (dummy oracle for wrong/no-oracle states) oracle_label_ct = BK.constants(full_shape, value=0, dtype=BK.int64, device=BK.CPU_DEVICE) for i, inst in enumerate(insts): EfOracler.init_oracle_mask(inst, oracle_mask_ct[i], oracle_label_ct[i]) self.oracle_mask_t = BK.to_device(oracle_mask_ct) self.oracle_mask_ct = oracle_mask_ct self.oracle_label_t = BK.to_device(oracle_label_ct) # scoring cache self.scoring_cache.init_cache(enc_repr, g1_pack)
def arange_cache(self, bidxes): new_bsize = len(bidxes) # if the idxes are already fine, then no need to select if not Helper.check_is_range(bidxes, self.cur_bsize): # mask is on CPU to make assigning easier bidxes_ct = BK.input_idx(bidxes, BK.CPU_DEVICE) self.scoring_fixed_mask_ct = self.scoring_fixed_mask_ct.index_select( 0, bidxes_ct) self.scoring_mask_ct = self.scoring_mask_ct.index_select( 0, bidxes_ct) self.oracle_mask_ct = self.oracle_mask_ct.index_select( 0, bidxes_ct) # other things are all on target-device (possibly GPU) bidxes_device = BK.to_device(bidxes_ct) self.enc_repr = self.enc_repr.index_select(0, bidxes_device) self.scoring_cache.arange_cache(bidxes_device) # oracles self.oracle_mask_t = self.oracle_mask_t.index_select( 0, bidxes_device) self.oracle_label_t = self.oracle_label_t.index_select( 0, bidxes_device) # update bsize self.update_bsize(new_bsize)
def __init__(self, pc: BK.ParamCollection, bconf: Berter2Conf): super().__init__(pc, None, None) self.bconf = bconf self.model_name = bconf.bert2_model zlog( f"Loading pre-trained bert model for Berter2 of {self.model_name}") # Load pretrained model/tokenizer self.tokenizer = BertTokenizer.from_pretrained( self.model_name, do_lower_case=bconf.bert2_lower_case, cache_dir=None if (not bconf.bert2_cache_dir) else bconf.bert2_cache_dir) self.model = BertModel.from_pretrained( self.model_name, output_hidden_states=True, cache_dir=None if (not bconf.bert2_cache_dir) else bconf.bert2_cache_dir) zlog(f"Load done, move to default device {BK.DEFAULT_DEVICE}") BK.to_device(self.model) # ===== # zero padding embeddings? if bconf.bert2_zero_pademb: with BK.no_grad_env(): # todo(warn): specific!! zlog( f"Unusual operation: make bert's padding embedding (idx0) zero!!" ) self.model.embeddings.word_embeddings.weight[0].fill_(0.) # ===== # check trainable ones and add parameters # todo(+N): this part is specific and looking into the lib, can break in further versions!! # the idx of layer is [1(embed)] + [N(enc)], that is, layer0 is the output of embeddings self.hidden_size = self.model.config.hidden_size self.num_bert_layers = len( self.model.encoder.layer) + 1 # +1 for embeddings self.output_layers = [ i if i >= 0 else (self.num_bert_layers + i) for i in bconf.bert2_output_layers ] self.layer_is_output = [False] * self.num_bert_layers for i in self.output_layers: self.layer_is_output[i] = True # the highest used layer self.output_max_layer = max( self.output_layers) if len(self.output_layers) > 0 else -1 # from max-layer down self.trainable_layers = list(range(self.output_max_layer, -1, -1))[:bconf.bert2_trainable_layers] # the lowest trainable layer self.trainable_min_layer = min(self.trainable_layers) if len( self.trainable_layers) > 0 else (self.output_max_layer + 1) zlog(f"Build Berter2: {self}") # add parameters prefix_name = self.pc.nnc_name(self.name, True) + "/" for layer_idx in self.trainable_layers: if layer_idx == 0: # add the embedding layer infix_name = "embed" named_params = self.pc.param_add_external( prefix_name + infix_name, self.model.embeddings) else: # here we should use the original (-1) index infix_name = "enc" + str(layer_idx) named_params = self.pc.param_add_external( prefix_name + infix_name, self.model.encoder.layer[layer_idx - 1]) # add to self.params for one_name, one_param in named_params: assert f"{infix_name}_{one_name}" not in self.params self.params[f"{infix_name}_{one_name}"] = one_param # for dropout/mask input self.random_sample_stream = Random.stream(Random.random_sample) # ===== # for other inputs; todo(note): still, 0 means all-zero embedding self.other_embeds = [ self.add_sub_node( "OE", Embedding(self.pc, vsize, self.hidden_size, fix_row0=True)) for vsize in bconf.bert2_other_input_vsizes ] # ===== # for output if bconf.bert2_output_mode == "layered": self.output_f = lambda x: x self.output_dims = ( self.hidden_size, len(self.output_layers), ) elif bconf.bert2_output_mode == "concat": self.output_f = lambda x: x.view(BK.get_shape(x)[:-2] + [-1] ) # combine the last two dims self.output_dims = (self.hidden_size * len(self.output_layers), ) elif bconf.bert2_output_mode == "weighted": self.output_f = self.add_sub_node( "wb", BertFeaturesWeightLayer(pc, len(self.output_layers))) self.output_dims = (self.hidden_size, ) else: raise NotImplementedError( f"UNK mode for bert2 output: {bconf.bert2_output_mode}")