def __init__(self, trn: str = None, dev: str = None, tst: str = None, sampler_builder: SamplerBuilder = None, dependencies: str = None, scalar_mix: ScalarMixWithDropoutBuilder = None, use_raw_hidden_states=False, lr=2e-3, separate_optimizer=False, punct=False, tree=True, pad_rel=None, apply_constraint=False, single_root=True, no_zero_head=None, n_mlp_arc=500, n_mlp_rel=100, mlp_dropout=.33, mu=.9, nu=.9, epsilon=1e-12, decay=.75, decay_steps=5000, cls_is_bos=True, use_pos=False, **kwargs) -> None: super().__init__(**merge_locals_kwargs(locals(), kwargs)) self.vocabs = VocabDict()
def __init__(self, trn: str = None, dev: str = None, tst: str = None, sampler_builder: SamplerBuilder = None, dependencies: str = None, scalar_mix: ScalarMixWithDropoutBuilder = None, use_raw_hidden_states=False, lr=None, separate_optimizer=False, cls_is_bos=True, sep_is_eos=True, delete=('', ':', '``', "''", '.', '?', '!', '-NONE-', 'TOP', ',', 'S1'), equal=(('ADVP', 'PRT'), ), mbr=True, n_mlp_span=500, n_mlp_label=100, mlp_dropout=.33, no_subcategory=True, **kwargs) -> None: if isinstance(equal, tuple): equal = dict(equal) super().__init__(**merge_locals_kwargs(locals(), kwargs)) self.vocabs = VocabDict()
def __init__(self, trn: str = None, dev: str = None, tst: str = None, sampler_builder: SamplerBuilder = None, dependencies: str = None, scalar_mix: ScalarMixWithDropoutBuilder = None, use_raw_hidden_states=False, lr=1e-3, separate_optimizer=False, lexical_dropout=0.5, dropout=0.2, span_width_feature_size=20, ffnn_size=150, ffnn_depth=2, argument_ratio=0.8, predicate_ratio=0.4, max_arg_width=30, mlp_label_size=100, enforce_srl_constraint=False, use_gold_predicates=False, doc_level_offset=True, use_biaffine=False, loss_reduction='mean', with_argument=' ', **kwargs) -> None: super().__init__(**merge_locals_kwargs(locals(), kwargs)) self.vocabs = VocabDict()
def __init__(self, trn: str = None, dev: str = None, tst: str = None, sampler_builder: SamplerBuilder = None, dependencies: str = None, scalar_mix: ScalarMixWithDropoutBuilder = None, use_raw_hidden_states=False, lr=1e-3, separate_optimizer=False, cls_is_bos=False, sep_is_eos=False, delimiter=None, max_seq_len=None, sent_delimiter=None, char_level=False, hard_constraint=False, token_key='token', **kwargs) -> None: super().__init__(**merge_locals_kwargs(locals(), kwargs)) self.vocabs = VocabDict()
def __init__(self, trn: str = None, dev: str = None, tst: str = None, sampler_builder: SamplerBuilder = None, dependencies: str = None, scalar_mix: ScalarMixWithDropoutBuilder = None, use_raw_hidden_states=False, lr=1e-3, separate_optimizer=False, cls_is_bos=True, sep_is_eos=False, char2concept_dim=128, cnn_filters=((3, 256), ), concept_char_dim=32, concept_dim=300, dropout=0.2, embed_dim=512, eval_every=20, ff_embed_dim=1024, graph_layers=2, inference_layers=4, num_heads=8, rel_dim=100, snt_layers=4, unk_rate=0.33, vocab_min_freq=5, beam_size=8, alpha=0.6, max_time_step=100, amr_version='2.0', **kwargs) -> None: super().__init__(**merge_locals_kwargs(locals(), kwargs)) self.vocabs = VocabDict() utils_dir = get_resource(get_amr_utils(amr_version)) self.sense_restore = NodeRestore(NodeUtilities.from_json(utils_dir))
def batchify(data, vocabs: VocabDict, unk_rate=0., device=None, squeeze=False, tokenizer: TransformerSequenceTokenizer = None, shuffle_sibling=True, levi_graph=False, extra_arc=False, bart=False): rel_vocab: VocabWithFrequency = vocabs.rel _tok = list_to_tensor(data['token'], vocabs['token'], unk_rate=unk_rate) if 'token' in vocabs else None _lem = list_to_tensor(data['lemma'], vocabs['lemma'], unk_rate=unk_rate) _pos = list_to_tensor(data['pos'], vocabs['pos'], unk_rate=unk_rate) if 'pos' in vocabs else None _ner = list_to_tensor(data['ner'], vocabs['ner'], unk_rate=unk_rate) if 'ner' in vocabs else None _word_char = lists_of_string_to_tensor( data['token'], vocabs['word_char']) if 'word_char' in vocabs else None local_token2idx = data['token2idx'] local_idx2token = data['idx2token'] _cp_seq = list_to_tensor(data['cp_seq'], vocabs['predictable_concept'], local_token2idx) _mp_seq = list_to_tensor(data['mp_seq'], vocabs['predictable_concept'], local_token2idx) ret = copy(data) if 'amr' in data: concept, edge = [], [] for amr in data['amr']: if levi_graph == 'kahn': concept_i, edge_i = amr.to_levi(rel_vocab.get_frequency, shuffle=shuffle_sibling) else: concept_i, edge_i, _ = amr.root_centered_sort( rel_vocab.get_frequency, shuffle=shuffle_sibling) concept.append(concept_i) edge.append(edge_i) if levi_graph is True: concept_with_rel, edge_with_rel = levi_amr(concept, edge, extra_arc=extra_arc) concept = concept_with_rel edge = edge_with_rel augmented_concept = [[DUM] + x + [END] for x in concept] _concept_in = list_to_tensor(augmented_concept, vocabs.get('concept_and_rel', vocabs['concept']), unk_rate=unk_rate)[:-1] _concept_char_in = lists_of_string_to_tensor( augmented_concept, vocabs['concept_char'])[:-1] _concept_out = list_to_tensor(augmented_concept, vocabs['predictable_concept'], local_token2idx)[1:] out_conc_len, bsz = _concept_out.shape _rel = np.full((1 + out_conc_len, bsz, out_conc_len), rel_vocab.pad_idx) # v: [<dummy>, concept_0, ..., concept_l, ..., concept_{n-1}, <end>] u: [<dummy>, concept_0, ..., concept_l, ..., concept_{n-1}] for bidx, (x, y) in enumerate(zip(edge, concept)): for l, _ in enumerate(y): if l > 0: # l=1 => pos=l+1=2 _rel[l + 1, bidx, 1:l + 1] = rel_vocab.get_idx(NIL) for v, u, r in x: if levi_graph: r = 1 else: r = rel_vocab.get_idx(r) assert v > u, 'Invalid typological order' _rel[v + 1, bidx, u + 1] = r ret.update({ 'concept_in': _concept_in, 'concept_char_in': _concept_char_in, 'concept_out': _concept_out, 'rel': _rel }) else: augmented_concept = None token_length = ret.get('token_length', None) if token_length is not None and not isinstance(token_length, torch.Tensor): ret['token_length'] = torch.tensor( token_length, dtype=torch.long, device=device if (isinstance(device, torch.device) or device >= 0) else 'cpu:0') ret.update({ 'lem': _lem, 'tok': _tok, 'pos': _pos, 'ner': _ner, 'word_char': _word_char, 'copy_seq': np.stack([_cp_seq, _mp_seq], -1), 'local_token2idx': local_token2idx, 'local_idx2token': local_idx2token }) if squeeze: token_field = make_batch_for_squeeze(data, augmented_concept, tokenizer, device, ret) else: token_field = 'token' subtoken_to_tensor(token_field, ret) if bart: make_batch_for_bart(augmented_concept, ret, tokenizer, device) move_dict_to_device(ret, device) return ret
def load_vocabs(self, save_dir, filename='vocabs.json'): if hasattr(self, 'vocabs'): self.vocabs = VocabDict() self.vocabs.load_vocabs(save_dir, filename, VocabWithFrequency)
class GraphSequenceAbstractMeaningRepresentationParser(TorchComponent): def __init__(self, **kwargs) -> None: """ An AMR parser implementing Cai and Lam (2020) and my unpublished models. Args: **kwargs: """ super().__init__(**kwargs) self.model: GraphSequenceAbstractMeaningRepresentationModel = self.model def build_optimizer(self, trn, epochs, lr, adam_epsilon, weight_decay, warmup_steps, transformer_lr, gradient_accumulation, **kwargs): model = self.model if self.config.squeeze and False: num_training_steps = len(trn) * epochs // gradient_accumulation optimizer, scheduler = build_optimizer_scheduler_with_transformer( model, model.bert_encoder, lr, transformer_lr, num_training_steps, warmup_steps, weight_decay, adam_epsilon) else: weight_decay_params = [] no_weight_decay_params = [] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] for name, param in model.named_parameters(): if name.endswith('bias') or 'layer_norm' in name or any( nd in name for nd in no_decay): no_weight_decay_params.append(param) else: weight_decay_params.append(param) grouped_params = [{ 'params': weight_decay_params, 'weight_decay': weight_decay }, { 'params': no_weight_decay_params, 'weight_decay': 0. }] optimizer = AdamWeightDecayOptimizer(grouped_params, lr, betas=(0.9, 0.999), eps=adam_epsilon) lr_scale = self.config.lr_scale embed_dim = self.config.embed_dim scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda steps: lr_scale * embed_dim**-0.5 * min( (steps + 1)**-0.5, (steps + 1) * (warmup_steps**-1.5))) return optimizer, scheduler def build_criterion(self, **kwargs): pass def build_metric(self, **kwargs): pass def execute_training_loop(self, trn: PrefetchDataLoader, dev: PrefetchDataLoader, epochs, criterion, optimizer, metric, save_dir, logger: logging.Logger, devices, ratio_width=None, dev_data=None, gradient_accumulation=1, **kwargs): best_epoch, best_metric = 0, -1 timer = CountdownTimer(epochs) history = History() try: for epoch in range(1, epochs + 1): logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]") trn = self.fit_dataloader( trn, criterion, optimizer, metric, logger, ratio_width=ratio_width, gradient_accumulation=gradient_accumulation, history=history, save_dir=save_dir) report = f'{timer.elapsed_human}/{timer.total_time_human}' if epoch % self.config.eval_every == 0 or epoch == epochs: metric = self.evaluate_dataloader(dev, logger, dev_data, ratio_width=ratio_width, save_dir=save_dir, use_fast=True) if metric > best_metric: self.save_weights(save_dir) best_metric = metric best_epoch = epoch report += ' [red]saved[/red]' timer.log(report, ratio_percentage=False, newline=True, ratio=False) if best_epoch and best_epoch != epochs: logger.info( f'Restored the best model with {best_metric} saved {epochs - best_epoch} epochs ago' ) self.load_weights(save_dir) finally: trn.close() dev.close() def fit_dataloader(self, trn: PrefetchDataLoader, criterion, optimizer, metric, logger: logging.Logger, gradient_accumulation=1, ratio_width=None, history=None, save_dir=None, **kwargs): self.model.train() num_training_steps = len( trn) * self.config.epochs // gradient_accumulation shuffle_sibling_steps = self.config.shuffle_sibling_steps if isinstance(shuffle_sibling_steps, float): shuffle_sibling_steps = int(shuffle_sibling_steps * num_training_steps) timer = CountdownTimer( len([ i for i in range(history.num_mini_batches + 1, history.num_mini_batches + len(trn) + 1) if i % gradient_accumulation == 0 ])) total_loss = 0 optimizer, scheduler = optimizer correct_conc, total_conc, correct_rel, total_rel = 0, 0, 0, 0 for idx, batch in enumerate(trn): loss = self.compute_loss(batch) if self.config.joint_arc_concept or self.model.squeeze or self.config.bart: loss, (concept_correct, concept_total), rel_out = loss correct_conc += concept_correct total_conc += concept_total if rel_out is not None: rel_correct, rel_total = rel_out correct_rel += rel_correct total_rel += rel_total loss /= gradient_accumulation # loss = loss.sum() # For data parallel loss.backward() total_loss += loss.item() history.num_mini_batches += 1 if history.num_mini_batches % gradient_accumulation == 0: self._step(optimizer, scheduler) metric = '' if self.config.joint_arc_concept or self.model.squeeze or self.model.bart: metric = f' Concept acc: {correct_conc / total_conc:.2%}' if not self.config.levi_graph: metric += f' Relation acc: {correct_rel / total_rel:.2%}' timer.log( f'loss: {total_loss / (timer.current + 1):.4f} lr: {optimizer.param_groups[0]["lr"]:.2e}' + metric, ratio_percentage=None, ratio_width=ratio_width, logger=logger) if history.num_mini_batches // gradient_accumulation == shuffle_sibling_steps: trn.batchify = self.build_batchify(self.device, shuffle=True, shuffle_sibling=False) timer.print( f'Switched to [bold]deterministic order[/bold] after {shuffle_sibling_steps} steps', newline=True) del loss return trn def _step(self, optimizer, scheduler): if self.config.grad_norm: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm) optimizer.step() # model = self.model # print(mean_model(model)) optimizer.zero_grad() scheduler.step() def update_metrics(self, batch: dict, prediction: Union[Dict, List], metrics): if isinstance(prediction, dict): prediction = prediction['prediction'] assert len(prediction) == len(batch['ner']) for pred, gold in zip(prediction, batch['ner']): metrics(set(pred), set(gold)) def compute_loss(self, batch): # debug # gold = torch.load('/home/hhe43/amr_gs/batch.pt', map_location=self.device) # self.debug_assert_batch_equal(batch, gold) # set_seed() # end debug concept_loss, arc_loss, rel_loss, graph_arc_loss = self.model(batch) if self.config.joint_arc_concept or self.config.squeeze or self.config.bart: concept_loss, concept_correct, concept_total = concept_loss if rel_loss is not None: rel_loss, rel_correct, rel_total = rel_loss loss = concept_loss + arc_loss + rel_loss rel_acc = (rel_correct, rel_total) else: loss = concept_loss + arc_loss rel_acc = None return loss, (concept_correct, concept_total), rel_acc loss = concept_loss + arc_loss + rel_loss return loss def debug_assert_batch_equal(self, batch, gold): # assert torch.equal(batch['token_input_ids'], gold['bert_token']) for k, v in gold.items(): pred = batch.get(k, None) if pred is not None: if isinstance(v, torch.Tensor) and not torch.equal(pred, v): assert torch.equal(pred, v), f'{k} not equal' @torch.no_grad() def evaluate_dataloader(self, data: PrefetchDataLoader, logger, input, output=False, ratio_width=None, save_dir=None, use_fast=False, test=False, **kwargs): self.model.eval() pp = PostProcessor(self.vocabs['rel']) if not output: output = os.path.join(save_dir, os.path.basename(input) + '.pred') # Squeeze tokens and concepts into one transformer basically reduces the max num of inputs it can handle parse_data(self.model, pp, data, input, output, max_time_step=80 if self.model.squeeze else 100) # noinspection PyBroadException try: output = post_process(output, amr_version=self.config.get( 'amr_version', '2.0')) scores = smatch_eval(output, input.replace('.features.preproc', ''), use_fast=use_fast) except Exception: eprint(f'Evaluation failed due to the following error:') traceback.print_exc() eprint( 'As smatch usually fails on erroneous outputs produced at early epochs, ' 'it might be OK to ignore it. Now `nan` will be returned as the score.' ) scores = F1_(float("nan"), float("nan"), float("nan")) if logger: header = f'{len(data)}/{len(data)}' if not ratio_width: ratio_width = len(header) logger.info(header.rjust(ratio_width) + f' {scores}') if test: data.close() return scores def build_model(self, training=True, **kwargs) -> torch.nn.Module: transformer = self.config.encoder.module() model = GraphSequenceAbstractMeaningRepresentationModel( self.vocabs, **merge_dict(self.config, overwrite=True, encoder=transformer), tokenizer=self.config.encoder.transform()) # self.model = model # self.debug_load() return model def debug_load(self): model = self.model states = torch.load('/home/hhe43/amr_gs/model.pt', map_location=self.device) model.load_state_dict(states, strict=False) def build_dataloader(self, data, batch_size, shuffle=False, device=None, logger: logging.Logger = None, gradient_accumulation=1, batch_max_tokens=None, **kwargs) -> DataLoader: dataset, lens = self.build_dataset(data, logger, training=shuffle) if batch_max_tokens: batch_max_tokens //= gradient_accumulation if not shuffle: batch_max_tokens //= 2 sampler = SortingSampler(lens, batch_size=None, batch_max_tokens=batch_max_tokens, shuffle=shuffle) dataloader = PrefetchDataLoader( DataLoader(batch_sampler=sampler, dataset=dataset, collate_fn=merge_list_of_dict, num_workers=0), batchify=self.build_batchify(device, shuffle)) return dataloader def build_batchify(self, device, shuffle, shuffle_sibling=None): if shuffle_sibling is None: shuffle_sibling = shuffle return functools.partial( batchify, vocabs=self.vocabs, squeeze=self.config.get('squeeze', None), tokenizer=self.config.encoder.transform(), levi_graph=self.config.get('levi_graph', False), bart=self.config.get('bart', False), extra_arc=self.config.get('extra_arc', False), unk_rate=self.config.unk_rate if shuffle else 0, shuffle_sibling=shuffle_sibling, device=device) def build_dataset(self, data, logger: logging.Logger = None, training=True): dataset = AbstractMeaningRepresentationDataset( data, generate_idx=not training) if self.vocabs.mutable: self.build_vocabs(dataset, logger) self.vocabs.lock() self.vocabs.summary(logger) lens = [len(x['token']) + len(x['amr']) for x in dataset] dataset.append_transform( functools.partial(get_concepts, vocab=self.vocabs.predictable_concept, rel_vocab=self.vocabs.rel if self.config.get( 'separate_rel', False) else None)) dataset.append_transform(append_bos) # Tokenization will happen in batchify if not self.config.get('squeeze', None): dataset.append_transform(self.config.encoder.transform()) if isinstance(data, str): dataset.purge_cache() timer = CountdownTimer(len(dataset)) for each in dataset: timer.log( 'Caching samples [blink][yellow]...[/yellow][/blink]') return dataset, lens def build_vocabs(self, dataset, logger: logging.Logger = None, **kwargs): # debug # self.load_vocabs('/home/hhe43/elit/data/model/amr2.0/convert/') # return # collect concepts and relations conc = [] rel = [] predictable_conc = [ ] # concepts that are not able to generate by copying lemmas ('multi-sentence', 'sense-01') tokens = [] lemmas = [] poses = [] ners = [] repeat = 10 levi_graph = self.config.get('levi_graph', False) separate_rel = self.config.separate_rel timer = CountdownTimer(repeat * len(dataset)) for i in range(repeat): # run 10 times random sort to get the priorities of different types of edges for sample in dataset: amr, lem, tok, pos, ner = sample['amr'], sample[ 'lemma'], sample['token'], sample['pos'], sample['ner'] if levi_graph == 'kahn': concept, edge = amr.to_levi() else: concept, edge, not_ok = amr.root_centered_sort() if levi_graph is True: concept, edge = linearize(concept, edge, NIL, prefix=REL) lexical_concepts = set() for lemma in lem: lexical_concepts.add(lemma + '_') lexical_concepts.add(lemma) if i == 0: if separate_rel: edge = [(c, ) for c in concept if c.startswith(REL)] concept = [c for c in concept if not c.startswith(REL)] predictable_conc.append( [c for c in concept if c not in lexical_concepts]) conc.append(concept) tokens.append(tok) lemmas.append(lem) poses.append(pos) ners.append(ner) rel.append([e[-1] for e in edge]) timer.log( 'Building vocabs [blink][yellow]...[/yellow][/blink]') # make vocabularies token_vocab, token_char_vocab = make_vocab(tokens, char_level=True) lemma_vocab, lemma_char_vocab = make_vocab(lemmas, char_level=True) pos_vocab = make_vocab(poses) ner_vocab = make_vocab(ners) conc_vocab, conc_char_vocab = make_vocab(conc, char_level=True) predictable_conc_vocab = make_vocab(predictable_conc) num_predictable_conc = sum(len(x) for x in predictable_conc) num_conc = sum(len(x) for x in conc) rel_vocab = make_vocab(rel) logger.info( f'Predictable concept coverage {num_predictable_conc} / {num_conc} = {num_predictable_conc / num_conc:.2%}' ) vocabs = self.vocabs vocab_min_freq = self.config.get('vocab_min_freq', 5) vocabs.token = VocabWithFrequency(token_vocab, vocab_min_freq, specials=[CLS]) vocabs.lemma = VocabWithFrequency(lemma_vocab, vocab_min_freq, specials=[CLS]) vocabs.pos = VocabWithFrequency(pos_vocab, vocab_min_freq, specials=[CLS]) vocabs.ner = VocabWithFrequency(ner_vocab, vocab_min_freq, specials=[CLS]) vocabs.predictable_concept = VocabWithFrequency(predictable_conc_vocab, vocab_min_freq, specials=[DUM, END]) vocabs.concept = VocabWithFrequency(conc_vocab, vocab_min_freq, specials=[DUM, END]) vocabs.rel = VocabWithFrequency(rel_vocab, vocab_min_freq * 10, specials=[NIL]) vocabs.word_char = VocabWithFrequency(token_char_vocab, vocab_min_freq * 20, specials=[CLS, END]) vocabs.concept_char = VocabWithFrequency(conc_char_vocab, vocab_min_freq * 20, specials=[CLS, END]) if separate_rel: vocabs.concept_and_rel = VocabWithFrequency( conc_vocab + rel_vocab, vocab_min_freq, specials=[DUM, END, NIL]) # if levi_graph: # # max = 993 # tokenizer = self.config.encoder.transform() # rel_to_unused = dict() # for i, rel in enumerate(vocabs.rel.idx_to_token): # rel_to_unused[rel] = f'[unused{i + 100}]' # tokenizer.rel_to_unused = rel_to_unused def predict(self, data: Union[str, List[str]], batch_size: int = None, **kwargs): pass def fit(self, trn_data, dev_data, save_dir, encoder, batch_size=None, batch_max_tokens=17776, epochs=1000, gradient_accumulation=4, char2concept_dim=128, char2word_dim=128, cnn_filters=((3, 256), ), concept_char_dim=32, concept_dim=300, dropout=0.2, embed_dim=512, eval_every=20, ff_embed_dim=1024, graph_layers=2, inference_layers=4, lr_scale=1.0, ner_dim=16, num_heads=8, pos_dim=32, pretrained_file=None, rel_dim=100, snt_layers=4, start_rank=0, unk_rate=0.33, warmup_steps=2000, with_bert=True, word_char_dim=32, word_dim=300, lr=1., transformer_lr=None, adam_epsilon=1e-6, weight_decay=1e-4, grad_norm=1.0, joint_arc_concept=False, joint_rel=False, external_biaffine=False, optimize_every_layer=False, squeeze=False, levi_graph=False, separate_rel=False, extra_arc=False, bart=False, shuffle_sibling_steps=50000, vocab_min_freq=5, amr_version='2.0', devices=None, logger=None, seed=None, **kwargs): return super().fit(**merge_locals_kwargs(locals(), kwargs)) def load_vocabs(self, save_dir, filename='vocabs.json'): if hasattr(self, 'vocabs'): self.vocabs = VocabDict() self.vocabs.load_vocabs(save_dir, filename, VocabWithFrequency)
def __init__(self, **kwargs) -> None: super().__init__() self.model: Optional[torch.nn.Module] = None self.config = SerializableDict(**kwargs) self.vocabs = VocabDict()
class TorchComponent(Component, ABC): def __init__(self, **kwargs) -> None: super().__init__() self.model: Optional[torch.nn.Module] = None self.config = SerializableDict(**kwargs) self.vocabs = VocabDict() def _capture_config(self, locals_: Dict, exclude=('trn_data', 'dev_data', 'save_dir', 'kwargs', 'self', 'logger', 'verbose', 'dev_batch_size', '__class__', 'devices', 'eval_trn')): """Save arguments to config Args: locals_: Dict: exclude: (Default value = ('trn_data') 'dev_data': 'save_dir': 'kwargs': 'self': 'logger': 'verbose': 'dev_batch_size': '__class__': 'devices'): Returns: """ if 'kwargs' in locals_: locals_.update(locals_['kwargs']) locals_ = dict((k, v) for k, v in locals_.items() if k not in exclude and not k.startswith('_')) self.config.update(locals_) return self.config def save_weights(self, save_dir, filename='model.pt', trainable_only=True, **kwargs): model = self.model_ state_dict = model.state_dict() if trainable_only: trainable_names = set(n for n, p in model.named_parameters() if p.requires_grad) state_dict = dict( (n, p) for n, p in state_dict.items() if n in trainable_names) torch.save(state_dict, os.path.join(save_dir, filename)) def load_weights(self, save_dir, filename='model.pt', **kwargs): save_dir = get_resource(save_dir) filename = os.path.join(save_dir, filename) # flash(f'Loading model: {filename} [blink]...[/blink][/yellow]') self.model_.load_state_dict(torch.load(filename, map_location='cpu'), strict=False) # flash('') def save_config(self, save_dir, filename='config.json'): self._savable_config.save_json(os.path.join(save_dir, filename)) def load_config(self, save_dir, filename='config.json', **kwargs): save_dir = get_resource(save_dir) self.config.load_json(os.path.join(save_dir, filename)) self.config.update(kwargs) # overwrite config loaded from disk for k, v in self.config.items(): if isinstance(v, dict) and 'classpath' in v: self.config[k] = Configurable.from_config(v) self.on_config_ready(**self.config) def save_vocabs(self, save_dir, filename='vocabs.json'): if hasattr(self, 'vocabs'): self.vocabs.save_vocabs(save_dir, filename) def load_vocabs(self, save_dir, filename='vocabs.json'): if hasattr(self, 'vocabs'): self.vocabs = VocabDict() self.vocabs.load_vocabs(save_dir, filename) def save(self, save_dir: str, **kwargs): self.save_config(save_dir) self.save_vocabs(save_dir) self.save_weights(save_dir) def load(self, save_dir: str, devices=None, **kwargs): save_dir = get_resource(save_dir) # flash('Loading config and vocabs [blink][yellow]...[/yellow][/blink]') if devices is None and self.model: devices = self.devices self.load_config(save_dir, **kwargs) self.load_vocabs(save_dir) flash('Building model [blink][yellow]...[/yellow][/blink]') self.model = self.build_model(**merge_dict(self.config, training=False, **kwargs, overwrite=True, inplace=True)) flash('') self.load_weights(save_dir, **kwargs) self.to(devices) self.model.eval() def fit(self, trn_data, dev_data, save_dir, batch_size, epochs, devices=None, logger=None, seed=None, finetune=False, eval_trn=True, _device_placeholder=False, **kwargs): # Common initialization steps config = self._capture_config(locals()) if not logger: logger = self.build_logger('train', save_dir) if not seed: self.config.seed = 233 if isdebugging() else int(time.time()) set_seed(self.config.seed) logger.info(self._savable_config.to_json(sort=True)) if isinstance(devices, list) or devices is None or isinstance( devices, float): flash('[yellow]Querying CUDA devices [blink]...[/blink][/yellow]') devices = -1 if isdebugging() else cuda_devices(devices) flash('') # flash(f'Available GPUs: {devices}') if isinstance(devices, list): first_device = (devices[0] if devices else -1) elif isinstance(devices, dict): first_device = next(iter(devices.values())) elif isinstance(devices, int): first_device = devices else: first_device = -1 if _device_placeholder and first_device >= 0: _dummy_placeholder = self._create_dummy_placeholder_on( first_device) if finetune: if isinstance(finetune, str): self.load(finetune, devices=devices) else: self.load(save_dir, devices=devices) logger.info( f'Finetune model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}' f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.' ) self.on_config_ready(**self.config) trn = self.build_dataloader(**merge_dict(config, data=trn_data, batch_size=batch_size, shuffle=True, training=True, device=first_device, logger=logger, vocabs=self.vocabs, overwrite=True)) dev = self.build_dataloader( **merge_dict(config, data=dev_data, batch_size=batch_size, shuffle=False, training=None, device=first_device, logger=logger, vocabs=self.vocabs, overwrite=True)) if dev_data else None if not finetune: flash('[yellow]Building model [blink]...[/blink][/yellow]') self.model = self.build_model(**merge_dict(config, training=True)) flash('') logger.info( f'Model built with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}' f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.' ) assert self.model, 'build_model is not properly implemented.' _description = repr(self.model) if len(_description.split('\n')) < 10: logger.info(_description) self.save_config(save_dir) self.save_vocabs(save_dir) self.to(devices, logger) if _device_placeholder and first_device >= 0: del _dummy_placeholder criterion = self.build_criterion(**merge_dict(config, trn=trn)) optimizer = self.build_optimizer( **merge_dict(config, trn=trn, criterion=criterion)) metric = self.build_metric(**self.config) if hasattr(trn.dataset, '__len__') and dev and hasattr( dev.dataset, '__len__'): logger.info( f'{len(trn.dataset)}/{len(dev.dataset)} samples in trn/dev set.' ) trn_size = len(trn) // self.config.get('gradient_accumulation', 1) ratio_width = len(f'{trn_size}/{trn_size}') else: ratio_width = None return self.execute_training_loop(**merge_dict(config, trn=trn, dev=dev, epochs=epochs, criterion=criterion, optimizer=optimizer, metric=metric, logger=logger, save_dir=save_dir, devices=devices, ratio_width=ratio_width, trn_data=trn_data, dev_data=dev_data, eval_trn=eval_trn, overwrite=True)) def build_logger(self, name, save_dir): logger = init_logger(name=name, root_dir=save_dir, level=logging.INFO, fmt="%(message)s") return logger @abstractmethod def build_dataloader(self, data, batch_size, shuffle=False, device=None, logger: logging.Logger = None, **kwargs) -> DataLoader: pass def build_vocabs(self, **kwargs): pass @property def _savable_config(self): def convert(k, v): if hasattr(v, 'config'): v = v.config if isinstance(v, (set, tuple)): v = list(v) return k, v config = SerializableDict( convert(k, v) for k, v in sorted(self.config.items())) config.update({ # 'create_time': now_datetime(), 'classpath': classpath_of(self), 'elit_version': elit.__version__, }) return config @abstractmethod def build_optimizer(self, **kwargs): pass @abstractmethod def build_criterion(self, decoder, **kwargs): pass @abstractmethod def build_metric(self, **kwargs): pass @abstractmethod def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir, logger: logging.Logger, devices, ratio_width=None, **kwargs): pass @abstractmethod def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, **kwargs): pass @abstractmethod def evaluate_dataloader(self, data: DataLoader, criterion: Callable, metric=None, output=False, **kwargs): pass @abstractmethod def build_model(self, training=True, **kwargs) -> torch.nn.Module: raise NotImplementedError def evaluate(self, tst_data, save_dir=None, logger: logging.Logger = None, batch_size=None, output=False, **kwargs): if not self.model: raise RuntimeError('Call fit or load before evaluate.') if isinstance(tst_data, str): tst_data = get_resource(tst_data) filename = os.path.basename(tst_data) else: filename = None if output is True: output = self.generate_prediction_filename( tst_data if isinstance(tst_data, str) else 'test.txt', save_dir) if logger is None: _logger_name = basename_no_ext(filename) if filename else None logger = self.build_logger(_logger_name, save_dir) if not batch_size: batch_size = self.config.get('batch_size', 32) data = self.build_dataloader(**merge_dict(self.config, data=tst_data, batch_size=batch_size, shuffle=False, device=self.devices[0], logger=logger, overwrite=True)) dataset = data while dataset and hasattr(dataset, 'dataset'): dataset = dataset.dataset num_samples = len(dataset) if dataset else None if output and isinstance(dataset, TransformDataset): def add_idx(samples): for idx, sample in enumerate(samples): if sample: sample[IDX] = idx add_idx(dataset.data) if dataset.cache: add_idx(dataset.cache) criterion = self.build_criterion(**self.config) metric = self.build_metric(**self.config) start = time.time() outputs = self.evaluate_dataloader(data, criterion=criterion, filename=filename, output=output, input=tst_data, save_dir=save_dir, test=True, num_samples=num_samples, **merge_dict(self.config, batch_size=batch_size, metric=metric, logger=logger, **kwargs)) elapsed = time.time() - start if logger: if num_samples: logger.info( f'speed: {num_samples / elapsed:.0f} samples/second') else: logger.info(f'speed: {len(data) / elapsed:.0f} batches/second') return metric, outputs def generate_prediction_filename(self, tst_data, save_dir): assert isinstance( tst_data, str), 'tst_data has be a str in order to infer the output name' output = os.path.splitext(os.path.basename(tst_data)) output = os.path.join(save_dir, output[0] + '.pred' + output[1]) return output def to(self, devices=Union[int, float, List[int], Dict[str, Union[int, torch.device]]], logger: logging.Logger = None): if devices == -1 or devices == [-1]: devices = [] elif isinstance(devices, (int, float)) or devices is None: devices = cuda_devices(devices) if devices: if logger: logger.info( f'Using GPUs: [on_blue][cyan][bold]{devices}[/bold][/cyan][/on_blue]' ) if isinstance(devices, list): flash( f'Moving model to GPUs {devices} [blink][yellow]...[/yellow][/blink]' ) self.model = self.model.to(devices[0]) if len(devices) > 1 and not isdebugging() and not isinstance( self.model, nn.DataParallel): self.model = self.parallelize(devices) elif isinstance(devices, dict): for name, module in self.model.named_modules(): for regex, device in devices.items(): try: on_device: torch.device = next( module.parameters()).device except StopIteration: continue if on_device == device: continue if isinstance(device, int): if on_device.index == device: continue if re.match(regex, name): if not name: name = '*' flash( f'Moving module [yellow]{name}[/yellow] to [on_yellow][magenta][bold]{device}' f'[/bold][/magenta][/on_yellow]: [red]{regex}[/red]\n' ) module.to(device) else: raise ValueError(f'Unrecognized devices {devices}') flash('') else: if logger: logger.info('Using CPU') def parallelize(self, devices: List[Union[int, torch.device]]): return nn.DataParallel(self.model, device_ids=devices) @property def devices(self): if self.model is None: return None # next(parser.model.parameters()).device if hasattr(self.model, 'device_ids'): return self.model.device_ids device: torch.device = next(self.model.parameters()).device return [device] @property def device(self): devices = self.devices if not devices: return None return devices[0] def on_config_ready(self, **kwargs): pass @property def model_(self) -> nn.Module: """ The actual model when it's wrapped by a `DataParallel` Returns: The "real" model """ if isinstance(self.model, nn.DataParallel): return self.model.module return self.model # noinspection PyMethodOverriding @abstractmethod def predict(self, data: Union[str, List[str]], batch_size: int = None, **kwargs): pass def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]: batch = merge_list_of_dict(samples) return batch @staticmethod def _create_dummy_placeholder_on(device): if device < 0: device = 'cpu:0' return torch.zeros(16, 16, device=device) @torch.no_grad() def __call__(self, data, batch_size=None, **kwargs): return super().__call__( data, **merge_dict(self.config, overwrite=True, batch_size=batch_size or self.config.get('batch_size', None), **kwargs))
class GraphAbstractMeaningRepresentationParser(TorchComponent): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.model: GraphAbstractMeaningRepresentationModel = self.model self.sense_restore: NodeRestore = None def build_optimizer(self, trn, epochs, lr, adam_epsilon, weight_decay, warmup_steps, transformer_lr, gradient_accumulation, **kwargs): model = self.model num_training_steps = len(trn) * epochs // gradient_accumulation optimizer, scheduler = build_optimizer_scheduler_with_transformer( model, model.bert_encoder, lr, transformer_lr, num_training_steps, warmup_steps, weight_decay, adam_epsilon) return optimizer, scheduler def build_criterion(self, **kwargs): pass def build_metric(self, **kwargs): pass def execute_training_loop(self, trn: PrefetchDataLoader, dev: PrefetchDataLoader, epochs, criterion, optimizer, metric, save_dir, logger: logging.Logger, devices, ratio_width=None, dev_data=None, gradient_accumulation=1, **kwargs): best_epoch, best_metric = 0, -1 timer = CountdownTimer(epochs) history = History() try: for epoch in range(1, epochs + 1): logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]") trn = self.fit_dataloader( trn, criterion, optimizer, metric, logger, ratio_width=ratio_width, gradient_accumulation=gradient_accumulation, history=history, save_dir=save_dir) report = f'{timer.elapsed_human}/{timer.total_time_human}' if epoch % self.config.eval_every == 0 or epoch == epochs: metric = self.evaluate_dataloader(dev, logger, dev_data, ratio_width=ratio_width, save_dir=save_dir, use_fast=True) if metric > best_metric: self.save_weights(save_dir) best_metric = metric best_epoch = epoch report += ' [red]saved[/red]' timer.log(report, ratio_percentage=False, newline=True, ratio=False) if best_epoch and best_epoch != epochs: logger.info( f'Restored the best model with {best_metric} saved {epochs - best_epoch} epochs ago' ) self.load_weights(save_dir) finally: trn.close() dev.close() def fit_dataloader(self, trn: PrefetchDataLoader, criterion, optimizer, metric, logger: logging.Logger, gradient_accumulation=1, ratio_width=None, history=None, save_dir=None, **kwargs): self.model.train() num_training_steps = len( trn) * self.config.epochs // gradient_accumulation shuffle_sibling_steps = self.config.shuffle_sibling_steps if isinstance(shuffle_sibling_steps, float): shuffle_sibling_steps = int(shuffle_sibling_steps * num_training_steps) timer = CountdownTimer( len([ i for i in range(history.num_mini_batches + 1, history.num_mini_batches + len(trn) + 1) if i % gradient_accumulation == 0 ])) total_loss = 0 optimizer, scheduler = optimizer correct_conc, total_conc, correct_rel, total_rel = 0, 0, 0, 0 for idx, batch in enumerate(trn): loss = self.compute_loss(batch) loss, (concept_correct, concept_total), rel_out = loss correct_conc += concept_correct total_conc += concept_total if rel_out is not None: rel_correct, rel_total = rel_out correct_rel += rel_correct total_rel += rel_total loss /= gradient_accumulation # loss = loss.sum() # For data parallel loss.backward() total_loss += loss.item() history.num_mini_batches += 1 if history.num_mini_batches % gradient_accumulation == 0: self._step(optimizer, scheduler) metric = f' Concept acc: {correct_conc / total_conc:.2%}' metric += f' Relation acc: {correct_rel / total_rel:.2%}' timer.log( f'loss: {total_loss / (timer.current + 1):.4f} lr: {optimizer.param_groups[0]["lr"]:.2e}' + metric, ratio_percentage=None, ratio_width=ratio_width, logger=logger) if history.num_mini_batches // gradient_accumulation == shuffle_sibling_steps: trn.batchify = self.build_batchify(self.device, shuffle=True, shuffle_sibling=False) timer.print( f'Switched to [bold]deterministic order[/bold] after {shuffle_sibling_steps} steps', newline=True) del loss return trn def _step(self, optimizer, scheduler): if self.config.grad_norm: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm) optimizer.step() optimizer.zero_grad() scheduler.step() def update_metrics(self, batch: dict, prediction: Union[Dict, List], metrics): if isinstance(prediction, dict): prediction = prediction['prediction'] assert len(prediction) == len(batch['ner']) for pred, gold in zip(prediction, batch['ner']): metrics(set(pred), set(gold)) def compute_loss(self, batch): # debug # gold = torch.load('/home/hhe43/amr_gs/batch.pt', map_location=self.device) # self.debug_assert_batch_equal(batch, gold) # set_seed() # end debug concept_loss, arc_loss, rel_loss, graph_arc_loss = self.model(batch) concept_loss, concept_correct, concept_total = concept_loss if rel_loss is not None: rel_loss, rel_correct, rel_total = rel_loss loss = concept_loss + arc_loss + rel_loss rel_acc = (rel_correct, rel_total) else: loss = concept_loss + arc_loss rel_acc = None return loss, (concept_correct, concept_total), rel_acc def debug_assert_batch_equal(self, batch, gold): # assert torch.equal(batch['token_input_ids'], gold['bert_token']) for k, v in gold.items(): pred = batch.get(k, None) if pred is not None: if isinstance(v, torch.Tensor) and not torch.equal(pred, v): assert torch.equal(pred, v), f'{k} not equal' @torch.no_grad() def evaluate_dataloader(self, data: PrefetchDataLoader, logger, input, output=False, ratio_width=None, save_dir=None, use_fast=False, test=False, metric: SmatchScores = None, model=None, h=None, **kwargs): if not model: model = self.model model.eval() pp = PostProcessor(self.vocabs['rel']) if not save_dir: save_dir = tempdir(str(os.getpid())) if not output: output = os.path.join(save_dir, os.path.basename(input) + '.pred') # Squeeze tokens and concepts into one transformer basically reduces the max num of inputs it can handle parse_data(model, pp, data, input, output, max_time_step=80 if model.squeeze else 100, h=h) # noinspection PyBroadException try: output = post_process(output, amr_version=self.config.get( 'amr_version', '2.0')) scores = smatch_eval(output, input.replace('.features.preproc', ''), use_fast=use_fast) if metric: metric.clear() if isinstance(scores, F1_): metric['Smatch'] = scores else: metric.update(scores) except Exception: eprint(f'Evaluation failed due to the following error:') traceback.print_exc() eprint( 'As smatch usually fails on erroneous outputs produced at early epochs, ' 'it might be OK to ignore it. Now `nan` will be returned as the score.' ) scores = F1_(float("nan"), float("nan"), float("nan")) if metric: metric.clear() metric['Smatch'] = scores if logger: header = f'{len(data)}/{len(data)}' if not ratio_width: ratio_width = len(header) logger.info(header.rjust(ratio_width) + f' {scores}') if test: data.close() return scores def build_model(self, training=True, **kwargs) -> torch.nn.Module: transformer = self.config.encoder.module() model = GraphAbstractMeaningRepresentationModel( self.vocabs, **merge_dict(self.config, overwrite=True, encoder=transformer), tokenizer=self.config.encoder.transform()) return model def build_dataloader(self, data, batch_size, shuffle=False, device=None, logger: logging.Logger = None, gradient_accumulation=1, batch_max_tokens=None, **kwargs) -> DataLoader: dataset, lens = self.build_dataset( data, logger, training=shuffle, transform=self.config.encoder.transform()) if batch_max_tokens: batch_max_tokens //= gradient_accumulation if not shuffle: batch_max_tokens //= 2 sampler = SortingSampler(lens, batch_size=None, batch_max_tokens=batch_max_tokens, shuffle=shuffle) dataloader = PrefetchDataLoader( DataLoader(batch_sampler=sampler, dataset=dataset, collate_fn=merge_list_of_dict, num_workers=0), batchify=self.build_batchify(device, shuffle), prefetch=10 if isinstance(data, str) else None) return dataloader def build_batchify(self, device, shuffle, shuffle_sibling=None): if shuffle_sibling is None: shuffle_sibling = shuffle tokenizer = self.config.encoder.transform() if self.config.get( 'encoder', None) else None return functools.partial( batchify, vocabs=self.vocabs, squeeze=self.config.get('squeeze', None), tokenizer=tokenizer, levi_graph=self.config.get('levi_graph', False), bart=self.config.get('bart', False), extra_arc=self.config.get('extra_arc', False), unk_rate=self.config.unk_rate if shuffle else 0, shuffle_sibling=shuffle_sibling, device=device) def build_dataset(self, data, logger: logging.Logger = None, training=True, transform=None): dataset = AbstractMeaningRepresentationDataset( data, generate_idx=not training) if self.vocabs.mutable: self.build_vocabs(dataset, logger) self.vocabs.lock() self.vocabs.summary(logger) lens = [ len(x['token']) + len(x['amr'] if 'amr' in x else []) for x in dataset ] dataset.append_transform( functools.partial(get_concepts, vocab=self.vocabs.predictable_concept, rel_vocab=self.vocabs.rel if self.config.get( 'separate_rel', False) else None)) dataset.append_transform(append_bos) if transform: dataset.append_transform(transform) if isinstance(data, str): dataset.purge_cache() timer = CountdownTimer(len(dataset)) for each in dataset: timer.log( 'Caching samples [blink][yellow]...[/yellow][/blink]') return dataset, lens def build_vocabs(self, dataset, logger: logging.Logger = None, **kwargs): # debug # self.load_vocabs('/home/hhe43/elit/data/model/amr2.0/convert/') # return # collect concepts and relations conc = [] rel = [] predictable_conc = [ ] # concepts that are not able to generate by copying lemmas ('multi-sentence', 'sense-01') tokens = [] lemmas = [] poses = [] ners = [] repeat = 10 timer = CountdownTimer(repeat * len(dataset)) for i in range(repeat): # run 10 times random sort to get the priorities of different types of edges for sample in dataset: amr, lem, tok, pos, ner = sample['amr'], sample[ 'lemma'], sample['token'], sample['pos'], sample['ner'] concept, edge, not_ok = amr.root_centered_sort() lexical_concepts = set() for lemma in lem: lexical_concepts.add(lemma + '_') lexical_concepts.add(lemma) if i == 0: predictable_conc.append( [c for c in concept if c not in lexical_concepts]) conc.append(concept) tokens.append(tok) lemmas.append(lem) poses.append(pos) ners.append(ner) rel.append([e[-1] for e in edge]) timer.log( 'Building vocabs [blink][yellow]...[/yellow][/blink]') # make vocabularies lemma_vocab, lemma_char_vocab = make_vocab(lemmas, char_level=True) conc_vocab, conc_char_vocab = make_vocab(conc, char_level=True) predictable_conc_vocab = make_vocab(predictable_conc) num_predictable_conc = sum(len(x) for x in predictable_conc) num_conc = sum(len(x) for x in conc) rel_vocab = make_vocab(rel) logger.info( f'Predictable concept coverage {num_predictable_conc} / {num_conc} = {num_predictable_conc / num_conc:.2%}' ) vocabs = self.vocabs vocab_min_freq = self.config.get('vocab_min_freq', 5) vocabs.lemma = VocabWithFrequency(lemma_vocab, vocab_min_freq, specials=[CLS]) vocabs.predictable_concept = VocabWithFrequency(predictable_conc_vocab, vocab_min_freq, specials=[DUM, END]) vocabs.concept = VocabWithFrequency(conc_vocab, vocab_min_freq, specials=[DUM, END]) vocabs.rel = VocabWithFrequency(rel_vocab, vocab_min_freq * 10, specials=[NIL]) vocabs.concept_char = VocabWithFrequency(conc_char_vocab, vocab_min_freq * 20, specials=[CLS, END]) def predict(self, data: Union[str, List[str]], batch_size: int = None, **kwargs): if not data: return [] flat = self.input_is_flat(data) if flat: data = [data] samples = self.build_samples(data) dataloader = self.build_dataloader(samples, device=self.device, **merge_dict(self.config, overwrite=True, batch_size=batch_size)) pp = PostProcessor(self.vocabs['rel']) results = list(parse_data_(self.model, pp, dataloader)) for i, each in enumerate(results): amr_graph = AMRGraph(each) self.sense_restore.restore_graph(amr_graph) results[i] = amr_graph if flat: return results[0] return results def input_is_flat(self, data: List): return isinstance(data[0], tuple) def build_samples(self, data): samples = [] for each in data: token, lemma = zip(*each) samples.append({'token': list(token), 'lemma': list(lemma)}) return samples def fit(self, trn_data, dev_data, save_dir, encoder, batch_size=None, batch_max_tokens=17776, epochs=1000, gradient_accumulation=4, char2concept_dim=128, cnn_filters=((3, 256), ), concept_char_dim=32, concept_dim=300, dropout=0.2, embed_dim=512, eval_every=20, ff_embed_dim=1024, graph_layers=2, inference_layers=4, num_heads=8, rel_dim=100, snt_layers=4, unk_rate=0.33, warmup_steps=0.1, lr=1e-3, transformer_lr=1e-4, adam_epsilon=1e-6, weight_decay=0, grad_norm=1.0, shuffle_sibling_steps=0.9, vocab_min_freq=5, amr_version='2.0', devices=None, logger=None, seed=None, **kwargs): return super().fit(**merge_locals_kwargs(locals(), kwargs)) def load_vocabs(self, save_dir, filename='vocabs.json'): if hasattr(self, 'vocabs'): self.vocabs = VocabDict() self.vocabs.load_vocabs(save_dir, filename, VocabWithFrequency) def on_config_ready(self, **kwargs): super().on_config_ready(**kwargs) utils_dir = get_resource(get_amr_utils(self.config.amr_version)) self.sense_restore = NodeRestore(NodeUtilities.from_json(utils_dir))