def load(self, load_config: Dict): """Defines and executes logic to load checkpoint for fine-tuning. :param load_config: config defining parameters related to loading the model and optimizer :type load_config: Dict """ if load_config['version'] is not None: load_dir = join(self.config.paths['OUT_DIR'], load_config['version'], 'checkpoints') self.load_path = self.checkpoint.get_saved_checkpoint_path( load_dir, load_config['load_best'], load_config['epoch']) logging.info( color("=> Loading model weights from {}".format( self.load_path))) if exists(self.load_path): checkpoint = torch.load(self.load_path) self.network.load_state_dict(checkpoint['network']) if load_config['resume_optimizer']: logging.info(color("=> Resuming optimizer params")) self.optimizer.load_state_dict(checkpoint['optimizer']) if load_config['resume_epoch']: self.epoch_counter = checkpoint['epoch'] else: sys.exit( color( 'Checkpoint file does not exist at {}'.format( self.load_path), 'red'))
def update_best_metric(self, epoch_counter: int, epoch_metric_dict: Dict) -> Dict: """Updates current best metric values with current epoch metrics :param epoch_counter: number of current epoch :type epoch_counter: int :param epoch_metric_dict: dict containing the current epoch metrics :type epoch_metric_dict: Dict :return: dict containing save status incl. filepath to save the model """ assert self.monitor in epoch_metric_dict, \ 'Metric {} not computed'.format(self.monitor) old_value = self.best_metric_dict[self.monitor] new_value = epoch_metric_dict[self.monitor] if self.monitor_mode == 'max': improvement_indicator = (new_value > old_value) elif self.monitor_mode == 'min': improvement_indicator = (new_value < old_value) save_status = {'save': False} if improvement_indicator: self.best_metric_dict[self.monitor] = new_value info = '[{}] {} improved from {} to {}'.format( color('Saving best model', 'red'), self.monitor, old_value, new_value) save_status.update({ 'save': True, 'path': join(self.ckpt_dir, 'best_ckpt.pth.tar'), 'info': info }) else: info = '[{}] {} did not improve from {}'.format( color('Saving regular model', 'red'), self.monitor, old_value) if not ((epoch_counter + 1) % self.period): save_status.update({ 'save': True, 'path': join(self.ckpt_dir, '{}_ckpt.pth.tar'.format(epoch_counter)), 'info': info }) return save_status
def save(self, epoch_metric_values: Dict, use_wandb: bool): """Saves the model and optimizer states :param epoch_metric_values: validation metrics computed for current epoch :type epoch_metric_values: Dict :param use_wandb: flag to decide whether to log visualizations to wandb :type use_wandb: bool """ # updating the best metric and obtaining save-related metadata save_status = self.checkpoint.update_best_metric( self.epoch_counter, epoch_metric_values) # if save status indicates necessaity to save, save the model # keeping this part model-class dependent since this can change # with models if save_status['save']: logging.info(color(save_status['info'], 'red')) torch.save( { 'network': self.network.get_state_dict(), 'optimizer': self.optimizer.state_dict(), 'epoch': self.epoch_counter, # 'metrics': epoch_metric_values }, save_status['path'])
def read_dataset_from_config(data_root: str, dataset_config: DatasetConfigDict) -> dict: """ Loads and returns the dataset version file corresponding to the dataset config. :param data_root: directory where data versions reside :type dataset_config: DatasetConfigDict :param dataset_config: dict containing `(name, version, mode)` corresponding to a dataset. Here, `name` stands for the name of the dataset under the `/data` directory, `version` stands for the version of the dataset (stored in `/data/name/processed/versions/`) and `mode` stands for the split to be loaded (train/val/test). :type dataset_config: DatasetConfigDict :returns: dict of values stored in the version file """ version_fpath = join(data_root, dataset_config['name'], 'processed/versions', dataset_config['version'] + '.yml') print( color("=> Loading dataset version file: [{}, {}, {}]".format( dataset_config['name'], dataset_config['version'], dataset_config['mode']))) version_file = read_yml(version_fpath) return version_file[dataset_config['mode']]
def load_state_dict(self, state_dict: OrderedDict): """Defines helper function to load saved model checkpoint""" try: self.blocks.load_state_dict(state_dict) except RuntimeError: logging.info( color('state_dict does not match strictly. Trying to correct', 'red')) state_dict = _correct_state_dict(state_dict, self.blocks.state_dict()) self.blocks.load_state_dict(state_dict, strict=False)
def _init_modules(self): """Initializes the parameters based on config""" if self.init is None: return logging.info(color('Initializing the parameters')) for m in self.modules(): if isinstance(m, nn.Conv2d): self._init_param(m.weight, 'weight') if m.bias is not None: self._init_param(m.bias, 'bias') elif isinstance(m, nn.BatchNorm2d): self._init_param(m.weight, 'bn_weight') self._init_param(m.bias, 'bn_bias')
def _setup_optimizers(self): """Setup optimizers to be used while training""" if 'optimizer' not in self.model_config: return logging.info(color("Setting up the optimizer ...")) kwargs = self.model_config['optimizer']['args'] kwargs.update({'params': self.network.parameters()}) self.optimizer = optimizer_factory.create( self.model_config['optimizer']['name'], **kwargs) if 'scheduler' in self.model_config['optimizer']: scheduler_config = self.model_config['optimizer']['scheduler'] scheduler_config['params']['optimizer'] = self.optimizer self.scheduler = scheduler_factory.create( scheduler_config['name'], **scheduler_config['params']) self.update_freq = [scheduler_config['update']] if 'value' in scheduler_config: self.value_to_track = scheduler_config['value']
def _freeze_layers(self): """Freeze layers based on config during training""" logging.info(color('Freezing specified layers')) self.network.freeze_layers()
def _setup_network(self): """Setup the network which needs to be trained""" logging.info(color("Building the network")) self.network = network_factory.create( self.network_config['name'], **self.network_config['params']).to(self.device)
def __init__(self, config): super(BinaryClassificationModel, self).__init__(config) logging.info(color('Using loss functions:')) logging.info(self.model_config.get('loss'))
def log_epoch_summary(self, mode: str, epoch_losses: dict, metrics: dict, epoch_data: dict, learning_rates: List[Any], batch_losses: defaultdict, instance_losses: defaultdict, use_wandb: bool): """Logs the summary of the epoch (losses, metrics and visualizations) :param mode: train/val or test mode :type mode: str :param epoch_losses: aggregate losses aggregated for the epoch :type epoch_losses: dict :param metrics: metrics for the epoch :type metrics: dict :param epoch_data: dictionary of various values in the epoch :type epoch_data: dict :param learning_rates: Dynamically accumulated learning rates per batch over all epochs :type learning_rates: List[Any] :param batch_losses: Dynamically accumulated losses per batch :type batch_losses: defaultdict :param instance_losses: losses per instance in the batch :type instance_losses: dict :param use_wandb: flag to decide whether to log visualizations to wandb :type use_wandb: bool """ logging.info( color( "V: {} | Epoch: {} | {} | Avg. Loss {:.4f}".format( self.config.version, self.epoch_counter, mode.capitalize(), epoch_losses['loss']), 'green')) metric_log = "V: {} | Epoch: {} | {}".format(self.config.version, self.epoch_counter, mode.capitalize()) for metric in self.config.metrics_to_track: metric_log += ' | {}: {:.4f}'.format(metric, metrics[metric]) logging.info(color(metric_log, 'green')) # update wandb if use_wandb: self._update_wandb(mode, epoch_losses, metrics, epoch_data, learning_rates, batch_losses) if batch_losses is not None: # reshape batch losses to the shape of instance losses instance_batch_losses = dict() for loss_name, loss_value in batch_losses.items(): loss_value = loss_value.reshape(-1, 1) loss_value = np.repeat(loss_value, self.model_config['batch_size'], axis=-1).reshape(-1) # correct for incomplete last batch instance_batch_losses[ loss_name] = loss_value[:len(epoch_data['items'])] # log instance-level epochwise values instance_values = { 'paths': [item.path for item in epoch_data['items']], 'predictions': epoch_data['predictions'], 'targets': epoch_data['targets'], } for key, value in metrics.items(): instance_values[key] = value for loss_name in instance_losses: instance_values['instance_loss'] = instance_losses[loss_name] if batch_losses is not None: instance_values['batch_loss'] = instance_batch_losses[ loss_name] save_path = join(self.config.log_dir, 'epochwise', '{}/{}.pt'.format(mode, self.epoch_counter)) makedirs(dirname(save_path), exist_ok=True) torch.save(instance_values, save_path)
def get_dataloader( cfg: Dict, mode: str, batch_size: int, num_workers: int = 10, shuffle: bool = True, drop_last: bool = True ) -> Tuple[DataLoader, Dataset]: """Creates the DataLoader and Dataset objects :param cfg: config specifying the dataloader :type cfg: Dict :param mode: mode/split to load; one of {'train', 'test', 'val'} :type mode: str :param batch_size: number of instances in each batch :type batch_size: int :param num_workers: number of cpu workers to use, defaults to 10 :type num_workers: int :param shuffle: whether to shuffle the data, defaults to True :type shuffle: bool, optional :param drop_last: whether to include last batch containing sample less than the batch size, defaults to True :type drop_last: bool, optional :returns: A tuple containing the DataLoader and Dataset objects """ logging.info(color('Creating {} DataLoader'.format(mode), 'blue')) # define target transform target_transform = None if 'target_transform' in cfg: target_transform = annotation_factory.create( cfg['target_transform']['name'], **cfg['target_transform']['params']) # define signal transform signal_transform = None if 'signal_transform' in cfg: signal_transform = DataProcessor(cfg['signal_transform'][mode]) # define Dataset object dataset_params = cfg['dataset']['params'].get(mode, {}) dataset_params.update({ 'target_transform': target_transform, 'signal_transform': signal_transform, 'mode': mode, 'data_type': cfg['data_type'], 'data_root': cfg['root'], 'dataset_config': cfg['dataset']['config'] }) dataset = dataset_factory.create(cfg['dataset']['name'], **dataset_params) # to load entire dataset in one batch if batch_size == -1: batch_size = len(dataset) # define sampler sampler_cfg = cfg['sampler'].get(mode, {'name': 'default'}) sampler_params = sampler_cfg.get('params', {}) sampler_params.update({ 'dataset': dataset, 'shuffle': shuffle, 'target_transform': target_transform }) sampler = sampler_factory.create(sampler_cfg['name'], **sampler_params) # define the collate function for accumulating a batch collate_fn = partial(eval(cfg['collate_fn']['name']), **cfg['collate_fn'].get('params', {})) # define DataLoader object dataloader = DataLoader( dataset=dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, drop_last=drop_last, collate_fn=collate_fn, pin_memory=True) return dataloader, dataset
def fit(self, debug: bool = False, overfit_batch: bool = False, use_wandb: bool = True): """Entry point to training the network :param debug: test run with epoch only on the val set without training, defaults to False :type debug: bool, optional :param overfit_batch: whether this run is for overfitting on a batch, defaults to False :type overfit_batch: bool, optional :param use_wandb: flag for whether to log visualizations to wandb, defaults to True :type use_wandb: bool, optional """ if not debug: # if we are overfitting a batch, then turn off shuffling # for the train set. Else set it to True shuffle = not overfit_batch train_dataloader, _ = get_dataloader( self.data_config, self.config.train_mode, self.model_config['batch_size'], num_workers=self.config.num_workers, shuffle=shuffle, drop_last=False) # ignore val operations when overfitting on a batch if not overfit_batch: val_dataloader, _ = get_dataloader( self.data_config, self.config.val_mode, self.model_config['batch_size'], num_workers=self.config.num_workers, shuffle=False, drop_last=False) else: logging.info(color('Overfitting a single batch', 'blue')) # track gradients and weights in wandb if use_wandb: self.network.watch() for epochID in range(self.model_config['epochs']): if not debug: # train epoch train_results = self.process_epoch(train_dataloader, self.config.train_mode, training=True, use_wandb=use_wandb, overfit_batch=overfit_batch) # ignore val operations when overfitting on a batch if not overfit_batch: # val epoch val_results = self.process_epoch(val_dataloader, self.config.val_mode, training=False, use_wandb=use_wandb) # save best model self.save(val_results, use_wandb=use_wandb) # update optimizer parameters using schedulers that # operate per epoch like ReduceLROnPlateau if hasattr(self, 'update_freq') and 'epoch' in self.update_freq: logging.info('Running scheduler step') self.update_optimizer_params(val_results, 'epoch') # increment epoch counter self.epoch_counter += 1
def process_epoch(self, data_loader: DataLoader, mode: str = None, training: bool = False, use_wandb: bool = True, log_summary: bool = True, overfit_batch: bool = False): """Basic epoch function (Used for train/val/test epochs) Args: :param dataloader: torch DataLoader for the epoch :type dataloader: DataLoader :param mode: train/val/test mode :type mode: str, defaults to None :param training: specifies where the model should be in training mode; if True, network is set to .train(). Else, it is set to .eval() :type training: str, defaults to False :param use_wandb: whether to log visualizations to wandb :type use_wandb: bool, defaults to True :param log_summary: whether to log epoch summary :type log_summary: bool, defaults to True :param overfit_batch: whether this run is for overfitting on a batch :type overfit_batch: bool """ instance_losses = defaultdict(list) batch_losses = defaultdict(list) epoch_data = defaultdict(list) learning_rates = [] if training: training_mode = color('train', 'magenta') self.network.train() else: training_mode = color('eval', 'magenta') self.network.eval() logging.info('{}: {}'.format( color('Setting network training mode:', 'blue'), color(training_mode))) iterator = tqdm(data_loader, dynamic_ncols=True) for batchID, batch in enumerate(iterator): # process one batch to compute and return the inputs, predictions, # ground truth and item in the batch batch_data = self.process_batch(batch) # calculate loss per instance in the batch _instance_losses = self.calculate_instance_loss( predictions=batch_data['predictions'], targets=batch_data['targets'], mode=mode) # calculate loss for the batch _batch_losses = self.calculate_batch_loss(_instance_losses) if mode is not None: # log batch summary self.log_batch_summary(iterator, mode, _batch_losses) # update network weights in training mode if training: self.update_network_params(_batch_losses) # append batch loss to the list of losses for the epoch instance_losses = self._accumulate_losses(instance_losses, _instance_losses) # append batch loss to the list of losses for the epoch batch_losses = self._accumulate_losses(batch_losses, _batch_losses) # accumulate learning rate before scheduler step self._accumulate_lr(learning_rates) # update optimizer parameters using schedulers that operate # per batch like CyclicalLearningRate if hasattr(self, 'update_freq' ) and 'batch' in self.update_freq and training: self.update_optimizer_params(_batch_losses, 'batch') # accumulate predictions, targets and items over the epoch for key in batch_data: if isinstance(batch_data[key], torch.Tensor): batch_data[key] = batch_data[key].detach().cpu() # ignore storing inputs if key == 'inputs': continue epoch_data[key].append(batch_data[key]) # ignore other batches after the first batch if we are # overfitting a batch if overfit_batch: break # break logging.info('Gathering data') epoch_data = self._gather_data(epoch_data) logging.info('Gathering losses') # gather all instance losses instance_losses = self._gather_losses(instance_losses) # gather all batch losses batch_losses = self._gather_losses(batch_losses) # accumulate list of batch losses to epoch loss epoch_losses = self.calculate_epoch_loss(batch_losses) logging.info('Computing metrics') # get parameters for evaluation like the optimal # threshold for classification eval_params = self.get_eval_params(epoch_data) # calculate metrics for the epoch logging.info('Computing metrics') metrics = self.compute_epoch_metrics(epoch_data['predictions'], epoch_data['targets'], **eval_params) if log_summary: logging.info('Logging epoch summary') # log losses, metrics and visualizations self.log_epoch_summary(mode, epoch_losses, metrics, epoch_data, learning_rates, batch_losses, instance_losses, use_wandb) results = dict() results.update(epoch_losses) results.update(metrics) results['batch_losses'] = batch_losses results['instance_losses'] = instance_losses for key in epoch_data: results[key] = epoch_data[key] return results
def evaluate(config, mode, use_wandb, ignore_cache, n_tta): """Run the actual evaluation :param config: config for the model to evaluate :type config: Config :param mode: data mode to evaluate on :type mode: str :param use_wandb: whether to log values to wandb :type use_wandb: bool :param ignore_cache: whether to ignore cached predictions :type ignore_cache: bool """ model = model_factory.create(config.model['name'], **{'config': config}) logging.info(color(f'Evaluating on mode: {mode}')) # reset sampler to default config.data['sampler'].update({mode: {'name': 'default'}}) dataloader, _ = get_dataloader(config.data, mode, config.model['batch_size'], num_workers=config.num_workers, shuffle=False, drop_last=False) # set to eval mode model.network.eval() all_predictions = [] for run_index in range(n_tta): logging.info(f'TTA run #{run_index + 1}') results = model.evaluate(dataloader, mode, use_wandb, ignore_cache, data_only=True, log_summary=False) logging.info(f'AUC = {results["auc-roc"]}') # logits predictions = results['predictions'] # convert to softmax predictions = torch.sigmoid(predictions) # add to list of all predictions across each TTA run all_predictions.append(predictions) all_predictions = torch.stack(all_predictions, -1) # take the mean across several TTA runs predictions = all_predictions.mean(-1) # calculate the metrics on the TTA predictions metrics = model.compute_epoch_metrics(predictions, results['targets'], as_logits=False) print(f'TTA auc: {metrics["auc-roc"]}') # get the file names names = [splitext(basename(item.path))[0] for item in results['items']] # convert to data frame data_frame = pd.DataFrame({ 'image_name': names, 'target': predictions.tolist() }) # save the results save_path = join(config.log_dir, 'evaluation', f'{mode}.csv') os.makedirs(dirname(save_path), exist_ok=True) logging.info(color(f'Saving results to {save_path}')) data_frame.to_csv(save_path, index=False)