예제 #1
0
파일: span_rank.py 프로젝트: emorynlp/elit
    def fit(self,
            trn_data,
            dev_data,
            save_dir,
            embed,
            context_layer,
            batch_size=40,
            batch_max_tokens=700,
            lexical_dropout=0.5,
            dropout=0.2,
            span_width_feature_size=20,
            ffnn_size=150,
            ffnn_depth=2,
            argument_ratio=0.8,
            predicate_ratio=0.4,
            max_arg_width=30,
            mlp_label_size=100,
            enforce_srl_constraint=False,
            use_gold_predicates=False,
            doc_level_offset=True,
            use_biaffine=False,
            lr=1e-3,
            transformer_lr=1e-5,
            adam_epsilon=1e-6,
            weight_decay=0.01,
            warmup_steps=0.1,
            grad_norm=5.0,
            gradient_accumulation=1,
            loss_reduction='sum',
            devices=None,
            logger=None,
            seed=None,
            **kwargs):

        return super().fit(**merge_locals_kwargs(locals(), kwargs))
예제 #2
0
 def fit(
         self,
         trn_data,
         dev_data,
         save_dir,
         text_a_key=None,
         text_b_key=None,
         label_key=None,
         transformer=None,
         max_seq_length=512,
         truncate_long_sequences=True,
         # hidden_dropout_prob=0.0,
         lr=5e-5,
         transformer_lr=None,
         adam_epsilon=1e-6,
         weight_decay=0,
         warmup_steps=0.1,
         batch_size=32,
         batch_max_tokens=None,
         epochs=3,
         logger=None,
         # transform=None,
         devices: Union[float, int, List[int]] = None,
         **kwargs):
     return super().fit(**merge_locals_kwargs(locals(), kwargs))
예제 #3
0
 def __init__(self,
              trn: str = None,
              dev: str = None,
              tst: str = None,
              sampler_builder: SamplerBuilder = None,
              dependencies: str = None,
              scalar_mix: ScalarMixWithDropoutBuilder = None,
              use_raw_hidden_states=False,
              lr=2e-3, separate_optimizer=False,
              punct=False,
              tree=True,
              pad_rel=None,
              apply_constraint=False,
              single_root=True,
              no_zero_head=None,
              n_mlp_arc=500,
              n_mlp_rel=100,
              mlp_dropout=.33,
              mu=.9,
              nu=.9,
              epsilon=1e-12,
              decay=.75,
              decay_steps=5000,
              cls_is_bos=True,
              use_pos=False,
              **kwargs) -> None:
     super().__init__(**merge_locals_kwargs(locals(), kwargs))
     self.vocabs = VocabDict()
예제 #4
0
 def fit(self,
         trn_data,
         dev_data,
         save_dir,
         epochs=5,
         append_after_sentence=None,
         eos_chars=None,
         eos_char_min_freq=200,
         eos_char_is_punct=True,
         char_min_freq=None,
         window_size=5,
         batch_size=32,
         lr=0.001,
         grad_norm=None,
         loss_reduction='sum',
         embedding_size=128,
         rnn_type: str = 'LSTM',
         rnn_size=256,
         rnn_layers=1,
         rnn_bidirectional=False,
         dropout=0.2,
         devices=None,
         logger=None,
         seed=None,
         **kwargs):
     return super().fit(**merge_locals_kwargs(locals(), kwargs))
예제 #5
0
 def distill(self,
             teacher: str,
             trn_data,
             dev_data,
             save_dir,
             batch_size=None,
             epochs=None,
             kd_criterion='kd_ce_loss',
             temperature_scheduler='flsw',
             devices=None,
             logger=None,
             seed=None,
             **kwargs):
     devices = devices or cuda_devices()
     if isinstance(kd_criterion, str):
         kd_criterion = KnowledgeDistillationLoss(kd_criterion)
     if isinstance(temperature_scheduler, str):
         temperature_scheduler = TemperatureScheduler.from_name(temperature_scheduler)
     teacher = self.build_teacher(teacher, devices=devices)
     self.vocabs = teacher.vocabs
     config = copy(teacher.config)
     batch_size = batch_size or config.get('batch_size', None)
     epochs = epochs or config.get('epochs', None)
     config.update(kwargs)
     return super().fit(**merge_locals_kwargs(locals(),
                                              config,
                                              excludes=('self', 'kwargs', '__class__', 'config')))
