Exemple #1
0
def flow_generator_inference(cfg):
    # make configuration
    cfg = projects.setup_run(cfg)
    # turn "models" in your project configuration to "full/path/to/models"
    log.info('args: {}'.format(' '.join(sys.argv)))
    log.info('configuration used in inference: ')
    log.info(OmegaConf.to_yaml(cfg))
    if 'sequence' not in cfg.keys() or 'latent_name' not in cfg.sequence.keys(
    ) or cfg.sequence.latent_name is None:
        latent_name = cfg.feature_extractor.arch
    else:
        latent_name = cfg.sequence.latent_name
    log.info('Latent name used in HDF5 file: {}'.format(latent_name))
    directory_list = cfg.inference.directory_list

    # figure out which videos to run inference on
    if directory_list is None or len(directory_list) == 0:
        raise ValueError('must pass list of directories from commmand line. '
                         'Ex: directory_list=[path_to_dir1,path_to_dir2]')
    elif type(directory_list) == str and directory_list == 'all':
        basedir = cfg.project.data_path
        directory_list = utils.get_subfiles(basedir, 'directory')
    elif isinstance(directory_list, str):
        directory_list = [directory_list]
    elif isinstance(directory_list, list):
        pass
    elif isinstance(directory_list, ListConfig):
        directory_list = OmegaConf.to_container(directory_list)
    else:
        raise ValueError(
            'unknown value for directory list: {}'.format(directory_list))

    # video files are found in your input list of directories using the records.yaml file that should be present
    # in each directory
    records = []
    for directory in directory_list:
        assert os.path.isdir(directory), 'Not a directory: {}'.format(
            directory)
        record = projects.get_record_from_subdir(directory)
        assert record['rgb'] is not None
        records.append(record)
    rgb = []
    for record in records:
        rgb.append(record['rgb'])

    assert cfg.feature_extractor.n_flows + 1 == cfg.flow_generator.n_rgb, 'Flow generator inputs must be one greater ' \
                                                                          'than feature extractor num flows '
    # set up gpu augmentation
    input_images = cfg.feature_extractor.n_flows + 1
    mode = '3d' if '3d' in cfg.feature_extractor.arch.lower() else '2d'
    # get the validation transforms. should have resizing, etc
    cpu_transform = get_cpu_transforms(cfg.augs)['val']
    gpu_transform = get_gpu_transforms(cfg.augs, mode)
    log.info('gpu_transform: {}'.format(gpu_transform))

    flow_generator_weights = projects.get_weightfile_from_cfg(
        cfg, 'flow_generator')
    assert os.path.isfile(flow_generator_weights)
    run_files = get_run_files_from_weights(flow_generator_weights,
                                           'opticalflow')
    if cfg.inference.use_loaded_model_cfg:
        loaded_config_file = run_files['config_file']
        loaded_cfg = OmegaConf.load(loaded_config_file)
        loaded_model_cfg = loaded_cfg.flow_generator
        current_model_cfg = cfg.flow_generator
        model_cfg = OmegaConf.merge(current_model_cfg, loaded_model_cfg)
        cfg.flow_generator = model_cfg
        # we don't want to use the weights that the trained model was initialized with, but the weights after training
        # therefore, overwrite the loaded configuration with the current weights
        cfg.flow_generator.weights = flow_generator_weights
        # num_classes = len(loaded_cfg.project.class_names)
    log.info('model loaded')
    # log.warning('Overwriting current project classes with loaded classes! REVERT')
    model = build_flow_generator(cfg)
    model = utils.load_weights(model, flow_generator_weights, device='cpu')
    # _, _, _, _, model = model_components
    device = 'cuda:{}'.format(cfg.compute.gpu_id)
    model = model.to(device)

    movie_format = 'ffmpeg'
    maxval = 5
    polar = True
    save_rgb_side_by_side = True
    for movie in tqdm(rgb):
        out_video = os.path.splitext(movie)[0] + '_flows'
        if movie_format == 'directory':
            pass
        elif movie_format == 'hdf5':
            out_video += '.h5'
        elif movie_format == 'ffmpeg':
            out_video += '.mp4'
        else:
            out_video += '.avi'
        if os.path.isdir(out_video):
            shutil.rmtree(out_video)
        elif os.path.isfile(out_video):
            os.remove(out_video)

        extract_movie(movie,
                      out_video,
                      model,
                      device,
                      cpu_transform,
                      gpu_transform,
                      mean_by_channels=cfg.augs.normalization.mean,
                      num_workers=1,
                      num_rgb=input_images,
                      maxval=maxval,
                      polar=polar,
                      movie_format=movie_format,
                      save_rgb_side_by_side=save_rgb_side_by_side)
