Ejemplo n.º 1
0
def train(model,
          dataloaders: dict,
          criterion,
          optimizer,
          metrics,
          scheduler,
          reconstructor,
          rundir: Union[str, bytes, os.PathLike],
          stopper,
          device: torch.device,
          num_epochs: int = 1000,
          steps_per_epoch: int = 1000,
          steps_per_validation_epoch: int = 1000,
          steps_per_test_epoch: int = 100,
          early_stopping_begins: int = 0,
          max_flow: float = 2.5,
          dali: bool = False,
          fp16: bool = False):
    # check our inputs
    assert (isinstance(model, nn.Module))
    assert (isinstance(criterion, nn.Module))
    assert (isinstance(optimizer, torch.optim.Optimizer))

    scaler = None
    if fp16:
        scaler = GradScaler()
    # loop over number of epochs!
    for epoch in trange(0, num_epochs):
        # if our learning rate scheduler plateaus when validation metric saturates, we have to pass our "key metric" for
        # our validation set. Else, just step every epoch
        if scheduler.name == 'plateau' and epoch > 0:
            if hasattr(metrics, 'latest_key'):
                if 'val' in list(metrics.latest_key.keys()):
                    scheduler.step(metrics.latest_key['val'])
        elif epoch > 0:
            scheduler.step()
        # update the learning rate for this epoch
        min_lr = utils.get_minimum_learning_rate(optimizer)
        # store the learning rate for this epoch in our metrics file
        # print('min lr: {}'.format(min_lr))
        metrics.update_lr(min_lr)

        # loop over our training set!
        model, metrics, _ = loop_one_epoch(dataloaders['train'],
                                           model,
                                           criterion,
                                           optimizer,
                                           metrics,
                                           reconstructor,
                                           steps_per_epoch,
                                           train_mode=True,
                                           device=device,
                                           dali=dali,
                                           fp16=fp16,
                                           scaler=scaler)

        # evaluate on validation set
        with torch.no_grad():
            model, metrics, examples = loop_one_epoch(
                dataloaders['val'],
                model,
                criterion,
                optimizer,
                metrics,
                reconstructor,
                steps_per_validation_epoch,
                train_mode=False,
                device=device,
                max_flow=max_flow,
                dali=dali,
                fp16=fp16,
                scaler=scaler)

            # some training protocols do not have test sets, so just reuse validation set for testing inference speed
            key = 'test' if 'test' in dataloaders.keys() else 'val'
            loader = dataloaders[key]
            # evaluate how fast inference takes, without loss calculation, which for some models can have a significant
            # speed impact
            metrics = speedtest(loader,
                                model,
                                metrics,
                                steps_per_test_epoch,
                                device=device,
                                dali=dali,
                                fp16=fp16)

        # use our metrics file to output graphs for this epoch
        viz.visualize_logger(metrics.fname, examples)

        # save a checkpoint
        utils.checkpoint(model, rundir, epoch)
        # # update latest models file
        # projects.write_latest_model(config['model'], config['flow_generator'], rundir, config)

        # input the latest validation loss to the early stopper
        if stopper.name == 'early':
            should_stop, _ = stopper(metrics.latest_key['val'])
        elif stopper.name == 'learning_rate':
            should_stop = stopper(min_lr)
        else:
            # every epoch, increment stopper
            should_stop = stopper()

        if should_stop:
            log.info('Stopping criterion reached!')
            break
    return model
