Пример #1
0
def train(model: Type[nn.Module],
          dataloaders: dict,
          criterion,
          optimizer,
          metrics,
          scheduler,
          rundir: Union[str, bytes, os.PathLike],
          stopper,
          device: torch.device,
          steps_per_epoch: dict,
          final_activation: str = 'sigmoid',
          sequence: bool = False,
          class_names: list = None,
          normalizer=None,
          dali: bool = False):
    """ Train feature extractor models

    Args:
        model (nn.Module): feature extractor (can also be a component, like the spatial stream or flow stream)
        dataloaders (dict): dictionary with PyTorch dataloader objects (see dataloaders.py)
        criterion (nn.Module): loss function
        optimizer (torch.optim): optimizer (SGD, SGDM, ADAM, etc)
        metrics (Metrics): metrics object for computing metrics and saving to disk (see metrics.py)
        scheduler (_LRScheduler): learning rate scheduler (see schedulers.py)
        rundir (str, os.PathLike): run directory for saving weights
        stopper (Stopper): object that stops training (see stoppers.py)
        device (str, torch.device): gpu device
        steps_per_epoch (dict): keys ['train', 'val', 'test']: number of steps in each "epoch"
        final_activation (str): either sigmoid or softmax
        sequence (bool): if True, assumes sequence inputs of shape N,K,T
        class_names (list): unused
        normalizer (Normalizer): normalizer object, used for un-zscoring images for visualization purposes

    Returns:
        model: a trained model
    """
    # check our inputs
    assert (isinstance(model, nn.Module))
    assert (isinstance(criterion, nn.Module))
    assert (isinstance(optimizer, torch.optim.Optimizer))

    # loop over number of epochs!
    for epoch in trange(0, stopper.num_epochs):
        # if our learning rate scheduler plateaus when validation metric saturates, we have to pass our "key metric" for
        # our validation set. Else, just step every epoch
        if scheduler.name == 'plateau' and epoch > 0:
            if hasattr(metrics, 'latest_key'):
                if 'val' in list(metrics.latest_key.keys()):
                    scheduler.step(metrics.latest_key['val'])
        elif epoch > 0:
            scheduler.step()
        # update the learning rate for this epoch
        min_lr = utils.get_minimum_learning_rate(optimizer)
        # store the learning rate for this epoch in our metrics file
        # print('min lr: {}'.format(min_lr))
        metrics.update_lr(min_lr)

        # loop over our training set!
        metrics, _ = loop_one_epoch(dataloaders['train'],
                                    model,
                                    criterion,
                                    optimizer,
                                    metrics,
                                    final_activation,
                                    steps_per_epoch['train'],
                                    train_mode=True,
                                    device=device,
                                    dali=dali)

        # evaluate on validation set
        with torch.no_grad():
            metrics, examples = loop_one_epoch(dataloaders['val'],
                                               model,
                                               criterion,
                                               optimizer,
                                               metrics,
                                               final_activation,
                                               steps_per_epoch['val'],
                                               train_mode=False,
                                               sequence=sequence,
                                               device=device,
                                               normalizer=normalizer,
                                               dali=dali)

            # some training protocols do not have test sets, so just reuse validation set for testing inference speed
            key = 'test' if 'test' in dataloaders.keys() else 'val'
            loader = dataloaders[key]
            # evaluate how fast inference takes, without loss calculation, which for some models can have a significant
            # speed impact
            metrics = speedtest(loader,
                                model,
                                metrics,
                                steps_per_epoch['test'],
                                device=device,
                                dali=dali)

        # use our metrics file to output graphs for this epoch
        viz.visualize_logger(metrics.fname,
                             examples if len(examples) > 0 else None)

        # save a checkpoint
        utils.checkpoint(model, rundir, epoch)
        # if should_update_latest_models:
        #     projects.write_latest_model(config['model'], config['classifier'], rundir, config)
        # input the latest validation loss to the early stopper
        if stopper.name == 'early':
            should_stop, _ = stopper(metrics.latest_key['val'])
        elif stopper.name == 'learning_rate':
            should_stop = stopper(min_lr)
        else:
            raise ValueError('Please select a stopping type')

        if should_stop:
            log.info('Stopping criterion reached!')
            break

    return model