Exemple #2
0
def build_model_from_cfg(
        cfg: DictConfig,
        return_components: bool = False,
        pos: np.ndarray = None,
        neg: np.ndarray = None) -> Union[Type[nn.Module], tuple]:
    """ Builds feature extractor from a configuration object.

    Parameters
    ----------
    cfg: DictConfig
        configuration, e.g. from Hydra command line
    return_components: bool
        if True, returns spatial classifier and flow classifier individually
    pos: np.ndarray
        Number of positive examples in dataset. Used for initializing biases in final layer
    neg: np.ndarray
        Number of negative examples in dataset. Used for initializing biases in final layer

    Returns
    -------
    if `return_components`:
        spatial_classifier, flow_classifier: nn.Module, nn.Module
            cnns for classifying rgb images and optic flows
    else:
        hidden two stream model: nn.Module
            hidden two stream CNN
    """
    device = torch.device(
        "cuda:" +
        str(cfg.compute.gpu_id) if torch.cuda.is_available() else "cpu")
    feature_extractor_weights = get_weightfile_from_cfg(
        cfg, 'feature_extractor')
    num_classes = len(cfg.project.class_names)

    # if feature_extractor_weights is None:
    #     # we get the dataloaders here just for the pos and negative example fields of this dictionary. This allows us
    #     # to build our models with initialization based on the class imbalance of our dataset
    #     dataloaders = get_dataloaders_from_cfg(cfg, model_type='feature_extractor',
    #                                            input_images=cfg.feature_extractor.n_flows + 1)
    # else:
    #     dataloaders = {'pos': None, 'neg': None}

    in_channels = cfg.feature_extractor.n_rgb * 3 if '3d' not in cfg.feature_extractor.arch else 3
    reload_imagenet = feature_extractor_weights is None
    if cfg.feature_extractor.arch == 'resnet3d_34':
        assert feature_extractor_weights is not None, 'Must specify path to resnet3d weights!'
    spatial_classifier = get_cnn(cfg.feature_extractor.arch,
                                 in_channels=in_channels,
                                 dropout_p=cfg.feature_extractor.dropout_p,
                                 num_classes=num_classes,
                                 reload_imagenet=reload_imagenet,
                                 pos=pos,
                                 neg=neg)
    # load this specific component from the weight file
    if feature_extractor_weights is not None:
        spatial_classifier = utils.load_feature_extractor_components(
            spatial_classifier,
            feature_extractor_weights,
            'spatial',
            device=device)
    in_channels = cfg.feature_extractor.n_flows * 2 if '3d' not in cfg.feature_extractor.arch else 2
    flow_classifier = get_cnn(cfg.feature_extractor.arch,
                              in_channels=in_channels,
                              dropout_p=cfg.feature_extractor.dropout_p,
                              num_classes=num_classes,
                              reload_imagenet=reload_imagenet,
                              pos=pos,
                              neg=neg)
    # load this specific component from the weight file
    if feature_extractor_weights is not None:
        flow_classifier = utils.load_feature_extractor_components(
            flow_classifier, feature_extractor_weights, 'flow', device=device)
    if return_components:
        return spatial_classifier, flow_classifier

    flow_generator = build_flow_generator(cfg)
    flow_weights = get_weightfile_from_cfg(cfg, 'flow_generator')
    assert flow_weights is not None, (
        'Must have a valid weightfile for flow generator. Use '
        'deepethogram.flow_generator.train or cfg.reload.latest')
    flow_generator = utils.load_weights(flow_generator,
                                        flow_weights,
                                        device=device)
    model = HiddenTwoStream(flow_generator,
                            spatial_classifier,
                            flow_classifier,
                            cfg.feature_extractor.arch,
                            fusion_style=cfg.feature_extractor.fusion,
                            num_classes=num_classes)
    model.set_mode('classifier')
    return model
