Beispiel #1
0
 def __init__(self,
              dataset,
              batch_size=32,
              shuffle=False,
              sampler=None,
              batch_sampler=None,
              num_workers=0,
              collate_fn=None,
              pin_memory=False,
              drop_last=False,
              timeout=0,
              worker_init_fn=None,
              multiprocessing_context=None,
              pad: dict = None,
              vocabs: VocabDict = None,
              device=None,
              **kwargs):
     if device == -1:
         device = None
     if collate_fn is None:
         collate_fn = self.collate_fn
     if num_workers is None:
         if isdebugging():
             num_workers = 0
         else:
             num_workers = 2
     if batch_sampler is None:
         assert batch_size, 'batch_size has to be specified when batch_sampler is None'
     else:
         batch_size = 1
         shuffle = None
         drop_last = None
     # noinspection PyArgumentList
     super(PadSequenceDataLoader,
           self).__init__(dataset=dataset,
                          batch_size=batch_size,
                          shuffle=shuffle,
                          sampler=sampler,
                          batch_sampler=batch_sampler,
                          num_workers=num_workers,
                          collate_fn=collate_fn,
                          pin_memory=pin_memory,
                          drop_last=drop_last,
                          timeout=timeout,
                          worker_init_fn=worker_init_fn,
                          multiprocessing_context=multiprocessing_context,
                          **kwargs)
     self.vocabs = vocabs
     if isinstance(dataset, TransformDataset) and dataset.transform:
         transform = dataset.transform
         if not isinstance(transform, TransformList):
             transform = []
         for each in transform:
             if isinstance(each, EmbeddingNamedTransform):
                 if pad is None:
                     pad = {}
                 if each.dst not in pad:
                     pad[each.dst] = 0
     self.pad = pad
     self.device = device
Beispiel #2
0
 def to(self,
        devices=Union[int, float, List[int],
                      Dict[str, Union[int, torch.device]]],
        logger: logging.Logger = None):
     if devices == -1 or devices == [-1]:
         devices = []
     elif isinstance(devices, (int, float)) or devices is None:
         devices = cuda_devices(devices)
     if devices:
         if logger:
             logger.info(
                 f'Using GPUs: [on_blue][cyan][bold]{devices}[/bold][/cyan][/on_blue]'
             )
         if isinstance(devices, list):
             flash(
                 f'Moving model to GPUs {devices} [blink][yellow]...[/yellow][/blink]'
             )
             self.model = self.model.to(devices[0])
             if len(devices) > 1 and not isdebugging() and not isinstance(
                     self.model, nn.DataParallel):
                 self.model = self.parallelize(devices)
         elif isinstance(devices, dict):
             for name, module in self.model.named_modules():
                 for regex, device in devices.items():
                     try:
                         on_device: torch.device = next(
                             module.parameters()).device
                     except StopIteration:
                         continue
                     if on_device == device:
                         continue
                     if isinstance(device, int):
                         if on_device.index == device:
                             continue
                     if re.match(regex, name):
                         if not name:
                             name = '*'
                         flash(
                             f'Moving module [yellow]{name}[/yellow] to [on_yellow][magenta][bold]{device}'
                             f'[/bold][/magenta][/on_yellow]: [red]{regex}[/red]\n'
                         )
                         module.to(device)
         else:
             raise ValueError(f'Unrecognized devices {devices}')
         flash('')
     else:
         if logger:
             logger.info('Using CPU')
Beispiel #3
0
    def __init__(self, dataloader, prefetch=10, batchify=None) -> None:
        """
        PrefetchDataLoader only works in spawn mode with the following initialization code:

        if __name__ == '__main__':
            import torch

            torch.multiprocessing.set_start_method('spawn')

        And these 2 lines MUST be put into `if __name__ == '__main__':` block.

        Args:
            dataloader:
            prefetch:
            batchify:
        """
        super().__init__(dataset=dataloader)
        self._batchify = batchify
        self.prefetch = None if isdebugging() else prefetch
        if self.prefetch:
            self._fire_process(dataloader, prefetch)