예제 #6
0
 def fit(self,
         trn_data,
         dev_data,
         save_dir,
         embed: Embedding,
         context_layer,
         sampler='sorting',
         n_buckets=32,
         batch_size=50,
         lexical_dropout=0.5,
         ffnn_size=150,
         is_flat_ner=True,
         doc_level_offset=True,
         lr=1e-3,
         transformer_lr=1e-5,
         adam_epsilon=1e-6,
         weight_decay=0.01,
         warmup_steps=0.1,
         grad_norm=5.0,
         epochs=50,
         loss_reduction='sum',
         gradient_accumulation=1,
         ret_tokens=True,
         tagset=None,
         sampler_builder=None,
         devices=None,
         logger=None,
         seed=None,
         **kwargs
         ):
     return super().fit(**merge_locals_kwargs(locals(), kwargs))
예제 #7
0
 def fit(self,
         trn_data,
         dev_data,
         save_dir,
         batch_size=50,
         epochs=100,
         embed=100,
         rnn_input=None,
         rnn_hidden=256,
         drop=0.5,
         lr=0.001,
         patience=10,
         crf=True,
         optimizer='adam',
         token_key='token',
         tagging_scheme=None,
         anneal_factor: float = 0.5,
         delimiter=None,
         anneal_patience=2,
         devices=None,
         token_delimiter=None,
         logger=None,
         verbose=True,
         **kwargs):
     return super().fit(**merge_locals_kwargs(locals(), kwargs))
예제 #8
0
 def __init__(self,
              trn: str = None,
              dev: str = None,
              tst: str = None,
              sampler_builder: SamplerBuilder = None,
              dependencies: str = None,
              scalar_mix: ScalarMixWithDropoutBuilder = None,
              use_raw_hidden_states=False,
              lr=None,
              separate_optimizer=False,
              cls_is_bos=True,
              sep_is_eos=True,
              delete=('', ':', '``', "''", '.', '?', '!', '-NONE-', 'TOP',
                      ',', 'S1'),
              equal=(('ADVP', 'PRT'), ),
              mbr=True,
              n_mlp_span=500,
              n_mlp_label=100,
              mlp_dropout=.33,
              no_subcategory=True,
              **kwargs) -> None:
     if isinstance(equal, tuple):
         equal = dict(equal)
     super().__init__(**merge_locals_kwargs(locals(), kwargs))
     self.vocabs = VocabDict()
예제 #9
0
 def fit(self,
         trn_data,
         dev_data,
         save_dir,
         transformer=None,
         lr=5e-5,
         transformer_lr=None,
         adam_epsilon=1e-8,
         weight_decay=0,
         warmup_steps=0.1,
         batch_size=32,
         gradient_accumulation=1,
         grad_norm=5.0,
         transformer_grad_norm=None,
         average_subwords=False,
         scalar_mix: Union[ScalarMixWithDropoutBuilder, int] = None,
         word_dropout=None,
         hidden_dropout=None,
         max_sequence_length=None,
         ret_raw_hidden_states=False,
         batch_max_tokens=None,
         epochs=3,
         logger=None,
         devices: Union[float, int, List[int]] = None,
         **kwargs):
     return super().fit(**merge_locals_kwargs(locals(), kwargs))
예제 #10
0
 def __init__(self,
              trn: str = None,
              dev: str = None,
              tst: str = None,
              sampler_builder: SamplerBuilder = None,
              dependencies: str = None,
              scalar_mix: ScalarMixWithDropoutBuilder = None,
              use_raw_hidden_states=False,
              lr=1e-3,
              separate_optimizer=False,
              lexical_dropout=0.5,
              dropout=0.2,
              span_width_feature_size=20,
              ffnn_size=150,
              ffnn_depth=2,
              argument_ratio=0.8,
              predicate_ratio=0.4,
              max_arg_width=30,
              mlp_label_size=100,
              enforce_srl_constraint=False,
              use_gold_predicates=False,
              doc_level_offset=True,
              use_biaffine=False,
              loss_reduction='mean',
              with_argument=' ',
              **kwargs) -> None:
     super().__init__(**merge_locals_kwargs(locals(), kwargs))
     self.vocabs = VocabDict()