Ejemplo n.º 2
0
def train_from_cfg(cfg: DictConfig) -> Type[nn.Module]:
    """ train DeepEthogram feature extractors from a configuration object.

    Args:
        cfg (DictConfig): configuration object generated by Hydra

    Returns:
        trained feature extractor
    """
    rundir = os.getcwd()  # done by hydra

    device = torch.device(
        "cuda:" +
        str(cfg.compute.gpu_id) if torch.cuda.is_available() else "cpu")
    if device != 'cpu': torch.cuda.set_device(device)

    flow_generator = build_flow_generator(cfg)
    flow_weights = get_weightfile_from_cfg(cfg, 'flow_generator')
    assert flow_weights is not None, (
        'Must have a valid weightfile for flow generator. Use '
        'deepethogram.flow_generator.train or cfg.reload.latest')
    log.info('loading flow generator from file {}'.format(flow_weights))

    flow_generator = utils.load_weights(flow_generator,
                                        flow_weights,
                                        device=device)
    flow_generator = flow_generator.to(device)

    dataloaders = get_dataloaders_from_cfg(
        cfg,
        model_type='feature_extractor',
        input_images=cfg.feature_extractor.n_flows + 1)

    spatial_classifier, flow_classifier = build_model_from_cfg(
        cfg,
        return_components=True,
        pos=dataloaders['pos'],
        neg=dataloaders['neg'])
    spatial_classifier = spatial_classifier.to(device)

    flow_classifier = flow_classifier.to(device)
    num_classes = len(cfg.project.class_names)

    utils.save_dict_to_yaml(dataloaders['split'],
                            os.path.join(rundir, 'split.yaml'))

    criterion = get_criterion(cfg.feature_extractor.final_activation,
                              dataloaders, device)
    steps_per_epoch = dict(cfg.train.steps_per_epoch)
    metrics = get_metrics(
        rundir,
        num_classes=num_classes,
        num_parameters=utils.get_num_parameters(spatial_classifier))

    dali = cfg.compute.dali

    # training in a curriculum goes as follows:
    # first, we train the spatial classifier, which takes still images as input
    # second, we train the flow classifier, which generates optic flow with the flow_generator model and then classifies
    # it. Thirdly, we will train the whole thing end to end
    # Without the curriculum we just train end to end from the start
    if cfg.feature_extractor.curriculum:
        del dataloaders
        # train spatial model, then flow model, then both end-to-end
        dataloaders = get_dataloaders_from_cfg(
            cfg,
            model_type='feature_extractor',
            input_images=cfg.feature_extractor.n_rgb)
        log.info('Num training batches {}, num val: {}'.format(
            len(dataloaders['train']), len(dataloaders['val'])))
        # we'll use this to visualize our data, because it is loaded z-scored. we want it to be in the range [0-1] or
        # [0-255] for visualization, and for that we need to know mean and std
        normalizer = get_normalizer(cfg,
                                    input_images=cfg.feature_extractor.n_rgb)

        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      spatial_classifier.parameters()),
                               lr=cfg.train.lr,
                               weight_decay=cfg.feature_extractor.weight_decay)

        spatialdir = os.path.join(rundir, 'spatial')
        if not os.path.isdir(spatialdir):
            os.makedirs(spatialdir)
        stopper = get_stopper(cfg)
        # we're using validation loss as our key metric
        scheduler = initialize_scheduler(
            optimizer,
            cfg,
            mode='min',
            reduction_factor=cfg.train.reduction_factor)

        log.info('key metric: {}'.format(metrics.key_metric))
        spatial_classifier = train(
            spatial_classifier,
            dataloaders,
            criterion,
            optimizer,
            metrics,
            scheduler,
            spatialdir,
            stopper,
            device,
            steps_per_epoch,
            final_activation=cfg.feature_extractor.final_activation,
            sequence=False,
            normalizer=normalizer,
            dali=dali)

        log.info('Training flow stream....')
        input_images = cfg.feature_extractor.n_flows + 1
        del dataloaders
        dataloaders = get_dataloaders_from_cfg(cfg,
                                               model_type='feature_extractor',
                                               input_images=input_images)

        normalizer = get_normalizer(cfg, input_images=input_images)
        log.info('Num training batches {}, num val: {}'.format(
            len(dataloaders['train']), len(dataloaders['val'])))
        flowdir = os.path.join(rundir, 'flow')
        if not os.path.isdir(flowdir):
            os.makedirs(flowdir)

        flow_generator_and_classifier = FlowOnlyClassifier(
            flow_generator, flow_classifier).to(device)
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      flow_classifier.parameters()),
                               lr=cfg.train.lr,
                               weight_decay=cfg.feature_extractor.weight_decay)

        stopper = get_stopper(cfg)
        # we're using validation loss as our key metric
        scheduler = initialize_scheduler(
            optimizer,
            cfg,
            mode='min',
            reduction_factor=cfg.train.reduction_factor)
        flow_generator_and_classifier = train(
            flow_generator_and_classifier,
            dataloaders,
            criterion,
            optimizer,
            metrics,
            scheduler,
            flowdir,
            stopper,
            device,
            steps_per_epoch,
            final_activation=cfg.feature_extractor.final_activation,
            sequence=False,
            normalizer=normalizer,
            dali=dali)
        flow_classifier = flow_generator_and_classifier.flow_classifier
        # overwrite checkpoint
        utils.checkpoint(flow_classifier, flowdir, stopper.epoch_counter)

    model = HiddenTwoStream(flow_generator,
                            spatial_classifier,
                            flow_classifier,
                            cfg.feature_extractor.arch,
                            fusion_style=cfg.feature_extractor.fusion,
                            num_classes=num_classes).to(device)
    # setting the mode to end-to-end would allow to backprop gradients into the flow generator itself
    # the paper does this, but I don't expect that users would have enough data for this to make sense
    model.set_mode('classifier')
    log.info('Training end to end...')
    input_images = cfg.feature_extractor.n_flows + 1
    dataloaders = get_dataloaders_from_cfg(cfg,
                                           model_type='feature_extractor',
                                           input_images=input_images)
    normalizer = get_normalizer(cfg, input_images=input_images)

    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=cfg.train.lr,
                           weight_decay=cfg.feature_extractor.weight_decay)
    stopper = get_stopper(cfg)
    # we're using validation loss as our key metric
    scheduler = initialize_scheduler(
        optimizer,
        cfg,
        mode='min',
        reduction_factor=cfg.train.reduction_factor)
    log.info('Total trainable params: {:,}'.format(
        utils.get_num_parameters(model)))
    model = train(model,
                  dataloaders,
                  criterion,
                  optimizer,
                  metrics,
                  scheduler,
                  rundir,
                  stopper,
                  device,
                  steps_per_epoch,
                  final_activation=cfg.feature_extractor.final_activation,
                  sequence=False,
                  normalizer=normalizer,
                  dali=dali)
    utils.save_hidden_two_stream(model, rundir, dict(cfg),
                                 stopper.epoch_counter)
    return model
