def distill(self, teacher: str, trn_data, dev_data, save_dir, batch_size=None, epochs=None, kd_criterion='kd_ce_loss', temperature_scheduler='flsw', devices=None, logger=None, seed=None, **kwargs): devices = devices or cuda_devices() if isinstance(kd_criterion, str): kd_criterion = KnowledgeDistillationLoss(kd_criterion) if isinstance(temperature_scheduler, str): temperature_scheduler = TemperatureScheduler.from_name(temperature_scheduler) teacher = self.build_teacher(teacher, devices=devices) self.vocabs = teacher.vocabs config = copy(teacher.config) batch_size = batch_size or config.get('batch_size', None) epochs = epochs or config.get('epochs', None) config.update(kwargs) return super().fit(**merge_locals_kwargs(locals(), config, excludes=('self', 'kwargs', '__class__', 'config')))
def to(self, devices: Union[int, float, List[int], Dict[str, Union[int, torch.device]]] = None, logger: logging.Logger = None, verbose=HANLP_VERBOSE): """Move this component to devices. Args: devices: Target devices. logger: Logger for printing progress report, as copying a model from CPU to GPU can takes several seconds. verbose: ``True`` to print progress when logger is None. """ if devices == -1 or devices == [-1]: devices = [] elif isinstance(devices, (int, float)) or devices is None: devices = cuda_devices(devices) if devices: if logger: logger.info( f'Using GPUs: [on_blue][cyan][bold]{devices}[/bold][/cyan][/on_blue]' ) if isinstance(devices, list): if verbose: flash( f'Moving model to GPUs {devices} [blink][yellow]...[/yellow][/blink]' ) self.model = self.model.to(devices[0]) if len(devices) > 1 and not isdebugging() and not isinstance( self.model, nn.DataParallel): self.model = self.parallelize(devices) elif isinstance(devices, dict): for name, module in self.model.named_modules(): for regex, device in devices.items(): try: on_device: torch.device = next( module.parameters()).device except StopIteration: continue if on_device == device: continue if isinstance(device, int): if on_device.index == device: continue if re.match(regex, name): if not name: name = '*' flash( f'Moving module [yellow]{name}[/yellow] to [on_yellow][magenta][bold]{device}' f'[/bold][/magenta][/on_yellow]: [red]{regex}[/red]\n' ) module.to(device) else: raise ValueError(f'Unrecognized devices {devices}') if verbose: flash('') else: if logger: logger.info('Using [red]CPU[/red]')
def fit(self, trn_data, dev_data, save_dir, batch_size, epochs, devices=None, logger=None, seed=None, finetune: Union[bool, str] = False, eval_trn=True, _device_placeholder=False, **kwargs): """Fit to data, triggers the training procedure. For training set and dev set, they shall be local or remote files. Args: trn_data: Training set. dev_data: Development set. save_dir: The directory to save trained component. batch_size: The number of samples in a batch. epochs: Number of epochs. devices: Devices this component will live on. logger: Any :class:`logging.Logger` instance. seed: Random seed to reproduce this training. finetune: ``True`` to load from ``save_dir`` instead of creating a randomly initialized component. ``str`` to specify a different ``save_dir`` to load from. eval_trn: Evaluate training set after each update. This can slow down the training but provides a quick diagnostic for debugging. _device_placeholder: ``True`` to create a placeholder tensor which triggers PyTorch to occupy devices so other components won't take these devices as first choices. **kwargs: Hyperparameters used by sub-classes. Returns: Any results sub-classes would like to return. Usually the best metrics on training set. """ # Common initialization steps config = self._capture_config(locals()) if not logger: logger = self.build_logger('train', save_dir) if not seed: self.config.seed = 233 if isdebugging() else int(time.time()) set_seed(self.config.seed) logger.info(self._savable_config.to_json(sort=True)) if isinstance(devices, list) or devices is None or isinstance(devices, float): flash('[yellow]Querying CUDA devices [blink]...[/blink][/yellow]') devices = -1 if isdebugging() else cuda_devices(devices) flash('') # flash(f'Available GPUs: {devices}') if isinstance(devices, list): first_device = (devices[0] if devices else -1) elif isinstance(devices, dict): first_device = next(iter(devices.values())) elif isinstance(devices, int): first_device = devices else: first_device = -1 if _device_placeholder and first_device >= 0: _dummy_placeholder = self._create_dummy_placeholder_on(first_device) if finetune: if isinstance(finetune, str): self.load(finetune, devices=devices) else: self.load(save_dir, devices=devices) logger.info( f'Finetune model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}' f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.') self.on_config_ready(**self.config) trn = self.build_dataloader(**merge_dict(config, data=trn_data, batch_size=batch_size, shuffle=True, training=True, device=first_device, logger=logger, vocabs=self.vocabs, overwrite=True)) dev = self.build_dataloader(**merge_dict(config, data=dev_data, batch_size=batch_size, shuffle=False, training=None, device=first_device, logger=logger, vocabs=self.vocabs, overwrite=True)) if dev_data else None if not finetune: flash('[yellow]Building model [blink]...[/blink][/yellow]') self.model = self.build_model(**merge_dict(config, training=True)) flash('') logger.info(f'Model built with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}' f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.') assert self.model, 'build_model is not properly implemented.' _description = repr(self.model) if len(_description.split('\n')) < 10: logger.info(_description) self.save_config(save_dir) self.save_vocabs(save_dir) self.to(devices, logger) if _device_placeholder and first_device >= 0: del _dummy_placeholder criterion = self.build_criterion(**merge_dict(config, trn=trn)) optimizer = self.build_optimizer(**merge_dict(config, trn=trn, criterion=criterion)) metric = self.build_metric(**self.config) if hasattr(trn.dataset, '__len__') and dev and hasattr(dev.dataset, '__len__'): logger.info(f'{len(trn.dataset)}/{len(dev.dataset)} samples in trn/dev set.') trn_size = len(trn) // self.config.get('gradient_accumulation', 1) ratio_width = len(f'{trn_size}/{trn_size}') else: ratio_width = None return self.execute_training_loop(**merge_dict(config, trn=trn, dev=dev, epochs=epochs, criterion=criterion, optimizer=optimizer, metric=metric, logger=logger, save_dir=save_dir, devices=devices, ratio_width=ratio_width, trn_data=trn_data, dev_data=dev_data, eval_trn=eval_trn, overwrite=True))