예제 #11
0
 def fit(self,
         trn_data,
         dev_data,
         save_dir,
         encoder,
         lr=5e-5,
         transformer_lr=None,
         adam_epsilon=1e-8,
         weight_decay=0,
         warmup_steps=0.1,
         grad_norm=1.0,
         n_mlp_span=500,
         n_mlp_label=100,
         mlp_dropout=.33,
         batch_size=None,
         batch_max_tokens=5000,
         gradient_accumulation=1,
         epochs=30,
         patience=0.5,
         mbr=True,
         sampler_builder=None,
         delete=('', ':', '``', "''", '.', '?', '!', '-NONE-', 'TOP', ',',
                 'S1'),
         equal=(('ADVP', 'PRT'), ),
         no_subcategory=True,
         eval_trn=True,
         transform=None,
         devices=None,
         logger=None,
         seed=None,
         **kwargs):
     if isinstance(equal, tuple):
         equal = dict(equal)
     return super().fit(**merge_locals_kwargs(locals(), kwargs))
예제 #12
0
 def fit(
         self,
         encoder: Embedding,
         tasks: Dict[str, Task],
         save_dir,
         epochs,
         patience=0.5,
         lr=1e-3,
         encoder_lr=5e-5,
         adam_epsilon=1e-8,
         weight_decay=0.0,
         warmup_steps=0.1,
         gradient_accumulation=1,
         grad_norm=5.0,
         encoder_grad_norm=None,
         decoder_grad_norm=None,
         tau: float = 0.8,
         transform=None,
         # prune: Callable = None,
         eval_trn=True,
         prefetch=None,
         tasks_need_custom_eval=None,
         _device_placeholder=False,
         devices=None,
         logger=None,
         seed=None,
         **kwargs):
     trn_data, dev_data, batch_size = 'trn', 'dev', None
     task_names = list(tasks.keys())
     return super().fit(
         **merge_locals_kwargs(locals(),
                               kwargs,
                               excludes=('self', 'kwargs', '__class__',
                                         'tasks')), **tasks)
예제 #13
0
 def fit(self,
         trn_data,
         dev_data,
         save_dir,
         encoder,
         batch_size=None,
         batch_max_tokens=17776,
         epochs=1000,
         gradient_accumulation=4,
         char2concept_dim=128,
         char2word_dim=128,
         cnn_filters=((3, 256), ),
         concept_char_dim=32,
         concept_dim=300,
         dropout=0.2,
         embed_dim=512,
         eval_every=20,
         ff_embed_dim=1024,
         graph_layers=2,
         inference_layers=4,
         lr_scale=1.0,
         ner_dim=16,
         num_heads=8,
         pos_dim=32,
         pretrained_file=None,
         rel_dim=100,
         snt_layers=4,
         start_rank=0,
         unk_rate=0.33,
         warmup_steps=2000,
         with_bert=True,
         word_char_dim=32,
         word_dim=300,
         lr=1.,
         transformer_lr=None,
         adam_epsilon=1e-6,
         weight_decay=1e-4,
         grad_norm=1.0,
         joint_arc_concept=False,
         joint_rel=False,
         external_biaffine=False,
         optimize_every_layer=False,
         squeeze=False,
         levi_graph=False,
         separate_rel=False,
         extra_arc=False,
         bart=False,
         shuffle_sibling_steps=50000,
         vocab_min_freq=5,
         amr_version='2.0',
         devices=None,
         logger=None,
         seed=None,
         **kwargs):
     return super().fit(**merge_locals_kwargs(locals(), kwargs))
예제 #14
0
 def fit(self,
         trn_data,
         dev_data,
         save_dir,
         feat=None,
         n_embed=100,
         pretrained_embed=None,
         transformer=None,
         average_subwords=False,
         word_dropout: float = 0.2,
         transformer_hidden_dropout=None,
         layer_dropout=0,
         mix_embedding: int = None,
         embed_dropout=.33,
         n_lstm_hidden=400,
         n_lstm_layers=3,
         hidden_dropout=.33,
         n_mlp_arc=500,
         n_mlp_rel=100,
         mlp_dropout=.33,
         arc_dropout=None,
         rel_dropout=None,
         arc_loss_interpolation=0.4,
         lr=2e-3,
         transformer_lr=5e-5,
         mu=.9,
         nu=.9,
         epsilon=1e-12,
         clip=5.0,
         decay=.75,
         decay_steps=5000,
         weight_decay=0,
         warmup_steps=0.1,
         separate_optimizer=True,
         patience=100,
         batch_size=None,
         sampler_builder=None,
         lowercase=False,
         epochs=50000,
         apply_constraint=False,
         single_root=None,
         no_zero_head=None,
         punct=False,
         min_freq=2,
         logger=None,
         verbose=True,
         unk=UNK,
         pad_rel=None,
         max_sequence_length=512,
         gradient_accumulation=1,
         devices: Union[float, int, List[int]] = None,
         transform=None,
         **kwargs):
     return super().fit(**merge_locals_kwargs(locals(), kwargs))
