def generate_model(opt): # 调用各个模型中的generate_model函数,定义相应的模型
    assert opt.model in [
        'resnet', 'resnet2p1d', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet':
        model = resnet.generate_model(model_depth=opt.model_depth,
                                      n_classes=opt.n_classes,
                                      n_input_channels=opt.n_input_channels,
                                      shortcut_type=opt.resnet_shortcut,
                                      conv1_t_size=opt.conv1_t_size,
                                      conv1_t_stride=opt.conv1_t_stride,
                                      no_max_pool=opt.no_max_pool,
                                      widen_factor=opt.resnet_widen_factor)
    elif opt.model == 'resnet2p1d':
        model = resnet2p1d.generate_model(model_depth=opt.model_depth,
                                          n_classes=opt.n_classes,
                                          n_input_channels=opt.n_input_channels,
                                          shortcut_type=opt.resnet_shortcut,
                                          conv1_t_size=opt.conv1_t_size,
                                          conv1_t_stride=opt.conv1_t_stride,
                                          no_max_pool=opt.no_max_pool,
                                          widen_factor=opt.resnet_widen_factor)
    elif opt.model == 'wideresnet':
        model = wide_resnet.generate_model(
            model_depth=opt.model_depth,
            k=opt.wide_resnet_k,
            n_classes=opt.n_classes,
            n_input_channels=opt.n_input_channels,
            shortcut_type=opt.resnet_shortcut,
            conv1_t_size=opt.conv1_t_size,
            conv1_t_stride=opt.conv1_t_stride,
            no_max_pool=opt.no_max_pool)
    elif opt.model == 'resnext':
        model = resnext.generate_model(model_depth=opt.model_depth,
                                       cardinality=opt.resnext_cardinality,
                                       n_classes=opt.n_classes,
                                       n_input_channels=opt.n_input_channels,
                                       shortcut_type=opt.resnet_shortcut,
                                       conv1_t_size=opt.conv1_t_size,
                                       conv1_t_stride=opt.conv1_t_stride,
                                       no_max_pool=opt.no_max_pool)
    elif opt.model == 'preresnet':
        model = pre_act_resnet.generate_model(
            model_depth=opt.model_depth,
            n_classes=opt.n_classes,
            n_input_channels=opt.n_input_channels,
            shortcut_type=opt.resnet_shortcut,
            conv1_t_size=opt.conv1_t_size,
            conv1_t_stride=opt.conv1_t_stride,
            no_max_pool=opt.no_max_pool)
    elif opt.model == 'densenet':
        model = densenet.generate_model(model_depth=opt.model_depth,
                                        n_classes=opt.n_classes,
                                        n_input_channels=opt.n_input_channels,
                                        conv1_t_size=opt.conv1_t_size,
                                        conv1_t_stride=opt.conv1_t_stride,
                                        no_max_pool=opt.no_max_pool)

    return model
Ejemplo n.º 2
0
def get_model(model_name, encoder, num_classes):
    if model_name == 'lstm':
        return LSTM(encoder, num_classes)
    elif model_name == 'cnn':
        return CNN(encoder, num_classes)
    elif model_name == 'transformer':
        return Transformer(encoder, num_classes)
    elif model_name == 'cnn3d':
        if 'efficientnet' in encoder:
            return EfficientNet3D.from_name(
                encoder,
                override_params={'num_classes': num_classes},
                in_channels=1)
        elif 'resnet' in encoder:
            if encoder == 'resnet':
                encoder = 'resnet34'
            return resnet.generate_model(model_depth=int(
                encoder.split('resnet')[-1]),
                                         n_classes=num_classes,
                                         n_input_channels=1,
                                         shortcut_type='B',
                                         conv1_t_size=7,
                                         conv1_t_stride=1,
                                         no_max_pool=False,
                                         widen_factor=1.)
        elif 'resnext' in encoder:
            return resnext.generate_model(model_depth=int(
                encoder.split('resnext')[-1]),
                                          n_classes=num_classes,
                                          n_input_channels=1,
                                          cardinality=32,
                                          shortcut_type='B',
                                          conv1_t_size=7,
                                          conv1_t_stride=1,
                                          no_max_pool=False)
        elif 'resnet2p1d' in encoder:
            return resnet2p1d.generate_model(model_depth=int(
                encoder.split('resnet2p1d')[-1]),
                                             n_classes=num_classes,
                                             n_input_channels=1,
                                             shortcut_type='B',
                                             conv1_t_size=7,
                                             conv1_t_stride=1,
                                             no_max_pool=False,
                                             widen_factor=1.)
        elif 'densenet' in encoder:
            return densenet.generate_model(model_depth=int(
                encoder.split('densenet')[-1]),
                                           num_classes=num_classes,
                                           n_input_channels=1,
                                           conv1_t_size=7,
                                           conv1_t_stride=1,
                                           no_max_pool=False)
        else:
            print(encoder)
            raise
    else:
        print(model_name)
        raise
Ejemplo n.º 3
0
                                   model_depth=34,
                                   max_target=30,
                                   grayscale=False,
                                   aug=True)

    # Pass your defaults to wandb.init
    run = wandb.init(config=hyperparameter_defaults)
    # run = wandb.init(project="speedchallenge")
    config = wandb.config

    # Init network
    # model = RNN_LSTM(input_size, hidden_size, num_layers, num_classes).to(device)

    if config.grayscale:
        model = generate_model(model_depth=config.model_depth,
                               n_classes=1,
                               n_input_channels=1)

        tfms = transforms.Compose([
            transforms.Grayscale(),
        ])
    else:
        model = generate_model(model_depth=config.model_depth,
                               n_classes=1,
                               n_input_channels=3)
        tfms = None

    trainset = VideoFrameDataset(os.path.join("data", "train"),
                                 int(config.sequence_length),
                                 5,
                                 skip_frames=int(config.skip_frames),