Exemple #1
0
 def __init__(self,
              trn: str = None,
              dev: str = None,
              tst: str = None,
              sampler_builder: SamplerBuilder = None,
              dependencies: str = None,
              scalar_mix: ScalarMixWithDropoutBuilder = None,
              use_raw_hidden_states=False,
              lr=2e-3, separate_optimizer=False,
              punct=False,
              tree=True,
              pad_rel=None,
              apply_constraint=False,
              single_root=True,
              no_zero_head=None,
              n_mlp_arc=500,
              n_mlp_rel=100,
              mlp_dropout=.33,
              mu=.9,
              nu=.9,
              epsilon=1e-12,
              decay=.75,
              decay_steps=5000,
              cls_is_bos=True,
              use_pos=False,
              **kwargs) -> None:
     super().__init__(**merge_locals_kwargs(locals(), kwargs))
     self.vocabs = VocabDict()
 def __init__(self,
              trn: str = None,
              dev: str = None,
              tst: str = None,
              sampler_builder: SamplerBuilder = None,
              dependencies: str = None,
              scalar_mix: ScalarMixWithDropoutBuilder = None,
              use_raw_hidden_states=False,
              lr=None,
              separate_optimizer=False,
              cls_is_bos=True,
              sep_is_eos=True,
              delete=('', ':', '``', "''", '.', '?', '!', '-NONE-', 'TOP',
                      ',', 'S1'),
              equal=(('ADVP', 'PRT'), ),
              mbr=True,
              n_mlp_span=500,
              n_mlp_label=100,
              mlp_dropout=.33,
              no_subcategory=True,
              **kwargs) -> None:
     if isinstance(equal, tuple):
         equal = dict(equal)
     super().__init__(**merge_locals_kwargs(locals(), kwargs))
     self.vocabs = VocabDict()
Exemple #3
0
 def __init__(self,
              trn: str = None,
              dev: str = None,
              tst: str = None,
              sampler_builder: SamplerBuilder = None,
              dependencies: str = None,
              scalar_mix: ScalarMixWithDropoutBuilder = None,
              use_raw_hidden_states=False,
              lr=1e-3,
              separate_optimizer=False,
              lexical_dropout=0.5,
              dropout=0.2,
              span_width_feature_size=20,
              ffnn_size=150,
              ffnn_depth=2,
              argument_ratio=0.8,
              predicate_ratio=0.4,
              max_arg_width=30,
              mlp_label_size=100,
              enforce_srl_constraint=False,
              use_gold_predicates=False,
              doc_level_offset=True,
              use_biaffine=False,
              loss_reduction='mean',
              with_argument=' ',
              **kwargs) -> None:
     super().__init__(**merge_locals_kwargs(locals(), kwargs))
     self.vocabs = VocabDict()
Exemple #4
0
 def __init__(self,
              trn: str = None,
              dev: str = None,
              tst: str = None,
              sampler_builder: SamplerBuilder = None,
              dependencies: str = None,
              scalar_mix: ScalarMixWithDropoutBuilder = None,
              use_raw_hidden_states=False,
              lr=1e-3,
              separate_optimizer=False,
              cls_is_bos=False,
              sep_is_eos=False,
              delimiter=None,
              max_seq_len=None,
              sent_delimiter=None,
              char_level=False,
              hard_constraint=False,
              token_key='token',
              **kwargs) -> None:
     super().__init__(**merge_locals_kwargs(locals(), kwargs))
     self.vocabs = VocabDict()
Exemple #5
0
 def __init__(self,
              trn: str = None,
              dev: str = None,
              tst: str = None,
              sampler_builder: SamplerBuilder = None,
              dependencies: str = None,
              scalar_mix: ScalarMixWithDropoutBuilder = None,
              use_raw_hidden_states=False,
              lr=1e-3,
              separate_optimizer=False,
              cls_is_bos=True,
              sep_is_eos=False,
              char2concept_dim=128,
              cnn_filters=((3, 256), ),
              concept_char_dim=32,
              concept_dim=300,
              dropout=0.2,
              embed_dim=512,
              eval_every=20,
              ff_embed_dim=1024,
              graph_layers=2,
              inference_layers=4,
              num_heads=8,
              rel_dim=100,
              snt_layers=4,
              unk_rate=0.33,
              vocab_min_freq=5,
              beam_size=8,
              alpha=0.6,
              max_time_step=100,
              amr_version='2.0',
              **kwargs) -> None:
     super().__init__(**merge_locals_kwargs(locals(), kwargs))
     self.vocabs = VocabDict()
     utils_dir = get_resource(get_amr_utils(amr_version))
     self.sense_restore = NodeRestore(NodeUtilities.from_json(utils_dir))
Exemple #6
0
def batchify(data,
             vocabs: VocabDict,
             unk_rate=0.,
             device=None,
             squeeze=False,
             tokenizer: TransformerSequenceTokenizer = None,
             shuffle_sibling=True,
             levi_graph=False,
             extra_arc=False,
             bart=False):
    rel_vocab: VocabWithFrequency = vocabs.rel
    _tok = list_to_tensor(data['token'], vocabs['token'],
                          unk_rate=unk_rate) if 'token' in vocabs else None
    _lem = list_to_tensor(data['lemma'], vocabs['lemma'], unk_rate=unk_rate)
    _pos = list_to_tensor(data['pos'], vocabs['pos'],
                          unk_rate=unk_rate) if 'pos' in vocabs else None
    _ner = list_to_tensor(data['ner'], vocabs['ner'],
                          unk_rate=unk_rate) if 'ner' in vocabs else None
    _word_char = lists_of_string_to_tensor(
        data['token'], vocabs['word_char']) if 'word_char' in vocabs else None

    local_token2idx = data['token2idx']
    local_idx2token = data['idx2token']
    _cp_seq = list_to_tensor(data['cp_seq'], vocabs['predictable_concept'],
                             local_token2idx)
    _mp_seq = list_to_tensor(data['mp_seq'], vocabs['predictable_concept'],
                             local_token2idx)

    ret = copy(data)
    if 'amr' in data:
        concept, edge = [], []
        for amr in data['amr']:
            if levi_graph == 'kahn':
                concept_i, edge_i = amr.to_levi(rel_vocab.get_frequency,
                                                shuffle=shuffle_sibling)
            else:
                concept_i, edge_i, _ = amr.root_centered_sort(
                    rel_vocab.get_frequency, shuffle=shuffle_sibling)
            concept.append(concept_i)
            edge.append(edge_i)
        if levi_graph is True:
            concept_with_rel, edge_with_rel = levi_amr(concept,
                                                       edge,
                                                       extra_arc=extra_arc)
            concept = concept_with_rel
            edge = edge_with_rel

        augmented_concept = [[DUM] + x + [END] for x in concept]

        _concept_in = list_to_tensor(augmented_concept,
                                     vocabs.get('concept_and_rel',
                                                vocabs['concept']),
                                     unk_rate=unk_rate)[:-1]
        _concept_char_in = lists_of_string_to_tensor(
            augmented_concept, vocabs['concept_char'])[:-1]
        _concept_out = list_to_tensor(augmented_concept,
                                      vocabs['predictable_concept'],
                                      local_token2idx)[1:]

        out_conc_len, bsz = _concept_out.shape
        _rel = np.full((1 + out_conc_len, bsz, out_conc_len),
                       rel_vocab.pad_idx)
        # v: [<dummy>, concept_0, ..., concept_l, ..., concept_{n-1}, <end>] u: [<dummy>, concept_0, ..., concept_l, ..., concept_{n-1}]

        for bidx, (x, y) in enumerate(zip(edge, concept)):
            for l, _ in enumerate(y):
                if l > 0:
                    # l=1 => pos=l+1=2
                    _rel[l + 1, bidx, 1:l + 1] = rel_vocab.get_idx(NIL)
            for v, u, r in x:
                if levi_graph:
                    r = 1
                else:
                    r = rel_vocab.get_idx(r)
                assert v > u, 'Invalid typological order'
                _rel[v + 1, bidx, u + 1] = r
        ret.update({
            'concept_in': _concept_in,
            'concept_char_in': _concept_char_in,
            'concept_out': _concept_out,
            'rel': _rel
        })
    else:
        augmented_concept = None

    token_length = ret.get('token_length', None)
    if token_length is not None and not isinstance(token_length, torch.Tensor):
        ret['token_length'] = torch.tensor(
            token_length,
            dtype=torch.long,
            device=device if
            (isinstance(device, torch.device) or device >= 0) else 'cpu:0')
    ret.update({
        'lem': _lem,
        'tok': _tok,
        'pos': _pos,
        'ner': _ner,
        'word_char': _word_char,
        'copy_seq': np.stack([_cp_seq, _mp_seq], -1),
        'local_token2idx': local_token2idx,
        'local_idx2token': local_idx2token
    })
    if squeeze:
        token_field = make_batch_for_squeeze(data, augmented_concept,
                                             tokenizer, device, ret)
    else:
        token_field = 'token'
    subtoken_to_tensor(token_field, ret)
    if bart:
        make_batch_for_bart(augmented_concept, ret, tokenizer, device)
    move_dict_to_device(ret, device)

    return ret
