def distill(self, teacher: str, trn_data, dev_data, save_dir, batch_size=None, epochs=None, kd_criterion='kd_ce_loss', temperature_scheduler='flsw', devices=None, logger=None, seed=None, **kwargs): devices = devices or cuda_devices() if isinstance(kd_criterion, str): kd_criterion = KnowledgeDistillationLoss(kd_criterion) if isinstance(temperature_scheduler, str): temperature_scheduler = TemperatureScheduler.from_name(temperature_scheduler) teacher = self.build_teacher(teacher, devices=devices) self.vocabs = teacher.vocabs config = copy(teacher.config) batch_size = batch_size or config.get('batch_size', None) epochs = epochs or config.get('epochs', None) config.update(kwargs) return super().fit(**merge_locals_kwargs(locals(), config, excludes=('self', 'kwargs', '__class__', 'config')))
def to(self, devices=Union[int, float, List[int], Dict[str, Union[int, torch.device]]], logger: logging.Logger = None): """ For inference, can only move to CPU or one GPU. """ if devices == -1 or devices == [-1]: devices = [] elif isinstance(devices, (int, float)) or devices is None: devices = cuda_devices(devices) if isinstance(devices, list) and len(devices) > 1: raise ValueError(f'Invalid devices{devices}; at most one GPU can be accepted for inference') super(CoreferenceResolver, self).to(devices=devices, logger=logger) self.torch_device = torch.device('cpu' if len(devices) == 0 else f'cuda:{devices[0]}') self.model.device = self.torch_device self.model.to(self.torch_device) # Double check; if already on device (should be), no action
def to(self, devices=Union[int, float, List[int], Dict[str, Union[int, torch.device]]], logger: logging.Logger = None): if devices == -1 or devices == [-1]: devices = [] elif isinstance(devices, (int, float)) or devices is None: devices = cuda_devices(devices) if devices: if logger: logger.info( f'Using GPUs: [on_blue][cyan][bold]{devices}[/bold][/cyan][/on_blue]' ) if isinstance(devices, list): flash( f'Moving model to GPUs {devices} [blink][yellow]...[/yellow][/blink]' ) self.model = self.model.to(devices[0]) if len(devices) > 1 and not isdebugging() and not isinstance( self.model, nn.DataParallel): self.model = self.parallelize(devices) elif isinstance(devices, dict): for name, module in self.model.named_modules(): for regex, device in devices.items(): try: on_device: torch.device = next( module.parameters()).device except StopIteration: continue if on_device == device: continue if isinstance(device, int): if on_device.index == device: continue if re.match(regex, name): if not name: name = '*' flash( f'Moving module [yellow]{name}[/yellow] to [on_yellow][magenta][bold]{device}' f'[/bold][/magenta][/on_yellow]: [red]{regex}[/red]\n' ) module.to(device) else: raise ValueError(f'Unrecognized devices {devices}') flash('') else: if logger: logger.info('Using CPU')
def fit(self, trn_data, dev_data, save_dir, batch_size, epochs, devices=None, logger=None, seed=None, finetune=False, eval_trn=True, _device_placeholder=False, **kwargs): # Common initialization steps config = self._capture_config(locals()) if not logger: logger = self.build_logger('train', save_dir) if not seed: self.config.seed = 233 if isdebugging() else int(time.time()) set_seed(self.config.seed) logger.info(self._savable_config.to_json(sort=True)) if isinstance(devices, list) or devices is None or isinstance( devices, float): flash('[yellow]Querying CUDA devices [blink]...[/blink][/yellow]') devices = -1 if isdebugging() else cuda_devices(devices) flash('') # flash(f'Available GPUs: {devices}') if isinstance(devices, list): first_device = (devices[0] if devices else -1) elif isinstance(devices, dict): first_device = next(iter(devices.values())) elif isinstance(devices, int): first_device = devices else: first_device = -1 if _device_placeholder and first_device >= 0: _dummy_placeholder = self._create_dummy_placeholder_on( first_device) if finetune: if isinstance(finetune, str): self.load(finetune, devices=devices) else: self.load(save_dir, devices=devices) logger.info( f'Finetune model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}' f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.' ) self.on_config_ready(**self.config) trn = self.build_dataloader(**merge_dict(config, data=trn_data, batch_size=batch_size, shuffle=True, training=True, device=first_device, logger=logger, vocabs=self.vocabs, overwrite=True)) dev = self.build_dataloader( **merge_dict(config, data=dev_data, batch_size=batch_size, shuffle=False, training=None, device=first_device, logger=logger, vocabs=self.vocabs, overwrite=True)) if dev_data else None if not finetune: flash('[yellow]Building model [blink]...[/blink][/yellow]') self.model = self.build_model(**merge_dict(config, training=True)) flash('') logger.info( f'Model built with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}' f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.' ) assert self.model, 'build_model is not properly implemented.' _description = repr(self.model) if len(_description.split('\n')) < 10: logger.info(_description) self.save_config(save_dir) self.save_vocabs(save_dir) self.to(devices, logger) if _device_placeholder and first_device >= 0: del _dummy_placeholder criterion = self.build_criterion(**merge_dict(config, trn=trn)) optimizer = self.build_optimizer( **merge_dict(config, trn=trn, criterion=criterion)) metric = self.build_metric(**self.config) if hasattr(trn.dataset, '__len__') and dev and hasattr( dev.dataset, '__len__'): logger.info( f'{len(trn.dataset)}/{len(dev.dataset)} samples in trn/dev set.' ) trn_size = len(trn) // self.config.get('gradient_accumulation', 1) ratio_width = len(f'{trn_size}/{trn_size}') else: ratio_width = None return self.execute_training_loop(**merge_dict(config, trn=trn, dev=dev, epochs=epochs, criterion=criterion, optimizer=optimizer, metric=metric, logger=logger, save_dir=save_dir, devices=devices, ratio_width=ratio_width, trn_data=trn_data, dev_data=dev_data, eval_trn=eval_trn, overwrite=True))