Exemple #3
0
def feature_extractor_train(cfg: DictConfig) -> nn.Module:
    """Trains feature extractor models from a configuration. 

    Parameters
    ----------
    cfg : DictConfig
        Configuration, e.g. that returned by deepethogram.configration.make_feature_extractor_train_cfg

    Returns
    -------
    nn.Module
        Trained feature extractor
    """
    # rundir = os.getcwd()
    cfg = projects.setup_run(cfg)

    log.info('args: {}'.format(' '.join(sys.argv)))
    # change the project paths from relative to absolute
    # allow for editing
    OmegaConf.set_struct(cfg, False)
    # SHOULD NEVER MODIFY / MAKE ASSIGNMENTS TO THE CFG OBJECT AFTER RIGHT HERE!
    log.info('configuration used ~~~~~')
    log.info(OmegaConf.to_yaml(cfg))

    # we build flow generator independently because you might want to load it from a different location
    flow_generator = build_flow_generator(cfg)
    flow_weights = projects.get_weightfile_from_cfg(cfg, 'flow_generator')
    assert flow_weights is not None, (
        'Must have a valid weightfile for flow generator. Use '
        'deepethogram.flow_generator.train or cfg.reload.latest')
    log.info('loading flow generator from file {}'.format(flow_weights))

    flow_generator = utils.load_weights(flow_generator, flow_weights)

    _, data_info = get_datasets_from_cfg(
        cfg,
        model_type='feature_extractor',
        input_images=cfg.feature_extractor.n_flows + 1)

    model_parts = build_model_from_cfg(cfg,
                                       pos=data_info['pos'],
                                       neg=data_info['neg'])
    _, spatial_classifier, flow_classifier, fusion, model = model_parts
    # log.info('model: {}'.format(model))

    num_classes = len(cfg.project.class_names)

    utils.save_dict_to_yaml(data_info['split'],
                            os.path.join(cfg.run.dir, 'split.yaml'))

    metrics = get_metrics(
        cfg.run.dir,
        num_classes=num_classes,
        num_parameters=utils.get_num_parameters(spatial_classifier),
        key_metric='f1_class_mean_nobg',
        num_workers=cfg.compute.metrics_workers)

    # cfg.compute.batch_size will be changed by the automatic batch size finder, possibly. store here so that
    # with each step of the curriculum, we can auto-tune it
    original_batch_size = cfg.compute.batch_size
    original_lr = cfg.train.lr

    # training in a curriculum goes as follows:
    # first, we train the spatial classifier, which takes still images as input
    # second, we train the flow classifier, which generates optic flow with the flow_generator model and then classifies
    # it. Thirdly, we will train the whole thing end to end
    # Without the curriculum we just train end to end from the start
    if cfg.feature_extractor.curriculum:
        # train spatial model, then flow model, then both end-to-end
        # dataloaders = get_dataloaders_from_cfg(cfg, model_type='feature_extractor',
        #                                        input_images=cfg.feature_extractor.n_rgb)
        datasets, data_info = get_datasets_from_cfg(
            cfg,
            model_type='feature_extractor',
            input_images=cfg.feature_extractor.n_rgb)
        stopper = get_stopper(cfg)

        criterion = get_criterion(cfg, spatial_classifier, data_info)

        lightning_module = HiddenTwoStreamLightning(spatial_classifier, cfg,
                                                    datasets, metrics,
                                                    criterion)
        trainer = get_trainer_from_cfg(cfg, lightning_module, stopper)
        # this horrible syntax is because we just changed our configuration's batch size and learning rate, if they are
        # set to auto. so we need to re-instantiate our lightning module
        # https://pytorch-lightning.readthedocs.io/en/latest/lr_finder.html?highlight=auto%20scale%20learning%20rate
        # I tried to do this without re-creating module, but finding the learning rate increments the epoch??
        # del lightning_module
        # log.info('epoch num: {}'.format(trainer.current_epoch))
        # lightning_module = HiddenTwoStreamLightning(spatial_classifier, cfg, datasets, metrics, criterion)
        trainer.fit(lightning_module)

        # free RAM. note: this doesn't do much
        log.info('free ram')
        del datasets, lightning_module, trainer, stopper, data_info
        torch.cuda.empty_cache()
        gc.collect()

        # return

        datasets, data_info = get_datasets_from_cfg(
            cfg,
            model_type='feature_extractor',
            input_images=cfg.feature_extractor.n_flows + 1)
        # re-initialize stopper so that it doesn't think we need to stop due to the previous model
        stopper = get_stopper(cfg)
        cfg.compute.batch_size = original_batch_size
        cfg.train.lr = original_lr

        # this class will freeze the flow generator
        flow_generator_and_classifier = FlowOnlyClassifier(
            flow_generator, flow_classifier)
        criterion = get_criterion(cfg, flow_generator_and_classifier,
                                  data_info)
        lightning_module = HiddenTwoStreamLightning(
            flow_generator_and_classifier, cfg, datasets, metrics, criterion)
        trainer = get_trainer_from_cfg(cfg, lightning_module, stopper)
        # lightning_module = HiddenTwoStreamLightning(flow_generator_and_classifier, cfg, datasets, metrics, criterion)
        trainer.fit(lightning_module)

        del datasets, lightning_module, trainer, stopper, data_info
        torch.cuda.empty_cache()
        gc.collect()

    torch.cuda.empty_cache()
    gc.collect()

    model = HiddenTwoStream(flow_generator, spatial_classifier,
                            flow_classifier, fusion,
                            cfg.feature_extractor.arch)
    model.set_mode('classifier')
    datasets, data_info = get_datasets_from_cfg(
        cfg,
        model_type='feature_extractor',
        input_images=cfg.feature_extractor.n_flows + 1)
    criterion = get_criterion(cfg, model, data_info)
    stopper = get_stopper(cfg)
    cfg.compute.batch_size = original_batch_size
    cfg.train.lr = original_lr

    # log.warning('SETTING ANAOMALY DETECTION TO TRUE! WILL SLOW DOWN.')
    # torch.autograd.set_detect_anomaly(True)

    lightning_module = HiddenTwoStreamLightning(model, cfg, datasets, metrics,
                                                criterion)

    trainer = get_trainer_from_cfg(cfg, lightning_module, stopper)
    # see above for horrible syntax explanation
    # lightning_module = HiddenTwoStreamLightning(model, cfg, datasets, metrics, criterion)
    trainer.fit(lightning_module)
    # trainer.test(model=lightning_module)
    return model