Exemple #7
0
 def load_vocabs(self, save_dir, filename='vocabs.json'):
     if hasattr(self, 'vocabs'):
         self.vocabs = VocabDict()
         self.vocabs.load_vocabs(save_dir, filename, VocabWithFrequency)
Exemple #8
0
class GraphSequenceAbstractMeaningRepresentationParser(TorchComponent):
    def __init__(self, **kwargs) -> None:
        """
        An AMR parser implementing Cai and Lam (2020) and my unpublished models.

        Args:
            **kwargs:
        """
        super().__init__(**kwargs)
        self.model: GraphSequenceAbstractMeaningRepresentationModel = self.model

    def build_optimizer(self, trn, epochs, lr, adam_epsilon, weight_decay,
                        warmup_steps, transformer_lr, gradient_accumulation,
                        **kwargs):
        model = self.model
        if self.config.squeeze and False:
            num_training_steps = len(trn) * epochs // gradient_accumulation
            optimizer, scheduler = build_optimizer_scheduler_with_transformer(
                model, model.bert_encoder, lr, transformer_lr,
                num_training_steps, warmup_steps, weight_decay, adam_epsilon)
        else:
            weight_decay_params = []
            no_weight_decay_params = []
            no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
            for name, param in model.named_parameters():
                if name.endswith('bias') or 'layer_norm' in name or any(
                        nd in name for nd in no_decay):
                    no_weight_decay_params.append(param)
                else:
                    weight_decay_params.append(param)
            grouped_params = [{
                'params': weight_decay_params,
                'weight_decay': weight_decay
            }, {
                'params': no_weight_decay_params,
                'weight_decay': 0.
            }]
            optimizer = AdamWeightDecayOptimizer(grouped_params,
                                                 lr,
                                                 betas=(0.9, 0.999),
                                                 eps=adam_epsilon)
            lr_scale = self.config.lr_scale
            embed_dim = self.config.embed_dim
            scheduler = torch.optim.lr_scheduler.LambdaLR(
                optimizer, lambda steps: lr_scale * embed_dim**-0.5 * min(
                    (steps + 1)**-0.5, (steps + 1) * (warmup_steps**-1.5)))
        return optimizer, scheduler

    def build_criterion(self, **kwargs):
        pass

    def build_metric(self, **kwargs):
        pass

    def execute_training_loop(self,
                              trn: PrefetchDataLoader,
                              dev: PrefetchDataLoader,
                              epochs,
                              criterion,
                              optimizer,
                              metric,
                              save_dir,
                              logger: logging.Logger,
                              devices,
                              ratio_width=None,
                              dev_data=None,
                              gradient_accumulation=1,
                              **kwargs):
        best_epoch, best_metric = 0, -1
        timer = CountdownTimer(epochs)
        history = History()
        try:
            for epoch in range(1, epochs + 1):
                logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
                trn = self.fit_dataloader(
                    trn,
                    criterion,
                    optimizer,
                    metric,
                    logger,
                    ratio_width=ratio_width,
                    gradient_accumulation=gradient_accumulation,
                    history=history,
                    save_dir=save_dir)
                report = f'{timer.elapsed_human}/{timer.total_time_human}'
                if epoch % self.config.eval_every == 0 or epoch == epochs:
                    metric = self.evaluate_dataloader(dev,
                                                      logger,
                                                      dev_data,
                                                      ratio_width=ratio_width,
                                                      save_dir=save_dir,
                                                      use_fast=True)
                    if metric > best_metric:
                        self.save_weights(save_dir)
                        best_metric = metric
                        best_epoch = epoch
                        report += ' [red]saved[/red]'
                timer.log(report,
                          ratio_percentage=False,
                          newline=True,
                          ratio=False)
            if best_epoch and best_epoch != epochs:
                logger.info(
                    f'Restored the best model with {best_metric} saved {epochs - best_epoch} epochs ago'
                )
                self.load_weights(save_dir)
        finally:
            trn.close()
            dev.close()

    def fit_dataloader(self,
                       trn: PrefetchDataLoader,
                       criterion,
                       optimizer,
                       metric,
                       logger: logging.Logger,
                       gradient_accumulation=1,
                       ratio_width=None,
                       history=None,
                       save_dir=None,
                       **kwargs):
        self.model.train()
        num_training_steps = len(
            trn) * self.config.epochs // gradient_accumulation
        shuffle_sibling_steps = self.config.shuffle_sibling_steps
        if isinstance(shuffle_sibling_steps, float):
            shuffle_sibling_steps = int(shuffle_sibling_steps *
                                        num_training_steps)
        timer = CountdownTimer(
            len([
                i for i in range(history.num_mini_batches +
                                 1, history.num_mini_batches + len(trn) + 1)
                if i % gradient_accumulation == 0
            ]))
        total_loss = 0
        optimizer, scheduler = optimizer
        correct_conc, total_conc, correct_rel, total_rel = 0, 0, 0, 0
        for idx, batch in enumerate(trn):
            loss = self.compute_loss(batch)
            if self.config.joint_arc_concept or self.model.squeeze or self.config.bart:
                loss, (concept_correct, concept_total), rel_out = loss
                correct_conc += concept_correct
                total_conc += concept_total
                if rel_out is not None:
                    rel_correct, rel_total = rel_out
                    correct_rel += rel_correct
                    total_rel += rel_total
            loss /= gradient_accumulation
            # loss = loss.sum()  # For data parallel
            loss.backward()
            total_loss += loss.item()
            history.num_mini_batches += 1
            if history.num_mini_batches % gradient_accumulation == 0:
                self._step(optimizer, scheduler)
                metric = ''
                if self.config.joint_arc_concept or self.model.squeeze or self.model.bart:
                    metric = f' Concept acc: {correct_conc / total_conc:.2%}'
                    if not self.config.levi_graph:
                        metric += f' Relation acc: {correct_rel / total_rel:.2%}'
                timer.log(
                    f'loss: {total_loss / (timer.current + 1):.4f} lr: {optimizer.param_groups[0]["lr"]:.2e}'
                    + metric,
                    ratio_percentage=None,
                    ratio_width=ratio_width,
                    logger=logger)

                if history.num_mini_batches // gradient_accumulation == shuffle_sibling_steps:
                    trn.batchify = self.build_batchify(self.device,
                                                       shuffle=True,
                                                       shuffle_sibling=False)
                    timer.print(
                        f'Switched to [bold]deterministic order[/bold] after {shuffle_sibling_steps} steps',
                        newline=True)
            del loss
        return trn

    def _step(self, optimizer, scheduler):
        if self.config.grad_norm:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.config.grad_norm)
        optimizer.step()
        # model = self.model
        # print(mean_model(model))
        optimizer.zero_grad()
        scheduler.step()

    def update_metrics(self, batch: dict, prediction: Union[Dict, List],
                       metrics):
        if isinstance(prediction, dict):
            prediction = prediction['prediction']
        assert len(prediction) == len(batch['ner'])
        for pred, gold in zip(prediction, batch['ner']):
            metrics(set(pred), set(gold))

    def compute_loss(self, batch):
        # debug
        # gold = torch.load('/home/hhe43/amr_gs/batch.pt', map_location=self.device)
        # self.debug_assert_batch_equal(batch, gold)
        # set_seed()
        # end debug
        concept_loss, arc_loss, rel_loss, graph_arc_loss = self.model(batch)
        if self.config.joint_arc_concept or self.config.squeeze or self.config.bart:
            concept_loss, concept_correct, concept_total = concept_loss
            if rel_loss is not None:
                rel_loss, rel_correct, rel_total = rel_loss
                loss = concept_loss + arc_loss + rel_loss
                rel_acc = (rel_correct, rel_total)
            else:
                loss = concept_loss + arc_loss
                rel_acc = None
            return loss, (concept_correct, concept_total), rel_acc
        loss = concept_loss + arc_loss + rel_loss
        return loss

    def debug_assert_batch_equal(self, batch, gold):
        # assert torch.equal(batch['token_input_ids'], gold['bert_token'])
        for k, v in gold.items():
            pred = batch.get(k, None)
            if pred is not None:
                if isinstance(v, torch.Tensor) and not torch.equal(pred, v):
                    assert torch.equal(pred, v), f'{k} not equal'

    @torch.no_grad()
    def evaluate_dataloader(self,
                            data: PrefetchDataLoader,
                            logger,
                            input,
                            output=False,
                            ratio_width=None,
                            save_dir=None,
                            use_fast=False,
                            test=False,
                            **kwargs):
        self.model.eval()
        pp = PostProcessor(self.vocabs['rel'])
        if not output:
            output = os.path.join(save_dir, os.path.basename(input) + '.pred')
        # Squeeze tokens and concepts into one transformer basically reduces the max num of inputs it can handle
        parse_data(self.model,
                   pp,
                   data,
                   input,
                   output,
                   max_time_step=80 if self.model.squeeze else 100)
        # noinspection PyBroadException
        try:
            output = post_process(output,
                                  amr_version=self.config.get(
                                      'amr_version', '2.0'))
            scores = smatch_eval(output,
                                 input.replace('.features.preproc', ''),
                                 use_fast=use_fast)
        except Exception:
            eprint(f'Evaluation failed due to the following error:')
            traceback.print_exc()
            eprint(
                'As smatch usually fails on erroneous outputs produced at early epochs, '
                'it might be OK to ignore it. Now `nan` will be returned as the score.'
            )
            scores = F1_(float("nan"), float("nan"), float("nan"))
        if logger:
            header = f'{len(data)}/{len(data)}'
            if not ratio_width:
                ratio_width = len(header)
            logger.info(header.rjust(ratio_width) + f' {scores}')
        if test:
            data.close()
        return scores

    def build_model(self, training=True, **kwargs) -> torch.nn.Module:
        transformer = self.config.encoder.module()
        model = GraphSequenceAbstractMeaningRepresentationModel(
            self.vocabs,
            **merge_dict(self.config, overwrite=True, encoder=transformer),
            tokenizer=self.config.encoder.transform())
        # self.model = model
        # self.debug_load()
        return model

    def debug_load(self):
        model = self.model
        states = torch.load('/home/hhe43/amr_gs/model.pt',
                            map_location=self.device)
        model.load_state_dict(states, strict=False)

    def build_dataloader(self,
                         data,
                         batch_size,
                         shuffle=False,
                         device=None,
                         logger: logging.Logger = None,
                         gradient_accumulation=1,
                         batch_max_tokens=None,
                         **kwargs) -> DataLoader:
        dataset, lens = self.build_dataset(data, logger, training=shuffle)
        if batch_max_tokens:
            batch_max_tokens //= gradient_accumulation
        if not shuffle:
            batch_max_tokens //= 2
        sampler = SortingSampler(lens,
                                 batch_size=None,
                                 batch_max_tokens=batch_max_tokens,
                                 shuffle=shuffle)
        dataloader = PrefetchDataLoader(
            DataLoader(batch_sampler=sampler,
                       dataset=dataset,
                       collate_fn=merge_list_of_dict,
                       num_workers=0),
            batchify=self.build_batchify(device, shuffle))
        return dataloader

    def build_batchify(self, device, shuffle, shuffle_sibling=None):
        if shuffle_sibling is None:
            shuffle_sibling = shuffle
        return functools.partial(
            batchify,
            vocabs=self.vocabs,
            squeeze=self.config.get('squeeze', None),
            tokenizer=self.config.encoder.transform(),
            levi_graph=self.config.get('levi_graph', False),
            bart=self.config.get('bart', False),
            extra_arc=self.config.get('extra_arc', False),
            unk_rate=self.config.unk_rate if shuffle else 0,
            shuffle_sibling=shuffle_sibling,
            device=device)

    def build_dataset(self,
                      data,
                      logger: logging.Logger = None,
                      training=True):
        dataset = AbstractMeaningRepresentationDataset(
            data, generate_idx=not training)
        if self.vocabs.mutable:
            self.build_vocabs(dataset, logger)
            self.vocabs.lock()
            self.vocabs.summary(logger)
        lens = [len(x['token']) + len(x['amr']) for x in dataset]
        dataset.append_transform(
            functools.partial(get_concepts,
                              vocab=self.vocabs.predictable_concept,
                              rel_vocab=self.vocabs.rel if self.config.get(
                                  'separate_rel', False) else None))
        dataset.append_transform(append_bos)
        # Tokenization will happen in batchify
        if not self.config.get('squeeze', None):
            dataset.append_transform(self.config.encoder.transform())
        if isinstance(data, str):
            dataset.purge_cache()
            timer = CountdownTimer(len(dataset))
            for each in dataset:
                timer.log(
                    'Caching samples [blink][yellow]...[/yellow][/blink]')
        return dataset, lens

    def build_vocabs(self, dataset, logger: logging.Logger = None, **kwargs):
        # debug
        # self.load_vocabs('/home/hhe43/elit/data/model/amr2.0/convert/')
        # return
        # collect concepts and relations
        conc = []
        rel = []
        predictable_conc = [
        ]  # concepts that are not able to generate by copying lemmas ('multi-sentence', 'sense-01')
        tokens = []
        lemmas = []
        poses = []
        ners = []
        repeat = 10
        levi_graph = self.config.get('levi_graph', False)
        separate_rel = self.config.separate_rel
        timer = CountdownTimer(repeat * len(dataset))
        for i in range(repeat):
            # run 10 times random sort to get the priorities of different types of edges
            for sample in dataset:
                amr, lem, tok, pos, ner = sample['amr'], sample[
                    'lemma'], sample['token'], sample['pos'], sample['ner']
                if levi_graph == 'kahn':
                    concept, edge = amr.to_levi()
                else:
                    concept, edge, not_ok = amr.root_centered_sort()
                if levi_graph is True:
                    concept, edge = linearize(concept, edge, NIL, prefix=REL)
                lexical_concepts = set()
                for lemma in lem:
                    lexical_concepts.add(lemma + '_')
                    lexical_concepts.add(lemma)

                if i == 0:
                    if separate_rel:
                        edge = [(c, ) for c in concept if c.startswith(REL)]
                        concept = [c for c in concept if not c.startswith(REL)]
                    predictable_conc.append(
                        [c for c in concept if c not in lexical_concepts])
                    conc.append(concept)
                    tokens.append(tok)
                    lemmas.append(lem)
                    poses.append(pos)
                    ners.append(ner)
                    rel.append([e[-1] for e in edge])
                timer.log(
                    'Building vocabs [blink][yellow]...[/yellow][/blink]')

        # make vocabularies
        token_vocab, token_char_vocab = make_vocab(tokens, char_level=True)
        lemma_vocab, lemma_char_vocab = make_vocab(lemmas, char_level=True)
        pos_vocab = make_vocab(poses)
        ner_vocab = make_vocab(ners)
        conc_vocab, conc_char_vocab = make_vocab(conc, char_level=True)

        predictable_conc_vocab = make_vocab(predictable_conc)
        num_predictable_conc = sum(len(x) for x in predictable_conc)
        num_conc = sum(len(x) for x in conc)
        rel_vocab = make_vocab(rel)
        logger.info(
            f'Predictable concept coverage {num_predictable_conc} / {num_conc} = {num_predictable_conc / num_conc:.2%}'
        )
        vocabs = self.vocabs
        vocab_min_freq = self.config.get('vocab_min_freq', 5)
        vocabs.token = VocabWithFrequency(token_vocab,
                                          vocab_min_freq,
                                          specials=[CLS])
        vocabs.lemma = VocabWithFrequency(lemma_vocab,
                                          vocab_min_freq,
                                          specials=[CLS])
        vocabs.pos = VocabWithFrequency(pos_vocab,
                                        vocab_min_freq,
                                        specials=[CLS])
        vocabs.ner = VocabWithFrequency(ner_vocab,
                                        vocab_min_freq,
                                        specials=[CLS])
        vocabs.predictable_concept = VocabWithFrequency(predictable_conc_vocab,
                                                        vocab_min_freq,
                                                        specials=[DUM, END])
        vocabs.concept = VocabWithFrequency(conc_vocab,
                                            vocab_min_freq,
                                            specials=[DUM, END])
        vocabs.rel = VocabWithFrequency(rel_vocab,
                                        vocab_min_freq * 10,
                                        specials=[NIL])
        vocabs.word_char = VocabWithFrequency(token_char_vocab,
                                              vocab_min_freq * 20,
                                              specials=[CLS, END])
        vocabs.concept_char = VocabWithFrequency(conc_char_vocab,
                                                 vocab_min_freq * 20,
                                                 specials=[CLS, END])
        if separate_rel:
            vocabs.concept_and_rel = VocabWithFrequency(
                conc_vocab + rel_vocab,
                vocab_min_freq,
                specials=[DUM, END, NIL])
        # if levi_graph:
        #     # max = 993
        #     tokenizer = self.config.encoder.transform()
        #     rel_to_unused = dict()
        #     for i, rel in enumerate(vocabs.rel.idx_to_token):
        #         rel_to_unused[rel] = f'[unused{i + 100}]'
        #     tokenizer.rel_to_unused = rel_to_unused

    def predict(self,
                data: Union[str, List[str]],
                batch_size: int = None,
                **kwargs):
        pass

    def fit(self,
            trn_data,
            dev_data,
            save_dir,
            encoder,
            batch_size=None,
            batch_max_tokens=17776,
            epochs=1000,
            gradient_accumulation=4,
            char2concept_dim=128,
            char2word_dim=128,
            cnn_filters=((3, 256), ),
            concept_char_dim=32,
            concept_dim=300,
            dropout=0.2,
            embed_dim=512,
            eval_every=20,
            ff_embed_dim=1024,
            graph_layers=2,
            inference_layers=4,
            lr_scale=1.0,
            ner_dim=16,
            num_heads=8,
            pos_dim=32,
            pretrained_file=None,
            rel_dim=100,
            snt_layers=4,
            start_rank=0,
            unk_rate=0.33,
            warmup_steps=2000,
            with_bert=True,
            word_char_dim=32,
            word_dim=300,
            lr=1.,
            transformer_lr=None,
            adam_epsilon=1e-6,
            weight_decay=1e-4,
            grad_norm=1.0,
            joint_arc_concept=False,
            joint_rel=False,
            external_biaffine=False,
            optimize_every_layer=False,
            squeeze=False,
            levi_graph=False,
            separate_rel=False,
            extra_arc=False,
            bart=False,
            shuffle_sibling_steps=50000,
            vocab_min_freq=5,
            amr_version='2.0',
            devices=None,
            logger=None,
            seed=None,
            **kwargs):
        return super().fit(**merge_locals_kwargs(locals(), kwargs))

    def load_vocabs(self, save_dir, filename='vocabs.json'):
        if hasattr(self, 'vocabs'):
            self.vocabs = VocabDict()
            self.vocabs.load_vocabs(save_dir, filename, VocabWithFrequency)
