Ejemplo n.º 1
0
 def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger = None,
                      sampler_builder: SamplerBuilder = None, gradient_accumulation=1,
                      extra_embeddings: Embedding = None, transform=None, **kwargs) -> DataLoader:
     if isinstance(data, TransformableDataset):
         dataset = data
     else:
         args = dict((k, self.config.get(k, None)) for k in
                     ['delimiter', 'max_seq_len', 'sent_delimiter', 'char_level', 'hard_constraint'])
         dataset = self.build_dataset(data, **args)
     if self.config.token_key is None:
         self.config.token_key = next(iter(dataset[0]))
         logger.info(
             f'Guess [bold][blue]token_key={self.config.token_key}[/blue][/bold] according to the '
             f'training dataset: [blue]{dataset}[/blue]')
     if transform:
         dataset.append_transform(transform)
     if extra_embeddings:
         dataset.append_transform(extra_embeddings.transform(self.vocabs))
     dataset.append_transform(self.tokenizer_transform)
     dataset.append_transform(self.last_transform())
     if not isinstance(data, list):
         dataset.purge_cache()
     if self.vocabs.mutable:
         self.build_vocabs(dataset, logger)
     if sampler_builder is not None:
         sampler = sampler_builder.build([len(x[f'{self.config.token_key}_input_ids']) for x in dataset], shuffle,
                                         gradient_accumulation=gradient_accumulation if shuffle else 1)
     else:
         sampler = None
     return PadSequenceDataLoader(dataset, batch_size, shuffle, device=device, batch_sampler=sampler)
Ejemplo n.º 2
0
 def build_model(self, training=True, extra_embeddings: Embedding = None, **kwargs) -> torch.nn.Module:
     model = TransformerTaggingModel(
         self.build_transformer(training=training),
         len(self.vocabs.tag),
         self.config.crf,
         self.config.get('secondary_encoder', None),
         extra_embeddings=extra_embeddings.module(self.vocabs) if extra_embeddings else None,
     )
     return model
Ejemplo n.º 3
0
 def build_model(self, embed: Embedding, encoder, training, **kwargs) -> torch.nn.Module:
     # noinspection PyCallByClass
     model = SpanBIOSemanticRoleLabelingModel(
         embed.module(training=training, vocabs=self.vocabs),
         encoder,
         len(self.vocabs.srl),
         self.config.n_mlp_rel,
         self.config.mlp_dropout,
         self.config.crf,
     )
     return model
Ejemplo n.º 4
0
    def build_dataloader(self,
                         data,
                         shuffle,
                         device,
                         embed: Embedding,
                         training=False,
                         logger=None,
                         gradient_accumulation=1,
                         sampler_builder=None,
                         batch_size=None,
                         bos='\0',
                         **kwargs) -> DataLoader:
        first_transform = TransformList(functools.partial(append_bos, bos=bos))
        embed_transform = embed.transform(vocabs=self.vocabs)
        transformer_transform = self._get_transformer_transform_from_transforms(
            embed_transform)
        if embed_transform:
            if transformer_transform and isinstance(embed_transform,
                                                    TransformList):
                embed_transform.remove(transformer_transform)

            first_transform.append(embed_transform)
        dataset = self.build_dataset(data, first_transform=first_transform)
        if self.config.get('transform', None):
            dataset.append_transform(self.config.transform)

        if self.vocabs.mutable:
            self.build_vocabs(dataset, logger, self._transformer_trainable())
        if transformer_transform and isinstance(embed_transform,
                                                TransformList):
            embed_transform.append(transformer_transform)

        dataset.append_transform(FieldLength('token', 'sent_length'))
        if isinstance(data, str):
            dataset.purge_cache()
        if len(dataset) > 1000 and isinstance(data, str):
            timer = CountdownTimer(len(dataset))
            self.cache_dataset(dataset, timer, training, logger)
        if sampler_builder:
            lens = [sample['sent_length'] for sample in dataset]
            sampler = sampler_builder.build(lens, shuffle,
                                            gradient_accumulation)
        else:
            sampler = None
        loader = PadSequenceDataLoader(dataset=dataset,
                                       batch_sampler=sampler,
                                       batch_size=batch_size,
                                       pad=self.get_pad_dict(),
                                       device=device,
                                       vocabs=self.vocabs)
        return loader
Ejemplo n.º 5
0
 def build_model(self,
                 embed: Embedding,
                 encoder,
                 n_mlp_arc,
                 n_mlp_rel,
                 mlp_dropout,
                 n_mlp_sib,
                 training=True,
                 **kwargs) -> torch.nn.Module:
     model = DependencyModel(embed=embed.module(vocabs=self.vocabs),
                             encoder=encoder,
                             decoder=TreeCRFDecoder(
                                 encoder.get_output_dim(), n_mlp_arc,
                                 n_mlp_sib, n_mlp_rel, mlp_dropout,
                                 len(self.vocabs['rel'])))
     return model