def generate_model(opt): # 调用各个模型中的generate_model函数,定义相应的模型
    assert opt.model in [
        'resnet', 'resnet2p1d', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet':
        model = resnet.generate_model(model_depth=opt.model_depth,
                                      n_classes=opt.n_classes,
                                      n_input_channels=opt.n_input_channels,
                                      shortcut_type=opt.resnet_shortcut,
                                      conv1_t_size=opt.conv1_t_size,
                                      conv1_t_stride=opt.conv1_t_stride,
                                      no_max_pool=opt.no_max_pool,
                                      widen_factor=opt.resnet_widen_factor)
    elif opt.model == 'resnet2p1d':
        model = resnet2p1d.generate_model(model_depth=opt.model_depth,
                                          n_classes=opt.n_classes,
                                          n_input_channels=opt.n_input_channels,
                                          shortcut_type=opt.resnet_shortcut,
                                          conv1_t_size=opt.conv1_t_size,
                                          conv1_t_stride=opt.conv1_t_stride,
                                          no_max_pool=opt.no_max_pool,
                                          widen_factor=opt.resnet_widen_factor)
    elif opt.model == 'wideresnet':
        model = wide_resnet.generate_model(
            model_depth=opt.model_depth,
            k=opt.wide_resnet_k,
            n_classes=opt.n_classes,
            n_input_channels=opt.n_input_channels,
            shortcut_type=opt.resnet_shortcut,
            conv1_t_size=opt.conv1_t_size,
            conv1_t_stride=opt.conv1_t_stride,
            no_max_pool=opt.no_max_pool)
    elif opt.model == 'resnext':
        model = resnext.generate_model(model_depth=opt.model_depth,
                                       cardinality=opt.resnext_cardinality,
                                       n_classes=opt.n_classes,
                                       n_input_channels=opt.n_input_channels,
                                       shortcut_type=opt.resnet_shortcut,
                                       conv1_t_size=opt.conv1_t_size,
                                       conv1_t_stride=opt.conv1_t_stride,
                                       no_max_pool=opt.no_max_pool)
    elif opt.model == 'preresnet':
        model = pre_act_resnet.generate_model(
            model_depth=opt.model_depth,
            n_classes=opt.n_classes,
            n_input_channels=opt.n_input_channels,
            shortcut_type=opt.resnet_shortcut,
            conv1_t_size=opt.conv1_t_size,
            conv1_t_stride=opt.conv1_t_stride,
            no_max_pool=opt.no_max_pool)
    elif opt.model == 'densenet':
        model = densenet.generate_model(model_depth=opt.model_depth,
                                        n_classes=opt.n_classes,
                                        n_input_channels=opt.n_input_channels,
                                        conv1_t_size=opt.conv1_t_size,
                                        conv1_t_stride=opt.conv1_t_stride,
                                        no_max_pool=opt.no_max_pool)

    return model
예제 #2
0
def get_model(model_name, encoder, num_classes):
    if model_name == 'lstm':
        return LSTM(encoder, num_classes)
    elif model_name == 'cnn':
        return CNN(encoder, num_classes)
    elif model_name == 'transformer':
        return Transformer(encoder, num_classes)
    elif model_name == 'cnn3d':
        if 'efficientnet' in encoder:
            return EfficientNet3D.from_name(
                encoder,
                override_params={'num_classes': num_classes},
                in_channels=1)
        elif 'resnet' in encoder:
            if encoder == 'resnet':
                encoder = 'resnet34'
            return resnet.generate_model(model_depth=int(
                encoder.split('resnet')[-1]),
                                         n_classes=num_classes,
                                         n_input_channels=1,
                                         shortcut_type='B',
                                         conv1_t_size=7,
                                         conv1_t_stride=1,
                                         no_max_pool=False,
                                         widen_factor=1.)
        elif 'resnext' in encoder:
            return resnext.generate_model(model_depth=int(
                encoder.split('resnext')[-1]),
                                          n_classes=num_classes,
                                          n_input_channels=1,
                                          cardinality=32,
                                          shortcut_type='B',
                                          conv1_t_size=7,
                                          conv1_t_stride=1,
                                          no_max_pool=False)
        elif 'resnet2p1d' in encoder:
            return resnet2p1d.generate_model(model_depth=int(
                encoder.split('resnet2p1d')[-1]),
                                             n_classes=num_classes,
                                             n_input_channels=1,
                                             shortcut_type='B',
                                             conv1_t_size=7,
                                             conv1_t_stride=1,
                                             no_max_pool=False,
                                             widen_factor=1.)
        elif 'densenet' in encoder:
            return densenet.generate_model(model_depth=int(
                encoder.split('densenet')[-1]),
                                           num_classes=num_classes,
                                           n_input_channels=1,
                                           conv1_t_size=7,
                                           conv1_t_stride=1,
                                           no_max_pool=False)
        else:
            print(encoder)
            raise
    else:
        print(model_name)
        raise
예제 #3
0
파일: utils.py 프로젝트: Cardiobotics/SUEF
def create_and_load_model_old(cfg):
    tags = []
    if cfg.performance.ddp:
        map_location = {'cuda:%d' % 0: 'cuda:%d' % cfg.rank}
    if cfg.model.name == 'ccnn':
        tags.append('CNN')
        model = custom_cnn.CNN()
    elif cfg.model.name == 'resnext':
        tags.append('ResNeXt')
        model = resnext.generate_model(model_depth=cfg.model.model_depth,
                                       cardinality=cfg.model.cardinality,
                                       n_classes=cfg.model.n_classes,
                                       n_input_channels=cfg.model.n_input_channels,
                                       shortcut_type=cfg.model.shortcut_type,
                                       conv1_t_size=cfg.model.conv1_t_size,
                                       conv1_t_stride=cfg.model.conv1_t_stride)
        model.load_state_dict(torch.load(cfg.model.pre_trained_checkpoint))
    elif cfg.model.name == 'i3d':
        tags.append('I3D')
        if cfg.training.continue_training:
            checkpoint = cfg.model.best_model
        else:
            checkpoint = cfg.model.pre_trained_checkpoint
        if cfg.data.type == 'img':
            tags.append('spatial')
            model = i3d_bert.inception_model(checkpoint, cfg.model.n_classes, cfg.model.n_input_channels,
                                             cfg.model.pre_n_classes, cfg.model.pre_n_input_channels)
        elif cfg.data.type == 'flow':
            tags.append('temporal')
            tags.append('TVL1')
            model = i3d_bert.inception_model_flow(checkpoint, cfg.model.n_classes, cfg.model.n_input_channels,
                                                  cfg.model.pre_n_classes, cfg.model.pre_n_input_channels)
    elif cfg.model.name == 'i3d_bert':
        tags.append('I3D')
        tags.append('BERT')

        if cfg.training.continue_training:
            state_dict = torch.load(cfg.model.best_model, map_location=map_location)['model']
            if cfg.data.type == 'img':
                tags.append('spatial')
                model = i3d_bert.rgb_I3D64f_bert2_FRMB('', cfg.model.length, cfg.model.n_classes,
                                                       cfg.model.n_input_channels, cfg.model.pre_n_classes,
                                                       cfg.model.pre_n_input_channels)
            if cfg.data.type == 'flow':
                tags.append('temporal')
                tags.append('TVL1')
                model = i3d_bert.flow_I3D64f_bert2_FRMB('', cfg.model.length, cfg.model.n_classes,
                                                        cfg.model.n_input_channels, cfg.model.pre_n_classes,
                                                        cfg.model.pre_n_input_channels)
            if cfg.data.type == 'multi-stream':
                tags.append('multi-stream')
                tags.append('TVL1')
                if cfg.model.shared_weights:

                    tags.append('shared-weights')
                    model_img, model_flow = create_two_stream_models(cfg, '', '')
                    model = multi_stream.MultiStreamShared(model_img, model_flow, len(state_dict['Linear_layer.weight'][0]), cfg.model.n_classes)
                    model.load_state_dict(state_dict)
                    if not len(state_dict['Linear_layer.weight'][0]) == len(cfg.data.allowed_views) * 2:
                        model.replace_fc(len(cfg.data.allowed_views) * 2)
                else:

                    model_dict = {}
                    for view in cfg.data.allowed_views:
                        m_img_name = 'model_img_' + str(view)
                        m_flow_name = 'model_flow_' + str(view)
                        model_img, model_flow = create_two_stream_models(cfg, '', '')
                        model_dict[m_img_name] = model_img
                        model_dict[m_flow_name] = model_flow
                    model = multi_stream.MultiStream(model_dict)
        
        else:
            if cfg.data.type == 'img':
                tags.append('spatial')
                model = i3d_bert.rgb_I3D64f_bert2_FRMB(cfg.model.pre_trained_checkpoint, cfg.model.length,
                                                       cfg.model.n_classes, cfg.model.n_input_channels,
                                                       cfg.model.pre_n_classes, cfg.model.pre_n_input_channels)
            if cfg.data.type == 'flow':
                tags.append('temporal')
                tags.append('TVL1')
                model = i3d_bert.flow_I3D64f_bert2_FRMB(cfg.model.pre_trained_checkpoint, cfg.model.length,
                                                        cfg.model.n_classes, cfg.model.n_input_channels,
                                                        cfg.model.pre_n_classes, cfg.model.pre_n_input_channels)
            if cfg.data.type == 'multi-stream':
                tags.append('multi-stream')
                tags.append('TVL1')
                if cfg.model.shared_weights:
                    tags.append('shared-weights')
                    model_img, model_flow = create_two_stream_models(cfg, cfg.model.pre_trained_checkpoint_img,
                                                                     cfg.model.pre_trained_checkpoint_flow)
                    model = multi_stream.MultiStreamShared(model_img, model_flow, len(cfg.data.allowed_views)*2,
                                                           cfg.model.n_classes)
                else:
                    model_dict = {}
                    for view in cfg.data.allowed_views:
                        m_img_name = 'model_img_' + str(view)
                        m_flow_name = 'model_flow_' + str(view)
                        model_img, model_flow = create_two_stream_models(cfg, cfg.model.pre_trained_checkpoint_img,
                                                                         cfg.model.pre_trained_checkpoint_flow)
                        model_dict[m_img_name] = model_img
                        model_dict[m_flow_name] = model_flow
                    model = multi_stream.MultiStream(model_dict, cfg.model.n_classes)
                    
    return model, tags
예제 #4
0
파일: utils.py 프로젝트: Cardiobotics/SUEF
def create_and_load_model(cfg):
    tags = []
    if cfg.performance.ddp:
        map_location = {'cuda:%d' % 0: 'cuda:%d' % cfg.rank}
    if cfg.model.name == 'ccnn':
        tags.append('CNN')
        model = custom_cnn.CNN()
    elif cfg.model.name == 'resnext':
        tags.append('ResNeXt')
        model = resnext.generate_model(model_depth=cfg.model.model_depth,
                                       cardinality=cfg.model.cardinality,
                                       n_classes=cfg.model.n_classes,
                                       n_input_channels=cfg.model.n_input_channels,
                                       shortcut_type=cfg.model.shortcut_type,
                                       conv1_t_size=cfg.model.conv1_t_size,
                                       conv1_t_stride=cfg.model.conv1_t_stride)
        model.load_state_dict(torch.load(cfg.model.pre_trained_checkpoint))
    elif cfg.model.name == 'i3d':

        tags.append('I3D')
        if cfg.data.type == 'img':
            tags.append('spatial')
            if cfg.training.continue_training:
                checkpoint = cfg.model.best_model
            else:
                checkpoint = cfg.model.pre_trained_checkpoint
            model = i3d_bert.inception_model(checkpoint, cfg.model.n_classes, cfg.model.n_input_channels,
                                             cfg.model.pre_n_classes, cfg.model.pre_n_input_channels)
        elif cfg.data.type == 'flow':
            tags.append('temporal')
            tags.append('TVL1')
            if cfg.training.continue_training:
                checkpoint = cfg.model.best_model
            else:
                checkpoint = cfg.model.pre_trained_checkpoint
            model = i3d_bert.inception_model_flow(checkpoint, cfg.model.n_classes, cfg.model.n_input_channels,
                                                  cfg.model.pre_n_classes, cfg.model.pre_n_input_channels)
        elif cfg.data.type == 'multi-stream':
            tags.append('multi-stream')
            tags.append('TVL1')
            if cfg.model.shared_weights:
                tags.append('shared-weights')
                if cfg.training.continue_training:
                    state_dict = torch.load(cfg.model.best_model, map_location=map_location)['model']
                    model_img, model_flow = create_two_stream_models(cfg, '', '' , bert=False)
                    model = multi_stream.MultiStreamShared(model_img, model_flow, len(cfg.data.allowed_views) * 2,
                                                           cfg.model.n_classes)
                    img_state_dict = OrderedDict({k: state_dict[k] for k in state_dict.keys() if k.find('bert') == -1})
                    model.load_state_dict(img_state_dict)
                    model.replace_mixed_5c()
                    model.replace_fc_submodels(1, inp_dim=1024)
                else:
                    model_img = i3d_bert.Inception3D_Maxpool(cfg.model.pre_trained_checkpoint_img, cfg.model.n_classes,
                                                             cfg.model.n_input_channels_img, cfg.model.pre_n_classes,
                                                             cfg.model.pre_n_input_channels_img)
                    model_flow = i3d_bert.Inception3D_Maxpool(cfg.model.pre_trained_checkpoint_flow, cfg.model.n_classes,
                                                             cfg.model.n_input_channels_flow, cfg.model.pre_n_classes,
                                                             cfg.model.pre_n_input_channels_flow)
                    model = multi_stream.MultiStreamShared(model_img, model_flow, len(cfg.data.allowed_views) * 2,
                                                       cfg.model.n_classes)
                if cfg.optimizer.loss_function == 'all-threshold':
                    model.thresholds = torch.nn.Parameter(torch.tensor(range(10)).float(), requires_grad=True)
    elif cfg.model.name == 'i3d_bert':
        tags.append('I3D')
        tags.append('BERT')
        if cfg.training.continue_training:
            #state_dict = torch.load(cfg.model.best_model)['model']
            state_dict = torch.load(cfg.model.best_model, map_location=map_location)['model']
            if cfg.data.type == 'img':
                tags.append('spatial')
                model = i3d_bert.rgb_I3D64f_bert2_FRMB('', cfg.model.length, cfg.model.n_classes,
                                                       cfg.model.n_input_channels, cfg.model.pre_n_classes,
                                                       cfg.model.pre_n_input_channels)
            if cfg.data.type == 'flow':
                tags.append('temporal')
                tags.append('TVL1')
                model = i3d_bert.flow_I3D64f_bert2_FRMB('', cfg.model.length, cfg.model.n_classes,
                                                        cfg.model.n_input_channels, cfg.model.pre_n_classes,
                                                        cfg.model.pre_n_input_channels)
            if cfg.data.type == 'multi-stream':
                tags.append('multi-stream')
                tags.append('TVL1')
                if cfg.model.shared_weights:
                    tags.append('shared-weights')
                    model_img, model_flow = create_two_stream_models(cfg, '', '')
                    model = multi_stream.MultiStreamShared(model_img, model_flow, len(state_dict['Linear_layer.weight'][0])/cfg.model.pre_n_classes, cfg.model.pre_n_classes)
                    model.load_state_dict(state_dict)
                    if not int(len(state_dict['Linear_layer.weight'][0])/cfg.model.pre_n_classes) == len(cfg.data.allowed_views) * 2 or not cfg.model.pre_n_classes == cfg.model.n_classes:
                        model.replace_fc(len(cfg.data.allowed_views) * 2, cfg.model.n_classes)
                    if cfg.optimizer.loss_function == 'all-threshold':
                        model.thresholds = torch.nn.Parameter(torch.tensor(range(10)).float(), requires_grad=True)
                else:
                    model_dict = {}
                    for view in cfg.data.allowed_views:
                        m_img_name = 'model_img_' + str(view)
                        m_flow_name = 'model_flow_' + str(view)
                        model_img, model_flow = create_two_stream_models(cfg, '', '')
                        model_dict[m_img_name] = model_img
                        model_dict[m_flow_name] = model_flow
                    model = multi_stream.MultiStream(model_dict)
            if cfg.data.type == 'no-flow':
                tags.append('no-flow')
                tags.append('multi-stream')
                if cfg.model.shared_weights:
                    tags.append('shared-weights')
                    model_img = i3d_bert.rgb_I3D64f_bert2_FRMB('', cfg.model.length,
                                                               cfg.model.n_classes, cfg.model.n_input_channels_img,
                                                               cfg.model.pre_n_classes,
                                                               cfg.model.pre_n_input_channels_img)
                    model = multi_stream.MSNoFlowShared(model_img, len(cfg.data.allowed_views)*2, cfg.model.pre_n_classes)
                    img_state_dict = OrderedDict({k: state_dict[k] for k in state_dict.keys() if
                                                  (k[0:9] == 'Model_img' or k[0:12] == 'Linear_layer')})
                    model.load_state_dict(img_state_dict)
                    model.replace_fc(len(cfg.data.allowed_views), cfg.model.n_classes)
                    if cfg.optimizer.loss_function == 'all-threshold':
                        model.thresholds = torch.nn.Parameter(torch.tensor(range(10)).float(), requires_grad=True)
                else:
                    model_dict = {}
                    for view in cfg.data.allowed_views:
                        m_img_name = 'model_img_' + str(view)
                        model_img = i3d_bert.rgb_I3D64f_bert2_FRMB('', cfg.model.length,
                                                                   cfg.model.n_classes, cfg.model.n_input_channels_img,
                                                                   cfg.model.pre_n_classes,
                                                                   cfg.model.pre_n_input_channels_img)
                        model_dict[m_img_name] = model_img
                    model = multi_stream.MultiStream(model_dict)
                    model.load_state_dict(state_dict)
        
        else:
            if cfg.data.type == 'img':
                tags.append('spatial')
                model = i3d_bert.rgb_I3D64f_bert2_FRMB(cfg.model.pre_trained_checkpoint, cfg.model.length,
                                                       cfg.model.n_classes, cfg.model.n_input_channels,
                                                       cfg.model.pre_n_classes, cfg.model.pre_n_input_channels)
            if cfg.data.type == 'flow':
                tags.append('temporal')
                tags.append('TVL1')
                model = i3d_bert.flow_I3D64f_bert2_FRMB(cfg.model.pre_trained_checkpoint, cfg.model.length,
                                                        cfg.model.n_classes, cfg.model.n_input_channels,
                                                        cfg.model.pre_n_classes, cfg.model.pre_n_input_channels)
            if cfg.data.type == 'multi-stream':
                tags.append('multi-stream')
                tags.append('TVL1')
                if cfg.model.shared_weights:
                    tags.append('shared-weights')
                    model_img, model_flow = create_two_stream_models(cfg, cfg.model.pre_trained_checkpoint_img,
                                                                     cfg.model.pre_trained_checkpoint_flow)
                    model = multi_stream.MultiStreamShared(model_img, model_flow, len(cfg.data.allowed_views) * 2,
                                                           cfg.model.n_classes)
                    if cfg.optimizer.loss_function == 'all-threshold':
                        model.thresholds = torch.nn.Parameter(torch.tensor(range(10)).float(), requires_grad=True)
                else:
                    model_dict = {}
                    for view in cfg.data.allowed_views:
                        m_img_name = 'model_img_' + str(view)
                        m_flow_name = 'model_flow_' + str(view)
                        model_img, model_flow = create_two_stream_models(cfg, cfg.model.pre_trained_checkpoint_img,
                                                                         cfg.model.pre_trained_checkpoint_flow)
                        model_dict[m_img_name] = model_img
                        model_dict[m_flow_name] = model_flow
                    model = multi_stream.MultiStream(model_dict, cfg.model.n_classes)
            if cfg.data.type == 'no-flow':
                tags.append('no-flow')
                tags.append('multi-stream')
                if cfg.model.shared_weights:
                    tags.append('shared-weights')
                    model_img = i3d_bert.rgb_I3D64f_bert2_FRMB(cfg.model.pre_trained_checkpoint_img, cfg.model.length,
                                                               cfg.model.n_classes, cfg.model.n_input_channels_img,
                                                               cfg.model.pre_n_classes,
                                                               cfg.model.pre_n_input_channels_img)
                    model = multi_stream.MSNoFlowShared(model_img, len(cfg.data.allowed_views), cfg.model.n_classes)
                    if cfg.optimizer.loss_function == 'all-threshold':
                        model.thresholds = torch.nn.Parameter(torch.tensor(range(10)).float(), requires_grad=True)
                else:
                    model_dict = {}
                    for view in cfg.data.allowed_views:
                        m_img_name = 'model_img_' + str(view)
                        model_img = i3d_bert.rgb_I3D64f_bert2_FRMB(cfg.model.pre_trained_checkpoint_img,
                                                                   cfg.model.length,
                                                                   cfg.model.n_classes, cfg.model.n_input_channels_img,
                                                                   cfg.model.pre_n_classes,
                                                                   cfg.model.pre_n_input_channels_img)
                        model_dict[m_img_name] = model_img
                    model = multi_stream.MultiStream(model_dict, cfg.model.n_classes)
            
    return model, tags
예제 #5
0
def main(cfg: DictConfig) -> None:

    assert cfg.model.name in [
        'ccnn', 'resnext', 'i3d', 'i3d_bert', 'i3d_bert_2stream'
    ]
    assert cfg.data.type in ['img', 'flow', 'multi-stream', 'no-flow']

    tags = []

    if cfg.model.name == 'ccnn':
        tags.append('CNN')
        model = custom_cnn.CNN()
    elif cfg.model.name == 'resnext':
        tags.append('ResNeXt')
        model = resnext.generate_model(
            model_depth=cfg.model.model_depth,
            cardinality=cfg.model.cardinality,
            n_classes=cfg.model.n_classes,
            n_input_channels=cfg.model.n_input_channels,
            shortcut_type=cfg.model.shortcut_type,
            conv1_t_size=cfg.model.conv1_t_size,
            conv1_t_stride=cfg.model.conv1_t_stride)
        model.load_state_dict(torch.load(cfg.model.pre_trained_checkpoint))
    elif cfg.model.name == 'i3d':

        tags.append('I3D')
        if cfg.data.type == 'img':
            tags.append('spatial')
            if cfg.training.continue_training:
                checkpoint = cfg.model.best_model
            else:
                checkpoint = cfg.model.pre_trained_checkpoint
            model = i3d_bert.inception_model(checkpoint, cfg.model.n_classes,
                                             cfg.model.n_input_channels,
                                             cfg.model.pre_n_classes,
                                             cfg.model.pre_n_input_channels)
        elif cfg.data.type == 'flow':
            tags.append('temporal')
            tags.append('TVL1')
            if cfg.training.continue_training:
                checkpoint = cfg.model.best_model
            else:
                checkpoint = cfg.model.pre_trained_checkpoint
            model = i3d_bert.inception_model_flow(
                checkpoint, cfg.model.n_classes, cfg.model.n_input_channels,
                cfg.model.pre_n_classes, cfg.model.pre_n_input_channels)
        elif cfg.data.type == 'multi-stream':
            tags.append('multi-stream')
            tags.append('TVL1')
            if cfg.model.shared_weights:
                tags.append('shared-weights')
                model_img = i3d_bert.Inception3D_Maxpool(
                    cfg.model.pre_trained_checkpoint_img, cfg.model.n_classes,
                    cfg.model.n_input_channels_img, cfg.model.pre_n_classes,
                    cfg.model.pre_n_input_channels_img)
                model_flow = i3d_bert.Inception3D_Maxpool(
                    cfg.model.pre_trained_checkpoint_flow, cfg.model.n_classes,
                    cfg.model.n_input_channels_flow, cfg.model.pre_n_classes,
                    cfg.model.pre_n_input_channels_flow)
                model = multi_stream.MultiStreamShared(
                    model_img, model_flow,
                    len(cfg.data.allowed_views) * 2, cfg.model.n_classes)
                if cfg.optimizer.loss_function == 'all-threshold':
                    model.thresholds = torch.nn.Parameter(torch.tensor(
                        range(10)).float(),
                                                          requires_grad=True)
    elif cfg.model.name == 'i3d_bert':
        tags.append('I3D')
        tags.append('BERT')
        if cfg.training.continue_training:
            state_dict = torch.load(cfg.model.best_model)['model']
            if cfg.data.type == 'img':
                tags.append('spatial')
                model = i3d_bert.rgb_I3D64f_bert2_FRMB(
                    '', cfg.model.length, cfg.model.n_classes,
                    cfg.model.n_input_channels, cfg.model.pre_n_classes,
                    cfg.model.pre_n_input_channels)
            if cfg.data.type == 'flow':
                tags.append('temporal')
                tags.append('TVL1')
                model = i3d_bert.flow_I3D64f_bert2_FRMB(
                    '', cfg.model.length, cfg.model.n_classes,
                    cfg.model.n_input_channels, cfg.model.pre_n_classes,
                    cfg.model.pre_n_input_channels)
            if cfg.data.type == 'multi-stream':
                tags.append('multi-stream')
                tags.append('TVL1')
                if cfg.model.shared_weights:
                    tags.append('shared-weights')
                    model_img, model_flow = create_two_stream_models(
                        cfg, '', '')
                    model = multi_stream.MultiStreamShared(
                        model_img, model_flow,
                        len(state_dict['Linear_layer.weight'][0]) /
                        cfg.model.pre_n_classes, cfg.model.pre_n_classes)
                    model.load_state_dict(state_dict)
                    if not int(
                            len(state_dict['Linear_layer.weight'][0]) /
                            cfg.model.pre_n_classes
                    ) == len(
                            cfg.data.allowed_views
                    ) * 2 or not cfg.model.pre_n_classes == cfg.model.n_classes:
                        model.replace_fc(
                            len(cfg.data.allowed_views) * 2,
                            cfg.model.n_classes)
                    if cfg.optimizer.loss_function == 'all-threshold':
                        model.thresholds = torch.nn.Parameter(
                            torch.tensor(range(10)).float(),
                            requires_grad=True)
                else:
                    model_dict = {}
                    for view in cfg.data.allowed_views:
                        m_img_name = 'model_img_' + str(view)
                        m_flow_name = 'model_flow_' + str(view)
                        model_img, model_flow = create_two_stream_models(
                            cfg, '', '')
                        model_dict[m_img_name] = model_img
                        model_dict[m_flow_name] = model_flow
                    model = multi_stream.MultiStream(model_dict)
            if cfg.data.type == 'no-flow':
                tags.append('no-flow')
                tags.append('multi-stream')
                if cfg.model.shared_weights:
                    tags.append('shared-weights')
                    model_img = i3d_bert.rgb_I3D64f_bert2_FRMB(
                        '', cfg.model.length, cfg.model.n_classes,
                        cfg.model.n_input_channels_img,
                        cfg.model.pre_n_classes,
                        cfg.model.pre_n_input_channels_img)
                    model = multi_stream.MSNoFlowShared(
                        model_img,
                        len(cfg.data.allowed_views) * 2,
                        cfg.model.pre_n_classes)
                    img_state_dict = OrderedDict({
                        k: state_dict[k]
                        for k in state_dict.keys()
                        if (k[0:9] == 'Model_img' or k[0:12] == 'Linear_layer')
                    })
                    model.load_state_dict(img_state_dict)
                    model.replace_fc(len(cfg.data.allowed_views),
                                     cfg.model.n_classes)
                    if cfg.optimizer.loss_function == 'all-threshold':
                        model.thresholds = torch.nn.Parameter(
                            torch.tensor(range(10)).float(),
                            requires_grad=True)
                else:
                    model_dict = {}
                    for view in cfg.data.allowed_views:
                        m_img_name = 'model_img_' + str(view)
                        model_img = i3d_bert.rgb_I3D64f_bert2_FRMB(
                            '', cfg.model.length, cfg.model.n_classes,
                            cfg.model.n_input_channels_img,
                            cfg.model.pre_n_classes,
                            cfg.model.pre_n_input_channels_img)
                        model_dict[m_img_name] = model_img
                    model = multi_stream.MultiStream(model_dict)
                    model.load_state_dict(state_dict)
        else:
            if cfg.data.type == 'img':
                tags.append('spatial')
                model = i3d_bert.rgb_I3D64f_bert2_FRMB(
                    cfg.model.pre_trained_checkpoint, cfg.model.length,
                    cfg.model.n_classes, cfg.model.n_input_channels,
                    cfg.model.pre_n_classes, cfg.model.pre_n_input_channels)
            if cfg.data.type == 'flow':
                tags.append('temporal')
                tags.append('TVL1')
                model = i3d_bert.flow_I3D64f_bert2_FRMB(
                    cfg.model.pre_trained_checkpoint, cfg.model.length,
                    cfg.model.n_classes, cfg.model.n_input_channels,
                    cfg.model.pre_n_classes, cfg.model.pre_n_input_channels)
            if cfg.data.type == 'multi-stream':
                tags.append('multi-stream')
                tags.append('TVL1')
                if cfg.model.shared_weights:
                    tags.append('shared-weights')
                    model_img, model_flow = create_two_stream_models(
                        cfg, cfg.model.pre_trained_checkpoint_img,
                        cfg.model.pre_trained_checkpoint_flow)
                    model = multi_stream.MultiStreamShared(
                        model_img, model_flow,
                        len(cfg.data.allowed_views) * 2, cfg.model.n_classes)
                    if cfg.optimizer.loss_function == 'all-threshold':
                        model.thresholds = torch.nn.Parameter(
                            torch.tensor(range(10)).float(),
                            requires_grad=True)
                else:
                    model_dict = {}
                    for view in cfg.data.allowed_views:
                        m_img_name = 'model_img_' + str(view)
                        m_flow_name = 'model_flow_' + str(view)
                        model_img, model_flow = create_two_stream_models(
                            cfg, cfg.model.pre_trained_checkpoint_img,
                            cfg.model.pre_trained_checkpoint_flow)
                        model_dict[m_img_name] = model_img
                        model_dict[m_flow_name] = model_flow
                    model = multi_stream.MultiStream(model_dict,
                                                     cfg.model.n_classes)
            if cfg.data.type == 'no-flow':
                tags.append('no-flow')
                tags.append('multi-stream')
                if cfg.model.shared_weights:
                    tags.append('shared-weights')
                    model_img = i3d_bert.rgb_I3D64f_bert2_FRMB(
                        cfg.model.pre_trained_checkpoint_img, cfg.model.length,
                        cfg.model.n_classes, cfg.model.n_input_channels_img,
                        cfg.model.pre_n_classes,
                        cfg.model.pre_n_input_channels_img)
                    model = multi_stream.MSNoFlowShared(
                        model_img, len(cfg.data.allowed_views),
                        cfg.model.n_classes)
                    if cfg.optimizer.loss_function == 'all-threshold':
                        model.thresholds = torch.nn.Parameter(
                            torch.tensor(range(10)).float(),
                            requires_grad=True)
                else:
                    model_dict = {}
                    for view in cfg.data.allowed_views:
                        m_img_name = 'model_img_' + str(view)
                        model_img = i3d_bert.rgb_I3D64f_bert2_FRMB(
                            cfg.model.pre_trained_checkpoint_img,
                            cfg.model.length, cfg.model.n_classes,
                            cfg.model.n_input_channels_img,
                            cfg.model.pre_n_classes,
                            cfg.model.pre_n_input_channels_img)
                        model_dict[m_img_name] = model_img
                    model = multi_stream.MultiStream(model_dict,
                                                     cfg.model.n_classes)

    train_data_set, val_data_set = create_data_sets(cfg)

    train_data_loader, val_data_loader = create_data_loaders(
        cfg, train_data_set, val_data_set)

    experiment = None
    if cfg.logging.logging_enabled:
        experiment_params = {
            **dict(cfg.data_loader),
            **dict(cfg.transforms),
            **dict(cfg.augmentations),
            **dict(cfg.performance),
            **dict(cfg.training),
            **dict(cfg.optimizer),
            **dict(cfg.model),
            **dict(cfg.evaluation), 'target_file': cfg.data.train_targets,
            'data_stream': cfg.data.type,
            'view': cfg.data.name,
            'train_dataset_size': len(train_data_loader.dataset),
            'val_dataset_size': len(val_data_loader.dataset)
        }
        experiment = neptune.init(project=cfg.logging.project_name,
                                  name=cfg.logging.experiment_name,
                                  tags=tags)
        experiment['parameters'] = experiment_params

    if not os.path.exists(cfg.training.checkpoint_save_path):
        os.makedirs(cfg.training.checkpoint_save_path)

    train_and_validate(model,
                       train_data_loader,
                       val_data_loader,
                       cfg,
                       experiment=experiment)