Exemple #4
0
def train_from_cfg(cfg: DictConfig) -> Type[nn.Module]:
    """ train DeepEthogram feature extractors from a configuration object.

    Args:
        cfg (DictConfig): configuration object generated by Hydra

    Returns:
        trained feature extractor
    """
    rundir = os.getcwd()  # done by hydra

    device = torch.device(
        "cuda:" +
        str(cfg.compute.gpu_id) if torch.cuda.is_available() else "cpu")
    if device != 'cpu': torch.cuda.set_device(device)

    flow_generator = build_flow_generator(cfg)
    flow_weights = get_weightfile_from_cfg(cfg, 'flow_generator')
    assert flow_weights is not None, (
        'Must have a valid weightfile for flow generator. Use '
        'deepethogram.flow_generator.train or cfg.reload.latest')
    log.info('loading flow generator from file {}'.format(flow_weights))

    flow_generator = utils.load_weights(flow_generator,
                                        flow_weights,
                                        device=device)
    flow_generator = flow_generator.to(device)

    dataloaders = get_dataloaders_from_cfg(
        cfg,
        model_type='feature_extractor',
        input_images=cfg.feature_extractor.n_flows + 1)

    spatial_classifier, flow_classifier = build_model_from_cfg(
        cfg,
        return_components=True,
        pos=dataloaders['pos'],
        neg=dataloaders['neg'])
    spatial_classifier = spatial_classifier.to(device)

    flow_classifier = flow_classifier.to(device)
    num_classes = len(cfg.project.class_names)

    utils.save_dict_to_yaml(dataloaders['split'],
                            os.path.join(rundir, 'split.yaml'))

    criterion = get_criterion(cfg.feature_extractor.final_activation,
                              dataloaders, device)
    steps_per_epoch = dict(cfg.train.steps_per_epoch)
    metrics = get_metrics(
        rundir,
        num_classes=num_classes,
        num_parameters=utils.get_num_parameters(spatial_classifier))

    dali = cfg.compute.dali

    # training in a curriculum goes as follows:
    # first, we train the spatial classifier, which takes still images as input
    # second, we train the flow classifier, which generates optic flow with the flow_generator model and then classifies
    # it. Thirdly, we will train the whole thing end to end
    # Without the curriculum we just train end to end from the start
    if cfg.feature_extractor.curriculum:
        del dataloaders
        # train spatial model, then flow model, then both end-to-end
        dataloaders = get_dataloaders_from_cfg(
            cfg,
            model_type='feature_extractor',
            input_images=cfg.feature_extractor.n_rgb)
        log.info('Num training batches {}, num val: {}'.format(
            len(dataloaders['train']), len(dataloaders['val'])))
        # we'll use this to visualize our data, because it is loaded z-scored. we want it to be in the range [0-1] or
        # [0-255] for visualization, and for that we need to know mean and std
        normalizer = get_normalizer(cfg,
                                    input_images=cfg.feature_extractor.n_rgb)

        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      spatial_classifier.parameters()),
                               lr=cfg.train.lr,
                               weight_decay=cfg.feature_extractor.weight_decay)

        spatialdir = os.path.join(rundir, 'spatial')
        if not os.path.isdir(spatialdir):
            os.makedirs(spatialdir)
        stopper = get_stopper(cfg)
        # we're using validation loss as our key metric
        scheduler = initialize_scheduler(
            optimizer,
            cfg,
            mode='min',
            reduction_factor=cfg.train.reduction_factor)

        log.info('key metric: {}'.format(metrics.key_metric))
        spatial_classifier = train(
            spatial_classifier,
            dataloaders,
            criterion,
            optimizer,
            metrics,
            scheduler,
            spatialdir,
            stopper,
            device,
            steps_per_epoch,
            final_activation=cfg.feature_extractor.final_activation,
            sequence=False,
            normalizer=normalizer,
            dali=dali)

        log.info('Training flow stream....')
        input_images = cfg.feature_extractor.n_flows + 1
        del dataloaders
        dataloaders = get_dataloaders_from_cfg(cfg,
                                               model_type='feature_extractor',
                                               input_images=input_images)

        normalizer = get_normalizer(cfg, input_images=input_images)
        log.info('Num training batches {}, num val: {}'.format(
            len(dataloaders['train']), len(dataloaders['val'])))
        flowdir = os.path.join(rundir, 'flow')
        if not os.path.isdir(flowdir):
            os.makedirs(flowdir)

        flow_generator_and_classifier = FlowOnlyClassifier(
            flow_generator, flow_classifier).to(device)
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      flow_classifier.parameters()),
                               lr=cfg.train.lr,
                               weight_decay=cfg.feature_extractor.weight_decay)

        stopper = get_stopper(cfg)
        # we're using validation loss as our key metric
        scheduler = initialize_scheduler(
            optimizer,
            cfg,
            mode='min',
            reduction_factor=cfg.train.reduction_factor)
        flow_generator_and_classifier = train(
            flow_generator_and_classifier,
            dataloaders,
            criterion,
            optimizer,
            metrics,
            scheduler,
            flowdir,
            stopper,
            device,
            steps_per_epoch,
            final_activation=cfg.feature_extractor.final_activation,
            sequence=False,
            normalizer=normalizer,
            dali=dali)
        flow_classifier = flow_generator_and_classifier.flow_classifier
        # overwrite checkpoint
        utils.checkpoint(flow_classifier, flowdir, stopper.epoch_counter)

    model = HiddenTwoStream(flow_generator,
                            spatial_classifier,
                            flow_classifier,
                            cfg.feature_extractor.arch,
                            fusion_style=cfg.feature_extractor.fusion,
                            num_classes=num_classes).to(device)
    # setting the mode to end-to-end would allow to backprop gradients into the flow generator itself
    # the paper does this, but I don't expect that users would have enough data for this to make sense
    model.set_mode('classifier')
    log.info('Training end to end...')
    input_images = cfg.feature_extractor.n_flows + 1
    dataloaders = get_dataloaders_from_cfg(cfg,
                                           model_type='feature_extractor',
                                           input_images=input_images)
    normalizer = get_normalizer(cfg, input_images=input_images)

    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=cfg.train.lr,
                           weight_decay=cfg.feature_extractor.weight_decay)
    stopper = get_stopper(cfg)
    # we're using validation loss as our key metric
    scheduler = initialize_scheduler(
        optimizer,
        cfg,
        mode='min',
        reduction_factor=cfg.train.reduction_factor)
    log.info('Total trainable params: {:,}'.format(
        utils.get_num_parameters(model)))
    model = train(model,
                  dataloaders,
                  criterion,
                  optimizer,
                  metrics,
                  scheduler,
                  rundir,
                  stopper,
                  device,
                  steps_per_epoch,
                  final_activation=cfg.feature_extractor.final_activation,
                  sequence=False,
                  normalizer=normalizer,
                  dali=dali)
    utils.save_hidden_two_stream(model, rundir, dict(cfg),
                                 stopper.epoch_counter)
    return model