예제 #15
0
 def distill(self,
             teacher: str,
             trn_data,
             dev_data,
             save_dir,
             transformer: str,
             batch_size=None,
             temperature_scheduler='flsw',
             epochs=None,
             devices=None,
             logger=None,
             seed=None,
             **kwargs):
     return super().distill(**merge_locals_kwargs(locals(), kwargs))
예제 #16
0
 def fit(self,
         trn_data,
         dev_data,
         save_dir,
         feat=None,
         n_embed=100,
         pretrained_embed=None,
         transformer=None,
         average_subwords=False,
         word_dropout: float = 0.2,
         transformer_hidden_dropout=None,
         layer_dropout=0,
         scalar_mix: int = None,
         embed_dropout=.33,
         n_lstm_hidden=400,
         n_lstm_layers=3,
         hidden_dropout=.33,
         n_mlp_arc=500,
         n_mlp_rel=100,
         mlp_dropout=.33,
         lr=2e-3,
         transformer_lr=5e-5,
         mu=.9,
         nu=.9,
         epsilon=1e-12,
         clip=5.0,
         decay=.75,
         decay_steps=5000,
         patience=100,
         batch_size=None,
         sampler_builder=None,
         lowercase=False,
         epochs=50000,
         tree=False,
         punct=False,
         min_freq=2,
         apply_constraint=True,
         joint=False,
         no_cycle=False,
         root=None,
         logger=None,
         verbose=True,
         unk=UNK,
         pad_rel=None,
         max_sequence_length=512,
         devices: Union[float, int, List[int]] = None,
         transform=None,
         **kwargs):
     return super().fit(**merge_locals_kwargs(locals(), kwargs))
예제 #17
0
 def fit(self,
         trn_data,
         dev_data,
         save_dir,
         embed,
         n_mlp_arc=500,
         n_mlp_rel=100,
         n_mlp_sib=100,
         mlp_dropout=.33,
         lr=2e-3,
         transformer_lr=5e-5,
         mu=.9,
         nu=.9,
         epsilon=1e-12,
         grad_norm=5.0,
         decay=.75,
         decay_steps=5000,
         weight_decay=0,
         warmup_steps=0.1,
         separate_optimizer=True,
         patience=100,
         lowercase=False,
         epochs=50000,
         tree=False,
         proj=True,
         mbr=True,
         partial=False,
         punct=False,
         min_freq=2,
         logger=None,
         verbose=True,
         unk=UNK,
         max_sequence_length=512,
         batch_size=None,
         sampler_builder=None,
         gradient_accumulation=1,
         devices: Union[float, int, List[int]] = None,
         transform=None,
         eval_trn=False,
         bos='\0',
         **kwargs):
     return super().fit(**merge_locals_kwargs(locals(), kwargs))
예제 #18
0
 def __init__(self,
              trn: str = None,
              dev: str = None,
              tst: str = None,
              sampler_builder: SamplerBuilder = None,
              dependencies: str = None,
              scalar_mix: ScalarMixWithDropoutBuilder = None,
              use_raw_hidden_states=False,
              lr=1e-3,
              separate_optimizer=False,
              cls_is_bos=False,
              sep_is_eos=False,
              delimiter=None,
              max_seq_len=None,
              sent_delimiter=None,
              char_level=False,
              hard_constraint=False,
              token_key='token',
              **kwargs) -> None:
     super().__init__(**merge_locals_kwargs(locals(), kwargs))
     self.vocabs = VocabDict()
