def log_run(split: str, epoch: int, writer: tf.summary.SummaryWriter,
            label_names: Sequence[str], metrics: MutableMapping[str, float],
            heaps: Mapping[str,
                           Mapping[int,
                                   List[HeapItem]]], cm: np.ndarray) -> None:
    """Logs the outputs (metrics, confusion matrix, tp/fp/fn images) from a
    single epoch run to Tensorboard.

    Args:
        metrics: dict, keys already prefixed with {split}/
    """
    per_class_recall = recall_from_confusion_matrix(cm, label_names)
    metrics.update(prefix_all_keys(per_class_recall, f'{split}/label_recall/'))

    # log metrics
    for metric, value in metrics.items():
        tf.summary.scalar(metric, value, epoch)

    # log confusion matrix
    cm_fig = plot_utils.plot_confusion_matrix(cm,
                                              classes=label_names,
                                              normalize=True)
    cm_fig_img = tf.convert_to_tensor(fig_to_img(cm_fig)[np.newaxis, ...])
    tf.summary.image(f'confusion_matrix/{split}', cm_fig_img, step=epoch)

    # log tp/fp/fn images
    for heap_type, heap_dict in heaps.items():
        log_images_with_confidence(heap_dict,
                                   label_names,
                                   epoch=epoch,
                                   tag=f'{split}/{heap_type}')
    writer.flush()
Example #2
0
def log_run(split: str, epoch: int, writer: tensorboard.SummaryWriter,
            label_names: Sequence[str], metrics: MutableMapping[str, float],
            heaps: Optional[Mapping[str, Mapping[int, list[HeapItem]]]],
            cm: np.ndarray) -> None:
    """Logs the outputs (metrics, confusion matrix, tp/fp/fn images) from a
    single epoch run to Tensorboard.

    Args:
        metrics: dict, keys already prefixed with {split}/
    """
    per_label_recall = recall_from_confusion_matrix(cm, label_names)
    metrics.update(prefix_all_keys(per_label_recall, f'{split}/label_recall/'))

    # log metrics
    for metric, value in metrics.items():
        writer.add_scalar(metric, value, epoch)

    # log confusion matrix
    cm_fig = plot_utils.plot_confusion_matrix(cm, classes=label_names,
                                              normalize=True)
    cm_fig_img = fig_to_img(cm_fig)
    writer.add_image(tag=f'confusion_matrix/{split}', img_tensor=cm_fig_img,
                     global_step=epoch, dataformats='HWC')

    # log tp/fp/fn images
    if heaps is not None:
        for heap_type, heap_dict in heaps.items():
            log_images_with_confidence(writer, heap_dict, label_names,
                                       epoch=epoch, tag=f'{split}/{heap_type}')
    writer.flush()