Beispiel #4
0
 def build_dataloader(self,
                      data,
                      shuffle,
                      device,
                      training=False,
                      logger=None,
                      gradient_accumulation=1,
                      sampler_builder=None,
                      batch_size=None,
                      **kwargs) -> DataLoader:
     dataset = self.build_dataset(data)
     if self.vocabs.mutable:
         self.build_vocabs(dataset, logger, self.config.transformer)
     transformer_tokenizer = self.transformer_tokenizer
     if transformer_tokenizer:
         dataset.transform.append(self.build_tokenizer_transform())
     dataset.append_transform(FieldLength('token', 'sent_length'))
     if isinstance(data, str):
         dataset.purge_cache()
     if len(dataset) > 1000 and isinstance(data, str):
         timer = CountdownTimer(len(dataset))
         self.cache_dataset(dataset, timer, training, logger)
     if self.config.transformer:
         lens = [len(sample['input_ids']) for sample in dataset]
     else:
         lens = [sample['sent_length'] for sample in dataset]
     if sampler_builder:
         sampler = sampler_builder.build(lens, shuffle,
                                         gradient_accumulation)
     else:
         sampler = None
     loader = PadSequenceDataLoader(dataset=dataset,
                                    batch_sampler=sampler,
                                    batch_size=batch_size,
                                    num_workers=0 if isdebugging() else 2,
                                    pad=self.get_pad_dict(),
                                    device=device,
                                    vocabs=self.vocabs)
     return loader
Beispiel #5
0
 def __init__(self,
              dataset,
              batch_size=32,
              shuffle=False,
              sampler=None,
              batch_sampler=None,
              num_workers=None,
              collate_fn=None,
              pin_memory=False,
              drop_last=False,
              timeout=0,
              worker_init_fn=None,
              multiprocessing_context=None,
              device=None,
              **kwargs):
     if batch_sampler is not None:
         batch_size = 1
     if num_workers is None:
         if isdebugging():
             num_workers = 0
         else:
             num_workers = 2
     # noinspection PyArgumentList
     super(DeviceDataLoader,
           self).__init__(dataset=dataset,
                          batch_size=batch_size,
                          shuffle=shuffle,
                          sampler=sampler,
                          batch_sampler=batch_sampler,
                          num_workers=num_workers,
                          collate_fn=collate_fn,
                          pin_memory=pin_memory,
                          drop_last=drop_last,
                          timeout=timeout,
                          worker_init_fn=worker_init_fn,
                          multiprocessing_context=multiprocessing_context,
                          **kwargs)
     self.device = device
Beispiel #6
0
 def build_dataloader(self,
                      data,
                      batch_size,
                      shuffle,
                      device,
                      text_a_key,
                      text_b_key,
                      label_key,
                      logger: logging.Logger = None,
                      sorting=True,
                      **kwargs) -> DataLoader:
     if not batch_size:
         batch_size = self.config.batch_size
     dataset = self.build_dataset(data)
     dataset.append_transform(self.vocabs)
     if self.vocabs.mutable:
         if not any([text_a_key, text_b_key]):
             if len(dataset.headers) == 2:
                 self.config.text_a_key = dataset.headers[0]
                 self.config.label_key = dataset.headers[1]
             elif len(dataset.headers) >= 3:
                 self.config.text_a_key, self.config.text_b_key, self.config.label_key = dataset.headers[0], \
                                                                                         dataset.headers[1], \
                                                                                         dataset.headers[-1]
             else:
                 raise ValueError('Wrong dataset format')
             report = {'text_a_key', 'text_b_key', 'label_key'}
             report = dict((k, self.config[k]) for k in report)
             report = [f'{k}={v}' for k, v in report.items() if v]
             report = ', '.join(report)
             logger.info(
                 f'Guess [bold][blue]{report}[/blue][/bold] according to the headers of training dataset: '
                 f'[blue]{dataset}[/blue]')
         self.build_vocabs(dataset, logger)
         dataset.purge_cache()
     # if self.config.transform:
     #     dataset.append_transform(self.config.transform)
     dataset.append_transform(
         TransformerTextTokenizer(
             tokenizer=self.transformer_tokenizer,
             text_a_key=self.config.text_a_key,
             text_b_key=self.config.text_b_key,
             max_seq_length=self.config.max_seq_length,
             truncate_long_sequences=self.config.truncate_long_sequences,
             output_key=''))
     batch_sampler = None
     if sorting and not isdebugging():
         if dataset.cache and len(dataset) > 1000:
             timer = CountdownTimer(len(dataset))
             lens = []
             for idx, sample in enumerate(dataset):
                 lens.append(len(sample['input_ids']))
                 timer.log(
                     'Pre-processing and caching dataset [blink][yellow]...[/yellow][/blink]',
                     ratio_percentage=None)
         else:
             lens = [len(sample['input_ids']) for sample in dataset]
         batch_sampler = SortingSampler(
             lens,
             batch_size=batch_size,
             shuffle=shuffle,
             batch_max_tokens=self.config.batch_max_tokens)
     return PadSequenceDataLoader(dataset,
                                  batch_size,
                                  shuffle,
                                  batch_sampler=batch_sampler,
                                  device=device,
                                  collate_fn=self.collate_fn)