예제 #19
0
 def fit(self,
         trn_data,
         dev_data,
         save_dir,
         transformer=None,
         mask_prob=0.15,
         projection=None,
         average_subwords=False,
         transformer_hidden_dropout=None,
         layer_dropout=0,
         mix_embedding: int = None,
         embed_dropout=.33,
         n_mlp_arc=500,
         n_mlp_rel=100,
         mlp_dropout=.33,
         lr=2e-3,
         transformer_lr=5e-5,
         mu=.9,
         nu=.9,
         epsilon=1e-12,
         clip=5.0,
         decay=.75,
         decay_steps=5000,
         patience=100,
         sampler='kmeans',
         n_buckets=32,
         batch_max_tokens=5000,
         batch_size=None,
         epochs=50000,
         tree=False,
         punct=False,
         logger=None,
         verbose=True,
         max_sequence_length=512,
         devices: Union[float, int, List[int]] = None,
         transform=None,
         **kwargs):
     return TorchComponent.fit(self,
                               **merge_locals_kwargs(locals(), kwargs))
예제 #20
0
 def fit(self,
         trn_data,
         dev_data,
         save_dir,
         transformer,
         average_subwords=False,
         word_dropout: float = 0.2,
         hidden_dropout=None,
         layer_dropout=0,
         scalar_mix=None,
         mix_embedding: int = 0,
         grad_norm=5.0,
         transformer_grad_norm=None,
         lr=5e-5,
         transformer_lr=None,
         transformer_layers=None,
         gradient_accumulation=1,
         adam_epsilon=1e-6,
         weight_decay=0,
         warmup_steps=0.1,
         secondary_encoder=None,
         crf=False,
         reduction='sum',
         batch_size=32,
         sampler_builder: SamplerBuilder = None,
         epochs=3,
         patience=5,
         token_key=None,
         delimiter=None,
         max_seq_len=None,
         sent_delimiter=None,
         char_level=False,
         hard_constraint=False,
         transform=None,
         logger=None,
         devices: Union[float, int, List[int]] = None,
         **kwargs):
     return super().fit(**merge_locals_kwargs(locals(), kwargs))
예제 #21
0
 def fit(self,
         trn_data,
         dev_data,
         save_dir,
         encoder,
         batch_size=None,
         batch_max_tokens=17776,
         epochs=1000,
         gradient_accumulation=4,
         char2concept_dim=128,
         cnn_filters=((3, 256), ),
         concept_char_dim=32,
         concept_dim=300,
         dropout=0.2,
         embed_dim=512,
         eval_every=20,
         ff_embed_dim=1024,
         graph_layers=2,
         inference_layers=4,
         num_heads=8,
         rel_dim=100,
         snt_layers=4,
         unk_rate=0.33,
         warmup_steps=0.1,
         lr=1e-3,
         transformer_lr=1e-4,
         adam_epsilon=1e-6,
         weight_decay=0,
         grad_norm=1.0,
         shuffle_sibling_steps=0.9,
         vocab_min_freq=5,
         amr_version='2.0',
         devices=None,
         logger=None,
         seed=None,
         **kwargs):
     return super().fit(**merge_locals_kwargs(locals(), kwargs))
예제 #22
0
 def __init__(self,
              trn: str = None,
              dev: str = None,
              tst: str = None,
              sampler_builder: SamplerBuilder = None,
              dependencies: str = None,
              scalar_mix: ScalarMixWithDropoutBuilder = None,
              use_raw_hidden_states=False,
              lr=1e-3,
              separate_optimizer=False,
              cls_is_bos=True,
              sep_is_eos=False,
              char2concept_dim=128,
              cnn_filters=((3, 256), ),
              concept_char_dim=32,
              concept_dim=300,
              dropout=0.2,
              embed_dim=512,
              eval_every=20,
              ff_embed_dim=1024,
              graph_layers=2,
              inference_layers=4,
              num_heads=8,
              rel_dim=100,
              snt_layers=4,
              unk_rate=0.33,
              vocab_min_freq=5,
              beam_size=8,
              alpha=0.6,
              max_time_step=100,
              amr_version='2.0',
              **kwargs) -> None:
     super().__init__(**merge_locals_kwargs(locals(), kwargs))
     self.vocabs = VocabDict()
     utils_dir = get_resource(get_amr_utils(amr_version))
     self.sense_restore = NodeRestore(NodeUtilities.from_json(utils_dir))