def main(dataset_dir: str,
         cropped_images_dir: str,
         multilabel: bool,
         model_name: str,
         pretrained: bool | str,
         finetune: int,
         label_weighted: bool,
         weight_by_detection_conf: bool | str,
         epochs: int,
         batch_size: int,
         lr: float,
         weight_decay: float,
         num_workers: int,
         logdir: str,
         log_extreme_examples: int,
         seed: Optional[int] = None) -> None:
    """Main function."""
    # input validation
    assert os.path.exists(dataset_dir)
    assert os.path.exists(cropped_images_dir)
    if isinstance(weight_by_detection_conf, str):
        assert os.path.exists(weight_by_detection_conf)
    if isinstance(pretrained, str):
        assert os.path.exists(pretrained)

    # set seed
    seed = np.random.randint(10_000) if seed is None else seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # create logdir and save params
    params = dict(locals())  # make a copy
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')  # '20200722_110816'
    logdir = os.path.join(logdir, timestamp)
    os.makedirs(logdir, exist_ok=True)
    print('Created logdir:', logdir)
    params_json_path = os.path.join(logdir, 'params.json')
    with open(params_json_path, 'w') as f:
        json.dump(params, f, indent=1)

    if 'efficientnet' in model_name:
        img_size = efficientnet.EfficientNet.get_image_size(model_name)
    else:
        img_size = 224

    # create dataloaders and log the index_to_label mapping
    print('Creating dataloaders')
    loaders, label_names = create_dataloaders(
        dataset_csv_path=os.path.join(dataset_dir, 'classification_ds.csv'),
        label_index_json_path=os.path.join(dataset_dir, 'label_index.json'),
        splits_json_path=os.path.join(dataset_dir, 'splits.json'),
        cropped_images_dir=cropped_images_dir,
        img_size=img_size,
        multilabel=multilabel,
        label_weighted=label_weighted,
        weight_by_detection_conf=weight_by_detection_conf,
        batch_size=batch_size,
        num_workers=num_workers,
        augment_train=True)

    writer = tensorboard.SummaryWriter(logdir)

    # create model
    model = build_model(model_name,
                        num_classes=len(label_names),
                        pretrained=pretrained,
                        finetune=finetune > 0)
    model, device = prep_device(model)

    # define loss function and optimizer
    loss_fn: torch.nn.Module
    if multilabel:
        loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none').to(device)
    else:
        loss_fn = torch.nn.CrossEntropyLoss(reduction='none').to(device)

    # using EfficientNet training defaults
    # - batch norm momentum: 0.99
    # - optimizer: RMSProp, decay 0.9 and momentum 0.9
    # - epochs: 350
    # - learning rate: 0.256, decays by 0.97 every 2.4 epochs
    # - weight decay: 1e-5
    optimizer: torch.optim.Optimizer
    if 'efficientnet' in model_name:
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr,
                                        alpha=0.9,
                                        momentum=0.9,
                                        weight_decay=weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,
                                                       step_size=1,
                                                       gamma=0.97**(1 / 2.4))
    else:  # resnet
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr,
                                    momentum=0.9,
                                    weight_decay=weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer, step_size=8,
            gamma=0.1)  # lower every 8 epochs

    best_epoch_metrics: dict[str, float] = {}
    for epoch in range(epochs):
        print(f'Epoch: {epoch}')
        writer.add_scalar('lr', lr_scheduler.get_last_lr()[0], epoch)

        if epoch > 0 and finetune == epoch:
            print('Turning off fine-tune!')
            set_finetune(model, model_name, finetune=False)

        print('- train:')
        train_metrics, train_heaps, train_cm = run_epoch(
            model,
            loader=loaders['train'],
            weighted=False,
            device=device,
            loss_fn=loss_fn,
            finetune=finetune > epoch,
            optimizer=optimizer,
            k_extreme=log_extreme_examples)
        train_metrics = prefix_all_keys(train_metrics, prefix='train/')
        log_run('train',
                epoch,
                writer,
                label_names,
                metrics=train_metrics,
                heaps=train_heaps,
                cm=train_cm)
        del train_heaps

        print('- val:')
        val_metrics, val_heaps, val_cm = run_epoch(
            model,
            loader=loaders['val'],
            weighted=label_weighted,
            device=device,
            loss_fn=loss_fn,
            k_extreme=log_extreme_examples)
        val_metrics = prefix_all_keys(val_metrics, prefix='val/')
        log_run('val',
                epoch,
                writer,
                label_names,
                metrics=val_metrics,
                heaps=val_heaps,
                cm=val_cm)
        del val_heaps

        lr_scheduler.step()  # decrease the learning rate

        if val_metrics['val/acc_top1'] > best_epoch_metrics.get('val/acc_top1', 0):  # pylint: disable=line-too-long
            filename = os.path.join(logdir, f'ckpt_{epoch}.pt')
            print(f'New best model! Saving checkpoint to {filename}')
            state = {
                'epoch': epoch,
                'model': getattr(model, 'module', model).state_dict(),
                'val/acc': val_metrics['val/acc_top1'],
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, filename)
            best_epoch_metrics.update(train_metrics)
            best_epoch_metrics.update(val_metrics)
            best_epoch_metrics['epoch'] = epoch

            print('- test:')
            test_metrics, test_heaps, test_cm = run_epoch(
                model,
                loader=loaders['test'],
                weighted=label_weighted,
                device=device,
                loss_fn=loss_fn,
                k_extreme=log_extreme_examples)
            test_metrics = prefix_all_keys(test_metrics, prefix='test/')
            log_run('test',
                    epoch,
                    writer,
                    label_names,
                    metrics=test_metrics,
                    heaps=test_heaps,
                    cm=test_cm)
            del test_heaps

        # stop training after 8 epochs without improvement
        if epoch >= best_epoch_metrics['epoch'] + 8:
            break

    hparams_dict = {
        'model_name': model_name,
        'multilabel': multilabel,
        'finetune': finetune,
        'batch_size': batch_size,
        'epochs': epochs
    }
    metric_dict = prefix_all_keys(best_epoch_metrics, prefix='hparam/')
    writer.add_hparams(hparam_dict=hparams_dict, metric_dict=metric_dict)
    writer.close()

    # do a complete evaluation run
    best_epoch = best_epoch_metrics['epoch']
    evaluate_model.main(params_json_path=params_json_path,
                        ckpt_path=os.path.join(logdir,
                                               f'ckpt_{best_epoch}.pt'),
                        output_dir=logdir,
                        splits=evaluate_model.SPLITS)