Exemple #9
0
 def __init__(self, **kwargs) -> None:
     super().__init__()
     self.model: Optional[torch.nn.Module] = None
     self.config = SerializableDict(**kwargs)
     self.vocabs = VocabDict()
Exemple #10
0
class TorchComponent(Component, ABC):
    def __init__(self, **kwargs) -> None:
        super().__init__()
        self.model: Optional[torch.nn.Module] = None
        self.config = SerializableDict(**kwargs)
        self.vocabs = VocabDict()

    def _capture_config(self,
                        locals_: Dict,
                        exclude=('trn_data', 'dev_data', 'save_dir', 'kwargs',
                                 'self', 'logger', 'verbose', 'dev_batch_size',
                                 '__class__', 'devices', 'eval_trn')):
        """Save arguments to config

        Args:
          locals_: Dict: 
          exclude:  (Default value = ('trn_data')
          'dev_data': 
          'save_dir': 
          'kwargs': 
          'self': 
          'logger': 
          'verbose': 
          'dev_batch_size': 
          '__class__': 
          'devices'): 

        Returns:

        
        """
        if 'kwargs' in locals_:
            locals_.update(locals_['kwargs'])
        locals_ = dict((k, v) for k, v in locals_.items()
                       if k not in exclude and not k.startswith('_'))
        self.config.update(locals_)
        return self.config

    def save_weights(self,
                     save_dir,
                     filename='model.pt',
                     trainable_only=True,
                     **kwargs):
        model = self.model_
        state_dict = model.state_dict()
        if trainable_only:
            trainable_names = set(n for n, p in model.named_parameters()
                                  if p.requires_grad)
            state_dict = dict(
                (n, p) for n, p in state_dict.items() if n in trainable_names)
        torch.save(state_dict, os.path.join(save_dir, filename))

    def load_weights(self, save_dir, filename='model.pt', **kwargs):
        save_dir = get_resource(save_dir)
        filename = os.path.join(save_dir, filename)
        # flash(f'Loading model: {filename} [blink]...[/blink][/yellow]')
        self.model_.load_state_dict(torch.load(filename, map_location='cpu'),
                                    strict=False)
        # flash('')

    def save_config(self, save_dir, filename='config.json'):
        self._savable_config.save_json(os.path.join(save_dir, filename))

    def load_config(self, save_dir, filename='config.json', **kwargs):
        save_dir = get_resource(save_dir)
        self.config.load_json(os.path.join(save_dir, filename))
        self.config.update(kwargs)  # overwrite config loaded from disk
        for k, v in self.config.items():
            if isinstance(v, dict) and 'classpath' in v:
                self.config[k] = Configurable.from_config(v)
        self.on_config_ready(**self.config)

    def save_vocabs(self, save_dir, filename='vocabs.json'):
        if hasattr(self, 'vocabs'):
            self.vocabs.save_vocabs(save_dir, filename)

    def load_vocabs(self, save_dir, filename='vocabs.json'):
        if hasattr(self, 'vocabs'):
            self.vocabs = VocabDict()
            self.vocabs.load_vocabs(save_dir, filename)

    def save(self, save_dir: str, **kwargs):
        self.save_config(save_dir)
        self.save_vocabs(save_dir)
        self.save_weights(save_dir)

    def load(self, save_dir: str, devices=None, **kwargs):
        save_dir = get_resource(save_dir)
        # flash('Loading config and vocabs [blink][yellow]...[/yellow][/blink]')
        if devices is None and self.model:
            devices = self.devices
        self.load_config(save_dir, **kwargs)
        self.load_vocabs(save_dir)
        flash('Building model [blink][yellow]...[/yellow][/blink]')
        self.model = self.build_model(**merge_dict(self.config,
                                                   training=False,
                                                   **kwargs,
                                                   overwrite=True,
                                                   inplace=True))
        flash('')
        self.load_weights(save_dir, **kwargs)
        self.to(devices)
        self.model.eval()

    def fit(self,
            trn_data,
            dev_data,
            save_dir,
            batch_size,
            epochs,
            devices=None,
            logger=None,
            seed=None,
            finetune=False,
            eval_trn=True,
            _device_placeholder=False,
            **kwargs):
        # Common initialization steps
        config = self._capture_config(locals())
        if not logger:
            logger = self.build_logger('train', save_dir)
        if not seed:
            self.config.seed = 233 if isdebugging() else int(time.time())
        set_seed(self.config.seed)
        logger.info(self._savable_config.to_json(sort=True))
        if isinstance(devices, list) or devices is None or isinstance(
                devices, float):
            flash('[yellow]Querying CUDA devices [blink]...[/blink][/yellow]')
            devices = -1 if isdebugging() else cuda_devices(devices)
            flash('')
        # flash(f'Available GPUs: {devices}')
        if isinstance(devices, list):
            first_device = (devices[0] if devices else -1)
        elif isinstance(devices, dict):
            first_device = next(iter(devices.values()))
        elif isinstance(devices, int):
            first_device = devices
        else:
            first_device = -1
        if _device_placeholder and first_device >= 0:
            _dummy_placeholder = self._create_dummy_placeholder_on(
                first_device)
        if finetune:
            if isinstance(finetune, str):
                self.load(finetune, devices=devices)
            else:
                self.load(save_dir, devices=devices)
            logger.info(
                f'Finetune model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
                f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.'
            )
        self.on_config_ready(**self.config)
        trn = self.build_dataloader(**merge_dict(config,
                                                 data=trn_data,
                                                 batch_size=batch_size,
                                                 shuffle=True,
                                                 training=True,
                                                 device=first_device,
                                                 logger=logger,
                                                 vocabs=self.vocabs,
                                                 overwrite=True))
        dev = self.build_dataloader(
            **merge_dict(config,
                         data=dev_data,
                         batch_size=batch_size,
                         shuffle=False,
                         training=None,
                         device=first_device,
                         logger=logger,
                         vocabs=self.vocabs,
                         overwrite=True)) if dev_data else None
        if not finetune:
            flash('[yellow]Building model [blink]...[/blink][/yellow]')
            self.model = self.build_model(**merge_dict(config, training=True))
            flash('')
            logger.info(
                f'Model built with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
                f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.'
            )
            assert self.model, 'build_model is not properly implemented.'
        _description = repr(self.model)
        if len(_description.split('\n')) < 10:
            logger.info(_description)
        self.save_config(save_dir)
        self.save_vocabs(save_dir)
        self.to(devices, logger)
        if _device_placeholder and first_device >= 0:
            del _dummy_placeholder
        criterion = self.build_criterion(**merge_dict(config, trn=trn))
        optimizer = self.build_optimizer(
            **merge_dict(config, trn=trn, criterion=criterion))
        metric = self.build_metric(**self.config)
        if hasattr(trn.dataset, '__len__') and dev and hasattr(
                dev.dataset, '__len__'):
            logger.info(
                f'{len(trn.dataset)}/{len(dev.dataset)} samples in trn/dev set.'
            )
            trn_size = len(trn) // self.config.get('gradient_accumulation', 1)
            ratio_width = len(f'{trn_size}/{trn_size}')
        else:
            ratio_width = None
        return self.execute_training_loop(**merge_dict(config,
                                                       trn=trn,
                                                       dev=dev,
                                                       epochs=epochs,
                                                       criterion=criterion,
                                                       optimizer=optimizer,
                                                       metric=metric,
                                                       logger=logger,
                                                       save_dir=save_dir,
                                                       devices=devices,
                                                       ratio_width=ratio_width,
                                                       trn_data=trn_data,
                                                       dev_data=dev_data,
                                                       eval_trn=eval_trn,
                                                       overwrite=True))

    def build_logger(self, name, save_dir):
        logger = init_logger(name=name,
                             root_dir=save_dir,
                             level=logging.INFO,
                             fmt="%(message)s")
        return logger

    @abstractmethod
    def build_dataloader(self,
                         data,
                         batch_size,
                         shuffle=False,
                         device=None,
                         logger: logging.Logger = None,
                         **kwargs) -> DataLoader:
        pass

    def build_vocabs(self, **kwargs):
        pass

    @property
    def _savable_config(self):
        def convert(k, v):
            if hasattr(v, 'config'):
                v = v.config
            if isinstance(v, (set, tuple)):
                v = list(v)
            return k, v

        config = SerializableDict(
            convert(k, v) for k, v in sorted(self.config.items()))
        config.update({
            # 'create_time': now_datetime(),
            'classpath': classpath_of(self),
            'elit_version': elit.__version__,
        })
        return config

    @abstractmethod
    def build_optimizer(self, **kwargs):
        pass

    @abstractmethod
    def build_criterion(self, decoder, **kwargs):
        pass

    @abstractmethod
    def build_metric(self, **kwargs):
        pass

    @abstractmethod
    def execute_training_loop(self,
                              trn: DataLoader,
                              dev: DataLoader,
                              epochs,
                              criterion,
                              optimizer,
                              metric,
                              save_dir,
                              logger: logging.Logger,
                              devices,
                              ratio_width=None,
                              **kwargs):
        pass

    @abstractmethod
    def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric,
                       logger: logging.Logger, **kwargs):
        pass

    @abstractmethod
    def evaluate_dataloader(self,
                            data: DataLoader,
                            criterion: Callable,
                            metric=None,
                            output=False,
                            **kwargs):
        pass

    @abstractmethod
    def build_model(self, training=True, **kwargs) -> torch.nn.Module:
        raise NotImplementedError

    def evaluate(self,
                 tst_data,
                 save_dir=None,
                 logger: logging.Logger = None,
                 batch_size=None,
                 output=False,
                 **kwargs):
        if not self.model:
            raise RuntimeError('Call fit or load before evaluate.')
        if isinstance(tst_data, str):
            tst_data = get_resource(tst_data)
            filename = os.path.basename(tst_data)
        else:
            filename = None
        if output is True:
            output = self.generate_prediction_filename(
                tst_data if isinstance(tst_data, str) else 'test.txt',
                save_dir)
        if logger is None:
            _logger_name = basename_no_ext(filename) if filename else None
            logger = self.build_logger(_logger_name, save_dir)
        if not batch_size:
            batch_size = self.config.get('batch_size', 32)
        data = self.build_dataloader(**merge_dict(self.config,
                                                  data=tst_data,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  device=self.devices[0],
                                                  logger=logger,
                                                  overwrite=True))
        dataset = data
        while dataset and hasattr(dataset, 'dataset'):
            dataset = dataset.dataset
        num_samples = len(dataset) if dataset else None
        if output and isinstance(dataset, TransformDataset):

            def add_idx(samples):
                for idx, sample in enumerate(samples):
                    if sample:
                        sample[IDX] = idx

            add_idx(dataset.data)
            if dataset.cache:
                add_idx(dataset.cache)

        criterion = self.build_criterion(**self.config)
        metric = self.build_metric(**self.config)
        start = time.time()
        outputs = self.evaluate_dataloader(data,
                                           criterion=criterion,
                                           filename=filename,
                                           output=output,
                                           input=tst_data,
                                           save_dir=save_dir,
                                           test=True,
                                           num_samples=num_samples,
                                           **merge_dict(self.config,
                                                        batch_size=batch_size,
                                                        metric=metric,
                                                        logger=logger,
                                                        **kwargs))
        elapsed = time.time() - start
        if logger:
            if num_samples:
                logger.info(
                    f'speed: {num_samples / elapsed:.0f} samples/second')
            else:
                logger.info(f'speed: {len(data) / elapsed:.0f} batches/second')
        return metric, outputs

    def generate_prediction_filename(self, tst_data, save_dir):
        assert isinstance(
            tst_data,
            str), 'tst_data has be a str in order to infer the output name'
        output = os.path.splitext(os.path.basename(tst_data))
        output = os.path.join(save_dir, output[0] + '.pred' + output[1])
        return output

    def to(self,
           devices=Union[int, float, List[int],
                         Dict[str, Union[int, torch.device]]],
           logger: logging.Logger = None):
        if devices == -1 or devices == [-1]:
            devices = []
        elif isinstance(devices, (int, float)) or devices is None:
            devices = cuda_devices(devices)
        if devices:
            if logger:
                logger.info(
                    f'Using GPUs: [on_blue][cyan][bold]{devices}[/bold][/cyan][/on_blue]'
                )
            if isinstance(devices, list):
                flash(
                    f'Moving model to GPUs {devices} [blink][yellow]...[/yellow][/blink]'
                )
                self.model = self.model.to(devices[0])
                if len(devices) > 1 and not isdebugging() and not isinstance(
                        self.model, nn.DataParallel):
                    self.model = self.parallelize(devices)
            elif isinstance(devices, dict):
                for name, module in self.model.named_modules():
                    for regex, device in devices.items():
                        try:
                            on_device: torch.device = next(
                                module.parameters()).device
                        except StopIteration:
                            continue
                        if on_device == device:
                            continue
                        if isinstance(device, int):
                            if on_device.index == device:
                                continue
                        if re.match(regex, name):
                            if not name:
                                name = '*'
                            flash(
                                f'Moving module [yellow]{name}[/yellow] to [on_yellow][magenta][bold]{device}'
                                f'[/bold][/magenta][/on_yellow]: [red]{regex}[/red]\n'
                            )
                            module.to(device)
            else:
                raise ValueError(f'Unrecognized devices {devices}')
            flash('')
        else:
            if logger:
                logger.info('Using CPU')

    def parallelize(self, devices: List[Union[int, torch.device]]):
        return nn.DataParallel(self.model, device_ids=devices)

    @property
    def devices(self):
        if self.model is None:
            return None
        # next(parser.model.parameters()).device
        if hasattr(self.model, 'device_ids'):
            return self.model.device_ids
        device: torch.device = next(self.model.parameters()).device
        return [device]

    @property
    def device(self):
        devices = self.devices
        if not devices:
            return None
        return devices[0]

    def on_config_ready(self, **kwargs):
        pass

    @property
    def model_(self) -> nn.Module:
        """
        The actual model when it's wrapped by a `DataParallel`

        Returns: The "real" model

        """
        if isinstance(self.model, nn.DataParallel):
            return self.model.module
        return self.model

    # noinspection PyMethodOverriding
    @abstractmethod
    def predict(self,
                data: Union[str, List[str]],
                batch_size: int = None,
                **kwargs):
        pass

    def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
        batch = merge_list_of_dict(samples)
        return batch

    @staticmethod
    def _create_dummy_placeholder_on(device):
        if device < 0:
            device = 'cpu:0'
        return torch.zeros(16, 16, device=device)

    @torch.no_grad()
    def __call__(self, data, batch_size=None, **kwargs):
        return super().__call__(
            data,
            **merge_dict(self.config,
                         overwrite=True,
                         batch_size=batch_size
                         or self.config.get('batch_size', None),
                         **kwargs))
