예제 #1
0
 def distill(self,
             teacher: str,
             trn_data,
             dev_data,
             save_dir,
             batch_size=None,
             epochs=None,
             kd_criterion='kd_ce_loss',
             temperature_scheduler='flsw',
             devices=None,
             logger=None,
             seed=None,
             **kwargs):
     devices = devices or cuda_devices()
     if isinstance(kd_criterion, str):
         kd_criterion = KnowledgeDistillationLoss(kd_criterion)
     if isinstance(temperature_scheduler, str):
         temperature_scheduler = TemperatureScheduler.from_name(temperature_scheduler)
     teacher = self.build_teacher(teacher, devices=devices)
     self.vocabs = teacher.vocabs
     config = copy(teacher.config)
     batch_size = batch_size or config.get('batch_size', None)
     epochs = epochs or config.get('epochs', None)
     config.update(kwargs)
     return super().fit(**merge_locals_kwargs(locals(),
                                              config,
                                              excludes=('self', 'kwargs', '__class__', 'config')))
예제 #2
0
    def to(self, devices=Union[int, float, List[int], Dict[str, Union[int, torch.device]]],
           logger: logging.Logger = None):
        """ For inference, can only move to CPU or one GPU. """
        if devices == -1 or devices == [-1]:
            devices = []
        elif isinstance(devices, (int, float)) or devices is None:
            devices = cuda_devices(devices)
        if isinstance(devices, list) and len(devices) > 1:
            raise ValueError(f'Invalid devices{devices}; at most one GPU can be accepted for inference')
        
        super(CoreferenceResolver, self).to(devices=devices, logger=logger)

        self.torch_device = torch.device('cpu' if len(devices) == 0 else f'cuda:{devices[0]}')
        self.model.device = self.torch_device
        self.model.to(self.torch_device)  # Double check; if already on device (should be), no action
예제 #3
0
 def to(self,
        devices=Union[int, float, List[int],
                      Dict[str, Union[int, torch.device]]],
        logger: logging.Logger = None):
     if devices == -1 or devices == [-1]:
         devices = []
     elif isinstance(devices, (int, float)) or devices is None:
         devices = cuda_devices(devices)
     if devices:
         if logger:
             logger.info(
                 f'Using GPUs: [on_blue][cyan][bold]{devices}[/bold][/cyan][/on_blue]'
             )
         if isinstance(devices, list):
             flash(
                 f'Moving model to GPUs {devices} [blink][yellow]...[/yellow][/blink]'
             )
             self.model = self.model.to(devices[0])
             if len(devices) > 1 and not isdebugging() and not isinstance(
                     self.model, nn.DataParallel):
                 self.model = self.parallelize(devices)
         elif isinstance(devices, dict):
             for name, module in self.model.named_modules():
                 for regex, device in devices.items():
                     try:
                         on_device: torch.device = next(
                             module.parameters()).device
                     except StopIteration:
                         continue
                     if on_device == device:
                         continue
                     if isinstance(device, int):
                         if on_device.index == device:
                             continue
                     if re.match(regex, name):
                         if not name:
                             name = '*'
                         flash(
                             f'Moving module [yellow]{name}[/yellow] to [on_yellow][magenta][bold]{device}'
                             f'[/bold][/magenta][/on_yellow]: [red]{regex}[/red]\n'
                         )
                         module.to(device)
         else:
             raise ValueError(f'Unrecognized devices {devices}')
         flash('')
     else:
         if logger:
             logger.info('Using CPU')
예제 #4
0
 def fit(self,
         trn_data,
         dev_data,
         save_dir,
         batch_size,
         epochs,
         devices=None,
         logger=None,
         seed=None,
         finetune=False,
         eval_trn=True,
         _device_placeholder=False,
         **kwargs):
     # Common initialization steps
     config = self._capture_config(locals())
     if not logger:
         logger = self.build_logger('train', save_dir)
     if not seed:
         self.config.seed = 233 if isdebugging() else int(time.time())
     set_seed(self.config.seed)
     logger.info(self._savable_config.to_json(sort=True))
     if isinstance(devices, list) or devices is None or isinstance(
             devices, float):
         flash('[yellow]Querying CUDA devices [blink]...[/blink][/yellow]')
         devices = -1 if isdebugging() else cuda_devices(devices)
         flash('')
     # flash(f'Available GPUs: {devices}')
     if isinstance(devices, list):
         first_device = (devices[0] if devices else -1)
     elif isinstance(devices, dict):
         first_device = next(iter(devices.values()))
     elif isinstance(devices, int):
         first_device = devices
     else:
         first_device = -1
     if _device_placeholder and first_device >= 0:
         _dummy_placeholder = self._create_dummy_placeholder_on(
             first_device)
     if finetune:
         if isinstance(finetune, str):
             self.load(finetune, devices=devices)
         else:
             self.load(save_dir, devices=devices)
         logger.info(
             f'Finetune model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
             f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.'
         )
     self.on_config_ready(**self.config)
     trn = self.build_dataloader(**merge_dict(config,
                                              data=trn_data,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              training=True,
                                              device=first_device,
                                              logger=logger,
                                              vocabs=self.vocabs,
                                              overwrite=True))
     dev = self.build_dataloader(
         **merge_dict(config,
                      data=dev_data,
                      batch_size=batch_size,
                      shuffle=False,
                      training=None,
                      device=first_device,
                      logger=logger,
                      vocabs=self.vocabs,
                      overwrite=True)) if dev_data else None
     if not finetune:
         flash('[yellow]Building model [blink]...[/blink][/yellow]')
         self.model = self.build_model(**merge_dict(config, training=True))
         flash('')
         logger.info(
             f'Model built with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
             f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.'
         )
         assert self.model, 'build_model is not properly implemented.'
     _description = repr(self.model)
     if len(_description.split('\n')) < 10:
         logger.info(_description)
     self.save_config(save_dir)
     self.save_vocabs(save_dir)
     self.to(devices, logger)
     if _device_placeholder and first_device >= 0:
         del _dummy_placeholder
     criterion = self.build_criterion(**merge_dict(config, trn=trn))
     optimizer = self.build_optimizer(
         **merge_dict(config, trn=trn, criterion=criterion))
     metric = self.build_metric(**self.config)
     if hasattr(trn.dataset, '__len__') and dev and hasattr(
             dev.dataset, '__len__'):
         logger.info(
             f'{len(trn.dataset)}/{len(dev.dataset)} samples in trn/dev set.'
         )
         trn_size = len(trn) // self.config.get('gradient_accumulation', 1)
         ratio_width = len(f'{trn_size}/{trn_size}')
     else:
         ratio_width = None
     return self.execute_training_loop(**merge_dict(config,
                                                    trn=trn,
                                                    dev=dev,
                                                    epochs=epochs,
                                                    criterion=criterion,
                                                    optimizer=optimizer,
                                                    metric=metric,
                                                    logger=logger,
                                                    save_dir=save_dir,
                                                    devices=devices,
                                                    ratio_width=ratio_width,
                                                    trn_data=trn_data,
                                                    dev_data=dev_data,
                                                    eval_trn=eval_trn,
                                                    overwrite=True))