def log_run(split: str, epoch: int, writer: tf.summary.SummaryWriter, label_names: Sequence[str], metrics: MutableMapping[str, float], heaps: Mapping[str, Mapping[int, List[HeapItem]]], cm: np.ndarray) -> None: """Logs the outputs (metrics, confusion matrix, tp/fp/fn images) from a single epoch run to Tensorboard. Args: metrics: dict, keys already prefixed with {split}/ """ per_class_recall = recall_from_confusion_matrix(cm, label_names) metrics.update(prefix_all_keys(per_class_recall, f'{split}/label_recall/')) # log metrics for metric, value in metrics.items(): tf.summary.scalar(metric, value, epoch) # log confusion matrix cm_fig = plot_utils.plot_confusion_matrix(cm, classes=label_names, normalize=True) cm_fig_img = tf.convert_to_tensor(fig_to_img(cm_fig)[np.newaxis, ...]) tf.summary.image(f'confusion_matrix/{split}', cm_fig_img, step=epoch) # log tp/fp/fn images for heap_type, heap_dict in heaps.items(): log_images_with_confidence(heap_dict, label_names, epoch=epoch, tag=f'{split}/{heap_type}') writer.flush()
def log_run(split: str, epoch: int, writer: tensorboard.SummaryWriter, label_names: Sequence[str], metrics: MutableMapping[str, float], heaps: Optional[Mapping[str, Mapping[int, list[HeapItem]]]], cm: np.ndarray) -> None: """Logs the outputs (metrics, confusion matrix, tp/fp/fn images) from a single epoch run to Tensorboard. Args: metrics: dict, keys already prefixed with {split}/ """ per_label_recall = recall_from_confusion_matrix(cm, label_names) metrics.update(prefix_all_keys(per_label_recall, f'{split}/label_recall/')) # log metrics for metric, value in metrics.items(): writer.add_scalar(metric, value, epoch) # log confusion matrix cm_fig = plot_utils.plot_confusion_matrix(cm, classes=label_names, normalize=True) cm_fig_img = fig_to_img(cm_fig) writer.add_image(tag=f'confusion_matrix/{split}', img_tensor=cm_fig_img, global_step=epoch, dataformats='HWC') # log tp/fp/fn images if heaps is not None: for heap_type, heap_dict in heaps.items(): log_images_with_confidence(writer, heap_dict, label_names, epoch=epoch, tag=f'{split}/{heap_type}') writer.flush()
def main(dataset_dir: str, cropped_images_dir: str, multilabel: bool, model_name: str, pretrained: bool | str, finetune: int, label_weighted: bool, weight_by_detection_conf: bool | str, epochs: int, batch_size: int, lr: float, weight_decay: float, num_workers: int, logdir: str, log_extreme_examples: int, seed: Optional[int] = None) -> None: """Main function.""" # input validation assert os.path.exists(dataset_dir) assert os.path.exists(cropped_images_dir) if isinstance(weight_by_detection_conf, str): assert os.path.exists(weight_by_detection_conf) if isinstance(pretrained, str): assert os.path.exists(pretrained) # set seed seed = np.random.randint(10_000) if seed is None else seed np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # create logdir and save params params = dict(locals()) # make a copy timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') # '20200722_110816' logdir = os.path.join(logdir, timestamp) os.makedirs(logdir, exist_ok=True) print('Created logdir:', logdir) params_json_path = os.path.join(logdir, 'params.json') with open(params_json_path, 'w') as f: json.dump(params, f, indent=1) if 'efficientnet' in model_name: img_size = efficientnet.EfficientNet.get_image_size(model_name) else: img_size = 224 # create dataloaders and log the index_to_label mapping print('Creating dataloaders') loaders, label_names = create_dataloaders( dataset_csv_path=os.path.join(dataset_dir, 'classification_ds.csv'), label_index_json_path=os.path.join(dataset_dir, 'label_index.json'), splits_json_path=os.path.join(dataset_dir, 'splits.json'), cropped_images_dir=cropped_images_dir, img_size=img_size, multilabel=multilabel, label_weighted=label_weighted, weight_by_detection_conf=weight_by_detection_conf, batch_size=batch_size, num_workers=num_workers, augment_train=True) writer = tensorboard.SummaryWriter(logdir) # create model model = build_model(model_name, num_classes=len(label_names), pretrained=pretrained, finetune=finetune > 0) model, device = prep_device(model) # define loss function and optimizer loss_fn: torch.nn.Module if multilabel: loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none').to(device) else: loss_fn = torch.nn.CrossEntropyLoss(reduction='none').to(device) # using EfficientNet training defaults # - batch norm momentum: 0.99 # - optimizer: RMSProp, decay 0.9 and momentum 0.9 # - epochs: 350 # - learning rate: 0.256, decays by 0.97 every 2.4 epochs # - weight decay: 1e-5 optimizer: torch.optim.Optimizer if 'efficientnet' in model_name: optimizer = torch.optim.RMSprop(model.parameters(), lr, alpha=0.9, momentum=0.9, weight_decay=weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=1, gamma=0.97**(1 / 2.4)) else: # resnet optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9, weight_decay=weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer=optimizer, step_size=8, gamma=0.1) # lower every 8 epochs best_epoch_metrics: dict[str, float] = {} for epoch in range(epochs): print(f'Epoch: {epoch}') writer.add_scalar('lr', lr_scheduler.get_last_lr()[0], epoch) if epoch > 0 and finetune == epoch: print('Turning off fine-tune!') set_finetune(model, model_name, finetune=False) print('- train:') train_metrics, train_heaps, train_cm = run_epoch( model, loader=loaders['train'], weighted=False, device=device, loss_fn=loss_fn, finetune=finetune > epoch, optimizer=optimizer, k_extreme=log_extreme_examples) train_metrics = prefix_all_keys(train_metrics, prefix='train/') log_run('train', epoch, writer, label_names, metrics=train_metrics, heaps=train_heaps, cm=train_cm) del train_heaps print('- val:') val_metrics, val_heaps, val_cm = run_epoch( model, loader=loaders['val'], weighted=label_weighted, device=device, loss_fn=loss_fn, k_extreme=log_extreme_examples) val_metrics = prefix_all_keys(val_metrics, prefix='val/') log_run('val', epoch, writer, label_names, metrics=val_metrics, heaps=val_heaps, cm=val_cm) del val_heaps lr_scheduler.step() # decrease the learning rate if val_metrics['val/acc_top1'] > best_epoch_metrics.get('val/acc_top1', 0): # pylint: disable=line-too-long filename = os.path.join(logdir, f'ckpt_{epoch}.pt') print(f'New best model! Saving checkpoint to {filename}') state = { 'epoch': epoch, 'model': getattr(model, 'module', model).state_dict(), 'val/acc': val_metrics['val/acc_top1'], 'optimizer': optimizer.state_dict() } torch.save(state, filename) best_epoch_metrics.update(train_metrics) best_epoch_metrics.update(val_metrics) best_epoch_metrics['epoch'] = epoch print('- test:') test_metrics, test_heaps, test_cm = run_epoch( model, loader=loaders['test'], weighted=label_weighted, device=device, loss_fn=loss_fn, k_extreme=log_extreme_examples) test_metrics = prefix_all_keys(test_metrics, prefix='test/') log_run('test', epoch, writer, label_names, metrics=test_metrics, heaps=test_heaps, cm=test_cm) del test_heaps # stop training after 8 epochs without improvement if epoch >= best_epoch_metrics['epoch'] + 8: break hparams_dict = { 'model_name': model_name, 'multilabel': multilabel, 'finetune': finetune, 'batch_size': batch_size, 'epochs': epochs } metric_dict = prefix_all_keys(best_epoch_metrics, prefix='hparam/') writer.add_hparams(hparam_dict=hparams_dict, metric_dict=metric_dict) writer.close() # do a complete evaluation run best_epoch = best_epoch_metrics['epoch'] evaluate_model.main(params_json_path=params_json_path, ckpt_path=os.path.join(logdir, f'ckpt_{best_epoch}.pt'), output_dir=logdir, splits=evaluate_model.SPLITS)
def main(dataset_dir: str, cropped_images_dir: str, multilabel: bool, model_name: str, pretrained: bool, finetune: int, label_weighted: bool, weight_by_detection_conf: Union[bool, str], epochs: int, batch_size: int, lr: float, weight_decay: float, seed: Optional[int] = None, logdir: str = '', cache_splits: Sequence[str] = ()) -> None: """Main function.""" # input validation assert os.path.exists(dataset_dir) assert os.path.exists(cropped_images_dir) if isinstance(weight_by_detection_conf, str): assert os.path.exists(weight_by_detection_conf) # set seed seed = np.random.randint(10_000) if seed is None else seed np.random.seed(seed) tf.random.set_seed(seed) # create logdir and save params params = dict(locals()) # make a copy timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') # '20200722_110816' logdir = os.path.join(logdir, timestamp) os.makedirs(logdir, exist_ok=True) print('Created logdir:', logdir) with open(os.path.join(logdir, 'params.json'), 'w') as f: json.dump(params, f, indent=1) gpus = tf.config.experimental.list_physical_devices('GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) img_size = EFFICIENTNET_MODELS[model_name]['img_size'] # create dataloaders and log the index_to_label mapping loaders, label_names = create_dataloaders( dataset_csv_path=os.path.join(dataset_dir, 'classification_ds.csv'), label_index_json_path=os.path.join(dataset_dir, 'label_index.json'), splits_json_path=os.path.join(dataset_dir, 'splits.json'), cropped_images_dir=cropped_images_dir, img_size=img_size, multilabel=multilabel, label_weighted=label_weighted, weight_by_detection_conf=weight_by_detection_conf, batch_size=batch_size, augment_train=True, cache_splits=cache_splits) writer = tf.summary.create_file_writer(logdir) writer.set_as_default() model = build_model(model_name, num_classes=len(label_names), img_size=img_size, pretrained=pretrained, finetune=finetune > 0) # define loss function and optimizer loss_fn: tf.keras.losses.Loss if multilabel: loss_fn = tf.keras.losses.BinaryCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE) else: loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE) # using EfficientNet training defaults # - batch norm momentum: 0.99 # - optimizer: RMSProp, decay 0.9 and momentum 0.9 # - epochs: 350 # - learning rate: 0.256, decays by 0.97 every 2.4 epochs # - weight decay: 1e-5 lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( lr, decay_steps=1, decay_rate=0.97, staircase=True) optimizer = tf.keras.optimizers.RMSprop(learning_rate=lr, rho=0.9, momentum=0.9) best_epoch_metrics: Dict[str, float] = {} for epoch in range(epochs): print(f'Epoch: {epoch}') optimizer.learning_rate = lr_schedule(epoch) tf.summary.scalar('lr', optimizer.learning_rate, epoch) if epoch > 0 and finetune == epoch: print('Turning off fine-tune!') model.base_model.trainable = True print('- train:') # TODO: change weighted to False if oversampling minority classes train_metrics, train_heaps, train_cm = run_epoch( model, loader=loaders['train'], weighted=label_weighted, loss_fn=loss_fn, weight_decay=weight_decay, optimizer=optimizer, finetune=finetune > epoch, return_extreme_images=True) train_metrics = prefix_all_keys(train_metrics, prefix='train/') log_run('train', epoch, writer, label_names, metrics=train_metrics, heaps=train_heaps, cm=train_cm) print('- val:') val_metrics, val_heaps, val_cm = run_epoch(model, loader=loaders['val'], weighted=label_weighted, loss_fn=loss_fn, return_extreme_images=True) val_metrics = prefix_all_keys(val_metrics, prefix='val/') log_run('val', epoch, writer, label_names, metrics=val_metrics, heaps=val_heaps, cm=val_cm) if val_metrics['val/acc_top1'] > best_epoch_metrics.get('val/acc_top1', 0): # pylint: disable=line-too-long filename = os.path.join(logdir, f'ckpt_{epoch}.h5') print(f'New best model! Saving checkpoint to {filename}') model.save(filename) best_epoch_metrics.update(train_metrics) best_epoch_metrics.update(val_metrics) best_epoch_metrics['epoch'] = epoch print('- test:') test_metrics, test_heaps, test_cm = run_epoch( model, loader=loaders['test'], weighted=label_weighted, loss_fn=loss_fn, return_extreme_images=True) test_metrics = prefix_all_keys(test_metrics, prefix='test/') log_run('test', epoch, writer, label_names, metrics=test_metrics, heaps=test_heaps, cm=test_cm) # stop training after 8 epochs without improvement if epoch >= best_epoch_metrics['epoch'] + 8: break hparams_dict = { 'model_name': model_name, 'multilabel': multilabel, 'finetune': finetune, 'batch_size': batch_size, 'epochs': epochs } hp.hparams(hparams_dict) writer.close()