Пример #2
0
def train(model,
          dataloaders: dict,
          criterion,
          optimizer,
          metrics,
          scheduler,
          reconstructor,
          rundir: Union[str, bytes, os.PathLike],
          stopper,
          device: torch.device,
          num_epochs: int = 1000,
          steps_per_epoch: int = 1000,
          steps_per_validation_epoch: int = 1000,
          steps_per_test_epoch: int = 100,
          early_stopping_begins: int = 0,
          max_flow: float = 2.5,
          dali: bool = False,
          fp16: bool = False):
    # check our inputs
    assert (isinstance(model, nn.Module))
    assert (isinstance(criterion, nn.Module))
    assert (isinstance(optimizer, torch.optim.Optimizer))

    scaler = None
    if fp16:
        scaler = GradScaler()
    # loop over number of epochs!
    for epoch in trange(0, num_epochs):
        # if our learning rate scheduler plateaus when validation metric saturates, we have to pass our "key metric" for
        # our validation set. Else, just step every epoch
        if scheduler.name == 'plateau' and epoch > 0:
            if hasattr(metrics, 'latest_key'):
                if 'val' in list(metrics.latest_key.keys()):
                    scheduler.step(metrics.latest_key['val'])
        elif epoch > 0:
            scheduler.step()
        # update the learning rate for this epoch
        min_lr = utils.get_minimum_learning_rate(optimizer)
        # store the learning rate for this epoch in our metrics file
        # print('min lr: {}'.format(min_lr))
        metrics.update_lr(min_lr)

        # loop over our training set!
        model, metrics, _ = loop_one_epoch(dataloaders['train'],
                                           model,
                                           criterion,
                                           optimizer,
                                           metrics,
                                           reconstructor,
                                           steps_per_epoch,
                                           train_mode=True,
                                           device=device,
                                           dali=dali,
                                           fp16=fp16,
                                           scaler=scaler)

        # evaluate on validation set
        with torch.no_grad():
            model, metrics, examples = loop_one_epoch(
                dataloaders['val'],
                model,
                criterion,
                optimizer,
                metrics,
                reconstructor,
                steps_per_validation_epoch,
                train_mode=False,
                device=device,
                max_flow=max_flow,
                dali=dali,
                fp16=fp16,
                scaler=scaler)

            # some training protocols do not have test sets, so just reuse validation set for testing inference speed
            key = 'test' if 'test' in dataloaders.keys() else 'val'
            loader = dataloaders[key]
            # evaluate how fast inference takes, without loss calculation, which for some models can have a significant
            # speed impact
            metrics = speedtest(loader,
                                model,
                                metrics,
                                steps_per_test_epoch,
                                device=device,
                                dali=dali,
                                fp16=fp16)

        # use our metrics file to output graphs for this epoch
        viz.visualize_logger(metrics.fname, examples)

        # save a checkpoint
        utils.checkpoint(model, rundir, epoch)
        # # update latest models file
        # projects.write_latest_model(config['model'], config['flow_generator'], rundir, config)

        # input the latest validation loss to the early stopper
        if stopper.name == 'early':
            should_stop, _ = stopper(metrics.latest_key['val'])
        elif stopper.name == 'learning_rate':
            should_stop = stopper(min_lr)
        else:
            # every epoch, increment stopper
            should_stop = stopper()

        if should_stop:
            log.info('Stopping criterion reached!')
            break
    return model
Пример #3
0
 def on_train_epoch_end(self, trainer, pl_module, outputs):
     pl_module.metrics.buffer.append(
         'train',
         {'lr': utils.get_minimum_learning_rate(pl_module.optimizer)})
     _ = log_metrics(pl_module, 'train')