예제 #23
0
    def fit(self,
            trn_data,
            dev_data,
            save_dir,
            batch_size=32,
            epochs=30,
            transformer='facebook/bart-base',
            lr=5e-05,
            grad_norm=2.5,
            weight_decay=0.004,
            warmup_steps=1,
            dropout=0.25,
            attention_dropout=0.0,
            pred_min=5,
            eval_after=0.5,
            collapse_name_ops=False,
            use_pointer_tokens=True,
            raw_graph=False,
            gradient_accumulation=1,
            recategorization_tokens=('PERSON', 'COUNTRY', 'QUANTITY',
                                     'ORGANIZATION', 'DATE_ATTRS',
                                     'NATIONALITY', 'LOCATION', 'ENTITY',
                                     'CITY', 'MISC', 'ORDINAL_ENTITY',
                                     'IDEOLOGY', 'RELIGION',
                                     'STATE_OR_PROVINCE', 'URL',
                                     'CAUSE_OF_DEATH', 'O', 'TITLE', 'DATE',
                                     'NUMBER', 'HANDLE', 'SCORE_ENTITY',
                                     'DURATION', 'ORDINAL', 'MONEY', 'SET',
                                     'CRIMINAL_CHARGE', '_1', '_2', '_3', '_4',
                                     '_2', '_5', '_6', '_7', '_8', '_9', '_10',
                                     '_11', '_12', '_13', '_14', '_15'),
            additional_tokens=(
                'date-entity', 'government-organization', 'temporal-quantity',
                'amr-unknown', 'multi-sentence', 'political-party',
                'monetary-quantity', 'ordinal-entity', 'religious-group',
                'percentage-entity', 'world-region', 'url-entity',
                'political-movement', 'et-cetera', 'at-least', 'mass-quantity',
                'have-org-role-91', 'have-rel-role-91', 'include-91',
                'have-concession-91', 'have-condition-91', 'be-located-at-91',
                'rate-entity-91', 'instead-of-91', 'hyperlink-91',
                'request-confirmation-91', 'have-purpose-91',
                'be-temporally-at-91', 'regardless-91', 'have-polarity-91',
                'byline-91', 'have-manner-91', 'have-part-91', 'have-quant-91',
                'publication-91', 'be-from-91', 'have-mod-91',
                'have-frequency-91', 'score-on-scale-91', 'have-li-91',
                'be-compared-to-91', 'be-destined-for-91', 'course-91',
                'have-subevent-91', 'street-address-91', 'have-extent-91',
                'statistical-test-91', 'have-instrument-91', 'have-name-91',
                'be-polite-91', '-00', '-01', '-02', '-03', '-04', '-05',
                '-06', '-07', '-08', '-09', '-10', '-11', '-12', '-13', '-14',
                '-15', '-16', '-17', '-18', '-19', '-20', '-21', '-22', '-23',
                '-24', '-25', '-26', '-27', '-28', '-29', '-20', '-31', '-32',
                '-33', '-34', '-35', '-36', '-37', '-38', '-39', '-40', '-41',
                '-42', '-43', '-44', '-45', '-46', '-47', '-48', '-49', '-50',
                '-51', '-52', '-53', '-54', '-55', '-56', '-57', '-58', '-59',
                '-60', '-61', '-62', '-63', '-64', '-65', '-66', '-67', '-68',
                '-69', '-70', '-71', '-72', '-73', '-74', '-75', '-76', '-77',
                '-78', '-79', '-80', '-81', '-82', '-83', '-84', '-85', '-86',
                '-87', '-88', '-89', '-90', '-91', '-92', '-93', '-94', '-95',
                '-96', '-97', '-98', '-of'),
            devices=None,
            logger=None,
            seed=None,
            finetune: Union[bool, str] = False,
            eval_trn=True,
            _device_placeholder=False,
            **kwargs):
        """

        Args:
            trn_data:
            dev_data:
            save_dir:
            batch_size:
            epochs:
            transformer:
            lr:
            grad_norm:
            weight_decay:
            warmup_steps:
            dropout:
            attention_dropout:
            pred_min:
            eval_after:
            collapse_name_ops: ``True`` to merge name ops.
            use_pointer_tokens: ``True`` to use pointer tokens to represent variables.
            raw_graph: ``True`` to use the raw graph as input and skip all pre/post-processing steps.
            gradient_accumulation:
            recategorization_tokens: Tokens used in re-categorization. They will be added to tokenizer too but do not
            put them into ``additional_tokens``.
            additional_tokens: Tokens to be added to the tokenizer vocab.
            devices:
            logger:
            seed:
            finetune:
            eval_trn:
            _device_placeholder:
            **kwargs:

        Returns:

        """
        return super().fit(**merge_locals_kwargs(locals(), kwargs))