def main(dataset_dir: str,
         cropped_images_dir: str,
         multilabel: bool,
         model_name: str,
         pretrained: bool,
         finetune: int,
         label_weighted: bool,
         weight_by_detection_conf: Union[bool, str],
         epochs: int,
         batch_size: int,
         lr: float,
         weight_decay: float,
         seed: Optional[int] = None,
         logdir: str = '',
         cache_splits: Sequence[str] = ()) -> None:
    """Main function."""
    # input validation
    assert os.path.exists(dataset_dir)
    assert os.path.exists(cropped_images_dir)
    if isinstance(weight_by_detection_conf, str):
        assert os.path.exists(weight_by_detection_conf)

    # set seed
    seed = np.random.randint(10_000) if seed is None else seed
    np.random.seed(seed)
    tf.random.set_seed(seed)

    # create logdir and save params
    params = dict(locals())  # make a copy
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')  # '20200722_110816'
    logdir = os.path.join(logdir, timestamp)
    os.makedirs(logdir, exist_ok=True)
    print('Created logdir:', logdir)
    with open(os.path.join(logdir, 'params.json'), 'w') as f:
        json.dump(params, f, indent=1)

    gpus = tf.config.experimental.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    img_size = EFFICIENTNET_MODELS[model_name]['img_size']

    # create dataloaders and log the index_to_label mapping
    loaders, label_names = create_dataloaders(
        dataset_csv_path=os.path.join(dataset_dir, 'classification_ds.csv'),
        label_index_json_path=os.path.join(dataset_dir, 'label_index.json'),
        splits_json_path=os.path.join(dataset_dir, 'splits.json'),
        cropped_images_dir=cropped_images_dir,
        img_size=img_size,
        multilabel=multilabel,
        label_weighted=label_weighted,
        weight_by_detection_conf=weight_by_detection_conf,
        batch_size=batch_size,
        augment_train=True,
        cache_splits=cache_splits)

    writer = tf.summary.create_file_writer(logdir)
    writer.set_as_default()

    model = build_model(model_name,
                        num_classes=len(label_names),
                        img_size=img_size,
                        pretrained=pretrained,
                        finetune=finetune > 0)

    # define loss function and optimizer
    loss_fn: tf.keras.losses.Loss
    if multilabel:
        loss_fn = tf.keras.losses.BinaryCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
    else:
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

    # using EfficientNet training defaults
    # - batch norm momentum: 0.99
    # - optimizer: RMSProp, decay 0.9 and momentum 0.9
    # - epochs: 350
    # - learning rate: 0.256, decays by 0.97 every 2.4 epochs
    # - weight decay: 1e-5
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        lr, decay_steps=1, decay_rate=0.97, staircase=True)
    optimizer = tf.keras.optimizers.RMSprop(learning_rate=lr,
                                            rho=0.9,
                                            momentum=0.9)

    best_epoch_metrics: Dict[str, float] = {}
    for epoch in range(epochs):
        print(f'Epoch: {epoch}')
        optimizer.learning_rate = lr_schedule(epoch)
        tf.summary.scalar('lr', optimizer.learning_rate, epoch)

        if epoch > 0 and finetune == epoch:
            print('Turning off fine-tune!')
            model.base_model.trainable = True

        print('- train:')
        # TODO: change weighted to False if oversampling minority classes
        train_metrics, train_heaps, train_cm = run_epoch(
            model,
            loader=loaders['train'],
            weighted=label_weighted,
            loss_fn=loss_fn,
            weight_decay=weight_decay,
            optimizer=optimizer,
            finetune=finetune > epoch,
            return_extreme_images=True)
        train_metrics = prefix_all_keys(train_metrics, prefix='train/')
        log_run('train',
                epoch,
                writer,
                label_names,
                metrics=train_metrics,
                heaps=train_heaps,
                cm=train_cm)

        print('- val:')
        val_metrics, val_heaps, val_cm = run_epoch(model,
                                                   loader=loaders['val'],
                                                   weighted=label_weighted,
                                                   loss_fn=loss_fn,
                                                   return_extreme_images=True)
        val_metrics = prefix_all_keys(val_metrics, prefix='val/')
        log_run('val',
                epoch,
                writer,
                label_names,
                metrics=val_metrics,
                heaps=val_heaps,
                cm=val_cm)

        if val_metrics['val/acc_top1'] > best_epoch_metrics.get('val/acc_top1', 0):  # pylint: disable=line-too-long
            filename = os.path.join(logdir, f'ckpt_{epoch}.h5')
            print(f'New best model! Saving checkpoint to {filename}')
            model.save(filename)
            best_epoch_metrics.update(train_metrics)
            best_epoch_metrics.update(val_metrics)
            best_epoch_metrics['epoch'] = epoch

            print('- test:')
            test_metrics, test_heaps, test_cm = run_epoch(
                model,
                loader=loaders['test'],
                weighted=label_weighted,
                loss_fn=loss_fn,
                return_extreme_images=True)
            test_metrics = prefix_all_keys(test_metrics, prefix='test/')
            log_run('test',
                    epoch,
                    writer,
                    label_names,
                    metrics=test_metrics,
                    heaps=test_heaps,
                    cm=test_cm)

        # stop training after 8 epochs without improvement
        if epoch >= best_epoch_metrics['epoch'] + 8:
            break

    hparams_dict = {
        'model_name': model_name,
        'multilabel': multilabel,
        'finetune': finetune,
        'batch_size': batch_size,
        'epochs': epochs
    }
    hp.hparams(hparams_dict)
    writer.close()