Beispiel #7
0
 def fit(self,
         trn_data,
         dev_data,
         save_dir,
         batch_size,
         epochs,
         devices=None,
         logger=None,
         seed=None,
         finetune=False,
         eval_trn=True,
         _device_placeholder=False,
         **kwargs):
     # Common initialization steps
     config = self._capture_config(locals())
     if not logger:
         logger = self.build_logger('train', save_dir)
     if not seed:
         self.config.seed = 233 if isdebugging() else int(time.time())
     set_seed(self.config.seed)
     logger.info(self._savable_config.to_json(sort=True))
     if isinstance(devices, list) or devices is None or isinstance(
             devices, float):
         flash('[yellow]Querying CUDA devices [blink]...[/blink][/yellow]')
         devices = -1 if isdebugging() else cuda_devices(devices)
         flash('')
     # flash(f'Available GPUs: {devices}')
     if isinstance(devices, list):
         first_device = (devices[0] if devices else -1)
     elif isinstance(devices, dict):
         first_device = next(iter(devices.values()))
     elif isinstance(devices, int):
         first_device = devices
     else:
         first_device = -1
     if _device_placeholder and first_device >= 0:
         _dummy_placeholder = self._create_dummy_placeholder_on(
             first_device)
     if finetune:
         if isinstance(finetune, str):
             self.load(finetune, devices=devices)
         else:
             self.load(save_dir, devices=devices)
         logger.info(
             f'Finetune model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
             f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.'
         )
     self.on_config_ready(**self.config)
     trn = self.build_dataloader(**merge_dict(config,
                                              data=trn_data,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              training=True,
                                              device=first_device,
                                              logger=logger,
                                              vocabs=self.vocabs,
                                              overwrite=True))
     dev = self.build_dataloader(
         **merge_dict(config,
                      data=dev_data,
                      batch_size=batch_size,
                      shuffle=False,
                      training=None,
                      device=first_device,
                      logger=logger,
                      vocabs=self.vocabs,
                      overwrite=True)) if dev_data else None
     if not finetune:
         flash('[yellow]Building model [blink]...[/blink][/yellow]')
         self.model = self.build_model(**merge_dict(config, training=True))
         flash('')
         logger.info(
             f'Model built with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
             f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.'
         )
         assert self.model, 'build_model is not properly implemented.'
     _description = repr(self.model)
     if len(_description.split('\n')) < 10:
         logger.info(_description)
     self.save_config(save_dir)
     self.save_vocabs(save_dir)
     self.to(devices, logger)
     if _device_placeholder and first_device >= 0:
         del _dummy_placeholder
     criterion = self.build_criterion(**merge_dict(config, trn=trn))
     optimizer = self.build_optimizer(
         **merge_dict(config, trn=trn, criterion=criterion))
     metric = self.build_metric(**self.config)
     if hasattr(trn.dataset, '__len__') and dev and hasattr(
             dev.dataset, '__len__'):
         logger.info(
             f'{len(trn.dataset)}/{len(dev.dataset)} samples in trn/dev set.'
         )
         trn_size = len(trn) // self.config.get('gradient_accumulation', 1)
         ratio_width = len(f'{trn_size}/{trn_size}')
     else:
         ratio_width = None
     return self.execute_training_loop(**merge_dict(config,
                                                    trn=trn,
                                                    dev=dev,
                                                    epochs=epochs,
                                                    criterion=criterion,
                                                    optimizer=optimizer,
                                                    metric=metric,
                                                    logger=logger,
                                                    save_dir=save_dir,
                                                    devices=devices,
                                                    ratio_width=ratio_width,
                                                    trn_data=trn_data,
                                                    dev_data=dev_data,
                                                    eval_trn=eval_trn,
                                                    overwrite=True))