Exemple #5
0
def build_model_from_cfg(cfg: DictConfig,
                         pos: np.ndarray = None,
                         neg: np.ndarray = None,
                         num_classes: int = None) -> tuple:
    """ Builds feature extractor from a configuration object.

    Parameters
    ----------
    cfg: DictConfig
        configuration, e.g. from Hydra command line
    return_components: bool
        if True, returns spatial classifier and flow classifier individually
    pos: np.ndarray
        Number of positive examples in dataset. Used for initializing biases in final layer
    neg: np.ndarray
        Number of negative examples in dataset. Used for initializing biases in final layer

    Returns
    -------
    if `return_components`:
        spatial_classifier, flow_classifier: nn.Module, nn.Module
            cnns for classifying rgb images and optic flows
    else:
        hidden two stream model: nn.Module
            hidden two stream CNN
    """
    # device = torch.device("cuda:" + str(cfg.compute.gpu_id) if torch.cuda.is_available() else "cpu")
    device = 'cpu'
    feature_extractor_weights = projects.get_weightfile_from_cfg(
        cfg, 'feature_extractor')
    if num_classes is None:
        num_classes = len(cfg.project.class_names)

    log.info(
        'feature extractor weightfile: {}'.format(feature_extractor_weights))

    in_channels = cfg.feature_extractor.n_rgb * 3 if '3d' not in cfg.feature_extractor.arch else 3
    reload_imagenet = feature_extractor_weights is None
    if cfg.feature_extractor.arch == 'resnet3d_34':
        assert feature_extractor_weights is not None, 'Must specify path to resnet3d weights!'
    spatial_classifier = get_cnn(cfg.feature_extractor.arch,
                                 in_channels=in_channels,
                                 dropout_p=cfg.feature_extractor.dropout_p,
                                 num_classes=num_classes,
                                 reload_imagenet=reload_imagenet,
                                 pos=pos,
                                 neg=neg,
                                 final_bn=cfg.feature_extractor.final_bn)
    # load this specific component from the weight file
    if feature_extractor_weights is not None:
        spatial_classifier = utils.load_feature_extractor_components(
            spatial_classifier,
            feature_extractor_weights,
            'spatial',
            device=device)
    in_channels = cfg.feature_extractor.n_flows * 2 if '3d' not in cfg.feature_extractor.arch else 2
    flow_classifier = get_cnn(cfg.feature_extractor.arch,
                              in_channels=in_channels,
                              dropout_p=cfg.feature_extractor.dropout_p,
                              num_classes=num_classes,
                              reload_imagenet=reload_imagenet,
                              pos=pos,
                              neg=neg,
                              final_bn=cfg.feature_extractor.final_bn)
    # load this specific component from the weight file
    if feature_extractor_weights is not None:
        flow_classifier = utils.load_feature_extractor_components(
            flow_classifier, feature_extractor_weights, 'flow', device=device)

    flow_generator = build_flow_generator(cfg)
    flow_weights = projects.get_weightfile_from_cfg(cfg, 'flow_generator')
    assert flow_weights is not None, (
        'Must have a valid weightfile for flow generator. Use '
        'deepethogram.flow_generator.train or cfg.reload.latest')
    flow_generator = utils.load_weights(flow_generator,
                                        flow_weights,
                                        device=device)

    spatial_classifier, flow_classifier, fusion = build_fusion_layer(
        spatial_classifier, flow_classifier, cfg.feature_extractor.fusion,
        num_classes)
    if feature_extractor_weights is not None:
        fusion = utils.load_feature_extractor_components(
            fusion, feature_extractor_weights, 'fusion', device=device)

    model = HiddenTwoStream(flow_generator, spatial_classifier,
                            flow_classifier, fusion,
                            cfg.feature_extractor.arch)
    # log.info(model.fusion.flow_weight)
    model.set_mode('classifier')

    return flow_generator, spatial_classifier, flow_classifier, fusion, model