Ejemplo n.º 3
0
def train(model: Type[nn.Module],
          dataloaders: dict,
          criterion,
          optimizer,
          metrics,
          scheduler,
          rundir: Union[str, bytes, os.PathLike],
          stopper,
          device: torch.device,
          steps_per_epoch: dict,
          final_activation: str = 'sigmoid',
          sequence: bool = False,
          class_names: list = None,
          normalizer=None,
          dali: bool = False):
    """ Train feature extractor models

    Args:
        model (nn.Module): feature extractor (can also be a component, like the spatial stream or flow stream)
        dataloaders (dict): dictionary with PyTorch dataloader objects (see dataloaders.py)
        criterion (nn.Module): loss function
        optimizer (torch.optim): optimizer (SGD, SGDM, ADAM, etc)
        metrics (Metrics): metrics object for computing metrics and saving to disk (see metrics.py)
        scheduler (_LRScheduler): learning rate scheduler (see schedulers.py)
        rundir (str, os.PathLike): run directory for saving weights
        stopper (Stopper): object that stops training (see stoppers.py)
        device (str, torch.device): gpu device
        steps_per_epoch (dict): keys ['train', 'val', 'test']: number of steps in each "epoch"
        final_activation (str): either sigmoid or softmax
        sequence (bool): if True, assumes sequence inputs of shape N,K,T
        class_names (list): unused
        normalizer (Normalizer): normalizer object, used for un-zscoring images for visualization purposes

    Returns:
        model: a trained model
    """
    # check our inputs
    assert (isinstance(model, nn.Module))
    assert (isinstance(criterion, nn.Module))
    assert (isinstance(optimizer, torch.optim.Optimizer))

    # loop over number of epochs!
    for epoch in trange(0, stopper.num_epochs):
        # if our learning rate scheduler plateaus when validation metric saturates, we have to pass our "key metric" for
        # our validation set. Else, just step every epoch
        if scheduler.name == 'plateau' and epoch > 0:
            if hasattr(metrics, 'latest_key'):
                if 'val' in list(metrics.latest_key.keys()):
                    scheduler.step(metrics.latest_key['val'])
        elif epoch > 0:
            scheduler.step()
        # update the learning rate for this epoch
        min_lr = utils.get_minimum_learning_rate(optimizer)
        # store the learning rate for this epoch in our metrics file
        # print('min lr: {}'.format(min_lr))
        metrics.update_lr(min_lr)

        # loop over our training set!
        metrics, _ = loop_one_epoch(dataloaders['train'],
                                    model,
                                    criterion,
                                    optimizer,
                                    metrics,
                                    final_activation,
                                    steps_per_epoch['train'],
                                    train_mode=True,
                                    device=device,
                                    dali=dali)

        # evaluate on validation set
        with torch.no_grad():
            metrics, examples = loop_one_epoch(dataloaders['val'],
                                               model,
                                               criterion,
                                               optimizer,
                                               metrics,
                                               final_activation,
                                               steps_per_epoch['val'],
                                               train_mode=False,
                                               sequence=sequence,
                                               device=device,
                                               normalizer=normalizer,
                                               dali=dali)

            # some training protocols do not have test sets, so just reuse validation set for testing inference speed
            key = 'test' if 'test' in dataloaders.keys() else 'val'
            loader = dataloaders[key]
            # evaluate how fast inference takes, without loss calculation, which for some models can have a significant
            # speed impact
            metrics = speedtest(loader,
                                model,
                                metrics,
                                steps_per_epoch['test'],
                                device=device,
                                dali=dali)

        # use our metrics file to output graphs for this epoch
        viz.visualize_logger(metrics.fname,
                             examples if len(examples) > 0 else None)

        # save a checkpoint
        utils.checkpoint(model, rundir, epoch)
        # if should_update_latest_models:
        #     projects.write_latest_model(config['model'], config['classifier'], rundir, config)
        # input the latest validation loss to the early stopper
        if stopper.name == 'early':
            should_stop, _ = stopper(metrics.latest_key['val'])
        elif stopper.name == 'learning_rate':
            should_stop = stopper(min_lr)
        else:
            raise ValueError('Please select a stopping type')

        if should_stop:
            log.info('Stopping criterion reached!')
            break

    return model
Ejemplo n.º 4
0
 def checkpoint(self, pl_module):
     utils.checkpoint(pl_module.model, os.getcwd(), pl_module.current_epoch)