Exemple #11
0
class GraphAbstractMeaningRepresentationParser(TorchComponent):
    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.model: GraphAbstractMeaningRepresentationModel = self.model
        self.sense_restore: NodeRestore = None

    def build_optimizer(self, trn, epochs, lr, adam_epsilon, weight_decay,
                        warmup_steps, transformer_lr, gradient_accumulation,
                        **kwargs):
        model = self.model
        num_training_steps = len(trn) * epochs // gradient_accumulation
        optimizer, scheduler = build_optimizer_scheduler_with_transformer(
            model, model.bert_encoder, lr, transformer_lr, num_training_steps,
            warmup_steps, weight_decay, adam_epsilon)
        return optimizer, scheduler

    def build_criterion(self, **kwargs):
        pass

    def build_metric(self, **kwargs):
        pass

    def execute_training_loop(self,
                              trn: PrefetchDataLoader,
                              dev: PrefetchDataLoader,
                              epochs,
                              criterion,
                              optimizer,
                              metric,
                              save_dir,
                              logger: logging.Logger,
                              devices,
                              ratio_width=None,
                              dev_data=None,
                              gradient_accumulation=1,
                              **kwargs):
        best_epoch, best_metric = 0, -1
        timer = CountdownTimer(epochs)
        history = History()
        try:
            for epoch in range(1, epochs + 1):
                logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
                trn = self.fit_dataloader(
                    trn,
                    criterion,
                    optimizer,
                    metric,
                    logger,
                    ratio_width=ratio_width,
                    gradient_accumulation=gradient_accumulation,
                    history=history,
                    save_dir=save_dir)
                report = f'{timer.elapsed_human}/{timer.total_time_human}'
                if epoch % self.config.eval_every == 0 or epoch == epochs:
                    metric = self.evaluate_dataloader(dev,
                                                      logger,
                                                      dev_data,
                                                      ratio_width=ratio_width,
                                                      save_dir=save_dir,
                                                      use_fast=True)
                    if metric > best_metric:
                        self.save_weights(save_dir)
                        best_metric = metric
                        best_epoch = epoch
                        report += ' [red]saved[/red]'
                timer.log(report,
                          ratio_percentage=False,
                          newline=True,
                          ratio=False)
            if best_epoch and best_epoch != epochs:
                logger.info(
                    f'Restored the best model with {best_metric} saved {epochs - best_epoch} epochs ago'
                )
                self.load_weights(save_dir)
        finally:
            trn.close()
            dev.close()

    def fit_dataloader(self,
                       trn: PrefetchDataLoader,
                       criterion,
                       optimizer,
                       metric,
                       logger: logging.Logger,
                       gradient_accumulation=1,
                       ratio_width=None,
                       history=None,
                       save_dir=None,
                       **kwargs):
        self.model.train()
        num_training_steps = len(
            trn) * self.config.epochs // gradient_accumulation
        shuffle_sibling_steps = self.config.shuffle_sibling_steps
        if isinstance(shuffle_sibling_steps, float):
            shuffle_sibling_steps = int(shuffle_sibling_steps *
                                        num_training_steps)
        timer = CountdownTimer(
            len([
                i for i in range(history.num_mini_batches +
                                 1, history.num_mini_batches + len(trn) + 1)
                if i % gradient_accumulation == 0
            ]))
        total_loss = 0
        optimizer, scheduler = optimizer
        correct_conc, total_conc, correct_rel, total_rel = 0, 0, 0, 0
        for idx, batch in enumerate(trn):
            loss = self.compute_loss(batch)
            loss, (concept_correct, concept_total), rel_out = loss
            correct_conc += concept_correct
            total_conc += concept_total
            if rel_out is not None:
                rel_correct, rel_total = rel_out
                correct_rel += rel_correct
                total_rel += rel_total
            loss /= gradient_accumulation
            # loss = loss.sum()  # For data parallel
            loss.backward()
            total_loss += loss.item()
            history.num_mini_batches += 1
            if history.num_mini_batches % gradient_accumulation == 0:
                self._step(optimizer, scheduler)
                metric = f' Concept acc: {correct_conc / total_conc:.2%}'
                metric += f' Relation acc: {correct_rel / total_rel:.2%}'
                timer.log(
                    f'loss: {total_loss / (timer.current + 1):.4f} lr: {optimizer.param_groups[0]["lr"]:.2e}'
                    + metric,
                    ratio_percentage=None,
                    ratio_width=ratio_width,
                    logger=logger)

                if history.num_mini_batches // gradient_accumulation == shuffle_sibling_steps:
                    trn.batchify = self.build_batchify(self.device,
                                                       shuffle=True,
                                                       shuffle_sibling=False)
                    timer.print(
                        f'Switched to [bold]deterministic order[/bold] after {shuffle_sibling_steps} steps',
                        newline=True)
            del loss
        return trn

    def _step(self, optimizer, scheduler):
        if self.config.grad_norm:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.config.grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

    def update_metrics(self, batch: dict, prediction: Union[Dict, List],
                       metrics):
        if isinstance(prediction, dict):
            prediction = prediction['prediction']
        assert len(prediction) == len(batch['ner'])
        for pred, gold in zip(prediction, batch['ner']):
            metrics(set(pred), set(gold))

    def compute_loss(self, batch):
        # debug
        # gold = torch.load('/home/hhe43/amr_gs/batch.pt', map_location=self.device)
        # self.debug_assert_batch_equal(batch, gold)
        # set_seed()
        # end debug
        concept_loss, arc_loss, rel_loss, graph_arc_loss = self.model(batch)
        concept_loss, concept_correct, concept_total = concept_loss
        if rel_loss is not None:
            rel_loss, rel_correct, rel_total = rel_loss
            loss = concept_loss + arc_loss + rel_loss
            rel_acc = (rel_correct, rel_total)
        else:
            loss = concept_loss + arc_loss
            rel_acc = None
        return loss, (concept_correct, concept_total), rel_acc

    def debug_assert_batch_equal(self, batch, gold):
        # assert torch.equal(batch['token_input_ids'], gold['bert_token'])
        for k, v in gold.items():
            pred = batch.get(k, None)
            if pred is not None:
                if isinstance(v, torch.Tensor) and not torch.equal(pred, v):
                    assert torch.equal(pred, v), f'{k} not equal'

    @torch.no_grad()
    def evaluate_dataloader(self,
                            data: PrefetchDataLoader,
                            logger,
                            input,
                            output=False,
                            ratio_width=None,
                            save_dir=None,
                            use_fast=False,
                            test=False,
                            metric: SmatchScores = None,
                            model=None,
                            h=None,
                            **kwargs):
        if not model:
            model = self.model
        model.eval()
        pp = PostProcessor(self.vocabs['rel'])
        if not save_dir:
            save_dir = tempdir(str(os.getpid()))
        if not output:
            output = os.path.join(save_dir, os.path.basename(input) + '.pred')
        # Squeeze tokens and concepts into one transformer basically reduces the max num of inputs it can handle
        parse_data(model,
                   pp,
                   data,
                   input,
                   output,
                   max_time_step=80 if model.squeeze else 100,
                   h=h)
        # noinspection PyBroadException
        try:
            output = post_process(output,
                                  amr_version=self.config.get(
                                      'amr_version', '2.0'))
            scores = smatch_eval(output,
                                 input.replace('.features.preproc', ''),
                                 use_fast=use_fast)
            if metric:
                metric.clear()
                if isinstance(scores, F1_):
                    metric['Smatch'] = scores
                else:
                    metric.update(scores)
        except Exception:
            eprint(f'Evaluation failed due to the following error:')
            traceback.print_exc()
            eprint(
                'As smatch usually fails on erroneous outputs produced at early epochs, '
                'it might be OK to ignore it. Now `nan` will be returned as the score.'
            )
            scores = F1_(float("nan"), float("nan"), float("nan"))
            if metric:
                metric.clear()
                metric['Smatch'] = scores
        if logger:
            header = f'{len(data)}/{len(data)}'
            if not ratio_width:
                ratio_width = len(header)
            logger.info(header.rjust(ratio_width) + f' {scores}')
        if test:
            data.close()
        return scores

    def build_model(self, training=True, **kwargs) -> torch.nn.Module:
        transformer = self.config.encoder.module()
        model = GraphAbstractMeaningRepresentationModel(
            self.vocabs,
            **merge_dict(self.config, overwrite=True, encoder=transformer),
            tokenizer=self.config.encoder.transform())
        return model

    def build_dataloader(self,
                         data,
                         batch_size,
                         shuffle=False,
                         device=None,
                         logger: logging.Logger = None,
                         gradient_accumulation=1,
                         batch_max_tokens=None,
                         **kwargs) -> DataLoader:
        dataset, lens = self.build_dataset(
            data,
            logger,
            training=shuffle,
            transform=self.config.encoder.transform())
        if batch_max_tokens:
            batch_max_tokens //= gradient_accumulation
        if not shuffle:
            batch_max_tokens //= 2
        sampler = SortingSampler(lens,
                                 batch_size=None,
                                 batch_max_tokens=batch_max_tokens,
                                 shuffle=shuffle)
        dataloader = PrefetchDataLoader(
            DataLoader(batch_sampler=sampler,
                       dataset=dataset,
                       collate_fn=merge_list_of_dict,
                       num_workers=0),
            batchify=self.build_batchify(device, shuffle),
            prefetch=10 if isinstance(data, str) else None)
        return dataloader

    def build_batchify(self, device, shuffle, shuffle_sibling=None):
        if shuffle_sibling is None:
            shuffle_sibling = shuffle
        tokenizer = self.config.encoder.transform() if self.config.get(
            'encoder', None) else None
        return functools.partial(
            batchify,
            vocabs=self.vocabs,
            squeeze=self.config.get('squeeze', None),
            tokenizer=tokenizer,
            levi_graph=self.config.get('levi_graph', False),
            bart=self.config.get('bart', False),
            extra_arc=self.config.get('extra_arc', False),
            unk_rate=self.config.unk_rate if shuffle else 0,
            shuffle_sibling=shuffle_sibling,
            device=device)

    def build_dataset(self,
                      data,
                      logger: logging.Logger = None,
                      training=True,
                      transform=None):
        dataset = AbstractMeaningRepresentationDataset(
            data, generate_idx=not training)
        if self.vocabs.mutable:
            self.build_vocabs(dataset, logger)
            self.vocabs.lock()
            self.vocabs.summary(logger)
        lens = [
            len(x['token']) + len(x['amr'] if 'amr' in x else [])
            for x in dataset
        ]
        dataset.append_transform(
            functools.partial(get_concepts,
                              vocab=self.vocabs.predictable_concept,
                              rel_vocab=self.vocabs.rel if self.config.get(
                                  'separate_rel', False) else None))
        dataset.append_transform(append_bos)
        if transform:
            dataset.append_transform(transform)
        if isinstance(data, str):
            dataset.purge_cache()
            timer = CountdownTimer(len(dataset))
            for each in dataset:
                timer.log(
                    'Caching samples [blink][yellow]...[/yellow][/blink]')
        return dataset, lens

    def build_vocabs(self, dataset, logger: logging.Logger = None, **kwargs):
        # debug
        # self.load_vocabs('/home/hhe43/elit/data/model/amr2.0/convert/')
        # return
        # collect concepts and relations
        conc = []
        rel = []
        predictable_conc = [
        ]  # concepts that are not able to generate by copying lemmas ('multi-sentence', 'sense-01')
        tokens = []
        lemmas = []
        poses = []
        ners = []
        repeat = 10
        timer = CountdownTimer(repeat * len(dataset))
        for i in range(repeat):
            # run 10 times random sort to get the priorities of different types of edges
            for sample in dataset:
                amr, lem, tok, pos, ner = sample['amr'], sample[
                    'lemma'], sample['token'], sample['pos'], sample['ner']
                concept, edge, not_ok = amr.root_centered_sort()
                lexical_concepts = set()
                for lemma in lem:
                    lexical_concepts.add(lemma + '_')
                    lexical_concepts.add(lemma)

                if i == 0:
                    predictable_conc.append(
                        [c for c in concept if c not in lexical_concepts])
                    conc.append(concept)
                    tokens.append(tok)
                    lemmas.append(lem)
                    poses.append(pos)
                    ners.append(ner)
                    rel.append([e[-1] for e in edge])
                timer.log(
                    'Building vocabs [blink][yellow]...[/yellow][/blink]')

        # make vocabularies
        lemma_vocab, lemma_char_vocab = make_vocab(lemmas, char_level=True)
        conc_vocab, conc_char_vocab = make_vocab(conc, char_level=True)

        predictable_conc_vocab = make_vocab(predictable_conc)
        num_predictable_conc = sum(len(x) for x in predictable_conc)
        num_conc = sum(len(x) for x in conc)
        rel_vocab = make_vocab(rel)
        logger.info(
            f'Predictable concept coverage {num_predictable_conc} / {num_conc} = {num_predictable_conc / num_conc:.2%}'
        )
        vocabs = self.vocabs
        vocab_min_freq = self.config.get('vocab_min_freq', 5)
        vocabs.lemma = VocabWithFrequency(lemma_vocab,
                                          vocab_min_freq,
                                          specials=[CLS])
        vocabs.predictable_concept = VocabWithFrequency(predictable_conc_vocab,
                                                        vocab_min_freq,
                                                        specials=[DUM, END])
        vocabs.concept = VocabWithFrequency(conc_vocab,
                                            vocab_min_freq,
                                            specials=[DUM, END])
        vocabs.rel = VocabWithFrequency(rel_vocab,
                                        vocab_min_freq * 10,
                                        specials=[NIL])
        vocabs.concept_char = VocabWithFrequency(conc_char_vocab,
                                                 vocab_min_freq * 20,
                                                 specials=[CLS, END])

    def predict(self,
                data: Union[str, List[str]],
                batch_size: int = None,
                **kwargs):
        if not data:
            return []
        flat = self.input_is_flat(data)
        if flat:
            data = [data]
        samples = self.build_samples(data)
        dataloader = self.build_dataloader(samples,
                                           device=self.device,
                                           **merge_dict(self.config,
                                                        overwrite=True,
                                                        batch_size=batch_size))
        pp = PostProcessor(self.vocabs['rel'])
        results = list(parse_data_(self.model, pp, dataloader))
        for i, each in enumerate(results):
            amr_graph = AMRGraph(each)
            self.sense_restore.restore_graph(amr_graph)
            results[i] = amr_graph
        if flat:
            return results[0]
        return results

    def input_is_flat(self, data: List):
        return isinstance(data[0], tuple)

    def build_samples(self, data):
        samples = []
        for each in data:
            token, lemma = zip(*each)
            samples.append({'token': list(token), 'lemma': list(lemma)})
        return samples

    def fit(self,
            trn_data,
            dev_data,
            save_dir,
            encoder,
            batch_size=None,
            batch_max_tokens=17776,
            epochs=1000,
            gradient_accumulation=4,
            char2concept_dim=128,
            cnn_filters=((3, 256), ),
            concept_char_dim=32,
            concept_dim=300,
            dropout=0.2,
            embed_dim=512,
            eval_every=20,
            ff_embed_dim=1024,
            graph_layers=2,
            inference_layers=4,
            num_heads=8,
            rel_dim=100,
            snt_layers=4,
            unk_rate=0.33,
            warmup_steps=0.1,
            lr=1e-3,
            transformer_lr=1e-4,
            adam_epsilon=1e-6,
            weight_decay=0,
            grad_norm=1.0,
            shuffle_sibling_steps=0.9,
            vocab_min_freq=5,
            amr_version='2.0',
            devices=None,
            logger=None,
            seed=None,
            **kwargs):
        return super().fit(**merge_locals_kwargs(locals(), kwargs))

    def load_vocabs(self, save_dir, filename='vocabs.json'):
        if hasattr(self, 'vocabs'):
            self.vocabs = VocabDict()
            self.vocabs.load_vocabs(save_dir, filename, VocabWithFrequency)

    def on_config_ready(self, **kwargs):
        super().on_config_ready(**kwargs)
        utils_dir = get_resource(get_amr_utils(self.config.amr_version))
        self.sense_restore = NodeRestore(NodeUtilities.from_json(utils_dir))